InjuryDetection / src /predict_utils.py
lrschuman17's picture
Rename src/predict_injury.py to src/predict_utils.py
2159d56 verified
import torch
@torch.no_grad()
def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id,
label_map_type, label_map_duration, attention_score=0.5):
model.eval()
device = next(model.parameters()).device
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
# create float32 tensors and unsqueeze to [1, 1]
prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device)
injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device)
position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device)
attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device)
outputs = model(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
prior_injuries=prior_injuries,
injury_type_id=injury_type_id,
position_id=position_id,
attention_score=attention_score
)
logits_type, logits_duration = outputs.logits
type_probs = torch.softmax(logits_type, dim=1)
duration_probs = torch.softmax(logits_duration, dim=1)
pred_type = type_probs.argmax(dim=1).item()
pred_duration = duration_probs.argmax(dim=1).item()
return (
label_map_type[pred_type], type_probs[0, pred_type].item(),
label_map_duration[pred_duration], duration_probs[0, pred_duration].item()
)