Spaces:
Sleeping
Sleeping
File size: 1,472 Bytes
19cbfd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
@torch.no_grad()
def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id,
label_map_type, label_map_duration, attention_score=0.5):
model.eval()
device = next(model.parameters()).device
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
# create float32 tensors and unsqueeze to [1, 1]
prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device)
injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device)
position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device)
attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device)
outputs = model(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
prior_injuries=prior_injuries,
injury_type_id=injury_type_id,
position_id=position_id,
attention_score=attention_score
)
logits_type, logits_duration = outputs.logits
type_probs = torch.softmax(logits_type, dim=1)
duration_probs = torch.softmax(logits_duration, dim=1)
pred_type = type_probs.argmax(dim=1).item()
pred_duration = duration_probs.argmax(dim=1).item()
return (
label_map_type[pred_type], type_probs[0, pred_type].item(),
label_map_duration[pred_duration], duration_probs[0, pred_duration].item()
)
|