File size: 3,823 Bytes
e72da52
2a263c0
 
 
 
 
e72da52
 
d14886f
e72da52
2a263c0
d14886f
2a263c0
b3f0838
2a263c0
d14886f
e72da52
 
 
2a263c0
b3f0838
2a263c0
b3f0838
 
 
 
 
 
 
 
 
 
 
 
 
 
2a263c0
b3f0838
 
 
 
2a263c0
 
b3f0838
2a263c0
e72da52
2a263c0
e72da52
d14886f
b3f0838
 
d14886f
b3f0838
 
 
 
2a263c0
 
d14886f
2a263c0
b3f0838
 
 
 
e72da52
 
 
b3f0838
e72da52
 
b3f0838
 
e72da52
2a263c0
e72da52
 
b3f0838
 
 
 
 
 
 
 
e72da52
b3f0838
e72da52
 
 
 
 
b3f0838
e72da52
 
b3f0838
 
 
e72da52
 
 
 
b3f0838
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import os
import datetime

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

# ==========================================
# 1. APP SETUP
# ==========================================
app = FastAPI(title="FunctionGemma Brain API", version="1.0.0")

MODEL_ID = "google/functiongemma-270m-it"
tokenizer = None
model = None

# ==========================================
# 2. FEW-SHOT EXAMPLES (The Teacher)
# ==========================================
# We teach the model the correct tool names here.
# This list simulates a previous conversation so the model knows what to do.
FEW_SHOT_MESSAGES = [
    # Example 1: Counting/Stats
    {"role": "user", "content": "How many regions are there?"},
    {"role": "model", "content": "<start_function_call>call:get_aggregate_stats{target_entity:revenue_region}<end_function_call>"},
    
    # Example 2: Specific Search
    {"role": "user", "content": "What is the water level in Aadale dam?"},
    {"role": "model", "content": "<start_function_call>call:search_specific_dam{dam_name:Aadale}<end_function_call>"},
    
    # Example 3: Filtering
    {"role": "user", "content": "Show me Major dams in Pune."},
    {"role": "model", "content": "<start_function_call>call:filter_dams{district:Pune,project_type:Major}<end_function_call>"},

    # Example 4: Irrelevant Question (Teach it to NOT call functions for random stuff)
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "model", "content": "I cannot answer that as it is not related to the dam database."}
]

# ==========================================
# 3. STARTUP
# ==========================================
@app.on_event("startup")
async def startup():
    global tokenizer, model
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token: raise RuntimeError("HF_TOKEN missing")
    login(token=hf_token)
    
    print(f"🧠 Loading {MODEL_ID}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32)
    print("✅ Model Loaded.")

# ==========================================
# 4. API ENDPOINT
# ==========================================
class ChatRequest(BaseModel):
    query: str
    tools: List[Dict[str, Any]]
    include_date: bool = True

@app.post("/generate")
async def generate_function_call(request: ChatRequest):
    if not model: raise HTTPException(status_code=503, detail="Model loading")

    try:
        # 1. System Prompt
        system_content = "You are a model that can do function calling with the following functions."
        if request.include_date:
            today = datetime.date.today().isoformat()
            system_content += f" Today is {today}."

        # 2. Construct History: System -> Examples -> Current User Query
        messages = [{"role": "system", "content": system_content}]
        
        # Inject the examples!
        messages.extend(FEW_SHOT_MESSAGES)
        
        # Add the actual user query
        messages.append({"role": "user", "content": request.query})

        # 3. Tokenize
        inputs = tokenizer.apply_chat_template(
            messages,
            tools=request.tools,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt",
        )

        # 4. Generate
        outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
        generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

        return {"response": generated_text}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))