Spaces:
Sleeping
Sleeping
root16285
Fix: Use ultralytics YOLO API instead of DetectMultiBackend to avoid export module dependency
ebfcb4b
| """ | |
| YOLOv5 Web Application Backend with FastAPI | |
| Real-time object detection with WebSocket support | |
| Utilise les fichiers YOLOv5 locaux directement | |
| """ | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import os | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI, File, UploadFile, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, HTMLResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| # YOLOv5 imports - Use ultralytics package (more reliable for deployment) | |
| import sys | |
| # Import YOLOv5 via ultralytics package | |
| try: | |
| from ultralytics import YOLO | |
| YOLOV5_AVAILABLE = True | |
| print("✅ Ultralytics YOLO importé avec succès") | |
| except Exception as e: | |
| print(f"❌ Erreur d'import Ultralytics: {e}") | |
| YOLOV5_AVAILABLE = False | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="ZKA Detection API", | |
| description="API de détection d'objets en temps réel avec IA", | |
| version="1.0.0" | |
| ) | |
| # Mount static files | |
| STATIC_DIR = Path(__file__).parent.parent / "static" | |
| if not STATIC_DIR.exists(): | |
| STATIC_DIR.mkdir(parents=True, exist_ok=True) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Traduction des classes COCO en français | |
| CLASS_NAMES_FR = { | |
| 'person': 'personne', | |
| 'bicycle': 'vélo', | |
| 'car': 'voiture', | |
| 'motorcycle': 'moto', | |
| 'airplane': 'avion', | |
| 'bus': 'bus', | |
| 'train': 'train', | |
| 'truck': 'camion', | |
| 'boat': 'bateau', | |
| 'traffic light': 'feu de circulation', | |
| 'fire hydrant': 'borne d\'incendie', | |
| 'stop sign': 'panneau stop', | |
| 'parking meter': 'parcomètre', | |
| 'bench': 'banc', | |
| 'bird': 'oiseau', | |
| 'cat': 'chat', | |
| 'dog': 'chien', | |
| 'horse': 'cheval', | |
| 'sheep': 'mouton', | |
| 'cow': 'vache', | |
| 'elephant': 'éléphant', | |
| 'bear': 'ours', | |
| 'zebra': 'zèbre', | |
| 'giraffe': 'girafe', | |
| 'backpack': 'sac à dos', | |
| 'umbrella': 'parapluie', | |
| 'handbag': 'sac à main', | |
| 'tie': 'cravate', | |
| 'suitcase': 'valise', | |
| 'frisbee': 'frisbee', | |
| 'skis': 'skis', | |
| 'snowboard': 'snowboard', | |
| 'sports ball': 'ballon', | |
| 'kite': 'cerf-volant', | |
| 'baseball bat': 'batte de baseball', | |
| 'baseball glove': 'gant de baseball', | |
| 'skateboard': 'skateboard', | |
| 'surfboard': 'planche de surf', | |
| 'tennis racket': 'raquette de tennis', | |
| 'bottle': 'bouteille', | |
| 'wine glass': 'verre à vin', | |
| 'cup': 'tasse', | |
| 'fork': 'fourchette', | |
| 'knife': 'couteau', | |
| 'spoon': 'cuillère', | |
| 'bowl': 'bol', | |
| 'banana': 'banane', | |
| 'apple': 'pomme', | |
| 'sandwich': 'sandwich', | |
| 'orange': 'orange', | |
| 'broccoli': 'brocoli', | |
| 'carrot': 'carotte', | |
| 'hot dog': 'hot-dog', | |
| 'pizza': 'pizza', | |
| 'donut': 'donut', | |
| 'cake': 'gâteau', | |
| 'chair': 'chaise', | |
| 'couch': 'canapé', | |
| 'potted plant': 'plante en pot', | |
| 'bed': 'lit', | |
| 'dining table': 'table à manger', | |
| 'toilet': 'toilettes', | |
| 'tv': 'télévision', | |
| 'laptop': 'ordinateur portable', | |
| 'mouse': 'souris', | |
| 'remote': 'télécommande', | |
| 'keyboard': 'clavier', | |
| 'cell phone': 'téléphone portable', | |
| 'microwave': 'micro-ondes', | |
| 'oven': 'four', | |
| 'toaster': 'grille-pain', | |
| 'sink': 'évier', | |
| 'refrigerator': 'réfrigérateur', | |
| 'book': 'livre', | |
| 'clock': 'horloge', | |
| 'vase': 'vase', | |
| 'scissors': 'ciseaux', | |
| 'teddy bear': 'ours en peluche', | |
| 'hair drier': 'sèche-cheveux', | |
| 'toothbrush': 'brosse à dents' | |
| } | |
| # Mount static files for JS/CSS | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |
| # Global variables | |
| models_cache = {} | |
| device = None # Will be set when first model is loaded | |
| detection_history = [] | |
| statistics = { | |
| "total_detections": 0, | |
| "total_images_processed": 0, | |
| "objects_detected": {}, | |
| "avg_confidence": 0, | |
| "processing_times": [] | |
| } | |
| # WebSocket connection manager | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: List[WebSocket] = [] | |
| async def connect(self, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections.append(websocket) | |
| def disconnect(self, websocket: WebSocket): | |
| if websocket in self.active_connections: | |
| self.active_connections.remove(websocket) | |
| async def broadcast(self, message: dict): | |
| for connection in self.active_connections: | |
| try: | |
| await connection.send_json(message) | |
| except: | |
| pass | |
| manager = ConnectionManager() | |
| def load_model(model_name: str = "yolov5s"): | |
| """Load YOLOv5 model using ultralytics""" | |
| global device | |
| if not YOLOV5_AVAILABLE: | |
| raise Exception("Ultralytics YOLO non disponible - vérifiez l'installation") | |
| # Initialize device on first load | |
| if device is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"📊 Device sélectionné: {device}") | |
| if model_name not in models_cache: | |
| try: | |
| print(f"🔄 Chargement du modèle {model_name}...") | |
| # Try to load from local file first | |
| model_path = f"/app/{model_name}.pt" | |
| if not os.path.exists(model_path): | |
| # Try in root directory | |
| model_path = f"{model_name}.pt" | |
| if not os.path.exists(model_path): | |
| # Download from ultralytics hub | |
| print(f"⚠️ {model_path} non trouvé, téléchargement depuis ultralytics...") | |
| model = YOLO(f"{model_name}.pt") | |
| else: | |
| model = YOLO(model_path) | |
| models_cache[model_name] = model | |
| print(f"✅ Modèle {model_name} chargé avec succès!") | |
| except Exception as e: | |
| print(f"❌ Erreur lors du chargement du modèle {model_name}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Fallback to yolov5s | |
| if model_name != "yolov5s": | |
| return load_model("yolov5s") | |
| raise | |
| return models_cache[model_name] | |
| def process_image(image: Image.Image, model_name: str = "yolov5s", conf_threshold: float = 0.25): | |
| """Process image and return detections using ultralytics API""" | |
| start_time = time.time() | |
| try: | |
| # Load model | |
| model = load_model(model_name) | |
| # Convert PIL to numpy array | |
| img_array = np.array(image) | |
| # Run inference using ultralytics API | |
| results = model(img_array, conf=conf_threshold, iou=0.45, verbose=False) | |
| # Process results | |
| detections = [] | |
| img_display = img_array.copy() | |
| # Get the first result (single image) | |
| result = results[0] | |
| if result.boxes is not None and len(result.boxes) > 0: | |
| boxes = result.boxes | |
| for i in range(len(boxes)): | |
| # Get box coordinates (xyxy format) | |
| box = boxes.xyxy[i].cpu().numpy() | |
| x1, y1, x2, y2 = map(int, box) | |
| # Get confidence and class | |
| confidence = float(boxes.conf[i].cpu().numpy()) | |
| class_id = int(boxes.cls[i].cpu().numpy()) | |
| # Get class name | |
| class_name_en = result.names[class_id] | |
| class_name = CLASS_NAMES_FR.get(class_name_en, class_name_en) | |
| detections.append({ | |
| "bbox": [x1, y1, x2, y2], | |
| "confidence": confidence, | |
| "class": class_name, | |
| "class_id": class_id | |
| }) | |
| # Draw on image | |
| cv2.rectangle(img_display, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| label = f"{class_name} {confidence:.2f}" | |
| (label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2) | |
| cv2.rectangle(img_display, (x1, y1 - label_h - 10), (x1 + label_w, y1), (0, 255, 0), -1) | |
| cv2.putText(img_display, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) | |
| # Convert back to PIL | |
| img_pil = Image.fromarray(img_display) | |
| processing_time = time.time() - start_time | |
| print(f"✅ Détection terminée: {len(detections)} objets trouvés en {processing_time:.3f}s") | |
| # Update statistics | |
| update_statistics(detections, processing_time) | |
| return { | |
| "detections": detections, | |
| "image": img_pil, | |
| "processing_time": processing_time, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| print(f"❌ Erreur lors du traitement de l'image: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Return original image if error | |
| return { | |
| "detections": [], | |
| "image": image, | |
| "processing_time": time.time() - start_time, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| return { | |
| "detections": [], | |
| "image": image, | |
| "processing_time": time.time() - start_time, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| def update_statistics(detections: List[dict], processing_time: float): | |
| """Update global statistics""" | |
| statistics["total_images_processed"] += 1 | |
| statistics["total_detections"] += len(detections) | |
| statistics["processing_times"].append(processing_time) | |
| # Keep only last 100 processing times | |
| if len(statistics["processing_times"]) > 100: | |
| statistics["processing_times"] = statistics["processing_times"][-100:] | |
| # Update object counts | |
| for det in detections: | |
| class_name = det["class"] | |
| statistics["objects_detected"][class_name] = statistics["objects_detected"].get(class_name, 0) + 1 | |
| # Update average confidence | |
| if detections: | |
| avg_conf = sum(d["confidence"] for d in detections) / len(detections) | |
| statistics["avg_confidence"] = avg_conf | |
| async def root(): | |
| """Serve the main HTML page""" | |
| html_path = STATIC_DIR / "index.html" | |
| if html_path.exists(): | |
| return FileResponse(html_path) | |
| return JSONResponse({"message": "YOLOv5 Detection API", "version": "1.0.0"}) | |
| async def api_info(): | |
| """API information endpoint""" | |
| return { | |
| "message": "ZKA Detection API", | |
| "version": "1.0.0", | |
| "yolov5_available": YOLOV5_AVAILABLE, | |
| "endpoints": { | |
| "detect": "/detect", | |
| "detect_batch": "/detect/batch", | |
| "statistics": "/statistics", | |
| "history": "/history", | |
| "models": "/models", | |
| "websocket": "/ws" | |
| } | |
| } | |
| async def detect_image( | |
| file: UploadFile = File(...), | |
| model: str = "yolov5s", | |
| confidence: float = 0.35 | |
| ): | |
| """Detect objects in uploaded image""" | |
| try: | |
| # Read image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Process image | |
| result = process_image(image, model, confidence) | |
| # Convert image to base64 | |
| buffered = io.BytesIO() | |
| result["image"].save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Add to history | |
| history_entry = { | |
| "timestamp": result["timestamp"], | |
| "filename": file.filename, | |
| "detections": result["detections"], | |
| "processing_time": result["processing_time"] | |
| } | |
| detection_history.append(history_entry) | |
| if len(detection_history) > 100: | |
| detection_history.pop(0) | |
| return JSONResponse({ | |
| "success": True, | |
| "detections": result["detections"], | |
| "image": img_str, | |
| "processing_time": result["processing_time"], | |
| "timestamp": result["timestamp"] | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return JSONResponse({ | |
| "success": False, | |
| "error": str(e) | |
| }, status_code=500) | |
| async def detect_batch(files: List[UploadFile] = File(...), model: str = "yolov5s", confidence: float = 0.35): | |
| """Detect objects in multiple images""" | |
| results = [] | |
| for file in files: | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| result = process_image(image, model, confidence) | |
| buffered = io.BytesIO() | |
| result["image"].save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| results.append({ | |
| "filename": file.filename, | |
| "detections": result["detections"], | |
| "image": img_str, | |
| "processing_time": result["processing_time"] | |
| }) | |
| except Exception as e: | |
| results.append({ | |
| "filename": file.filename, | |
| "error": str(e) | |
| }) | |
| return JSONResponse({"success": True, "results": results}) | |
| async def get_statistics(): | |
| """Get detection statistics""" | |
| avg_processing_time = sum(statistics["processing_times"]) / len(statistics["processing_times"]) if statistics["processing_times"] else 0 | |
| return JSONResponse({ | |
| "total_detections": statistics["total_detections"], | |
| "total_images_processed": statistics["total_images_processed"], | |
| "objects_detected": statistics["objects_detected"], | |
| "avg_confidence": statistics["avg_confidence"], | |
| "avg_processing_time": avg_processing_time, | |
| "fps": 1 / avg_processing_time if avg_processing_time > 0 else 0 | |
| }) | |
| async def get_history(limit: int = 50): | |
| """Get detection history""" | |
| return JSONResponse({ | |
| "history": detection_history[-limit:] | |
| }) | |
| async def clear_history(): | |
| """Clear detection history""" | |
| detection_history.clear() | |
| return JSONResponse({"success": True, "message": "History cleared"}) | |
| async def list_models(): | |
| """List available models""" | |
| models = ["yolov5n", "yolov5s", "yolov5m", "yolov5l", "yolov5x"] | |
| return JSONResponse({"models": models}) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket endpoint for real-time detection""" | |
| await manager.connect(websocket) | |
| try: | |
| while True: | |
| # Receive frame from client | |
| data = await websocket.receive_text() | |
| frame_data = json.loads(data) | |
| # Decode base64 image | |
| img_data = base64.b64decode(frame_data["frame"].split(",")[1]) | |
| image = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| # Process image | |
| model_name = frame_data.get("model", "yolov5s") | |
| confidence = frame_data.get("confidence", 0.25) | |
| result = process_image(image, model_name, confidence) | |
| # Convert result image to base64 | |
| buffered = io.BytesIO() | |
| result["image"].save(buffered, format="JPEG", quality=85) | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Send result back | |
| await websocket.send_json({ | |
| "detections": result["detections"], | |
| "image": f"data:image/jpeg;base64,{img_str}", | |
| "processing_time": result["processing_time"], | |
| "timestamp": result["timestamp"] | |
| }) | |
| except WebSocketDisconnect: | |
| manager.disconnect(websocket) | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| manager.disconnect(websocket) | |
| if __name__ == "__main__": | |
| print("🚀 Démarrage du serveur ZKA...") | |
| if not YOLOV5_AVAILABLE: | |
| print("❌ YOLOv5 non disponible. Vérifiez l'installation.") | |
| sys.exit(1) | |
| # Précharger le modèle yolov5s au démarrage | |
| print("\n📦 Préchargement du modèle IA...") | |
| try: | |
| load_model("yolov5s") | |
| print("✅ Modèle chargé avec succès!\n") | |
| except Exception as e: | |
| print(f"❌ Erreur lors du préchargement: {e}") | |
| print("⚠️ Le serveur démarre quand même, le modèle se chargera à la première utilisation\n") | |
| print(f"🔗 Documentation API: http://localhost:8001/docs") | |
| print(f"🌐 Application ZKA: http://localhost:8001") | |
| print("\n💡 Appuyez sur Ctrl+C pour arrêter le serveur\n") | |
| uvicorn.run(app, host="0.0.0.0", port=8001) | |