IFMedTechdemo commited on
Commit
0514a9d
·
verified ·
1 Parent(s): db79803

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -73
app.py CHANGED
@@ -1,30 +1,51 @@
1
- from huggingface_hub import hf_hub_download
2
  import torch
 
3
  from transformers import ViTImageProcessor
4
  from PIL import Image
5
  import gradio as gr
6
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
 
8
  # Repository configuration
9
  REPO_ID = "IFMedTech/Dental_Q"
10
  MODEL_FILENAME = "quantized_model.ptl"
11
 
12
- # Download the model file from Hugging Face Hub
13
- model_path = hf_hub_download(
14
- repo_id=REPO_ID,
15
- filename=MODEL_FILENAME
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Download the processor files (assuming they're in the same repo)
19
- processor_path = hf_hub_download(
20
- repo_id=REPO_ID,
21
- filename="preprocessor_config.json"
22
- )
 
 
 
 
 
 
 
 
23
 
24
- # Load Processor & Quantized Model
25
- processor = ViTImageProcessor.from_pretrained(REPO_ID)
26
- quantized_model = torch.jit.load(model_path, map_location="cpu")
27
- quantized_model.eval()
28
 
29
  # Define Inference Preprocessing
30
  size = processor.size['height']
@@ -43,60 +64,4 @@ except AttributeError:
43
  label_names = ["Background", "Caries", "Normal Teeth", "Plaque"]
44
 
45
  def preprocess_image(image):
46
- """Load and preprocess a PIL image."""
47
- if not isinstance(image, Image.Image):
48
- image = Image.fromarray(image)
49
- image = image.convert("RGB")
50
- return inference_transform(image).unsqueeze(0)
51
-
52
- def predict_image(image):
53
- """Run inference on image and return multi-label predictions."""
54
- pixel_values = preprocess_image(image)
55
-
56
- with torch.no_grad():
57
- logits = quantized_model(pixel_values)
58
-
59
- probs = torch.sigmoid(logits).squeeze(0)
60
- preds = (probs > 0.5).int().tolist()
61
-
62
- detected_conditions = []
63
- for i, (label, pred) in enumerate(zip(label_names, preds)):
64
- if pred == 1:
65
- confidence = probs[i].item()
66
- detected_conditions.append(f"{label} (confidence: {confidence:.2%})")
67
-
68
- # Check for potential Caries
69
- try:
70
- caries_index = label_names.index("Caries")
71
- caries_prob = probs[caries_index].item()
72
- if 0.3 <= caries_prob < 0.5:
73
- detected_conditions.append(f"Possible Caries (confidence: {caries_prob:.2%})")
74
- except ValueError:
75
- pass
76
-
77
- if detected_conditions:
78
- result = "Detected: " + ", ".join(detected_conditions)
79
- else:
80
- result = "No dental issues detected"
81
-
82
- return result
83
-
84
- # Example images
85
- examples = [
86
- ["example_image1.jfif"],
87
- ["example_image2.jfif"],
88
- ["example_image3.jfif"]
89
- ]
90
-
91
- # Gradio interface
92
- iface = gr.Interface(
93
- fn=predict_image,
94
- inputs=gr.Image(type="pil"),
95
- outputs="text",
96
- title="Dental Image Multi-Label Classification",
97
- description="Upload an image or select from the examples below to predict dental conditions. The model can detect multiple dental issues in a single image.",
98
- examples=examples
99
- )
100
-
101
- if __name__ == "__main__":
102
- iface.launch()
 
1
+ import os
2
  import torch
3
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
4
  from transformers import ViTImageProcessor
5
  from PIL import Image
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
 
9
  # Repository configuration
10
  REPO_ID = "IFMedTech/Dental_Q"
11
  MODEL_FILENAME = "quantized_model.ptl"
12
 
13
+ def download_model_from_hub():
14
+ """Download model from private Hugging Face repository"""
15
+ token = os.environ.get("HUGGINGFACE_TOKEN")
16
+
17
+ if not token:
18
+ raise ValueError(
19
+ "HUGGINGFACE_TOKEN environment variable is required for private repo access. "
20
+ "Please set it in your Space settings under 'Repository secrets'."
21
+ )
22
+
23
+ try:
24
+ model_path = hf_hub_download(
25
+ repo_id=REPO_ID,
26
+ filename=MODEL_FILENAME,
27
+ token=token
28
+ )
29
+ return model_path
30
+ except Exception as e:
31
+ raise RuntimeError(f"Failed to download model from {REPO_ID}: {str(e)}")
32
 
33
+ def load_model_and_processor():
34
+ """Load the model and processor"""
35
+ token = os.environ.get("HUGGINGFACE_TOKEN")
36
+
37
+ # Download and load model
38
+ model_path = download_model_from_hub()
39
+ quantized_model = torch.jit.load(model_path, map_location="cpu")
40
+ quantized_model.eval()
41
+
42
+ # Load processor from private repo
43
+ processor = ViTImageProcessor.from_pretrained(REPO_ID, token=token)
44
+
45
+ return quantized_model, processor
46
 
47
+ # Initialize model and processor
48
+ quantized_model, processor = load_model_and_processor()
 
 
49
 
50
  # Define Inference Preprocessing
51
  size = processor.size['height']
 
64
  label_names = ["Background", "Caries", "Normal Teeth", "Plaque"]
65
 
66
  def preprocess_image(image):
67
+ """Load and preprocess a