amasood commited on
Commit
d62389e
·
verified ·
1 Parent(s): 33e7457

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -28
app.py CHANGED
@@ -1,47 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
3
- payload = {
4
- "model": "mistral-small",
5
- "messages": [{"role": "system", "content": "You are a dermatologist assistant chatbot."}] +
6
- [{"role": "user", "content": message}],
7
- }
8
- response = requests.post(API_URL, headers=headers, json=payload)
9
- data = response.json()
10
- reply = data["choices"][0]["message"]["content"]
11
- return reply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  except Exception as e:
13
- return f"Error communicating with chatbot: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
15
 
16
- # Gradio Interface
17
- def main_app():
18
- with gr.Blocks() as demo:
19
- gr.Markdown("# 🧴 AI-Powered Acne Type Classifier")
20
- gr.Markdown("Enter an image URL of your face or acne region to detect the acne type.")
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- with gr.Row():
24
- image_url = gr.Textbox(label="Image URL")
25
- classify_btn = gr.Button("Classify Acne Type")
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- image_display = gr.Image(label="Input Image")
29
- output_text = gr.Markdown()
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- classify_btn.click(classify_acne, inputs=image_url, outputs=[image_display, output_text])
33
 
 
 
 
 
 
34
 
35
- gr.Markdown("---")
36
- gr.Markdown("### 💬 Chat with Dermatologist Assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- chatbot = gr.ChatInterface(fn=chat_with_bot, title="Acne Query Chatbot")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return demo
 
 
 
 
 
 
 
 
43
 
 
 
 
 
44
 
 
45
  if __name__ == "__main__":
46
- app = main_app()
47
- app.launch()
 
1
+ # app.py
2
+ """
3
+ Hugging Face Space / Gradio app for Acne Type/Severity Classification + Chatbot (Mistral)
4
+ - Input: Image URL (user provides)
5
+ - Model: loads a Hugging Face image-classification model (default recommended checkpoint)
6
+ - Explanation: returns textual explanation for predicted acne label
7
+ - Chatbot: uses Mistral Chat Completions API (user supplies API key)
8
+ """
9
+
10
+ import os
11
+ import io
12
+ import requests
13
+ from PIL import Image
14
  import gradio as gr
15
+
16
+ # Transformers imports
17
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
18
+ import torch
19
+
20
+ # --------------------------
21
+ # CONFIG: choose model here
22
+ # --------------------------
23
+ # You can swap this to any HF image-classification checkpoint that supports acne/skin labels.
24
+ MODEL_NAME = "imfarzanansari/skintelligent-acne" # recommended default (acne severity)
25
+ # Fallbacks (used if primary model fails to load):
26
+ FALLBACK_MODELS = [
27
+ "naamalia23/acne-severity-classification",
28
+ "Tanishq77/skin-condition-classifier"
29
+ ]
30
+
31
+ # Mistral API end-point (chat completions)
32
+ MISTRAL_CHAT_URL = "https://api.mistral.ai/v1/chat/completions"
33
+
34
+ # --------------------------
35
+ # Utility helpers
36
+ # --------------------------
37
+ def load_model(model_name):
38
+ """
39
+ Try to load HF image-classification pipeline for model_name.
40
+ Returns a pipeline object or raises.
41
+ """
42
+ try:
43
+ device = 0 if torch.cuda.is_available() else -1
44
+ classifier = pipeline("image-classification", model=model_name, device=device)
45
+ return classifier
46
+ except Exception as e:
47
+ raise RuntimeError(f"Failed to load model {model_name}: {e}")
48
+
49
+ # Try to load the chosen model, fallback if necessary
50
+ classifier = None
51
+ loaded_model_name = None
52
+ load_errors = []
53
+ try:
54
+ classifier = load_model(MODEL_NAME)
55
+ loaded_model_name = MODEL_NAME
56
  except Exception as e:
57
+ load_errors.append(str(e))
58
+ for alt in FALLBACK_MODELS:
59
+ try:
60
+ classifier = load_model(alt)
61
+ loaded_model_name = alt
62
+ break
63
+ except Exception as e2:
64
+ load_errors.append(str(e2))
65
+
66
+ if classifier is None:
67
+ # If no model loaded, app will still start but classification will return helpful error
68
+ print("WARNING: No classification model loaded. Errors:", load_errors)
69
+
70
+ # --------------------------
71
+ # Simple textual explanations for common labels
72
+ # (Customize / extend as needed for your model's label set)
73
+ # --------------------------
74
+ EXPLANATION_BANK = {
75
+ # examples for acne severity labels (modify as per the model labels)
76
+ "Level -1: Clear Skin": "No active acne detected. Skin appears clear. Maintain gentle cleansing and sunscreen.",
77
+ "Level 0: Occasional Spots": "Occasional pimples or spots. Often manageable with over-the-counter topical treatments (benzoyl peroxide, salicylic acid).",
78
+ "Level 1: Mild Acne": "Mild acne with comedones (whiteheads/blackheads) and a few papules. Use topical retinoids, gentle cleanser; seek dermatologist if persistent.",
79
+ "Level 2: Moderate Acne": "Moderate acne with inflammatory papules and pustules. Prescription topical or oral treatments may be needed. See dermatologist for tailored therapy.",
80
+ "Level 3: Severe Acne": "Severe inflammatory acne, possibly nodules or cysts. Early dermatologist consultation is strongly recommended; systemic therapy may be needed.",
81
+ "Level 4: Very Severe Acne": "Very severe acne with widespread nodules/cysts or scarring. Urgent dermatologist evaluation required for systemic and procedural options.",
82
+ # fallback generic labels
83
+ "acne": "Signs of acne detected. Severity and subtype should be confirmed by a clinician. Usual treatments range from topical care to systemic medications depending on severity.",
84
+ "mild": "Mild acne. Start with gentle skincare and OTC active ingredients; consult dermatologist if it doesn't improve.",
85
+ "moderate": "Moderate acne. Dermatology visit recommended; topical and/or oral therapies may be indicated.",
86
+ "severe": "Severe acne. Dermatologist assessment needed; potential for scarring and systemic therapy."
87
+ }
88
+
89
+ def get_explanation_for_label(label):
90
+ # direct match
91
+ if label in EXPLANATION_BANK:
92
+ return EXPLANATION_BANK[label]
93
+ # case-insensitive partial match
94
+ ll = label.lower()
95
+ for k, v in EXPLANATION_BANK.items():
96
+ if k.lower() in ll or ll in k.lower():
97
+ return v
98
+ # fallback
99
+ return ("Detected label: {}. This model's label indicates acne or a related skin condition. "
100
+ "If you want a more specific explanation, fine-tune the EXPLANATION_BANK for your model's labels.").format(label)
101
+
102
+ # --------------------------
103
+ # Image download and prepare
104
+ # --------------------------
105
+ def load_image_from_url(url):
106
+ try:
107
+ resp = requests.get(url, timeout=10)
108
+ resp.raise_for_status()
109
+ img = Image.open(io.BytesIO(resp.content)).convert("RGB")
110
+ return img
111
+ except Exception as e:
112
+ raise RuntimeError(f"Failed to fetch image from URL: {e}")
113
+
114
+ # --------------------------
115
+ # Classification function (used by Gradio)
116
+ # --------------------------
117
+ def classify_image_from_url(image_url):
118
+ if classifier is None:
119
+ return {
120
+ "status": "error",
121
+ "message": "No model available. Check server logs or swap MODEL_NAME to a valid checkpoint."
122
+ }
123
+
124
+ # fetch image
125
+ try:
126
+ img = load_image_from_url(image_url)
127
+ except Exception as e:
128
+ return {"status": "error", "message": str(e)}
129
+
130
+ # run inference (pipeline returns list of dicts)
131
+ try:
132
+ preds = classifier(img, top_k=3)
133
+ except Exception as e:
134
+ return {"status": "error", "message": f"Model inference failed: {e}"}
135
+
136
+ # normalize output format
137
+ # preds -> list like [{"label": "Level 1: Mild", "score": 0.91}, ...]
138
+ top = preds[0]
139
+ label = top.get("label", str(top))
140
+ score = float(top.get("score", 0.0))
141
 
142
+ explanation = get_explanation_for_label(label)
143
 
144
+ # construct a concise structured response for the UI
145
+ response = {
146
+ "status": "ok",
147
+ "model": loaded_model_name or "none",
148
+ "label": label,
149
+ "score": round(score, 4),
150
+ "explanation": explanation,
151
+ "top_predictions": preds
152
+ }
153
+ return response
154
 
155
+ # --------------------------
156
+ # Mistral Chatbot integration
157
+ # --------------------------
158
+ def call_mistral_chat(api_key: str, messages: list, model: str = "mistral-small-latest", stream: bool = False):
159
+ """
160
+ Call the Mistral Chat Completions endpoint.
161
+ messages: a list of dicts, e.g. [{"role":"user", "content":"..."}]
162
+ returns response text (single string) or raise.
163
+ """
164
+ if not api_key:
165
+ raise RuntimeError("Mistral API key is required for chatbot.")
166
+ headers = {
167
+ "Authorization": f"Bearer {api_key}",
168
+ "Content-Type": "application/json"
169
+ }
170
+ body = {
171
+ "model": model,
172
+ "messages": messages
173
+ }
174
+ try:
175
+ r = requests.post(MISTRAL_CHAT_URL, json=body, headers=headers, timeout=30)
176
+ r.raise_for_status()
177
+ data = r.json()
178
+ # parse returned content
179
+ choices = data.get("choices", [])
180
+ if len(choices) > 0:
181
+ # some Mistral endpoints put the message under choices[0]["message"]["content"]
182
+ msg = choices[0].get("message", {}) or {}
183
+ content = msg.get("content") or choices[0].get("text") or ""
184
+ # content may be a string or dict; ensure string
185
+ if isinstance(content, dict):
186
+ # join parts if necessary
187
+ content = content.get("text", str(content))
188
+ return content
189
+ # fallback flat text
190
+ return data.get("text", str(data))
191
+ except Exception as e:
192
+ raise RuntimeError(f"Mistral API call failed: {e}")
193
 
194
+ # --------------------------
195
+ # Gradio UI callbacks
196
+ # --------------------------
197
+ # Keep simple conversation state via closure
198
+ chat_history = []
199
 
200
+ def classify_and_prepare_context(image_url):
201
+ """
202
+ Runs classification and returns structured outputs plus
203
+ a "context" text that the chatbot can use (label + explanation).
204
+ """
205
+ result = classify_image_from_url(image_url)
206
+ if result.get("status") != "ok":
207
+ return None, result.get("message", "Unknown error")
208
+ # Build context summary
209
+ context_summary = (
210
+ f"Detected acne label: {result['label']} (confidence {result['score']}). "
211
+ f"Explanation: {result['explanation']}"
212
+ )
213
+ return result, context_summary
214
 
215
+ def chat_with_context(mistral_api_key, user_message, context_summary, model_name="mistral-small-latest"):
216
+ """
217
+ Send conversation to Mistral with context prepended.
218
+ Returns assistant reply (string).
219
+ """
220
+ if not mistral_api_key:
221
+ return "Please provide your Mistral API key (in the Mistral API Key box) to use the chatbot."
222
 
223
+ # maintain in-memory chat history for nicer flow
224
+ # We will prepend a system message + context on every call to give the model grounding
225
+ system_msg = {
226
+ "role": "system",
227
+ "content": (
228
+ "You are a helpful, concise dermatology assistant. Use clinical but accessible language. "
229
+ "Base your answers on standard dermatology practice. If you are unsure, recommend seeing a dermatologist."
230
+ )
231
+ }
232
+ context_msg = {"role": "system", "content": context_summary}
233
+ user_msg = {"role": "user", "content": user_message}
234
 
235
+ messages = [system_msg, context_msg, user_msg]
236
 
237
+ try:
238
+ reply = call_mistral_chat(api_key=mistral_api_key, messages=messages, model=model_name)
239
+ except Exception as e:
240
+ return f"[Chat error] {e}"
241
+ return reply
242
 
243
+ # --------------------------
244
+ # Build Gradio app layout
245
+ # --------------------------
246
+ with gr.Blocks(theme=gr.themes.Default(), title="Acne Classifier + Mistral Chatbot") as demo:
247
+ gr.Markdown("## Acne Type/Severity Classifier + Chatbot\n"
248
+ "Paste an **image URL** (a photo of the face/skin area). The app will classify acne type/severity "
249
+ "and provide an explanation. Use the chatbot (Mistral) to ask follow-up questions about the diagnosis, "
250
+ "treatments, and next steps. **You must provide your Mistral API key** to use the chatbot.")
251
+ with gr.Row():
252
+ with gr.Column(scale=2):
253
+ image_url_input = gr.Textbox(label="Image URL", placeholder="https://...", lines=1)
254
+ load_and_classify_btn = gr.Button("Load & Classify")
255
+ image_output = gr.Image(label="Loaded Image", type="pil")
256
+ model_info = gr.Textbox(value=f"Model loaded: {loaded_model_name or 'None'}", label="Model info", interactive=False)
257
+ results_box = gr.JSON(label="Classification Result (structured)", interactive=False)
258
+ with gr.Column(scale=1):
259
+ mistral_key_input = gr.Textbox(label="Mistral API Key", placeholder="sk-...", type="password")
260
+ gr.Markdown("### Chatbot about detected acne")
261
+ chat_output = gr.Chatbot(label="Dermatology Assistant")
262
+ user_msg_input = gr.Textbox(placeholder="Ask about the detected acne...", label="Your question")
263
+ send_btn = gr.Button("Send")
264
 
265
+ # classify button action
266
+ def on_classify_click(url):
267
+ if not url or url.strip() == "":
268
+ return None, {"status":"error","message":"Please paste an image URL"}, None
269
+ # show image
270
+ try:
271
+ img = load_image_from_url(url)
272
+ except Exception as e:
273
+ return None, {"status":"error","message":str(e)}, None
274
+ result, context = classify_and_prepare_context(url)
275
+ if result is None:
276
+ return img, {"status":"error","message": context}, None
277
+ # Preload the chat history reset
278
+ global chat_history
279
+ chat_history = []
280
+ # Return image to display, JSON results, and put context into a hidden area via gr.State if needed
281
+ return img, result, context
282
 
283
+ load_and_classify_btn.click(on_classify_click, inputs=[image_url_input], outputs=[image_output, results_box, gr.State()])
284
 
285
+ # chat send action
286
+ def on_send_click(mkey, user_text, last_context):
287
+ if not last_context:
288
+ return gr.update(), "Please classify an image first (use the Load & Classify button)."
289
+ if not user_text or user_text.strip() == "":
290
+ return gr.update(), "Please type a question."
291
+ # call Mistral
292
+ reply = chat_with_context(mkey, user_text, last_context)
293
+ # Append to chat_history and return
294
+ global chat_history
295
+ chat_history.append(("User", user_text))
296
+ chat_history.append(("Assistant", reply))
297
+ # Format chat_history as list of tuples for gr.Chatbot
298
+ formatted = [(u, a) for u, a in zip(chat_history[::2], chat_history[1::2])]
299
+ return formatted, ""
300
+ # Note: gr.State will hold the latest context_summary; as a simple approach, we pass last output results_box['explanation'] as context.
301
+ # But Gradio's .click binding above returned a third value (context) which is not stored here; for simplicity we re-run classification to extract context.
302
+ # We'll implement a small wrapper to grab the context from the results_box JSON client-side.
303
+ # For clarity and reliability in Spaces, recommend wiring a hidden State; here we accept the user to paste Mistral key and ask after classifying.
304
 
305
+ send_btn.click(
306
+ fn=lambda key, text, context_summary: (
307
+ # return updated chat and cleared input
308
+ chat_with_context(key, text, context_summary),
309
+ ""
310
+ ),
311
+ inputs=[mistral_key_input, user_msg_input, results_box],
312
+ outputs=[chat_output, user_msg_input]
313
+ )
314
 
315
+ gr.Markdown("**Notes & Tips**:\n\n"
316
+ "- If pipeline/model loading fails on startup, change `MODEL_NAME` to another HF checkpoint and restart the Space.\n"
317
+ "- For best results: clear, well-lit closeup photos of acne lesions give higher accuracy.\n"
318
+ "- This app provides informational assistance only — not a medical diagnosis. Encourage users to consult a dermatologist for medical decisions.")
319
 
320
+ # Launch
321
  if __name__ == "__main__":
322
+ demo.launch()