""" distil-home-assistant-functiongemma — OpenAI/OpenRouter-compatible inference server Endpoints: GET / -> health check / API info GET /v1/models -> list available models POST /v1/chat/completions -> generate text (OpenAI format) """ import os import time import uuid import torch from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from transformers import AutoModelForCausalLM, AutoTokenizer # ─── Config ─────────────────────────────────────────────────────────────────── MODEL_ID = "LisaMegaWatts/distil-home-assistant-functiongemma" MODEL_NAME = "functiongemma-270m" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 # ─── Load model at startup ──────────────────────────────────────────────────── print(f"Loading {MODEL_ID} on {DEVICE} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, trust_remote_code=True ).to(DEVICE) model.eval() if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token MODEL_CREATED_AT = int(time.time()) print(f"Model ready on {DEVICE}") # ─── FastAPI app ────────────────────────────────────────────────────────────── app = FastAPI(title=f"{MODEL_NAME} API", version="1.0.0") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @app.get("/") def root(): return { "name": MODEL_NAME, "version": "1.0.0", "description": "FunctionGemma 270m distilled for home assistant function calling", "model": MODEL_ID, "endpoints": ["/v1/models", "/v1/chat/completions"], "compatible_with": ["OpenAI API", "OpenRouter"], } @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{"id": MODEL_NAME, "object": "model", "created": MODEL_CREATED_AT, "owned_by": "LisaMegaWatts"}], } @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: body = await request.json() except Exception: return JSONResponse(status_code=400, content={"error": {"message": "Invalid JSON", "type": "invalid_request_error", "code": "invalid_json"}}) messages = body.get("messages", []) temperature = max(0.01, min(float(body.get("temperature", 0.7)), 2.0)) max_tokens = max(1, min(int(body.get("max_tokens", 512)), 2048)) top_p = float(body.get("top_p", 0.9)) n_completions = int(body.get("n", 1)) if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: parts = [f"{m.get('role','user')}: {m.get('content','')}" for m in messages] parts.append("assistant:") prompt = "\n".join(parts) inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) prompt_tokens = inputs["input_ids"].shape[1] choices = [] total_completion_tokens = 0 for i in range(n_completions): with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0.01, pad_token_id=tokenizer.pad_token_id, ) new_tokens = outputs[0][prompt_tokens:] text = tokenizer.decode(new_tokens, skip_special_tokens=True) completion_tokens = len(new_tokens) total_completion_tokens += completion_tokens choices.append({"index": i, "message": {"role": "assistant", "content": text}, "finish_reason": "length" if completion_tokens >= max_tokens else "stop"}) return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), "model": MODEL_NAME, "choices": choices, "usage": {"prompt_tokens": prompt_tokens, "completion_tokens": total_completion_tokens, "total_tokens": prompt_tokens + total_completion_tokens}, "system_fingerprint": f"{MODEL_NAME}-v1", }