LH-Tech-AI's picture
Update app.py
b30d9e6 verified
raw
history blame
2.5 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime as ort
import numpy as np
import tiktoken
import json
import os
app = FastAPI()
# WICHTIG: Erlaubt deinem externen Frontend den Zugriff
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Hier kannst du später deine Domain eintragen
allow_methods=["*"],
allow_headers=["*"],
)
# Modell & Tokenizer laden
tokenizer = tiktoken.get_encoding("gpt2")
MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
# Optimierte Session-Optionen für CPU
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 2 # HF Spaces haben meist 2 Kerne
session = ort.InferenceSession(MODEL_PATH, sess_options, providers=['CPUExecutionProvider'])
def top_k_sample(logits, k=50, temp=0.7):
logits = logits / max(temp, 1e-6)
# Nur die Top-K Werte betrachten (spart massiv Zeit beim Sortieren)
top_k_indices = np.argpartition(logits, -k)[-k:]
top_k_logits = logits[top_k_indices]
# Stabiler Softmax
exp_logits = np.exp(top_k_logits - np.max(top_k_logits))
probs = exp_logits / np.sum(exp_logits)
return int(np.random.choice(top_k_indices, p=probs))
@app.post("/chat")
async def chat(request: Request):
data = await request.json()
prompt = f"Instruction:\n{data['prompt']}\n\nResponse:\n"
tokens = tokenizer.encode(prompt)
max_len = int(data.get('maxLen', 100))
temp = float(data.get('temp', 0.7))
top_k = int(data.get('topK', 40))
async def generate():
nonlocal tokens
for _ in range(max_len):
# Kontext auf 1024 beschränken
ctx = tokens[-1024:]
# Padding (Rechtsbündig)
padded = np.zeros((1, 1024), dtype=np.int64)
padded[0, -len(ctx):] = ctx
# Inferenz
outputs = session.run(None, {'input': padded})
# Wir nehmen nur die Logits des letzten Tokens
logits = outputs[0][0, -1, :50304]
next_token = top_k_sample(logits, k=top_k, temp=temp)
if next_token == 50256: # EOS
break
tokens.append(next_token)
yield f"data: {json.dumps({'token': tokenizer.decode([next_token])})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@app.get("/")
def health():
return {"status": "SmaLLMPro API is online"}