saad003 commited on
Commit
2501ddf
·
verified ·
1 Parent(s): aaea08e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -45
app.py CHANGED
@@ -1,7 +1,6 @@
1
  # app.py
2
  import io
3
  import os
4
- import base64
5
 
6
  import faiss
7
  import torch
@@ -21,20 +20,23 @@ app = FastAPI()
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
24
- allow_origins=["*"], # you can restrict later
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
28
  )
29
 
30
  # ---------- Config ----------
31
- # FAISS index + radiology_metadata.csv
32
  EMBED_REPO_ID = "saad003/Red01"
33
 
34
- # All radiology images, filenames like ROCOv2_2023_valid_000001.jpg
35
  IMAGE_REPO_ID = "saad003/images"
36
- BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
 
 
37
 
 
38
  HF_TOKEN = os.environ.get("HF_TOKEN")
39
 
40
  # ---------- Download index + metadata ----------
@@ -60,6 +62,7 @@ index = faiss.read_index(INDEX_PATH)
60
  print("Loading metadata CSV...")
61
  metadata = pd.read_csv(META_PATH)
62
 
 
63
  required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
64
  missing = required_cols - set(metadata.columns)
65
  if missing:
@@ -83,55 +86,34 @@ print("Loading BLIP radiology captioning model...")
83
  CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
84
 
85
  caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
86
- caption_model = BlipForConditionalGeneration.from_pretrained(CAPTION_MODEL_ID).to(device)
 
 
87
  caption_model.eval()
88
 
89
  print("Backend ready ✅")
90
 
91
 
92
- # ---------- Helpers for images ----------
93
  def id_to_image_url(image_id: str) -> str:
94
- """Public HF URL (optional, for debugging/click)."""
95
- if not isinstance(image_id, str):
96
- return None
97
- filename = f"{image_id}.jpg"
98
- return f"{BASE_IMAGE_URL}/{filename}"
99
-
100
-
101
- def id_to_image_base64(image_id: str) -> str | None:
102
  """
103
- Download the image from `saad003/images` (cached by hf_hub_download),
104
- then return base64-encoded bytes so frontend can display directly.
 
 
 
105
  """
106
  if not isinstance(image_id, str):
107
  return None
108
-
109
  filename = f"{image_id}.jpg"
110
- try:
111
- local_path = hf_hub_download(
112
- repo_id=IMAGE_REPO_ID,
113
- filename=filename,
114
- repo_type="dataset",
115
- token=HF_TOKEN,
116
- )
117
- except Exception as e:
118
- print(f"Error downloading image for ID={image_id}: {e}")
119
- return None
120
-
121
- try:
122
- with open(local_path, "rb") as f:
123
- data = f.read()
124
- return base64.b64encode(data).decode("utf-8")
125
- except Exception as e:
126
- print(f"Error reading image file for ID={image_id}: {e}")
127
- return None
128
 
129
 
130
- # ---------- Retrieval ----------
131
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
132
  """
133
- Encode query image with CLIP, search FAISS, return top-k rows
134
- with vec_index, ID, caption, concepts_manual, score, image_url, image_base64.
135
  """
136
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
137
  with torch.no_grad():
@@ -144,17 +126,15 @@ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
144
 
145
  rows = metadata.iloc[I[0]].copy()
146
  rows["score"] = D[0]
147
-
148
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
149
- rows["image_base64"] = rows["ID"].apply(id_to_image_base64)
150
 
151
  return rows[
152
- ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url", "image_base64"]
153
  ]
154
 
155
 
156
- # ---------- Captioning ----------
157
  def generate_query_caption(image: Image.Image) -> str:
 
158
  inputs = caption_processor(images=image, return_tensors="pt").to(device)
159
  with torch.no_grad():
160
  out = caption_model.generate(**inputs, max_new_tokens=64)
@@ -163,6 +143,7 @@ def generate_query_caption(image: Image.Image) -> str:
163
 
164
 
165
  def infer_modality_from_caption(caption: str) -> str:
 
166
  if not caption:
167
  return "Unknown"
168
 
@@ -197,21 +178,23 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
197
  - query_caption: BLIP caption for query image
198
  - modality: inferred imaging modality
199
  - results: list of similar images with
