only run reasoning for non-teacher utterances
Browse files- handler.py +3 -3
handler.py
CHANGED
|
@@ -256,11 +256,11 @@ class ReasoningModel:
|
|
| 256 |
self.model = BertForSequenceClassification.from_pretrained(path)
|
| 257 |
self.model.to(self.device)
|
| 258 |
|
| 259 |
-
def run_inference(self, transcript, min_num_words=8):
|
| 260 |
self.model.eval()
|
| 261 |
with torch.no_grad():
|
| 262 |
for i, utt in enumerate(transcript.utterances):
|
| 263 |
-
if utt.get_num_words() >= min_num_words:
|
| 264 |
instance = self.input_builder.build_inputs([], utt.text,
|
| 265 |
max_length=self.max_length,
|
| 266 |
input_str=True)
|
|
@@ -430,7 +430,7 @@ class EndpointHandler():
|
|
| 430 |
# Reasoning
|
| 431 |
reasoning_model = ReasoningModel(
|
| 432 |
self.device, self.tokenizer, self.input_builder)
|
| 433 |
-
reasoning_model.run_inference(transcript)
|
| 434 |
|
| 435 |
# Question
|
| 436 |
question_model = QuestionModel(
|
|
|
|
| 256 |
self.model = BertForSequenceClassification.from_pretrained(path)
|
| 257 |
self.model.to(self.device)
|
| 258 |
|
| 259 |
+
def run_inference(self, transcript, min_num_words=8, uptake_speaker=None):
|
| 260 |
self.model.eval()
|
| 261 |
with torch.no_grad():
|
| 262 |
for i, utt in enumerate(transcript.utterances):
|
| 263 |
+
if utt.get_num_words() >= min_num_words and utt.speaker != uptake_speaker:
|
| 264 |
instance = self.input_builder.build_inputs([], utt.text,
|
| 265 |
max_length=self.max_length,
|
| 266 |
input_str=True)
|
|
|
|
| 430 |
# Reasoning
|
| 431 |
reasoning_model = ReasoningModel(
|
| 432 |
self.device, self.tokenizer, self.input_builder)
|
| 433 |
+
reasoning_model.run_inference(transcript, uptake_speaker=uptake_speaker)
|
| 434 |
|
| 435 |
# Question
|
| 436 |
question_model = QuestionModel(
|