Spaces:
Build error
Build error
Upload predict_injury.py
Browse files- src/predict_injury.py +36 -0
src/predict_injury.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
@torch.no_grad()
|
| 4 |
+
def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id,
|
| 5 |
+
label_map_type, label_map_duration, attention_score=0.5):
|
| 6 |
+
model.eval()
|
| 7 |
+
device = next(model.parameters()).device
|
| 8 |
+
|
| 9 |
+
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
|
| 10 |
+
|
| 11 |
+
# create float32 tensors and unsqueeze to [1, 1]
|
| 12 |
+
prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device)
|
| 13 |
+
injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device)
|
| 14 |
+
position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device)
|
| 15 |
+
attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device)
|
| 16 |
+
|
| 17 |
+
outputs = model(
|
| 18 |
+
input_ids=encoded["input_ids"],
|
| 19 |
+
attention_mask=encoded["attention_mask"],
|
| 20 |
+
prior_injuries=prior_injuries,
|
| 21 |
+
injury_type_id=injury_type_id,
|
| 22 |
+
position_id=position_id,
|
| 23 |
+
attention_score=attention_score
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
logits_type, logits_duration = outputs.logits
|
| 27 |
+
type_probs = torch.softmax(logits_type, dim=1)
|
| 28 |
+
duration_probs = torch.softmax(logits_duration, dim=1)
|
| 29 |
+
|
| 30 |
+
pred_type = type_probs.argmax(dim=1).item()
|
| 31 |
+
pred_duration = duration_probs.argmax(dim=1).item()
|
| 32 |
+
|
| 33 |
+
return (
|
| 34 |
+
label_map_type[pred_type], type_probs[0, pred_type].item(),
|
| 35 |
+
label_map_duration[pred_duration], duration_probs[0, pred_duration].item()
|
| 36 |
+
)
|