amasood commited on
Commit
e3b59c9
·
verified ·
1 Parent(s): 5223021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -264
app.py CHANGED
@@ -1,292 +1,120 @@
1
- # app.py
2
 
3
- ```python
4
- # app.py
5
- # Hugging Face Space (Gradio) for Acne Type Classification + Explainability + Mistral Chatbot
6
 
7
- import os
8
- import io
9
- import requests
10
- from PIL import Image
11
- import numpy as np
12
- import matplotlib.pyplot as plt
13
- import gradio as gr
14
 
 
 
 
15
  import torch
16
- from torchvision import transforms, models
17
- from transformers import pipeline
18
-
19
- # Optional: Grad-CAM
20
- from pytorch_grad_cam import GradCAM
21
- from pytorch_grad_cam.utils.image import show_cam_on_image
22
-
23
- # -----------------------------
24
- # Config / Models
25
- # -----------------------------
26
-
27
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- # Zero-shot candidate labels: many acne types (try to cover maximum common types)
30
- ACNE_LABELS = [
31
- "Acne vulgaris",
32
- "Comedonal acne",
33
- "Inflammatory acne",
34
- "Papular acne",
35
- "Pustular acne",
36
- "Nodulocystic acne",
37
- "Nodular acne",
38
- "Cystic acne",
39
- "Conglobate acne",
40
- "Acne rosacea",
41
- "Hormonal acne",
42
- "Acne mechanica",
43
- "Acne keloidalis",
44
- "Acneiform eruption",
45
- "Post-inflammatory hyperpigmentation",
46
- "Milia",
47
- "Folliculitis",
48
- "Perioral dermatitis",
49
- "Seborrheic dermatitis",
50
- "Other skin lesion"
51
  ]
52
 
53
- # Zero-shot pipeline using CLIP (open-source)
54
- print("Loading zero-shot CLIP pipeline...")
55
- try:
56
- zsl_pipe = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
57
- except Exception as e:
58
- print("Failed to load CLIP pipeline:", e)
59
- zsl_pipe = None
60
-
61
- # A pretrained CNN for Grad-CAM (ResNet50)
62
- print("Loading ResNet50 for Grad-CAM...")
63
- resnet_model = models.resnet50(pretrained=True)
64
- resnet_model.eval()
65
- resnet_model.to(DEVICE)
66
-
67
- # We will use the final conv layer of resnet50
68
- CAM_TARGET_LAYER = resnet_model.layer4[-1]
69
-
70
- # Preprocessing for ResNet
71
- resnet_transforms = transforms.Compose([
72
- transforms.Resize((224, 224)),
73
- transforms.ToTensor(),
74
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
- std=[0.229, 0.224, 0.225])
76
- ])
77
-
78
- # Preprocessing for CLIP pipeline (it accepts PIL images directly)
79
-
80
- # -----------------------------
81
- # Utilities
82
- # -----------------------------
83
-
84
- def load_image_from_url(url_or_path):
85
- # Accepts either a URL or a local file path
86
  try:
87
- if url_or_path.startswith("http://") or url_or_path.startswith("https://"):
88
- resp = requests.get(url_or_path, timeout=10)
89
- resp.raise_for_status()
90
- img = Image.open(io.BytesIO(resp.content)).convert("RGB")
91
- else:
92
- img = Image.open(url_or_path).convert("RGB")
93
- return img
94
- except Exception as e:
95
- raise RuntimeError(f"Could not load image: {e}")
 
 
 
 
96
 
97
-
98
- def run_zero_shot(pil_img, candidate_labels=ACNE_LABELS, top_k=3):
99
- if zsl_pipe is None:
100
- return [("Model not loaded", 0.0)]
101
- results = zsl_pipe(pil_img, candidate_labels)
102
- # pipeline returns list of dicts or single dict depending on version
103
- if isinstance(results, list):
104
- res = results[0]
105
- else:
106
- res = results
107
- # Keep top_k
108
- out = list(zip(res.get("labels", []), res.get("scores", [])))[:top_k]
109
- return out
110
-
111
-
112
- def make_gradcam(pil_img):
113
- # returns overlayed cam image as PIL
114
- try:
115
- img_np = np.array(pil_img).astype(np.float32) / 255.0
116
- input_tensor = resnet_transforms(pil_img).unsqueeze(0).to(DEVICE)
117
-
118
- cam = GradCAM(model=resnet_model, target_layer=CAM_TARGET_LAYER, use_cuda=(DEVICE=="cuda"))
119
- grayscale_cam = cam(input_tensor=input_tensor, targets=None)
120
- grayscale_cam = grayscale_cam[0, :]
121
-
122
- visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
123
- vis_pil = Image.fromarray(visualization)
124
- return vis_pil
125
  except Exception as e:
126
- print("Grad-CAM failed:", e)
127
- return None
128
-
129
 
130
- def explain_acne_label(label):
131
- # Simple rule-based explanations for acne types. You can later replace with richer medical text
132
  explanations = {
133
- "Acne vulgaris": "Common acne characterized by comedones (blackheads/whiteheads), papules, pustules; seen on face, chest, back.",
134
- "Comedonal acne": "Non-inflammatory acne featuring comedones (blackheads and whiteheads).",
135
- "Inflammatory acne": "Inflammatory lesions such as papules and pustules, often red and tender.",
136
- "Papular acne": "Small red raised bumps (papules) often from inflammation of follicles.",
137
- "Pustular acne": "Pus-filled lesions (pustules) typically inflammatory.",
138
- "Nodulocystic acne": "Severe form with deep nodules and cysts; may cause scarring and requires clinical care.",
139
- "Nodular acne": "Large, painful lumps beneath the skin surface (nodules).",
140
- "Cystic acne": "Deep, inflamed cysts filled with pus that can be painful and lead to scarring.",
141
- "Conglobate acne": "Severe, widespread acne with interconnected nodules and abscesses — needs specialist treatment.",
142
- "Acne rosacea": "Rosacea-related acne-like bumps often with facial redness and flushing.",
143
- "Hormonal acne": "Flare-ups linked to hormonal changes; often in lower face and jawline.",
144
- "Acne mechanica": "Triggered by friction, pressure, or occlusion (e.g., helmets, masks).",
145
- "Acne keloidalis": "Keloid-like papules typically on the back of the neck — more common in men of African descent.",
146
- "Acneiform eruption": "Acne-like lesions caused by medications or other triggers.",
147
- "Post-inflammatory hyperpigmentation": "Dark spots left after acne lesions heal; common in darker skin tones.",
148
- "Milia": "Small white cysts often mistaken for closed comedones.",
149
- "Folliculitis": "Infection/inflammation of hair follicles that can look like acne.",
150
- "Perioral dermatitis": "Rash around the mouth that may resemble acne but has different causes.",
151
- "Seborrheic dermatitis": "Greasy scales and redness that can coexist or be mistaken for acne.",
152
- "Other skin lesion": "Lesion not typical for acne; consider dermatology consultation."
153
  }
154
- return explanations.get(label, "No explanation available for this label.")
155
 
156
-
157
- # -----------------------------
158
- # Gradio Interface
159
- # -----------------------------
160
-
161
- def classify_and_explain(image_input, image_url):
162
- # image_input is from file uploader, image_url is optional
163
- pil_img = None
164
- try:
165
- if image_input is not None:
166
- pil_img = Image.fromarray(image_input).convert("RGB")
167
- elif image_url:
168
- pil_img = load_image_from_url(image_url)
169
- else:
170
- return "No image provided", None, None
171
- except Exception as e:
172
- return f"Error loading image: {e}", None, None
173
-
174
- # Run zero-shot classification
175
- try:
176
- zsl = run_zero_shot(pil_img, ACNE_LABELS, top_k=3)
177
- except Exception as e:
178
- zsl = [("Error", 0.0)]
179
- print("ZSL error:", e)
180
-
181
- top_label, top_score = zsl[0]
182
- explanation = explain_acne_label(top_label)
183
-
184
- # Grad-CAM image
185
- cam_img = make_gradcam(pil_img)
186
-
187
- # Prepare textual output + simple suggestion
188
- suggestion = "This result is for informational purposes only and does not substitute a medical diagnosis. For severe or persistent acne, consult a dermatologist."
189
-
190
- text_out = f"**Detected:** {top_label} (score: {top_score:.2f})\n\n**Explanation:** {explanation}\n\n{suggestion}"
191
-
192
- return text_out, pil_img, cam_img
193
-
194
-
195
- # Simple Mistral chat wrapper — this will attempt to call a Mistral-style API if configured
196
  MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
197
- MISTRAL_API_URL = os.getenv("https://api.mistral.ai/v1/chat/completions") # e.g., set to your model endpoint
198
-
199
-
200
- def mistral_chat(user_message, context_summary=None):
201
- # If no API configured, return a fallback response
202
- if not MISTRAL_API_KEY or not MISTRAL_API_URL:
203
- fallback = (
204
- "[Local fallback] I don't have a Mistral API key configured in the environment.\n"
205
- "You can set MISTRAL_API_KEY and MISTRAL_API_URL in the Space secrets.\n"
206
- "Meanwhile, here's a simple suggestion: maintain gentle skincare, avoid aggressive scrubs, consult a dermatologist for nodulocystic acne."
207
- )
208
- return fallback
209
-
210
- headers = {
211
- "Authorization": f"Bearer {MISTRAL_API_KEY}",
212
- "Content-Type": "application/json",
213
- }
214
-
215
- prompt = (
216
- "You are a helpful dermatology assistant. Answer user queries about acne types, treatments, risks, and when to seek a doctor. "
217
- f"Context: {context_summary}\nUser: {user_message}\nAssistant:"
218
- )
219
-
220
- payload = {
221
- "prompt": prompt,
222
- "max_tokens": 300,
223
- "temperature": 0.2
224
- }
225
 
 
226
  try:
227
- r = requests.post(MISTRAL_API_URL, json=payload, headers=headers, timeout=20)
228
- r.raise_for_status()
229
- jr = r.json()
230
- # Different Mistral endpoints return different shapes attempt to extract text
231
- text = jr.get("text") or jr.get("generated_text") or jr.get("output") or str(jr)
232
- return text
 
 
 
 
233
  except Exception as e:
234
- print("Mistral API call failed:", e)
235
- return "Mistral API call failed or returned an unexpected response. Check logs and your MISTRAL_API_URL/MISTRAL_API_KEY."
236
-
237
 
238
- with gr.Blocks(title="DermaBot - Acne Type Classifier") as demo:
239
- gr.Markdown("# DermaBot — Acne Type Classification (Zero-shot)\nUpload an image or provide a URL. The app will try to classify acne subtype using a CLIP zero-shot model and show a Grad-CAM heatmap from ResNet50. Use the chat box for follow-up questions (Mistral API optional).")
240
-
241
- with gr.Row():
242
- with gr.Column(scale=1):
243
- image_input = gr.Image(type="numpy", label="Upload image (face/chest/back)")
244
- image_url = gr.Textbox(label="Or provide an image URL (http://...)")
245
- classify_btn = gr.Button("Classify & Explain")
246
- output_text = gr.Markdown()
247
- with gr.Column(scale=1):
248
- image_display = gr.Image(label="Input Image", interactive=False)
249
- cam_display = gr.Image(label="Grad-CAM Overlay", interactive=False)
250
-
251
- classify_btn.click(fn=classify_and_explain, inputs=[image_input, image_url], outputs=[output_text, image_display, cam_display])
252
 
253
- gr.Markdown("---")
254
- gr.Markdown("## Ask about the detected acne (Chatbot) — optional Mistral API")
 
255
 
256
- with gr.Row():
257
- chat_input = gr.Textbox(label="Your question to the assistant")
258
- chat_btn = gr.Button("Send")
259
- chat_output = gr.Textbox(label="Assistant reply", lines=6)
260
 
261
- def chat_with_context(user_q, txt_out):
262
- # We pass the last detection summary as context to the chatbot to make answers specific
263
- context_summary = txt_out if txt_out else "No prior detection available."
264
- return mistral_chat(user_q, context_summary=context_summary)
265
 
266
- chat_btn.click(fn=chat_with_context, inputs=[chat_input, output_text], outputs=chat_output)
 
267
 
268
- gr.Markdown("\n---\n**Notes:** Set `MISTRAL_API_KEY` and `MISTRAL_API_URL` in Space secrets to enable the chatbot. The app uses zero-shot CLIP classification and a ResNet50 Grad-CAM visualization for explainability. Results are informational only.")
269
 
 
270
 
271
  if __name__ == "__main__":
272
- demo.launch()
273
- ```
274
-
275
- ---
276
-
277
- # requirements.txt
278
-
279
- ```text
280
- gradio
281
- transformers
282
- torch
283
- torchvision
284
- pillow
285
- requests
286
- pytorch-grad-cam
287
- matplotlib
288
- timm
289
- \# Optional but useful
290
- scikit-learn
291
 
292
- ```
 