200
- vec_index, ID, concepts_manual, score,
201
- image_url, image_base64
202
  """
203
  content = await file.read()
204
  image = Image.open(io.BytesIO(content)).convert("RGB")
205
 
 
206
  results_df = search_similar_by_image(image, k=k)
207
  results = results_df.to_dict(orient="records")
208
 
 
209
  try:
210
  query_caption = generate_query_caption(image)
211
  except Exception as e:
212
  print("Error generating caption:", e)
213
  query_caption = None
214
 
 
215
  modality = infer_modality_from_caption(query_caption or "")
216
 
217
  return JSONResponse(
 
1
  # app.py
2
  import io
3
  import os
 
4
 
5
  import faiss
6
  import torch
 
20
 
21
  app.add_middleware(
22
  CORSMiddleware,
23
+ allow_origins=["*"], # later you can restrict to your frontend domain
24
  allow_credentials=True,
25
  allow_methods=["*"],
26
  allow_headers=["*"],
27
  )
28
 
29
  # ---------- Config ----------
30
+ # Dataset with FAISS index + radiology_metadata.csv
31
  EMBED_REPO_ID = "saad003/Red01"
32
 
33
+ # Dataset with ALL radiology images (flat, filenames = ID + ".jpg")
34
  IMAGE_REPO_ID = "saad003/images"
35
+ BASE_IMAGE_URL = (
36
+ f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
+ )
38
 
39
+ # Optional token (if Red01 / images are private). Set HF_TOKEN in Space secrets.
40
  HF_TOKEN = os.environ.get("HF_TOKEN")
41
 
42
  # ---------- Download index + metadata ----------
 
62
  print("Loading metadata CSV...")
63
  metadata = pd.read_csv(META_PATH)
64
 
65
+ # We only need these columns
66
  required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
67
  missing = required_cols - set(metadata.columns)
68
  if missing:
 
86
  CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
87
 
88
  caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
89
+ caption_model = BlipForConditionalGeneration.from_pretrained(
90
+ CAPTION_MODEL_ID
91
+ ).to(device)
92
  caption_model.eval()
93
 
94
  print("Backend ready ✅")
95
 
96
 
97
+ # ---------- Helpers ----------
98
  def id_to_image_url(image_id: str) -> str:
 
 
 
 
 
 
 
 
99
  """
100
+ Build raw image URL.
101
+
102
+ Example:
103
+ ID = "ROCOv2_2023_test_000040"
104
+ -> https://huggingface.co/datasets/saad003/images/resolve/main/ROCOv2_2023_test_000040.jpg
105
  """
106
  if not isinstance(image_id, str):
107
  return None
108
+ image_id = image_id.strip()
109
  filename = f"{image_id}.jpg"
110
+ return f"{BASE_IMAGE_URL}/{filename}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
 
113
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
114
  """
115
+ Encode query image with CLIP, search FAISS, and return top-k rows
116
+ with vec_index, ID, caption, concepts_manual, score, image_url.
117
  """
118
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
119
  with torch.no_grad():
 
126
 
127
  rows = metadata.iloc[I[0]].copy()
128
  rows["score"] = D[0]
 
129
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
 
130
 
131
  return rows[
132
+ ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]
133
  ]
134
 
135
 
 
136
  def generate_query_caption(image: Image.Image) -> str:
137
+ """Generate a medical caption for the query image using BLIP."""
138
  inputs = caption_processor(images=image, return_tensors="pt").to(device)
139
  with torch.no_grad():
140
  out = caption_model.generate(**inputs, max_new_tokens=64)
 
143
 
144
 
145
  def infer_modality_from_caption(caption: str) -> str:
146
+ """Heuristic to infer modality from caption text."""
147
  if not caption:
148
  return "Unknown"
149
 
 
178
  - query_caption: BLIP caption for query image
179
  - modality: inferred imaging modality
180
  - results: list of similar images with
181
+ vec_index, ID, concepts_manual, score, image_url
 
182
  """
183
  content = await file.read()
184
  image = Image.open(io.BytesIO(content)).convert("RGB")
185
 
186
+ # 1) Retrieval
187
  results_df = search_similar_by_image(image, k=k)
188
  results = results_df.to_dict(orient="records")
189
 
190
+ # 2) Caption for query image
191
  try:
192
  query_caption = generate_query_caption(image)
193
  except Exception as e:
194
  print("Error generating caption:", e)
195
  query_caption = None
196
 
197
+ # 3) Modality from caption
198
  modality = infer_modality_from_caption(query_caption or "")
199
 
200
  return JSONResponse(