orbittv2 / server.py
aniketkumar1106's picture
Update server.py
03e77d7 verified
import os
import shutil
import zipfile
import logging
import urllib.parse
import unicodedata
import threading
import anyio
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from huggingface_hub import snapshot_download
# 1. SETTINGS & PATHS
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DATASET_REPO = "aniketkumar1106/orbit-data"
IMAGE_DIR = "Productimages"
DB_TARGET_FOLDER = "orbiitt_db" # The folder ChromaDB expects
MIN_CONFIDENCE_THRESHOLD = 0.1
# BOOTSTRAP: Mandatory folder creation
os.makedirs(IMAGE_DIR, exist_ok=True)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 2. GLOBAL STATE
engine = None
loading_status = "System Booting..."
def normalize_filename(name):
if not name: return ""
name = urllib.parse.unquote(name)
name = name.replace('’', "'").replace('‘', "'").replace('“', '"').replace('”', '"').replace('–', '-')
return unicodedata.normalize('NFKD', name).encode('ascii', 'ignore').decode('ascii').strip()
# 3. BACKGROUND INITIALIZATION
def background_sync():
global engine, loading_status
token = os.environ.get("HF_TOKEN")
# Cleanup old images to prevent duplicates or stale data
logger.info("Cleaning up old image directory...")
for f in os.listdir(IMAGE_DIR):
p = os.path.join(IMAGE_DIR, f)
try:
if os.path.isfile(p): os.unlink(p)
elif os.path.isdir(p): shutil.rmtree(p)
except: pass
try:
loading_status = "Syncing Assets..."
logger.info(f"Downloading dataset from {DATASET_REPO}...")
# Download everything (including the new zip) to current directory
snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", token=token, local_dir=".")
# --- STEP A: HANDLE IMAGE ZIP (Productimages.zip) ---
if os.path.exists("Productimages.zip"):
loading_status = "Extracting Images..."
logger.info("Found Productimages.zip! Extracting...")
# Extract to a temp folder first
with zipfile.ZipFile("Productimages.zip", 'r') as z:
z.extractall("temp_images_zip")
# Move images from temp zip extract to the main IMAGE_DIR
count_zip_images = 0
for root, dirs, files in os.walk("temp_images_zip"):
for f in files:
# Ignore hidden files (like __MACOSX)
if f.startswith('.'): continue
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
src = os.path.join(root, f)
clean_name = normalize_filename(f)
dst = os.path.join(IMAGE_DIR, clean_name)
try:
# Move and overwrite if necessary
shutil.move(src, dst)
count_zip_images += 1
except: pass
shutil.rmtree("temp_images_zip")
logger.info(f"Extracted {count_zip_images} images from Productimages.zip")
# --- STEP B: HANDLE DATABASE ZIP (orbiitt_db.zip) ---
if os.path.exists("orbiitt_db.zip"):
loading_status = "Extracting Database..."
logger.info("Extracting orbiitt_db.zip...")
with zipfile.ZipFile("orbiitt_db.zip", 'r') as z:
z.extractall("temp_extract")
# Smart Extraction: Find the ChromaDB folder logic
db_found = False
for root, dirs, files in os.walk("temp_extract"):
# Identifying ChromaDB by its signature file
if "chroma.sqlite3" in files:
if os.path.exists(DB_TARGET_FOLDER):
shutil.rmtree(DB_TARGET_FOLDER)
# Move the directory containing the sqlite3 file to our target location
shutil.move(root, DB_TARGET_FOLDER)
db_found = True
shutil.rmtree("temp_extract")
if not db_found and not os.path.exists(DB_TARGET_FOLDER):
os.makedirs(DB_TARGET_FOLDER, exist_ok=True)
# --- STEP C: CATCH ANY LEFTOVER LOOSE IMAGES ---
# If any images were downloaded loose (not in zip), move them too
for root, dirs, files in os.walk("."):
if IMAGE_DIR in root or ".git" in root or DB_TARGET_FOLDER in root:
continue
for f in files:
if f.startswith('.'): continue
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
src = os.path.join(root, f)
clean_name = normalize_filename(f)
dst = os.path.join(IMAGE_DIR, clean_name)
if not os.path.exists(dst):
try: shutil.move(src, dst)
except: pass
# LOGGING FILE COUNT FOR VALIDATION
final_count = len(os.listdir(IMAGE_DIR))
logger.info(f"DISK VALIDATION: {final_count} images ready in {IMAGE_DIR}")
loading_status = "Loading AI Engine..."
try:
from orbiitt_engine import OrbiittEngine
# Initialize with the CORRECT folder path
engine = OrbiittEngine(db_path=f"./{DB_TARGET_FOLDER}")
loading_status = "Ready"
logger.info(">>> ENGINE ONLINE <<<")
except Exception as e:
loading_status = f"Engine Error: {str(e)}"
logger.error(f"Engine Failed: {e}")
except Exception as e:
loading_status = f"Sync Error: {str(e)}"
logger.error(f"Sync Failed: {e}")
@app.on_event("startup")
async def startup_event():
thread = threading.Thread(target=background_sync, daemon=True)
thread.start()
# Mount Static Images
app.mount("/Productimages", StaticFiles(directory=IMAGE_DIR), name="Productimages")
# Serve UI (Root Endpoint)
@app.get("/")
async def read_index():
if os.path.exists('index.html'):
return FileResponse('index.html')
return {"status": "Online", "message": "index.html not found, but server is running."}
@app.get("/health")
def health():
return {"status": loading_status, "ready": engine is not None}
# 4. FIXED SEARCH LOGIC
@app.post("/search")
async def search(text: str = Form(None), weight: float = Form(0.5), file: UploadFile = File(None)):
if not engine:
raise HTTPException(status_code=503, detail=f"Engine not ready: {loading_status}")
t_path = f"buffer_{os.getpid()}.jpg" if file else None
try:
actual_weight = weight
if not text and file: actual_weight = 0.0
if text and not file: actual_weight = 1.0
if file and t_path:
content = await file.read()
async with await anyio.open_file(t_path, "wb") as f:
await f.write(content)
# CORRECTED: Calling engine WITHOUT top_k
results = await anyio.to_thread.run_sync(
lambda: engine.search(
text_query=text,
image_file=t_path,
text_weight=actual_weight
)
)
all_files = os.listdir(IMAGE_DIR)
final_list = []
seen_ids = set()
for r in results:
score = r.get('score', 0)
pid = r.get('id', 'Product')
if score < MIN_CONFIDENCE_THRESHOLD or pid in seen_ids:
continue
# The engine returns a clean ID/path, let's match it to disk
fname_from_db = os.path.basename(r.get('id', ''))
fname = normalize_filename(fname_from_db)
match = None
if fname in all_files:
match = fname
else:
# Fuzzy fallback if exact match fails
for disk_f in all_files:
if fname[:15].lower() in disk_f.lower():
match = disk_f
break
if match:
final_list.append({
"id": pid,
# Ensure URL is properly encoded for web
"url": f"Productimages/{urllib.parse.quote(match)}",
"score": round(float(score), 4)
})
seen_ids.add(pid)
final_list.sort(key=lambda x: x['score'], reverse=True)
return {"results": final_list[:20]}
except Exception as e:
logger.error(f"Search Failure: {e}")
return {"results": [], "error": str(e)}
finally:
if t_path and os.path.exists(t_path):
try: os.remove(t_path)
except: pass
if __name__ == "__main__":
import uvicorn
# FIXED: Listen on all interfaces (0.0.0.0) and correct port 7860
uvicorn.run(app, host="0.0.0.0", port=7860)