Ashlee Kupor commited on
Commit
5a083df
·
1 Parent(s): e14355b

Add json converting

Browse files
Files changed (1) hide show
  1. handler.py +50 -2
handler.py CHANGED
@@ -12,6 +12,7 @@ import webvtt
12
  from datetime import datetime
13
  import torch
14
  import spacy
 
15
 
16
 
17
  nlp = spacy.load("en_core_web_sm")
@@ -199,6 +200,18 @@ class EndpointHandler():
199
  print(utterances_list)
200
  return utterances_list
201
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  def __call__(self, data: str) -> List[Dict[str, Any]]:
203
  ''' data_file is a str pointing to filename of type .vtt '''
204
 
@@ -208,9 +221,12 @@ class EndpointHandler():
208
 
209
  if data_file is None:
210
  raise ValueError("no data file provided")
211
-
 
212
  utterances_list = []
213
- for utterance in self.process_vtt_transcript(data_file):
 
 
214
  #TODO: filter out to only have SL utterances
215
  if model_id == 'eliciting':
216
  utterance_str, is_list = self.eliciting_utterance_to_str(utterance)
@@ -223,11 +239,16 @@ class EndpointHandler():
223
 
224
  if is_list == 'list':
225
  utterances_list.extend(utterance_str)
 
 
226
  else:
227
  utterances_list.append(utterance_str)
 
228
 
 
229
  cuda_available = torch.cuda.is_available()
230
  if model_id == 'eliciting':
 
231
  self.model = ClassificationModel(
232
  "roberta", "aekupor/eliciting", use_cuda=cuda_available
233
  )
@@ -240,6 +261,8 @@ class EndpointHandler():
240
  "roberta", "aekupor/probing", use_cuda=cuda_available
241
  )
242
  elif model_id == 'adding_on':
 
 
243
  self.model = ClassificationModel(
244
  "roberta", "aekupor/adding_on", use_cuda=cuda_available
245
  )
@@ -247,6 +270,31 @@ class EndpointHandler():
247
  raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
248
 
249
  predictions, _ = self.model.predict(utterances_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  return predictions
252
 
 
12
  from datetime import datetime
13
  import torch
14
  import spacy
15
+ import json
16
 
17
 
18
  nlp = spacy.load("en_core_web_sm")
 
200
  print(utterances_list)
201
  return utterances_list
202
 
203
+
204
+ def utterance_list_to_json(self, utterances: List[Utterance], use_prior_text: bool) -> List[str]:
205
+ formatted = []
206
+ for utterance in utterances:
207
+ if not use_prior_text:
208
+ formatted.append({'speaker': utterance.speaker, 'data': utterance.text, 'time': utterance.starttime})
209
+ else:
210
+ formatted.append([{'speaker': utterance.prev_prev_utterance.speaker, 'data': utterance.prev_prev_utterance.text, 'time': utterance.prev_prev_utterance.starttime},
211
+ {'speaker': utterance.prev_utterance.speaker, 'data': utterance.prev_utterance.text, 'time': utterance.prev_utterance.starttime},
212
+ {'speaker': utterance.speaker, 'data': utterance.text, 'time': utterance.starttime}])
213
+ return formatted
214
+
215
  def __call__(self, data: str) -> List[Dict[str, Any]]:
216
  ''' data_file is a str pointing to filename of type .vtt '''
217
 
 
221
 
222
  if data_file is None:
223
  raise ValueError("no data file provided")
224
+
225
+ full_transcript = self.process_vtt_transcript(data_file)
226
  utterances_list = []
227
+ utterances_indexes = [] # entry corresponds to utterance in full_transcript
228
+ for i in range(len(full_transcript)):
229
+ utterance = full_transcript[i]
230
  #TODO: filter out to only have SL utterances
231
  if model_id == 'eliciting':
232
  utterance_str, is_list = self.eliciting_utterance_to_str(utterance)
 
239
 
240
  if is_list == 'list':
241
  utterances_list.extend(utterance_str)
242
+ for j in range(len(utterance_str)):
243
+ utterances_indexes.append(i)
244
  else:
245
  utterances_list.append(utterance_str)
246
+ utterances_indexes.append(i)
247
 
248
+ talk_move = ""
249
  cuda_available = torch.cuda.is_available()
250
  if model_id == 'eliciting':
251
+ talk_move = 'getIdeas'
252
  self.model = ClassificationModel(
253
  "roberta", "aekupor/eliciting", use_cuda=cuda_available
254
  )
 
261
  "roberta", "aekupor/probing", use_cuda=cuda_available
262
  )
263
  elif model_id == 'adding_on':
264
+ # TODO: combine adding on and others
265
+ talk_move = 'buildIdeas'
266
  self.model = ClassificationModel(
267
  "roberta", "aekupor/adding_on", use_cuda=cuda_available
268
  )
 
270
  raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
271
 
272
  predictions, _ = self.model.predict(utterances_list)
273
+
274
+ # json formating
275
+ full_transcript_json = json.dumps(self.utterance_list_to_json(full_transcript, False), separators=(',', ':'))
276
+ print("FULL TRANSCRIPT")
277
+ print(full_transcript_json)
278
+
279
+ utterance_talk_moves = set()
280
+ for i in range(len(predictions)):
281
+ if predictions[i] == 1:
282
+ utterance_talk_moves.add(full_transcript[utterances_indexes[i]])
283
+
284
+ utterance_talk_moves_json = ''
285
+ if model_id == 'elicting' or model_id == 'connecting':
286
+ utterance_talk_moves_json = json.dumps(self.utterance_list_to_json(utterance_talk_moves, False), separators=(',', ':'))
287
+ elif model_id == 'adding_on':
288
+ utterance_talk_moves_json = json.dumps(self.utterance_list_to_json(utterance_talk_moves, True), separators=(',', ':'))
289
+
290
+ print("TALK MOVES FOUND")
291
+ print(utterance_talk_moves_json)
292
+
293
+ print("TALK MOVE")
294
+ print(talk_move)
295
+
296
+ print("NUM TALK MOVES")
297
+ print(len(utterance_talk_moves))
298
 
299
  return predictions
300