ABAO77 commited on
Commit
712da4d
·
verified ·
1 Parent(s): 95c2d79

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -6
app.py CHANGED
@@ -10,6 +10,7 @@ from src.modules import FeatureExtractor
10
  from src.config import DATA_DIR
11
  from fastapi.middleware.cors import CORSMiddleware
12
  import zipfile
 
13
  app = FastAPI(docs_url="/")
14
  origins = ["*"]
15
 
@@ -45,33 +46,46 @@ def image_to_base64(image: Image.Image) -> str:
45
  buffered = BytesIO()
46
  image.save(buffered, format="JPEG")
47
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
 
 
 
 
 
 
 
 
48
  def unzip_folder(zip_file_path, extract_to_path):
49
  # Check if the zip file exists
50
  if not os.path.exists(zip_file_path):
51
  raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
52
-
53
  # Unzip the folder
54
- with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
55
  zip_ref.extractall(extract_to_path)
56
  print(f"Extracted all files to: {extract_to_path}")
57
 
 
58
  # Example usage
59
- zip_file = './images.zip' # Replace with your zip file path
60
- extract_path = './data' # Replace with the directory you want to extract to
61
  unzip_folder(zip_file, extract_path)
62
 
63
 
64
  image_dir = "./data/images"
65
  from pydantic import BaseModel, Field
 
 
66
  class ImageSearchBody(BaseModel):
67
  base64_image: str = Field(..., title="Base64 Image String")
68
 
 
69
  @app.post("/search-image/")
70
  async def search_image(body: ImageSearchBody):
71
  try:
72
  # Convert the Base64 string to an image
73
  image = base64_to_image(body.base64_image)
74
-
75
  # Extract features from the image
76
  with torch.no_grad():
77
  output = feature_extractor.extract_features(image)
@@ -85,13 +99,23 @@ async def search_image(body: ImageSearchBody):
85
  image_list = sorted(os.listdir(image_dir))
86
 
87
  # Load the matched image from the directory
 
 
 
88
  matched_image_path = f"{image_dir}/{image_list[int(I[0][0])]}"
89
  matched_image = Image.open(matched_image_path)
90
 
91
  # Convert the matched image to Base64 string
92
  matched_image_base64 = image_to_base64(matched_image)
93
 
94
- return JSONResponse(content={"image_base64": matched_image_base64})
 
 
 
 
 
 
 
95
 
96
  except Exception as e:
97
  return JSONResponse(content={"error": str(e)}, status_code=500)
 
10
  from src.config import DATA_DIR
11
  from fastapi.middleware.cors import CORSMiddleware
12
  import zipfile
13
+
14
  app = FastAPI(docs_url="/")
15
  origins = ["*"]
16
 
 
46
  buffered = BytesIO()
47
  image.save(buffered, format="JPEG")
48
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
49
+
50
+
51
+ # Helper function to convert PIL image to base64 string
52
+ def image_to_base64(image: Image.Image) -> str:
53
+ buffered = BytesIO()
54
+ image.save(buffered, format="JPEG")
55
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
56
+
57
+
58
  def unzip_folder(zip_file_path, extract_to_path):
59
  # Check if the zip file exists
60
  if not os.path.exists(zip_file_path):
61
  raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
62
+
63
  # Unzip the folder
64
+ with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
65
  zip_ref.extractall(extract_to_path)
66
  print(f"Extracted all files to: {extract_to_path}")
67
 
68
+
69
  # Example usage
70
+ zip_file = "./images.zip" # Replace with your zip file path
71
+ extract_path = "./data" # Replace with the directory you want to extract to
72
  unzip_folder(zip_file, extract_path)
73
 
74
 
75
  image_dir = "./data/images"
76
  from pydantic import BaseModel, Field
77
+
78
+
79
  class ImageSearchBody(BaseModel):
80
  base64_image: str = Field(..., title="Base64 Image String")
81
 
82
+
83
  @app.post("/search-image/")
84
  async def search_image(body: ImageSearchBody):
85
  try:
86
  # Convert the Base64 string to an image
87
  image = base64_to_image(body.base64_image)
88
+
89
  # Extract features from the image
90
  with torch.no_grad():
91
  output = feature_extractor.extract_features(image)
 
99
  image_list = sorted(os.listdir(image_dir))
100
 
101
  # Load the matched image from the directory
102
+
103
+ image_list = sorted(os.listdir(image_dir))
104
+ image_name = image_list[int(I[0][0])]
105
  matched_image_path = f"{image_dir}/{image_list[int(I[0][0])]}"
106
  matched_image = Image.open(matched_image_path)
107
 
108
  # Convert the matched image to Base64 string
109
  matched_image_base64 = image_to_base64(matched_image)
110
 
111
+ return JSONResponse(
112
+ content={
113
+ "image_base64": matched_image_base64,
114
+ "image_name": image_name,
115
+ "similarity_score": float(D[0][0]),
116
+ },
117
+ status_code=200,
118
+ )
119
 
120
  except Exception as e:
121
  return JSONResponse(content={"error": str(e)}, status_code=500)