Update app.py
Browse files
app.py
CHANGED
|
@@ -1,77 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI, Request
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
-
import onnxruntime as ort
|
| 5 |
-
import numpy as np
|
| 6 |
-
import tiktoken
|
| 7 |
-
import json
|
| 8 |
-
import os
|
| 9 |
|
|
|
|
| 10 |
app = FastAPI()
|
| 11 |
|
| 12 |
-
#
|
| 13 |
app.add_middleware(
|
| 14 |
CORSMiddleware,
|
| 15 |
-
allow_origins=["*"],
|
| 16 |
allow_methods=["*"],
|
| 17 |
allow_headers=["*"],
|
| 18 |
)
|
| 19 |
|
| 20 |
-
# Modell & Tokenizer
|
| 21 |
-
|
| 22 |
MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
|
|
|
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
logits = logits / max(temp, 1e-6)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
|
|
|
|
| 37 |
probs = exp_logits / np.sum(exp_logits)
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
@app.post("/chat")
|
| 42 |
async def chat(request: Request):
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
@app.get("/")
|
| 76 |
-
def health():
|
| 77 |
-
return {"status": "SmaLLMPro
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
import tiktoken
|
| 7 |
from fastapi import FastAPI, Request
|
| 8 |
from fastapi.responses import StreamingResponse
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# App-Initialisierung
|
| 12 |
app = FastAPI()
|
| 13 |
|
| 14 |
+
# CORS für dein externes Frontend
|
| 15 |
app.add_middleware(
|
| 16 |
CORSMiddleware,
|
| 17 |
+
allow_origins=["*"],
|
| 18 |
allow_methods=["*"],
|
| 19 |
allow_headers=["*"],
|
| 20 |
)
|
| 21 |
|
| 22 |
+
# 1. Modell & Tokenizer Setup
|
| 23 |
+
TOKENIZER = tiktoken.get_encoding("gpt2")
|
| 24 |
MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
|
| 25 |
+
VOCAB_SIZE = 50304
|
| 26 |
|
| 27 |
+
# 2. ONNX Runtime Optimierung
|
| 28 |
+
# HF Free Spaces haben 2 vCPUs. Wir limitieren die Threads,
|
| 29 |
+
# um "Context Switching" Overhead zu vermeiden.
|
| 30 |
+
options = ort.SessionOptions()
|
| 31 |
+
options.intra_op_num_threads = 2
|
| 32 |
+
options.inter_op_num_threads = 2
|
| 33 |
+
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 34 |
+
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
| 35 |
|
| 36 |
+
# Session laden
|
| 37 |
+
print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...")
|
| 38 |
+
session = ort.InferenceSession(
|
| 39 |
+
MODEL_PATH,
|
| 40 |
+
sess_options=options,
|
| 41 |
+
providers=['CPUExecutionProvider']
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None):
|
| 45 |
+
"""Hochoptimiertes Sampling mit NumPy"""
|
| 46 |
+
# 1. Repetition Penalty (optional, falls history vorhanden)
|
| 47 |
+
if history is not None and penalty != 1.0:
|
| 48 |
+
# Wir bestrafen bereits generierte Tokens direkt in den Logits
|
| 49 |
+
unique_history = np.unique(history)
|
| 50 |
+
# Nur gültige Token-Indices bestrafen
|
| 51 |
+
valid_indices = unique_history[unique_history < len(logits)]
|
| 52 |
+
logits[valid_indices] /= penalty
|
| 53 |
+
|
| 54 |
+
# 2. Temperature Scaling
|
| 55 |
logits = logits / max(temp, 1e-6)
|
| 56 |
+
|
| 57 |
+
# 3. Top-K via Partition (schneller als vollständiges Sortieren)
|
| 58 |
+
# Sucht die k größten Werte ohne den Rest zu sortieren
|
| 59 |
+
top_k_idx = np.argpartition(logits, -k)[-k:]
|
| 60 |
+
top_k_logits = logits[top_k_idx]
|
| 61 |
|
| 62 |
+
# 4. Softmax
|
| 63 |
+
shifted_logits = top_k_logits - np.max(top_k_logits)
|
| 64 |
+
exp_logits = np.exp(shifted_logits)
|
| 65 |
probs = exp_logits / np.sum(exp_logits)
|
| 66 |
|
| 67 |
+
# 5. Sample
|
| 68 |
+
choice = np.random.choice(top_k_idx, p=probs)
|
| 69 |
+
return int(choice)
|
| 70 |
|
| 71 |
@app.post("/chat")
|
| 72 |
async def chat(request: Request):
|
| 73 |
+
try:
|
| 74 |
+
data = await request.json()
|
| 75 |
+
user_prompt = data.get('prompt', '')
|
| 76 |
+
max_len = int(data.get('maxLen', 100))
|
| 77 |
+
temp = float(data.get('temp', 0.7))
|
| 78 |
+
top_k = int(data.get('topK', 25))
|
| 79 |
+
repetition_penalty = float(data.get('penalty', 1.2))
|
| 80 |
+
|
| 81 |
+
# Alpaca Instruction Format
|
| 82 |
+
full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n"
|
| 83 |
+
tokens = TOKENIZER.encode(full_prompt)
|
| 84 |
+
|
| 85 |
+
async def generate():
|
| 86 |
+
nonlocal tokens
|
| 87 |
+
# Wir behalten die Historie für die Penalty im Auge
|
| 88 |
+
history = np.array(tokens, dtype=np.int32)
|
| 89 |
+
|
| 90 |
+
for _ in range(max_len):
|
| 91 |
+
# 1. Context Handling: Immer exakt 1024 (Padding rechtsbündig)
|
| 92 |
+
ctx = tokens[-1024:]
|
| 93 |
+
input_array = np.zeros((1, 1024), dtype=np.int64)
|
| 94 |
+
input_array[0, -len(ctx):] = ctx
|
| 95 |
+
|
| 96 |
+
# 2. Inferenz (Synchroner Call in asynchronem Generator)
|
| 97 |
+
# Das ist der Flaschenhals, hier arbeitet die CPU
|
| 98 |
+
outputs = session.run(None, {'input': input_array})
|
| 99 |
+
|
| 100 |
+
# 3. Logits extrahieren (letztes Token, erste VOCAB_SIZE)
|
| 101 |
+
logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
|
| 102 |
+
|
| 103 |
+
# 4. Sampling
|
| 104 |
+
next_token = fast_top_k_sample(
|
| 105 |
+
logits,
|
| 106 |
+
k=top_k,
|
| 107 |
+
temp=temp,
|
| 108 |
+
penalty=repetition_penalty,
|
| 109 |
+
history=history
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if next_token == 50256: # EOS Token
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
# 5. Update state
|
| 116 |
+
tokens.append(next_token)
|
| 117 |
+
history = np.append(history, next_token)
|
| 118 |
+
|
| 119 |
+
# 6. Stream zum Client
|
| 120 |
+
yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n"
|
| 121 |
+
|
| 122 |
+
# Kurze Pause für den Event-Loop
|
| 123 |
+
await asyncio.sleep(0.01)
|
| 124 |
+
|
| 125 |
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"Error: {e}")
|
| 129 |
+
return {"error": str(e)}
|
| 130 |
|
| 131 |
@app.get("/")
|
| 132 |
+
async def health():
|
| 133 |
+
return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads}
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
import uvicorn
|
| 137 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|