Ashlee Kupor
commited on
Commit
·
a7d2f40
1
Parent(s):
89779fa
Add model
Browse files- __pycache__/handler.cpython-310.pyc +0 -0
- __pycache__/handler.cpython-311.pyc +0 -0
- config.json +28 -0
- eval_results.txt +12 -0
- handler.py +171 -0
- merges.txt +0 -0
- model_args.json +1 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +15 -0
- test_run_handler.py +13 -0
- tokenizer.json +0 -0
- tokenizer_config.json +16 -0
- training_args.bin +3 -0
- training_progress_scores.csv +8 -0
- vocab.json +0 -0
__pycache__/handler.cpython-310.pyc
ADDED
|
Binary file (4.56 kB). View file
|
|
|
__pycache__/handler.cpython-311.pyc
ADDED
|
Binary file (7.8 kB). View file
|
|
|
config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "roberta-base",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"RobertaForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"classifier_dropout": null,
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 768,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 3072,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 514,
|
| 17 |
+
"model_type": "roberta",
|
| 18 |
+
"num_attention_heads": 12,
|
| 19 |
+
"num_hidden_layers": 12,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"position_embedding_type": "absolute",
|
| 22 |
+
"problem_type": "single_label_classification",
|
| 23 |
+
"torch_dtype": "float32",
|
| 24 |
+
"transformers_version": "4.28.0",
|
| 25 |
+
"type_vocab_size": 1,
|
| 26 |
+
"use_cache": true,
|
| 27 |
+
"vocab_size": 50265
|
| 28 |
+
}
|
eval_results.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accuracy = 1.0
|
| 2 |
+
auprc = 1.0
|
| 3 |
+
auroc = 1.0
|
| 4 |
+
eval_loss = 5.9566273298529246e-05
|
| 5 |
+
f1 = 1.0
|
| 6 |
+
fn = 0
|
| 7 |
+
fp = 0
|
| 8 |
+
mcc = 1.0
|
| 9 |
+
precision = 1.0
|
| 10 |
+
recall = 1.0
|
| 11 |
+
tn = 2262
|
| 12 |
+
tp = 241
|
handler.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from simpletransformers.classification import ClassificationModel, ClassificationArgs
|
| 2 |
+
from typing import Dict, List, Any
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import webvtt
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import torch
|
| 7 |
+
import spacy
|
| 8 |
+
|
| 9 |
+
nlp = spacy.load("en_core_web_sm")
|
| 10 |
+
tokenizer = nlp.tokenizer
|
| 11 |
+
token_limit = 200
|
| 12 |
+
|
| 13 |
+
class Utterance(object):
|
| 14 |
+
|
| 15 |
+
def __init__(self, starttime, endtime, speaker, text,
|
| 16 |
+
idx, prev_utterance, prev_prev_utterance):
|
| 17 |
+
self.starttime = starttime
|
| 18 |
+
self.endtime = endtime
|
| 19 |
+
self.speaker = speaker
|
| 20 |
+
self.text = text
|
| 21 |
+
self.idx = idx
|
| 22 |
+
self.prev_utterance = prev_utterance
|
| 23 |
+
self.prev_prev_utterance = prev_prev_utterance
|
| 24 |
+
|
| 25 |
+
class EndpointHandler():
|
| 26 |
+
def __init__(self, path="."):
|
| 27 |
+
print("Loading models...")
|
| 28 |
+
cuda_available = torch.cuda.is_available()
|
| 29 |
+
self.model = ClassificationModel(
|
| 30 |
+
"roberta", path, use_cuda=cuda_available
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def utterance_to_str(self, utterance: Utterance) -> (List[str], str):
|
| 34 |
+
#model utterance uses prior text
|
| 35 |
+
|
| 36 |
+
doc = nlp(utterance.text)
|
| 37 |
+
prior_text = self.get_prior_text(utterance)
|
| 38 |
+
|
| 39 |
+
if len(doc) > token_limit:
|
| 40 |
+
utterance_text_list = self.handle_long_utterances(doc)
|
| 41 |
+
utterance_with_prior_text = []
|
| 42 |
+
for text in utterance_text_list:
|
| 43 |
+
utterance_with_prior_text.append([prior_text, text])
|
| 44 |
+
return utterance_with_prior_text, 'list'
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
return [prior_text, utterance.text], 'single'
|
| 48 |
+
|
| 49 |
+
def format_speaker(self, speaker: str, source: str) -> str:
|
| 50 |
+
prior_text = ''
|
| 51 |
+
if speaker == 'student':
|
| 52 |
+
prior_text += '***STUDENT '
|
| 53 |
+
else:
|
| 54 |
+
prior_text += '***SECTION_LEADER '
|
| 55 |
+
if source == 'not chat':
|
| 56 |
+
prior_text += '(audio)*** : '
|
| 57 |
+
else:
|
| 58 |
+
prior_text += '(chat)*** : '
|
| 59 |
+
return prior_text
|
| 60 |
+
|
| 61 |
+
def get_prior_text(self, utterance: Utterance) -> str:
|
| 62 |
+
prior_text = ''
|
| 63 |
+
if utterance.prev_utterance != None and utterance.prev_prev_utterance != None:
|
| 64 |
+
#TODO: add in the source
|
| 65 |
+
prior_text = '\"' + self.format_speaker(utterance.prev_prev_utterance.speaker, 'not chat') + utterance.prev_prev_utterance.text + ' \n '
|
| 66 |
+
prior_text += self.format_speaker(utterance.prev_utterance.speaker, 'not chat') + utterance.prev_utterance.text + ' \n '
|
| 67 |
+
else:
|
| 68 |
+
prior_text = 'No prior utterance'
|
| 69 |
+
return prior_text
|
| 70 |
+
|
| 71 |
+
def handle_long_utterances(self, doc: str) -> List[str]:
|
| 72 |
+
split_count = 1
|
| 73 |
+
total_sent = len([x for x in doc.sents])
|
| 74 |
+
sent_count = 0
|
| 75 |
+
token_count = 0
|
| 76 |
+
split_utterance = ''
|
| 77 |
+
utterances = []
|
| 78 |
+
for sent in doc.sents:
|
| 79 |
+
# add a sentence to split
|
| 80 |
+
split_utterance = split_utterance + ' ' + sent.text
|
| 81 |
+
token_count += len(sent)
|
| 82 |
+
sent_count +=1
|
| 83 |
+
if token_count >= token_limit or sent_count == total_sent:
|
| 84 |
+
# save utterance segment
|
| 85 |
+
utterances.append(split_utterance)
|
| 86 |
+
|
| 87 |
+
# restart count
|
| 88 |
+
split_utterance = ''
|
| 89 |
+
token_count = 0
|
| 90 |
+
split_count += 1
|
| 91 |
+
|
| 92 |
+
return utterances
|
| 93 |
+
|
| 94 |
+
def convert_time(self, time_str):
|
| 95 |
+
time = datetime.strptime(time_str, "%H:%M:%S.%f")
|
| 96 |
+
return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
|
| 97 |
+
|
| 98 |
+
def process_vtt_transcript(self, vttfile) -> List[Utterance]:
|
| 99 |
+
"""Process raw vtt file."""
|
| 100 |
+
|
| 101 |
+
utterances_list = []
|
| 102 |
+
text = ""
|
| 103 |
+
prev_start = "00:00:00.000"
|
| 104 |
+
prev_end = "00:00:00.000"
|
| 105 |
+
idx = 0
|
| 106 |
+
prev_speaker = None
|
| 107 |
+
prev_utterance = None
|
| 108 |
+
prev_prev_utterance = None
|
| 109 |
+
for caption in webvtt.read(vttfile):
|
| 110 |
+
|
| 111 |
+
# Get speaker
|
| 112 |
+
check_for_speaker = caption.text.split(":")
|
| 113 |
+
if len(check_for_speaker) > 1: # the speaker was changed or restated
|
| 114 |
+
speaker = check_for_speaker[0]
|
| 115 |
+
else:
|
| 116 |
+
speaker = prev_speaker
|
| 117 |
+
|
| 118 |
+
# Get utterance
|
| 119 |
+
new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
|
| 120 |
+
|
| 121 |
+
# If speaker was changed, start new batch
|
| 122 |
+
if (prev_speaker is not None) and (speaker != prev_speaker):
|
| 123 |
+
utterance = Utterance(starttime=self.convert_time(prev_start),
|
| 124 |
+
endtime=self.convert_time(prev_end),
|
| 125 |
+
speaker=prev_speaker,
|
| 126 |
+
text=text.strip(),
|
| 127 |
+
idx=idx,
|
| 128 |
+
prev_utterance=prev_utterance,
|
| 129 |
+
prev_prev_utterance=prev_prev_utterance)
|
| 130 |
+
|
| 131 |
+
utterances_list.append(utterance)
|
| 132 |
+
|
| 133 |
+
# Start new batch
|
| 134 |
+
prev_start = caption.start
|
| 135 |
+
text = ""
|
| 136 |
+
prev_prev_utterance = prev_utterance
|
| 137 |
+
prev_utterance = utterance
|
| 138 |
+
idx+=1
|
| 139 |
+
text += new_text + " "
|
| 140 |
+
prev_end = caption.end
|
| 141 |
+
prev_speaker = speaker
|
| 142 |
+
|
| 143 |
+
# Append last one
|
| 144 |
+
if prev_speaker is not None:
|
| 145 |
+
utterance = Utterance(starttime=self.convert_time(prev_start),
|
| 146 |
+
endtime=self.convert_time(prev_end),
|
| 147 |
+
speaker=prev_speaker,
|
| 148 |
+
text=text.strip(),
|
| 149 |
+
idx=idx,
|
| 150 |
+
prev_utterance=prev_utterance,
|
| 151 |
+
prev_prev_utterance=prev_prev_utterance)
|
| 152 |
+
utterances_list.append(utterance)
|
| 153 |
+
|
| 154 |
+
return utterances_list
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def __call__(self, data_file: str) -> List[Dict[str, Any]]:
|
| 158 |
+
''' data_file is a str pointing to filename of type .vtt '''
|
| 159 |
+
|
| 160 |
+
utterances_list = []
|
| 161 |
+
for utterance in self.process_vtt_transcript(data_file):
|
| 162 |
+
#TODO: filter out to only have SL utterances
|
| 163 |
+
utterance_str, is_list = self.utterance_to_str(utterance)
|
| 164 |
+
if is_list == 'list':
|
| 165 |
+
utterances_list.extend(utterance_str)
|
| 166 |
+
else:
|
| 167 |
+
utterances_list.append(utterance_str)
|
| 168 |
+
|
| 169 |
+
predictions, raw_outputs = self.model.predict(utterances_list)
|
| 170 |
+
|
| 171 |
+
return predictions
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model_args.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"adafactor_beta1": null, "adafactor_clip_threshold": 1.0, "adafactor_decay_rate": -0.8, "adafactor_eps": [1e-30, 0.001], "adafactor_relative_step": true, "adafactor_scale_parameter": true, "adafactor_warmup_init": true, "adam_betas": [0.9, 0.999], "adam_epsilon": 1e-08, "best_model_dir": "outputs/roberta/model_utterance_FINAL_MODEL/best_model_all_transcripts", "cache_dir": "outputs/roberta/model_utterance_FINAL_MODEL/cache", "config": {}, "cosine_schedule_num_cycles": 0.5, "custom_layer_parameters": [], "custom_parameter_groups": [], "dataloader_num_workers": 0, "do_lower_case": false, "dynamic_quantize": false, "early_stopping_consider_epochs": false, "early_stopping_delta": 0, "early_stopping_metric": "eval_loss", "early_stopping_metric_minimize": true, "early_stopping_patience": 3, "encoding": null, "eval_batch_size": 8, "evaluate_during_training": true, "evaluate_during_training_silent": true, "evaluate_during_training_steps": 565, "evaluate_during_training_verbose": false, "evaluate_each_epoch": true, "fp16": false, "gradient_accumulation_steps": 2, "learning_rate": 4e-05, "local_rank": -1, "logging_steps": 50, "loss_type": null, "loss_args": {}, "manual_seed": null, "max_grad_norm": 1.0, "max_seq_length": 256, "model_name": "roberta-base", "model_type": "roberta", "multiprocessing_chunksize": -1, "n_gpu": 1, "no_cache": false, "no_save": false, "not_saved_args": [], "num_train_epochs": 5, "optimizer": "AdamW", "output_dir": "outputs/roberta/model_utterance_FINAL_MODEL", "overwrite_output_dir": true, "polynomial_decay_schedule_lr_end": 1e-07, "polynomial_decay_schedule_power": 1.0, "process_count": 1, "quantized_model": false, "reprocess_input_data": true, "save_best_model": true, "save_eval_checkpoints": false, "save_model_every_epoch": false, "save_optimizer_and_scheduler": true, "save_steps": 2000, "scheduler": "linear_schedule_with_warmup", "silent": false, "skip_special_tokens": true, "tensorboard_dir": "outputs/roberta/model_utterance_FINAL_MODEL/tensorboard", "thread_count": null, "tokenizer_name": "roberta-base", "tokenizer_type": null, "train_batch_size": 8, "train_custom_parameters_only": false, "use_cached_eval_features": false, "use_early_stopping": false, "use_hf_datasets": false, "use_multiprocessing": false, "use_multiprocessing_for_evaluation": false, "wandb_kwargs": {"reinit": true}, "wandb_project": "model_utterance_all_transcripts", "warmup_ratio": 0.06, "warmup_steps": 85, "weight_decay": 0.0, "model_class": "ClassificationModel", "labels_list": [0, 1], "labels_map": {}, "lazy_delimiter": "\t", "lazy_labels_column": 1, "lazy_loading": false, "lazy_loading_start_line": 1, "lazy_text_a_column": null, "lazy_text_b_column": null, "lazy_text_column": 0, "onnx": false, "regression": false, "sliding_window": false, "special_tokens_list": [], "stride": 0.8, "tie_value": 1}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ba1fcd9cd1ccf8636c01436ef0ba4a0b736244197e7936e2acc6be587d51197
|
| 3 |
+
size 498662069
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"cls_token": "<s>",
|
| 4 |
+
"eos_token": "</s>",
|
| 5 |
+
"mask_token": {
|
| 6 |
+
"content": "<mask>",
|
| 7 |
+
"lstrip": true,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"pad_token": "<pad>",
|
| 13 |
+
"sep_token": "</s>",
|
| 14 |
+
"unk_token": "<unk>"
|
| 15 |
+
}
|
test_run_handler.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from handler import EndpointHandler
|
| 2 |
+
|
| 3 |
+
# init handler
|
| 4 |
+
my_handler = EndpointHandler(path=".")
|
| 5 |
+
|
| 6 |
+
# prepare sample payload
|
| 7 |
+
test_payload = 'test.transcript.vtt'
|
| 8 |
+
|
| 9 |
+
# test the handler
|
| 10 |
+
test_pred=my_handler(test_payload)
|
| 11 |
+
|
| 12 |
+
# show results
|
| 13 |
+
print("test_pred", test_pred)
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"bos_token": "<s>",
|
| 4 |
+
"clean_up_tokenization_spaces": true,
|
| 5 |
+
"cls_token": "<s>",
|
| 6 |
+
"do_lower_case": false,
|
| 7 |
+
"eos_token": "</s>",
|
| 8 |
+
"errors": "replace",
|
| 9 |
+
"mask_token": "<mask>",
|
| 10 |
+
"model_max_length": 512,
|
| 11 |
+
"pad_token": "<pad>",
|
| 12 |
+
"sep_token": "</s>",
|
| 13 |
+
"tokenizer_class": "RobertaTokenizer",
|
| 14 |
+
"trim_offsets": true,
|
| 15 |
+
"unk_token": "<unk>"
|
| 16 |
+
}
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90c1b64e7276b2247df250b6e1c458a2e5b2ce6dfad37780d09a9761df9499ee
|
| 3 |
+
size 3451
|
training_progress_scores.csv
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
global_step,train_loss,mcc,tp,tn,fp,fn,auroc,auprc,accuracy,precision,recall,f1,eval_loss
|
| 2 |
+
283,0.3241749703884125,0.8551549806677315,228,2205,57,13,0.9911931203246127,0.9301238510748596,0.9720335597283261,0.946058091286307,0.8,0.8669201520912547,0.07226471802533661
|
| 3 |
+
565,0.0068567488342523575,0.9461676303820653,240,2238,24,1,0.9994827035891566,0.9972630809896331,0.9900119856172593,0.995850622406639,0.9090909090909091,0.9504950495049505,0.03605068988365176
|
| 4 |
+
566,0.00029088457813486457,0.9172646185925842,240,2223,39,1,0.9993634686008416,0.9968337615467682,0.9840191769876149,0.995850622406639,0.8602150537634409,0.923076923076923,0.05396954482425928
|
| 5 |
+
849,0.00020018930081278086,1.0,241,2262,0,0,1.0,1.0,1.0,1.0,1.0,1.0,0.0004376148809996856
|
| 6 |
+
1130,7.240189734147862e-05,1.0,241,2262,0,0,1.0,1.0,1.0,1.0,1.0,1.0,0.00017808274872814522
|
| 7 |
+
1132,6.916876009199768e-05,1.0,241,2262,0,0,1.0,1.0,1.0,1.0,1.0,1.0,0.0001942521630669758
|
| 8 |
+
1415,6.425171159207821e-05,1.0,241,2262,0,0,1.0,1.0,1.0,1.0,1.0,1.0,5.9566273298529246e-05
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|