Soumik Bose commited on
Commit
967868b
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
File without changes
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python 3.11 slim image
2
+ FROM python:3.11-slim
3
+
4
+ # Install curl for the keep-alive script (and clean up after)
5
+ RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
6
+
7
+ # Set the working directory inside the container
8
+ WORKDIR /app
9
+
10
+ # Environment variables for optimization and logging
11
+ ENV PYTHONUNBUFFERED=1
12
+ ENV PYTHONIOENCODING=UTF-8
13
+ ENV HF_HOME=/tmp/cache
14
+
15
+ # Copy the requirements file first
16
+ COPY requirements.txt .
17
+
18
+ # Install dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy the rest of the application code
22
+ COPY . .
23
+
24
+ # Create cache directory
25
+ RUN mkdir -p ${HF_HOME} && chmod 777 ${HF_HOME}
26
+
27
+ # Expose port 7860 (required by Hugging Face Spaces)
28
+ EXPOSE 7860
29
+
30
+ # Keep-alive script + start Uvicorn with optimized workers
31
+ CMD bash -c "while true; do curl -s https://sasasas635-database-chat.hf.space/ping >/dev/null && sleep 300; done & uvicorn main:app --host 0.0.0.0 --port 7860 --workers 4 --loop asyncio"
README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ title: My Embeddings API
2
+ emoji: 🤩
3
+ colorFrom: orange
4
+ colorTo: blue
5
+ sdk: docker
6
+ app_file: main.py
7
+ pinned: false
__pycache__/model_service.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
download_setup.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sentence_transformers import SentenceTransformer
3
+
4
+ # Configuration
5
+ MODEL_NAME = 'BAAI/bge-base-en-v1.5' # The 768-dimension model
6
+ SAVE_PATH = './models/bge-base-en-v1.5'
7
+
8
+ def download_model():
9
+ """Download and save the embedding model locally."""
10
+ print(f"Downloading model: {MODEL_NAME}...")
11
+
12
+ # Download and load the model
13
+ model = SentenceTransformer(MODEL_NAME)
14
+
15
+ # Save it to the specific folder
16
+ os.makedirs(SAVE_PATH, exist_ok=True)
17
+ print(f"Saving model to: {SAVE_PATH}...")
18
+ model.save(SAVE_PATH)
19
+
20
+ print("✅ Model downloaded and saved successfully.")
21
+
22
+ # Check model file size
23
+ model_file = os.path.join(SAVE_PATH, 'model.safetensors')
24
+ if os.path.exists(model_file):
25
+ size_mb = os.path.getsize(model_file) / (1024 * 1024)
26
+ print(f"Model file size: {size_mb:.2f} MB")
27
+
28
+ print(f"Model dimension: {model.get_sentence_embedding_dimension()}")
29
+
30
+ if __name__ == "__main__":
31
+ download_model()
main.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Security, Depends, Header
2
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel, Field
5
+ from typing import List, Union, Optional
6
+ import os
7
+ import logging
8
+ import asyncio
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import multiprocessing
11
+ from model_service import LocalEmbeddingService
12
+
13
+ # ============================================================================
14
+ # LOGGING CONFIGURATION
15
+ # ============================================================================
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ handlers=[
20
+ logging.StreamHandler()
21
+ ]
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # ============================================================================
26
+ # CONFIGURATION
27
+ # ============================================================================
28
+ LOCAL_MODEL_PATH = os.getenv('MODEL_PATH', './models/bge-base-en-v1.5')
29
+ AUTH_TOKEN = os.getenv('AUTH_TOKEN', None) # Set via environment variable
30
+ ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
31
+
32
+ # Detect CPU cores for optimal workers
33
+ CPU_COUNT = multiprocessing.cpu_count()
34
+ MAX_WORKERS = CPU_COUNT * 2 # 2x CPU cores for I/O-bound operations
35
+ logger.info(f"Detected {CPU_COUNT} CPU cores. Using {MAX_WORKERS} max workers for thread pool.")
36
+
37
+ # ============================================================================
38
+ # FASTAPI APP INITIALIZATION
39
+ # ============================================================================
40
+ app = FastAPI(
41
+ title="BGE Embedding API",
42
+ description="Production-grade embedding inference API using BAAI/bge-base-en-v1.5",
43
+ version="2.0.0",
44
+ docs_url="/docs",
45
+ redoc_url="/redoc"
46
+ )
47
+
48
+ # ============================================================================
49
+ # CORS MIDDLEWARE
50
+ # ============================================================================
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=ALLOWED_ORIGINS,
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
+ logger.info(f"CORS enabled for origins: {ALLOWED_ORIGINS}")
59
+
60
+ # ============================================================================
61
+ # SECURITY
62
+ # ============================================================================
63
+ security = HTTPBearer(auto_error=False)
64
+
65
+ async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
66
+ """Verify Bearer token if AUTH_TOKEN is set."""
67
+ if AUTH_TOKEN is None:
68
+ # No authentication required
69
+ return True
70
+
71
+ if credentials is None:
72
+ logger.warning("Authentication required but no token provided")
73
+ raise HTTPException(
74
+ status_code=401,
75
+ detail="Authentication required",
76
+ headers={"WWW-Authenticate": "Bearer"},
77
+ )
78
+
79
+ if credentials.credentials != AUTH_TOKEN:
80
+ logger.warning(f"Invalid token attempt: {credentials.credentials[:10]}...")
81
+ raise HTTPException(
82
+ status_code=401,
83
+ detail="Invalid authentication token",
84
+ headers={"WWW-Authenticate": "Bearer"},
85
+ )
86
+
87
+ return True
88
+
89
+ # ============================================================================
90
+ # GLOBAL STATE
91
+ # ============================================================================
92
+ service = None
93
+ executor = None
94
+
95
+ @app.on_event("startup")
96
+ async def startup_event():
97
+ """Load the model on startup and initialize thread pool."""
98
+ global service, executor
99
+
100
+ try:
101
+ logger.info("=" * 60)
102
+ logger.info("Starting BGE Embedding Service")
103
+ logger.info("=" * 60)
104
+
105
+ # Initialize thread pool executor for non-blocking operations
106
+ executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
107
+ logger.info(f"Thread pool executor initialized with {MAX_WORKERS} workers")
108
+
109
+ # Load model
110
+ logger.info(f"Loading model from: {LOCAL_MODEL_PATH}")
111
+ service = LocalEmbeddingService(LOCAL_MODEL_PATH)
112
+ logger.info(f"✅ Model loaded successfully! Dimension: {service.embedding_dim}")
113
+
114
+ # Authentication status
115
+ if AUTH_TOKEN:
116
+ logger.info("🔒 Authentication enabled (Bearer token required)")
117
+ else:
118
+ logger.warning("⚠️ Authentication disabled (no AUTH_TOKEN set)")
119
+
120
+ logger.info("=" * 60)
121
+ logger.info("Service ready to accept requests")
122
+ logger.info("=" * 60)
123
+
124
+ except Exception as e:
125
+ logger.error(f"❌ Failed to initialize service: {e}", exc_info=True)
126
+ raise
127
+
128
+ @app.on_event("shutdown")
129
+ async def shutdown_event():
130
+ """Cleanup on shutdown."""
131
+ global executor
132
+ logger.info("Shutting down service...")
133
+
134
+ if executor:
135
+ executor.shutdown(wait=True)
136
+ logger.info("Thread pool executor shut down")
137
+
138
+ logger.info("Service shutdown complete")
139
+
140
+ # ============================================================================
141
+ # REQUEST/RESPONSE MODELS
142
+ # ============================================================================
143
+ class EmbedRequest(BaseModel):
144
+ text: Union[str, List[str]] = Field(
145
+ ...,
146
+ description="Single text string or list of texts to embed"
147
+ )
148
+
149
+ class Config:
150
+ schema_extra = {
151
+ "example": {
152
+ "text": "Ginger was also a smart giraffe. She knew what was wrong."
153
+ }
154
+ }
155
+
156
+ class EmbedResponse(BaseModel):
157
+ embeddings: Union[List[float], List[List[float]]] = Field(
158
+ ...,
159
+ description="Generated embedding(s)"
160
+ )
161
+ dimension: int = Field(..., description="Embedding dimension")
162
+ count: int = Field(..., description="Number of texts processed")
163
+
164
+ # ============================================================================
165
+ # ENDPOINTS
166
+ # ============================================================================
167
+
168
+ @app.get("/")
169
+ async def root():
170
+ """API information."""
171
+ return {
172
+ "message": "BGE Embedding API - Production Ready",
173
+ "model": "BAAI/bge-base-en-v1.5",
174
+ "dimension": 768,
175
+ "version": "2.0.0",
176
+ "authentication": "enabled" if AUTH_TOKEN else "disabled",
177
+ "endpoints": {
178
+ "health": "/health",
179
+ "ping": "/ping",
180
+ "embed": "/embed",
181
+ "embeddings": "/embeddings",
182
+ "docs": "/docs"
183
+ }
184
+ }
185
+
186
+ @app.get("/health")
187
+ async def health_check():
188
+ """Check if the service is healthy."""
189
+ if service is None:
190
+ logger.error("Health check failed: service not initialized")
191
+ raise HTTPException(status_code=503, detail="Service not initialized")
192
+
193
+ return {
194
+ "status": "healthy",
195
+ "model_dimension": service.embedding_dim,
196
+ "model_path": LOCAL_MODEL_PATH,
197
+ "max_workers": MAX_WORKERS,
198
+ "cpu_count": CPU_COUNT
199
+ }
200
+
201
+ @app.get("/ping")
202
+ async def ping():
203
+ """Simple ping endpoint for keep-alive."""
204
+ return {"status": "ok", "message": "pong"}
205
+
206
+ @app.post("/embed", response_model=EmbedResponse)
207
+ async def create_embeddings(
208
+ request: EmbedRequest,
209
+ authenticated: bool = Depends(verify_token)
210
+ ):
211
+ """
212
+ Generate embeddings for the provided text(s) - Non-blocking operation.
213
+
214
+ - **text**: Single string or list of strings to embed
215
+
216
+ Returns normalized 768-dimensional embeddings suitable for cosine similarity.
217
+
218
+ Requires Bearer token authentication if AUTH_TOKEN is set.
219
+ """
220
+ if service is None:
221
+ logger.error("Embedding request failed: service not initialized")
222
+ raise HTTPException(status_code=503, detail="Service not initialized")
223
+
224
+ try:
225
+ # Determine input type and count
226
+ is_single = isinstance(request.text, str)
227
+ count = 1 if is_single else len(request.text)
228
+
229
+ logger.info(f"Processing embedding request for {count} text(s)")
230
+
231
+ # Run embedding generation in thread pool (non-blocking)
232
+ loop = asyncio.get_event_loop()
233
+ embeddings = await loop.run_in_executor(
234
+ executor,
235
+ service.generate_embedding,
236
+ request.text
237
+ )
238
+
239
+ logger.info(f"✅ Successfully generated {count} embedding(s)")
240
+
241
+ return EmbedResponse(
242
+ embeddings=embeddings,
243
+ dimension=service.embedding_dim,
244
+ count=count
245
+ )
246
+
247
+ except Exception as e:
248
+ logger.error(f"❌ Embedding generation failed: {e}", exc_info=True)
249
+ raise HTTPException(
250
+ status_code=500,
251
+ detail=f"Embedding generation failed: {str(e)}"
252
+ )
253
+
254
+ @app.post("/embeddings", response_model=EmbedResponse)
255
+ async def create_embeddings_batch(
256
+ request: EmbedRequest,
257
+ authenticated: bool = Depends(verify_token)
258
+ ):
259
+ """
260
+ Alias for /embed endpoint - Non-blocking batch embedding generation.
261
+
262
+ Requires Bearer token authentication if AUTH_TOKEN is set.
263
+ """
264
+ return await create_embeddings(request, authenticated)
model_service.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ class LocalEmbeddingService:
6
+ """Service for generating embeddings using a locally stored model."""
7
+
8
+ def __init__(self, model_folder: str):
9
+ """
10
+ Initialize the service by loading the model from a local path.
11
+
12
+ Args:
13
+ model_folder: Path to the folder containing the saved model
14
+ """
15
+ if not os.path.exists(model_folder):
16
+ raise FileNotFoundError(
17
+ f"Model folder not found at: {model_folder}. "
18
+ "Please run download_model.py first."
19
+ )
20
+
21
+ print(f"Loading model from {model_folder}...")
22
+ self.model = SentenceTransformer(model_folder)
23
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
24
+ print(f"✅ Model loaded successfully. Dimension: {self.embedding_dim}")
25
+
26
+ def generate_embedding(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
27
+ """
28
+ Generate embeddings for the given text(s).
29
+
30
+ Args:
31
+ text: A single string or list of strings to embed
32
+
33
+ Returns:
34
+ A single embedding (list of floats) or list of embeddings
35
+ """
36
+ # Encode the text with normalization for cosine similarity
37
+ embeddings = self.model.encode(
38
+ text,
39
+ normalize_embeddings=True,
40
+ convert_to_tensor=False
41
+ )
42
+
43
+ # Convert to list for JSON serialization
44
+ if isinstance(text, str):
45
+ return embeddings.tolist()
46
+
47
+ return embeddings.tolist()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ fastapi==0.115.5
3
+ uvicorn[standard]==0.32.1
4
+ pydantic==2.10.3
5
+
6
+ # ML dependencies
7
+ sentence-transformers==3.3.1
8
+ torch==2.5.1
9
+ numpy==1.26.4
10
+
11
+ # Production dependencies
12
+ python-multipart==0.0.20
13
+ aiofiles==24.1.0
test_local.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from model_service import LocalEmbeddingService
3
+
4
+ # Configuration
5
+ LOCAL_MODEL_PATH = './models/bge-base-en-v1.5'
6
+
7
+ def test_single_text():
8
+ """Test embedding generation for a single text."""
9
+ service = LocalEmbeddingService(LOCAL_MODEL_PATH)
10
+
11
+ text = "Ginger was also a smart giraffe. She knew what was wrong."
12
+
13
+ print(f"\n{'='*60}")
14
+ print("Testing single text embedding")
15
+ print(f"{'='*60}")
16
+ print(f"Text: '{text}'")
17
+
18
+ start_time = time.time()
19
+ vector = service.generate_embedding(text)
20
+ end_time = time.time()
21
+
22
+ print(f"\n✅ Embedding generated in {end_time - start_time:.4f} seconds")
23
+ print(f"Dimensions: {len(vector)}")
24
+ print(f"First 10 values: {vector[:10]}")
25
+ print(f"Vector norm (should be ~1.0): {sum(x**2 for x in vector)**0.5:.4f}")
26
+
27
+ def test_batch_texts():
28
+ """Test embedding generation for multiple texts."""
29
+ service = LocalEmbeddingService(LOCAL_MODEL_PATH)
30
+
31
+ texts = [
32
+ "The quick brown fox jumps over the lazy dog.",
33
+ "Machine learning is transforming technology.",
34
+ "Embeddings capture semantic meaning of text."
35
+ ]
36
+
37
+ print(f"\n{'='*60}")
38
+ print("Testing batch text embeddings")
39
+ print(f"{'='*60}")
40
+ print(f"Number of texts: {len(texts)}")
41
+
42
+ start_time = time.time()
43
+ vectors = service.generate_embedding(texts)
44
+ end_time = time.time()
45
+
46
+ print(f"\n✅ {len(vectors)} embeddings generated in {end_time - start_time:.4f} seconds")
47
+ print(f"Average time per text: {(end_time - start_time) / len(texts):.4f} seconds")
48
+ print(f"Each embedding dimension: {len(vectors[0])}")
49
+
50
+ # Show first embedding sample
51
+ print(f"\nFirst embedding (first 10 values): {vectors[0][:10]}")
52
+
53
+ def test_similarity():
54
+ """Test cosine similarity between embeddings."""
55
+ service = LocalEmbeddingService(LOCAL_MODEL_PATH)
56
+
57
+ texts = [
58
+ "The cat sits on the mat.",
59
+ "A feline rests on the rug.", # Similar meaning
60
+ "Python is a programming language." # Different meaning
61
+ ]
62
+
63
+ print(f"\n{'='*60}")
64
+ print("Testing semantic similarity")
65
+ print(f"{'='*60}")
66
+
67
+ vectors = service.generate_embedding(texts)
68
+
69
+ # Calculate cosine similarities (vectors are already normalized)
70
+ def cosine_sim(v1, v2):
71
+ return sum(a * b for a, b in zip(v1, v2))
72
+
73
+ sim_01 = cosine_sim(vectors[0], vectors[1])
74
+ sim_02 = cosine_sim(vectors[0], vectors[2])
75
+
76
+ print(f"\nText 1: '{texts[0]}'")
77
+ print(f"Text 2: '{texts[1]}'")
78
+ print(f"Similarity: {sim_01:.4f} (similar meaning)")
79
+
80
+ print(f"\nText 1: '{texts[0]}'")
81
+ print(f"Text 3: '{texts[2]}'")
82
+ print(f"Similarity: {sim_02:.4f} (different meaning)")
83
+
84
+ print(f"\n✅ As expected, similar texts have higher similarity!")
85
+
86
+ def main():
87
+ """Run all tests."""
88
+ try:
89
+ test_single_text()
90
+ test_batch_texts()
91
+ test_similarity()
92
+
93
+ print(f"\n{'='*60}")
94
+ print("✅ All tests completed successfully!")
95
+ print(f"{'='*60}\n")
96
+
97
+ except FileNotFoundError:
98
+ print("\n❌ Model not found. Please run download_model.py first.")
99
+ except Exception as e:
100
+ print(f"\n❌ An error occurred: {e}")
101
+
102
+ if __name__ == "__main__":
103
+ main()