anyonehomep1mane
Code Changes
5aa6736
import torch
from config.settings import DEVICE
def post_processed_probs(probs, labels):
return {labels[i]: probs[0][i].item() for i in range(len(labels))}
def generate_ouput(model, processor, image, texts):
inputs = processor(
text=texts,
images=image,
return_tensors="pt",
padding=True
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
return probs
def infer(model, processor, image, candidate_labels):
labels = [l.strip() for l in candidate_labels.split(",")]
probs = generate_ouput(model, processor, image, labels)
return post_processed_probs(probs, labels)