ABAO77 commited on
Commit
351bcee
·
verified ·
1 Parent(s): d834c9f

Upload 22 files

Browse files
Files changed (4) hide show
  1. app.py +16 -30
  2. images.zip +2 -2
  3. model/db_vit_b_16.index +1 -1
  4. src/build_vector_database.py +13 -18
app.py CHANGED
@@ -3,13 +3,13 @@ import torch
3
  import faiss
4
  import base64
5
  from PIL import Image
6
- from fastapi import FastAPI, UploadFile, File, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from io import BytesIO
9
  from src.modules import FeatureExtractor
10
- from src.config import DATA_DIR
11
  from fastapi.middleware.cors import CORSMiddleware
12
  import zipfile
 
13
 
14
  app = FastAPI(docs_url="/")
15
  origins = ["*"]
@@ -22,16 +22,13 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Load FAISS index and feature extractor
26
  index = faiss.read_index("./model/db_vit_b_16.index")
27
  feature_extractor = FeatureExtractor(base_model="vit_b_16")
28
 
29
- # Use MPS (Apple Silicon) if available, otherwise fallback to CPU
30
  if torch.backends.mps.is_built():
31
  torch.set_default_device("mps")
32
 
33
 
34
- # Helper function to convert base64 string to PIL image
35
  def base64_to_image(base64_str: str) -> Image.Image:
36
  try:
37
  image_data = base64.b64decode(base64_str)
@@ -41,14 +38,12 @@ def base64_to_image(base64_str: str) -> Image.Image:
41
  raise HTTPException(status_code=400, detail="Invalid Base64 image")
42
 
43
 
44
- # Helper function to convert PIL image to base64 string
45
  def image_to_base64(image: Image.Image) -> str:
46
  buffered = BytesIO()
47
  image.save(buffered, format="JPEG")
48
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
49
 
50
 
51
- # Helper function to convert PIL image to base64 string
52
  def image_to_base64(image: Image.Image) -> str:
53
  buffered = BytesIO()
54
  image.save(buffered, format="JPEG")
@@ -56,24 +51,26 @@ def image_to_base64(image: Image.Image) -> str:
56
 
57
 
58
  def unzip_folder(zip_file_path, extract_to_path):
59
- # Check if the zip file exists
60
  if not os.path.exists(zip_file_path):
61
  raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
62
-
63
- # Unzip the folder
64
  with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
65
- zip_ref.extractall(extract_to_path)
 
 
 
 
 
66
  print(f"Extracted all files to: {extract_to_path}")
67
 
68
 
69
- # Example usage
70
- zip_file = "./images.zip" # Replace with your zip file path
71
- extract_path = "./data" # Replace with the directory you want to extract to
72
  unzip_folder(zip_file, extract_path)
73
 
74
 
75
- image_dir = "./data/images"
76
- from pydantic import BaseModel, Field
 
77
 
78
 
79
  class ImageSearchBody(BaseModel):
@@ -83,29 +80,18 @@ class ImageSearchBody(BaseModel):
83
  @app.post("/search-image/")
84
  async def search_image(body: ImageSearchBody):
85
  try:
86
- # Convert the Base64 string to an image
87
  image = base64_to_image(body.base64_image)
88
-
89
- # Extract features from the image
90
  with torch.no_grad():
91
  output = feature_extractor.extract_features(image)
92
  output = output.view(output.size(0), -1)
93
  output = output / output.norm(p=2, dim=1, keepdim=True)
94
-
95
- # Perform FAISS search for the top 1 similar image
96
  D, I = index.search(output.cpu().numpy(), 1)
97
  print(D, I)
98
-
99
- image_list = sorted(os.listdir(image_dir))
100
-
101
- # Load the matched image from the directory
102
-
103
- image_list = sorted(os.listdir(image_dir))
104
  image_name = image_list[int(I[0][0])]
105
- matched_image_path = f"{image_dir}/{image_list[int(I[0][0])]}"
106
  matched_image = Image.open(matched_image_path)
107
-
108
- # Convert the matched image to Base64 string
109
  matched_image_base64 = image_to_base64(matched_image)
110
 
111
  return JSONResponse(
 
3
  import faiss
4
  import base64
5
  from PIL import Image
6
+ from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from io import BytesIO
9
  from src.modules import FeatureExtractor
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  import zipfile
12
+ from pydantic import BaseModel, Field
13
 
14
  app = FastAPI(docs_url="/")
15
  origins = ["*"]
 
22
  allow_headers=["*"],
23
  )
24
 
 
25
  index = faiss.read_index("./model/db_vit_b_16.index")
26
  feature_extractor = FeatureExtractor(base_model="vit_b_16")
27
 
 
28
  if torch.backends.mps.is_built():
29
  torch.set_default_device("mps")
30
 
31
 
 
32
  def base64_to_image(base64_str: str) -> Image.Image:
33
  try:
34
  image_data = base64.b64decode(base64_str)
 
38
  raise HTTPException(status_code=400, detail="Invalid Base64 image")
39
 
40
 
 
41
  def image_to_base64(image: Image.Image) -> str:
42
  buffered = BytesIO()
43
  image.save(buffered, format="JPEG")
44
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
45
 
46
 
 
47
  def image_to_base64(image: Image.Image) -> str:
48
  buffered = BytesIO()
49
  image.save(buffered, format="JPEG")
 
51
 
52
 
53
  def unzip_folder(zip_file_path, extract_to_path):
 
54
  if not os.path.exists(zip_file_path):
55
  raise FileNotFoundError(f"Zip file not found: {zip_file_path}")
 
 
