File size: 8,482 Bytes
fea62df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from contextlib import asynccontextmanager
import time
from typing import Union, Optional

from fastapi import FastAPI, HTTPException
from loguru import logger

from core import ModelManager
from models import (
    EmbedRequest,
    EmbedResponse,
    BatchEmbedRequest,
    BatchEmbedResponse,
    SparseEmbedding,
    SparseEmbedResponse,
    BatchSparseEmbedResponse,
)


model_manager: Optional[ModelManager] = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Async context manager that runs at application startup and shutdown.

    On startup, initializes a :class:`core.ModelManager` using the
    configuration file and preloads all models so the first inference call
    does not pay a cold-start penalty. On shutdown, it unloads all models to
    free memory.

    Args:
        app: FastAPI instance (passed by FastAPI runtime).

    Yields:
        None. This is an async contextmanager used by FastAPI's lifespan.
    """
    global model_manager
    
    logger.info("Starting embedding API...")
    try:
        model_manager = ModelManager("config.yaml")
        logger.info("Preloading all models...")
        model_manager.preload_all_models()
        logger.success("All models preloaded successfully!")
    except Exception:
        # log exception with traceback for easier debugging
        logger.exception("Failed to initialize models")
        raise
    
    yield
    
    # Shutdown
    logger.info("Shutting down embedding API...")
    if model_manager:
        model_manager.unload_all_models()



def create_app() -> FastAPI:
    """Create and return the FastAPI application instance.

    This function instantiates a temporary :class:`ModelManager` to generate
    the API description based on available models in the configuration, then
    deletes it to avoid keeping models loaded before the application lifespan
    runs.

    Returns:
        FastAPI: configured FastAPI application.
    """

    temp_manager = ModelManager("config.yaml")
    api_description = temp_manager.generate_api_description()

    # explicitly delete the temporary manager to avoid keeping models loaded
    del temp_manager

    return FastAPI(
        title="Unified Embedding API",
        description=api_description,
        version="3.0.0",
        lifespan=lifespan,
        docs_url="/docs",
        redoc_url="/redoc",
    )

app = create_app()


@app.post("/embed", response_model=Union[EmbedResponse, SparseEmbedResponse])
async def create_embedding(request: EmbedRequest):
    """Create a single dense or sparse embedding for the given text.

    The request must include `model_id`. For sparse models (config type
    "sparse-embeddings") the endpoint returns a `SparseEmbedResponse`,
    otherwise a dense `EmbedResponse` is returned.

    Args:
        request: `EmbedRequest` pydantic model with text, prompt and model_id.

    Returns:
        Union[EmbedResponse, SparseEmbedResponse]: The embedding response.

    Raises:
        HTTPException: on validation or internal errors with appropriate
            HTTP status codes.
    """

    if not request.model_id:
        raise HTTPException(status_code=400, detail="model_id is required")

    try:
        assert model_manager is not None
        model = model_manager.get_model(request.model_id)
        start_time = time.time()

        config = model_manager.model_configs[request.model_id]

        if config.type == "sparse-embeddings":
            # Sparse embedding
            sparse_result = model.embed(request.text, prompt=request.prompt)
            processing_time = time.time() - start_time

            if isinstance(sparse_result, dict) and "indices" in sparse_result:
                sparse_embedding = SparseEmbedding(
                    text=request.text,
                    indices=sparse_result["indices"],
                    values=sparse_result["values"],
                )
            else:
                raise ValueError(f"Unexpected sparse result format: {sparse_result}")

            return SparseEmbedResponse(
                sparse_embedding=sparse_embedding,
                model_id=request.model_id,
                processing_time=processing_time,
            )

        # Dense embedding
        embedding = model.embed([request.text], request.prompt)[0]
        processing_time = time.time() - start_time

        return EmbedResponse(
            embedding=embedding,
            dimension=len(embedding),
            model_id=request.model_id,
            processing_time=processing_time,
        )

    except AssertionError:
        logger.exception("Model manager is not initialized")
        raise HTTPException(status_code=500, detail="Server not ready")
    except Exception:
        logger.exception("Error creating embedding")
        raise HTTPException(status_code=500, detail="Failed to create embedding")

@app.post(
    "/embed/batch",
    response_model=Union[BatchEmbedResponse, BatchSparseEmbedResponse],
)
async def create_batch_embedding(request: BatchEmbedRequest):
    """Create batch embeddings (dense or sparse) for a list of texts.

    Args:
        request: `BatchEmbedRequest` containing `texts`, optional `prompt`, and
            required `model_id`.

    Returns:
        Union[BatchEmbedResponse, BatchSparseEmbedResponse]: Batch embedding
            responses depending on model type.
    """

    if not request.texts:
        raise HTTPException(status_code=400, detail="texts list cannot be empty")
    if not request.model_id:
        raise HTTPException(status_code=400, detail="model_id is required")

    try:
        assert model_manager is not None
        model = model_manager.get_model(request.model_id)
        start_time = time.time()

        config = model_manager.model_configs[request.model_id]

        if config.type == "sparse-embeddings":
            # Sparse batch embedding
            sparse_embeddings_raw = model.embed_batch(request.texts, request.prompt)
            processing_time = time.time() - start_time

            sparse_embeddings = []
            for emb in sparse_embeddings_raw:
                if isinstance(emb, dict) and "sparse_embedding" in emb:
                    sparse_data = emb["sparse_embedding"]
                    text = str(emb.get("text", ""))
                    sparse_embeddings.append(
                        SparseEmbedding(
                            text=text,
                            indices=sparse_data["indices"],
                            values=sparse_data["values"],
                        )
                    )
                else:
                    raise ValueError(f"Unexpected sparse embedding format: {emb}")

            return BatchSparseEmbedResponse(
                embeddings=sparse_embeddings,
                model_id=request.model_id,
                processing_time=processing_time,
            )

        # Dense batch embedding
        embeddings = model.embed(request.texts, request.prompt)
        processing_time = time.time() - start_time

        return BatchEmbedResponse(
            embeddings=embeddings,
            dimension=len(embeddings[0]) if embeddings else 0,
            model_id=request.model_id,
            processing_time=processing_time,
        )

    except AssertionError:
        logger.exception("Model manager is not initialized")
        raise HTTPException(status_code=500, detail="Server not ready")
    except Exception:
        logger.exception("Error creating batch embeddings")
        raise HTTPException(status_code=500, detail="Failed to create batch embeddings")

@app.get("/models")
async def list_available_models():
    try:
        return model_manager.list_models()
    except Exception as e:
        logger.exception("Failed to list models")
        raise HTTPException(status_code=500, detail="Failed to list models")

@app.get("/health")
async def health_check():
    try:
        memory_info = model_manager.get_memory_usage()
        return {
            "status": "ok",
            "total_models": len(model_manager.model_configs),
            "loaded_models": memory_info["loaded_count"],
            "memory": memory_info,
            "startup_complete": True
        }
    except Exception as e:
        logger.exception("Health check failed")
        return {"status": "error", "error": str(e)}

@app.get("/")
async def root():
    return {
        "message": "Unified Embedding API - Dense & Sparse Embeddings",
        "version": "3.0.0"
    }