|
|
""" |
|
|
distil-home-assistant-functiongemma β OpenAI/OpenRouter-compatible inference server |
|
|
|
|
|
Endpoints: |
|
|
GET / -> health check / API info |
|
|
GET /v1/models -> list available models |
|
|
POST /v1/chat/completions -> generate text (OpenAI format) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
|
|
|
import torch |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "LisaMegaWatts/distil-home-assistant-functiongemma" |
|
|
MODEL_NAME = "functiongemma-270m" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading {MODEL_ID} on {DEVICE} ...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, torch_dtype=DTYPE, trust_remote_code=True |
|
|
).to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
MODEL_CREATED_AT = int(time.time()) |
|
|
print(f"Model ready on {DEVICE}") |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title=f"{MODEL_NAME} API", version="1.0.0") |
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return { |
|
|
"name": MODEL_NAME, |
|
|
"version": "1.0.0", |
|
|
"description": "FunctionGemma 270m distilled for home assistant function calling", |
|
|
"model": MODEL_ID, |
|
|
"endpoints": ["/v1/models", "/v1/chat/completions"], |
|
|
"compatible_with": ["OpenAI API", "OpenRouter"], |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/v1/models") |
|
|
def list_models(): |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [{"id": MODEL_NAME, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "LisaMegaWatts"}], |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_completions(request: Request): |
|
|
try: |
|
|
body = await request.json() |
|
|
except Exception: |
|
|
return JSONResponse(status_code=400, content={"error": {"message": "Invalid JSON", "type": "invalid_request_error", "code": "invalid_json"}}) |
|
|
|
|
|
messages = body.get("messages", []) |
|
|
temperature = max(0.01, min(float(body.get("temperature", 0.7)), 2.0)) |
|
|
max_tokens = max(1, min(int(body.get("max_tokens", 512)), 2048)) |
|
|
top_p = float(body.get("top_p", 0.9)) |
|
|
n_completions = int(body.get("n", 1)) |
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: |
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
else: |
|
|
parts = [f"{m.get('role','user')}: {m.get('content','')}" for m in messages] |
|
|
parts.append("assistant:") |
|
|
prompt = "\n".join(parts) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
|
prompt_tokens = inputs["input_ids"].shape[1] |
|
|
|
|
|
choices = [] |
|
|
total_completion_tokens = 0 |
|
|
for i in range(n_completions): |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, max_new_tokens=max_tokens, temperature=temperature, |
|
|
top_p=top_p, do_sample=temperature > 0.01, pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
new_tokens = outputs[0][prompt_tokens:] |
|
|
text = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
completion_tokens = len(new_tokens) |
|
|
total_completion_tokens += completion_tokens |
|
|
choices.append({"index": i, "message": {"role": "assistant", "content": text}, "finish_reason": "length" if completion_tokens >= max_tokens else "stop"}) |
|
|
|
|
|
return { |
|
|
"id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), |
|
|
"model": MODEL_NAME, "choices": choices, |
|
|
"usage": {"prompt_tokens": prompt_tokens, "completion_tokens": total_completion_tokens, "total_tokens": prompt_tokens + total_completion_tokens}, |
|
|
"system_fingerprint": f"{MODEL_NAME}-v1", |
|
|
} |
|
|
|