Ashlee Kupor
commited on
Commit
·
a6cedd3
1
Parent(s):
c5c8993
Add revoicing
Browse files- handler.py +26 -4
handler.py
CHANGED
|
@@ -66,6 +66,22 @@ class EndpointHandler():
|
|
| 66 |
else:
|
| 67 |
return [prior_text, utterance.text], 'single'
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def adding_on_utterance_to_str(self, utterance: Utterance) -> str:
|
| 70 |
#adding_on uses prior text
|
| 71 |
|
|
@@ -230,7 +246,9 @@ class EndpointHandler():
|
|
| 230 |
utterance_str, is_list = self.probing_utterance_to_str(utterance)
|
| 231 |
elif model_id == 'adding_on':
|
| 232 |
utterance_str, is_list = self.adding_on_utterance_to_str(utterance)
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
if is_list == 'list':
|
| 235 |
utterances_list.extend(utterance_str)
|
| 236 |
for j in range(len(utterance_str)):
|
|
@@ -257,11 +275,14 @@ class EndpointHandler():
|
|
| 257 |
self.model = ClassificationModel(
|
| 258 |
"roberta", "aekupor/probing", use_cuda=cuda_available
|
| 259 |
)
|
| 260 |
-
elif model_id == 'adding_on':
|
| 261 |
-
# TODO: combine adding on and others
|
| 262 |
self.model = ClassificationModel(
|
| 263 |
"roberta", "aekupor/adding_on", use_cuda=cuda_available
|
| 264 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
predictions, _ = self.model.predict(utterances_list)
|
| 267 |
return utterances_list, utterances_indexes, predictions
|
|
@@ -296,7 +317,8 @@ class EndpointHandler():
|
|
| 296 |
utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
|
| 297 |
utterances_list, utterances_indexes, predictions = self.do_prediction(full_transcript, 'adding_on')
|
| 298 |
utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
|
| 299 |
-
|
|
|
|
| 300 |
else:
|
| 301 |
raise ValueError("no valid talk move provided")
|
| 302 |
|
|
|
|
| 66 |
else:
|
| 67 |
return [prior_text, utterance.text], 'single'
|
| 68 |
|
| 69 |
+
def revoicing_utterance_to_str(self, utterance: Utterance) -> str:
|
| 70 |
+
# revoicing uses prior text and truncates end of the prior text
|
| 71 |
+
|
| 72 |
+
doc = nlp(utterance.text)
|
| 73 |
+
prior_text = self.truncate_end(self.get_prior_text(utterance))
|
| 74 |
+
|
| 75 |
+
if len(doc) > token_limit:
|
| 76 |
+
utterance_text_list = self.handle_long_utterances(doc)
|
| 77 |
+
utterance_with_prior_text = []
|
| 78 |
+
for text in utterance_text_list:
|
| 79 |
+
utterance_with_prior_text.append([prior_text, text])
|
| 80 |
+
return utterance_with_prior_text, 'list'
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
return [prior_text, utterance.text], 'single'
|
| 84 |
+
|
| 85 |
def adding_on_utterance_to_str(self, utterance: Utterance) -> str:
|
| 86 |
#adding_on uses prior text
|
| 87 |
|
|
|
|
| 246 |
utterance_str, is_list = self.probing_utterance_to_str(utterance)
|
| 247 |
elif model_id == 'adding_on':
|
| 248 |
utterance_str, is_list = self.adding_on_utterance_to_str(utterance)
|
| 249 |
+
elif model_id == 'revoicing':
|
| 250 |
+
utterance_str, is_list = self.revoicing_utterance_to_str(utterance)
|
| 251 |
+
|
| 252 |
if is_list == 'list':
|
| 253 |
utterances_list.extend(utterance_str)
|
| 254 |
for j in range(len(utterance_str)):
|
|
|
|
| 275 |
self.model = ClassificationModel(
|
| 276 |
"roberta", "aekupor/probing", use_cuda=cuda_available
|
| 277 |
)
|
| 278 |
+
elif model_id == 'adding_on':
|
|
|
|
| 279 |
self.model = ClassificationModel(
|
| 280 |
"roberta", "aekupor/adding_on", use_cuda=cuda_available
|
| 281 |
)
|
| 282 |
+
elif model_id == 'revoicing':
|
| 283 |
+
self.model = ClassificationModel(
|
| 284 |
+
"roberta", "aekupor/revoicing", use_cuda=cuda_available
|
| 285 |
+
)
|
| 286 |
|
| 287 |
predictions, _ = self.model.predict(utterances_list)
|
| 288 |
return utterances_list, utterances_indexes, predictions
|
|
|
|
| 317 |
utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
|
| 318 |
utterances_list, utterances_indexes, predictions = self.do_prediction(full_transcript, 'adding_on')
|
| 319 |
utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
|
| 320 |
+
utterances_list, utterances_indexes, predictions = self.do_prediction(full_transcript, 'revoicing')
|
| 321 |
+
utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
|
| 322 |
else:
|
| 323 |
raise ValueError("no valid talk move provided")
|
| 324 |
|