Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any | |
| import os | |
| import datetime | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import login | |
| # ========================================== | |
| # 1. APP SETUP | |
| # ========================================== | |
| app = FastAPI(title="FunctionGemma Brain API", version="1.0.0") | |
| MODEL_ID = "google/functiongemma-270m-it" | |
| tokenizer = None | |
| model = None | |
| # ========================================== | |
| # 2. FEW-SHOT EXAMPLES (The Teacher) | |
| # ========================================== | |
| # We teach the model the correct tool names here. | |
| # This list simulates a previous conversation so the model knows what to do. | |
| FEW_SHOT_MESSAGES = [ | |
| # Example 1: Counting/Stats | |
| {"role": "user", "content": "How many regions are there?"}, | |
| {"role": "model", "content": "<start_function_call>call:get_aggregate_stats{target_entity:revenue_region}<end_function_call>"}, | |
| # Example 2: Specific Search | |
| {"role": "user", "content": "What is the water level in Aadale dam?"}, | |
| {"role": "model", "content": "<start_function_call>call:search_specific_dam{dam_name:Aadale}<end_function_call>"}, | |
| # Example 3: Filtering | |
| {"role": "user", "content": "Show me Major dams in Pune."}, | |
| {"role": "model", "content": "<start_function_call>call:filter_dams{district:Pune,project_type:Major}<end_function_call>"}, | |
| # Example 4: Irrelevant Question (Teach it to NOT call functions for random stuff) | |
| {"role": "user", "content": "What is the capital of France?"}, | |
| {"role": "model", "content": "I cannot answer that as it is not related to the dam database."} | |
| ] | |
| # ========================================== | |
| # 3. STARTUP | |
| # ========================================== | |
| async def startup(): | |
| global tokenizer, model | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: raise RuntimeError("HF_TOKEN missing") | |
| login(token=hf_token) | |
| print(f"π§ Loading {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32) | |
| print("β Model Loaded.") | |
| # ========================================== | |
| # 4. API ENDPOINT | |
| # ========================================== | |
| class ChatRequest(BaseModel): | |
| query: str | |
| tools: List[Dict[str, Any]] | |
| include_date: bool = True | |
| async def generate_function_call(request: ChatRequest): | |
| if not model: raise HTTPException(status_code=503, detail="Model loading") | |
| try: | |
| # 1. System Prompt | |
| system_content = "You are a model that can do function calling with the following functions." | |
| if request.include_date: | |
| today = datetime.date.today().isoformat() | |
| system_content += f" Today is {today}." | |
| # 2. Construct History: System -> Examples -> Current User Query | |
| messages = [{"role": "system", "content": system_content}] | |
| # Inject the examples! | |
| messages.extend(FEW_SHOT_MESSAGES) | |
| # Add the actual user query | |
| messages.append({"role": "user", "content": request.query}) | |
| # 3. Tokenize | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tools=request.tools, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| # 4. Generate | |
| outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False) | |
| generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) | |
| return {"response": generated_text} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |