tanthinhdt commited on
Commit
291adba
·
verified ·
1 Parent(s): 47c18ca

fix(utils): adjust get_prediction

Browse files
Files changed (1) hide show
  1. utils.py +5 -4
utils.py CHANGED
@@ -116,13 +116,14 @@ def get_predictions(
116
  if inputs is None:
117
  return []
118
 
119
- outputs = model(**inputs)
120
- logits = outputs.logits
 
121
 
122
  # Get top-3 predictions
123
  topk_scores, topk_indices = torch.topk(logits, k, dim=1)
124
- topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().cpu().numpy()
125
- topk_indices = topk_indices.squeeze().cpu().numpy()
126
 
127
  return [
128
  {
 
116
  if inputs is None:
117
  return []
118
 
119
+ with torch.no_grad():
120
+ outputs = model(**inputs)
121
+ logits = outputs.logits
122
 
123
  # Get top-3 predictions
124
  topk_scores, topk_indices = torch.topk(logits, k, dim=1)
125
+ topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
126
+ topk_indices = topk_indices.squeeze().detach().numpy()
127
 
128
  return [
129
  {