ex510 commited on
Commit
2d14f9f
·
verified ·
1 Parent(s): 62821b0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +322 -41
main.py CHANGED
@@ -3,69 +3,216 @@ from pydantic import BaseModel, Field
3
  from sentence_transformers import SentenceTransformer
4
  import uvicorn
5
  import asyncio
6
- from concurrent.futures import ThreadPoolExecutor
7
  from typing import List
8
  import numpy as np
9
  from contextlib import asynccontextmanager
10
  import httpx
11
  import os
 
 
 
 
12
 
13
  # Globals
14
  model = None
15
  tokenizer = None
16
  model_id = 'Qwen/Qwen3-Embedding-0.6B'
17
- executor = ThreadPoolExecutor(max_workers=4)
18
  MAX_TOKENS = 32000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @asynccontextmanager
21
  async def lifespan(app: FastAPI):
 
 
 
22
  # Load the model and tokenizer at startup
23
  global model, tokenizer
24
  print(f"Loading model: {model_id}...")
25
  model = SentenceTransformer(model_id)
26
  tokenizer = model.tokenizer
27
  print("Model loaded successfully")
 
 
 
 
28
  yield
29
- # (Optional) Clean up resources at shutdown
 
30
  print("Cleaning up resources...")
31
  model = None
32
  tokenizer = None
33
 
 
34
  app = FastAPI(
35
  title="Text Embedding API (Qwen/Qwen3-Embedding-0.6B)",
36
  lifespan=lifespan
37
  )
38
 
 
39
  class TextRequest(BaseModel):
40
  text: str = Field(..., min_length=1, description="Text to embed")
41
  request_id: str | None = Field(None, description="Optional unique identifier for the request")
42
 
43
 
44
-
45
-
46
- async def send_to_webhook(url: str, data: dict):
47
- """Sends data to a webhook URL asynchronously."""
48
- try:
49
- async with httpx.AsyncClient() as client:
50
- response = await client.post(url, json=data)
51
- response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
52
- print(f"Successfully sent data to webhook: {url}")
53
- except httpx.RequestError as e:
54
- print(f"Error sending data to webhook {url}: {e}")
55
-
56
- @app.get("/")
57
- def home():
58
- return {"status": "online", "model": model_id, "endpoint": "/embed/text"}
59
-
60
  def chunk_and_embed(text: str) -> List[float]:
61
  """Split text into chunks if too long, then pool embeddings"""
62
  tokens = tokenizer.encode(text, add_special_tokens=False)
63
 
64
- # If text is short, embed directly
65
  if len(tokens) <= MAX_TOKENS:
66
  return model.encode(text, normalize_embeddings=True).tolist()
67
 
68
- # Split into chunks
69
  chunks = []
70
  overlap = 50
71
  start = 0
@@ -79,39 +226,173 @@ def chunk_and_embed(text: str) -> List[float]:
79
  break
80
  start = end - overlap
81
 
82
- # Embed all chunks
83
  chunk_embeddings = [model.encode(chunk, normalize_embeddings=True) for chunk in chunks]
84
-
85
- # Pool embeddings (mean)
86
  final_embedding = np.mean(chunk_embeddings, axis=0).tolist()
87
 
88
  return final_embedding
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  @app.post("/embed/text")
91
- async def embed_text(request: TextRequest, background_tasks: BackgroundTasks):
 
92
  try:
93
- loop = asyncio.get_event_loop()
94
- embedding = await loop.run_in_executor(
95
- executor,
96
- lambda: chunk_and_embed(request.text)
97
- )
98
 
99
- # Check for webhook URL and add the background task
100
- webhook_url = os.environ.get("WEBHOOK_URL")
101
- if webhook_url:
102
- payload = {
103
- "text": request.text,
104
- "embedding": embedding,
105
- "request_id": request.request_id
106
- }
107
- background_tasks.add_task(send_to_webhook, webhook_url, payload)
108
-
109
  return {
110
  "success": True,
111
- "model": model_id,
 
 
112
  }
 
113
  except Exception as e:
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if __name__ == "__main__":
117
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from sentence_transformers import SentenceTransformer
4
  import uvicorn
5
  import asyncio
 
6
  from typing import List
7
  import numpy as np
