Spaces:
Sleeping
Sleeping
| """ | |
| Production-focused FastAPI wrapper for SmolVLM2 video highlights. | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Dict, Optional | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| # Set cache directories to writable locations for HuggingFace Spaces | |
| # Use /tmp which is guaranteed to be writable in containers | |
| CACHE_DIR = os.path.join("/tmp", ".cache", "huggingface") | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.makedirs(os.path.join("/tmp", ".cache", "torch"), exist_ok=True) | |
| os.environ["HF_HOME"] = CACHE_DIR | |
| os.environ["HF_DATASETS_CACHE"] = CACHE_DIR | |
| os.environ["TORCH_HOME"] = os.path.join("/tmp", ".cache", "torch") | |
| os.environ["XDG_CACHE_HOME"] = os.path.join("/tmp", ".cache") | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Add src directory to path for imports | |
| sys.path.append(str(Path(__file__).parent / "src")) | |
| try: | |
| from huggingface_exact_approach import VideoHighlightDetector | |
| except ImportError: | |
| print("Cannot import huggingface_exact_approach.py") | |
| sys.exit(1) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Runtime configuration | |
| APP_START_TIME = time.time() | |
| DEFAULT_MODEL = os.getenv("DEFAULT_MODEL_NAME", "HuggingFaceTB/SmolVLM2-256M-Video-Instruct") | |
| MODEL_DEVICE = os.getenv("MODEL_DEVICE", "auto").lower() | |
| MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(512 * 1024 * 1024))) # 512MB | |
| MAX_CONCURRENT_JOBS = int(os.getenv("MAX_CONCURRENT_JOBS", "1")) | |
| PROCESS_TIMEOUT_SECONDS = int(os.getenv("PROCESS_TIMEOUT_SECONDS", "3600")) | |
| # Directories | |
| TEMP_DIR = os.path.join("/tmp", "temp") | |
| OUTPUTS_DIR = os.path.join("/tmp", "outputs") | |
| os.makedirs(OUTPUTS_DIR, mode=0o755, exist_ok=True) | |
| os.makedirs(TEMP_DIR, mode=0o755, exist_ok=True) | |
| if MODEL_DEVICE not in {"auto", "cpu", "cuda", "mps"}: | |
| raise RuntimeError(f"Invalid MODEL_DEVICE '{MODEL_DEVICE}'. Use auto/cpu/cuda/mps.") | |
| class AnalysisResponse(BaseModel): | |
| success: bool | |
| message: str | |
| video_description: str | |
| highlights: str | |
| analysis_file: str | |
| def _sentence_count(text: str) -> int: | |
| return len([s.strip() for s in re.split(r"[.!?]+", text or "") if s.strip()]) | |
| def _device_for_detector() -> Optional[str]: | |
| return None if MODEL_DEVICE == "auto" else MODEL_DEVICE | |
| class DetectorRegistry: | |
| """In-memory singleton detector registry keyed by model name.""" | |
| def __init__(self) -> None: | |
| self._detectors: Dict[str, VideoHighlightDetector] = {} | |
| self._lock = asyncio.Lock() | |
| async def get(self, model_name: str) -> VideoHighlightDetector: | |
| if model_name in self._detectors: | |
| return self._detectors[model_name] | |
| async with self._lock: | |
| # Double-check after lock acquire. | |
| if model_name in self._detectors: | |
| return self._detectors[model_name] | |
| logger.info("Loading detector model '%s' (device=%s)", model_name, MODEL_DEVICE) | |
| detector = await run_in_threadpool( | |
| VideoHighlightDetector, | |
| model_name, | |
| _device_for_detector(), | |
| 16, | |
| ) | |
| self._detectors[model_name] = detector | |
| logger.info("Model '%s' loaded and cached", model_name) | |
| return detector | |
| async def warmup(self, model_name: str) -> None: | |
| await self.get(model_name) | |
| def loaded_models(self) -> Dict[str, str]: | |
| return {name: getattr(detector, "device", "unknown") for name, detector in self._detectors.items()} | |
| detector_registry = DetectorRegistry() | |
| processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_JOBS) | |
| app = FastAPI( | |
| title="SmolVLM2 Optimized HuggingFace Video Highlights API", | |
| description="Generate intelligent video highlights using SmolVLM2 segment-based approach", | |
| version="3.0.0", | |
| openapi_url=None, | |
| docs_url=None, | |
| redoc_url=None, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["*"], | |
| ) | |
| async def _startup() -> None: | |
| logger.info("Startup: default_model=%s, model_device=%s", DEFAULT_MODEL, MODEL_DEVICE) | |
| try: | |
| await detector_registry.warmup(DEFAULT_MODEL) | |
| except Exception: | |
| logger.exception("Model warmup failed") | |
| async def _save_upload_stream(upload: UploadFile, path: str) -> int: | |
| size = 0 | |
| chunk_size = 1024 * 1024 | |
| with open(path, "wb") as buffer: | |
| while True: | |
| chunk = await upload.read(chunk_size) | |
| if not chunk: | |
| break | |
| size += len(chunk) | |
| if size > MAX_UPLOAD_BYTES: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"Uploaded file too large. Max size is {MAX_UPLOAD_BYTES} bytes.", | |
| ) | |
| buffer.write(chunk) | |
| return size | |
| async def _acquire_processing_slot() -> None: | |
| try: | |
| await asyncio.wait_for(processing_semaphore.acquire(), timeout=0.05) | |
| except asyncio.TimeoutError: | |
| raise HTTPException(status_code=429, detail="Server is busy. Try again shortly.") | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "uptime_seconds": int(time.time() - APP_START_TIME), | |
| "default_model": DEFAULT_MODEL, | |
| "loaded_models": detector_registry.loaded_models(), | |
| } | |
| async def root(): | |
| return { | |
| "service": "SmolVLM2 Video Highlights API", | |
| "status": "ok", | |
| "health": "/health", | |
| "ready": "/ready", | |
| "upload": "/upload-video", | |
| } | |
| async def readiness_check(): | |
| loaded = detector_registry.loaded_models() | |
| ready = DEFAULT_MODEL in loaded | |
| return { | |
| "status": "ready" if ready else "not_ready", | |
| "default_model": DEFAULT_MODEL, | |
| "loaded_models": loaded, | |
| } | |
| async def get_output_file(filename: str): | |
| safe_name = os.path.basename(filename) | |
| file_path = os.path.join(OUTPUTS_DIR, safe_name) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse(path=file_path, filename=safe_name) | |
| async def upload_video( | |
| video: UploadFile = File(...), | |
| segment_length: float = 5.0, | |
| model_name: str = DEFAULT_MODEL, | |
| with_effects: bool = True, | |
| ): | |
| if not video.content_type or not video.content_type.startswith("video/"): | |
| raise HTTPException(status_code=400, detail="File must be a video") | |
| if segment_length <= 0: | |
| raise HTTPException(status_code=400, detail="segment_length must be > 0") | |
| job_id = str(uuid.uuid4()) | |
| temp_video_path = os.path.join(TEMP_DIR, f"{job_id}_input.mp4") | |
| output_filename = f"{job_id}_highlights.mp4" | |
| analysis_filename = f"{job_id}_analysis.json" | |
| output_path = os.path.join(OUTPUTS_DIR, output_filename) | |
| analysis_path = os.path.join(OUTPUTS_DIR, analysis_filename) | |
| await _acquire_processing_slot() | |
| try: | |
| await _save_upload_stream(video, temp_video_path) | |
| detector = await detector_registry.get(model_name) | |
| results = await asyncio.wait_for( | |
| run_in_threadpool( | |
| detector.process_video, | |
| temp_video_path, | |
| output_path, | |
| segment_length, | |
| with_effects, | |
| ), | |
| timeout=PROCESS_TIMEOUT_SECONDS, | |
| ) | |
| if "error" in results: | |
| raise HTTPException(status_code=500, detail=results["error"]) | |
| selected_set = str(results.get("selected_set", "")).strip() | |
| h1 = results.get("highlights1", "") | |
| h2 = results.get("highlights2", "") | |
| base_desc = results.get("video_description", "") | |
| if selected_set == "1": | |
| enriched_description = h1 | |
| elif selected_set == "2": | |
| enriched_description = h2 | |
| else: | |
| enriched_description = h1 or h2 or base_desc | |
| if _sentence_count(h1) > _sentence_count(enriched_description): | |
| enriched_description = h1 | |
| if _sentence_count(h2) > _sentence_count(enriched_description): | |
| enriched_description = h2 | |
| if not enriched_description: | |
| enriched_description = base_desc | |
| logger.info( | |
| "API response selected_set=%s video_description=%s", | |
| selected_set or "fallback", | |
| enriched_description, | |
| ) | |
| results["video_description"] = enriched_description | |
| with open(analysis_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| return AnalysisResponse( | |
| success=True, | |
| message="Video description generated successfully", | |
| video_description=enriched_description, | |
| highlights=f"/tmp/outputs/{output_filename}", | |
| analysis_file=f"/tmp/outputs/{analysis_filename}", | |
| ) | |
| except asyncio.TimeoutError: | |
| raise HTTPException(status_code=504, detail="Processing timed out") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("Upload processing failed") | |
| raise HTTPException(status_code=500, detail=f"Failed to process upload: {str(e)}") | |
| finally: | |
| processing_semaphore.release() | |
| try: | |
| await video.close() | |
| except Exception: | |
| pass | |
| if os.path.exists(temp_video_path): | |
| os.unlink(temp_video_path) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |