File size: 1,472 Bytes
19cbfd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch

@torch.no_grad()
def predict_injury(model, tokenizer, text, prior_injuries, injury_type_id, position_id,
                   label_map_type, label_map_duration, attention_score=0.5):
    model.eval()
    device = next(model.parameters()).device

    encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
    
    # create float32 tensors and unsqueeze to [1, 1]
    prior_injuries = torch.tensor([[prior_injuries]], dtype=torch.float32).to(device)
    injury_type_id = torch.tensor([[injury_type_id]], dtype=torch.float32).to(device)
    position_id = torch.tensor([[position_id]], dtype=torch.float32).to(device)
    attention_score = torch.tensor([[attention_score]], dtype=torch.float32).to(device)

    outputs = model(
        input_ids=encoded["input_ids"],
        attention_mask=encoded["attention_mask"],
        prior_injuries=prior_injuries,
        injury_type_id=injury_type_id,
        position_id=position_id,
        attention_score=attention_score
    )

    logits_type, logits_duration = outputs.logits
    type_probs = torch.softmax(logits_type, dim=1)
    duration_probs = torch.softmax(logits_duration, dim=1)

    pred_type = type_probs.argmax(dim=1).item()
    pred_duration = duration_probs.argmax(dim=1).item()

    return (
        label_map_type[pred_type], type_probs[0, pred_type].item(),
        label_map_duration[pred_duration], duration_probs[0, pred_duration].item()
    )