fokan commited on
Commit
1739e75
Β·
verified Β·
1 Parent(s): 62d8126

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -1,15 +1,14 @@
1
  # app.py
2
  import os
3
  import json
 
4
  import gradio as gr
5
- from transformers import pipeline
6
  from huggingface_hub import login
 
7
 
8
  # ============= πŸ” AUTHENTICATION =============
9
- # Hugging Face token (set in your Space settings)
10
- # Go to your Space β†’ Settings β†’ Repository secrets β†’ add key: HF_TOKEN, value: your access token
11
  hf_token = os.getenv("HF_TOKEN")
12
-
13
  if hf_token:
14
  login(token=hf_token)
15
  else:
@@ -17,26 +16,45 @@ else:
17
 
18
  # ============= πŸ“¦ LOAD LABELS =============
19
  with open("labels.json", "r", encoding="utf-8") as f:
20
- candidate_labels = json.load(f)
21
-
22
- # ============= 🧠 MODEL PIPELINE =============
23
- # Use the authenticated pipeline for MedSigLIP (gated model)
24
- pipe = pipeline(
25
- "zero-shot-image-classification",
26
- model="google/medsiglip-448",
27
- use_auth_token=hf_token
28
- )
29
 
30
- # ============= βš™οΈ PREDICTION FUNCTION =============
31
- def classify_medical_image(image):
32
- """
33
- Run zero-shot classification using MedSigLIP.
34
- """
 
 
 
 
 
35
  try:
36
- results = pipe(image, candidate_labels=candidate_labels)
37
- formatted = {r["label"]: round(r["score"] * 100, 2) for r in results}
38
- sorted_result = dict(sorted(formatted.items(), key=lambda x: x[1], reverse=True))
39
- return sorted_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
  return {"Error": str(e)}
42
 
@@ -47,9 +65,8 @@ demo = gr.Interface(
47
  outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
48
  title="🩻 MedSigLIP Zero-Shot Medical Classifier",
49
  description=(
50
- "This demo uses Google's **MedSigLIP (448x448)** model for zero-shot medical image understanding. "
51
- "Upload a medical image (e.g., chest X-ray, skin lesion, fundus photo, histopathology slide) and the model "
52
- "will compute similarity against known medical findings loaded from `labels.json`."
53
  ),
54
  examples=[
55
  ["https://storage.googleapis.com/dx-scin-public-data/dataset/images/3445096909671059178.png"],
@@ -60,4 +77,4 @@ demo = gr.Interface(
60
 
61
  # ============= πŸš€ RUN APP =============
62
  if __name__ == "__main__":
63
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  # app.py
2
  import os
3
  import json
4
+ import torch
5
  import gradio as gr
6
+ from PIL import Image
7
  from huggingface_hub import login
8
+ from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
9
 
10
  # ============= πŸ” AUTHENTICATION =============
 
 
11
  hf_token = os.getenv("HF_TOKEN")
 
12
  if hf_token:
13
  login(token=hf_token)
14
  else:
 
16
 
17
  # ============= πŸ“¦ LOAD LABELS =============
18
  with open("labels.json", "r", encoding="utf-8") as f:
19
+ all_labels = json.load(f)
20
+
21
+ # ============= 🧠 MODEL & PROCESSOR =============
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
23
 
24
+ processor = AutoProcessor.from_pretrained("google/medsiglip-448", use_auth_token=hf_token)
25
+ model = AutoModelForZeroShotImageClassification.from_pretrained(
26
+ "google/medsiglip-448",
27
+ use_auth_token=hf_token,
28
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ ).to(device)
30
+ model.eval()
31
+
32
+ # ============= βš™οΈ INFERENCE FUNCTION =============
33
+ def classify_medical_image(image_path):
34
  try:
35
+ image = Image.open(image_path).convert("RGB")
36
+
37
+ # πŸ”Ή Split labels into small batches to avoid memory overload
38
+ batch_size = 50 # adjust if needed
39
+ results = []
40
+
41
+ for i in range(0, len(all_labels), batch_size):
42
+ batch = all_labels[i:i+batch_size]
43
+ inputs = processor(text=batch, images=image, return_tensors="pt", padding=True).to(device)
44
+
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+ logits_per_image = outputs.logits_per_image
48
+ probs = torch.softmax(logits_per_image, dim=1)[0]
49
+
50
+ for label, score in zip(batch, probs.tolist()):
51
+ results.append((label, round(score * 100, 2)))
52
+
53
+ # πŸ”Ή Sort final results
54
+ results.sort(key=lambda x: x[1], reverse=True)
55
+ top5 = dict(results[:5])
56
+ return top5
57
+
58
  except Exception as e:
59
  return {"Error": str(e)}
60
 
 
65
  outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
66
  title="🩻 MedSigLIP Zero-Shot Medical Classifier",
67
  description=(
68
+ "Efficient version using Google's **MedSigLIP (448x448)** model for medical image understanding. "
69
+ "Optimized for Hugging Face CPU Spaces. Uses batched label processing to reduce memory load."
 
70
  ),
71
  examples=[
72
  ["https://storage.googleapis.com/dx-scin-public-data/dataset/images/3445096909671059178.png"],
 
77
 
78
  # ============= πŸš€ RUN APP =============
79
  if __name__ == "__main__":
80
+ demo.launch(server_name="0.0.0.0", server_port=7860, queue=True)