ex510 commited on
Commit
03cc16d
·
verified ·
1 Parent(s): 2d14f9f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -68
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from pydantic import BaseModel, Field
3
  from sentence_transformers import SentenceTransformer
4
  import uvicorn
@@ -9,23 +9,19 @@ from contextlib import asynccontextmanager
9
  import httpx
10
  import os
11
  import sqlite3
12
- from datetime import datetime
13
  import json
14
- import threading
15
 
16
  # Globals
17
  model = None
18
  tokenizer = None
19
  model_id = 'Qwen/Qwen3-Embedding-0.6B'
20
  MAX_TOKENS = 32000
21
- DB_PATH = "/data/embeddings.db" # هام: المسار ده في HuggingFace
22
- processing_lock = threading.Lock()
23
  is_processing = False
24
 
25
 
26
  def init_database():
27
  """Initialize the SQLite database"""
28
- # Create /data directory if it doesn't exist
29
  os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
30
 
31
  conn = sqlite3.connect(DB_PATH)
@@ -45,7 +41,6 @@ def init_database():
45
  )
46
  ''')
47
 
48
- # Create index for faster queries
49
  cursor.execute('''
50
  CREATE INDEX IF NOT EXISTS idx_status
51
  ON embedding_requests(status)
@@ -53,7 +48,7 @@ def init_database():
53
 
54
  conn.commit()
55
  conn.close()
56
- print("Database initialized successfully")
57
 
58
 
59
  def save_request_to_db(text: str, request_id: str = None) -> int:
@@ -125,34 +120,44 @@ def update_embedding_in_db(row_id: int, embedding: List[float]):
125
 
126
  conn.commit()
127
  conn.close()
128
- print(f"✅ Embedding updated for row ID: {row_id}")
129
 
130
 
131
- def mark_webhook_sent(row_id: int):
132
- """Mark that webhook was sent successfully"""
133
  conn = sqlite3.connect(DB_PATH)
134
  cursor = conn.cursor()
135
 
136
  cursor.execute('''
137
- UPDATE embedding_requests
138
- SET webhook_sent = 1
139
  WHERE id = ?
140
  ''', (row_id,))
141
 
142
- conn.commit()
143
  conn.close()
 
 
144
 
145
 
146
- def delete_from_db(row_id: int):
147
- """Delete the request from database after webhook is sent"""
148
  conn = sqlite3.connect(DB_PATH)
149
  cursor = conn.cursor()
150
 
 
 
 
 
 
 
 
 
151
  cursor.execute('DELETE FROM embedding_requests WHERE id = ?', (row_id,))
152
 
153
  conn.commit()
154
  conn.close()
155
- print(f"🗑️ Request deleted from DB with ID: {row_id}")
156
 
157
 
158
  def mark_request_failed(row_id: int, error_message: str):
@@ -177,42 +182,43 @@ async def lifespan(app: FastAPI):
177
  # Initialize database
178
  init_database()
179
 
180
- # Load the model and tokenizer at startup
181
  global model, tokenizer
182
  print(f"Loading model: {model_id}...")
183
  model = SentenceTransformer(model_id)
184
  tokenizer = model.tokenizer
185
- print("Model loaded successfully")
186
 
187
- # Start the background processor
188
  asyncio.create_task(process_queue())
189
 
190
  yield
191
 
192
- # Clean up
193
- print("Cleaning up resources...")
194
  model = None
195
  tokenizer = None
196
 
197
 
198
  app = FastAPI(
199
- title="Text Embedding API (Qwen/Qwen3-Embedding-0.6B)",
200
  lifespan=lifespan
201
  )
202
 
203
 
204
  class TextRequest(BaseModel):
205
  text: str = Field(..., min_length=1, description="Text to embed")
206
- request_id: str | None = Field(None, description="Optional unique identifier for the request")
207
 
208
 
209
  def chunk_and_embed(text: str) -> List[float]:
210
- """Split text into chunks if too long, then pool embeddings"""
211
  tokens = tokenizer.encode(text, add_special_tokens=False)
212
 
213
  if len(tokens) <= MAX_TOKENS:
214
  return model.encode(text, normalize_embeddings=True).tolist()
215
 
 
216
  chunks = []
217
  overlap = 50
218
  start = 0
@@ -232,74 +238,74 @@ def chunk_and_embed(text: str) -> List[float]:
232
  return final_embedding
