IFMedTechdemo commited on
Commit
1d3a8a8
·
verified ·
1 Parent(s): b256e42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -47
app.py CHANGED
@@ -1,71 +1,102 @@
 
1
  import torch
2
- from torchvision import models, transforms
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
- # Updated class names with 'plaque' in front of 'calculus' and 'gingivitis'
7
- class_names = [
8
- "Plaque Calculus",
9
- "Caries",
10
- "Plaque Gingivitis",
11
- "Hypodontia",
12
- "Mouth Ulcer",
13
- "Tooth Discoloration"
14
- ]
15
-
16
- # Load the model and update the final fully connected layer
17
- model = models.resnet50(weights=None)
18
- model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
19
-
20
- # Load the model weights from tooth_model.pth
21
- model.load_state_dict(torch.load('tooth_model.pth', map_location=torch.device('cpu')))
22
- model.eval()
23
 
24
- # Preprocessing steps for input images
25
- preprocess = transforms.Compose([
26
- transforms.Resize((224, 224)),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
- ])
30
 
31
- def predict_image(image):
32
- # Preprocess the image and add a batch dimension
33
- processed_image = preprocess(image).unsqueeze(0)
 
 
34
 
35
- with torch.no_grad():
36
- outputs = model(processed_image)
37
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
38
- top_probs, top_indices = torch.topk(probabilities, 2) # Get top 2 predictions
39
 
40
- top_class_1 = class_names[top_indices[0][0]]
41
- top_prob_1 = top_probs[0][0].item()
 
 
 
 
 
 
 
42
 
43
- # Initialize result with the top prediction
44
- result = top_class_1
 
 
 
45
 
46
- # Include the second prediction if the top prediction's probability is less than 80%
47
- if top_prob_1 < 0.8:
48
- top_class_2 = class_names[top_indices[0][1]]
49
- result += f", {top_class_2}"
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return result
52
 
53
- # Example images to use as input
54
  examples = [
55
  ["example_image1.jfif"],
56
  ["example_image2.jfif"],
57
  ["example_image3.jfif"]
58
  ]
59
 
60
- # Set up the Gradio interface
61
  iface = gr.Interface(
62
  fn=predict_image,
63
  inputs=gr.Image(type="pil"),
64
- outputs="text", # Output will be text listing the predictions
65
- title="Dental Image Classification",
66
- description="Upload an image or select from the examples below to predict its class. The model accounts for the possibility of multiple dental flaws in a single image.",
67
- examples=examples # Add example images
68
  )
69
 
70
- # Launch the interface
71
- iface.launch()
 
1
+ from huggingface_hub import hf_hub_download
2
  import torch
3
+ from transformers import ViTImageProcessor
4
  from PIL import Image
5
  import gradio as gr
6
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
 
8
+ # Repository configuration
9
+ REPO_ID = "IFMedTech/Dental_Q"
10
+ MODEL_FILENAME = "quantized_model.ptl"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Download the model file from Hugging Face Hub
13
+ model_path = hf_hub_download(
14
+ repo_id=REPO_ID,
15
+ filename=MODEL_FILENAME
16
+ )
 
17
 
18
+ # Download the processor files (assuming they're in the same repo)
19
+ processor_path = hf_hub_download(
20
+ repo_id=REPO_ID,
21
+ filename="preprocessor_config.json"
22
+ )
23
 
24
+ # Load Processor & Quantized Model
25
+ processor = ViTImageProcessor.from_pretrained(REPO_ID)
26
+ quantized_model = torch.jit.load(model_path, map_location="cpu")
27
+ quantized_model.eval()
28
 
29
+ # Define Inference Preprocessing
30
+ size = processor.size['height']
31
+ normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
32
+ inference_transform = Compose([
33
+ Resize(size),
34
+ CenterCrop(size),
35
+ ToTensor(),
36
+ normalize
37
+ ])
38
 
39
+ # Multi-label class names
40
+ try:
41
+ label_names = [quantized_model.config.id2label[i] for i in range(len(quantized_model.config.id2label))]
42
+ except AttributeError:
43
+ label_names = ["Background", "Caries", "Normal Teeth", "Plaque"]
44
 
45
+ def preprocess_image(image):
46
+ """Load and preprocess a PIL image."""
47
+ if not isinstance(image, Image.Image):
48
+ image = Image.fromarray(image)
49
+ image = image.convert("RGB")
50
+ return inference_transform(image).unsqueeze(0)
51
 
52
+ def predict_image(image):
53
+ """Run inference on image and return multi-label predictions."""
54
+ pixel_values = preprocess_image(image)
55
+
56
+ with torch.no_grad():
57
+ logits = quantized_model(pixel_values)
58
+
59
+ probs = torch.sigmoid(logits).squeeze(0)
60
+ preds = (probs > 0.5).int().tolist()
61
+
62
+ detected_conditions = []
63
+ for i, (label, pred) in enumerate(zip(label_names, preds)):
64
+ if pred == 1:
65
+ confidence = probs[i].item()
66
+ detected_conditions.append(f"{label} (confidence: {confidence:.2%})")
67
+
68
+ # Check for potential Caries
69
+ try:
70
+ caries_index = label_names.index("Caries")
71
+ caries_prob = probs[caries_index].item()
72
+ if 0.3 <= caries_prob < 0.5:
73
+ detected_conditions.append(f"Possible Caries (confidence: {caries_prob:.2%})")
74
+ except ValueError:
75
+ pass
76
+
77
+ if detected_conditions:
78
+ result = "Detected: " + ", ".join(detected_conditions)
79
+ else:
80
+ result = "No dental issues detected"
81
+
82
  return result
83
 
84
+ # Example images
85
  examples = [
86
  ["example_image1.jfif"],
87
  ["example_image2.jfif"],
88
  ["example_image3.jfif"]
89
  ]
90
 
91
+ # Gradio interface
92
  iface = gr.Interface(
93
  fn=predict_image,
94
  inputs=gr.Image(type="pil"),
95
+ outputs="text",
96
+ title="Dental Image Multi-Label Classification",
97
+ description="Upload an image or select from the examples below to predict dental conditions. The model can detect multiple dental issues in a single image.",
98
+ examples=examples
99
  )
100
 
101
+ if __name__ == "__main__":
102
+ iface.launch()