8
  from contextlib import asynccontextmanager
9
  import httpx
10
  import os
11
+ import sqlite3
12
+ from datetime import datetime
13
+ import json
14
+ import threading
15
 
16
  # Globals
17
  model = None
18
  tokenizer = None
19
  model_id = 'Qwen/Qwen3-Embedding-0.6B'
 
20
  MAX_TOKENS = 32000
21
+ DB_PATH = "/data/embeddings.db" # هام: المسار ده في HuggingFace
22
+ processing_lock = threading.Lock()
23
+ is_processing = False
24
+
25
+
26
+ def init_database():
27
+ """Initialize the SQLite database"""
28
+ # Create /data directory if it doesn't exist
29
+ os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
30
+
31
+ conn = sqlite3.connect(DB_PATH)
32
+ cursor = conn.cursor()
33
+
34
+ cursor.execute('''
35
+ CREATE TABLE IF NOT EXISTS embedding_requests (
36
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
37
+ request_id TEXT,
38
+ text TEXT NOT NULL,
39
+ embedding TEXT,
40
+ status TEXT DEFAULT 'pending',
41
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
42
+ processed_at TIMESTAMP,
43
+ webhook_sent BOOLEAN DEFAULT 0,
44
+ error_message TEXT
45
+ )
46
+ ''')
47
+
48
+ # Create index for faster queries
49
+ cursor.execute('''
50
+ CREATE INDEX IF NOT EXISTS idx_status
51
+ ON embedding_requests(status)
52
+ ''')
53
+
54
+ conn.commit()
55
+ conn.close()
56
+ print("Database initialized successfully")
57
+
58
+
59
+ def save_request_to_db(text: str, request_id: str = None) -> int:
60
+ """Save the incoming request to database"""
61
+ conn = sqlite3.connect(DB_PATH)
62
+ cursor = conn.cursor()
63
+
64
+ cursor.execute('''
65
+ INSERT INTO embedding_requests (request_id, text, status)
66
+ VALUES (?, ?, 'pending')
67
+ ''', (request_id, text))
68
+
69
+ row_id = cursor.lastrowid
70
+ conn.commit()
71
+ conn.close()
72
+
73
+ print(f"✅ Request saved to DB with ID: {row_id}")
74
+ return row_id
75
+
76
+
77
+ def get_next_pending_request():
78
+ """Get the next pending request from database"""
79
+ conn = sqlite3.connect(DB_PATH)
80
+ cursor = conn.cursor()
81
+
82
+ cursor.execute('''
83
+ SELECT id, request_id, text
84
+ FROM embedding_requests
85
+ WHERE status = 'pending'
86
+ ORDER BY id ASC
87
+ LIMIT 1
88
+ ''')
89
+
90
+ result = cursor.fetchone()
91
+ conn.close()
92
+
93
+ return result
94
+
95
+
96
+ def update_request_processing(row_id: int):
97
+ """Mark request as processing"""
98
+ conn = sqlite3.connect(DB_PATH)
99
+ cursor = conn.cursor()
100
+
101
+ cursor.execute('''
102
+ UPDATE embedding_requests
103
+ SET status = 'processing'
104
+ WHERE id = ?
105
+ ''', (row_id,))
106
+
107
+ conn.commit()
108
+ conn.close()
109
+
110
+
111
+ def update_embedding_in_db(row_id: int, embedding: List[float]):
112
+ """Update the embedding in database"""
113
+ conn = sqlite3.connect(DB_PATH)
114
+ cursor = conn.cursor()
115
+
116
+ embedding_json = json.dumps(embedding)
117
+
118
+ cursor.execute('''
119
+ UPDATE embedding_requests
120
+ SET embedding = ?,
121
+ status = 'completed',
122
+ processed_at = CURRENT_TIMESTAMP
123
+ WHERE id = ?
124
+ ''', (embedding_json, row_id))
125
+
126
+ conn.commit()
127
+ conn.close()
128
+ print(f"✅ Embedding updated for row ID: {row_id}")
129
+
130
+
131
+ def mark_webhook_sent(row_id: int):
132
+ """Mark that webhook was sent successfully"""
133
+ conn = sqlite3.connect(DB_PATH)
134
+ cursor = conn.cursor()
135
+
136
+ cursor.execute('''
137
+ UPDATE embedding_requests
138
+ SET webhook_sent = 1
139
+ WHERE id = ?
140
+ ''', (row_id,))
141
+
142
+ conn.commit()
143
+ conn.close()
144
+
145
+
146
+ def delete_from_db(row_id: int):
147
+ """Delete the request from database after webhook is sent"""
148
+ conn = sqlite3.connect(DB_PATH)
149
+ cursor = conn.cursor()
150
+
151
+ cursor.execute('DELETE FROM embedding_requests WHERE id = ?', (row_id,))
152
+
153
+ conn.commit()
154
+ conn.close()
155
+ print(f"🗑️ Request deleted from DB with ID: {row_id}")
156
+
157
+
158
+ def mark_request_failed(row_id: int, error_message: str):
159
+ """Mark request as failed"""
160
+ conn = sqlite3.connect(DB_PATH)
161
+ cursor = conn.cursor()
162
+
163
+ cursor.execute('''
164
+ UPDATE embedding_requests
165
+ SET status = 'failed',
166
+ error_message = ?,
167
+ processed_at = CURRENT_TIMESTAMP
168
+ WHERE id = ?
169
+ ''', (error_message, row_id))
170
+
171
+ conn.commit()
172
+ conn.close()
173
+
174
 
175
  @asynccontextmanager
176
  async def lifespan(app: FastAPI):
177
+ # Initialize database
178
+ init_database()
179
+
180
  # Load the model and tokenizer at startup
181
  global model, tokenizer
182
  print(f"Loading model: {model_id}...")
183
  model = SentenceTransformer(model_id)
184
  tokenizer = model.tokenizer
185
  print("Model loaded successfully")
186
+
187
+ # Start the background processor
188
+ asyncio.create_task(process_queue())
189
+
190
  yield
191
+
192
+ # Clean up
193
  print("Cleaning up resources...")
194
  model = None
195
  tokenizer = None
196
 
197
+
198
  app = FastAPI(
199
  title="Text Embedding API (Qwen/Qwen3-Embedding-0.6B)",
200
  lifespan=lifespan
201
  )
202
 
203
+
204
  class TextRequest(BaseModel):
205
  text: str = Field(..., min_length=1, description="Text to embed")
206
  request_id: str | None = Field(None, description="Optional unique identifier for the request")
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def chunk_and_embed(text: str) -> List[float]:
210
  """Split text into chunks if too long, then pool embeddings"""
211
  tokens = tokenizer.encode(text, add_special_tokens=False)
212
 
 
213
  if len(tokens) <= MAX_TOKENS:
214
  return model.encode(text, normalize_embeddings=True).tolist()
215
 
 
216
  chunks = []
217
  overlap = 50
218
  start = 0
 
226
  break
227
  start = end - overlap
228
 
 
229
  chunk_embeddings = [model.encode(chunk, normalize_embeddings=True) for chunk in chunks]
 
 
230
  final_embedding = np.mean(chunk_embeddings, axis=0).tolist()
231
 
232
  return final_embedding
233
 
