Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List | |
| import os | |
| import shutil | |
| import uuid | |
| import re | |
| import inflect | |
| import cloudinary.api | |
| from src.models import AIModelManager | |
| from src.cloud_db import CloudDB | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| print("Loading AI Models and Cloud DB...") | |
| ai = AIModelManager() | |
| db = CloudDB() | |
| p = inflect.engine() | |
| print("Ready!") | |
| os.makedirs("temp_uploads", exist_ok=True) | |
| def standardize_category_name(name: str) -> str: | |
| clean_name = name.strip().lower() | |
| clean_name = re.sub(r'\s+', '_', clean_name) | |
| clean_name = re.sub(r'[^\w\s]', '', clean_name) | |
| singular_name = p.singular_noun(clean_name) | |
| return singular_name if singular_name else clean_name | |
| def sanitize_filename(filename: str) -> str: | |
| clean_name = re.sub(r'\s+', '_', filename) | |
| clean_name = re.sub(r'[^\w\.\-]', '', clean_name) | |
| return clean_name | |
| async def upload_new_images(files: List[UploadFile] = File(...), folder_name: str = Form(...)): | |
| uploaded_urls = [] | |
| standardized_folder = standardize_category_name(folder_name) | |
| try: | |
| for file in files: | |
| safe_filename = sanitize_filename(file.filename) | |
| temp_path = f"temp_uploads/{safe_filename}" | |
| with open(temp_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| image_url = db.upload_image(temp_path, standardized_folder) | |
| vectors_to_save = ai.process_image(temp_path, is_query=False) | |
| for vec_dict in vectors_to_save: | |
| image_id = str(uuid.uuid4()) | |
| db.add_vector(vec_dict, image_url, image_id) | |
| os.remove(temp_path) | |
| uploaded_urls.append(image_url) | |
| return {"message": "Success!", "urls": uploaded_urls} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def search_database(file: UploadFile = File(...)): | |
| try: | |
| safe_filename = sanitize_filename(file.filename) | |
| temp_path = f"temp_uploads/query_{safe_filename}" | |
| with open(temp_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| vectors_to_search = ai.process_image(temp_path, is_query=True) | |
| all_results = [] | |
| for vec_dict in vectors_to_search: | |
| results = db.search(vec_dict, top_k=10) | |
| all_results.extend(results) | |
| os.remove(temp_path) | |
| unique_results = {} | |
| for r in all_results: | |
| url = r["url"] | |
| if url not in unique_results or r["score"] > unique_results[url]["score"]: | |
| unique_results[url] = r | |
| final_results = sorted(unique_results.values(), key=lambda x: x["score"], reverse=True) | |
| return {"results": final_results[:10]} | |
| except Exception as e: | |
| print(f"Production Search Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_categories(): | |
| try: | |
| result = cloudinary.api.root_folders() | |
| folders = [folder["name"] for folder in result.get("folders", [])] | |
| return {"categories": folders} | |
| except Exception as e: | |
| print(f"Error fetching categories from Cloudinary: {e}") | |
| return {"categories": []} |