aekupor commited on
Commit
c2d84ec
·
1 Parent(s): 51e0034

Add probing

Browse files
Files changed (1) hide show
  1. handler.py +57 -0
handler.py CHANGED
@@ -48,6 +48,53 @@ class EndpointHandler():
48
  if len(doc) > token_limit:
49
  return self.handle_long_utterances(doc)
50
  return utterance.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def handle_long_utterances(self, doc: str) -> List[str]:
53
  split_count = 1
@@ -153,6 +200,12 @@ class EndpointHandler():
153
  utterances_list.append(self.eliciting_utterance_to_str(utterance))
154
  elif model_id == 'connecting':
155
  utterances_list.append(self.connecting_utterance_to_str(utterance))
 
 
 
 
 
 
156
 
157
  cuda_available = torch.cuda.is_available()
158
  if model_id == 'eliciting':
@@ -163,6 +216,10 @@ class EndpointHandler():
163
  self.model = ClassificationModel(
164
  "roberta", "aekupor/connecting", use_cuda=cuda_available
165
  )
 
 
 
 
166
  else:
167
  raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
168
 
 
48
  if len(doc) > token_limit:
49
  return self.handle_long_utterances(doc)
50
  return utterance.text
51
+
52
+ def probing_utterance_to_str(self, utterance: Utterance) -> str:
53
+ #probing using prior text and truncates end of the prior text
54
+
55
+ doc = nlp(utterance.text)
56
+ prior_text = self.truncate_end(self.get_prior_text(utterance))
57
+
58
+ if len(doc) > token_limit:
59
+ utterance_text_list = self.handle_long_utterances(doc)
60
+ utterance_with_prior_text = []
61
+ for text in utterance_text_list:
62
+ utterance_with_prior_text.append([prior_text, text])
63
+ return utterance_with_prior_text, 'list'
64
+
65
+ else:
66
+ return [prior_text, utterance.text], 'single'
67
+
68
+ def truncate_end(self, prior_text: str) -> str:
69
+ max_seq_length = 512
70
+ prior_text_max_length = int(max_seq_length / 2) #divide by 2 because 2 columns
71
+
72
+ if len(prior_text) > prior_text_max_length:
73
+ starting_index = len(prior_text) - prior_text_max_length
74
+ return prior_text[starting_index:]
75
+ return prior_text
76
+
77
+ def format_speaker(self, speaker: str, source: str) -> str:
78
+ prior_text = ''
79
+ if speaker == 'student':
80
+ prior_text += '***STUDENT '
81
+ else:
82
+ prior_text += '***SECTION_LEADER '
83
+ if source == 'not chat':
84
+ prior_text += '(audio)*** : '
85
+ else:
86
+ prior_text += '(chat)*** : '
87
+ return prior_text
88
+
89
+ def get_prior_text(self, utterance: Utterance) -> str:
90
+ prior_text = ''
91
+ if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
92
+ #TODO: add in the source
93
+ prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
94
+ prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
95
+ else:
96
+ prior_text = 'No prior utterance'
97
+ return prior_text
98
 
99
  def handle_long_utterances(self, doc: str) -> List[str]:
100
  split_count = 1
 
200
  utterances_list.append(self.eliciting_utterance_to_str(utterance))
201
  elif model_id == 'connecting':
202
  utterances_list.append(self.connecting_utterance_to_str(utterance))
203
+ elif model_id == 'probing':
204
+ utterance_str, is_list = self.probing_utterance_to_str(utterance)
205
+ if is_list == 'list':
206
+ utterances_list.extend(utterance_str)
207
+ else:
208
+ utterances_list.append(utterance_str)
209
 
210
  cuda_available = torch.cuda.is_available()
211
  if model_id == 'eliciting':
 
216
  self.model = ClassificationModel(
217
  "roberta", "aekupor/connecting", use_cuda=cuda_available
218
  )
219
+ elif model_id == 'probing':
220
+ self.model = ClassificationModel(
221
+ "roberta", "aekupor/probing", use_cuda=cuda_available
222
+ )
223
  else:
224
  raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
225