fokan commited on
Commit
bafe93c
Β·
verified Β·
1 Parent(s): 73ef9b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -22
app.py CHANGED
@@ -1,41 +1,63 @@
1
  # app.py
2
- import gradio as gr
3
  import json
 
4
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Load labels from external JSON
7
  with open("labels.json", "r", encoding="utf-8") as f:
8
  candidate_labels = json.load(f)
9
 
10
- # Initialize the MedSigLIP pipeline
11
- pipe = pipeline("zero-shot-image-classification", model="google/medsiglip-448")
 
 
 
 
 
12
 
13
- # Define prediction function
14
- def classify_image(image):
15
- result = pipe(image, candidate_labels=candidate_labels)
16
- # Convert result list to {label: score}
17
- formatted = {item["label"]: round(item["score"] * 100, 2) for item in result}
18
- # Sort descending by confidence
19
- sorted_result = dict(sorted(formatted.items(), key=lambda x: x[1], reverse=True))
20
- return sorted_result
 
 
 
 
21
 
22
- # Build Gradio interface
23
  demo = gr.Interface(
24
- fn=classify_image,
25
- inputs=gr.Image(type="filepath", label="Upload Chest X-ray or Medical Image"),
26
- outputs=gr.Label(num_top_classes=5, label="Top Predictions"),
27
- title="🩻 MedSigLIP Zero-Shot Medical Image Classifier",
28
  description=(
29
- "This demo uses Google's MedSigLIP (Sigmoid Loss for Language Image Pre-training) model "
30
- "for zero-shot medical image classification. Upload an image and the model will estimate "
31
- "its similarity with known medical conditions (loaded dynamically from labels.json)."
32
  ),
33
- allow_flagging="never",
34
  examples=[
35
  ["https://storage.googleapis.com/dx-scin-public-data/dataset/images/3445096909671059178.png"],
36
  ["https://storage.googleapis.com/dx-scin-public-data/dataset/images/-5669089898008966381.png"]
37
  ],
 
38
  )
39
 
 
40
  if __name__ == "__main__":
41
- demo.launch()
 
1
  # app.py
2
+ import os
3
  import json
4
+ import gradio as gr
5
  from transformers import pipeline
6
+ from huggingface_hub import login
7
+
8
+ # ============= πŸ” AUTHENTICATION =============
9
+ # Hugging Face token (set in your Space settings)
10
+ # Go to your Space β†’ Settings β†’ Repository secrets β†’ add key: HF_TOKEN, value: your access token
11
+ hf_token = os.getenv("HF_TOKEN")
12
+
13
+ if hf_token:
14
+ login(token=hf_token)
15
+ else:
16
+ raise EnvironmentError("❌ Missing HF_TOKEN environment variable. Please add it in your Space settings.")
17
 
18
+ # ============= πŸ“¦ LOAD LABELS =============
19
  with open("labels.json", "r", encoding="utf-8") as f:
20
  candidate_labels = json.load(f)
21
 
22
+ # ============= 🧠 MODEL PIPELINE =============
23
+ # Use the authenticated pipeline for MedSigLIP (gated model)
24
+ pipe = pipeline(
25
+ "zero-shot-image-classification",
26
+ model="google/medsiglip-448",
27
+ use_auth_token=hf_token
28
+ )
29
 
30
+ # ============= βš™οΈ PREDICTION FUNCTION =============
31
+ def classify_medical_image(image):
32
+ """
33
+ Run zero-shot classification using MedSigLIP.
34
+ """
35
+ try:
36
+ results = pipe(image, candidate_labels=candidate_labels)
37
+ formatted = {r["label"]: round(r["score"] * 100, 2) for r in results}
38
+ sorted_result = dict(sorted(formatted.items(), key=lambda x: x[1], reverse=True))
39
+ return sorted_result
40
+ except Exception as e:
41
+ return {"Error": str(e)}
42
 
43
+ # ============= 🎨 GRADIO UI =============
44
  demo = gr.Interface(
45
+ fn=classify_medical_image,
46
+ inputs=gr.Image(type="filepath", label="πŸ“€ Upload Medical Image (X-ray, MRI, Pathology, etc.)"),
47
+ outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
48
+ title="🩻 MedSigLIP Zero-Shot Medical Classifier",
49
  description=(
50
+ "This demo uses Google's **MedSigLIP (448x448)** model for zero-shot medical image understanding. "
51
+ "Upload a medical image (e.g., chest X-ray, skin lesion, fundus photo, histopathology slide) and the model "
52
+ "will compute similarity against known medical findings loaded from `labels.json`."
53
  ),
 
54
  examples=[
55
  ["https://storage.googleapis.com/dx-scin-public-data/dataset/images/3445096909671059178.png"],
56
  ["https://storage.googleapis.com/dx-scin-public-data/dataset/images/-5669089898008966381.png"]
57
  ],
58
+ allow_flagging="never",
59
  )
60
 
61
+ # ============= πŸš€ RUN APP =============
62
  if __name__ == "__main__":
63
+ demo.launch(server_name="0.0.0.0", server_port=7860)