Spaces:
Build error
Build error
fix(utils): adjust get_prediction
Browse files
utils.py
CHANGED
|
@@ -116,13 +116,14 @@ def get_predictions(
|
|
| 116 |
if inputs is None:
|
| 117 |
return []
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
|
| 122 |
# Get top-3 predictions
|
| 123 |
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
|
| 124 |
-
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().
|
| 125 |
-
topk_indices = topk_indices.squeeze().
|
| 126 |
|
| 127 |
return [
|
| 128 |
{
|
|
|
|
| 116 |
if inputs is None:
|
| 117 |
return []
|
| 118 |
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
outputs = model(**inputs)
|
| 121 |
+
logits = outputs.logits
|
| 122 |
|
| 123 |
# Get top-3 predictions
|
| 124 |
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
|
| 125 |
+
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
|
| 126 |
+
topk_indices = topk_indices.squeeze().detach().numpy()
|
| 127 |
|
| 128 |
return [
|
| 129 |
{
|