| from typing import Tuple, Optional |
| import torch |
|
|
|
|
| def predict_topk( |
| model: torch.nn.Module, |
| img_tensor: torch.Tensor, |
| top_k: int = 5, |
| device: Optional[torch.device] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]: |
| """Retorna top_k probabilidades e índices, número total de classes e vetor de probabilidades completo. |
| |
| Saída: (top_prob, top_idx, num_classes, probabilities) |
| """ |
| if device is not None: |
| img_tensor = img_tensor.to(device) |
| model.eval() |
| with torch.no_grad(): |
| output = model(img_tensor) |
| if isinstance(output, tuple): |
| output = output[0] |
| logits = output[0] |
| probabilities = torch.nn.functional.softmax(logits, dim=0) |
| num_classes = probabilities.shape[0] |
| k = min(top_k, num_classes) |
| top_prob, top_idx = torch.topk(probabilities, k) |
| return top_prob, top_idx, num_classes, probabilities |
|
|