import torch from config.settings import DEVICE def post_processed_probs(probs, labels): return {labels[i]: probs[0][i].item() for i in range(len(labels))} def generate_ouput(model, processor, image, texts): inputs = processor( text=texts, images=image, return_tensors="pt", padding=True ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1) return probs def infer(model, processor, image, candidate_labels): labels = [l.strip() for l in candidate_labels.split(",")] probs = generate_ouput(model, processor, image, labels) return post_processed_probs(probs, labels)