saad003 commited on
Commit
bdbadcd
·
verified ·
1 Parent(s): e73f631

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import faiss
4
+ import torch
5
+ import pandas as pd
6
+
7
+ from PIL import Image
8
+ from fastapi import FastAPI, File, UploadFile
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ from transformers import CLIPProcessor, CLIPModel
14
+
15
+ # ---------- FastAPI app ----------
16
+ app = FastAPI()
17
+
18
+ # Allow React frontend to call this API
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # later you can restrict to your frontend domain
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # ---------- Load index + metadata + model at startup ----------
28
+
29
+ REPO_ID = "saad003/Red01" # your dataset repo
30
+
31
+ print("Downloading FAISS index & metadata...")
32
+ INDEX_PATH = hf_hub_download(
33
+ repo_id=REPO_ID,
34
+ filename="radiology_index.faiss",
35
+ repo_type="dataset",
36
+ )
37
+ META_PATH = hf_hub_download(
38
+ repo_id=REPO_ID,
39
+ filename="radiology_metadata.csv",
40
+ repo_type="dataset",
41
+ )
42
+
43
+ print("Loading FAISS index...")
44
+ index = faiss.read_index(INDEX_PATH)
45
+
46
+ print("Loading metadata CSV...")
47
+ metadata = pd.read_csv(META_PATH)
48
+
49
+ print("Loading CLIP model...")
50
+ MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+ clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
54
+ clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
55
+ clip_model.eval()
56
+
57
+ print("Backend ready ✅")
58
+
59
+
60
+ # ---------- Helper: search by image ----------
61
+
62
+ def _search_similar_by_image(image: Image.Image, k: int = 5):
63
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
64
+ with torch.no_grad():
65
+ feats = clip_model.get_image_features(**inputs)
66
+
67
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
68
+ feats = feats.cpu().numpy().astype("float32")
69
+
70
+ D, I = index.search(feats, k)
71
+
72
+ rows = metadata.iloc[I[0]].copy()
73
+ rows["score"] = D[0]
74
+ # Only send useful columns
75
+ return rows[["ID", "split", "img_path", "caption", "concepts_manual", "score"]]
76
+
77
+
78
+ # ---------- API endpoint ----------
79
+
80
+ @app.post("/search_by_image")
81
+ async def search_by_image(file: UploadFile = File(...), k: int = 5):
82
+ # read image from request
83
+ content = await file.read()
84
+ image = Image.open(io.BytesIO(content)).convert("RGB")
85
+
86
+ results_df = _search_similar_by_image(image, k=k)
87
+ results = results_df.to_dict(orient="records")
88
+
89
+ return JSONResponse({"results": results})
90
+
91
+
92
+ @app.get("/")
93
+ def root():
94
+ return {"status": "ok", "message": "Radiology retrieval API"}