233
 
234
 
235
- async def send_to_webhook(url: str, data: dict, db_row_id: int):
236
- """Send data to webhook and delete from DB on success"""
237
  try:
 
 
 
 
 
 
 
 
238
  async with httpx.AsyncClient(timeout=60.0) as client:
239
- response = await client.post(url, json=data)
240
  response.raise_for_status()
241
- print(f"✅ Webhook sent successfully for ID: {db_row_id}")
242
 
243
- mark_webhook_sent(db_row_id)
244
- delete_from_db(db_row_id)
245
 
246
  except Exception as e:
247
- print(f"❌ Webhook error for ID {db_row_id}: {e}")
 
248
 
249
 
250
  async def process_queue():
251
- """Background task to process pending requests one by one"""
252
  global is_processing
253
 
254
  print("🚀 Queue processor started")
255
 
256
  while True:
257
  try:
258
- # Check if there's a pending request
259
  pending = get_next_pending_request()
260
 
261
  if pending:
262
  row_id, request_id, text = pending
263
-
264
- # Mark as processing
265
  is_processing = True
266
  update_request_processing(row_id)
267
 
268
  print(f"⚙️ Processing request ID: {row_id}")
269
 
270
  try:
271
- # Generate embedding (synchronous in async context)
272
  embedding = await asyncio.to_thread(chunk_and_embed, text)
273
 
274
- # Save embedding to DB
275
  update_embedding_in_db(row_id, embedding)
276
 
277
- # Send to webhook if URL exists
278
  webhook_url = os.environ.get("WEBHOOK_URL")
279
  if webhook_url:
280
- payload = {
281
- "db_id": row_id,
282
- "text": text,
283
- "embedding": embedding,
284
- "request_id": request_id
285
- }
286
- await send_to_webhook(webhook_url, row_id, payload)
287
  else:
288
  # No webhook, just delete
289
- delete_from_db(row_id)
290
 
291
  except Exception as e:
292
- print(f"❌ Error processing request {row_id}: {e}")
293
  mark_request_failed(row_id, str(e))
294
 
295
  is_processing = False
296
 
297
  else:
298
- # No pending requests, wait a bit
299
  await asyncio.sleep(2)
300
 
301
  except Exception as e:
302
- print(f"❌ Queue processor error: {e}")
303
  is_processing = False
304
  await asyncio.sleep(5)
305
 
@@ -308,24 +314,27 @@ async def process_queue():
308
  def home():
309
  return {
310
  "status": "online",
311
- "model": model_id,
312
- "endpoint": "/embed/text",
313
  "processing": is_processing
314
  }
315
 
316
 
317
  @app.post("/embed/text")
318
  async def embed_text(request: TextRequest):
319
- """Just save the request to database, processing happens in background"""
 
 
 
320
  try:
321
- # Simply save to database
322
  db_row_id = save_request_to_db(request.text, request.request_id)
323
 
 
324
  return {
325
  "success": True,
326
- "message": "Request queued for processing",
327
  "db_id": db_row_id,
328
- "model": model_id
329
  }
330
 
331
  except Exception as e:
@@ -334,7 +343,7 @@ async def embed_text(request: TextRequest):
334
 
335
  @app.get("/status")
336
  def get_status():
337
- """Get queue status"""
338
  conn = sqlite3.connect(DB_PATH)
339
  cursor = conn.cursor()
340
 
@@ -350,24 +359,36 @@ def get_status():
350
  cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "failed"')
351
  failed = cursor.fetchone()[0]
352
 
353
- cursor.execute('SELECT COUNT(*) FROM embedding_requests')
354
- total = cursor.fetchone()[0]
 
 
 
 
 
 
 
355
 
356
  conn.close()
357
 
358
  return {
359
- "total": total,
360
- "pending": pending,
361
- "processing": processing,
362
- "completed": completed,
363
- "failed": failed,
364
- "is_processing": is_processing
 
 
 
 
 
365
  }
366
 
367
 
368
  @app.get("/request/{db_id}")
369
- def get_request_status(db_id: int):
370
- """Check status of a specific request"""
371
  conn = sqlite3.connect(DB_PATH)
372
  cursor = conn.cursor()
373
 
@@ -381,7 +402,7 @@ def get_request_status(db_id: int):
381
  conn.close()
382
 
383
  if not result:
