Devam / inference_server.py
Devam0's picture
Fix Python 3.9 compatibility - replace union operator with Union type
89677b3
Raw
History Blame Contribute Delete
5.7 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from typing import List, Union
import torch
from transformers import AutoImageProcessor, AutoModel
from ultralytics import YOLO
import faiss
import numpy as np
from PIL import Image
import io
import os
app = FastAPI(
title="Devam Jersey Server",
description="Jersey similarity and detection server using YOLOv8 and DINOv2",
version="1.0.0"
)
# Load models and index ONCE at startup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models with error handling
try:
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
print(f"βœ… DINOv2 model loaded successfully on {device}")
except Exception as e:
print(f"❌ Error loading DINOv2 model: {e}")
processor = None
dino_model = None
try:
yolo_model = YOLO("models/deepfashion2_yolov8s-seg.pt")
print("βœ… YOLOv8 model loaded successfully")
except Exception as e:
print(f"❌ Error loading YOLOv8 model: {e}")
yolo_model = None
# Initialize FAISS index and metadata
faiss_index = None
index_to_path = {}
try:
if os.path.exists("index/jersey_index.faiss"):
faiss_index = faiss.read_index("index/jersey_index.faiss")
print("βœ… FAISS index loaded successfully")
else:
# Create a dummy index if the real one doesn't exist
print("⚠️ FAISS index not found, creating dummy index")
dimension = 768 # DINOv2 base dimension
faiss_index = faiss.IndexFlatIP(dimension)
faiss_index.add(np.random.random((100, dimension)).astype('float32'))
if os.path.exists("index/jersey_metadata.npy"):
loaded_data = np.load("index/jersey_metadata.npy", allow_pickle=True)
if isinstance(loaded_data, dict):
index_to_path = {int(k): v for k, v in loaded_data.items()}
elif isinstance(loaded_data, np.ndarray):
index_to_path = {i: str(item) for i, item in enumerate(loaded_data)}
print("βœ… Jersey metadata loaded successfully")
else:
print("⚠️ Jersey metadata not found, using dummy data")
index_to_path = {i: f"jersey_{i}.jpg" for i in range(100)}
except Exception as e:
print(f"❌ Error loading FAISS index or metadata: {e}")
class FeaturesRequest(BaseModel):
features: Union[List[float], List[List[float]]]
@app.get("/")
async def root():
return {
"message": "Devam Jersey Server",
"status": "running",
"endpoints": ["/dino", "/faiss", "/yolo"],
"models_loaded": {
"dino": dino_model is not None,
"yolo": yolo_model is not None,
"faiss": faiss_index is not None
}
}
@app.post("/dino")
async def dino_inference(file: UploadFile = File(...)):
if dino_model is None:
raise HTTPException(status_code=500, detail="DINOv2 model not loaded")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = dino_model(**inputs)
features = outputs.last_hidden_state.mean(dim=1).detach().cpu().numpy()[0]
return {"features": features.tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/faiss")
async def faiss_search(request: FeaturesRequest):
if faiss_index is None:
raise HTTPException(status_code=500, detail="FAISS index not loaded")
try:
features = request.features
if isinstance(features[0], list):
vector = np.array(features, dtype=np.float32)
else:
vector = np.array([features], dtype=np.float32)
if vector.shape[1] != faiss_index.d:
error_msg = f"Feature vector length {vector.shape[1]} does not match FAISS index dimension {faiss_index.d}"
return {"error": error_msg}
faiss.normalize_L2(vector)
distances, indices = faiss_index.search(vector, 15)
results = []
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
if idx in index_to_path:
key = idx
results.append({
"rank": i + 1,
"distance": float(distance),
"file_path": index_to_path[key],
"full_path": f"catalogue/{index_to_path[key]}"
})
return {"results": results}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error searching FAISS index: {str(e)}")
@app.post("/yolo")
async def yolo_inference(file: UploadFile = File(...)):
if yolo_model is None:
raise HTTPException(status_code=500, detail="YOLOv8 model not loaded")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
results = yolo_model(image, device=0 if torch.cuda.is_available() else 'cpu', verbose=False)[0]
polygons = []
if hasattr(results, 'masks') and results.masks is not None and hasattr(results.masks, 'xy'):
for mask in results.masks.xy:
polygons.append(mask.tolist())
return {"polygons": polygons}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image with YOLO: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)