saad003 commited on
Commit
dde6aa9
·
verified ·
1 Parent(s): c41d28a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -44
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # app.py
2
  import io
3
  import os
 
4
 
5
  import faiss
6
  import torch
@@ -18,7 +19,6 @@ from transformers import BlipForConditionalGeneration, AutoProcessor
18
  # ---------- FastAPI app ----------
19
  app = FastAPI()
20
 
21
- # Allow your React app to call this API
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"], # you can restrict later
@@ -28,15 +28,13 @@ app.add_middleware(
28
  )
29
 
30
  # ---------- Config ----------
31
- # Dataset with FAISS index + radiology_metadata.csv
32
  EMBED_REPO_ID = "saad003/Red01"
33
 
34
- # Dataset with ALL radiology images (flat: ID + ".jpg" in root)
35
- # e.g. ROCOv2_2023_valid_000001.jpg
36
  IMAGE_REPO_ID = "saad003/images"
37
- BASE_IMAGE_URL = f"https://huggingface.co/datasets/saad003/images"
38
 
39
- # Optional: token if Red01 is private (set HF_TOKEN secret on the Space)
40
  HF_TOKEN = os.environ.get("HF_TOKEN")
41
 
42
  # ---------- Download index + metadata ----------
@@ -62,7 +60,6 @@ index = faiss.read_index(INDEX_PATH)
62
  print("Loading metadata CSV...")
63
  metadata = pd.read_csv(META_PATH)
64
 
65
- # We will only rely on: vec_index, ID, caption, concepts_manual
66
  required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
67
  missing = required_cols - set(metadata.columns)
68
  if missing:
@@ -92,53 +89,72 @@ caption_model.eval()
92
  print("Backend ready ✅")
93
 
94
 
95
- # ---------- Helper: build image URL from ID ----------
96
  def id_to_image_url(image_id: str) -> str:
 
 
 
 
 
 
 
 
97
  """
98
- Your images dataset `saad003/images` has files like:
99
- ROCOv2_2023_valid_000001.jpg
100
- where filename = ID + ".jpg".
101
  """
102
  if not isinstance(image_id, str):
103
  return None
 
104
  filename = f"{image_id}.jpg"
105
- return f"{BASE_IMAGE_URL}/{filename}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
- # ---------- Helper: search by image ----------
109
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
110
  """
111
  Encode query image with CLIP, search FAISS, return top-k rows
112
- containing vec_index, ID, caption, concepts_manual, score, image_url.
113
  """
114
- # Encode image
115
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
116
  with torch.no_grad():
117
  feats = clip_model.get_image_features(**inputs)
118
 
119
- # Normalize (must match how index was built)
120
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
121
  feats = feats.cpu().numpy().astype("float32")
122
 
123
- # Search FAISS
124
- D, I = index.search(feats, k) # D: distances, I: indices
125
 
126
- # Get rows for top-k indices
127
  rows = metadata.iloc[I[0]].copy()
128
  rows["score"] = D[0]
129
 
130
- # Build URL from ID only
131
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
 
132
 
133
- return rows[["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]]
 
 
134
 
135
 
136
- # ---------- Helper: generate caption for query image ----------
137
  def generate_query_caption(image: Image.Image) -> str:
138
- """
139
- Generate a medical radiology caption for the query image using BLIP
140
- fine-tuned on ROCO.
141
- """
142
  inputs = caption_processor(images=image, return_tensors="pt").to(device)
143
  with torch.no_grad():
144
  out = caption_model.generate(**inputs, max_new_tokens=64)
@@ -146,11 +162,7 @@ def generate_query_caption(image: Image.Image) -> str:
146
  return caption.strip()
147
 
148
 
149
- # ---------- Helper: infer modality from caption ----------
150
  def infer_modality_from_caption(caption: str) -> str:
151
- """
152
- Simple heuristic to infer imaging modality (CT, MRI, X-ray, etc.).
153
- """
154
  if not caption:
155
  return "Unknown"
156
 
@@ -158,16 +170,12 @@ def infer_modality_from_caption(caption: str) -> str:
158
 
159
  if any(w in text for w in ["ct scan", "ct of", "computed tomography"]):
160
  return "CT"
161
-
162
  if any(w in text for w in ["mri", "magnetic resonance"]):
163
  return "MRI"
164
-
165
  if any(w in text for w in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]):
166
  return "X-ray"
167
-
168
  if any(w in text for w in ["ultrasound", "sonography", "sonogram"]):
169
  return "Ultrasound"
170
-
171
  if any(w in text for w in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]):
172
  return "PET/CT"
173
 
@@ -175,7 +183,6 @@ def infer_modality_from_caption(caption: str) -> str:
175
 
176
 
177
  # ---------- Routes ----------
178
-
179
  @app.get("/")
180
  def root():
181
  return {"status": "ok", "message": "Radiology retrieval + captioning API"}
@@ -187,25 +194,24 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
187
  Upload a radiology image.
188
 
189
  Returns:
190
- - query_caption: generated caption for the query image (BLIP)
191
- - modality: inferred imaging modality from the caption
192
- - results: list of similar images with vec_index, ID, concepts_manual, score, image_url
 
 
193
  """
194
  content = await file.read()
195
  image = Image.open(io.BytesIO(content)).convert("RGB")
196
 
197
- # 1) Retrieval
198
  results_df = search_similar_by_image(image, k=k)
199
  results = results_df.to_dict(orient="records")
200
 
201
- # 2) Captioning for the query image
202
  try:
203
- query_caption = generate_query_caption(image)
204
  except Exception as e:
205
- print("Error generating caption:", e)
206
- query_caption = None
207
 
208
- # 3) Infer modality
209
  modality = infer_modality_from_caption(query_caption or "")
210
 
211
  return JSONResponse(
 
1
  # app.py
2
  import io
3
  import os
4
+ import base64
5
 
6
  import faiss
7
  import torch
 
19
  # ---------- FastAPI app ----------
20
  app = FastAPI()
21
 
 
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"], # you can restrict later
 
28
  )
29
 
30
  # ---------- Config ----------
31
+ # FAISS index + radiology_metadata.csv
32
  EMBED_REPO_ID = "saad003/Red01"
33
 
34
+ # All radiology images, filenames like ROCOv2_2023_valid_000001.jpg
 
35
  IMAGE_REPO_ID = "saad003/images"
36
+ BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
 
 
38
  HF_TOKEN = os.environ.get("HF_TOKEN")
39
 
40
  # ---------- Download index + metadata ----------
 
60
  print("Loading metadata CSV...")
61
  metadata = pd.read_csv(META_PATH)
62
 
 
63
  required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
64
  missing = required_cols - set(metadata.columns)
65
  if missing:
 
89
  print("Backend ready ✅")
90
 
91
 
92
+ # ---------- Helpers for images ----------
93
  def id_to_image_url(image_id: str) -> str:
94
+ """Public HF URL (optional, for debugging/click)."""
95
+ if not isinstance(image_id, str):
96
+ return None
97
+ filename = f"{image_id}.jpg"
98
+ return f"{BASE_IMAGE_URL}/{filename}"
99
+
100
+
101
+ def id_to_image_base64(image_id: str) -> str | None:
102
  """
103
+ Download the image from `saad003/images` (cached by hf_hub_download),
104
+ then return base64-encoded bytes so frontend can display directly.
 
105
  """
106
  if not isinstance(image_id, str):
107
  return None
108
+
109
  filename = f"{image_id}.jpg"
110
+ try:
111
+ local_path = hf_hub_download(
112
+ repo_id=IMAGE_REPO_ID,
113
+ filename=filename,
114
+ repo_type="dataset",
115
+ token=HF_TOKEN,
116
+ )
117
+ except Exception as e:
118
+ print(f"Error downloading image for ID={image_id}: {e}")
119
+ return None
120
+
121
+ try:
122
+ with open(local_path, "rb") as f:
123
+ data = f.read()
124
+ return base64.b64encode(data).decode("utf-8")
125
+ except Exception as e:
126
+ print(f"Error reading image file for ID={image_id}: {e}")
127
+ return None
128
 
129
 
130
+ # ---------- Retrieval ----------
131
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
132
  """
133
  Encode query image with CLIP, search FAISS, return top-k rows
134
+ with vec_index, ID, caption, concepts_manual, score, image_url, image_base64.
135
  """
 
136
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
137
  with torch.no_grad():
138
  feats = clip_model.get_image_features(**inputs)
139
 
 
140
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
141
  feats = feats.cpu().numpy().astype("float32")
142
 
143
+ D, I = index.search(feats, k)
 
144
 
 
145
  rows = metadata.iloc[I[0]].copy()
146
  rows["score"] = D[0]
147
 
 
148
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
149
+ rows["image_base64"] = rows["ID"].apply(id_to_image_base64)
150
 
151
+ return rows[
152
+ ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url", "image_base64"]
153
+ ]
154
 
155
 
156
+ # ---------- Captioning ----------
157
  def generate_query_caption(image: Image.Image) -> str:
 
 
 
 
158
  inputs = caption_processor(images=image, return_tensors="pt").to(device)
159
  with torch.no_grad():
160
  out = caption_model.generate(**inputs, max_new_tokens=64)
 
162
  return caption.strip()
163
 
164
 
 
165
  def infer_modality_from_caption(caption: str) -> str:
 
 
 
166
  if not caption:
167
  return "Unknown"
168
 
 
170
 
171
  if any(w in text for w in ["ct scan", "ct of", "computed tomography"]):
172
  return "CT"
 
173
  if any(w in text for w in ["mri", "magnetic resonance"]):
174
  return "MRI"
 
175
  if any(w in text for w in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]):
176
  return "X-ray"
 
177
  if any(w in text for w in ["ultrasound", "sonography", "sonogram"]):
178
  return "Ultrasound"
 
179
  if any(w in text for w in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]):
180
  return "PET/CT"
181
 
 
183
 
184
 
185
  # ---------- Routes ----------
 
186
  @app.get("/")
187
  def root():
188
  return {"status": "ok", "message": "Radiology retrieval + captioning API"}
 
194
  Upload a radiology image.
195
 
196
  Returns:
197
+ - query_caption: BLIP caption for query image
198
+ - modality: inferred imaging modality
199
+ - results: list of similar images with
200
+ vec_index, ID, concepts_manual, score,
201
+ image_url, image_base64
202
  """
203
  content = await file.read()
204
  image = Image.open(io.BytesIO(content)).convert("RGB")
205
 
 
206
  results_df = search_similar_by_image(image, k=k)
207
  results = results_df.to_dict(orient="records")
208
 
 
209
  try:
210
+ query_caption = generate_query_caption(image)
211
  except Exception as e:
212
+ print("Error generating caption:", e)
213
+ query_caption = None
214
 
 
215
  modality = infer_modality_from_caption(query_caption or "")
216
 
217
  return JSONResponse(