1
+ # app.py and requirements.txt for Hugging Face Space: Acne Type Classifier
2
 
3
+ ---
 
 
4
 
5
+ ## app.py
 
 
 
 
 
 
6
 
7
+ ```python
8
+ import gradio as gr
9
+ import requests
10
  import torch
11
+ from PIL import Image
12
+ from transformers import CLIPProcessor, CLIPModel
13
+ import os
14
+ import io
 
 
 
 
 
 
 
 
15
 
16
+ # Initialize model and processor for zero-shot acne classification
17
+ model_id = "openai/clip-vit-base-patch32"
18
+ model = CLIPModel.from_pretrained(model_id)
19
+ processor = CLIPProcessor.from_pretrained(model_id)
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ model.to(device)
23
+
24
+ # Define acne types for classification
25
+ acne_types = [
26
+ "blackheads",
27
+ "whiteheads",
28
+ "papules",
29
+ "pustules",
30
+ "nodules",
31
+ "cysts",
32
+ "fungal acne",
33
+ "acne scars",
34
+ "mild acne",
35
+ "moderate acne",
36
+ "severe acne"
 
37
  ]
38
 
39
+ # Function to classify acne type from image URL
40
+ def classify_acne(image_url):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  try:
42
+ response = requests.get(image_url)
43
+ image = Image.open(io.BytesIO(response.content)).convert("RGB")
44
+ inputs = processor(text=acne_types, images=image, return_tensors="pt", padding=True).to(device)
45
+ outputs = model(**inputs)
46
+ logits_per_image = outputs.logits_per_image
47
+ probs = logits_per_image.softmax(dim=1).cpu().detach().numpy()[0]
48
+ result_idx = probs.argmax()
49
+ detected_type = acne_types[result_idx]
50
+ confidence = probs[result_idx]
51
+
52
+ explanation = f"Detected acne type: **{detected_type}** (confidence: {confidence:.2f}).\\n\\n"
53
+ explanation += explain_acne(detected_type)
54
+ return image, explanation
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
57
+ return None, f"Error: {str(e)}"
 
 
58
 
59
+ # Function to give acne explanation
60
+ def explain_acne(acne_type):
61
  explanations = {
62
+ "blackheads": "Blackheads are open clogged pores with oxidized oil, often caused by excess sebum and dead skin.",
63
+ "whiteheads": "Whiteheads are closed clogged pores that form small white bumps on the skin.",
64
+ "papules": "Papules are small red bumps caused by inflamed hair follicles.",
65
+ "pustules": "Pustules are pimples with visible pus, indicating bacterial infection.",
66
+ "nodules": "Nodules are large, painful acne lesions deep under the skin.",
67
+ "cysts": "Cystic acne involves pus-filled lesions beneath the skin, often leading to scars.",
68
+ "fungal acne": "Fungal acne is caused by yeast infection and appears similar to whiteheads.",
69
+ "acne scars": "Scars are skin indentations or pigmentation left after acne heals.",
70
+ "mild acne": "Mild acne includes occasional blackheads or small pimples.",
71
+ "moderate acne": "Moderate acne has more inflamed pimples and occasional nodules.",
72
+ "severe acne": "Severe acne includes numerous inflamed cysts and nodules, often painful."
 
 
 
 
 
 
 
 
 
73
  }
74
+ return explanations.get(acne_type, "No explanation available.")
75
 
76
+ # Chatbot section using Mistral API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
78
+ API_URL = "https://api.mistral.ai/v1/chat/completions"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def chat_with_bot(message, history):
81
  try:
82
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
83
+ payload = {
84
+ "model": "mistral-small",
85
+ "messages": [{"role": "system", "content": "You are a dermatologist assistant chatbot."}] +
86
+ [{"role": "user", "content": message}],
87
+ }
88
+ response = requests.post(API_URL, headers=headers, json=payload)
89
+ data = response.json()
90
+ reply = data["choices"][0]["message"]["content"]
91
+ return reply
92
  except Exception as e:
93
+ return f"Error communicating with chatbot: {str(e)}"
 
 
94
 
95
+ # Gradio Interface
96
+ def main_app():
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("# 🧴 AI-Powered Acne Type Classifier")
99
+ gr.Markdown("Enter an image URL of your face or acne region to detect the acne type.")
 
 
 
 
 
 
 
 
 
100
 
101
+ with gr.Row():
102
+ image_url = gr.Textbox(label="Image URL")
103
+ classify_btn = gr.Button("Classify Acne Type")
104
 
105
+ image_display = gr.Image(label="Input Image")
106
+ output_text = gr.Markdown()
 
 
107
 
108
+ classify_btn.click(classify_acne, inputs=image_url, outputs=[image_display, output_text])
 
 
 
109
 
110
+ gr.Markdown("---")
111
+ gr.Markdown("### 💬 Chat with Dermatologist Assistant")
112
 
113
+ chatbot = gr.ChatInterface(fn=chat_with_bot, title="Acne Query Chatbot")
114
 
115
+ return demo
116
 
117
  if __name__ == "__main__":
118
+ app = main_app()
119
+ app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120