384
- raise HTTPException(status_code=404, detail="Request not found")
385
 
386
  return {
387
  "db_id": result[0],
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel, Field
3
  from sentence_transformers import SentenceTransformer
4
  import uvicorn
 
9
  import httpx
10
  import os
11
  import sqlite3
 
12
  import json
 
13
 
14
  # Globals
15
  model = None
16
  tokenizer = None
17
  model_id = 'Qwen/Qwen3-Embedding-0.6B'
18
  MAX_TOKENS = 32000
19
+ DB_PATH = "/data/embeddings.db"
 
20
  is_processing = False
21
 
22
 
23
  def init_database():
24
  """Initialize the SQLite database"""
 
25
  os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
26
 
27
  conn = sqlite3.connect(DB_PATH)
 
41
  )
42
  ''')
43
 
 
44
  cursor.execute('''
45
  CREATE INDEX IF NOT EXISTS idx_status
46
  ON embedding_requests(status)
 
48
 
49
  conn.commit()
50
  conn.close()
51
+ print("Database initialized successfully")
52
 
53
 
54
  def save_request_to_db(text: str, request_id: str = None) -> int:
 
120
 
121
  conn.commit()
122
  conn.close()
123
+ print(f"✅ Embedding saved for ID: {row_id}")
124
 
125
 
126
+ def get_request_data(row_id: int):
127
+ """Get full request data including embedding"""
128
  conn = sqlite3.connect(DB_PATH)
129
  cursor = conn.cursor()
130
 
131
  cursor.execute('''
132
+ SELECT id, request_id, text, embedding
133
+ FROM embedding_requests
134
  WHERE id = ?
135
  ''', (row_id,))
136
 
137
+ result = cursor.fetchone()
138
  conn.close()
139
+
140
+ return result
141
 
142
 
143
+ def mark_webhook_sent_and_delete(row_id: int):
144
+ """Mark webhook as sent and delete from DB"""
145
  conn = sqlite3.connect(DB_PATH)
146
  cursor = conn.cursor()
147
 
148
+ # First mark as sent
149
+ cursor.execute('''
150
+ UPDATE embedding_requests
151
+ SET webhook_sent = 1
152
+ WHERE id = ?
153
+ ''', (row_id,))
154
+
155
+ # Then delete
156
  cursor.execute('DELETE FROM embedding_requests WHERE id = ?', (row_id,))
157
 
158
  conn.commit()
159
  conn.close()
160
+ print(f"🗑️ Request deleted from DB: {row_id}")
161
 
162
 
163
  def mark_request_failed(row_id: int, error_message: str):
 
182
  # Initialize database
183
  init_database()
184
 
185
+ # Load the model
186
  global model, tokenizer
187
  print(f"Loading model: {model_id}...")
188
  model = SentenceTransformer(model_id)
189
  tokenizer = model.tokenizer
190
+ print("Model loaded successfully")
191
 
192
+ # Start background processor
193
  asyncio.create_task(process_queue())
194
 
195
  yield
196
 
197
+ # Cleanup
198
+ print("Cleaning up...")
199
  model = None
200
  tokenizer = None
201
 
202
 
203
  app = FastAPI(
204
+ title="Text Embedding API with Queue",
205
  lifespan=lifespan
206
  )
207
 
208
 
209
  class TextRequest(BaseModel):
210
  text: str = Field(..., min_length=1, description="Text to embed")
211
+ request_id: str | None = Field(None, description="Optional request identifier")
212
 
213
 
214
  def chunk_and_embed(text: str) -> List[float]:
215
+ """Generate embedding with chunking if needed"""
216
  tokens = tokenizer.encode(text, add_special_tokens=False)
217
 
218
  if len(tokens) <= MAX_TOKENS:
219
  return model.encode(text, normalize_embeddings=True).tolist()
220
 
221
+ # Chunking
222
  chunks = []
223
  overlap = 50
224
  start = 0
 
238
  return final_embedding
239
 
240
 
241
+ async def send_to_webhook(webhook_url: str, row_id: int, request_id: str, text: str, embedding: List[float]):
242
+ """Send complete data to webhook after embedding is ready"""
243
  try:
244
+ payload = {
245
+ "db_id": row_id,
246
+ "request_id": request_id,
247
+ "text": text,
248
+ "embedding": embedding,
249
+ "status": "completed"
250
+ }
251
+
252
  async with httpx.AsyncClient(timeout=60.0) as client:
