Shashwat98 commited on
Commit
f6c40fc
·
verified ·
1 Parent(s): 5c5cb0a

Update src/inference/lr_model.py

Browse files
Files changed (1) hide show
  1. src/inference/lr_model.py +25 -8
src/inference/lr_model.py CHANGED
@@ -43,21 +43,38 @@ class LRModel:
43
  {
44
  "class_id": int,
45
  "class_name": str,
46
- "probabilities": {class_name: prob, ...}
 
 
 
 
47
  }
48
  """
49
  x = self.preprocess(image)
50
  probs = self.model.predict_proba(x)[0]
51
- class_id = int(np.argmax(probs))
52
- class_name = self.labels[class_id]
 
53
 
54
- # Build probability dict (optional)
55
  prob_dict = {
56
  self.labels[i]: float(probs[i]) for i in range(len(probs))
57
  }
58
-
 
 
 
 
 
 
 
 
 
 
 
 
59
  return {
60
- "class_id": class_id,
61
- "class_name": class_name,
62
- "probabilities": prob_dict
 
63
  }
 
43
  {
44
  "class_id": int,
45
  "class_name": str,
46
+ "probabilities": {class_name: prob},
47
+ "top_k": [
48
+ {"class_id": int, "class_name": str, "probability": float},
49
+ ...
50
+ ]
51
  }
52
  """
53
  x = self.preprocess(image)
54
  probs = self.model.predict_proba(x)[0]
55
+
56
+ pred_id = int(np.argmax(probs))
57
+ pred_name = self.labels[pred_id]
58
 
 
59
  prob_dict = {
60
  self.labels[i]: float(probs[i]) for i in range(len(probs))
61
  }
62
+
63
+ # Top-k (sorted)
64
+ sorted_indices = np.argsort(probs)[::-1]
65
+ top_k = min(top_k, len(sorted_indices))
66
+ top_k_list = []
67
+ for i in range(top_k):
68
+ cid = int(sorted_indices[i])
69
+ top_k_list.append({
70
+ "class_id": cid,
71
+ "class_name": self.labels[cid],
72
+ "probability": float(probs[cid]),
73
+ })
74
+
75
  return {
76
+ "class_id": pred_id,
77
+ "class_name": pred_name,
78
+ "probabilities": prob_dict,
79
+ "top_k": top_k_list,
80
  }