Aditya Kulkarni commited on
Commit
0b7903a
·
1 Parent(s): c6e316a

feat: add POST /predict endpoint with MiniLM embedding generation

Browse files
Files changed (10) hide show
  1. .gitignore +13 -0
  2. .python-version +1 -0
  3. app/__init__.py +0 -0
  4. app/config.py +18 -0
  5. app/main.py +58 -0
  6. app/model.py +28 -0
  7. app/schemas.py +20 -0
  8. pyproject.toml +13 -0
  9. tests/__init__.py +0 -0
  10. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # AI agent files
13
+ .claude
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
app/__init__.py ADDED
File without changes
app/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ import torch
3
+
4
+ class Settings(BaseSettings):
5
+ """Application settings loaded from environment variables.
6
+
7
+ Hint: pydantic-settings reads from env vars automatically.
8
+ Prefix with model_config = SettingsConfigDict(env_prefix="INFERENCE_") if you want
9
+ namespaced env vars like INFERENCE_MODEL_NAME.
10
+ """
11
+
12
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
13
+ device: str = "mps" if torch.backends.mps.is_available() else "cpu"
14
+ host: str = "0.0.0.0"
15
+ port: int = 8000
16
+
17
+
18
+ settings = Settings()
app/main.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+
3
+ from fastapi import FastAPI, Request
4
+
5
+ from .schemas import PredictRequest, PredictResponse
6
+ from .model import load_model, predict
7
+ from .config import settings
8
+
9
+
10
+ @asynccontextmanager
11
+ async def lifespan(app: FastAPI):
12
+ """FastAPI lifespan context manager — runs on startup and shutdown.
13
+
14
+ The 'yield' separates startup from shutdown.
15
+ After yield, add any cleanup logic if needed (e.g. logging shutdown).
16
+
17
+ Docs: https://fastapi.tiangolo.com/advanced/events/#lifespan
18
+ """
19
+ # Startup
20
+ model = load_model(settings.model_name, settings.device)
21
+ app.state.model = model
22
+ print(f"model loaded on {settings.device}")
23
+ # raise NotImplementedError # Replace with your startup logic
24
+ yield
25
+ # Shutdown (optional cleanup here)
26
+ # model.clear()
27
+
28
+
29
+ app = FastAPI(
30
+ title="Embedding Inference Server",
31
+ version="0.1.0",
32
+ lifespan=lifespan,
33
+ )
34
+
35
+
36
+ @app.get("/health")
37
+ async def health():
38
+ """Health check endpoint.
39
+
40
+ This lets you verify the server is running and which model is loaded.
41
+ """
42
+ return {
43
+ "status": "ok",
44
+ "model" : settings.model_name,
45
+ "device": settings.device
46
+ }
47
+
48
+
49
+ @app.post("/predict", response_model=PredictResponse)
50
+ async def predict_endpoint(request: Request, body: PredictRequest):
51
+ """Generate embeddings for input texts."""
52
+ model = request.app.state.model
53
+ result = predict(model, body.texts)
54
+ return {
55
+ "embeddings": result,
56
+ "dim": model.get_sentence_embedding_dimension(),
57
+ "num_texts": len(body.texts)
58
+ }
app/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+
3
+ def load_model(model_name: str, device: str) -> SentenceTransformer:
4
+ """Load a SentenceTransformer model onto the specified device.
5
+
6
+ Args:
7
+ model_name: HuggingFace model ID (e.g. "sentence-transformers/all-MiniLM-L6-v2")
8
+ device: torch device string ("cpu", "mps", "cuda")
9
+
10
+ Returns:
11
+ Loaded SentenceTransformer model ready for inference
12
+ """
13
+ model = SentenceTransformer(model_name, device=device)
14
+ return model
15
+
16
+
17
+ def predict(model: SentenceTransformer, texts: list[str]) -> list[list[float]]:
18
+ """Generate embeddings for a list of text strings.
19
+
20
+ Args:
21
+ model: Loaded SentenceTransformer model
22
+ texts: List of strings to embed
23
+
24
+ Returns:
25
+ List of embedding vectors, each a list of floats
26
+ """
27
+ embeddings = model.encode(texts).tolist()
28
+ return embeddings
app/schemas.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class PredictRequest(BaseModel):
5
+ """Request body for the /predict endpoint.
6
+
7
+ Consider adding validation:
8
+ - Non-empty list (min_length=1)
9
+ - Individual strings should be non-empty
10
+ """
11
+
12
+ texts: list[str]
13
+
14
+
15
+ class PredictResponse(BaseModel):
16
+ """Response body for the /predict endpoint."""
17
+
18
+ embeddings: list[list[float]]
19
+ dim: int | None
20
+ num_texts: int
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "inference-server"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "fastapi>=0.133.1",
9
+ "pydantic-settings>=2.13.1",
10
+ "sentence-transformers>=5.2.3",
11
+ "torch>=2.10.0",
12
+ "uvicorn>=0.41.0",
13
+ ]
tests/__init__.py ADDED
File without changes
uv.lock ADDED
The diff for this file is too large to render. See raw diff