Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import zipfile | |
| import logging | |
| import urllib.parse | |
| import unicodedata | |
| import threading | |
| import anyio | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from huggingface_hub import snapshot_download | |
| # 1. SETTINGS & PATHS | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| DATASET_REPO = "aniketkumar1106/orbit-data" | |
| IMAGE_DIR = "Productimages" | |
| DB_TARGET_FOLDER = "orbiitt_db" # The folder ChromaDB expects | |
| MIN_CONFIDENCE_THRESHOLD = 0.1 | |
| # BOOTSTRAP: Mandatory folder creation | |
| os.makedirs(IMAGE_DIR, exist_ok=True) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 2. GLOBAL STATE | |
| engine = None | |
| loading_status = "System Booting..." | |
| def normalize_filename(name): | |
| if not name: return "" | |
| name = urllib.parse.unquote(name) | |
| name = name.replace('’', "'").replace('‘', "'").replace('“', '"').replace('”', '"').replace('–', '-') | |
| return unicodedata.normalize('NFKD', name).encode('ascii', 'ignore').decode('ascii').strip() | |
| # 3. BACKGROUND INITIALIZATION | |
| def background_sync(): | |
| global engine, loading_status | |
| token = os.environ.get("HF_TOKEN") | |
| # Cleanup old images to prevent duplicates or stale data | |
| logger.info("Cleaning up old image directory...") | |
| for f in os.listdir(IMAGE_DIR): | |
| p = os.path.join(IMAGE_DIR, f) | |
| try: | |
| if os.path.isfile(p): os.unlink(p) | |
| elif os.path.isdir(p): shutil.rmtree(p) | |
| except: pass | |
| try: | |
| loading_status = "Syncing Assets..." | |
| logger.info(f"Downloading dataset from {DATASET_REPO}...") | |
| # Download everything (including the new zip) to current directory | |
| snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", token=token, local_dir=".") | |
| # --- STEP A: HANDLE IMAGE ZIP (Productimages.zip) --- | |
| if os.path.exists("Productimages.zip"): | |
| loading_status = "Extracting Images..." | |
| logger.info("Found Productimages.zip! Extracting...") | |
| # Extract to a temp folder first | |
| with zipfile.ZipFile("Productimages.zip", 'r') as z: | |
| z.extractall("temp_images_zip") | |
| # Move images from temp zip extract to the main IMAGE_DIR | |
| count_zip_images = 0 | |
| for root, dirs, files in os.walk("temp_images_zip"): | |
| for f in files: | |
| # Ignore hidden files (like __MACOSX) | |
| if f.startswith('.'): continue | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): | |
| src = os.path.join(root, f) | |
| clean_name = normalize_filename(f) | |
| dst = os.path.join(IMAGE_DIR, clean_name) | |
| try: | |
| # Move and overwrite if necessary | |
| shutil.move(src, dst) | |
| count_zip_images += 1 | |
| except: pass | |
| shutil.rmtree("temp_images_zip") | |
| logger.info(f"Extracted {count_zip_images} images from Productimages.zip") | |
| # --- STEP B: HANDLE DATABASE ZIP (orbiitt_db.zip) --- | |
| if os.path.exists("orbiitt_db.zip"): | |
| loading_status = "Extracting Database..." | |
| logger.info("Extracting orbiitt_db.zip...") | |
| with zipfile.ZipFile("orbiitt_db.zip", 'r') as z: | |
| z.extractall("temp_extract") | |
| # Smart Extraction: Find the ChromaDB folder logic | |
| db_found = False | |
| for root, dirs, files in os.walk("temp_extract"): | |
| # Identifying ChromaDB by its signature file | |
| if "chroma.sqlite3" in files: | |
| if os.path.exists(DB_TARGET_FOLDER): | |
| shutil.rmtree(DB_TARGET_FOLDER) | |
| # Move the directory containing the sqlite3 file to our target location | |
| shutil.move(root, DB_TARGET_FOLDER) | |
| db_found = True | |
| shutil.rmtree("temp_extract") | |
| if not db_found and not os.path.exists(DB_TARGET_FOLDER): | |
| os.makedirs(DB_TARGET_FOLDER, exist_ok=True) | |
| # --- STEP C: CATCH ANY LEFTOVER LOOSE IMAGES --- | |
| # If any images were downloaded loose (not in zip), move them too | |
| for root, dirs, files in os.walk("."): | |
| if IMAGE_DIR in root or ".git" in root or DB_TARGET_FOLDER in root: | |
| continue | |
| for f in files: | |
| if f.startswith('.'): continue | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): | |
| src = os.path.join(root, f) | |
| clean_name = normalize_filename(f) | |
| dst = os.path.join(IMAGE_DIR, clean_name) | |
| if not os.path.exists(dst): | |
| try: shutil.move(src, dst) | |
| except: pass | |
| # LOGGING FILE COUNT FOR VALIDATION | |
| final_count = len(os.listdir(IMAGE_DIR)) | |
| logger.info(f"DISK VALIDATION: {final_count} images ready in {IMAGE_DIR}") | |
| loading_status = "Loading AI Engine..." | |
| try: | |
| from orbiitt_engine import OrbiittEngine | |
| # Initialize with the CORRECT folder path | |
| engine = OrbiittEngine(db_path=f"./{DB_TARGET_FOLDER}") | |
| loading_status = "Ready" | |
| logger.info(">>> ENGINE ONLINE <<<") | |
| except Exception as e: | |
| loading_status = f"Engine Error: {str(e)}" | |
| logger.error(f"Engine Failed: {e}") | |
| except Exception as e: | |
| loading_status = f"Sync Error: {str(e)}" | |
| logger.error(f"Sync Failed: {e}") | |
| async def startup_event(): | |
| thread = threading.Thread(target=background_sync, daemon=True) | |
| thread.start() | |
| # Mount Static Images | |
| app.mount("/Productimages", StaticFiles(directory=IMAGE_DIR), name="Productimages") | |
| # Serve UI (Root Endpoint) | |
| async def read_index(): | |
| if os.path.exists('index.html'): | |
| return FileResponse('index.html') | |
| return {"status": "Online", "message": "index.html not found, but server is running."} | |
| def health(): | |
| return {"status": loading_status, "ready": engine is not None} | |
| # 4. FIXED SEARCH LOGIC | |
| async def search(text: str = Form(None), weight: float = Form(0.5), file: UploadFile = File(None)): | |
| if not engine: | |
| raise HTTPException(status_code=503, detail=f"Engine not ready: {loading_status}") | |
| t_path = f"buffer_{os.getpid()}.jpg" if file else None | |
| try: | |
| actual_weight = weight | |
| if not text and file: actual_weight = 0.0 | |
| if text and not file: actual_weight = 1.0 | |
| if file and t_path: | |
| content = await file.read() | |
| async with await anyio.open_file(t_path, "wb") as f: | |
| await f.write(content) | |
| # CORRECTED: Calling engine WITHOUT top_k | |
| results = await anyio.to_thread.run_sync( | |
| lambda: engine.search( | |
| text_query=text, | |
| image_file=t_path, | |
| text_weight=actual_weight | |
| ) | |
| ) | |
| all_files = os.listdir(IMAGE_DIR) | |
| final_list = [] | |
| seen_ids = set() | |
| for r in results: | |
| score = r.get('score', 0) | |
| pid = r.get('id', 'Product') | |
| if score < MIN_CONFIDENCE_THRESHOLD or pid in seen_ids: | |
| continue | |
| # The engine returns a clean ID/path, let's match it to disk | |
| fname_from_db = os.path.basename(r.get('id', '')) | |
| fname = normalize_filename(fname_from_db) | |
| match = None | |
| if fname in all_files: | |
| match = fname | |
| else: | |
| # Fuzzy fallback if exact match fails | |
| for disk_f in all_files: | |
| if fname[:15].lower() in disk_f.lower(): | |
| match = disk_f | |
| break | |
| if match: | |
| final_list.append({ | |
| "id": pid, | |
| # Ensure URL is properly encoded for web | |
| "url": f"Productimages/{urllib.parse.quote(match)}", | |
| "score": round(float(score), 4) | |
| }) | |
| seen_ids.add(pid) | |
| final_list.sort(key=lambda x: x['score'], reverse=True) | |
| return {"results": final_list[:20]} | |
| except Exception as e: | |
| logger.error(f"Search Failure: {e}") | |
| return {"results": [], "error": str(e)} | |
| finally: | |
| if t_path and os.path.exists(t_path): | |
| try: os.remove(t_path) | |
| except: pass | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # FIXED: Listen on all interfaces (0.0.0.0) and correct port 7860 | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |