File size: 3,168 Bytes
6ca4b94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import logging
import asyncio
import functools
from functools import lru_cache
from fastapi import FastAPI, Request
from pydantic import BaseModel
from contextlib import asynccontextmanager
from sentence_transformers import SentenceTransformer
from typing import Literal
from src.config import config
from src.utils.logging import context_logger
logger = logging.getLogger(__name__)
@lru_cache()
async def load_item_model(app: FastAPI):
with context_logger(f"π€ Loading model: {config.model.item_model_path}"):
try:
if config.model.item_model_access_token:
# `HF_TOKEN` seems to be required in addition to the `token` parameter
os.environ["HF_TOKEN"] = config.model.item_model_access_token
loop = asyncio.get_event_loop()
app.state.item_model = await loop.run_in_executor(
None,
functools.partial(
SentenceTransformer,
model_name_or_path=config.model.item_model_path,
device=config.model.device,
token=config.model.item_model_access_token,
trust_remote_code=True,
cache_folder="cache"
)
)
except Exception as e:
logger.error(f"β Failed to load model: {e}")
app.state.item_model = None
@lru_cache()
async def load_scale_model(app: FastAPI):
with context_logger(f"π€ Loading model: {config.model.scale_model_path}"):
if config.model.scale_model_access_token:
os.environ["HF_TOKEN"] = config.model.scale_model_access_token
try:
loop = asyncio.get_event_loop()
app.state.scale_model = await loop.run_in_executor(
None,
functools.partial(
SentenceTransformer,
model_name_or_path=config.model.scale_model_path,
device=config.model.device,
token=config.model.scale_model_access_token,
trust_remote_code=True,
cache_folder="cache"
)
)
except Exception as e:
logger.error(f"β Failed to load model: {e}")
app.state.scale_model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
await load_item_model(app)
await load_scale_model(app)
yield
finally:
pass
app = FastAPI(lifespan=lifespan)
class EncodeRequest(BaseModel):
texts: list[str]
mode: Literal["item", "scale"]
@app.post("/encode")
async def encode(encode_request: EncodeRequest, request: Request):
if encode_request.mode == "item":
model = request.app.state.item_model
else:
model = request.app.state.scale_model
if model is None:
raise RuntimeError(f"Model for request mode {encode_request.mode} not loaded!")
with context_logger(f"Encoding {len(encode_request.texts)} texts"):
embeddings = model.encode(encode_request.texts, convert_to_tensor=True)
return {"embeddings": embeddings.tolist()} |