emmd / gomain.py
Mazenbs's picture
Update gomain.py
b6f9965 verified
import os
import time
import re
import gc
import logging
from functools import lru_cache
from typing import List
import numpy as np
import psutil
import onnxruntime as ort
from transformers import AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
# =====================================================
# إعدادات عامة (CPU – HuggingFace Spaces)
# =====================================================
MODEL_PATH = os.environ["MODEL_PATH"]
TOKENIZER_PATH = os.environ["TOKENIZER_PATH"]
MAXTEXTLENGTH = int(os.environ.get("MAXTEXTLENGTH", 512))
CACHESIZE = int(os.environ.get("CACHESIZE", 512))
PORT = int(os.environ.get("PORT", 7860))
DEFAULT_DIM = 256 # أفضل توازن سرعة / جودة
# تقليل logging
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger("embedding-api")
# =====================================================
# تسريع ONNX Runtime (CPU)
# =====================================================
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OMP_WAIT_POLICY"] = "ACTIVE"
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 1
sess_options.enable_cpu_mem_arena = True
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
session = ort.InferenceSession(
MODEL_PATH,
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
# =====================================================
# تحميل tokenizer مرة واحدة
# =====================================================
tokenizer = AutoTokenizer.from_pretrained(
TOKENIZER_PATH,
local_files_only=True,
use_fast=True
)
# =====================================================
# تطبيع النص العربي (مع cache)
# =====================================================
@lru_cache(maxsize=1024)
def normalize_arabic(text: str) -> str:
text = re.sub(r"[ًٌٍَُِّْـ]", "", text)
text = re.sub(r"[إأآ]", "ا", text)
text = re.sub(r"ى", "ي", text)
text = re.sub(r"ؤ", "و", text)
text = re.sub(r"ئ", "ي", text)
text = re.sub(r"ة\b", "ه", text)
text = re.sub(r"[^\w\s]", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
# =====================================================
# تحويل النص إلى Embedding (سريع جدًا + cache)
# =====================================================
@lru_cache(maxsize=CACHESIZE)
def text_to_embedding(text: str) -> np.ndarray:
if not text or not text.strip():
return None
text = normalize_arabic(text)
inputs = tokenizer(
f"query: {text}",
return_tensors="np",
truncation=True,
max_length=96, # أقصر = أسرع
padding=False,
return_token_type_ids=False,
)
outputs = session.run(
None,
{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
)
# CLS embedding (الأسرع)
vector = outputs[1][0].astype(np.float32)
# L2 Normalize (مهم للبحث)
np.divide(vector, np.linalg.norm(vector) + 1e-12, out=vector)
return vector
# =====================================================
# نماذج API
# =====================================================
class TextRequest(BaseModel):
text: str = Field(..., minlength=1, maxlength=MAXTEXTLENGTH)
dim: int = Field(
DEFAULT_DIM,
ge=32,
description="Embedding dimension (default=256)"
)
class EmbeddingResponse(BaseModel):
embedding: List[float]
dimension: int
processing_time: float
class HealthResponse(BaseModel):
status: str
memory_usage: str
memory_available_gb: float
uptime: float
# =====================================================
# إنشاء التطبيق
# =====================================================
app = FastAPI(
title="Fast Arabic Embedding API (CPU Optimized)",
version="3.1.0"
)
# =====================================================
# نقاط النهاية
# =====================================================
@app.get("/")
def root():
return {
"message": "✅ Arabic Embedding API is running",
"docs": "/docs",
"health": "/health"
}
@app.get("/health", response_model=HealthResponse)
def health():
memory = psutil.virtual_memory()
uptime = time.time() - app.state.start_time
return HealthResponse(
status="healthy",
memory_usage=f"{memory.percent}%",
memory_available_gb=round(memory.available / (1024 ** 3), 2),
uptime=uptime,
)
@app.post("/query", response_model=EmbeddingResponse)
def query_endpoint(request: TextRequest):
start_time = time.time()
vector = text_to_embedding(request.text)
if vector is None:
raise HTTPException(400, "فشل إنشاء embedding")
dim = min(request.dim, vector.shape[0])
vector = vector[:dim]
return EmbeddingResponse(
embedding=vector.tolist(),
dimension=dim,
processing_time=time.time() - start_time
)
# =====================================================
# startup / shutdown
# =====================================================
@app.on_event("startup")
def startup():
app.state.start_time = time.time()
# warm-up (مهم جدًا)
text_to_embedding("warm up")
logger.warning("🚀 Embedding API started")
@app.on_event("shutdown")
def shutdown():
gc.collect()
logger.warning("🛑 Embedding API stopped")
# =====================================================
# تشغيل السيرفر
# =====================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=PORT,
workers=1, # مهم لـ HuggingFace Spaces
access_log=False
)