File size: 4,588 Bytes
2fd3ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
distil-home-assistant-functiongemma β€” OpenAI/OpenRouter-compatible inference server

Endpoints:
    GET  /                       -> health check / API info
    GET  /v1/models              -> list available models
    POST /v1/chat/completions    -> generate text (OpenAI format)
"""

import os
import time
import uuid

import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer

# ─── Config ───────────────────────────────────────────────────────────────────

MODEL_ID = "LisaMegaWatts/distil-home-assistant-functiongemma"
MODEL_NAME = "functiongemma-270m"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

# ─── Load model at startup ────────────────────────────────────────────────────

print(f"Loading {MODEL_ID} on {DEVICE} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype=DTYPE, trust_remote_code=True
).to(DEVICE)
model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

MODEL_CREATED_AT = int(time.time())
print(f"Model ready on {DEVICE}")

# ─── FastAPI app ──────────────────────────────────────────────────────────────

app = FastAPI(title=f"{MODEL_NAME} API", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])


@app.get("/")
def root():
    return {
        "name": MODEL_NAME,
        "version": "1.0.0",
        "description": "FunctionGemma 270m distilled for home assistant function calling",
        "model": MODEL_ID,
        "endpoints": ["/v1/models", "/v1/chat/completions"],
        "compatible_with": ["OpenAI API", "OpenRouter"],
    }


@app.get("/v1/models")
def list_models():
    return {
        "object": "list",
        "data": [{"id": MODEL_NAME, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "LisaMegaWatts"}],
    }


@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
    try:
        body = await request.json()
    except Exception:
        return JSONResponse(status_code=400, content={"error": {"message": "Invalid JSON", "type": "invalid_request_error", "code": "invalid_json"}})

    messages = body.get("messages", [])
    temperature = max(0.01, min(float(body.get("temperature", 0.7)), 2.0))
    max_tokens = max(1, min(int(body.get("max_tokens", 512)), 2048))
    top_p = float(body.get("top_p", 0.9))
    n_completions = int(body.get("n", 1))

    if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        parts = [f"{m.get('role','user')}: {m.get('content','')}" for m in messages]
        parts.append("assistant:")
        prompt = "\n".join(parts)

    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    prompt_tokens = inputs["input_ids"].shape[1]

    choices = []
    total_completion_tokens = 0
    for i in range(n_completions):
        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=max_tokens, temperature=temperature,
                top_p=top_p, do_sample=temperature > 0.01, pad_token_id=tokenizer.pad_token_id,
            )
        new_tokens = outputs[0][prompt_tokens:]
        text = tokenizer.decode(new_tokens, skip_special_tokens=True)
        completion_tokens = len(new_tokens)
        total_completion_tokens += completion_tokens
        choices.append({"index": i, "message": {"role": "assistant", "content": text}, "finish_reason": "length" if completion_tokens >= max_tokens else "stop"})

    return {
        "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()),
        "model": MODEL_NAME, "choices": choices,
        "usage": {"prompt_tokens": prompt_tokens, "completion_tokens": total_completion_tokens, "total_tokens": prompt_tokens + total_completion_tokens},
        "system_fingerprint": f"{MODEL_NAME}-v1",
    }