| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import onnxruntime as ort |
| import numpy as np |
| import tiktoken |
| import json |
| import os |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| tokenizer = tiktoken.get_encoding("gpt2") |
| MODEL_PATH = "SmaLLMPro_350M_int8.onnx" |
|
|
| |
| sess_options = ort.SessionOptions() |
| sess_options.intra_op_num_threads = 2 |
| session = ort.InferenceSession(MODEL_PATH, sess_options, providers=['CPUExecutionProvider']) |
|
|
| def top_k_sample(logits, k=50, temp=0.7): |
| logits = logits / max(temp, 1e-6) |
| |
| top_k_indices = np.argpartition(logits, -k)[-k:] |
| top_k_logits = logits[top_k_indices] |
| |
| |
| exp_logits = np.exp(top_k_logits - np.max(top_k_logits)) |
| probs = exp_logits / np.sum(exp_logits) |
| |
| return int(np.random.choice(top_k_indices, p=probs)) |
|
|
| @app.post("/chat") |
| async def chat(request: Request): |
| data = await request.json() |
| prompt = f"Instruction:\n{data['prompt']}\n\nResponse:\n" |
| tokens = tokenizer.encode(prompt) |
| |
| max_len = int(data.get('maxLen', 100)) |
| temp = float(data.get('temp', 0.7)) |
| top_k = int(data.get('topK', 40)) |
|
|
| async def generate(): |
| nonlocal tokens |
| for _ in range(max_len): |
| |
| ctx = tokens[-1024:] |
| |
| padded = np.zeros((1, 1024), dtype=np.int64) |
| padded[0, -len(ctx):] = ctx |
| |
| |
| outputs = session.run(None, {'input': padded}) |
| |
| logits = outputs[0][0, -1, :50304] |
| |
| next_token = top_k_sample(logits, k=top_k, temp=temp) |
| |
| if next_token == 50256: |
| break |
| |
| tokens.append(next_token) |
| yield f"data: {json.dumps({'token': tokenizer.decode([next_token])})}\n\n" |
|
|
| return StreamingResponse(generate(), media_type="text/event-stream") |
|
|
| @app.get("/") |
| def health(): |
| return {"status": "SmaLLMPro API is online"} |