File size: 3,168 Bytes
6ca4b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import logging
import asyncio
import functools
from functools import lru_cache
from fastapi import FastAPI, Request
from pydantic import BaseModel
from contextlib import asynccontextmanager
from sentence_transformers import SentenceTransformer
from typing import Literal

from src.config import config
from src.utils.logging import context_logger

logger = logging.getLogger(__name__)

@lru_cache()
async def load_item_model(app: FastAPI):

    with context_logger(f"πŸ€– Loading model: {config.model.item_model_path}"):

        try:
            if config.model.item_model_access_token:
                # `HF_TOKEN` seems to be required in addition to the `token` parameter
                os.environ["HF_TOKEN"] = config.model.item_model_access_token 

            loop = asyncio.get_event_loop()
            app.state.item_model = await loop.run_in_executor(
                None,
                functools.partial(
                    SentenceTransformer,
                    model_name_or_path=config.model.item_model_path,
                    device=config.model.device,
                    token=config.model.item_model_access_token,
                    trust_remote_code=True,
                    cache_folder="cache"
                )
            )
        except Exception as e:
            logger.error(f"❌ Failed to load model: {e}")
            app.state.item_model = None

@lru_cache()
async def load_scale_model(app: FastAPI):

    with context_logger(f"πŸ€– Loading model: {config.model.scale_model_path}"):

        if config.model.scale_model_access_token:
            os.environ["HF_TOKEN"] = config.model.scale_model_access_token 

        try:
            loop = asyncio.get_event_loop()
            app.state.scale_model = await loop.run_in_executor(
                None,
                functools.partial(
                    SentenceTransformer,
                    model_name_or_path=config.model.scale_model_path,
                    device=config.model.device,
                    token=config.model.scale_model_access_token,
                    trust_remote_code=True,
                    cache_folder="cache"
                )
            )
        except Exception as e:
            logger.error(f"❌ Failed to load model: {e}")
            app.state.scale_model = None


@asynccontextmanager
async def lifespan(app: FastAPI):

    try:
        await load_item_model(app)
        await load_scale_model(app)

        yield

    finally:
        pass

app = FastAPI(lifespan=lifespan)

class EncodeRequest(BaseModel):
    texts: list[str]
    mode: Literal["item", "scale"]

@app.post("/encode")
async def encode(encode_request: EncodeRequest, request: Request):
    
    if encode_request.mode == "item":
        model = request.app.state.item_model
    else:
        model = request.app.state.scale_model
    
    if model is None:
        raise RuntimeError(f"Model for request mode {encode_request.mode} not loaded!")

    with context_logger(f"Encoding {len(encode_request.texts)} texts"):
        embeddings = model.encode(encode_request.texts, convert_to_tensor=True)
        return {"embeddings": embeddings.tolist()}