NLP_WSD / code.txt
gkc55's picture
Add Flask-based Word Sense Disambiguation Tool with Enhanced Lesk Algorithm
3fe126f
from flask import Flask, render_template, request, redirect, url_for, jsonify, session
import nltk
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.tag import pos_tag
from nltk.stem import WordNetLemmatizer
from collections import Counter
import re
import os
import json
import random
# Download required NLTK resources
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')
app = Flask(__name__)
app.secret_key = 'wsd_secret_key_2023'
# Path for storing feedback data
FEEDBACK_FILE = 'feedback_data.json'
class EnhancedLesk:
def __init__(self):
self.feedback = self.load_feedback()
self.lemmatizer = WordNetLemmatizer()
self.stopwords = set(nltk.corpus.stopwords.words('english'))
# Try to load BERT models if available
try:
from transformers import AutoTokenizer, AutoModel
import torch
# Load pre-trained model and tokenizer
print("Loading BERT models...")
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.bert_model = AutoModel.from_pretrained('bert-base-uncased')
self.bert_available = True
print("BERT models loaded successfully")
except Exception as e:
print(f"BERT models not available: {e}")
print("Continuing without BERT embeddings")
self.bert_available = False
def load_feedback(self):
if os.path.exists(FEEDBACK_FILE):
with open(FEEDBACK_FILE) as f:
return json.load(f)
return {}
def save_feedback(self):
with open(FEEDBACK_FILE, 'w') as f:
json.dump(self.feedback, f)
def get_wordnet_pos(self, treebank_tag):
"""Convert POS tag to WordNet POS format"""
if treebank_tag.startswith('J'):
return wn.ADJ
elif treebank_tag.startswith('V'):
return wn.VERB
elif treebank_tag.startswith('N'):
return wn.NOUN
elif treebank_tag.startswith('R'):
return wn.ADV
else:
return None
def process_context(self, sentence, target_word):
"""Process context words with positional weighting"""
words = word_tokenize(sentence.lower())
# Find target word position
target_pos = -1
for i, word in enumerate(words):
if word.lower() == target_word.lower():
target_pos = i
break
# Process context words with proximity weighting
context_words = []
for i, word in enumerate(words):
if word.isalpha() and word not in self.stopwords:
lemma = self.lemmatizer.lemmatize(word)
# Weight by proximity to target word (closer = more important)
if target_pos >= 0:
distance = abs(i - target_pos)
# Add word multiple times based on proximity (max 5 times for adjacent words)
weight = max(1, 6 - distance) if distance <= 5 else 1
context_words.extend([lemma] * weight)
else:
context_words.append(lemma)
return context_words
def calculate_overlap_score(self, sense, context):
"""Calculate overlap between sense signature and context with improved weighting"""
# Create rich signature from sense
signature = []
# Add definition words (higher weight)
def_words = [w.lower() for w in word_tokenize(sense.definition())
if w.isalpha() and w not in self.stopwords]
signature.extend(def_words * 2) # Double weight for definition
# Add example words
for example in sense.examples():
ex_words = [w.lower() for w in word_tokenize(example)
if w.isalpha() and w not in self.stopwords]
signature.extend(ex_words)
# Add hypernyms, hyponyms, meronyms and holonyms
for hypernym in sense.hypernyms():
hyper_words = [w.lower() for w in word_tokenize(hypernym.definition())
if w.isalpha() and w not in self.stopwords]
signature.extend(hyper_words)
for hyponym in sense.hyponyms():
hypo_words = [w.lower() for w in word_tokenize(hyponym.definition())
if w.isalpha() and w not in self.stopwords]
signature.extend(hypo_words)
# Add meronyms and holonyms
for meronym in sense.part_meronyms() + sense.substance_meronyms():
meronym_words = [w.lower() for w in word_tokenize(meronym.definition())
if w.isalpha() and w not in self.stopwords]
signature.extend(meronym_words)
for holonym in sense.part_holonyms() + sense.substance_holonyms():
holonym_words = [w.lower() for w in word_tokenize(holonym.definition())
if w.isalpha() and w not in self.stopwords]
signature.extend(holonym_words)
# Calculate overlap using Counter for better frequency matching
context_counter = Counter(context)
signature_counter = Counter(signature)
# Calculate weighted overlap
overlap_score = 0
for word, count in context_counter.items():
if word in signature_counter:
# Score is product of frequencies
overlap_score += count * min(signature_counter[word], 5)
return overlap_score
def bert_similarity(self, sense, context_sentence, target_word):
"""Calculate semantic similarity using BERT embeddings"""
if not hasattr(self, 'bert_available') or not self.bert_available:
return 0
try:
import torch
# Create context-gloss pair as in GlossBERT
gloss = sense.definition()
# Tokenize
inputs = self.tokenizer(context_sentence, gloss, return_tensors="pt",
padding=True, truncation=True, max_length=512)
# Get embeddings
with torch.no_grad():
outputs = self.bert_model(**inputs)
# Use CLS token embedding for similarity
similarity = torch.cosine_similarity(
outputs.last_hidden_state[0, 0],
outputs.last_hidden_state[0, inputs.input_ids[0].tolist().index(self.tokenizer.sep_token_id) + 1]
).item()
return similarity * 10 # Scale up to be comparable with other scores
except Exception as e:
print(f"Error in BERT similarity calculation: {e}")
return 0
def check_collocations(self, sentence, target_word):
"""Check for common collocations that indicate specific senses"""
collocations = {
"bat": {
"noun.animal": ["flying bat", "bat flying", "bat wings", "vampire bat", "fruit bat", "bat in the dark", "bat at night"],
"noun.artifact": ["baseball bat", "cricket bat", "swing the bat", "wooden bat", "hit with bat"]
},
"bank": {
"noun.artifact": ["bank account", "bank manager", "bank loan", "bank robbery", "money in bank"],
"noun.object": ["river bank", "bank of the river", "west bank", "bank erosion", "along the bank"]
},
"bass": {
"noun.animal": ["bass fish", "catch bass", "fishing bass", "largemouth bass"],
"noun.attribute": ["bass sound", "bass guitar", "bass player", "bass note", "bass drum"]
},
"spring": {
"noun.time": ["spring season", "this spring", "last spring", "spring weather", "spring flowers"],
"noun.artifact": ["metal spring", "spring coil", "spring mechanism"],
"noun.object": ["water spring", "hot spring", "spring water"]
},
"crane": {
"noun.animal": ["crane bird", "crane flew", "crane nest", "crane species"],
"noun.artifact": ["construction crane", "crane operator", "crane lifted"]
}
}
if target_word not in collocations:
return None, 0
# Check for collocations in sentence
sentence_lower = sentence.lower()
for domain, phrases in collocations[target_word].items():
for phrase in phrases:
if phrase.lower() in sentence_lower:
# Find matching sense
for sense in wn.synsets(target_word):
if sense.lexname() == domain:
return sense, 15 # Very high confidence for collocations
return None, 0
def apply_rules(self, word, context, senses):
"""Apply hand-coded rules for common ambiguous words"""
word = word.lower()
context_words = set(context)
# Rules for "bat"
if word == "bat":
# Animal sense rules
animal_indicators = {"fly", "flying", "flew", "wing", "wings", "night",
"dark", "cave", "nocturnal", "mammal", "animal", "leather", "leathery"}
if any(indicator in context_words for indicator in animal_indicators):
# Find animal sense
for sense in senses:
if sense.lexname() == "noun.animal":
return 10, sense # High confidence boost
# Sports equipment rules
sports_indicators = {"hit", "swing", "ball", "baseball", "cricket",
"player", "game", "sport", "team", "wooden"}
if any(indicator in context_words for indicator in sports_indicators):
# Find artifact sense
for sense in senses:
if sense.lexname() == "noun.artifact":
return 8, sense # High confidence boost
# Rules for "bank"
elif word == "bank":
# Financial institution rules
finance_indicators = {"money", "account", "deposit", "withdraw", "loan",
"credit", "debit", "financial", "cash", "check"}
if any(indicator in context_words for indicator in finance_indicators):
for sense in senses:
if "financial" in sense.definition() or "money" in sense.definition():
return 10, sense
# River bank rules
river_indicators = {"river", "stream", "water", "flow", "shore", "beach"}
if any(indicator in context_words for indicator in river_indicators):
for sense in senses:
if "river" in sense.definition() or "stream" in sense.definition():
return 10, sense
# Rules for "bass"
elif word == "bass":
# Fish sense rules
fish_indicators = {"fish", "fishing", "catch", "caught", "water", "lake", "river"}
if any(indicator in context_words for indicator in fish_indicators):
for sense in senses:
if sense.lexname() == "noun.animal":
return 10, sense
# Sound/music sense rules
music_indicators = {"music", "sound", "guitar", "player", "band", "note", "tone", "instrument", "concert", "loud"}
if any(indicator in context_words for indicator in music_indicators):
for sense in senses:
if sense.lexname() == "noun.attribute" or "music" in sense.definition():
return 10, sense
# No rule matched with high confidence
return 0, None
def safe_compare_synsets(self, synset1, synset2):
"""Safely compare two synsets, handling None values."""
if synset1 is None or synset2 is None:
return synset1 is synset2 # True only if both are None
# Use the built-in equality check for synsets
try:
return synset1 == synset2
except AttributeError:
return False # If comparison fails, they're not equal
def disambiguate(self, sentence, word):
"""Disambiguate a word in a given sentence context"""
word = word.lower()
# Get POS tag for the target word
word_tokens = word_tokenize(sentence)
pos_tags = pos_tag(word_tokens)
word_pos = None
for token, pos in pos_tags:
if token.lower() == word:
word_pos = self.get_wordnet_pos(pos)
break
# Get senses filtered by POS if available
if word_pos:
senses = [s for s in wn.synsets(word) if s.pos() == word_pos]
if not senses:
senses = wn.synsets(word)
else:
senses = wn.synsets(word)
if not senses:
return None, []
# Process context with positional weighting
context = self.process_context(sentence, word)
# 1. Check for collocations first (highest priority)
collocation_sense, collocation_score = self.check_collocations(sentence, word)
if collocation_sense and collocation_score > 0:
# Return the collocation sense and remaining senses as alternatives
top_senses = [s for s in senses if not self.safe_compare_synsets(s, collocation_sense)][:3]
return collocation_sense, top_senses
# 2. Apply rules for common ambiguous words
rule_score, rule_sense = self.apply_rules(word, context, senses)
# Score each sense
scored_senses = []
for sense in senses:
# If this sense was selected by rules, add the rule score
# FIX: Use safe comparison to prevent AttributeError
rule_boost = rule_score if (rule_sense is not None and self.safe_compare_synsets(sense, rule_sense)) else 0
# Calculate base score using overlap
overlap_score = self.calculate_overlap_score(sense, context)
# Calculate BERT similarity if available
bert_score = 0
if hasattr(self, 'bert_available') and self.bert_available:
bert_score = self.bert_similarity(sense, sentence, word)
# Apply feedback boost if available
feedback_key = f"{word}_{hash(sentence) % 10000}"
feedback_score = self.feedback.get(feedback_key, {}).get(sense.name(), 0)
# Calculate final score as weighted combination
final_score = (
overlap_score * 0.4 +
bert_score * 0.3 +
rule_boost * 0.2 +
feedback_score * 0.1
)
scored_senses.append((final_score, sense))
scored_senses.sort(reverse=True, key=lambda x: x[0])
if not scored_senses:
return None, []
best_sense = scored_senses[0][1]
top_senses = [s[1] for s in scored_senses[1:4]]
return best_sense, top_senses
def add_feedback(self, word, context, correct_sense):
"""Store user feedback to improve future disambiguation"""
# Create a key based on word and hashed context
context_str = ' '.join(context[:10]) # Use first 10 context words
key = f"{word}_{hash(context_str) % 10000}"
if key not in self.feedback:
self.feedback[key] = {}
# Increase score for the correct sense
self.feedback[key][correct_sense] = self.feedback[key].get(correct_sense, 0) + 5
# Optionally decrease scores for other senses
for sense in wn.synsets(word):
if sense.name() != correct_sense and sense.name() in self.feedback[key]:
self.feedback[key][sense.name()] = max(0, self.feedback[key][sense.name()] - 1)
self.save_feedback()
# Return the updated sense information
for sense in wn.synsets(word):
if sense.name() == correct_sense:
return {
'definition': sense.definition(),
'examples': sense.examples()
}
return None
# Initialize the Lesk processor
lesk_processor = EnhancedLesk()
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
text = request.form['text']
target_word = request.form.get('target_word', '')
return redirect(url_for('results', text=text, word=target_word))
return render_template('index.html')
@app.route('/results')
def results():
text = request.args.get('text', '')
target_word = request.args.get('word', '').lower()
if not target_word:
# Find ambiguous words (with multiple senses)
words = word_tokenize(text.lower())
ambiguous_words = []
for word in words:
if word.isalpha() and len(wn.synsets(word)) > 1:
ambiguous_words.append(word)
# If there are ambiguous words, use the first one
if ambiguous_words:
target_word = ambiguous_words[0]
best_sense = None
top_senses = []
highlighted_text = text
sentence = ""
context_words = []
if target_word:
sentences = sent_tokenize(text)
for sent in sentences:
if re.search(r'\b' + re.escape(target_word) + r'\b', sent, re.I):
sentence = sent
context_words = lesk_processor.process_context(sent, target_word)
try:
best_sense, top_senses = lesk_processor.disambiguate(sent, target_word)
except Exception as e:
print(f"Disambiguation error: {e}")
return render_template('error.html',
error_message=f"Could not disambiguate the word '{target_word}'. Please try a different word or sentence.",
error_details=str(e))
highlighted_text = re.sub(
r'\b' + re.escape(target_word) + r'\b',
f'<span class="highlight-word">{target_word}</span>',
text,
flags=re.IGNORECASE
)
break
# Store in session for feedback
if best_sense:
session['last_disambiguation'] = {
'word': target_word,
'context': context_words,
'sentence': sentence
}
return render_template('results.html',
text=text,
highlighted_text=highlighted_text,
target_word=target_word,
best_sense=best_sense,
top_senses=top_senses,
sentence=sentence,
context_words=', '.join([w for w in set(context_words)][:10])) # Show unique context words
@app.route('/feedback', methods=['POST'])
def feedback():
data = request.get_json()
word = data.get('word')
context = data.get('context', [])
correct_sense = data.get('correct_sense')
if word and correct_sense:
updated_sense = lesk_processor.add_feedback(word, context, correct_sense)
return jsonify(updated_sense)
return jsonify({'error': 'Invalid feedback data'}), 400
@app.route('/lesk-explained')
def lesk_explained():
return render_template('lesk_explained.html')
# Add error template handler
@app.route('/error')
def error():
error_message = request.args.get('message', 'An unknown error occurred')
error_details = request.args.get('details', '')
return render_template('error.html', error_message=error_message, error_details=error_details)
if __name__ == '__main__':
app.run(debug=True)