brainfuncall / main.py
triflix's picture
Update main.py
b3f0838 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import os
import datetime
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
# ==========================================
# 1. APP SETUP
# ==========================================
app = FastAPI(title="FunctionGemma Brain API", version="1.0.0")
MODEL_ID = "google/functiongemma-270m-it"
tokenizer = None
model = None
# ==========================================
# 2. FEW-SHOT EXAMPLES (The Teacher)
# ==========================================
# We teach the model the correct tool names here.
# This list simulates a previous conversation so the model knows what to do.
FEW_SHOT_MESSAGES = [
# Example 1: Counting/Stats
{"role": "user", "content": "How many regions are there?"},
{"role": "model", "content": "<start_function_call>call:get_aggregate_stats{target_entity:revenue_region}<end_function_call>"},
# Example 2: Specific Search
{"role": "user", "content": "What is the water level in Aadale dam?"},
{"role": "model", "content": "<start_function_call>call:search_specific_dam{dam_name:Aadale}<end_function_call>"},
# Example 3: Filtering
{"role": "user", "content": "Show me Major dams in Pune."},
{"role": "model", "content": "<start_function_call>call:filter_dams{district:Pune,project_type:Major}<end_function_call>"},
# Example 4: Irrelevant Question (Teach it to NOT call functions for random stuff)
{"role": "user", "content": "What is the capital of France?"},
{"role": "model", "content": "I cannot answer that as it is not related to the dam database."}
]
# ==========================================
# 3. STARTUP
# ==========================================
@app.on_event("startup")
async def startup():
global tokenizer, model
hf_token = os.getenv("HF_TOKEN")
if not hf_token: raise RuntimeError("HF_TOKEN missing")
login(token=hf_token)
print(f"🧠 Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32)
print("βœ… Model Loaded.")
# ==========================================
# 4. API ENDPOINT
# ==========================================
class ChatRequest(BaseModel):
query: str
tools: List[Dict[str, Any]]
include_date: bool = True
@app.post("/generate")
async def generate_function_call(request: ChatRequest):
if not model: raise HTTPException(status_code=503, detail="Model loading")
try:
# 1. System Prompt
system_content = "You are a model that can do function calling with the following functions."
if request.include_date:
today = datetime.date.today().isoformat()
system_content += f" Today is {today}."
# 2. Construct History: System -> Examples -> Current User Query
messages = [{"role": "system", "content": system_content}]
# Inject the examples!
messages.extend(FEW_SHOT_MESSAGES)
# Add the actual user query
messages.append({"role": "user", "content": request.query})
# 3. Tokenize
inputs = tokenizer.apply_chat_template(
messages,
tools=request.tools,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
# 4. Generate
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
return {"response": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))