import os import logging import asyncio import functools from functools import lru_cache from fastapi import FastAPI, Request from pydantic import BaseModel from contextlib import asynccontextmanager from sentence_transformers import SentenceTransformer from typing import Literal from src.config import config from src.utils.logging import context_logger logger = logging.getLogger(__name__) @lru_cache() async def load_item_model(app: FastAPI): with context_logger(f"🤖 Loading model: {config.model.item_model_path}"): try: if config.model.item_model_access_token: # `HF_TOKEN` seems to be required in addition to the `token` parameter os.environ["HF_TOKEN"] = config.model.item_model_access_token loop = asyncio.get_event_loop() app.state.item_model = await loop.run_in_executor( None, functools.partial( SentenceTransformer, model_name_or_path=config.model.item_model_path, device=config.model.device, token=config.model.item_model_access_token, trust_remote_code=True, cache_folder="cache" ) ) except Exception as e: logger.error(f"❌ Failed to load model: {e}") app.state.item_model = None @lru_cache() async def load_scale_model(app: FastAPI): with context_logger(f"🤖 Loading model: {config.model.scale_model_path}"): if config.model.scale_model_access_token: os.environ["HF_TOKEN"] = config.model.scale_model_access_token try: loop = asyncio.get_event_loop() app.state.scale_model = await loop.run_in_executor( None, functools.partial( SentenceTransformer, model_name_or_path=config.model.scale_model_path, device=config.model.device, token=config.model.scale_model_access_token, trust_remote_code=True, cache_folder="cache" ) ) except Exception as e: logger.error(f"❌ Failed to load model: {e}") app.state.scale_model = None @asynccontextmanager async def lifespan(app: FastAPI): try: await load_item_model(app) await load_scale_model(app) yield finally: pass app = FastAPI(lifespan=lifespan) class EncodeRequest(BaseModel): texts: list[str] mode: Literal["item", "scale"] @app.post("/encode") async def encode(encode_request: EncodeRequest, request: Request): if encode_request.mode == "item": model = request.app.state.item_model else: model = request.app.state.scale_model if model is None: raise RuntimeError(f"Model for request mode {encode_request.mode} not loaded!") with context_logger(f"Encoding {len(encode_request.texts)} texts"): embeddings = model.encode(encode_request.texts, convert_to_tensor=True) return {"embeddings": embeddings.tolist()}