File size: 706 Bytes
5aa6736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from config.settings import DEVICE

def post_processed_probs(probs, labels):
    return {labels[i]: probs[0][i].item() for i in range(len(labels))}

def generate_ouput(model, processor, image, texts):
    inputs = processor(
        text=texts,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1)

    return probs

def infer(model, processor, image, candidate_labels):
    labels = [l.strip() for l in candidate_labels.split(",")]
    probs = generate_ouput(model, processor, image, labels)
    return post_processed_probs(probs, labels)