emmd / main.py
Mazenbs's picture
Rename mainv2.py to main.py
b7f088f verified
# main.py
import os
import time
import re
import gc
import logging
from functools import lru_cache
from typing import List
import numpy as np
import psutil
import onnxruntime as ort
from transformers import AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
# =====================================================
# إعدادات عامة (CPU – HuggingFace Spaces)
# =====================================================
MODEL_PATH = os.environ["MODEL_PATH"]
TOKENIZER_PATH = os.environ["TOKENIZER_PATH"]
MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", 512))
CACHE_SIZE = int(os.environ.get("CACHE_SIZE", 512))
PORT = int(os.environ.get("PORT", 7860))
# تقليل logging لزيادة السرعة
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger("embedding-api")
# =====================================================
# تسريع ONNX Runtime على CPU
# =====================================================
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OMP_WAIT_POLICY"] = "ACTIVE"
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 1
sess_options.enable_cpu_mem_arena = True
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
session = ort.InferenceSession(
MODEL_PATH,
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
# =====================================================
# تحميل tokenizer مرة واحدة
# =====================================================
tokenizer = AutoTokenizer.from_pretrained(
TOKENIZER_PATH,
local_files_only=True,
use_fast=True
)
# =====================================================
# تطبيع النص العربي (مع cache)
# =====================================================
@lru_cache(maxsize=1024)
def normalize_arabic(text: str) -> str:
text = re.sub(r"[ًٌٍَُِّْـ]", "", text)
text = re.sub(r"[إأآ]", "ا", text)
text = re.sub(r"ى", "ي", text)
text = re.sub(r"ؤ", "و", text)
text = re.sub(r"ئ", "ي", text)
text = re.sub(r"ة\b", "ه", text)
text = re.sub(r"[^\w\s]", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
# =====================================================
# تحويل النص إلى Embedding (سريع + cache)
# =====================================================
@lru_cache(maxsize=CACHE_SIZE)
def text_to_embedding(text: str) -> np.ndarray:
if not text or not text.strip():
return None
text = normalize_arabic(text)
inputs = tokenizer(
f"query: {text}",
return_tensors="np",
truncation=True,
max_length=128, # query قصير = أسرع
padding=False,
return_token_type_ids=False,
)
outputs = session.run(
None,
{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
)
vector = outputs[1][0]
# L2 normalize (سريع)
vector /= np.linalg.norm(vector) + 1e-12
return vector.astype(np.float32)
# =====================================================
# نماذج API
# =====================================================
class TextRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
class EmbeddingResponse(BaseModel):
embedding: List[float]
dimension: int
processing_time: float
class HealthResponse(BaseModel):
status: str
memory_usage: str
memory_available_gb: float
uptime: float
# =====================================================
# إنشاء التطبيق
# =====================================================
app = FastAPI(
title="Fast Arabic Embedding API (CPU Optimized)",
version="3.0.0"
)
# =====================================================
# نقاط النهاية
# =====================================================
@app.get("/")
def root():
return {
"message": "✅ Arabic Embedding API is running",
"docs": "/docs",
"health": "/health"
}
@app.get("/health", response_model=HealthResponse)
def health():
memory = psutil.virtual_memory()
uptime = time.time() - app.state.start_time
return HealthResponse(
status="healthy",
memory_usage=f"{memory.percent}%",
memory_available_gb=round(memory.available / (1024 ** 3), 2),
uptime=uptime,
)
@app.post("/query", response_model=EmbeddingResponse)
def query_endpoint(request: TextRequest):
start_time = time.time()
# --- السطر المضاف هنا ---
print(f"\n{'='*20}")
print(f"🔎 New Search Query: {request.text}")
print(f"{'='*20}\n", flush=True)
vector = text_to_embedding(request.text)
if vector is None:
raise HTTPException(400, "فشل إنشاء embedding")
return EmbeddingResponse(
embedding=vector.tolist(),
dimension=vector.shape[0],
processing_time=time.time() - start_time
)
# =====================================================
# startup / shutdown
# =====================================================
@app.on_event("startup")
def startup():
app.state.start_time = time.time()
# warm-up (مهم جدًا)
text_to_embedding("warm up")
logger.warning("🚀 Embedding API started")
@app.on_event("shutdown")
def shutdown():
gc.collect()
logger.warning("🛑 Embedding API stopped")
# =====================================================
# تشغيل السيرفر
# =====================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=PORT,
workers=1, # مهم لـ HuggingFace
access_log=False
)