Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import concurrent.futures | |
| # import io # DISABLED (OOM mitigation) — only used by vision | |
| import os | |
| import time | |
| from typing import Annotated, List, Optional | |
| # import httpx # DISABLED (OOM mitigation) — only used by vision | |
| from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status | |
| # from PIL import Image # DISABLED (OOM mitigation) | |
| from app.api.deps import require_auth, get_embeddings_service | |
| from app.config import get_settings | |
| from app.core.logger import get_logger | |
| from app.models.schemas import EmbeddingItem, EmbeddingRequest, EmbeddingResponse # , VisionUrlRequest # DISABLED (OOM mitigation) | |
| from app.services.embeddings_service import EmbeddingService | |
| router = APIRouter() | |
| _logger = get_logger(__name__) | |
| _settings = get_settings() | |
| _MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4) | |
| _thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=_MAX_WORKERS) | |
| # _MAX_VISION_ITEMS = 5 # DISABLED (OOM mitigation) | |
| # _MAX_IMAGE_BYTES = 15 * 1024 * 1024 # DISABLED (OOM mitigation) | |
| # def _validate_image(raw: bytes, source: str) -> Image.Image: # DISABLED (OOM mitigation) | |
| # # FIX: Check if the file is completely empty (0 bytes) | |
| # if not raw: | |
| # raise ValueError(f"File {source} is empty (0 bytes).") | |
| # | |
| # if len(raw) > _MAX_IMAGE_BYTES: | |
| # raise ValueError(f"Image {source} exceeds 15 MB limit") | |
| # | |
| # try: | |
| # img = Image.open(io.BytesIO(raw)) | |
| # img.load() | |
| # if img.mode != "RGB": | |
| # img = img.convert("RGB") | |
| # return img | |
| # except Exception as exc: | |
| # raise ValueError(f"Invalid image {source}: {exc}") | |
| # async def _download_image(url: str) -> bytes: # DISABLED (OOM mitigation) | |
| # try: | |
| # async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: | |
| # resp = await client.get(url) | |
| # resp.raise_for_status() | |
| # ctype = resp.headers.get("content-type", "") | |
| # if not ctype.startswith("image/"): | |
| # raise ValueError(f"URL {url} returned non-image Content-Type: {ctype}") | |
| # return resp.content | |
| # except httpx.HTTPError as exc: | |
| # raise ValueError(f"Failed to download {url}: {exc}") | |
| async def create_embeddings( | |
| body: EmbeddingRequest, | |
| token: str = Depends(require_auth), | |
| embedding_service: EmbeddingService = Depends(get_embeddings_service), | |
| ) -> EmbeddingResponse: | |
| _logger.info("Embedding request: dim=%s, items=%s", body.dimension, len(body.content)) | |
| if not embedding_service.is_loaded(body.dimension): | |
| _logger.error("Model dim=%s not loaded. Loaded: %s", body.dimension, embedding_service.loaded_dimensions) | |
| raise HTTPException( | |
| status_code=503, | |
| detail={ | |
| "success": False, | |
| "message": f"Model for dimension {body.dimension} not loaded. Loaded: {embedding_service.loaded_dimensions}", | |
| }, | |
| ) | |
| start = time.perf_counter() | |
| try: | |
| loop = asyncio.get_running_loop() | |
| vectors = await loop.run_in_executor( | |
| _thread_pool, | |
| embedding_service.generate_embedding, | |
| body.content, | |
| body.dimension, | |
| ) | |
| except Exception as exc: | |
| elapsed = (time.perf_counter() - start) * 1000 | |
| _logger.error("Embedding error: %s", exc) | |
| return EmbeddingResponse( | |
| success=False, | |
| time_ms=round(elapsed, 3), | |
| success_count=0, | |
| failed_count=len(body.content), | |
| error_message=str(exc), | |
| results=[ | |
| EmbeddingItem(success=False, time_ms=0, error_message=str(exc)) | |
| for _ in body.content | |
| ], | |
| ) | |
| total_ms = (time.perf_counter() - start) * 1000 | |
| per_item_ms = total_ms / len(body.content) | |
| results = [ | |
| EmbeddingItem( | |
| success=True, | |
| time_ms=round(per_item_ms, 3), | |
| embeddings=vec, | |
| dimension=body.dimension, | |
| ) | |
| for vec in vectors | |
| ] | |
| _logger.info("Embedding success: dim=%s, items=%s, total_ms=%s", body.dimension, len(results), round(total_ms, 3)) | |
| return EmbeddingResponse( | |
| success=True, | |
| time_ms=round(total_ms, 3), | |
| success_count=len(results), | |
| failed_count=0, | |
| results=results, | |
| ) | |
| # @router.post( # DISABLED (OOM mitigation) | |
| # "/embeddings/vision/file", | |
| # response_model=EmbeddingResponse, | |
| # summary="Generate embeddings from uploaded images", | |
| # ) | |
| # async def create_vision_embeddings_file( | |
| # files: Annotated[List[UploadFile], File(description="Image files to embed (max 5)")], | |
| # token: str = Depends(require_auth), | |
| # embedding_service: EmbeddingService = Depends(get_embeddings_service), | |
| # ) -> EmbeddingResponse: | |
| # if not files: | |
| # raise HTTPException(status_code=400, detail={"success": False, "message": "No files provided."}) | |
| # if len(files) > _MAX_VISION_ITEMS: | |
| # raise HTTPException(status_code=400, detail={"success": False, "message": f"Maximum {_MAX_VISION_ITEMS} images per request."}) | |
| # | |
| # if not embedding_service._vision_loaded: | |
| # raise HTTPException(status_code=503, detail={"success": False, "message": "Vision model not loaded."}) | |
| # | |
| # _logger.info("Vision embedding file request: files=%s", len(files)) | |
| # | |
| # dim = embedding_service.vision_dimension | |
| # start = time.perf_counter() | |
| # images: List[Image.Image] = [] | |
| # item_results: List[EmbeddingItem] = [] | |
| # | |
| # for f in files: | |
| # t0 = time.perf_counter() | |
| # try: | |
| # # FIX: Guarantee the file cursor is at the beginning before reading! | |
| # await f.seek(0) | |
| # raw = await f.read() | |
| # | |
| # img = await asyncio.get_running_loop().run_in_executor(_thread_pool, _validate_image, raw, f.filename or "unknown") | |
| # images.append(img) | |
| # except Exception as exc: | |
| # elapsed = (time.perf_counter() - t0) * 1000 | |
| # item_results.append(EmbeddingItem( | |
| # success=False, | |
| # time_ms=round(elapsed, 3), | |
| # error_message=str(exc), | |
| # )) | |
| # | |
| # if not images: | |
| # total_ms = (time.perf_counter() - start) * 1000 | |
| # return EmbeddingResponse( | |
| # success=False, | |
| # time_ms=round(total_ms, 3), | |
| # success_count=0, | |
| # failed_count=len(item_results), | |
| # error_message="No valid images could be processed.", | |
| # results=item_results, | |
| # ) | |
| # | |
| # try: | |
| # loop = asyncio.get_running_loop() | |
| # vectors = await loop.run_in_executor( | |
| # _thread_pool, | |
| # embedding_service.generate_image_embedding, | |
| # images, | |
| # ) | |
| # except Exception as exc: | |
| # elapsed = (time.perf_counter() - start) * 1000 | |
| # _logger.error("Vision embedding error: %s", exc) | |
| # for _ in range(len(images) - len(item_results)): | |
| # item_results.append(EmbeddingItem(success=False, time_ms=0, error_message=str(exc))) | |
| # return EmbeddingResponse( | |
| # success=False, | |
| # time_ms=round(elapsed, 3), | |
| # success_count=0, | |
| # failed_count=len(item_results), | |
| # error_message=str(exc), | |
| # results=item_results, | |
| # ) | |
| # | |
| # total_ms = (time.perf_counter() - start) * 1000 | |
| # for i, vec in enumerate(vectors): | |
| # item_results.append(EmbeddingItem( | |
| # success=True, | |
| # time_ms=round(total_ms / len(vectors), 3), | |
| # embeddings=vec, | |
| # dimension=dim, | |
| # )) | |
| # | |
| # success_count = sum(1 for r in item_results if r.success) | |
| # failed_count = len(item_results) - success_count | |
| # _logger.info("Vision embedding success: items=%s, success=%s, failed=%s, total_ms=%s", | |
| # len(item_results), success_count, failed_count, round(total_ms, 3)) | |
| # return EmbeddingResponse( | |
| # success=failed_count == 0, | |
| # time_ms=round(total_ms, 3), | |
| # success_count=success_count, | |
| # failed_count=failed_count, | |
| # results=item_results, | |
| # ) | |
| # | |
| # | |
| # @router.post( # DISABLED (OOM mitigation) | |
| # "/embeddings/vision/url", | |
| # response_model=EmbeddingResponse, | |
| # summary="Generate embeddings from image URLs", | |
| # ) | |
| # async def create_vision_embeddings_url( | |
| # body: VisionUrlRequest, | |
| # token: str = Depends(require_auth), | |
| # embedding_service: EmbeddingService = Depends(get_embeddings_service), | |
| # ) -> EmbeddingResponse: | |
| # if not embedding_service._vision_loaded: | |
| # raise HTTPException(status_code=503, detail={"success": False, "message": "Vision model not loaded."}) | |
| # | |
| # _logger.info("Vision embedding URL request: urls=%s", len(body.urls)) | |
| # | |
| # dim = embedding_service.vision_dimension | |
| # start = time.perf_counter() | |
| # images: List[Image.Image] = [] | |
| # item_results: List[EmbeddingItem] = [] | |
| # | |
| # for url in body.urls: | |
| # t0 = time.perf_counter() | |
| # try: | |
| # raw = await _download_image(url) | |
| # img = await asyncio.get_running_loop().run_in_executor(_thread_pool, _validate_image, raw, url) | |
| # images.append(img) | |
| # except Exception as exc: | |
| # elapsed = (time.perf_counter() - t0) * 1000 | |
| # item_results.append(EmbeddingItem( | |
| # success=False, | |
| # time_ms=round(elapsed, 3), | |
| # error_message=str(exc), | |
| # )) | |
| # | |
| # if not images: | |
| # total_ms = (time.perf_counter() - start) * 1000 | |
| # return EmbeddingResponse( | |
| # success=False, | |
| # time_ms=round(total_ms, 3), | |
| # success_count=0, | |
| # failed_count=len(item_results), | |
| # error_message="No valid images could be downloaded.", | |
| # results=item_results, | |
| # ) | |
| # | |
| # try: | |
| # loop = asyncio.get_running_loop() | |
| # vectors = await loop.run_in_executor( | |
| # _thread_pool, | |
| # embedding_service.generate_image_embedding, | |
| # images, | |
| # ) | |
| # except Exception as exc: | |
| # elapsed = (time.perf_counter() - start) * 1000 | |
| # _logger.error("Vision embedding error: %s", exc) | |
| # for _ in range(len(images) - len(item_results)): | |
| # item_results.append(EmbeddingItem(success=False, time_ms=0, error_message=str(exc))) | |
| # return EmbeddingResponse( | |
| # success=False, | |
| # time_ms=round(elapsed, 3), | |
| # success_count=0, | |
| # failed_count=len(item_results), | |
| # error_message=str(exc), | |
| # results=item_results, | |
| # ) | |
| # | |
| # total_ms = (time.perf_counter() - start) * 1000 | |
| # for i, vec in enumerate(vectors): | |
| # item_results.append(EmbeddingItem( | |
| # success=True, | |
| # time_ms=round(total_ms / len(vectors), 3), | |
| # embeddings=vec, | |
| # dimension=dim, | |
| # )) | |
| # | |
| # success_count = sum(1 for r in item_results if r.success) | |
| # failed_count = len(item_results) - success_count | |
| # _logger.info("Vision embedding success: items=%s, success=%s, failed=%s, total_ms=%s", | |
| # len(item_results), success_count, failed_count, round(total_ms, 3)) | |
| # return EmbeddingResponse( | |
| # success=failed_count == 0, | |
| # time_ms=round(total_ms, 3), | |
| # success_count=success_count, | |
| # failed_count=failed_count, | |
| # results=item_results, | |
| # ) |