ex510 commited on
Commit
388960d
·
verified ·
1 Parent(s): 4e20d48

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -3
main.py CHANGED
@@ -4,35 +4,72 @@ from sentence_transformers import SentenceTransformer
4
  import uvicorn
5
  import asyncio
6
  from concurrent.futures import ThreadPoolExecutor
 
 
7
 
8
  app = FastAPI(title="Text Embedding API (Qwen/Qwen3-Embedding-0.6B)")
9
 
10
  class TextRequest(BaseModel):
11
- text: str = Field(..., min_length=1, max_length=10000, description="Text to embed")
12
 
13
  # Globals
14
  model = None
 
15
  model_id = 'Qwen/Qwen3-Embedding-0.6B'
16
  executor = ThreadPoolExecutor(max_workers=4)
 
17
 
18
  @app.on_event("startup")
19
  async def load_model():
20
- global model
21
  print(f"Loading model: {model_id}...")
22
  model = SentenceTransformer(model_id)
 
23
  print("Model loaded successfully")
24
 
25
  @app.get("/")
26
  def home():
27
  return {"status": "online", "model": model_id, "endpoint": "/embed/text"}
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @app.post("/embed/text")
30
  async def embed_text(request: TextRequest):
31
  try:
32
  loop = asyncio.get_event_loop()
33
  embedding = await loop.run_in_executor(
34
  executor,
35
- lambda: model.encode(request.text, normalize_embeddings=True).tolist()
36
  )
37
 
38
  return {
 
4
  import uvicorn
5
  import asyncio
6
  from concurrent.futures import ThreadPoolExecutor
7
+ from typing import List # ← إضافة جديدة
8
+ import numpy as np # ← إضافة جديدة
9
 
10
  app = FastAPI(title="Text Embedding API (Qwen/Qwen3-Embedding-0.6B)")
11
 
12
  class TextRequest(BaseModel):
13
+ text: str = Field(..., min_length=1, description="Text to embed") # ← تم حذف max_length=10000
14
 
15
  # Globals
16
  model = None
17
+ tokenizer = None # ← إضافة جديدة
18
  model_id = 'Qwen/Qwen3-Embedding-0.6B'
19
  executor = ThreadPoolExecutor(max_workers=4)
20
+ MAX_TOKENS = 512 # ← إضافة جديدة
21
 
22
  @app.on_event("startup")
23
  async def load_model():
24
+ global model, tokenizer # ← تم إضافة tokenizer
25
  print(f"Loading model: {model_id}...")
26
  model = SentenceTransformer(model_id)
27
+ tokenizer = model.tokenizer # ← إضافة جديدة
28
  print("Model loaded successfully")
29
 
30
  @app.get("/")
31
  def home():
32
  return {"status": "online", "model": model_id, "endpoint": "/embed/text"}
33
 
34
+ # ↓↓↓ Function جديدة كاملة ↓↓↓
35
+ def chunk_and_embed(text: str) -> List[float]:
36
+ """Split text into chunks if too long, then pool embeddings"""
37
+ tokens = tokenizer.encode(text, add_special_tokens=False)
38
+
39
+ # If text is short, embed directly
40
+ if len(tokens) <= MAX_TOKENS:
41
+ return model.encode(text, normalize_embeddings=True).tolist()
42
+
43
+ # Split into chunks
44
+ chunks = []
45
+ overlap = 50
46
+ start = 0
47
+ while start < len(tokens):
48
+ end = start + MAX_TOKENS
49
+ chunk_tokens = tokens[start:end]
50
+ chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
51
+ chunks.append(chunk_text)
52
+
53
+ if end >= len(tokens):
54
+ break
55
+ start = end - overlap
56
+
57
+ # Embed all chunks
58
+ chunk_embeddings = [model.encode(chunk, normalize_embeddings=True) for chunk in chunks]
59
+
60
+ # Pool embeddings (mean)
61
+ final_embedding = np.mean(chunk_embeddings, axis=0).tolist()
62
+
63
+ return final_embedding
64
+ # ↑↑↑ نهاية Function الجديدة ↑↑↑
65
+
66
  @app.post("/embed/text")
67
  async def embed_text(request: TextRequest):
68
  try:
69
  loop = asyncio.get_event_loop()
70
  embedding = await loop.run_in_executor(
71
  executor,
72
+ lambda: chunk_and_embed(request.text) # ← تم التعديل من model.encode إلى chunk_and_embed
73
  )
74
 
75
  return {