253
+ response = await client.post(webhook_url, json=payload)
254
  response.raise_for_status()
255
+ print(f"✅ Webhook sent successfully for ID: {row_id}")
256
 
257
+ # Delete from DB after successful webhook
258
+ mark_webhook_sent_and_delete(row_id)
259
 
260
  except Exception as e:
261
+ print(f"❌ Webhook error for ID {row_id}: {e}")
262
+ # Don't delete if webhook failed
263
 
264
 
265
  async def process_queue():
266
+ """Background processor - processes one request at a time"""
267
  global is_processing
268
 
269
  print("🚀 Queue processor started")
270
 
271
  while True:
272
  try:
 
273
  pending = get_next_pending_request()
274
 
275
  if pending:
276
  row_id, request_id, text = pending
 
 
277
  is_processing = True
278
  update_request_processing(row_id)
279
 
280
  print(f"⚙️ Processing request ID: {row_id}")
281
 
282
  try:
283
+ # Generate embedding
284
  embedding = await asyncio.to_thread(chunk_and_embed, text)
285
 
286
+ # Save to DB
287
  update_embedding_in_db(row_id, embedding)
288
 
289
+ # Send to webhook with ALL data
290
  webhook_url = os.environ.get("WEBHOOK_URL")
291
  if webhook_url:
292
+ await send_to_webhook(webhook_url, row_id, request_id, text, embedding)
 
 
 
 
 
 
293
  else:
294
  # No webhook, just delete
295
+ mark_webhook_sent_and_delete(row_id)
296
 
297
  except Exception as e:
298
+ print(f"❌ Error processing {row_id}: {e}")
299
  mark_request_failed(row_id, str(e))
300
 
301
  is_processing = False
302
 
303
  else:
304
+ # No pending requests
305
  await asyncio.sleep(2)
306
 
307
  except Exception as e:
308
+ print(f"❌ Queue error: {e}")
309
  is_processing = False
310
  await asyncio.sleep(5)
311
 
 
314
  def home():
315
  return {
316
  "status": "online",
317
+ "model": model_id,
 
318
  "processing": is_processing
319
  }
320
 
321
 
322
  @app.post("/embed/text")
323
  async def embed_text(request: TextRequest):
324
+ """
325
+ Fast response - just queue the request
326
+ Processing happens in background
327
+ """
328
  try:
329
+ # Save to DB immediately
330
  db_row_id = save_request_to_db(request.text, request.request_id)
331
 
332
+ # Return immediately
333
  return {
334
  "success": True,
335
+ "message": "Request queued successfully",
336
  "db_id": db_row_id,
337
+ "status": "pending"
338
  }
339
 
340
  except Exception as e:
 
343
 
344
  @app.get("/status")
345
  def get_status():
346
+ """Get queue statistics"""
347
  conn = sqlite3.connect(DB_PATH)
348
  cursor = conn.cursor()
349
 
 
359
  cursor.execute('SELECT COUNT(*) FROM embedding_requests WHERE status = "failed"')
360
  failed = cursor.fetchone()[0]
361
 
362
+ # Get next in queue
363
+ cursor.execute('''
364
+ SELECT id, created_at
365
+ FROM embedding_requests
366
+ WHERE status = "pending"
367
+ ORDER BY id ASC
368
+ LIMIT 1
369
+ ''')
370
+ next_request = cursor.fetchone()
371
 
372
  conn.close()
373
 
374
  return {
375
+ "queue": {
376
+ "pending": pending,
377
+ "processing": processing,
378
+ "completed": completed,
379
+ "failed": failed
380
+ },
381
+ "is_processing": is_processing,
382
+ "next_request": {
383
+ "id": next_request[0] if next_request else None,
384
+ "created_at": next_request[1] if next_request else None
385
+ } if next_request else None
386
  }
387
 
388
 
389
  @app.get("/request/{db_id}")
390
+ def get_request_info(db_id: int):
391
+ """Check specific request status"""
392
  conn = sqlite3.connect(DB_PATH)
393
  cursor = conn.cursor()
394
 
 
402
  conn.close()
403
 
404
  if not result:
405
+ raise HTTPException(status_code=404, detail="Request not found or already deleted")
406
 
407
  return {
408
  "db_id": result[0],