ABAO77 commited on
Commit
dfde318
·
verified ·
1 Parent(s): 4a36574

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -14,6 +14,9 @@ app = FastAPI(docs_url="/")
14
  index = faiss.read_index("./model/db_vit_b_16.index")
15
  feature_extractor = FeatureExtractor(base_model="vit_b_16")
16
 
 
 
 
17
 
18
 
19
  # Helper function to load image from uploaded file
@@ -21,6 +24,10 @@ def load_image(uploaded_file):
21
  image = Image.open(BytesIO(uploaded_file)).convert("RGB")
22
  return image
23
 
 
 
 
 
24
  @app.post("/search-image/")
25
  async def search_image(file: UploadFile = File(...)):
26
  try:
@@ -34,20 +41,19 @@ async def search_image(file: UploadFile = File(...)):
34
  output = output / output.norm(p=2, dim=1, keepdim=True)
35
 
36
  # Perform FAISS search for the top 1 similar image
37
- D, I = index.search(output.cpu().numpy(), 1) # Changed from 5 to 1
38
- print("I", I)
39
- print("D", D)
40
- # Load the list of image filenames (assuming you have image_list)
41
- image_list = sorted(os.listdir(os.path.join(DATA_DIR, "images")))
42
 
43
- # Get the path of the most similar image
44
- similar_image_path = os.path.join(DATA_DIR, "images", image_list[int(I[0][0])])
45
 
46
- # Return the image file itself
47
- return FileResponse(similar_image_path, media_type="image/jpeg")
 
48
  except Exception as e:
49
  return JSONResponse(content={"error": str(e)}, status_code=500)
50
 
 
51
  if __name__ == "__main__":
52
  import uvicorn
 
53
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
14
  index = faiss.read_index("./model/db_vit_b_16.index")
15
  feature_extractor = FeatureExtractor(base_model="vit_b_16")
16
 
17
+ # Use MPS (Apple Silicon) if available, otherwise fallback to CPU
18
+ if torch.backends.mps.is_built():
19
+ torch.set_default_device("mps")
20
 
21
 
22
  # Helper function to load image from uploaded file
 
24
  image = Image.open(BytesIO(uploaded_file)).convert("RGB")
25
  return image
26
 
27
+
28
+ image_dir = "./data/images"
29
+
30
+
31
  @app.post("/search-image/")
32
  async def search_image(file: UploadFile = File(...)):
33
  try:
 
41
  output = output / output.norm(p=2, dim=1, keepdim=True)
42
 
43
  # Perform FAISS search for the top 1 similar image
44
+ D, I = index.search(output.cpu().numpy(), 1)
45
+ print(D, I)
 
 
 
46
 
47
+ image_list = sorted(os.listdir(image_dir))
 
48
 
49
+ data_dir = f"{image_dir}/{image_list[int(I[0][0])]}"
50
+
51
+ return FileResponse(data_dir, media_type="image/jpeg")
52
  except Exception as e:
53
  return JSONResponse(content={"error": str(e)}, status_code=500)
54
 
55
+
56
  if __name__ == "__main__":
57
  import uvicorn
58
+
59
  uvicorn.run(app, host="0.0.0.0", port=8000)