Ashlee Kupor commited on
Commit ·
5a083df
1
Parent(s): e14355b
Add json converting
Browse files- handler.py +50 -2
handler.py
CHANGED
|
@@ -12,6 +12,7 @@ import webvtt
|
|
| 12 |
from datetime import datetime
|
| 13 |
import torch
|
| 14 |
import spacy
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
nlp = spacy.load("en_core_web_sm")
|
|
@@ -199,6 +200,18 @@ class EndpointHandler():
|
|
| 199 |
print(utterances_list)
|
| 200 |
return utterances_list
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
def __call__(self, data: str) -> List[Dict[str, Any]]:
|
| 203 |
''' data_file is a str pointing to filename of type .vtt '''
|
| 204 |
|
|
@@ -208,9 +221,12 @@ class EndpointHandler():
|
|
| 208 |
|
| 209 |
if data_file is None:
|
| 210 |
raise ValueError("no data file provided")
|
| 211 |
-
|
|
|
|
| 212 |
utterances_list = []
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
#TODO: filter out to only have SL utterances
|
| 215 |
if model_id == 'eliciting':
|
| 216 |
utterance_str, is_list = self.eliciting_utterance_to_str(utterance)
|
|
@@ -223,11 +239,16 @@ class EndpointHandler():
|
|
| 223 |
|
| 224 |
if is_list == 'list':
|
| 225 |
utterances_list.extend(utterance_str)
|
|
|
|
|
|
|
| 226 |
else:
|
| 227 |
utterances_list.append(utterance_str)
|
|
|
|
| 228 |
|
|
|
|
| 229 |
cuda_available = torch.cuda.is_available()
|
| 230 |
if model_id == 'eliciting':
|
|
|
|
| 231 |
self.model = ClassificationModel(
|
| 232 |
"roberta", "aekupor/eliciting", use_cuda=cuda_available
|
| 233 |
)
|
|
@@ -240,6 +261,8 @@ class EndpointHandler():
|
|
| 240 |
"roberta", "aekupor/probing", use_cuda=cuda_available
|
| 241 |
)
|
| 242 |
elif model_id == 'adding_on':
|
|
|
|
|
|
|
| 243 |
self.model = ClassificationModel(
|
| 244 |
"roberta", "aekupor/adding_on", use_cuda=cuda_available
|
| 245 |
)
|
|
@@ -247,6 +270,31 @@ class EndpointHandler():
|
|
| 247 |
raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
|
| 248 |
|
| 249 |
predictions, _ = self.model.predict(utterances_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
return predictions
|
| 252 |
|
|
|
|
| 12 |
from datetime import datetime
|
| 13 |
import torch
|
| 14 |
import spacy
|
| 15 |
+
import json
|
| 16 |
|
| 17 |
|
| 18 |
nlp = spacy.load("en_core_web_sm")
|
|
|
|
| 200 |
print(utterances_list)
|
| 201 |
return utterances_list
|
| 202 |
|
| 203 |
+
|
| 204 |
+
def utterance_list_to_json(self, utterances: List[Utterance], use_prior_text: bool) -> List[str]:
|
| 205 |
+
formatted = []
|
| 206 |
+
for utterance in utterances:
|
| 207 |
+
if not use_prior_text:
|
| 208 |
+
formatted.append({'speaker': utterance.speaker, 'data': utterance.text, 'time': utterance.starttime})
|
| 209 |
+
else:
|
| 210 |
+
formatted.append([{'speaker': utterance.prev_prev_utterance.speaker, 'data': utterance.prev_prev_utterance.text, 'time': utterance.prev_prev_utterance.starttime},
|
| 211 |
+
{'speaker': utterance.prev_utterance.speaker, 'data': utterance.prev_utterance.text, 'time': utterance.prev_utterance.starttime},
|
| 212 |
+
{'speaker': utterance.speaker, 'data': utterance.text, 'time': utterance.starttime}])
|
| 213 |
+
return formatted
|
| 214 |
+
|
| 215 |
def __call__(self, data: str) -> List[Dict[str, Any]]:
|
| 216 |
''' data_file is a str pointing to filename of type .vtt '''
|
| 217 |
|
|
|
|
| 221 |
|
| 222 |
if data_file is None:
|
| 223 |
raise ValueError("no data file provided")
|
| 224 |
+
|
| 225 |
+
full_transcript = self.process_vtt_transcript(data_file)
|
| 226 |
utterances_list = []
|
| 227 |
+
utterances_indexes = [] # entry corresponds to utterance in full_transcript
|
| 228 |
+
for i in range(len(full_transcript)):
|
| 229 |
+
utterance = full_transcript[i]
|
| 230 |
#TODO: filter out to only have SL utterances
|
| 231 |
if model_id == 'eliciting':
|
| 232 |
utterance_str, is_list = self.eliciting_utterance_to_str(utterance)
|
|
|
|
| 239 |
|
| 240 |
if is_list == 'list':
|
| 241 |
utterances_list.extend(utterance_str)
|
| 242 |
+
for j in range(len(utterance_str)):
|
| 243 |
+
utterances_indexes.append(i)
|
| 244 |
else:
|
| 245 |
utterances_list.append(utterance_str)
|
| 246 |
+
utterances_indexes.append(i)
|
| 247 |
|
| 248 |
+
talk_move = ""
|
| 249 |
cuda_available = torch.cuda.is_available()
|
| 250 |
if model_id == 'eliciting':
|
| 251 |
+
talk_move = 'getIdeas'
|
| 252 |
self.model = ClassificationModel(
|
| 253 |
"roberta", "aekupor/eliciting", use_cuda=cuda_available
|
| 254 |
)
|
|
|
|
| 261 |
"roberta", "aekupor/probing", use_cuda=cuda_available
|
| 262 |
)
|
| 263 |
elif model_id == 'adding_on':
|
| 264 |
+
# TODO: combine adding on and others
|
| 265 |
+
talk_move = 'buildIdeas'
|
| 266 |
self.model = ClassificationModel(
|
| 267 |
"roberta", "aekupor/adding_on", use_cuda=cuda_available
|
| 268 |
)
|
|
|
|
| 270 |
raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
|
| 271 |
|
| 272 |
predictions, _ = self.model.predict(utterances_list)
|
| 273 |
+
|
| 274 |
+
# json formating
|
| 275 |
+
full_transcript_json = json.dumps(self.utterance_list_to_json(full_transcript, False), separators=(',', ':'))
|
| 276 |
+
print("FULL TRANSCRIPT")
|
| 277 |
+
print(full_transcript_json)
|
| 278 |
+
|
| 279 |
+
utterance_talk_moves = set()
|
| 280 |
+
for i in range(len(predictions)):
|
| 281 |
+
if predictions[i] == 1:
|
| 282 |
+
utterance_talk_moves.add(full_transcript[utterances_indexes[i]])
|
| 283 |
+
|
| 284 |
+
utterance_talk_moves_json = ''
|
| 285 |
+
if model_id == 'elicting' or model_id == 'connecting':
|
| 286 |
+
utterance_talk_moves_json = json.dumps(self.utterance_list_to_json(utterance_talk_moves, False), separators=(',', ':'))
|
| 287 |
+
elif model_id == 'adding_on':
|
| 288 |
+
utterance_talk_moves_json = json.dumps(self.utterance_list_to_json(utterance_talk_moves, True), separators=(',', ':'))
|
| 289 |
+
|
| 290 |
+
print("TALK MOVES FOUND")
|
| 291 |
+
print(utterance_talk_moves_json)
|
| 292 |
+
|
| 293 |
+
print("TALK MOVE")
|
| 294 |
+
print(talk_move)
|
| 295 |
+
|
| 296 |
+
print("NUM TALK MOVES")
|
| 297 |
+
print(len(utterance_talk_moves))
|
| 298 |
|
| 299 |
return predictions
|
| 300 |
|