File size: 5,906 Bytes
967868b
 
 
 
9ab4c8b
 
 
 
 
 
 
 
 
3b07301
58f4a9c
967868b
 
58f4a9c
967868b
 
 
9ab4c8b
967868b
9ab4c8b
967868b
 
58f4a9c
967868b
9ab4c8b
967868b
 
58f4a9c
9ab4c8b
 
3b07301
9ab4c8b
 
 
58f4a9c
9ab4c8b
 
 
3b07301
 
 
58f4a9c
 
 
 
 
 
3b07301
58f4a9c
 
9ab4c8b
58f4a9c
3b07301
9ab4c8b
 
3b07301
9ab4c8b
 
 
58f4a9c
 
 
 
 
 
9ab4c8b
 
 
967868b
 
58f4a9c
967868b
 
3b07301
 
58f4a9c
967868b
 
 
 
 
 
 
 
 
 
 
 
 
9ab4c8b
967868b
58f4a9c
 
967868b
 
 
3b07301
967868b
 
58f4a9c
 
9ab4c8b
 
 
967868b
58f4a9c
 
967868b
 
9ab4c8b
967868b
 
58f4a9c
 
 
 
 
 
967868b
 
 
 
 
 
 
58f4a9c
 
9ab4c8b
967868b
 
58f4a9c
967868b
 
9ab4c8b
 
967868b
3b07301
 
967868b
9ab4c8b
 
 
 
 
 
58f4a9c
3b07301
 
 
 
58f4a9c
967868b
58f4a9c
3b07301
9ab4c8b
 
967868b
 
 
3b07301
58f4a9c
967868b
9ab4c8b
967868b
 
58f4a9c
967868b
 
9ab4c8b
967868b
3b07301
58f4a9c
967868b
14cc6a3
 
5062b98
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import logging
import asyncio
import multiprocessing
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Union, List, Optional, Any

from fastapi import FastAPI, HTTPException, Security, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field

# Import the new MultiEmbeddingService
from model_service import MultiEmbeddingService

# ============================================================================
# LOGGING
# ============================================================================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("EmbedAPI")

# ============================================================================
# CONFIGURATION
# ============================================================================
AUTH_TOKEN = os.getenv('AUTH_TOKEN', None)
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')

# Global context container
ml_context = {
    "service": None,
    "executor": None
}

# ============================================================================
# LIFESPAN MANAGER
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifecycle manager: Loads models and thread pool."""
    # --- Startup ---
    logger.info("Initializing Multi-Dimensional Embedding Service...")

    # 1. Thread Pool
    cpu_count = multiprocessing.cpu_count()
    max_workers = cpu_count * 2
    executor = ThreadPoolExecutor(max_workers=max_workers)
    ml_context["executor"] = executor
    logger.info(f"Thread pool ready: {max_workers} workers")

    # 2. Load Models
    try:
        service = MultiEmbeddingService()
        service.load_all_models() # Loads 384, 768, 1024 models
        ml_context["service"] = service
    except Exception as e:
        logger.critical(f"Critical error loading models: {e}", exc_info=True)
        raise e

    if AUTH_TOKEN:
        logger.info("🔒 Auth enabled.")
    
    yield
    
    # --- Shutdown ---
    logger.info("Shutting down...")
    if ml_context["executor"]:
        ml_context["executor"].shutdown(wait=True)
    ml_context.clear()

# ============================================================================
# APP SETUP
# ============================================================================
app = FastAPI(
    title="Multi-Dim Embedding API",
    version="3.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

security = HTTPBearer(auto_error=False)

async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
    if not AUTH_TOKEN:
        return True
    if not credentials or credentials.credentials != AUTH_TOKEN:
        raise HTTPException(status_code=401, detail="Invalid token")
    return True

# ============================================================================
# MODELS
# ============================================================================
class EmbedRequest(BaseModel):
    data: Union[str, List[str]] = Field(..., description="Text string or list of strings")
    dimension: int = Field(768, description="Target dimension (384, 768, or 1024)")

    model_config = {
        "json_schema_extra": {
            "example": {
                "data": ["Hello world", "Machine learning is great"],
                "dimension": 768
            }
        }
    }

class EmbedResponse(BaseModel):
    embeddings: Union[List[float], List[List[float]]] = Field(...)
    dimension: int
    count: int

class DeEmbedRequest(BaseModel):
    vector: List[float] = Field(..., description="The embedding vector to decode")

# ============================================================================
# ENDPOINTS
# ============================================================================

@app.get("/health")
async def health_check():
    service = ml_context.get("service")
    if not service:
        raise HTTPException(status_code=503, detail="Service not ready")
    return {
        "status": "healthy",
        "loaded_dimensions": list(service.models.keys())
    }

@app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
async def create_embeddings(request: EmbedRequest):
    """
    Generate embeddings for specific dimensions.
    Supported dimensions: 384, 768, 1024.
    """
    service = ml_context.get("service")
    executor = ml_context.get("executor")

    if not service or not executor:
        raise HTTPException(status_code=503, detail="Service unavailable")

    if request.dimension not in service.models:
        raise HTTPException(
            status_code=400, 
            detail=f"Dimension {request.dimension} not supported. Use 384, 768, or 1024."
        )

    try:
        is_single = isinstance(request.data, str)
        count = 1 if is_single else len(request.data)

        loop = asyncio.get_running_loop()
        embeddings = await loop.run_in_executor(
            executor,
            service.generate_embedding,
            request.data,
            request.dimension
        )

        return EmbedResponse(
            embeddings=embeddings,
            dimension=request.dimension,
            count=count
        )

    except Exception as e:
        logger.error(f"Inference error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/ping")
async def ping():
    return {"message": "embed-api is alive!"}

@app.get("/")
async def root():
    return {"version": "3.0.0", "message": "Multi-Dimensional Embedding API Server is running."}