| import torch | |
| from config.settings import DEVICE | |
| def post_processed_probs(probs, labels): | |
| return {labels[i]: probs[0][i].item() for i in range(len(labels))} | |
| def generate_ouput(model, processor, image, texts): | |
| inputs = processor( | |
| text=texts, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1) | |
| return probs | |
| def infer(model, processor, image, candidate_labels): | |
| labels = [l.strip() for l in candidate_labels.split(",")] | |
| probs = generate_ouput(model, processor, image, labels) | |
| return post_processed_probs(probs, labels) | |