Ashlee Kupor commited on
Commit
a6cedd3
·
1 Parent(s): c5c8993

Add revoicing

Browse files
Files changed (1) hide show
  1. handler.py +26 -4
handler.py CHANGED
@@ -66,6 +66,22 @@ class EndpointHandler():
66
  else:
67
  return [prior_text, utterance.text], 'single'
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def adding_on_utterance_to_str(self, utterance: Utterance) -> str:
70
  #adding_on uses prior text
71
 
@@ -230,7 +246,9 @@ class EndpointHandler():
230
  utterance_str, is_list = self.probing_utterance_to_str(utterance)
231
  elif model_id == 'adding_on':
232
  utterance_str, is_list = self.adding_on_utterance_to_str(utterance)
233
-
 
 
234
  if is_list == 'list':
235
  utterances_list.extend(utterance_str)
236
  for j in range(len(utterance_str)):
@@ -257,11 +275,14 @@ class EndpointHandler():
257
  self.model = ClassificationModel(
258
  "roberta", "aekupor/probing", use_cuda=cuda_available
259
  )
260
- elif model_id == 'adding_on':
261
- # TODO: combine adding on and others
262
  self.model = ClassificationModel(
263
  "roberta", "aekupor/adding_on", use_cuda=cuda_available
264
  )
 
 
 
 
265
 
266
  predictions, _ = self.model.predict(utterances_list)
267
  return utterances_list, utterances_indexes, predictions
@@ -296,7 +317,8 @@ class EndpointHandler():
296
  utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
297
  utterances_list, utterances_indexes, predictions = self.do_prediction(full_transcript, 'adding_on')
298
  utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
299
- #TODO: add in revoicing
 
300
  else:
301
  raise ValueError("no valid talk move provided")
302
 
 
66
  else:
67
  return [prior_text, utterance.text], 'single'
68
 
69
+ def revoicing_utterance_to_str(self, utterance: Utterance) -> str:
70
+ # revoicing uses prior text and truncates end of the prior text
71
+
72
+ doc = nlp(utterance.text)
73
+ prior_text = self.truncate_end(self.get_prior_text(utterance))
74
+
75
+ if len(doc) > token_limit:
76
+ utterance_text_list = self.handle_long_utterances(doc)
77
+ utterance_with_prior_text = []
78
+ for text in utterance_text_list:
79
+ utterance_with_prior_text.append([prior_text, text])
80
+ return utterance_with_prior_text, 'list'
81
+
82
+ else:
83
+ return [prior_text, utterance.text], 'single'
84
+
85
  def adding_on_utterance_to_str(self, utterance: Utterance) -> str:
86
  #adding_on uses prior text
87
 
 
246
  utterance_str, is_list = self.probing_utterance_to_str(utterance)
247
  elif model_id == 'adding_on':
248
  utterance_str, is_list = self.adding_on_utterance_to_str(utterance)
249
+ elif model_id == 'revoicing':
250
+ utterance_str, is_list = self.revoicing_utterance_to_str(utterance)
251
+
252
  if is_list == 'list':
253
  utterances_list.extend(utterance_str)
254
  for j in range(len(utterance_str)):
 
275
  self.model = ClassificationModel(
276
  "roberta", "aekupor/probing", use_cuda=cuda_available
277
  )
278
+ elif model_id == 'adding_on':
 
279
  self.model = ClassificationModel(
280
  "roberta", "aekupor/adding_on", use_cuda=cuda_available
281
  )
282
+ elif model_id == 'revoicing':
283
+ self.model = ClassificationModel(
284
+ "roberta", "aekupor/revoicing", use_cuda=cuda_available
285
+ )
286
 
287
  predictions, _ = self.model.predict(utterances_list)
288
  return utterances_list, utterances_indexes, predictions
 
317
  utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
318
  utterances_list, utterances_indexes, predictions = self.do_prediction(full_transcript, 'adding_on')
319
  utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
320
+ utterances_list, utterances_indexes, predictions = self.do_prediction(full_transcript, 'revoicing')
321
+ utterance_talk_moves = self.add_preds_to_list(utterance_talk_moves, predictions, utterances_indexes, full_transcript)
322
  else:
323
  raise ValueError("no valid talk move provided")
324