ViTViz / utils /inference.py
lucasddmc's picture
feat: primeira versão
e11edb1
raw
history blame contribute delete
936 Bytes
from typing import Tuple, Optional
import torch
def predict_topk(
model: torch.nn.Module,
img_tensor: torch.Tensor,
top_k: int = 5,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]:
"""Retorna top_k probabilidades e índices, número total de classes e vetor de probabilidades completo.
Saída: (top_prob, top_idx, num_classes, probabilities)
"""
if device is not None:
img_tensor = img_tensor.to(device)
model.eval()
with torch.no_grad():
output = model(img_tensor)
if isinstance(output, tuple):
output = output[0]
logits = output[0]
probabilities = torch.nn.functional.softmax(logits, dim=0)
num_classes = probabilities.shape[0]
k = min(top_k, num_classes)
top_prob, top_idx = torch.topk(probabilities, k)
return top_prob, top_idx, num_classes, probabilities