56
  with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
57
+ for member in zip_ref.infolist():
58
+ filename = member.filename.encode("cp437").decode("utf-8")
59
+ extracted_path = os.path.join(extract_to_path, filename)
60
+ os.makedirs(os.path.dirname(extracted_path), exist_ok=True)
61
+ with zip_ref.open(member) as source, open(extracted_path, "wb") as target:
62
+ target.write(source.read())
63
  print(f"Extracted all files to: {extract_to_path}")
64
 
65
 
66
+ zip_file = "./images.zip"
67
+ extract_path = "./data"
 
68
  unzip_folder(zip_file, extract_path)
69
 
70
 
71
+ def is_image_file(filename):
72
+ valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp")
73
+ return filename.lower().endswith(valid_extensions)
74
 
75
 
76
  class ImageSearchBody(BaseModel):
 
80
  @app.post("/search-image/")
81
  async def search_image(body: ImageSearchBody):
82
  try:
 
83
  image = base64_to_image(body.base64_image)
 
 
84
  with torch.no_grad():
85
  output = feature_extractor.extract_features(image)
86
  output = output.view(output.size(0), -1)
87
  output = output / output.norm(p=2, dim=1, keepdim=True)
 
 
88
  D, I = index.search(output.cpu().numpy(), 1)
89
  print(D, I)
90
+ image_list = sorted([f for f in os.listdir(extract_path) if is_image_file(f)])
91
+ print(image_list)
 
 
 
 
92
  image_name = image_list[int(I[0][0])]
93
+ matched_image_path = f"{extract_path}/{image_list[int(I[0][0])]}"
94
  matched_image = Image.open(matched_image_path)
 
 
95
  matched_image_base64 = image_to_base64(matched_image)
96
 
97
  return JSONResponse(
images.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0c65c754c9eb4694987f102d2fb9d1b957bfd4dcf44d5e1dbfb4b2e40e590fee
3
- size 29598354
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb42184dd971df46852b4c7b7ae6b5a2891abdc4a39006e83923245ae7b5e66b
3
+ size 29594676
model/db_vit_b_16.index CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e26e2231dfeee665f11a89639ef30e94cf47780aa44460353529a15c8f3691b4
3
  size 276525
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ca38802b326da27ea0d3fd79c78672da86dd82b61c307d51201092cfaf0c107
3
  size 276525
src/build_vector_database.py CHANGED
@@ -1,36 +1,32 @@
1
- # Description:
2
- # This script is used to build the vector database for the images in the dataset.
3
- # The script uses the FeatureExtractor class to extract the features from the images and saves them to a Faiss index.
4
- #
5
- # Usage:
6
- # To use this script, you can run the following commands:
7
- # python3 build_vector_database.py
8
- # python3 build_vector_database.py --feat_extractor vit_l_32
9
- # python3 build_vector_database.py --feat_extractor resnet101
10
- #
11
  import faulthandler
 
12
  faulthandler.enable()
13
 
14
  import torch
15
  from tqdm import tqdm
16
  import argparse
17
  import faiss
18
- import torch
19
  import PIL
20
  import os
21
 
22
  from modules import FeatureExtractor
23
  from config import *
24
 
25
- images_dir = "../webp_images/images"
26
- data_dir = "../webp_images"
 
 
 
 
 
27
  def main(args=None):
28
  # initialize the feature extractor with the base model specified in the arguments
29
  feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
30
  # initialize the vector database indexing
31
  index = faiss.IndexFlatIP(feature_extractor.feat_dims)
32
- # get the list of images in sorted order
33
- image_list = sorted(os.listdir(images_dir))
 
34
 
35
  with torch.no_grad():
36
  # iterate over the images and add their extracted features to the index
@@ -47,10 +43,9 @@ def main(args=None):
47
  index.add(output.numpy())
48
 
49
  # save the index
50
- index_filepath = os.path.join(data_dir, f"db_{args.feat_extractor}.index")
51
  faiss.write_index(index, index_filepath)
52
 
53
-
54
  if __name__ == "__main__":
55
  # parse arguments
56
  args = argparse.ArgumentParser()
@@ -63,4 +58,4 @@ if __name__ == "__main__":
63
  args = args.parse_args()
64
 
65
  # run the main function
66
- main(args)
 
 
 
 
 
 
 
 
 
 
 
1
  import faulthandler
2
+
3
  faulthandler.enable()
4
 
5
  import torch
6
  from tqdm import tqdm
7
  import argparse
8
  import faiss
 
9
  import PIL
10
  import os
11
 
12
  from modules import FeatureExtractor
13
  from config import *
14
 
15
+ images_dir = "../data"
16
+ model_dir = "../model"
17
+
18
+ def is_image_file(filename):
19
+ valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp')
20
+ return filename.lower().endswith(valid_extensions)
21
+
22
  def main(args=None):
23
  # initialize the feature extractor with the base model specified in the arguments
24
  feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
25
  # initialize the vector database indexing
26
  index = faiss.IndexFlatIP(feature_extractor.feat_dims)
27
+ # get the list of images in sorted order and filter out non-image files
28
+ image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)])
29
+ # print(image_list)
30
 
31
  with torch.no_grad():
32
  # iterate over the images and add their extracted features to the index
 
43
  index.add(output.numpy())
44
 
45
  # save the index
46
+ index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index")
47
  faiss.write_index(index, index_filepath)
48
 
 
49
  if __name__ == "__main__":
50
  # parse arguments
51
  args = argparse.ArgumentParser()
 
58
  args = args.parse_args()
59
 
60
  # run the main function
61
+ main(args)