lrschuman17 commited on
Commit
19cbfd0
·
verified ·
1 Parent(s): 8a31222

Upload predict_injury.py

Browse files
Files changed (1) hide show
  1. src/predict_injury.py +36 -0
src/predict_injury.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id,
5
+ label_map_type, label_map_duration, attention_score=0.5):
6
+ model.eval()
7
+ device = next(model.parameters()).device
8
+
9
+ encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
10
+
11
+ # create float32 tensors and unsqueeze to [1, 1]
12
+ prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device)
13
+ injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device)
14
+ position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device)
15
+ attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device)
16
+
17
+ outputs = model(
18
+ input_ids=encoded["input_ids"],
19
+ attention_mask=encoded["attention_mask"],
20
+ prior_injuries=prior_injuries,
21
+ injury_type_id=injury_type_id,
22
+ position_id=position_id,
23
+ attention_score=attention_score
24
+ )
25
+
26
+ logits_type, logits_duration = outputs.logits
27
+ type_probs = torch.softmax(logits_type, dim=1)
28
+ duration_probs = torch.softmax(logits_duration, dim=1)
29
+
30
+ pred_type = type_probs.argmax(dim=1).item()
31
+ pred_duration = duration_probs.argmax(dim=1).item()
32
+
33
+ return (
34
+ label_map_type[pred_type], type_probs[0, pred_type].item(),
35
+ label_map_duration[pred_duration], duration_probs[0, pred_duration].item()
36
+ )