from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel from typing import List, Union import torch from transformers import AutoImageProcessor, AutoModel from ultralytics import YOLO import faiss import numpy as np from PIL import Image import io import os app = FastAPI( title="Devam Jersey Server", description="Jersey similarity and detection server using YOLOv8 and DINOv2", version="1.0.0" ) # Load models and index ONCE at startup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize models with error handling try: processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device) print(f"✅ DINOv2 model loaded successfully on {device}") except Exception as e: print(f"❌ Error loading DINOv2 model: {e}") processor = None dino_model = None try: yolo_model = YOLO("models/deepfashion2_yolov8s-seg.pt") print("✅ YOLOv8 model loaded successfully") except Exception as e: print(f"❌ Error loading YOLOv8 model: {e}") yolo_model = None # Initialize FAISS index and metadata faiss_index = None index_to_path = {} try: if os.path.exists("index/jersey_index.faiss"): faiss_index = faiss.read_index("index/jersey_index.faiss") print("✅ FAISS index loaded successfully") else: # Create a dummy index if the real one doesn't exist print("⚠️ FAISS index not found, creating dummy index") dimension = 768 # DINOv2 base dimension faiss_index = faiss.IndexFlatIP(dimension) faiss_index.add(np.random.random((100, dimension)).astype('float32')) if os.path.exists("index/jersey_metadata.npy"): loaded_data = np.load("index/jersey_metadata.npy", allow_pickle=True) if isinstance(loaded_data, dict): index_to_path = {int(k): v for k, v in loaded_data.items()} elif isinstance(loaded_data, np.ndarray): index_to_path = {i: str(item) for i, item in enumerate(loaded_data)} print("✅ Jersey metadata loaded successfully") else: print("⚠️ Jersey metadata not found, using dummy data") index_to_path = {i: f"jersey_{i}.jpg" for i in range(100)} except Exception as e: print(f"❌ Error loading FAISS index or metadata: {e}") class FeaturesRequest(BaseModel): features: Union[List[float], List[List[float]]] @app.get("/") async def root(): return { "message": "Devam Jersey Server", "status": "running", "endpoints": ["/dino", "/faiss", "/yolo"], "models_loaded": { "dino": dino_model is not None, "yolo": yolo_model is not None, "faiss": faiss_index is not None } } @app.post("/dino") async def dino_inference(file: UploadFile = File(...)): if dino_model is None: raise HTTPException(status_code=500, detail="DINOv2 model not loaded") try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert('RGB') with torch.no_grad(): inputs = processor(images=image, return_tensors="pt").to(device) outputs = dino_model(**inputs) features = outputs.last_hidden_state.mean(dim=1).detach().cpu().numpy()[0] return {"features": features.tolist()} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.post("/faiss") async def faiss_search(request: FeaturesRequest): if faiss_index is None: raise HTTPException(status_code=500, detail="FAISS index not loaded") try: features = request.features if isinstance(features[0], list): vector = np.array(features, dtype=np.float32) else: vector = np.array([features], dtype=np.float32) if vector.shape[1] != faiss_index.d: error_msg = f"Feature vector length {vector.shape[1]} does not match FAISS index dimension {faiss_index.d}" return {"error": error_msg} faiss.normalize_L2(vector) distances, indices = faiss_index.search(vector, 15) results = [] for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): if idx in index_to_path: key = idx results.append({ "rank": i + 1, "distance": float(distance), "file_path": index_to_path[key], "full_path": f"catalogue/{index_to_path[key]}" }) return {"results": results} except Exception as e: raise HTTPException(status_code=500, detail=f"Error searching FAISS index: {str(e)}") @app.post("/yolo") async def yolo_inference(file: UploadFile = File(...)): if yolo_model is None: raise HTTPException(status_code=500, detail="YOLOv8 model not loaded") try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert('RGB') results = yolo_model(image, device=0 if torch.cuda.is_available() else 'cpu', verbose=False)[0] polygons = [] if hasattr(results, 'masks') and results.masks is not None and hasattr(results.masks, 'xy'): for mask in results.masks.xy: polygons.append(mask.tolist()) return {"polygons": polygons} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image with YOLO: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)