File size: 4,588 Bytes
2fd3ad2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
"""
distil-home-assistant-functiongemma β OpenAI/OpenRouter-compatible inference server
Endpoints:
GET / -> health check / API info
GET /v1/models -> list available models
POST /v1/chat/completions -> generate text (OpenAI format)
"""
import os
import time
import uuid
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
# βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_ID = "LisaMegaWatts/distil-home-assistant-functiongemma"
MODEL_NAME = "functiongemma-270m"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# βββ Load model at startup ββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"Loading {MODEL_ID} on {DEVICE} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=DTYPE, trust_remote_code=True
).to(DEVICE)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
MODEL_CREATED_AT = int(time.time())
print(f"Model ready on {DEVICE}")
# βββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(title=f"{MODEL_NAME} API", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.get("/")
def root():
return {
"name": MODEL_NAME,
"version": "1.0.0",
"description": "FunctionGemma 270m distilled for home assistant function calling",
"model": MODEL_ID,
"endpoints": ["/v1/models", "/v1/chat/completions"],
"compatible_with": ["OpenAI API", "OpenRouter"],
}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{"id": MODEL_NAME, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "LisaMegaWatts"}],
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
body = await request.json()
except Exception:
return JSONResponse(status_code=400, content={"error": {"message": "Invalid JSON", "type": "invalid_request_error", "code": "invalid_json"}})
messages = body.get("messages", [])
temperature = max(0.01, min(float(body.get("temperature", 0.7)), 2.0))
max_tokens = max(1, min(int(body.get("max_tokens", 512)), 2048))
top_p = float(body.get("top_p", 0.9))
n_completions = int(body.get("n", 1))
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
parts = [f"{m.get('role','user')}: {m.get('content','')}" for m in messages]
parts.append("assistant:")
prompt = "\n".join(parts)
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
prompt_tokens = inputs["input_ids"].shape[1]
choices = []
total_completion_tokens = 0
for i in range(n_completions):
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=max_tokens, temperature=temperature,
top_p=top_p, do_sample=temperature > 0.01, pad_token_id=tokenizer.pad_token_id,
)
new_tokens = outputs[0][prompt_tokens:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True)
completion_tokens = len(new_tokens)
total_completion_tokens += completion_tokens
choices.append({"index": i, "message": {"role": "assistant", "content": text}, "finish_reason": "length" if completion_tokens >= max_tokens else "stop"})
return {
"id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()),
"model": MODEL_NAME, "choices": choices,
"usage": {"prompt_tokens": prompt_tokens, "completion_tokens": total_completion_tokens, "total_tokens": prompt_tokens + total_completion_tokens},
"system_fingerprint": f"{MODEL_NAME}-v1",
}
|