Spaces:
Runtime error
Runtime error
Adding word sense disambiguation + definitions to synonym generation
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import gradio as gr
|
|
| 11 |
import readability
|
| 12 |
import seaborn as sns
|
| 13 |
import torch
|
|
|
|
| 14 |
from fuzzywuzzy import fuzz
|
| 15 |
from nltk.corpus import stopwords
|
| 16 |
from nltk.corpus import wordnet as wn
|
|
@@ -18,6 +19,8 @@ from nltk.tokenize import word_tokenize
|
|
| 18 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 19 |
from transformers import DistilBertTokenizer
|
| 20 |
from transformers import pipeline
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
nltk.download('wordnet')
|
|
@@ -442,6 +445,218 @@ def vocab_level_inter(text):
|
|
| 442 |
interp.append(('', 0))
|
| 443 |
return {'original': text, 'interpretation': interp}, f'{level(sum/total*4*2.5)[1:]} Level Vocabulary'
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
|
| 446 |
gr.HTML("""<center><h7 style="font-size: 35px">Automatic Literacy and Speech Assesment</h7></center>""")
|
| 447 |
gr.HTML("""<center><h7 style="font-size: 15px">This may take 60s to generate all statistics</h7></center>""")
|
|
@@ -460,8 +675,8 @@ with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
|
|
| 460 |
audio_file = gr.Audio(source="microphone",type="filepath")
|
| 461 |
grade1 = gr.Button("Grade Your Speech")
|
| 462 |
with gr.Group():
|
| 463 |
-
gr.Markdown("Reading Level Based Synonyms | Enter
|
| 464 |
-
words = gr.Textbox(label="
|
| 465 |
lvl = gr.Dropdown(choices=["Elementary Level", "Middle School Level", "High School Level", "College Level" ], label="Intended Reading Level For Synonym")
|
| 466 |
get_syns = gr.Button("Get Synonyms")
|
| 467 |
reccos = gr.Label()
|
|
@@ -532,6 +747,6 @@ with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
|
|
| 532 |
grade.click(vocab_level_inter, inputs=in_text, outputs=[interpretation3, vocab_output])
|
| 533 |
grade1.click(speech_to_score, inputs=audio_file, outputs=diff_output)
|
| 534 |
b1.click(speech_to_text, inputs=[audio_file1, target], outputs=[text, some_val, phones])
|
| 535 |
-
get_syns.click(
|
| 536 |
find_sim.click(get_sim_words, inputs=[in_text, words1], outputs=sims)
|
| 537 |
demo.launch(debug=True)
|
|
|
|
| 11 |
import readability
|
| 12 |
import seaborn as sns
|
| 13 |
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
from fuzzywuzzy import fuzz
|
| 16 |
from nltk.corpus import stopwords
|
| 17 |
from nltk.corpus import wordnet as wn
|
|
|
|
| 19 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 20 |
from transformers import DistilBertTokenizer
|
| 21 |
from transformers import pipeline
|
| 22 |
+
from transformers import BertTokenizer
|
| 23 |
+
from transformers import AutoTokenizer, BertForSequenceClassification
|
| 24 |
|
| 25 |
|
| 26 |
nltk.download('wordnet')
|
|
|
|
| 445 |
interp.append(('', 0))
|
| 446 |
return {'original': text, 'interpretation': interp}, f'{level(sum/total*4*2.5)[1:]} Level Vocabulary'
|
| 447 |
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
logger = logging.getLogger(__name__)
|
| 451 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| 452 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
| 453 |
+
level = logging.INFO)
|
| 454 |
+
tokenizer4 = AutoTokenizer.from_pretrained('kanishka/GlossBERT')
|
| 455 |
+
|
| 456 |
+
def construct_context_gloss_pairs_through_nltk(input, target_start_id, target_end_id):
|
| 457 |
+
"""
|
| 458 |
+
construct context gloss pairs like sent_cls_ws
|
| 459 |
+
:param input: str, a sentence
|
| 460 |
+
:param target_start_id: int
|
| 461 |
+
:param target_end_id: int
|
| 462 |
+
:param lemma: lemma of the target word
|
| 463 |
+
:return: candidate lists
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
sent = tokenizer4.tokenize(input)
|
| 467 |
+
assert 0 <= target_start_id and target_start_id < target_end_id and target_end_id <= len(sent)
|
| 468 |
+
target = " ".join(sent[target_start_id:target_end_id])
|
| 469 |
+
if len(sent) > target_end_id:
|
| 470 |
+
sent = sent[:target_start_id] + ['"'] + sent[target_start_id:target_end_id] + ['"'] + sent[target_end_id:]
|
| 471 |
+
else:
|
| 472 |
+
sent = sent[:target_start_id] + ['"'] + sent[target_start_id:target_end_id] + ['"']
|
| 473 |
+
|
| 474 |
+
sent = " ".join(sent)
|
| 475 |
+
|
| 476 |
+
candidate = []
|
| 477 |
+
syns = wn.synsets(target)
|
| 478 |
+
|
| 479 |
+
for syn in syns:
|
| 480 |
+
if target == syn.name().split('.')[0]:
|
| 481 |
+
continue
|
| 482 |
+
|
| 483 |
+
gloss = (syn.definition(), syn.name())
|
| 484 |
+
candidate.append((sent, f"{target} : {gloss}", target, gloss))
|
| 485 |
+
|
| 486 |
+
assert len(candidate) != 0, f'there is no candidate sense of "{target}" in WordNet, please check'
|
| 487 |
+
# print(f'there are {len(candidate)} candidate senses of "{target}"')
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
return candidate
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class InputFeatures(object):
|
| 494 |
+
"""A single set of features of data."""
|
| 495 |
+
|
| 496 |
+
def __init__(self, input_ids, input_mask, segment_ids):
|
| 497 |
+
self.input_ids = input_ids
|
| 498 |
+
self.input_mask = input_mask
|
| 499 |
+
self.segment_ids = segment_ids
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def convert_to_features(candidate, tokenizer3, max_seq_length=512):
|
| 503 |
+
|
| 504 |
+
candidate_results = []
|
| 505 |
+
features = []
|
| 506 |
+
for item in candidate:
|
| 507 |
+
text_a = item[0] # sentence
|
| 508 |
+
text_b = item[1] # gloss
|
| 509 |
+
candidate_results.append((item[-2], item[-1])) # (target, gloss)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
tokens_a = tokenizer3.tokenize(text_a)
|
| 513 |
+
tokens_b = tokenizer3.tokenize(text_b)
|
| 514 |
+
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
| 515 |
+
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
|
| 516 |
+
segment_ids = [0] * len(tokens)
|
| 517 |
+
tokens += tokens_b + ["[SEP]"]
|
| 518 |
+
segment_ids += [1] * (len(tokens_b) + 1)
|
| 519 |
+
|
| 520 |
+
input_ids = tokenizer3.convert_tokens_to_ids(tokens)
|
| 521 |
+
|
| 522 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 523 |
+
# tokens are attended to.
|
| 524 |
+
input_mask = [1] * len(input_ids)
|
| 525 |
+
|
| 526 |
+
# Zero-pad up to the sequence length.
|
| 527 |
+
padding = [0] * (max_seq_length - len(input_ids))
|
| 528 |
+
input_ids += padding
|
| 529 |
+
input_mask += padding
|
| 530 |
+
segment_ids += padding
|
| 531 |
+
|
| 532 |
+
assert len(input_ids) == max_seq_length
|
| 533 |
+
assert len(input_mask) == max_seq_length
|
| 534 |
+
assert len(segment_ids) == max_seq_length
|
| 535 |
+
|
| 536 |
+
features.append(
|
| 537 |
+
InputFeatures(input_ids=input_ids,
|
| 538 |
+
input_mask=input_mask,
|
| 539 |
+
segment_ids=segment_ids))
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
return features, candidate_results
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
| 547 |
+
"""Truncates a sequence pair in place to the maximum length."""
|
| 548 |
+
|
| 549 |
+
# This is a simple heuristic which will always truncate the longer sequence
|
| 550 |
+
# one token at a time. This makes more sense than truncating an equal percent
|
| 551 |
+
# of tokens from each, since if one sequence is very short then each token
|
| 552 |
+
# that's truncated likely contains more information than a longer sequence.
|
| 553 |
+
while True:
|
| 554 |
+
total_length = len(tokens_a) + len(tokens_b)
|
| 555 |
+
if total_length <= max_length:
|
| 556 |
+
break
|
| 557 |
+
if len(tokens_a) > len(tokens_b):
|
| 558 |
+
tokens_a.pop()
|
| 559 |
+
else:
|
| 560 |
+
tokens_b.pop()
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def infer(input, target_start_id, target_end_id, args):
|
| 564 |
+
sent = tokenizer4.tokenize(input)
|
| 565 |
+
assert 0 <= target_start_id and target_start_id < target_end_id and target_end_id <= len(sent)
|
| 566 |
+
target = " ".join(sent[target_start_id:target_end_id])
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
label_list = ["0", "1"]
|
| 573 |
+
num_labels = len(label_list)
|
| 574 |
+
|
| 575 |
+
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
| 576 |
+
num_labels=num_labels)
|
| 577 |
+
model.to(device)
|
| 578 |
+
|
| 579 |
+
# print(f"input: {input}\ntarget: {target}")
|
| 580 |
+
examples = construct_context_gloss_pairs_through_nltk(input, target_start_id, target_end_id)
|
| 581 |
+
eval_features, candidate_results = convert_to_features(examples, tokenizer4)
|
| 582 |
+
input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
| 583 |
+
input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
| 584 |
+
segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
model.eval()
|
| 588 |
+
input_ids = input_ids.to(device)
|
| 589 |
+
input_mask = input_mask.to(device)
|
| 590 |
+
segment_ids = segment_ids.to(device)
|
| 591 |
+
with torch.no_grad():
|
| 592 |
+
logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None).logits
|
| 593 |
+
logits_ = F.softmax(logits, dim=-1)
|
| 594 |
+
logits_ = logits_.detach().cpu().numpy()
|
| 595 |
+
output = np.argmax(logits_, axis=0)[1]
|
| 596 |
+
results= []
|
| 597 |
+
for idx, i in enumerate(logits_):
|
| 598 |
+
results.append((candidate_results[idx][1], i[1]*100))
|
| 599 |
+
sorted_results = sorted(results, key=lambda x: x[1], reverse=True)
|
| 600 |
+
|
| 601 |
+
return sorted_results
|
| 602 |
+
|
| 603 |
+
def format_for_gradio(inp):
|
| 604 |
+
retval = ''
|
| 605 |
+
for idx, i in enumerate(inp):
|
| 606 |
+
if idx == len(inp)-1:
|
| 607 |
+
retval += i.split('.')[0]
|
| 608 |
+
break
|
| 609 |
+
retval += f'''{i.split('.')[0]} | '''
|
| 610 |
+
return retval
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def smart_synonyms(text, level):
|
| 614 |
+
parser = argparse.ArgumentParser()
|
| 615 |
+
parser.add_argument("--bert_model", default="kanishka/GlossBERT", type=str)
|
| 616 |
+
parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available")
|
| 617 |
+
args, unknown = parser.parse_known_args()
|
| 618 |
+
|
| 619 |
+
location = 0
|
| 620 |
+
word = ''
|
| 621 |
+
tokens = tokenizer4.tokenize(text)
|
| 622 |
+
school_to_level = {"Elementary Level":'1', "Middle School Level":'2', "High School Level":'3', "College Level":'4'}
|
| 623 |
+
for idx, i in enumerate(tokens):
|
| 624 |
+
if i[0] == '@':
|
| 625 |
+
location = idx
|
| 626 |
+
text = text.replace('@', '')
|
| 627 |
+
word = tokens[location]
|
| 628 |
+
break
|
| 629 |
+
raw_syns = []
|
| 630 |
+
raw_defs = []
|
| 631 |
+
raw_scores = []
|
| 632 |
+
syns = []
|
| 633 |
+
defs = []
|
| 634 |
+
scores = []
|
| 635 |
+
preds = infer(text, location, location+1, args)
|
| 636 |
+
for i in preds:
|
| 637 |
+
if not i[0][1].split('.')[0] in data[school_to_level[level]]:
|
| 638 |
+
continue
|
| 639 |
+
raw_syns.append(i[0][1])
|
| 640 |
+
raw_defs.append(i[0][0])
|
| 641 |
+
raw_scores.append(i[1])
|
| 642 |
+
if i[1] > 5:
|
| 643 |
+
syns.append(i[0][1])
|
| 644 |
+
defs.append(i[0][0])
|
| 645 |
+
scores.append(i[1])
|
| 646 |
+
|
| 647 |
+
if not syns:
|
| 648 |
+
top_syns = int(len(raw_syns)*.25//1+1)
|
| 649 |
+
syns = raw_syns[:top_syns]
|
| 650 |
+
defs = raw_defs[:top_syns]
|
| 651 |
+
scores = raw_scores[:top_syns]
|
| 652 |
+
|
| 653 |
+
cleaned_syns = format_for_gradio(syns)
|
| 654 |
+
cleaend_defs = format_for_gradio(defs)
|
| 655 |
+
|
| 656 |
+
return f'{cleaned_syns}: Definition- {cleaend_defs} | '
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
|
| 660 |
with gr.Blocks(title="Automatic Literacy and Speech Assesmen") as demo:
|
| 661 |
gr.HTML("""<center><h7 style="font-size: 35px">Automatic Literacy and Speech Assesment</h7></center>""")
|
| 662 |
gr.HTML("""<center><h7 style="font-size: 15px">This may take 60s to generate all statistics</h7></center>""")
|
|
|
|
| 675 |
audio_file = gr.Audio(source="microphone",type="filepath")
|
| 676 |
grade1 = gr.Button("Grade Your Speech")
|
| 677 |
with gr.Group():
|
| 678 |
+
gr.Markdown("""Reading Level Based Synonyms | Enter a sentence with the word you want a synonym | Add an @ before the target word for synonym, e.g. - "Today is an @amazing day"- target word = amazing" """)
|
| 679 |
+
words = gr.Textbox(label="Text with word for synonyms")
|
| 680 |
lvl = gr.Dropdown(choices=["Elementary Level", "Middle School Level", "High School Level", "College Level" ], label="Intended Reading Level For Synonym")
|
| 681 |
get_syns = gr.Button("Get Synonyms")
|
| 682 |
reccos = gr.Label()
|
|
|
|
| 747 |
grade.click(vocab_level_inter, inputs=in_text, outputs=[interpretation3, vocab_output])
|
| 748 |
grade1.click(speech_to_score, inputs=audio_file, outputs=diff_output)
|
| 749 |
b1.click(speech_to_text, inputs=[audio_file1, target], outputs=[text, some_val, phones])
|
| 750 |
+
get_syns.click(smart_synonyms, inputs=[words, lvl], outputs=reccos)
|
| 751 |
find_sim.click(get_sim_words, inputs=[in_text, words1], outputs=sims)
|
| 752 |
demo.launch(debug=True)
|