emmd / kami.py
Mazenbs's picture
Update kami.py
dab2837 verified
# main.py
import os
import time
import re
from functools import lru_cache
from typing import List
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# ---------- 1. ONNX: أقصى تسريع CPU ----------
MODEL_PATH = os.environ["MODEL_PATH"]
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.intra_op_num_threads = 4 # أنسب لمعظم أجهزة Spaces
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
opts.add_session_config_entry("session.set_denormal_as_zero", "1")
sess = ort.InferenceSession(MODEL_PATH, opts, providers=["CPUExecutionProvider"])
# ---------- 2. Tokenizer ----------
tok = AutoTokenizer.from_pretrained("./lib", local_files_only=True, use_fast=True)
# ---------- 3. Normalisation (cache) ----------
@lru_cache(maxsize=20_000)
def _norm(text: str) -> str:
text = re.sub(r"[ًٌٍَُِّْـ]", "", text)
text = re.sub(r"[إأآ]", "ا", text)
text = re.sub(r"ى", "ي", text)
text = re.sub(r"ؤ", "و", text)
text = re.sub(r"ئ", "ي", text)
text = re.sub(r"ة\b", "ه", text)
text = re.sub(r"[^\w\s]", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
# ---------- 4. Core embedding (cache + bytes) ----------
@lru_cache(maxsize=5_000)
def _embed(text: str) -> bytes:
"""تُعيد المتجه كـ bytes (float32) لتقليل allocations"""
txt = "query: " + _norm(text)
enc = tok(txt, return_tensors="np", truncation=True, max_length=128, padding=False)
vec = sess.run(None, dict(enc))[1][0] # shape (768,)
norm = np.linalg.norm(vec)
if norm > 0:
vec /= norm
return vec.astype(np.float32).tobytes() # bytes
# ---------- 5. FastAPI (بدون extras) ----------
app = FastAPI(title="AE", version="1")
class In(BaseModel):
q: str
@app.post("/query")
def ep(item: In) -> List[float]:
b = _embed(item.q.strip())
return np.frombuffer(b, dtype=np.float32).tolist()
# ---------- 6. Warm-up ----------
@app.on_event("startup")
def _w():
_embed("مرحبا")
# ---------- 7. Run (workers=1 إجباري في Spaces) ----------
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", 7860)), workers=1, access_log=False)