234
+
235
+ async def send_to_webhook(url: str, data: dict, db_row_id: int):
236
+ """Send data to webhook and delete from DB on success"""
237
+ try:
238
+ async with httpx.AsyncClient(timeout=60.0) as client:
239
+ response = await client.post(url, json=data)
240
+ response.raise_for_status()
241
+ print(f"✅ Webhook sent successfully for ID: {db_row_id}")
242
+
243
+ mark_webhook_sent(db_row_id)
244
+ delete_from_db(db_row_id)
245
+
246
+ except Exception as e:
247
+ print(f"❌ Webhook error for ID {db_row_id}: {e}")
248
+
249
+
250
+ async def process_queue():
251
+ """Background task to process pending requests one by one"""
252
+ global is_processing
253
+
254
+ print("🚀 Queue processor started")
255
+
256
+ while True:
257
+ try:
258
+ # Check if there's a pending request
259
+ pending = get_next_pending_request()
260
+
261
+ if pending:
262
+ row_id, request_id, text = pending
263
+
264
+ # Mark as processing
265
+ is_processing = True
266
+ update_request_processing(row_id)
267
+
268
+ print(f"⚙️ Processing request ID: {row_id}")
269
+
270
+ try:
271
+ # Generate embedding (synchronous in async context)
272
+ embedding = await asyncio.to_thread(chunk_and_embed, text)
273
+
274
+ # Save embedding to DB
275
+ update_embedding_in_db(row_id, embedding)
276
+
277
+ # Send to webhook if URL exists
278
+ webhook_url = os.environ.get("WEBHOOK_URL")
279
+ if webhook_url:
280
+ payload = {
281
+ "db_id": row_id,
282
+ "text": text,
283
+ "embedding": embedding,
284
+ "request_id": request_id
285
+ }
286
+ await send_to_webhook(webhook_url, row_id, payload)
287
+ else:
288
+ # No webhook, just delete
289
+ delete_from_db(row_id)
290
+
291
+ except Exception as e:
292
+ print(f"❌ Error processing request {row_id}: {e}")
293
+ mark_request_failed(row_id, str(e))
294
+
295
+ is_processing = False
296
+
297
+ else:
298
+ # No pending requests, wait a bit
299
+ await asyncio.sleep(2)
300
+
301
+ except Exception as e:
302
+ print(f"❌ Queue processor error: {e}")
303
+ is_processing = False
304
+ await asyncio.sleep(5)
305
+
306
+
307
+ @app.get("/")
308
+ def home():
309
+ return {
310
+ "status": "online",
311
+ "model": model_id,
312
+ "endpoint": "/embed/text",
313
+ "processing": is_processing
314
+ }
315
+
316
+
317
  @app.post("/embed/text")
318
+ async def embed_text(request: TextRequest):
319
+ """Just save the request to database, processing happens in background"""
320
  try:
321
+ # Simply save to database
322
+ db_row_id = save_request_to_db(request.text, request.request_id)
 
 
 
323
 
 
 
 
 
 
 
 
 
 
 
324
  return {
325
  "success": True,
326
+ "message": "Request queued for processing",
327
+ "db_id": db_row_id,
328
+ "model": model_id
329
  }
330
+
331
  except Exception as e:
332
  raise HTTPException(status_code=500, detail=str(e))
333
 
334
+
335
+ @app.get("/status")
336
+ def get_status():
337
+ """Get queue status"""
338
+ conn = sqlite3.connect(DB_PATH)
339
+ cursor = conn.cursor()
340
+
341
+ cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "pending"')
342
+ pending = cursor.fetchone()[0]
343
+
344
+ cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "processing"')
345
+ processing = cursor.fetchone()[0]
346
+
347
+ cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "completed"')
348
+ completed = cursor.fetchone()[0]
349
+
350
+ cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "failed"')
351
+ failed = cursor.fetchone()[0]
352
+
353
+ cursor.execute('SELECT COUNT(*) FROM embedding_requests')
354
+ total = cursor.fetchone()[0]
355
+
356
+ conn.close()
357
+
358
+ return {
359
+ "total": total,
360
+ "pending": pending,
361
+ "processing": processing,
362
+ "completed": completed,
363
+ "failed": failed,
364
+ "is_processing": is_processing
365
+ }
366
+
367
+
368
+ @app.get("/request/{db_id}")
369
+ def get_request_status(db_id: int):
370
+ """Check status of a specific request"""
371
+ conn = sqlite3.connect(DB_PATH)
372
+ cursor = conn.cursor()
373
+
374
+ cursor.execute('''
375
+ SELECT id, request_id, status, created_at, processed_at, webhook_sent, error_message
376
+ FROM embedding_requests
377
+ WHERE id = ?
378
+ ''', (db_id,))
379
+
380
+ result = cursor.fetchone()
381
+ conn.close()
382
+
383
+ if not result:
384
+ raise HTTPException(status_code=404, detail="Request not found")
385
+
386
+ return {
387
+ "db_id": result[0],
388
+ "request_id": result[1],
389
+ "status": result[2],
390
+ "created_at": result[3],
391
+ "processed_at": result[4],
392
+ "webhook_sent": bool(result[5]),
393
+ "error_message": result[6]
394
+ }
395
+
396
+
397
  if __name__ == "__main__":
398
  uvicorn.run(app, host="0.0.0.0", port=7860)