saad003 commited on
Commit
9e37ce2
·
verified ·
1 Parent(s): 52d9d85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -89
app.py CHANGED
@@ -30,14 +30,14 @@ app.add_middleware(
30
  # Dataset with FAISS index + radiology_metadata.csv
31
  EMBED_REPO_ID = "saad003/Red01"
32
 
33
- # Dataset with ALL radiology images (flat, filenames = ID + ".jpg")
34
- IMAGE_REPO_ID = "saad003/images"
 
35
  BASE_IMAGE_URL = (
36
  f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
  )
38
 
39
- # Optional token (if Red01 / images are private). Set HF_TOKEN in Space secrets.
40
- HF_TOKEN = os.environ.get("HF_TOKEN")
41
 
42
  # ---------- Download index + metadata ----------
43
  print("Downloading FAISS index & metadata from Hugging Face...")
@@ -62,7 +62,6 @@ index = faiss.read_index(INDEX_PATH)
62
  print("Loading metadata CSV...")
63
  metadata = pd.read_csv(META_PATH)
64
 
65
- # We only need these columns
66
  required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
67
  missing = required_cols - set(metadata.columns)
68
  if missing:
@@ -95,112 +94,157 @@ print("Backend ready ✅")
95
 
96
 
97
  # ---------- Helpers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def id_to_image_url(image_id: str) -> str:
99
- """
100
- Build raw image URL.
101
-
102
- Example:
103
- ID = "ROCOv2_2023_test_000040"
104
- -> https://huggingface.co/datasets/saad003/images/resolve/main/ROCOv2_2023_test_000040.jpg
105
- """
106
- if not isinstance(image_id, str):
107
- return None
108
- image_id = image_id.strip()
109
- filename = f"{image_id}.jpg"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  return f"{BASE_IMAGE_URL}/{filename}"
111
 
112
 
113
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
114
- """
115
- Encode query image with CLIP, search FAISS, and return top-k rows
116
- with vec_index, ID, caption, concepts_manual, score, image_url.
117
- """
118
- inputs = clip_processor(images=image, return_tensors="pt").to(device)
119
- with torch.no_grad():
120
- feats = clip_model.get_image_features(**inputs)
121
 
122
- feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
123
- feats = feats.cpu().numpy().astype("float32")
124
 
125
- D, I = index.search(feats, k)
126
 
127
- rows = metadata.iloc[I[0]].copy()
128
- rows["score"] = D[0]
129
- rows["image_url"] = rows["ID"].apply(id_to_image_url)
130
 
131
- return rows[
132
- ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]
133
- ]
134
 
135
 
136
  def generate_query_caption(image: Image.Image) -> str:
137
- """Generate a medical caption for the query image using BLIP."""
138
- inputs = caption_processor(images=image, return_tensors="pt").to(device)
139
- with torch.no_grad():
140
- out = caption_model.generate(**inputs, max_new_tokens=64)
141
- caption = caption_processor.batch_decode(out, skip_special_tokens=True)[0]
142
- return caption.strip()
143
 
144
 
145
  def infer_modality_from_caption(caption: str) -> str:
146
- """Heuristic to infer modality from caption text."""
147
- if not caption:
148
- return "Unknown"
149
-
150
- text = caption.lower()
151
-
152
- if any(w in text for w in ["ct scan", "ct of", "computed tomography"]):
153
- return "CT"
154
- if any(w in text for w in ["mri", "magnetic resonance"]):
155
- return "MRI"
156
- if any(w in text for w in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]):
157
- return "X-ray"
158
- if any(w in text for w in ["ultrasound", "sonography", "sonogram"]):
159
- return "Ultrasound"
160
- if any(w in text for w in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]):
161
- return "PET/CT"
162
-
163
  return "Unknown"
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  # ---------- Routes ----------
167
  @app.get("/")
168
  def root():
169
- return {"status": "ok", "message": "Radiology retrieval + captioning API"}
170
 
171
 
172
  @app.post("/search_by_image")
173
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
174
- """
175
- Upload a radiology image.
176
-
177
- Returns:
178
- - query_caption: BLIP caption for query image
179
- - modality: inferred imaging modality
180
- - results: list of similar images with
181
- vec_index, ID, concepts_manual, score, image_url
182
- """
183
- content = await file.read()
184
- image = Image.open(io.BytesIO(content)).convert("RGB")
185
-
186
- # 1) Retrieval
187
- results_df = search_similar_by_image(image, k=k)
188
- results = results_df.to_dict(orient="records")
189
-
190
- # 2) Caption for query image
191
- try:
192
- query_caption = generate_query_caption(image)
193
- except Exception as e:
194
- print("Error generating caption:", e)
195
- query_caption = None
196
-
197
- # 3) Modality from caption
198
- modality = infer_modality_from_caption(query_caption or "")
199
-
200
- return JSONResponse(
201
- {
202
- "query_caption": query_caption,
203
- "modality": modality,
204
- "results": results,
205
- }
206
- )
 
30
  # Dataset with FAISS index + radiology_metadata.csv
31
  EMBED_REPO_ID = "saad003/Red01"
32
 
33
+ # NEW dataset with images organized into subfolders
34
+ # test, valid, train01, train02, ..., train07
35
+ IMAGE_REPO_ID = "saad003/images04"
36
  BASE_IMAGE_URL = (
37
  f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
38
  )
39
 
40
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set in HF Space secrets if needed
 
41
 
42
  # ---------- Download index + metadata ----------
43
  print("Downloading FAISS index & metadata from Hugging Face...")
 
62
  print("Loading metadata CSV...")
63
  metadata = pd.read_csv(META_PATH)
