Aditya Kulkarni commited on
Commit ·
0b7903a
1
Parent(s): c6e316a
feat: add POST /predict endpoint with MiniLM embedding generation
Browse files- .gitignore +13 -0
- .python-version +1 -0
- app/__init__.py +0 -0
- app/config.py +18 -0
- app/main.py +58 -0
- app/model.py +28 -0
- app/schemas.py +20 -0
- pyproject.toml +13 -0
- tests/__init__.py +0 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# AI agent files
|
| 13 |
+
.claude
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
app/__init__.py
ADDED
|
File without changes
|
app/config.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
"""Application settings loaded from environment variables.
|
| 6 |
+
|
| 7 |
+
Hint: pydantic-settings reads from env vars automatically.
|
| 8 |
+
Prefix with model_config = SettingsConfigDict(env_prefix="INFERENCE_") if you want
|
| 9 |
+
namespaced env vars like INFERENCE_MODEL_NAME.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 13 |
+
device: str = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 14 |
+
host: str = "0.0.0.0"
|
| 15 |
+
port: int = 8000
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
settings = Settings()
|
app/main.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
|
| 5 |
+
from .schemas import PredictRequest, PredictResponse
|
| 6 |
+
from .model import load_model, predict
|
| 7 |
+
from .config import settings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@asynccontextmanager
|
| 11 |
+
async def lifespan(app: FastAPI):
|
| 12 |
+
"""FastAPI lifespan context manager — runs on startup and shutdown.
|
| 13 |
+
|
| 14 |
+
The 'yield' separates startup from shutdown.
|
| 15 |
+
After yield, add any cleanup logic if needed (e.g. logging shutdown).
|
| 16 |
+
|
| 17 |
+
Docs: https://fastapi.tiangolo.com/advanced/events/#lifespan
|
| 18 |
+
"""
|
| 19 |
+
# Startup
|
| 20 |
+
model = load_model(settings.model_name, settings.device)
|
| 21 |
+
app.state.model = model
|
| 22 |
+
print(f"model loaded on {settings.device}")
|
| 23 |
+
# raise NotImplementedError # Replace with your startup logic
|
| 24 |
+
yield
|
| 25 |
+
# Shutdown (optional cleanup here)
|
| 26 |
+
# model.clear()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
app = FastAPI(
|
| 30 |
+
title="Embedding Inference Server",
|
| 31 |
+
version="0.1.0",
|
| 32 |
+
lifespan=lifespan,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@app.get("/health")
|
| 37 |
+
async def health():
|
| 38 |
+
"""Health check endpoint.
|
| 39 |
+
|
| 40 |
+
This lets you verify the server is running and which model is loaded.
|
| 41 |
+
"""
|
| 42 |
+
return {
|
| 43 |
+
"status": "ok",
|
| 44 |
+
"model" : settings.model_name,
|
| 45 |
+
"device": settings.device
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@app.post("/predict", response_model=PredictResponse)
|
| 50 |
+
async def predict_endpoint(request: Request, body: PredictRequest):
|
| 51 |
+
"""Generate embeddings for input texts."""
|
| 52 |
+
model = request.app.state.model
|
| 53 |
+
result = predict(model, body.texts)
|
| 54 |
+
return {
|
| 55 |
+
"embeddings": result,
|
| 56 |
+
"dim": model.get_sentence_embedding_dimension(),
|
| 57 |
+
"num_texts": len(body.texts)
|
| 58 |
+
}
|
app/model.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
|
| 3 |
+
def load_model(model_name: str, device: str) -> SentenceTransformer:
|
| 4 |
+
"""Load a SentenceTransformer model onto the specified device.
|
| 5 |
+
|
| 6 |
+
Args:
|
| 7 |
+
model_name: HuggingFace model ID (e.g. "sentence-transformers/all-MiniLM-L6-v2")
|
| 8 |
+
device: torch device string ("cpu", "mps", "cuda")
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
Loaded SentenceTransformer model ready for inference
|
| 12 |
+
"""
|
| 13 |
+
model = SentenceTransformer(model_name, device=device)
|
| 14 |
+
return model
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def predict(model: SentenceTransformer, texts: list[str]) -> list[list[float]]:
|
| 18 |
+
"""Generate embeddings for a list of text strings.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model: Loaded SentenceTransformer model
|
| 22 |
+
texts: List of strings to embed
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
List of embedding vectors, each a list of floats
|
| 26 |
+
"""
|
| 27 |
+
embeddings = model.encode(texts).tolist()
|
| 28 |
+
return embeddings
|
app/schemas.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class PredictRequest(BaseModel):
|
| 5 |
+
"""Request body for the /predict endpoint.
|
| 6 |
+
|
| 7 |
+
Consider adding validation:
|
| 8 |
+
- Non-empty list (min_length=1)
|
| 9 |
+
- Individual strings should be non-empty
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
texts: list[str]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PredictResponse(BaseModel):
|
| 16 |
+
"""Response body for the /predict endpoint."""
|
| 17 |
+
|
| 18 |
+
embeddings: list[list[float]]
|
| 19 |
+
dim: int | None
|
| 20 |
+
num_texts: int
|
pyproject.toml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "inference-server"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi>=0.133.1",
|
| 9 |
+
"pydantic-settings>=2.13.1",
|
| 10 |
+
"sentence-transformers>=5.2.3",
|
| 11 |
+
"torch>=2.10.0",
|
| 12 |
+
"uvicorn>=0.41.0",
|
| 13 |
+
]
|
tests/__init__.py
ADDED
|
File without changes
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|