IFMedTech commited on
Commit
8b6ff1c
·
verified ·
1 Parent(s): af4d1dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -3,6 +3,7 @@ from torchvision import models, transforms
3
  from PIL import Image
4
  import gradio as gr
5
 
 
6
  class_names = [
7
  "plaque_calculus",
8
  "caries",
@@ -33,18 +34,37 @@ def predict_image(image):
33
 
34
  with torch.no_grad():
35
  outputs = model(processed_image)
36
- _, top_indices = torch.topk(outputs, 2) # Get top 2 predictions
37
- top_classes = [class_names[idx] for idx in top_indices[0]]
38
 
39
- return ", ".join(top_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Set up the Gradio interface
42
  iface = gr.Interface(
43
  fn=predict_image,
44
  inputs=gr.Image(type="pil"),
45
- outputs="text", # Output will be text listing
46
- title="Dental Classifier",
47
- description="Upload an image to predict its class."
 
48
  )
49
 
50
  # Launch the interface
 
3
  from PIL import Image
4
  import gradio as gr
5
 
6
+ # Updated class names with 'plaque' in front of 'calculus' and 'gingivitis'
7
  class_names = [
8
  "plaque_calculus",
9
  "caries",
 
34
 
35
  with torch.no_grad():
36
  outputs = model(processed_image)
37
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
38
+ top_probs, top_indices = torch.topk(probabilities, 2) # Get top 2 predictions
39
 
40
+ top_class_1 = class_names[top_indices[0][0]]
41
+ top_prob_1 = top_probs[0][0].item()
42
+
43
+ # Initialize result with the top prediction
44
+ result = top_class_1
45
+
46
+ # Include the second prediction if the top prediction's probability is less than 80%
47
+ if top_prob_1 < 0.8:
48
+ top_class_2 = class_names[top_indices[0][1]]
49
+ result += f", {top_class_2}"
50
+
51
+ return result
52
+
53
+ # Example images to use as input
54
+ examples = [
55
+ ["example_image1.jpg"],
56
+ ["example_image2.jpg"],
57
+ ["example_image3.jpg"]
58
+ ]
59
 
60
  # Set up the Gradio interface
61
  iface = gr.Interface(
62
  fn=predict_image,
63
  inputs=gr.Image(type="pil"),
64
+ outputs="text", # Output will be text listing the predictions
65
+ title="Medical Image Classification",
66
+ description="Upload an image to predict its class. Displays the top 2 predictions if needed.",
67
+ examples=examples # Add example images
68
  )
69
 
70
  # Launch the interface