saad003 commited on
Commit
aaf4ae5
·
verified ·
1 Parent(s): 928975f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -15
app.py CHANGED
@@ -18,10 +18,10 @@ from transformers import BlipForConditionalGeneration, AutoProcessor
18
  # ---------- FastAPI app ----------
19
  app = FastAPI()
20
 
21
- # CORS so your React app can call this API
22
  app.add_middleware(
23
  CORSMiddleware,
24
- allow_origins=["*"], # later you can restrict to your frontend domain
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
@@ -31,11 +31,11 @@ app.add_middleware(
31
  # Dataset with FAISS index + radiology_metadata.csv
32
  EMBED_REPO_ID = "saad003/Red01"
33
 
34
- # Dataset with all radiology images you uploaded
35
  IMAGE_REPO_ID = "saad003/images02"
36
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
 
38
- # Optional: token if Red01 is private
39
  HF_TOKEN = os.environ.get("HF_TOKEN")
40
 
41
  # ---------- Download index + metadata ----------
@@ -61,10 +61,11 @@ index = faiss.read_index(INDEX_PATH)
61
  print("Loading metadata CSV...")
62
  metadata = pd.read_csv(META_PATH)
63
 
64
- # Make sure the index and metadata have same length
65
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
66
 
67
  # ---------- Load CLIP (retrieval) ----------
 
68
  print("Loading PubMedCLIP model for retrieval...")
69
  CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
70
 
@@ -86,13 +87,33 @@ caption_model.eval()
86
  print("Backend ready ✅")
87
 
88
 
89
- # ---------- Helper: build image URL ----------
90
- def id_to_image_url(image_id: str) -> str:
91
  """
92
- Build a public URL to the image in saad003/images02.
93
- Assumes filenames are exactly f\"{image_id}.jpg\".
 
 
 
 
 
 
 
 
94
  """
95
- return f"{BASE_IMAGE_URL}/{image_id}.jpg"
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  # ---------- Helper: search by image ----------
@@ -106,19 +127,19 @@ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
106
  with torch.no_grad():
107
  feats = clip_model.get_image_features(**inputs)
108
 
109
- # Normalize (very important, matches how you created the index)
110
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
111
  feats = feats.cpu().numpy().astype("float32")
112
 
113
- # Search in FAISS
114
  D, I = index.search(feats, k) # D: distances/similarity, I: indices
115
 
116
  # Get metadata rows for top-k indices
117
  rows = metadata.iloc[I[0]].copy()
118
  rows["score"] = D[0]
119
 
120
- # Add image_url for each result
121
- rows["image_url"] = rows["ID"].apply(id_to_image_url)
122
 
123
  return rows[["ID", "split", "caption", "concepts_manual", "score", "image_url"]]
124
 
@@ -136,6 +157,39 @@ def generate_query_caption(image: Image.Image) -> str:
136
  return caption.strip()
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # ---------- Routes ----------
140
 
141
  @app.get("/")
@@ -147,8 +201,10 @@ def root():
147
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
148
  """
149
  Upload a radiology image.
 
150
  Returns:
151
- - query_caption: generated caption for the query image
 
152
  - results: list of similar images with their captions, concepts, score, image_url
153
  """
154
  content = await file.read()
@@ -165,9 +221,13 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
165
  print("Error generating caption:", e)
166
  query_caption = None
167
 
 
 
 
168
  return JSONResponse(
169
  {
170
  "query_caption": query_caption,
 
171
  "results": results,
172
  }
173
  )
 
18
  # ---------- FastAPI app ----------
19
  app = FastAPI()
20
 
21
+ # Allow your React app to call this API
22
  app.add_middleware(
23
  CORSMiddleware,
24
+ allow_origins=["*"], # You can later restrict to your domain
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
 
31
  # Dataset with FAISS index + radiology_metadata.csv
32
  EMBED_REPO_ID = "saad003/Red01"
33
 
34
+ # Dataset with all radiology images (you uploaded here)
35
  IMAGE_REPO_ID = "saad003/images02"
36
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
 
38
+ # Optional: token if Red01 is private (set HF_TOKEN secret in Space)
39
  HF_TOKEN = os.environ.get("HF_TOKEN")
40
 
41
  # ---------- Download index + metadata ----------
 
61
  print("Loading metadata CSV...")
62
  metadata = pd.read_csv(META_PATH)
63
 
64
+ # Sanity check
65
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
66
 
67
  # ---------- Load CLIP (retrieval) ----------
68
+ # IMPORTANT: must match the model you used to build the index.
69
  print("Loading PubMedCLIP model for retrieval...")
70
  CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
71
 
 
87
  print("Backend ready ✅")
88
 
89
 
90
+ # ---------- Helper: build image URL from img_path ----------
91
+ def img_path_to_image_url(img_path: str) -> str:
92
  """
93
+ Use the original img_path from Kaggle and map it to your HF dataset.
94
+
95
+ Example img_path in CSV:
96
+ /kaggle/input/radiology/8333645/train_images/train/ROCOv2_2023_train_000001.jpg
97
+
98
+ If you uploaded folders train_images/..., test_images/..., valid_images/... into
99
+ saad003/images02, the relative path after '8333645/' is what we want.
100
+
101
+ So URL becomes:
102
+ https://huggingface.co/datasets/saad003/images02/resolve/main/train_images/train/ROCOv2_2023_train_000001.jpg
103
  """
104
+ if not isinstance(img_path, str):
105
+ return None
106
+
107
+ # Try to cut everything up to the Kaggle dataset root
108
+ marker = "8333645/"
109
+ if marker in img_path:
110
+ rel = img_path.split(marker, 1)[1]
111
+ else:
112
+ # Fallback: just take the filename
113
+ rel = os.path.basename(img_path)
114
+
115
+ rel = rel.lstrip("/") # safety
116
+ return f"{BASE_IMAGE_URL}/{rel}"
117
 
118
 
119
  # ---------- Helper: search by image ----------
 
127
  with torch.no_grad():
128
  feats = clip_model.get_image_features(**inputs)
129
 
130
+ # Normalize (very important, must match index construction)
131
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
132
  feats = feats.cpu().numpy().astype("float32")
133
 
134
+ # Search FAISS
135
  D, I = index.search(feats, k) # D: distances/similarity, I: indices
136
 
137
  # Get metadata rows for top-k indices
138
  rows = metadata.iloc[I[0]].copy()
139
  rows["score"] = D[0]
140
 
141
+ # Add image_url using original img_path column
142
+ rows["image_url"] = rows["img_path"].apply(img_path_to_image_url)
143
 
144
  return rows[["ID", "split", "caption", "concepts_manual", "score", "image_url"]]
145
 
 
157
  return caption.strip()
158
 
159
 
160
+ # ---------- Helper: infer modality from caption ----------
161
+ def infer_modality_from_caption(caption: str) -> str:
162
+ """
163
+ Simple heuristic to map a caption to imaging modality.
164
+ """
165
+ if not caption:
166
+ return "Unknown"
167
+
168
+ text = caption.lower()
169
+
170
+ # CT
171
+ if any(word in text for word in ["ct scan", "computed tomography", "ct of", "ct image", "ct of the"]):
172
+ return "CT"
173
+
174
+ # MRI
175
+ if any(word in text for word in ["mri", "magnetic resonance"]):
176
+ return "MRI"
177
+
178
+ # X-ray / radiograph
179
+ if any(word in text for word in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]):
180
+ return "X-ray"
181
+
182
+ # Ultrasound
183
+ if any(word in text for word in ["ultrasound", "sonography", "sonogram"]):
184
+ return "Ultrasound"
185
+
186
+ # PET / PET-CT
187
+ if any(word in text for word in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]):
188
+ return "PET/CT"
189
+
190
+ return "Unknown"
191
+
192
+
193
  # ---------- Routes ----------
194
 
195
  @app.get("/")
 
201
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
202
  """
203
  Upload a radiology image.
204
+
205
  Returns:
206
+ - query_caption: generated caption for the query image (BLIP)
207
+ - modality: inferred imaging modality from the caption
208
  - results: list of similar images with their captions, concepts, score, image_url
209
  """
210
  content = await file.read()
 
221
  print("Error generating caption:", e)
222
  query_caption = None
223
 
224
+ # 3) Infer modality
225
+ modality = infer_modality_from_caption(query_caption or "")
226
+
227
  return JSONResponse(
228
  {
229
  "query_caption": query_caption,
230
+ "modality": modality,
231
  "results": results,
232
  }
233
  )