import torch @torch.no_grad() def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id, label_map_type, label_map_duration, attention_score=0.5): model.eval() device = next(model.parameters()).device encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) # create float32 tensors and unsqueeze to [1, 1] prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device) injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device) position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device) attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device) outputs = model( input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"], prior_injuries=prior_injuries, injury_type_id=injury_type_id, position_id=position_id, attention_score=attention_score ) logits_type, logits_duration = outputs.logits type_probs = torch.softmax(logits_type, dim=1) duration_probs = torch.softmax(logits_duration, dim=1) pred_type = type_probs.argmax(dim=1).item() pred_duration = duration_probs.argmax(dim=1).item() return ( label_map_type[pred_type], type_probs[0, pred_type].item(), label_map_duration[pred_duration], duration_probs[0, pred_duration].item() )