| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Union |
| import torch |
| from transformers import AutoImageProcessor, AutoModel |
| from ultralytics import YOLO |
| import faiss |
| import numpy as np |
| from PIL import Image |
| import io |
| import os |
|
|
| app = FastAPI( |
| title="Devam Jersey Server", |
| description="Jersey similarity and detection server using YOLOv8 and DINOv2", |
| version="1.0.0" |
| ) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| try: |
| processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') |
| dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device) |
| print(f"β
DINOv2 model loaded successfully on {device}") |
| except Exception as e: |
| print(f"β Error loading DINOv2 model: {e}") |
| processor = None |
| dino_model = None |
|
|
| try: |
| yolo_model = YOLO("models/deepfashion2_yolov8s-seg.pt") |
| print("β
YOLOv8 model loaded successfully") |
| except Exception as e: |
| print(f"β Error loading YOLOv8 model: {e}") |
| yolo_model = None |
|
|
| |
| faiss_index = None |
| index_to_path = {} |
|
|
| try: |
| if os.path.exists("index/jersey_index.faiss"): |
| faiss_index = faiss.read_index("index/jersey_index.faiss") |
| print("β
FAISS index loaded successfully") |
| else: |
| |
| print("β οΈ FAISS index not found, creating dummy index") |
| dimension = 768 |
| faiss_index = faiss.IndexFlatIP(dimension) |
| faiss_index.add(np.random.random((100, dimension)).astype('float32')) |
| |
| if os.path.exists("index/jersey_metadata.npy"): |
| loaded_data = np.load("index/jersey_metadata.npy", allow_pickle=True) |
| if isinstance(loaded_data, dict): |
| index_to_path = {int(k): v for k, v in loaded_data.items()} |
| elif isinstance(loaded_data, np.ndarray): |
| index_to_path = {i: str(item) for i, item in enumerate(loaded_data)} |
| print("β
Jersey metadata loaded successfully") |
| else: |
| print("β οΈ Jersey metadata not found, using dummy data") |
| index_to_path = {i: f"jersey_{i}.jpg" for i in range(100)} |
| |
| except Exception as e: |
| print(f"β Error loading FAISS index or metadata: {e}") |
|
|
| class FeaturesRequest(BaseModel): |
| features: Union[List[float], List[List[float]]] |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "Devam Jersey Server", |
| "status": "running", |
| "endpoints": ["/dino", "/faiss", "/yolo"], |
| "models_loaded": { |
| "dino": dino_model is not None, |
| "yolo": yolo_model is not None, |
| "faiss": faiss_index is not None |
| } |
| } |
|
|
| @app.post("/dino") |
| async def dino_inference(file: UploadFile = File(...)): |
| if dino_model is None: |
| raise HTTPException(status_code=500, detail="DINOv2 model not loaded") |
| |
| try: |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| with torch.no_grad(): |
| inputs = processor(images=image, return_tensors="pt").to(device) |
| outputs = dino_model(**inputs) |
| features = outputs.last_hidden_state.mean(dim=1).detach().cpu().numpy()[0] |
| return {"features": features.tolist()} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
| @app.post("/faiss") |
| async def faiss_search(request: FeaturesRequest): |
| if faiss_index is None: |
| raise HTTPException(status_code=500, detail="FAISS index not loaded") |
| |
| try: |
| features = request.features |
| if isinstance(features[0], list): |
| vector = np.array(features, dtype=np.float32) |
| else: |
| vector = np.array([features], dtype=np.float32) |
| |
| if vector.shape[1] != faiss_index.d: |
| error_msg = f"Feature vector length {vector.shape[1]} does not match FAISS index dimension {faiss_index.d}" |
| return {"error": error_msg} |
| |
| faiss.normalize_L2(vector) |
| distances, indices = faiss_index.search(vector, 15) |
| results = [] |
| for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): |
| if idx in index_to_path: |
| key = idx |
| results.append({ |
| "rank": i + 1, |
| "distance": float(distance), |
| "file_path": index_to_path[key], |
| "full_path": f"catalogue/{index_to_path[key]}" |
| }) |
| return {"results": results} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error searching FAISS index: {str(e)}") |
|
|
| @app.post("/yolo") |
| async def yolo_inference(file: UploadFile = File(...)): |
| if yolo_model is None: |
| raise HTTPException(status_code=500, detail="YOLOv8 model not loaded") |
| |
| try: |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| results = yolo_model(image, device=0 if torch.cuda.is_available() else 'cpu', verbose=False)[0] |
| polygons = [] |
| if hasattr(results, 'masks') and results.masks is not None and hasattr(results.masks, 'xy'): |
| for mask in results.masks.xy: |
| polygons.append(mask.tolist()) |
| return {"polygons": polygons} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error processing image with YOLO: {str(e)}") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |