File size: 3,649 Bytes
29bb2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66be360
29bb2d7
66be360
29bb2d7
 
 
 
 
 
 
 
 
 
66be360
 
29bb2d7
 
 
 
 
 
 
 
 
 
66be360
29bb2d7
 
 
 
 
 
 
 
 
66be360
29bb2d7
 
66be360
29bb2d7
 
 
 
66be360
29bb2d7
 
 
 
 
 
 
 
 
 
 
66be360
29bb2d7
 
 
66be360
 
29bb2d7
 
 
 
66be360
29bb2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
66be360
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# main.py - SLM Inference Server
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import subprocess
import tiktoken
import os
import time

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 100
    temperature: float = 0.8
    top_k: int = 40

try:
    enc = tiktoken.get_encoding("gpt2")
    print("Tokenizer loaded successfully.")
except Exception as e:
    print(f"Warning: tiktoken not found. Error: {e}")
    enc = None

@app.get("/")
async def root():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    return FileResponse(os.path.join(current_dir, "index.html"))

@app.get("/health")
async def health_check():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    exe_path   = os.path.join(current_dir, "inference")
    model_path = os.path.join(current_dir, "model.bin")
    return {
        "status": "ok",
        "inference_exe_found": os.path.exists(exe_path),
        "model_bin_found":     os.path.exists(model_path),
        "working_directory":   current_dir
    }

@app.post("/generate")
async def generate_text(req: GenerateRequest):
    if enc is None:
        raise HTTPException(status_code=500, detail="Tokenizer not loaded.")

    input_tokens = enc.encode(req.prompt)
    token_str    = ",".join(map(str, input_tokens))

    current_dir = os.path.dirname(os.path.abspath(__file__))
    exe_path    = os.path.join(current_dir, "inference")
    model_path  = os.path.join(current_dir, "model.bin")

    if not os.path.exists(exe_path):
        raise HTTPException(status_code=500, detail=f"inference binary not found: {exe_path}")

    if not os.path.exists(model_path):
        raise HTTPException(status_code=500, detail=f"model.bin not found: {model_path}")

    try:
        start_time = time.perf_counter()
        process = subprocess.run(
            [exe_path, token_str, str(req.max_tokens), str(req.temperature), str(req.top_k)],
            capture_output=True,
            text=True,
            cwd=current_dir
        )
        elapsed_ms = (time.perf_counter() - start_time) * 1000
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}")

    if process.returncode != 0 and not process.stdout.strip():
        stdout_msg = process.stdout.strip() if process.stdout else ""
        stderr_msg = process.stderr.strip() if process.stderr else ""
        raise HTTPException(status_code=500, detail=f"C++ Error | stdout: '{stdout_msg}' | stderr: '{stderr_msg}'")

    try:
        output_str = process.stdout.strip()
        generated_ids = []
        if output_str:
            for x in output_str.split():
                try:
                    generated_ids.append(int(x))
                except ValueError:
                    pass

        generated_text = enc.decode(generated_ids) if generated_ids else ""
        tokens_out     = len(generated_ids)
        tokens_per_sec = round(tokens_out / (elapsed_ms / 1000), 2) if elapsed_ms > 0 else 0

        return {
            "prompt":         req.prompt,
            "generated_text": generated_text,
            "tokens_in":      len(input_tokens),
            "tokens_out":     tokens_out,
            "latency_ms":     round(elapsed_ms, 2),
            "tokens_per_sec": tokens_per_sec
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Decoding error: {str(e)}")