Mazenbs commited on
Commit
ff7efde
·
verified ·
1 Parent(s): 913fabc

Create cop.py

Browse files
Files changed (1) hide show
  1. cop.py +142 -0
cop.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ # FastAPI app for ultra-low-latency Arabic query embeddings using ONNX (INT8) on CPU.
3
+ # Single-file, production-ready for Hugging Face Spaces (CPU, single worker).
4
+
5
+ import re
6
+ import time
7
+ import numpy as np
8
+ import multiprocessing
9
+ import onnxruntime as ort
10
+ from functools import lru_cache
11
+ from fastapi import FastAPI, Query, Response
12
+ from transformers import AutoTokenizer
13
+
14
+ # ==============================
15
+ # Config
16
+ # ==============================
17
+ MODEL_PATH = "lib/intfloat_multilingual-e5-small_merged_int8.onnx"
18
+ TOKENIZER_PATH = "lib" # directory containing tokenizer files
19
+ MAX_LENGTH = 64 # tuned for short queries (≤ ~15 words)
20
+
21
+ # ==============================
22
+ # ONNX Runtime session (max CPU acceleration)
23
+ # ==============================
24
+ session_options = ort.SessionOptions()
25
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
26
+ session_options.enable_cpu_mem_arena = True
27
+ session_options.intra_op_num_threads = multiprocessing.cpu_count() # use all available cores
28
+ session_options.inter_op_num_threads = 1
29
+ # Optional: write optimized graph once; harmless if it can't be written
30
+ session_options.optimized_model_filepath = "optimized_model.onnx"
31
+
32
+ session = ort.InferenceSession(
33
+ MODEL_PATH,
34
+ providers=[('CPUExecutionProvider', {})],
35
+ sess_options=session_options
36
+ )
37
+
38
+ # ==============================
39
+ # Tokenizer: load once
40
+ # ==============================
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ TOKENIZER_PATH,
43
+ local_files_only=True,
44
+ use_fast=True
45
+ )
46
+
47
+ # ==============================
48
+ # Arabic normalization (cached)
49
+ # ==============================
50
+ @lru_cache(maxsize=4096)
51
+ def normalize_arabic(text: str) -> str:
52
+ # Remove diacritics
53
+ text = re.sub(r'[ًٌٍَُِّْـ]', '', text)
54
+ # Normalize hamza/aleph variants
55
+ text = re.sub(r'[إأآ]', 'ا', text)
56
+ text = re.sub(r'ى', 'ي', text)
57
+ text = re.sub(r'ؤ', 'و', text)
58
+ text = re.sub(r'ئ', 'ي', text)
59
+ # Ta marbuta at word end -> ha (common retrieval normalization)
60
+ text = re.sub(r'ة\b', 'ه', text)
61
+ # Strip non-word chars, collapse spaces
62
+ text = re.sub(r'[^\w\s]', ' ', text)
63
+ text = re.sub(r'\s+', ' ', text)
64
+ return text.strip()
65
+
66
+ # ==============================
67
+ # Embedding function (cached, L2 normalized)
68
+ # ==============================
69
+ @lru_cache(maxsize=4096)
70
+ def embed_query_cached(query: str, do_normalize: bool) -> np.ndarray:
71
+ if do_normalize:
72
+ query = normalize_arabic(query)
73
+
74
+ # Fixed-length tokenization for stable shapes and faster CPU execution
75
+ inputs = tokenizer(
76
+ "query: " + query,
77
+ return_tensors="np",
78
+ truncation=True,
79
+ padding="max_length",
80
+ max_length=MAX_LENGTH,
81
+ return_attention_mask=True,
82
+ return_token_type_ids=False
83
+ )
84
+
85
+ # ONNX inference (INT8 model)
86
+ ort_outs = session.run(None, dict(inputs))
87
+
88
+ # E5-style pooled embedding (second output); adjust if your model differs
89
+ vector = ort_outs[1][0].astype(np.float32)
90
+
91
+ # L2 normalization
92
+ norm = np.linalg.norm(vector)
93
+ if norm > 0.0:
94
+ vector /= norm
95
+
96
+ return vector
97
+
98
+ def query_to_embedding(query: str, normalize_text: bool = True) -> np.ndarray:
99
+ # Route through cached function to maximize single-query latency performance
100
+ return embed_query_cached(query.strip(), normalize_text)
101
+
102
+ # ==============================
103
+ # FastAPI app
104
+ # ==============================
105
+ app = FastAPI()
106
+
107
+ # Warm-up on startup: builds caches, JIT paths, memory arenas
108
+ @app.on_event("startup")
109
+ def warmup():
110
+ try:
111
+ _ = query_to_embedding("مرحبا بالعالم", normalize_text=True)
112
+ except Exception:
113
+ # Avoid any heavy logging; fail silently to keep startup lightweight
114
+ pass
115
+
116
+ # Ultra-low-latency GET endpoint (no extra middlewares/gzip/logging)
117
+ @app.get("/query")
118
+ def query_endpoint(
119
+ q: str = Query(..., min_length=1),
120
+ normalize: bool = Query(True)
121
+ ):
122
+ # Minimal validation and fast path
123
+ s = q.strip()
124
+ if not s:
125
+ return Response(status_code=400)
126
+
127
+ start = time.perf_counter()
128
+ vec = query_to_embedding(s, normalize_text=normalize)
129
+ latency_ms = (time.perf_counter() - start) * 1000.0
130
+
131
+ # Return only essentials (embedding as list); omit heavy metadata
132
+ return {
133
+ "embedding": vec.tolist(),
134
+ "length": len(vec),
135
+ "normalized": True,
136
+ "latency_ms": round(latency_ms, 3)
137
+ }
138
+
139
+ # Optional root for quick health checks without noise
140
+ @app.get("/")
141
+ def root():
142
+ return {"status": "ok", "model": "onnx-int8", "max_length": MAX_LENGTH}