from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer import uvicorn import asyncio from typing import List import numpy as np from contextlib import asynccontextmanager import httpx import os import sqlite3 import json # Globals model = None tokenizer = None model_id = 'Qwen/Qwen3-Embedding-0.6B' MAX_TOKENS = 32000 DB_PATH = "/data/embeddings.db" is_processing = False def init_database(): """Initialize the SQLite database""" os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS embedding_requests ( id INTEGER PRIMARY KEY AUTOINCREMENT, request_id TEXT, text TEXT NOT NULL, embedding TEXT, status TEXT DEFAULT 'pending', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, processed_at TIMESTAMP, webhook_sent BOOLEAN DEFAULT 0, error_message TEXT ) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_status ON embedding_requests(status) ''') conn.commit() conn.close() print("✅ Database initialized successfully") def save_request_to_db(text: str, request_id: str = None) -> int: """Save the incoming request to database""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' INSERT INTO embedding_requests (request_id, text, status) VALUES (?, ?, 'pending') ''', (request_id, text)) row_id = cursor.lastrowid conn.commit() conn.close() print(f"✅ Request saved to DB with ID: {row_id}") return row_id def get_next_pending_request(): """Get the next pending request from database""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' SELECT id, request_id, text FROM embedding_requests WHERE status = 'pending' ORDER BY id ASC LIMIT 1 ''') result = cursor.fetchone() conn.close() return result def update_request_processing(row_id: int): """Mark request as processing""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' UPDATE embedding_requests SET status = 'processing' WHERE id = ? ''', (row_id,)) conn.commit() conn.close() def update_embedding_in_db(row_id: int, embedding: List[float]): """Update the embedding in database""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() embedding_json = json.dumps(embedding) cursor.execute(''' UPDATE embedding_requests SET embedding = ?, status = 'completed', processed_at = CURRENT_TIMESTAMP WHERE id = ? ''', (embedding_json, row_id)) conn.commit() conn.close() print(f"✅ Embedding saved for ID: {row_id}") def get_request_data(row_id: int): """Get full request data including embedding""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' SELECT id, request_id, text, embedding FROM embedding_requests WHERE id = ? ''', (row_id,)) result = cursor.fetchone() conn.close() return result def mark_webhook_sent_and_delete(row_id: int): """Mark webhook as sent and delete from DB""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() # First mark as sent cursor.execute(''' UPDATE embedding_requests SET webhook_sent = 1 WHERE id = ? ''', (row_id,)) # Then delete cursor.execute('DELETE FROM embedding_requests WHERE id = ?', (row_id,)) conn.commit() conn.close() print(f"🗑️ Request deleted from DB: {row_id}") def mark_request_failed(row_id: int, error_message: str): """Mark request as failed""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' UPDATE embedding_requests SET status = 'failed', error_message = ?, processed_at = CURRENT_TIMESTAMP WHERE id = ? ''', (error_message, row_id)) conn.commit() conn.close() @asynccontextmanager async def lifespan(app: FastAPI): # Initialize database init_database() # Load the model global model, tokenizer print(f"Loading model: {model_id}...") model = SentenceTransformer(model_id) tokenizer = model.tokenizer print("✅ Model loaded successfully") # Start background processor asyncio.create_task(process_queue()) yield # Cleanup print("Cleaning up...") model = None tokenizer = None app = FastAPI( title="Text Embedding API with Queue", lifespan=lifespan ) class TextRequest(BaseModel): text: str = Field(..., min_length=1, description="Text to embed") request_id: str | None = Field(None, description="Optional request identifier") def chunk_and_embed(text: str) -> List[float]: """Generate embedding with chunking if needed""" tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) <= MAX_TOKENS: return model.encode(text, normalize_embeddings=True).tolist() # Chunking chunks = [] overlap = 50 start = 0 while start < len(tokens): end = start + MAX_TOKENS chunk_tokens = tokens[start:end] chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True) chunks.append(chunk_text) if end >= len(tokens): break start = end - overlap chunk_embeddings = [model.encode(chunk, normalize_embeddings=True) for chunk in chunks] final_embedding = np.mean(chunk_embeddings, axis=0).tolist() return final_embedding async def send_to_webhook(webhook_url: str, row_id: int, request_id: str, text: str, embedding: List[float]): """Send complete data to webhook after embedding is ready""" try: payload = { "db_id": row_id, "request_id": request_id, "text": text, "embedding": embedding, "status": "completed" } async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(webhook_url, json=payload) response.raise_for_status() print(f"✅ Webhook sent successfully for ID: {row_id}") # Delete from DB after successful webhook mark_webhook_sent_and_delete(row_id) except Exception as e: print(f"❌ Webhook error for ID {row_id}: {e}") # Don't delete if webhook failed async def process_queue(): """Background processor - processes one request at a time""" global is_processing print("🚀 Queue processor started") while True: try: pending = get_next_pending_request() if pending: row_id, request_id, text = pending is_processing = True update_request_processing(row_id) print(f"⚙️ Processing request ID: {row_id}") try: # Generate embedding embedding = await asyncio.to_thread(chunk_and_embed, text) # Save to DB update_embedding_in_db(row_id, embedding) # Send to webhook with ALL data webhook_url = os.environ.get("WEBHOOK_URL") if webhook_url: await send_to_webhook(webhook_url, row_id, request_id, text, embedding) else: # No webhook, just delete mark_webhook_sent_and_delete(row_id) except Exception as e: print(f"❌ Error processing {row_id}: {e}") mark_request_failed(row_id, str(e)) is_processing = False else: # No pending requests await asyncio.sleep(2) except Exception as e: print(f"❌ Queue error: {e}") is_processing = False await asyncio.sleep(5) @app.get("/") def home(): return { "status": "online", "model": model_id, "processing": is_processing } @app.post("/embed/text") async def embed_text(request: TextRequest): """ Fast response - just queue the request Processing happens in background """ try: # Save to DB immediately db_row_id = save_request_to_db(request.text, request.request_id) # Return immediately return { "success": True, "message": "Request queued successfully", "db_id": db_row_id, "status": "pending" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/status") def get_status(): """Get queue statistics""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "pending"') pending = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "processing"') processing = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "completed"') completed = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "failed"') failed = cursor.fetchone()[0] # Get next in queue cursor.execute(''' SELECT id, created_at FROM embedding_requests WHERE status = "pending" ORDER BY id ASC LIMIT 1 ''') next_request = cursor.fetchone() conn.close() return { "queue": { "pending": pending, "processing": processing, "completed": completed, "failed": failed }, "is_processing": is_processing, "next_request": { "id": next_request[0] if next_request else None, "created_at": next_request[1] if next_request else None } if next_request else None } @app.get("/request/{db_id}") def get_request_info(db_id: int): """Check specific request status""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(''' SELECT id, request_id, status, created_at, processed_at, webhook_sent, error_message FROM embedding_requests WHERE id = ? ''', (db_id,)) result = cursor.fetchone() conn.close() if not result: raise HTTPException(status_code=404, detail="Request not found or already deleted") return { "db_id": result[0], "request_id": result[1], "status": result[2], "created_at": result[3], "processed_at": result[4], "webhook_sent": bool(result[5]), "error_message": result[6] } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)