AdarshDRC commited on
Commit
eb23bfb
·
verified ·
1 Parent(s): 46f6ce1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +46 -41
main.py CHANGED
@@ -1,13 +1,12 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from typing import List
4
- from PIL import Image
5
  import os
6
  import shutil
7
  import uuid
8
  import re
9
- import inflect # <-- NEW: Import inflect
10
-
11
  from src.models import AIModelManager
12
  from src.cloud_db import CloudDB
13
 
@@ -15,7 +14,7 @@ app = FastAPI()
15
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
@@ -23,82 +22,88 @@ app.add_middleware(
23
 
24
  print("Loading AI Models and Cloud DB...")
25
  ai = AIModelManager()
26
- db = CloudDB()
27
- # Initialize the inflect engine
28
  p = inflect.engine()
29
  print("Ready!")
30
 
31
  os.makedirs("temp_uploads", exist_ok=True)
32
 
33
- # --- NEW: Standardization Function ---
34
  def standardize_category_name(name: str) -> str:
35
- """Converts ' Cows ', 'COWS', or 'cow' all into 'cow'."""
36
- # 1. Lowercase and strip accidental edge spaces
37
  clean_name = name.strip().lower()
38
-
39
- # 2. Replace inner spaces with underscores (e.g., 'sports cars' -> 'sports_cars')
40
  clean_name = re.sub(r'\s+', '_', clean_name)
41
-
42
- # 3. Remove weird special characters just in case (keep only letters, numbers, underscores)
43
  clean_name = re.sub(r'[^\w\s]', '', clean_name)
44
-
45
- # 4. Convert plural to singular (if it's already singular, it returns False, so we keep the clean_name)
46
  singular_name = p.singular_noun(clean_name)
47
- if singular_name:
48
- return singular_name
49
-
 
 
50
  return clean_name
51
- # -------------------------------------
52
 
53
  @app.post("/api/upload")
54
  async def upload_new_images(files: List[UploadFile] = File(...), folder_name: str = Form(...)):
55
- """Handles bulk uploading of multiple images at once."""
56
  uploaded_urls = []
57
-
58
- # Clean the folder name before doing anything else!
59
  standardized_folder = standardize_category_name(folder_name)
60
 
61
  try:
62
  for file in files:
63
- temp_path = f"temp_uploads/{file.filename}"
 
64
  with open(temp_path, "wb") as buffer:
65
  shutil.copyfileobj(file.file, buffer)
66
 
67
- # Upload to Cloudinary using the perfectly clean folder name
68
  image_url = db.upload_image(temp_path, standardized_folder)
69
 
70
- img = Image.open(temp_path).convert('RGB')
71
- vector = ai.encode_image(img)
72
 
73
- image_id = str(uuid.uuid4())
74
- db.add_vector(vector, image_url, image_id)
 
75
 
76
  os.remove(temp_path)
77
  uploaded_urls.append(image_url)
78
 
79
- # Return the standardized name so the frontend knows what was actually saved
80
- return {
81
- "message": f"Successfully added {len(files)} images to category '{standardized_folder}'!",
82
- "urls": uploaded_urls
83
- }
84
-
85
  except Exception as e:
86
  raise HTTPException(status_code=500, detail=str(e))
87
 
88
- # ... (Your /api/search endpoint stays exactly the same) ...
89
  @app.post("/api/search")
90
  async def search_database(file: UploadFile = File(...)):
91
  try:
92
- temp_path = f"temp_uploads/query_{file.filename}"
 
 
93
  with open(temp_path, "wb") as buffer:
94
  shutil.copyfileobj(file.file, buffer)
95
 
96
- img = Image.open(temp_path).convert('RGB')
97
- vector = ai.encode_image(img)
98
 
99
- results = db.search(vector, top_k=10)
 
 
 
 
100
  os.remove(temp_path)
101
 
102
- return {"results": results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
- raise HTTPException(status_code=500, detail=str(e))
 
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from typing import List
 
4
  import os
5
  import shutil
6
  import uuid
7
  import re
8
+ import inflect
9
+ import cloudinary.api
10
  from src.models import AIModelManager
11
  from src.cloud_db import CloudDB
12
 
 
14
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
+ allow_origins=["*"],
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
 
22
 
23
  print("Loading AI Models and Cloud DB...")
24
  ai = AIModelManager()
25
+ db = CloudDB()
 
26
  p = inflect.engine()
27
  print("Ready!")
28
 
29
  os.makedirs("temp_uploads", exist_ok=True)
30
 
 
31
  def standardize_category_name(name: str) -> str:
 
 
32
  clean_name = name.strip().lower()
 
 
33
  clean_name = re.sub(r'\s+', '_', clean_name)
 
 
34
  clean_name = re.sub(r'[^\w\s]', '', clean_name)
 
 
35
  singular_name = p.singular_noun(clean_name)
36
+ return singular_name if singular_name else clean_name
37
+
38
+ def sanitize_filename(filename: str) -> str:
39
+ clean_name = re.sub(r'\s+', '_', filename)
40
+ clean_name = re.sub(r'[^\w\.\-]', '', clean_name)
41
  return clean_name
 
42
 
43
  @app.post("/api/upload")
44
  async def upload_new_images(files: List[UploadFile] = File(...), folder_name: str = Form(...)):
 
45
  uploaded_urls = []
 
 
46
  standardized_folder = standardize_category_name(folder_name)
47
 
48
  try:
49
  for file in files:
50
+ safe_filename = sanitize_filename(file.filename)
51
+ temp_path = f"temp_uploads/{safe_filename}"
52
  with open(temp_path, "wb") as buffer:
53
  shutil.copyfileobj(file.file, buffer)
54
 
 
55
  image_url = db.upload_image(temp_path, standardized_folder)
56
 
57
+ vectors_to_save = ai.process_image(temp_path, is_query=False)
 
58
 
59
+ for vec_dict in vectors_to_save:
60
+ image_id = str(uuid.uuid4())
61
+ db.add_vector(vec_dict, image_url, image_id)
62
 
63
  os.remove(temp_path)
64
  uploaded_urls.append(image_url)
65
 
66
+ return {"message": "Success!", "urls": uploaded_urls}
 
 
 
 
 
67
  except Exception as e:
68
  raise HTTPException(status_code=500, detail=str(e))
69
 
 
70
  @app.post("/api/search")
71
  async def search_database(file: UploadFile = File(...)):
72
  try:
73
+ safe_filename = sanitize_filename(file.filename)
74
+ temp_path = f"temp_uploads/query_{safe_filename}"
75
+
76
  with open(temp_path, "wb") as buffer:
77
  shutil.copyfileobj(file.file, buffer)
78
 
79
+ vectors_to_search = ai.process_image(temp_path, is_query=True)
 
80
 
81
+ all_results = []
82
+ for vec_dict in vectors_to_search:
83
+ results = db.search(vec_dict, top_k=10)
84
+ all_results.extend(results)
85
+
86
  os.remove(temp_path)
87
 
88
+ unique_results = {}
89
+ for r in all_results:
90
+ url = r["url"]
91
+ if url not in unique_results or r["score"] > unique_results[url]["score"]:
92
+ unique_results[url] = r
93
+
94
+ final_results = sorted(unique_results.values(), key=lambda x: x["score"], reverse=True)
95
+
96
+ return {"results": final_results[:10]}
97
+ except Exception as e:
98
+ print(f"Production Search Error: {str(e)}")
99
+ raise HTTPException(status_code=500, detail=str(e))
100
+
101
+ @app.get("/api/categories")
102
+ async def get_categories():
103
+ try:
104
+ result = cloudinary.api.root_folders()
105
+ folders = [folder["name"] for folder in result.get("folders", [])]
106
+ return {"categories": folders}
107
  except Exception as e:
108
+ print(f"Error fetching categories from Cloudinary: {e}")
109
+ return {"categories": []}