Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from sentence_transformers import SentenceTransformer | |
| import uvicorn | |
| import asyncio | |
| from typing import List | |
| import numpy as np | |
| from contextlib import asynccontextmanager | |
| import httpx | |
| import os | |
| import sqlite3 | |
| import json | |
| # Globals | |
| model = None | |
| tokenizer = None | |
| model_id = 'Qwen/Qwen3-Embedding-0.6B' | |
| MAX_TOKENS = 32000 | |
| DB_PATH = "/data/embeddings.db" | |
| is_processing = False | |
| def init_database(): | |
| """Initialize the SQLite database""" | |
| os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS embedding_requests ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| request_id TEXT, | |
| text TEXT NOT NULL, | |
| embedding TEXT, | |
| status TEXT DEFAULT 'pending', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| processed_at TIMESTAMP, | |
| webhook_sent BOOLEAN DEFAULT 0, | |
| error_message TEXT | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE INDEX IF NOT EXISTS idx_status | |
| ON embedding_requests(status) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| print("✅ Database initialized successfully") | |
| def save_request_to_db(text: str, request_id: str = None) -> int: | |
| """Save the incoming request to database""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO embedding_requests (request_id, text, status) | |
| VALUES (?, ?, 'pending') | |
| ''', (request_id, text)) | |
| row_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| print(f"✅ Request saved to DB with ID: {row_id}") | |
| return row_id | |
| def get_next_pending_request(): | |
| """Get the next pending request from database""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT id, request_id, text | |
| FROM embedding_requests | |
| WHERE status = 'pending' | |
| ORDER BY id ASC | |
| LIMIT 1 | |
| ''') | |
| result = cursor.fetchone() | |
| conn.close() | |
| return result | |
| def update_request_processing(row_id: int): | |
| """Mark request as processing""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| UPDATE embedding_requests | |
| SET status = 'processing' | |
| WHERE id = ? | |
| ''', (row_id,)) | |
| conn.commit() | |
| conn.close() | |
| def update_embedding_in_db(row_id: int, embedding: List[float]): | |
| """Update the embedding in database""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| embedding_json = json.dumps(embedding) | |
| cursor.execute(''' | |
| UPDATE embedding_requests | |
| SET embedding = ?, | |
| status = 'completed', | |
| processed_at = CURRENT_TIMESTAMP | |
| WHERE id = ? | |
| ''', (embedding_json, row_id)) | |
| conn.commit() | |
| conn.close() | |
| print(f"✅ Embedding saved for ID: {row_id}") | |
| def get_request_data(row_id: int): | |
| """Get full request data including embedding""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT id, request_id, text, embedding | |
| FROM embedding_requests | |
| WHERE id = ? | |
| ''', (row_id,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| return result | |
| def mark_webhook_sent_and_delete(row_id: int): | |
| """Mark webhook as sent and delete from DB""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| # First mark as sent | |
| cursor.execute(''' | |
| UPDATE embedding_requests | |
| SET webhook_sent = 1 | |
| WHERE id = ? | |
| ''', (row_id,)) | |
| # Then delete | |
| cursor.execute('DELETE FROM embedding_requests WHERE id = ?', (row_id,)) | |
| conn.commit() | |
| conn.close() | |
| print(f"🗑️ Request deleted from DB: {row_id}") | |
| def mark_request_failed(row_id: int, error_message: str): | |
| """Mark request as failed""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| UPDATE embedding_requests | |
| SET status = 'failed', | |
| error_message = ?, | |
| processed_at = CURRENT_TIMESTAMP | |
| WHERE id = ? | |
| ''', (error_message, row_id)) | |
| conn.commit() | |
| conn.close() | |
| async def lifespan(app: FastAPI): | |
| # Initialize database | |
| init_database() | |
| # Load the model | |
| global model, tokenizer | |
| print(f"Loading model: {model_id}...") | |
| model = SentenceTransformer(model_id) | |
| tokenizer = model.tokenizer | |
| print("✅ Model loaded successfully") | |
| # Start background processor | |
| asyncio.create_task(process_queue()) | |
| yield | |
| # Cleanup | |
| print("Cleaning up...") | |
| model = None | |
| tokenizer = None | |
| app = FastAPI( | |
| title="Text Embedding API with Queue", | |
| lifespan=lifespan | |
| ) | |
| class TextRequest(BaseModel): | |
| text: str = Field(..., min_length=1, description="Text to embed") | |
| request_id: str | None = Field(None, description="Optional request identifier") | |
| def chunk_and_embed(text: str) -> List[float]: | |
| """Generate embedding with chunking if needed""" | |
| tokens = tokenizer.encode(text, add_special_tokens=False) | |
| if len(tokens) <= MAX_TOKENS: | |
| return model.encode(text, normalize_embeddings=True).tolist() | |
| # Chunking | |
| chunks = [] | |
| overlap = 50 | |
| start = 0 | |
| while start < len(tokens): | |
| end = start + MAX_TOKENS | |
| chunk_tokens = tokens[start:end] | |
| chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True) | |
| chunks.append(chunk_text) | |
| if end >= len(tokens): | |
| break | |
| start = end - overlap | |
| chunk_embeddings = [model.encode(chunk, normalize_embeddings=True) for chunk in chunks] | |
| final_embedding = np.mean(chunk_embeddings, axis=0).tolist() | |
| return final_embedding | |
| async def send_to_webhook(webhook_url: str, row_id: int, request_id: str, text: str, embedding: List[float]): | |
| """Send complete data to webhook after embedding is ready""" | |
| try: | |
| payload = { | |
| "db_id": row_id, | |
| "request_id": request_id, | |
| "text": text, | |
| "embedding": embedding, | |
| "status": "completed" | |
| } | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| response = await client.post(webhook_url, json=payload) | |
| response.raise_for_status() | |
| print(f"✅ Webhook sent successfully for ID: {row_id}") | |
| # Delete from DB after successful webhook | |
| mark_webhook_sent_and_delete(row_id) | |
| except Exception as e: | |
| print(f"❌ Webhook error for ID {row_id}: {e}") | |
| # Don't delete if webhook failed | |
| async def process_queue(): | |
| """Background processor - processes one request at a time""" | |
| global is_processing | |
| print("🚀 Queue processor started") | |
| while True: | |
| try: | |
| pending = get_next_pending_request() | |
| if pending: | |
| row_id, request_id, text = pending | |
| is_processing = True | |
| update_request_processing(row_id) | |
| print(f"⚙️ Processing request ID: {row_id}") | |
| try: | |
| # Generate embedding | |
| embedding = await asyncio.to_thread(chunk_and_embed, text) | |
| # Save to DB | |
| update_embedding_in_db(row_id, embedding) | |
| # Send to webhook with ALL data | |
| webhook_url = os.environ.get("WEBHOOK_URL") | |
| if webhook_url: | |
| await send_to_webhook(webhook_url, row_id, request_id, text, embedding) | |
| else: | |
| # No webhook, just delete | |
| mark_webhook_sent_and_delete(row_id) | |
| except Exception as e: | |
| print(f"❌ Error processing {row_id}: {e}") | |
| mark_request_failed(row_id, str(e)) | |
| is_processing = False | |
| else: | |
| # No pending requests | |
| await asyncio.sleep(2) | |
| except Exception as e: | |
| print(f"❌ Queue error: {e}") | |
| is_processing = False | |
| await asyncio.sleep(5) | |
| def home(): | |
| return { | |
| "status": "online", | |
| "model": model_id, | |
| "processing": is_processing | |
| } | |
| async def embed_text(request: TextRequest): | |
| """ | |
| Fast response - just queue the request | |
| Processing happens in background | |
| """ | |
| try: | |
| # Save to DB immediately | |
| db_row_id = save_request_to_db(request.text, request.request_id) | |
| # Return immediately | |
| return { | |
| "success": True, | |
| "message": "Request queued successfully", | |
| "db_id": db_row_id, | |
| "status": "pending" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_status(): | |
| """Get queue statistics""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "pending"') | |
| pending = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "processing"') | |
| processing = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "completed"') | |
| completed = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "failed"') | |
| failed = cursor.fetchone()[0] | |
| # Get next in queue | |
| cursor.execute(''' | |
| SELECT id, created_at | |
| FROM embedding_requests | |
| WHERE status = "pending" | |
| ORDER BY id ASC | |
| LIMIT 1 | |
| ''') | |
| next_request = cursor.fetchone() | |
| conn.close() | |
| return { | |
| "queue": { | |
| "pending": pending, | |
| "processing": processing, | |
| "completed": completed, | |
| "failed": failed | |
| }, | |
| "is_processing": is_processing, | |
| "next_request": { | |
| "id": next_request[0] if next_request else None, | |
| "created_at": next_request[1] if next_request else None | |
| } if next_request else None | |
| } | |
| def get_request_info(db_id: int): | |
| """Check specific request status""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT id, request_id, status, created_at, processed_at, webhook_sent, error_message | |
| FROM embedding_requests | |
| WHERE id = ? | |
| ''', (db_id,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| if not result: | |
| raise HTTPException(status_code=404, detail="Request not found or already deleted") | |
| return { | |
| "db_id": result[0], | |
| "request_id": result[1], | |
| "status": result[2], | |
| "created_at": result[3], | |
| "processed_at": result[4], | |
| "webhook_sent": bool(result[5]), | |
| "error_message": result[6] | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |