AdarshDRC's picture
Update main.py
eb23bfb verified
raw
history blame
3.65 kB
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import os
import shutil
import uuid
import re
import inflect
import cloudinary.api
from src.models import AIModelManager
from src.cloud_db import CloudDB
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("Loading AI Models and Cloud DB...")
ai = AIModelManager()
db = CloudDB()
p = inflect.engine()
print("Ready!")
os.makedirs("temp_uploads", exist_ok=True)
def standardize_category_name(name: str) -> str:
clean_name = name.strip().lower()
clean_name = re.sub(r'\s+', '_', clean_name)
clean_name = re.sub(r'[^\w\s]', '', clean_name)
singular_name = p.singular_noun(clean_name)
return singular_name if singular_name else clean_name
def sanitize_filename(filename: str) -> str:
clean_name = re.sub(r'\s+', '_', filename)
clean_name = re.sub(r'[^\w\.\-]', '', clean_name)
return clean_name
@app.post("/api/upload")
async def upload_new_images(files: List[UploadFile] = File(...), folder_name: str = Form(...)):
uploaded_urls = []
standardized_folder = standardize_category_name(folder_name)
try:
for file in files:
safe_filename = sanitize_filename(file.filename)
temp_path = f"temp_uploads/{safe_filename}"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
image_url = db.upload_image(temp_path, standardized_folder)
vectors_to_save = ai.process_image(temp_path, is_query=False)
for vec_dict in vectors_to_save:
image_id = str(uuid.uuid4())
db.add_vector(vec_dict, image_url, image_id)
os.remove(temp_path)
uploaded_urls.append(image_url)
return {"message": "Success!", "urls": uploaded_urls}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search")
async def search_database(file: UploadFile = File(...)):
try:
safe_filename = sanitize_filename(file.filename)
temp_path = f"temp_uploads/query_{safe_filename}"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
vectors_to_search = ai.process_image(temp_path, is_query=True)
all_results = []
for vec_dict in vectors_to_search:
results = db.search(vec_dict, top_k=10)
all_results.extend(results)
os.remove(temp_path)
unique_results = {}
for r in all_results:
url = r["url"]
if url not in unique_results or r["score"] > unique_results[url]["score"]:
unique_results[url] = r
final_results = sorted(unique_results.values(), key=lambda x: x["score"], reverse=True)
return {"results": final_results[:10]}
except Exception as e:
print(f"Production Search Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/categories")
async def get_categories():
try:
result = cloudinary.api.root_folders()
folders = [folder["name"] for folder in result.get("folders", [])]
return {"categories": folders}
except Exception as e:
print(f"Error fetching categories from Cloudinary: {e}")
return {"categories": []}