synth-net / src /servers /model_server.py
github-actions
Sync from GitHub (CI)
6ca4b94
import os
import logging
import asyncio
import functools
from functools import lru_cache
from fastapi import FastAPI, Request
from pydantic import BaseModel
from contextlib import asynccontextmanager
from sentence_transformers import SentenceTransformer
from typing import Literal
from src.config import config
from src.utils.logging import context_logger
logger = logging.getLogger(__name__)
@lru_cache()
async def load_item_model(app: FastAPI):
with context_logger(f"πŸ€– Loading model: {config.model.item_model_path}"):
try:
if config.model.item_model_access_token:
# `HF_TOKEN` seems to be required in addition to the `token` parameter
os.environ["HF_TOKEN"] = config.model.item_model_access_token
loop = asyncio.get_event_loop()
app.state.item_model = await loop.run_in_executor(
None,
functools.partial(
SentenceTransformer,
model_name_or_path=config.model.item_model_path,
device=config.model.device,
token=config.model.item_model_access_token,
trust_remote_code=True,
cache_folder="cache"
)
)
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
app.state.item_model = None
@lru_cache()
async def load_scale_model(app: FastAPI):
with context_logger(f"πŸ€– Loading model: {config.model.scale_model_path}"):
if config.model.scale_model_access_token:
os.environ["HF_TOKEN"] = config.model.scale_model_access_token
try:
loop = asyncio.get_event_loop()
app.state.scale_model = await loop.run_in_executor(
None,
functools.partial(
SentenceTransformer,
model_name_or_path=config.model.scale_model_path,
device=config.model.device,
token=config.model.scale_model_access_token,
trust_remote_code=True,
cache_folder="cache"
)
)
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
app.state.scale_model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
await load_item_model(app)
await load_scale_model(app)
yield
finally:
pass
app = FastAPI(lifespan=lifespan)
class EncodeRequest(BaseModel):
texts: list[str]
mode: Literal["item", "scale"]
@app.post("/encode")
async def encode(encode_request: EncodeRequest, request: Request):
if encode_request.mode == "item":
model = request.app.state.item_model
else:
model = request.app.state.scale_model
if model is None:
raise RuntimeError(f"Model for request mode {encode_request.mode} not loaded!")
with context_logger(f"Encoding {len(encode_request.texts)} texts"):
embeddings = model.encode(encode_request.texts, convert_to_tensor=True)
return {"embeddings": embeddings.tolist()}