saad003 commited on
Commit
cd5db07
·
verified ·
1 Parent(s): 1ea22f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -16
app.py CHANGED
@@ -1,5 +1,7 @@
1
  # app.py
2
  import io
 
 
3
  import faiss
4
  import torch
5
  import pandas as pd
@@ -24,20 +26,34 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # ---------- Load index + metadata + model at startup ----------
 
 
 
 
 
 
 
28
 
29
- REPO_ID = "saad003/Red01" # your dataset repo
 
 
 
30
 
 
31
  print("Downloading FAISS index & metadata...")
32
  INDEX_PATH = hf_hub_download(
33
- repo_id=REPO_ID,
34
  filename="radiology_index.faiss",
35
  repo_type="dataset",
 
36
  )
 
37
  META_PATH = hf_hub_download(
38
- repo_id=REPO_ID,
39
  filename="radiology_metadata.csv",
40
  repo_type="dataset",
 
41
  )
42
 
43
  print("Loading FAISS index...")
@@ -46,9 +62,11 @@ index = faiss.read_index(INDEX_PATH)
46
  print("Loading metadata CSV...")
47
  metadata = pd.read_csv(META_PATH)
48
 
 
49
  print("Loading CLIP model...")
50
  MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
52
 
53
  clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
54
  clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
@@ -58,28 +76,44 @@ print("Backend ready ✅")
58
 
59
 
60
  # ---------- Helper: search by image ----------
61
-
62
- def _search_similar_by_image(image: Image.Image, k: int = 5):
 
 
 
63
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
64
  with torch.no_grad():
65
  feats = clip_model.get_image_features(**inputs)
66
 
 
67
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
68
  feats = feats.cpu().numpy().astype("float32")
69
 
70
- D, I = index.search(feats, k)
71
 
72
  rows = metadata.iloc[I[0]].copy()
73
  rows["score"] = D[0]
74
- # Only send useful columns
75
- return rows[["ID", "split", "img_path", "caption", "concepts_manual", "score"]]
76
 
 
 
 
 
 
77
 
78
- # ---------- API endpoint ----------
 
 
 
 
 
 
79
 
80
  @app.post("/search_by_image")
81
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
82
- # read image from request
 
 
 
83
  content = await file.read()
84
  image = Image.open(io.BytesIO(content)).convert("RGB")
85
 
@@ -87,8 +121,3 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
87
  results = results_df.to_dict(orient="records")
88
 
89
  return JSONResponse({"results": results})
90
-
91
-
92
- @app.get("/")
93
- def root():
94
- return {"status": "ok", "message": "Radiology retrieval API"}
 
1
  # app.py
2
  import io
3
+ import os
4
+
5
  import faiss
6
  import torch
7
  import pandas as pd
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ # ---------- Config ----------
30
+ # Dataset with FAISS index + metadata
31
+ EMBED_REPO_ID = "saad003/Red01"
32
+ # Dataset with raw radiology images
33
+ IMAGE_REPO_ID = "saad003/images"
34
+
35
+ # Base URL for images (you uploaded ROCOv2_*.jpg directly in the root)
36
+ BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
 
38
+ # Optional: token if EMBED_REPO_ID is private (set HF_TOKEN secret in Space)
39
+ HF_TOKEN = os.environ.get("HF_TOKEN")
40
+ if HF_TOKEN is None:
41
+ print("⚠️ No HF_TOKEN env var found. If the dataset is private, this may fail.")
42
 
43
+ # ---------- Load index + metadata ----------
44
  print("Downloading FAISS index & metadata...")
45
  INDEX_PATH = hf_hub_download(
46
+ repo_id=EMBED_REPO_ID,
47
  filename="radiology_index.faiss",
48
  repo_type="dataset",
49
+ token=HF_TOKEN,
50
  )
51
+
52
  META_PATH = hf_hub_download(
53
+ repo_id=EMBED_REPO_ID,
54
  filename="radiology_metadata.csv",
55
  repo_type="dataset",
56
+ token=HF_TOKEN,
57
  )
58
 
59
  print("Loading FAISS index...")
 
62
  print("Loading metadata CSV...")
63
  metadata = pd.read_csv(META_PATH)
64
 
65
+ # ---------- Load CLIP model ----------
66
  print("Loading CLIP model...")
67
  MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
68
  device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ print("Using device:", device)
70
 
71
  clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
72
  clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
 
76
 
77
 
78
  # ---------- Helper: search by image ----------
79
+ def _search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
80
+ """
81
+ Encode query image with CLIP, search FAISS, return top-k rows
82
+ with ID, split, caption, concepts, score, and image_url.
83
+ """
84
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
85
  with torch.no_grad():
86
  feats = clip_model.get_image_features(**inputs)
87
 
88
+ # L2-normalize to match index normalization
89
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
90
  feats = feats.cpu().numpy().astype("float32")
91
 
92
+ D, I = index.search(feats, k) # D: similarity scores, I: indices
93
 
94
  rows = metadata.iloc[I[0]].copy()
95
  rows["score"] = D[0]
 
 
96
 
97
+ # Build image URL for each retrieved ID
98
+ # Files are named like ROCOv2_2023_test_000001.jpg in saad003/images
99
+ rows["image_url"] = rows["ID"].apply(
100
+ lambda id_str: f"{BASE_IMAGE_URL}/{id_str}.jpg"
101
+ )
102
 
103
+ return rows[["ID", "split", "caption", "concepts_manual", "score", "image_url"]]
104
+
105
+
106
+ # ---------- Routes ----------
107
+ @app.get("/")
108
+ def root():
109
+ return {"status": "ok", "message": "Radiology retrieval API"}
110
 
111
  @app.post("/search_by_image")
112
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
113
+ """
114
+ Accepts an uploaded radiology image.
115
+ Returns top-k similar images (ID, caption, concepts, score, image_url).
116
+ """
117
  content = await file.read()
118
  image = Image.open(io.BytesIO(content)).convert("RGB")
119
 
 
121
  results = results_df.to_dict(orient="records")
122
 
123
  return JSONResponse({"results": results})