Mazenbs commited on
Commit
913fabc
ยท
verified ยท
1 Parent(s): 5653651

Create kami.py

Browse files
Files changed (1) hide show
  1. kami.py +120 -0
kami.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import os
3
+ import re
4
+ import time
5
+ import uuid
6
+ from functools import lru_cache
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel, Field
13
+ from transformers import AutoTokenizer
14
+
15
+ # ------------------------------------------------------------------
16
+ # 1. FastAPI App
17
+ # ------------------------------------------------------------------
18
+ app = FastAPI(
19
+ title="Arabic-ONNX-Embedding",
20
+ version="1.0.0",
21
+ docs_url=None, # disable docs to save memory & latency
22
+ redoc_url=None,
23
+ )
24
+
25
+ # ------------------------------------------------------------------
26
+ # 2. ONNX Runtime โ€“ CPU-optimised session
27
+ # ------------------------------------------------------------------
28
+ MODEL_PATH = "lib/intfloat_multilingual-e5-small_merged_int8.onnx"
29
+
30
+ sess_opts = ort.SessionOptions()
31
+ sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
32
+ sess_opts.intra_op_num_threads = os.cpu_count() or 1
33
+ sess_opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
34
+ sess_opts.add_session_config_entry("session.set_denormal_as_zero", "1")
35
+
36
+ providers = ["CPUExecutionProvider"]
37
+ session = ort.InferenceSession(
38
+ MODEL_PATH, providers=providers, sess_options=sess_opts
39
+ )
40
+
41
+ # ------------------------------------------------------------------
42
+ # 3. Tokenizer โ€“ load once
43
+ # ------------------------------------------------------------------
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ "./lib", local_files_only=True, use_fast=True
46
+ )
47
+
48
+ # ------------------------------------------------------------------
49
+ # 4. Normalisation โ€“ fast & cached
50
+ # ------------------------------------------------------------------
51
+ @lru_cache(maxsize=20_000)
52
+ def _normalize(text: str) -> str:
53
+ text = re.sub(r"[ูŽู‹ููŒููู’ู€]", "", text)
54
+ text = re.sub(r"[ุฅุฃุข]", "ุง", text)
55
+ text = re.sub(r"ู‰", "ูŠ", text)
56
+ text = re.sub(r"ุค", "ูˆ", text)
57
+ text = re.sub(r"ุฆ", "ูŠ", text)
58
+ text = re.sub(r"ุฉ\b", "ู‡", text)
59
+ text = re.sub(r"[^\w\s]", " ", text)
60
+ text = re.sub(r"\s+", " ", text)
61
+ return text.strip()
62
+
63
+ # ------------------------------------------------------------------
64
+ # 5. Core embedding โ€“ no async, no locks, pure CPU
65
+ # ------------------------------------------------------------------
66
+ def text_to_embedding(text: str) -> List[float]:
67
+ if not text or not text.strip():
68
+ raise ValueError("Empty text")
69
+
70
+ text = "query: " + _normalize(text.strip())
71
+
72
+ inputs = tokenizer(
73
+ text,
74
+ return_tensors="np",
75
+ truncation=True,
76
+ padding=False, # single query โ†’ no padding
77
+ max_length=128,
78
+ return_attention_mask=True,
79
+ return_token_type_ids=False,
80
+ )
81
+
82
+ vec = session.run(None, dict(inputs))[1][0] # shape: (768,)
83
+ norm = np.linalg.norm(vec)
84
+ if norm > 0:
85
+ vec /= norm
86
+ return vec.astype(np.float32).tolist()
87
+
88
+ # ------------------------------------------------------------------
89
+ # 6. Warm-up on startup
90
+ # ------------------------------------------------------------------
91
+ @app.on_event("startup")
92
+ def _warm():
93
+ text_to_embedding("ู…ุฑุญุจุง")
94
+
95
+ # ------------------------------------------------------------------
96
+ # 7. Pydantic models
97
+ # ------------------------------------------------------------------
98
+ class QueryIn(BaseModel):
99
+ q: str = Field(..., min_length=1, max_length=256)
100
+
101
+ class EmbeddingOut(BaseModel):
102
+ embedding: List[float]
103
+
104
+ # ------------------------------------------------------------------
105
+ # 8. Endpoint โ€“ minimal, sync, no extra middleware
106
+ # ------------------------------------------------------------------
107
+ @app.post("/query", response_model=EmbeddingOut)
108
+ def query_endpoint(item: QueryIn):
109
+ try:
110
+ emb = text_to_embedding(item.q)
111
+ return EmbeddingOut(embedding=emb)
112
+ except Exception:
113
+ raise HTTPException(status_code=400, detail="Bad input")
114
+
115
+ # ------------------------------------------------------------------
116
+ # 9. Health-check (optional but lightweight)
117
+ # ------------------------------------------------------------------
118
+ @app.get("/health")
119
+ def health():
120
+ return {"status": "ok"}