IFMedTech commited on
Commit
57b42a5
·
verified ·
1 Parent(s): 3acab63

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ class_names = [
7
+ "plaque_calculus",
8
+ "caries",
9
+ "plaque_gingivitis",
10
+ "hypodontia",
11
+ "mouth_ulcer",
12
+ "tooth_discoloration"
13
+ ]
14
+
15
+ # Load the model and update the final fully connected layer
16
+ model = models.resnet50(weights=None)
17
+ model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
18
+
19
+ # Load the model weights from tooth_model.pth
20
+ model.load_state_dict(torch.load('tooth_model.pth', map_location=torch.device('cpu')))
21
+ model.eval()
22
+
23
+ # Preprocessing steps for input images
24
+ preprocess = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28
+ ])
29
+
30
+ def predict_image(image):
31
+ # Preprocess the image and add a batch dimension
32
+ processed_image = preprocess(image).unsqueeze(0)
33
+
34
+ with torch.no_grad():
35
+ outputs = model(processed_image)
36
+ _, top_indices = torch.topk(outputs, 2) # Get top 2 predictions
37
+ top_classes = [class_names[idx] for idx in top_indices[0]]
38
+
39
+ return ", ".join(top_classes)
40
+
41
+ # Set up the Gradio interface
42
+ iface = gr.Interface(
43
+ fn=predict_image,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs="text", # Output will be text listing
46
+ title="Medical Image Classification",
47
+ description="Upload an image to predict its class."
48
+ )
49
+
50
+ # Launch the interface
51
+ iface.launch()