michaela299 commited on
Commit
a9b9cc0
·
1 Parent(s): e7ecfd9

limit predictions

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -16,30 +16,35 @@ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cp
16
  model.eval()
17
 
18
  def predict(input_image):
19
- #apply the transform
20
- image_tensor = val_test_transform(input_image)
21
 
22
- #add batch dimension
23
- image_tensor = image_tensor.unsqueeze(0)
24
 
25
- #run inference
26
  with torch.no_grad():
27
- output = model(image_tensor)
28
-
29
- #get probabilitiees
30
- probabilities = torch.nn.functional.softmax(output,dim=1)[0]
 
 
 
31
 
32
- #create the output dictionary
33
- # Convert to dictionary
34
- all_results = {CLASS_NAMES[i]: probabilities[i].item() for i in range(len(probabilities))}
35
 
36
- # Sort by probability (highest first)
37
- sorted_results = dict(sorted(all_results.items(), key=lambda x: x[1], reverse=True))
 
 
 
 
38
 
39
- # Keep only top 10
40
- top_10 = dict(list(sorted_results.items())[:10])
41
 
42
- return top_10
43
 
44
 
45
  with gr.Blocks(title="Plant Disease Classifier") as app:
 
16
  model.eval()
17
 
18
  def predict(input_image):
19
+ # 1. Transform the image (resize, normalize, etc.)
20
+ processed_image = val_test_transform(input_image)
21
 
22
+ # 2. Add a batch dimension because the model expects [batch, channels, height, width]
23
+ processed_image = processed_image.unsqueeze(0)
24
 
25
+ # 3. Run the image through the model
26
  with torch.no_grad():
27
+ model_output = model(processed_image)
28
+
29
+ # 4. Convert raw model scores into probabilities
30
+ probabilities = torch.nn.functional.softmax(model_output, dim=1)[0]
31
+
32
+ # 5. Choose how many results you want to show
33
+ number_of_predictions_to_show = 5
34
 
35
+ # 6. Get the top-k highest probability classes
36
+ top_probabilities, top_class_indices = torch.topk(probabilities, number_of_predictions_to_show)
 
37
 
38
+ # 7. Build a dictionary of the results
39
+ results = {}
40
+ for rank in range(number_of_predictions_to_show):
41
+ class_index = top_class_indices[rank].item()
42
+ class_name = CLASS_NAMES[class_index]
43
+ probability_value = top_probabilities[rank].item()
44
 
45
+ results[class_name] = probability_value
 
46
 
47
+ return results
48
 
49
 
50
  with gr.Blocks(title="Plant Disease Classifier") as app: