File size: 3,001 Bytes
dcc55a4
75c14d3
dcc55a4
75c14d3
 
 
 
 
dcc55a4
 
75c14d3
dcc55a4
75c14d3
 
 
dcc55a4
75c14d3
 
dcc55a4
75c14d3
 
dcc55a4
 
 
75c14d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcc55a4
 
 
75c14d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcc55a4
75c14d3
 
 
 
 
 
 
dcc55a4
75c14d3
 
 
dcc55a4
 
 
75c14d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2897aba
75c14d3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from typing import List, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from huggingface_hub import HfApi
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv


# Load environment variables

load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = os.getenv("MODEL_REPO")   # Example: "sentence-transformers/all-MiniLM-L6-v2"

if HF_TOKEN is None:
    raise ValueError("HF_TOKEN is missing in environment variables")

if MODEL_REPO is None:
    raise ValueError("MODEL_REPO is missing in environment variables")



# Request / Response Schemas

class EmbeddingInput(BaseModel):
    text: Union[str, List[str]]


class EmbeddingOutput(BaseModel):
    embeddings: List[List[float]]
    model_version: str



# Global model instance

model: SentenceTransformer = None
api = HfApi()



# Lifespan (FastAPI startup/shutdown)

@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    print("πŸ”„ Loading embedding model...")

    model = SentenceTransformer(
        MODEL_REPO,
        token=HF_TOKEN,
        device="cpu"    # Force CPU for stability on 16GB RAM host
    )

    # Put model into eval mode (more correct, avoids any unnecessary training state)
    model.eval()

    print("βœ… Model loaded successfully.")
    yield
    print("πŸ›‘ Shutting down.")



# FastAPI App

app = FastAPI(
    title="Custom Embedding Model API",
    description="Lightweight SentenceTransformer-based embedding service.",
    version="1.0.0",
    lifespan=lifespan
)



# Health / Info Route

@app.get("/")
def root():
    # Get the repo's latest commit hash/tag
    try:
        latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name
    except Exception:
        latest_tag = "unknown"

    return {
        "message": "Embedding API is running.",
        "model_loaded": True if model else False,
        "model_version": latest_tag,
    }



# Embedding Route

@app.post("/embed", response_model=EmbeddingOutput)
def generate_embedding(input_data: EmbeddingInput):
    try:
        # Normalize input to list
        texts = input_data.text if isinstance(input_data.text, list) else [input_data.text]

        # Compute embeddings (in batch for performance)
        vectors = model.encode(
            texts,
            batch_size=32,
            show_progress_bar=False,
            normalize_embeddings=True   # helps with vector similarity use cases
        ).tolist()

        # Fetch model version
        try:
            latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name
        except Exception:
            latest_tag = "latest"

        return EmbeddingOutput(
            embeddings=vectors,
            model_version=latest_tag
        )

    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Embedding generation failed: {str(e)}"
        )