LH-Tech-AI commited on
Commit
acdc12e
·
verified ·
1 Parent(s): 4de9275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -53
app.py CHANGED
@@ -1,77 +1,137 @@
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
- import onnxruntime as ort
5
- import numpy as np
6
- import tiktoken
7
- import json
8
- import os
9
 
 
10
  app = FastAPI()
11
 
12
- # WICHTIG: Erlaubt deinem externen Frontend den Zugriff
13
  app.add_middleware(
14
  CORSMiddleware,
15
- allow_origins=["*"], # Hier kannst du später deine Domain eintragen
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
- # Modell & Tokenizer laden
21
- tokenizer = tiktoken.get_encoding("gpt2")
22
  MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
 
23
 
24
- # Optimierte Session-Optionen für CPU
25
- sess_options = ort.SessionOptions()
26
- sess_options.intra_op_num_threads = 2 # HF Spaces haben meist 2 Kerne
27
- session = ort.InferenceSession(MODEL_PATH, sess_options, providers=['CPUExecutionProvider'])
 
 
 
 
28
 
29
- def top_k_sample(logits, k=50, temp=0.7):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  logits = logits / max(temp, 1e-6)
31
- # Nur die Top-K Werte betrachten (spart massiv Zeit beim Sortieren)
32
- top_k_indices = np.argpartition(logits, -k)[-k:]
33
- top_k_logits = logits[top_k_indices]
 
 
34
 
35
- # Stabiler Softmax
36
- exp_logits = np.exp(top_k_logits - np.max(top_k_logits))
 
37
  probs = exp_logits / np.sum(exp_logits)
38
 
39
- return int(np.random.choice(top_k_indices, p=probs))
 
 
40
 
41
  @app.post("/chat")
42
  async def chat(request: Request):
43
- data = await request.json()
44
- prompt = f"Instruction:\n{data['prompt']}\n\nResponse:\n"
45
- tokens = tokenizer.encode(prompt)
46
-
47
- max_len = int(data.get('maxLen', 100))
48
- temp = float(data.get('temp', 0.7))
49
- top_k = int(data.get('topK', 40))
50
-
51
- async def generate():
52
- nonlocal tokens
53
- for _ in range(max_len):
54
- # Kontext auf 1024 beschränken
55
- ctx = tokens[-1024:]
56
- # Padding (Rechtsbündig)
57
- padded = np.zeros((1, 1024), dtype=np.int64)
58
- padded[0, -len(ctx):] = ctx
59
-
60
- # Inferenz
61
- outputs = session.run(None, {'input': padded})
62
- # Wir nehmen nur die Logits des letzten Tokens
63
- logits = outputs[0][0, -1, :50304]
64
-
65
- next_token = top_k_sample(logits, k=top_k, temp=temp)
66
-
67
- if next_token == 50256: # EOS
68
- break
69
-
70
- tokens.append(next_token)
71
- yield f"data: {json.dumps({'token': tokenizer.decode([next_token])})}\n\n"
72
-
73
- return StreamingResponse(generate(), media_type="text/event-stream")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  @app.get("/")
76
- def health():
77
- return {"status": "SmaLLMPro API is online"}
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ import tiktoken
7
  from fastapi import FastAPI, Request
8
  from fastapi.responses import StreamingResponse
9
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
10
 
11
+ # App-Initialisierung
12
  app = FastAPI()
13
 
14
+ # CORS für dein externes Frontend
15
  app.add_middleware(
16
  CORSMiddleware,
17
+ allow_origins=["*"],
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
+ # 1. Modell & Tokenizer Setup
23
+ TOKENIZER = tiktoken.get_encoding("gpt2")
24
  MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
25
+ VOCAB_SIZE = 50304
26
 
27
+ # 2. ONNX Runtime Optimierung
28
+ # HF Free Spaces haben 2 vCPUs. Wir limitieren die Threads,
29
+ # um "Context Switching" Overhead zu vermeiden.
30
+ options = ort.SessionOptions()
31
+ options.intra_op_num_threads = 2
32
+ options.inter_op_num_threads = 2
33
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
34
+ options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
35
 
36
+ # Session laden
37
+ print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...")
38
+ session = ort.InferenceSession(
39
+ MODEL_PATH,
40
+ sess_options=options,
41
+ providers=['CPUExecutionProvider']
42
+ )
43
+
44
+ def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None):
45
+ """Hochoptimiertes Sampling mit NumPy"""
46
+ # 1. Repetition Penalty (optional, falls history vorhanden)
47
+ if history is not None and penalty != 1.0:
48
+ # Wir bestrafen bereits generierte Tokens direkt in den Logits
49
+ unique_history = np.unique(history)
50
+ # Nur gültige Token-Indices bestrafen
51
+ valid_indices = unique_history[unique_history < len(logits)]
52
+ logits[valid_indices] /= penalty
53
+
54
+ # 2. Temperature Scaling
55
  logits = logits / max(temp, 1e-6)
56
+
57
+ # 3. Top-K via Partition (schneller als vollständiges Sortieren)
58
+ # Sucht die k größten Werte ohne den Rest zu sortieren
59
+ top_k_idx = np.argpartition(logits, -k)[-k:]
60
+ top_k_logits = logits[top_k_idx]
61
 
62
+ # 4. Softmax
63
+ shifted_logits = top_k_logits - np.max(top_k_logits)
64
+ exp_logits = np.exp(shifted_logits)
65
  probs = exp_logits / np.sum(exp_logits)
66
 
67
+ # 5. Sample
68
+ choice = np.random.choice(top_k_idx, p=probs)
69
+ return int(choice)
70
 
71
  @app.post("/chat")
72
  async def chat(request: Request):
73
+ try:
74
+ data = await request.json()
75
+ user_prompt = data.get('prompt', '')
76
+ max_len = int(data.get('maxLen', 100))
77
+ temp = float(data.get('temp', 0.7))
78
+ top_k = int(data.get('topK', 25))
79
+ repetition_penalty = float(data.get('penalty', 1.2))
80
+
81
+ # Alpaca Instruction Format
82
+ full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n"
83
+ tokens = TOKENIZER.encode(full_prompt)
84
+
85
+ async def generate():
86
+ nonlocal tokens
87
+ # Wir behalten die Historie für die Penalty im Auge
88
+ history = np.array(tokens, dtype=np.int32)
89
+
90
+ for _ in range(max_len):
91
+ # 1. Context Handling: Immer exakt 1024 (Padding rechtsbündig)
92
+ ctx = tokens[-1024:]
93
+ input_array = np.zeros((1, 1024), dtype=np.int64)
94
+ input_array[0, -len(ctx):] = ctx
95
+
96
+ # 2. Inferenz (Synchroner Call in asynchronem Generator)
97
+ # Das ist der Flaschenhals, hier arbeitet die CPU
98
+ outputs = session.run(None, {'input': input_array})
99
+
100
+ # 3. Logits extrahieren (letztes Token, erste VOCAB_SIZE)
101
+ logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
102
+
103
+ # 4. Sampling
104
+ next_token = fast_top_k_sample(
105
+ logits,
106
+ k=top_k,
107
+ temp=temp,
108
+ penalty=repetition_penalty,
109
+ history=history
110
+ )
111
+
112
+ if next_token == 50256: # EOS Token
113
+ break
114
+
115
+ # 5. Update state
116
+ tokens.append(next_token)
117
+ history = np.append(history, next_token)
118
+
119
+ # 6. Stream zum Client
120
+ yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n"
121
+
122
+ # Kurze Pause für den Event-Loop
123
+ await asyncio.sleep(0.01)
124
+
125
+ return StreamingResponse(generate(), media_type="text/event-stream")
126
+
127
+ except Exception as e:
128
+ print(f"Error: {e}")
129
+ return {"error": str(e)}
130
 
131
  @app.get("/")
132
+ async def health():
133
+ return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads}
134
+
135
+ if __name__ == "__main__":
136
+ import uvicorn
137
+ uvicorn.run(app, host="0.0.0.0", port=7860)