# main.py - SLM Inference Server from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import subprocess import tiktoken import os import time app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class GenerateRequest(BaseModel): prompt: str max_tokens: int = 100 temperature: float = 0.8 top_k: int = 40 try: enc = tiktoken.get_encoding("gpt2") print("Tokenizer loaded successfully.") except Exception as e: print(f"Warning: tiktoken not found. Error: {e}") enc = None @app.get("/") async def root(): current_dir = os.path.dirname(os.path.abspath(__file__)) return FileResponse(os.path.join(current_dir, "index.html")) @app.get("/health") async def health_check(): current_dir = os.path.dirname(os.path.abspath(__file__)) exe_path = os.path.join(current_dir, "inference") model_path = os.path.join(current_dir, "model.bin") return { "status": "ok", "inference_exe_found": os.path.exists(exe_path), "model_bin_found": os.path.exists(model_path), "working_directory": current_dir } @app.post("/generate") async def generate_text(req: GenerateRequest): if enc is None: raise HTTPException(status_code=500, detail="Tokenizer not loaded.") input_tokens = enc.encode(req.prompt) token_str = ",".join(map(str, input_tokens)) current_dir = os.path.dirname(os.path.abspath(__file__)) exe_path = os.path.join(current_dir, "inference") model_path = os.path.join(current_dir, "model.bin") if not os.path.exists(exe_path): raise HTTPException(status_code=500, detail=f"inference binary not found: {exe_path}") if not os.path.exists(model_path): raise HTTPException(status_code=500, detail=f"model.bin not found: {model_path}") try: start_time = time.perf_counter() process = subprocess.run( [exe_path, token_str, str(req.max_tokens), str(req.temperature), str(req.top_k)], capture_output=True, text=True, cwd=current_dir ) elapsed_ms = (time.perf_counter() - start_time) * 1000 except Exception as e: raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}") if process.returncode != 0 and not process.stdout.strip(): stdout_msg = process.stdout.strip() if process.stdout else "" stderr_msg = process.stderr.strip() if process.stderr else "" raise HTTPException(status_code=500, detail=f"C++ Error | stdout: '{stdout_msg}' | stderr: '{stderr_msg}'") try: output_str = process.stdout.strip() generated_ids = [] if output_str: for x in output_str.split(): try: generated_ids.append(int(x)) except ValueError: pass generated_text = enc.decode(generated_ids) if generated_ids else "" tokens_out = len(generated_ids) tokens_per_sec = round(tokens_out / (elapsed_ms / 1000), 2) if elapsed_ms > 0 else 0 return { "prompt": req.prompt, "generated_text": generated_text, "tokens_in": len(input_tokens), "tokens_out": tokens_out, "latency_ms": round(elapsed_ms, 2), "tokens_per_sec": tokens_per_sec } except Exception as e: raise HTTPException(status_code=500, detail=f"Decoding error: {str(e)}")