64
 
 
65
  required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
66
  missing = required_cols - set(metadata.columns)
67
  if missing:
 
94
 
95
 
96
  # ---------- Helpers ----------
97
+ def train_folder_from_id(image_id: str) -> str:
98
+ """
99
+ For IDs like 'ROCOv2_2023_train_000001', decide which trainXX folder.
100
+ Uses numeric ranges based on the last 6 digits.
101
+ """
102
+ try:
103
+ num_str = image_id.split("_")[-1] # "000001"
104
+ num = int(num_str)
105
+ except Exception:
106
+ return "train01" # safe default
107
+
108
+ if num <= 9000:
109
+ return "train01"
110
+ elif num <= 18000:
111
+ return "train02"
112
+ elif num <= 27000:
113
+ return "train03"
114
+ elif num <= 36000:
115
+ return "train04"
116
+ elif num <= 45000:
117
+ return "train05"
118
+ elif num <= 54000:
119
+ return "train06"
120
+ else:
121
+ return "train07"
122
+
123
+
124
  def id_to_image_url(image_id: str) -> str:
125
+ """
126
+ Build raw image URL based on ID and folder structure.
127
+
128
+ Examples:
129
+ ROCOv2_2023_test_000001 -> test/ROCOv2_2023_test_000001.jpg
130
+ ROCOv2_2023_valid_000005 -> valid/ROCOv2_2023_valid_000005.jpg
131
+ ROCOv2_2023_train_000001 -> train01/ROCOv2_2023_train_000001.jpg
132
+ ROCOv2_2023_train_009001 -> train02/ROCOv2_2023_train_009001.jpg
133
+ """
134
+ if not isinstance(image_id, str):
135
+ return None
136
+
137
+ image_id = image_id.strip()
138
+
139
+ if "test_" in image_id:
140
+ folder = "test"
141
+ elif "valid_" in image_id:
142
+ folder = "valid"
143
+ elif "train_" in image_id:
144
+ folder = train_folder_from_id(image_id)
145
+ else:
146
+ # Fallback: put directly at root (in case of weird ID)
147
+ folder = ""
148
+
149
+ filename = f"{image_id}.jpg"
150
+
151
+ if folder:
152
+ return f"{BASE_IMAGE_URL}/{folder}/{filename}"
153
+ else:
154
  return f"{BASE_IMAGE_URL}/{filename}"
155
 
156
 
157
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
158
+ """
159
+ Encode query image with CLIP, search FAISS, and return top-k rows
160
+ with vec_index, ID, caption, concepts_manual, score, image_url.
161
+ """
162
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
163
+ with torch.no_grad():
164
+ feats = clip_model.get_image_features(**inputs)
165
 
166
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
167
+ feats = feats.cpu().numpy().astype("float32")
168
 
169
+ D, I = index.search(feats, k)
170
 
171
+ rows = metadata.iloc[I[0]].copy()
172
+ rows["score"] = D[0]
173
+ rows["image_url"] = rows["ID"].apply(id_to_image_url)
174
 
175
+ return rows[
176
+ ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]
177
+ ]
178
 
179
 
180
  def generate_query_caption(image: Image.Image) -> str:
181
+ """Generate a medical caption for the query image using BLIP."""
182
+ inputs = caption_processor(images=image, return_tensors="pt").to(device)
183
+ with torch.no_grad():
184
+ out = caption_model.generate(**inputs, max_new_tokens=64)
185
+ caption = caption_processor.batch_decode(out, skip_special_tokens=True)[0]
186
+ return caption.strip()
187
 
188
 
189
  def infer_modality_from_caption(caption: str) -> str:
190
+ """Heuristic to infer modality from caption text."""
191
+ if not caption:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  return "Unknown"
193
 
194
+ text = caption.lower()
195
+
196
+ if any(w in text for w in ["ct scan", "ct of", "computed tomography"]):
197
+ return "CT"
198
+ if any(w in text for w in ["mri", "magnetic resonance"]):
199
+ return "MRI"
200
+ if any(w in text for w in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]):
201
+ return "X-ray"
202
+ if any(w in text for w in ["ultrasound", "sonography", "sonogram"]):
203
+ return "Ultrasound"
204
+ if any(w in text for w in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]):
205
+ return "PET/CT"
206
+
207
+ return "Unknown"
208
+
209
 
210
  # ---------- Routes ----------
211
  @app.get("/")
212
  def root():
213
+ return {"status": "ok", "message": "Radiology retrieval + captioning API"}
214
 
215
 
216
  @app.post("/search_by_image")
217
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
218
+ """
219
+ Upload a radiology image.
220
+
221
+ Returns:
222
+ - query_caption: BLIP caption for query image
223
+ - modality: inferred imaging modality
224
+ - results: list of similar images with
225
+ vec_index, ID, concepts_manual, score, image_url
226
+ """
227
+ content = await file.read()
228
+ image = Image.open(io.BytesIO(content)).convert("RGB")
229
+
230
+ # 1) Retrieval
231
+ results_df = search_similar_by_image(image, k=k)
232
+ results = results_df.to_dict(orient="records")
233
+
234
+ # 2) Caption for query image
235
+ try:
236
+ query_caption = generate_query_caption(image)
237
+ except Exception as e:
238
+ print("Error generating caption:", e)
239
+ query_caption = None
240
+
241
+ # 3) Modality
242
+ modality = infer_modality_from_caption(query_caption or "")
243
+
244
+ return JSONResponse(
245
+ {
246
+ "query_caption": query_caption,
247
+ "modality": modality,
248
+ "results": results,
249
+ }
250
+ )