Spaces:
Sleeping
Sleeping
| import torch | |
| def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id, | |
| label_map_type, label_map_duration, attention_score=0.5): | |
| model.eval() | |
| device = next(model.parameters()).device | |
| encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) | |
| # create float32 tensors and unsqueeze to [1, 1] | |
| prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device) | |
| injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device) | |
| position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device) | |
| attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device) | |
| outputs = model( | |
| input_ids=encoded["input_ids"], | |
| attention_mask=encoded["attention_mask"], | |
| prior_injuries=prior_injuries, | |
| injury_type_id=injury_type_id, | |
| position_id=position_id, | |
| attention_score=attention_score | |
| ) | |
| logits_type, logits_duration = outputs.logits | |
| type_probs = torch.softmax(logits_type, dim=1) | |
| duration_probs = torch.softmax(logits_duration, dim=1) | |
| pred_type = type_probs.argmax(dim=1).item() | |
| pred_duration = duration_probs.argmax(dim=1).item() | |
| return ( | |
| label_map_type[pred_type], type_probs[0, pred_type].item(), | |
| label_map_duration[pred_duration], duration_probs[0, pred_duration].item() | |
| ) | |