|
|
import os |
|
|
import logging |
|
|
import asyncio |
|
|
import functools |
|
|
from functools import lru_cache |
|
|
from fastapi import FastAPI, Request |
|
|
from pydantic import BaseModel |
|
|
from contextlib import asynccontextmanager |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from typing import Literal |
|
|
|
|
|
from src.config import config |
|
|
from src.utils.logging import context_logger |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@lru_cache() |
|
|
async def load_item_model(app: FastAPI): |
|
|
|
|
|
with context_logger(f"π€ Loading model: {config.model.item_model_path}"): |
|
|
|
|
|
try: |
|
|
if config.model.item_model_access_token: |
|
|
|
|
|
os.environ["HF_TOKEN"] = config.model.item_model_access_token |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
app.state.item_model = await loop.run_in_executor( |
|
|
None, |
|
|
functools.partial( |
|
|
SentenceTransformer, |
|
|
model_name_or_path=config.model.item_model_path, |
|
|
device=config.model.device, |
|
|
token=config.model.item_model_access_token, |
|
|
trust_remote_code=True, |
|
|
cache_folder="cache" |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load model: {e}") |
|
|
app.state.item_model = None |
|
|
|
|
|
@lru_cache() |
|
|
async def load_scale_model(app: FastAPI): |
|
|
|
|
|
with context_logger(f"π€ Loading model: {config.model.scale_model_path}"): |
|
|
|
|
|
if config.model.scale_model_access_token: |
|
|
os.environ["HF_TOKEN"] = config.model.scale_model_access_token |
|
|
|
|
|
try: |
|
|
loop = asyncio.get_event_loop() |
|
|
app.state.scale_model = await loop.run_in_executor( |
|
|
None, |
|
|
functools.partial( |
|
|
SentenceTransformer, |
|
|
model_name_or_path=config.model.scale_model_path, |
|
|
device=config.model.device, |
|
|
token=config.model.scale_model_access_token, |
|
|
trust_remote_code=True, |
|
|
cache_folder="cache" |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load model: {e}") |
|
|
app.state.scale_model = None |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
try: |
|
|
await load_item_model(app) |
|
|
await load_scale_model(app) |
|
|
|
|
|
yield |
|
|
|
|
|
finally: |
|
|
pass |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
class EncodeRequest(BaseModel): |
|
|
texts: list[str] |
|
|
mode: Literal["item", "scale"] |
|
|
|
|
|
@app.post("/encode") |
|
|
async def encode(encode_request: EncodeRequest, request: Request): |
|
|
|
|
|
if encode_request.mode == "item": |
|
|
model = request.app.state.item_model |
|
|
else: |
|
|
model = request.app.state.scale_model |
|
|
|
|
|
if model is None: |
|
|
raise RuntimeError(f"Model for request mode {encode_request.mode} not loaded!") |
|
|
|
|
|
with context_logger(f"Encoding {len(encode_request.texts)} texts"): |
|
|
embeddings = model.encode(encode_request.texts, convert_to_tensor=True) |
|
|
return {"embeddings": embeddings.tolist()} |