update handler for new output
Browse files- handler.py +123 -9
handler.py
CHANGED
|
@@ -3,6 +3,9 @@ from scipy.special import softmax
|
|
| 3 |
import numpy as np
|
| 4 |
import weakref
|
| 5 |
import re
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from utils import clean_str, clean_str_nopunct
|
| 8 |
import torch
|
|
@@ -10,7 +13,7 @@ from utils import MultiHeadModel, BertInputBuilder, get_num_words, MATH_PREFIXES
|
|
| 10 |
|
| 11 |
import transformers
|
| 12 |
from transformers import BertTokenizer, BertForSequenceClassification
|
| 13 |
-
|
| 14 |
|
| 15 |
transformers.logging.set_verbosity_debug()
|
| 16 |
|
|
@@ -30,9 +33,15 @@ class Utterance:
|
|
| 30 |
self.endtime = endtime
|
| 31 |
self.transcript = weakref.ref(transcript) if transcript else None
|
| 32 |
self.props = kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
self.num_math_terms = None
|
| 34 |
self.math_terms = None
|
| 35 |
|
|
|
|
| 36 |
self.uptake = None
|
| 37 |
self.reasoning = None
|
| 38 |
self.question = None
|
|
@@ -62,6 +71,21 @@ class Utterance:
|
|
| 62 |
**self.props
|
| 63 |
}
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def __repr__(self):
|
| 66 |
return f"Utterance(speaker='{self.speaker}'," \
|
| 67 |
f"text='{self.text}', uid={self.uid}," \
|
|
@@ -91,6 +115,86 @@ class Transcript:
|
|
| 91 |
def length(self):
|
| 92 |
return len(self.utterances)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def to_dict(self):
|
| 95 |
return {
|
| 96 |
'utterances': [utterance.to_dict() for utterance in self.utterances],
|
|
@@ -218,8 +322,6 @@ class UptakeModel:
|
|
| 218 |
return_pooler_output=False)
|
| 219 |
return output
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
class FocusingQuestionModel:
|
| 224 |
def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
|
| 225 |
print("Loading models...")
|
|
@@ -254,8 +356,7 @@ class FocusingQuestionModel:
|
|
| 254 |
output = self.model(input_ids=instance["input_ids"],
|
| 255 |
attention_mask=instance["attention_mask"],
|
| 256 |
token_type_ids=instance["token_type_ids"])
|
| 257 |
-
return output
|
| 258 |
-
|
| 259 |
|
| 260 |
def load_math_terms():
|
| 261 |
math_terms = []
|
|
@@ -283,7 +384,7 @@ def run_math_density(transcript):
|
|
| 283 |
matches = [match for match in matches if not any(match.start() in range(existing[0], existing[1]) for existing in matched_positions)]
|
| 284 |
if len(matches) > 0:
|
| 285 |
match_list.append(math_terms_dict[term])
|
| 286 |
-
# Update
|
| 287 |
matched_positions.update((match.start(), match.end()) for match in matches)
|
| 288 |
num_matches += len(matches)
|
| 289 |
utt.num_math_terms = num_matches
|
|
@@ -319,13 +420,13 @@ class EndpointHandler():
|
|
| 319 |
transcript.add_utterance(Utterance(**utt))
|
| 320 |
|
| 321 |
print("Running inference on %d examples..." % transcript.length())
|
| 322 |
-
|
| 323 |
# Uptake
|
| 324 |
uptake_model = UptakeModel(
|
| 325 |
self.device, self.tokenizer, self.input_builder)
|
|
|
|
| 326 |
uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
|
| 327 |
uptake_speaker=uptake_speaker)
|
| 328 |
-
|
| 329 |
# Reasoning
|
| 330 |
reasoning_model = ReasoningModel(
|
| 331 |
self.device, self.tokenizer, self.input_builder)
|
|
@@ -343,4 +444,17 @@ class EndpointHandler():
|
|
| 343 |
|
| 344 |
run_math_density(transcript)
|
| 345 |
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import weakref
|
| 5 |
import re
|
| 6 |
+
import nltk
|
| 7 |
+
from nltk.corpus import stopwords
|
| 8 |
+
nltk.download('stopwords')
|
| 9 |
|
| 10 |
from utils import clean_str, clean_str_nopunct
|
| 11 |
import torch
|
|
|
|
| 13 |
|
| 14 |
import transformers
|
| 15 |
from transformers import BertTokenizer, BertForSequenceClassification
|
| 16 |
+
from transformers.utils import logging
|
| 17 |
|
| 18 |
transformers.logging.set_verbosity_debug()
|
| 19 |
|
|
|
|
| 33 |
self.endtime = endtime
|
| 34 |
self.transcript = weakref.ref(transcript) if transcript else None
|
| 35 |
self.props = kwargs
|
| 36 |
+
self.role = None
|
| 37 |
+
self.word_count = self.get_num_words()
|
| 38 |
+
self.timestamp = [starttime, endtime]
|
| 39 |
+
self.unit_measure = None
|
| 40 |
+
self.aggregate_unit_measure = endtime
|
| 41 |
self.num_math_terms = None
|
| 42 |
self.math_terms = None
|
| 43 |
|
| 44 |
+
# moments
|
| 45 |
self.uptake = None
|
| 46 |
self.reasoning = None
|
| 47 |
self.question = None
|
|
|
|
| 71 |
**self.props
|
| 72 |
}
|
| 73 |
|
| 74 |
+
def to_talk_timeline_dict(self):
|
| 75 |
+
return{
|
| 76 |
+
'speaker': self.speaker,
|
| 77 |
+
'text': self.text,
|
| 78 |
+
'uid': self.uid,
|
| 79 |
+
'role': self.role,
|
| 80 |
+
'timestamp': self.timestamp,
|
| 81 |
+
'moments': {'reasoning': True if self.reasoning else False, 'questioning': True if self.question else False, 'uptake': True if self.uptake else False, 'focusingQuestion': True if self.focusing_question else False},
|
| 82 |
+
'unitMeasure': self.unit_measure,
|
| 83 |
+
'aggregateUnitMeasure': self.aggregate_unit_measure,
|
| 84 |
+
'wordCount': self.word_count,
|
| 85 |
+
'numMathTerms': self.num_math_terms,
|
| 86 |
+
'mathTerms': self.math_terms
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
def __repr__(self):
|
| 90 |
return f"Utterance(speaker='{self.speaker}'," \
|
| 91 |
f"text='{self.text}', uid={self.uid}," \
|
|
|
|
| 115 |
def length(self):
|
| 116 |
return len(self.utterances)
|
| 117 |
|
| 118 |
+
def update_utterance_roles(self, uptake_speaker):
|
| 119 |
+
for utt in self.utterances:
|
| 120 |
+
if (utt.speaker == uptake_speaker):
|
| 121 |
+
utt.role = 'teacher'
|
| 122 |
+
else:
|
| 123 |
+
utt.role = 'student'
|
| 124 |
+
|
| 125 |
+
def get_talk_distribution_and_length(self, uptake_speaker):
|
| 126 |
+
if ((uptake_speaker is None)):
|
| 127 |
+
return None
|
| 128 |
+
teacher_words = 0
|
| 129 |
+
teacher_utt_count = 0
|
| 130 |
+
student_words = 0
|
| 131 |
+
student_utt_count = 0
|
| 132 |
+
for utt in self.utterances:
|
| 133 |
+
if (utt.speaker == uptake_speaker):
|
| 134 |
+
utt.role = 'teacher'
|
| 135 |
+
teacher_words += utt.get_num_words()
|
| 136 |
+
teacher_utt_count += 1
|
| 137 |
+
else:
|
| 138 |
+
utt.role = 'student'
|
| 139 |
+
student_words += utt.get_num_words()
|
| 140 |
+
student_utt_count += 1
|
| 141 |
+
teacher_percentage = round(
|
| 142 |
+
(teacher_words / (teacher_words + student_words)) * 100)
|
| 143 |
+
student_percentage = 100 - teacher_percentage
|
| 144 |
+
avg_teacher_length = teacher_words / teacher_utt_count
|
| 145 |
+
avg_student_length = student_words / student_utt_count
|
| 146 |
+
return {'teacher': teacher_percentage, 'student': student_percentage}, {'teacher': avg_teacher_length, 'student': avg_student_length}
|
| 147 |
+
|
| 148 |
+
def get_word_cloud_dicts(self):
|
| 149 |
+
teacher_dict = {}
|
| 150 |
+
student_dict = {}
|
| 151 |
+
uptake_teacher_dict = {}
|
| 152 |
+
stop_words = stopwords.words('english')
|
| 153 |
+
# stopwords = nltk.corpus.stopwords.word('english')
|
| 154 |
+
# print("stopwords: ", stopwords)
|
| 155 |
+
for utt in self.utterances:
|
| 156 |
+
words = (utt.get_clean_text(remove_punct=True)).split(' ')
|
| 157 |
+
for word in words:
|
| 158 |
+
if word in stop_words: continue
|
| 159 |
+
if utt.role == 'teacher':
|
| 160 |
+
if word not in teacher_dict:
|
| 161 |
+
teacher_dict[word] = 0
|
| 162 |
+
teacher_dict[word] += 1
|
| 163 |
+
if utt.uptake == 1:
|
| 164 |
+
if word not in uptake_teacher_dict:
|
| 165 |
+
uptake_teacher_dict[word] = 0
|
| 166 |
+
uptake_teacher_dict[word] += 1
|
| 167 |
+
else:
|
| 168 |
+
if word not in student_dict:
|
| 169 |
+
student_dict[word] = 0
|
| 170 |
+
student_dict[word] += 1
|
| 171 |
+
dict_list = []
|
| 172 |
+
uptake_dict_list = []
|
| 173 |
+
for word in uptake_teacher_dict.keys():
|
| 174 |
+
uptake_dict_list.append({'text': word, 'value': uptake_teacher_dict[word], 'category': 'teacher'})
|
| 175 |
+
for word in teacher_dict.keys():
|
| 176 |
+
dict_list.append(
|
| 177 |
+
{'text': word, 'value': teacher_dict[word], 'category': 'teacher'})
|
| 178 |
+
for word in student_dict.keys():
|
| 179 |
+
dict_list.append(
|
| 180 |
+
{'text': word, 'value': student_dict[word], 'category': 'student'})
|
| 181 |
+
sorted_dict_list = sorted(dict_list, key=lambda x: x['value'], reverse=True)
|
| 182 |
+
sorted_uptake_dict_list = sorted(uptake_dict_list, key=lambda x: x['value'], reverse=True)
|
| 183 |
+
return sorted_dict_list[:50], sorted_uptake_dict_list[:50]
|
| 184 |
+
|
| 185 |
+
def get_talk_timeline(self):
|
| 186 |
+
return [utterance.to_talk_timeline_dict() for utterance in self.utterances]
|
| 187 |
+
|
| 188 |
+
def calculate_aggregate_word_count(self):
|
| 189 |
+
unit_measures = [utt.unit_measure for utt in self.utterances]
|
| 190 |
+
if None in unit_measures:
|
| 191 |
+
aggregate_word_count = 0
|
| 192 |
+
for utt in self.utterances:
|
| 193 |
+
aggregate_word_count += utt.get_num_words()
|
| 194 |
+
utt.unit_measure = utt.get_num_words()
|
| 195 |
+
utt.aggregate_unit_measure = aggregate_word_count
|
| 196 |
+
|
| 197 |
+
|
| 198 |
def to_dict(self):
|
| 199 |
return {
|
| 200 |
'utterances': [utterance.to_dict() for utterance in self.utterances],
|
|
|
|
| 322 |
return_pooler_output=False)
|
| 323 |
return output
|
| 324 |
|
|
|
|
|
|
|
| 325 |
class FocusingQuestionModel:
|
| 326 |
def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
|
| 327 |
print("Loading models...")
|
|
|
|
| 356 |
output = self.model(input_ids=instance["input_ids"],
|
| 357 |
attention_mask=instance["attention_mask"],
|
| 358 |
token_type_ids=instance["token_type_ids"])
|
| 359 |
+
return output
|
|
|
|
| 360 |
|
| 361 |
def load_math_terms():
|
| 362 |
math_terms = []
|
|
|
|
| 384 |
matches = [match for match in matches if not any(match.start() in range(existing[0], existing[1]) for existing in matched_positions)]
|
| 385 |
if len(matches) > 0:
|
| 386 |
match_list.append(math_terms_dict[term])
|
| 387 |
+
# Update matched positions
|
| 388 |
matched_positions.update((match.start(), match.end()) for match in matches)
|
| 389 |
num_matches += len(matches)
|
| 390 |
utt.num_math_terms = num_matches
|
|
|
|
| 420 |
transcript.add_utterance(Utterance(**utt))
|
| 421 |
|
| 422 |
print("Running inference on %d examples..." % transcript.length())
|
| 423 |
+
logging.set_verbosity_info()
|
| 424 |
# Uptake
|
| 425 |
uptake_model = UptakeModel(
|
| 426 |
self.device, self.tokenizer, self.input_builder)
|
| 427 |
+
uptake_speaker = params.pop("uptake_speaker", None)
|
| 428 |
uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
|
| 429 |
uptake_speaker=uptake_speaker)
|
|
|
|
| 430 |
# Reasoning
|
| 431 |
reasoning_model = ReasoningModel(
|
| 432 |
self.device, self.tokenizer, self.input_builder)
|
|
|
|
| 444 |
|
| 445 |
run_math_density(transcript)
|
| 446 |
|
| 447 |
+
transcript.update_utterance_roles(uptake_speaker)
|
| 448 |
+
transcript.calculate_aggregate_word_count()
|
| 449 |
+
return_dict = {'talkDistribution': None, 'talkLength': None, 'talkMoments': None, 'commonTopWords': None, 'uptakeTopWords': None}
|
| 450 |
+
talk_dist, talk_len = transcript.get_talk_distribution_and_length(uptake_speaker)
|
| 451 |
+
return_dict['talkDistribution'] = talk_dist
|
| 452 |
+
return_dict['talkLength'] = talk_len
|
| 453 |
+
talk_moments = transcript.get_talk_timeline()
|
| 454 |
+
return_dict['talkMoments'] = talk_moments
|
| 455 |
+
word_cloud, uptake_word_cloud = transcript.get_word_cloud_dicts()
|
| 456 |
+
return_dict['commonTopWords'] = word_cloud
|
| 457 |
+
return_dict['uptakeTopWords'] = uptake_word_cloud
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
return return_dict
|