MedhaCodes commited on
Commit
b72b7f2
·
verified ·
1 Parent(s): 6b51665

Update app/predict.py

Browse files
Files changed (1) hide show
  1. app/predict.py +46 -32
app/predict.py CHANGED
@@ -12,41 +12,55 @@ from app.utils import clean_name, risk_from_prob, generate_leaf_report
12
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13
  CLASSES_PATH = os.path.join(BASE_DIR, "..", "classes.json")
14
 
 
 
15
  with open(CLASSES_PATH, "r") as f:
16
- CLASS_NAMES = json.load(f)
 
 
17
 
18
  model = load_model()
19
 
20
 
21
  def predict_pil_image(pil_image: Image.Image):
22
- img = pil_image.convert("RGB")
23
- input_tensor = model.transform(img).unsqueeze(0)
24
-
25
- with torch.no_grad():
26
- logits = model.model(input_tensor)
27
- probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
28
-
29
- pred_idx = int(np.argmax(probs))
30
- pred_name = clean_name(CLASS_NAMES[pred_idx])
31
- confidence = round(float(probs[pred_idx]) * 100, 2)
32
-
33
- risk = risk_from_prob(probs[pred_idx])
34
- description, treatment = generate_leaf_report(pred_name)
35
-
36
- top5_idx = probs.argsort()[-5:][::-1]
37
- top5 = [
38
- {
39
- "class": clean_name(CLASS_NAMES[i]),
40
- "probability": float(probs[i])
41
- }
42
- for i in top5_idx
43
- ]
44
-
45
- return {
46
- "prediction": pred_name,
47
- "confidence": confidence,
48
- "risk": risk,
49
- "description": description,
50
- "treatment": treatment,
51
- "top5": top5
52
- }
 
 
 
 
 
 
 
 
 
 
 
12
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13
  CLASSES_PATH = os.path.join(BASE_DIR, "..", "classes.json")
14
 
15
+
16
+
17
  with open(CLASSES_PATH, "r") as f:
18
+ CLASS_NAMES = json.load(f)
19
+
20
+
21
 
22
  model = load_model()
23
 
24
 
25
  def predict_pil_image(pil_image: Image.Image):
26
+ img = pil_image.convert("RGB")
27
+ input_tensor = model.transform(img).unsqueeze(0)
28
+
29
+
30
+
31
+ with torch.no_grad():
32
+ logits = model.model(input_tensor)
33
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
34
+
35
+
36
+
37
+ pred_idx = int(np.argmax(probs))
38
+ pred_name = clean_name(CLASS_NAMES[pred_idx])
39
+ confidence = round(float(probs[pred_idx]) * 100, 2)
40
+
41
+
42
+
43
+ risk = risk_from_prob(probs[pred_idx])
44
+ description, treatment = generate_leaf_report(pred_name)
45
+
46
+
47
+
48
+ top5_idx = probs.argsort()[-5:][::-1]
49
+ top5 = [
50
+ {
51
+ "class": clean_name(CLASS_NAMES[i]),
52
+ "probability": float(probs[i])
53
+ }
54
+ for i in top5_idx
55
+ ]
56
+
57
+
58
+
59
+ return {
60
+ "prediction": pred_name,
61
+ "confidence": confidence,
62
+ "risk": risk,
63
+ "description": description,
64
+ "treatment": treatment,
65
+ "top5": top5
66
+ }