|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from huggingface_hub import login |
|
|
import os |
|
|
import json |
|
|
|
|
|
|
|
|
hf_token = os.getenv('HF_TOKEN') |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
|
|
|
|
|
|
model_name = "google/functiongemma-270m-it" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="auto", |
|
|
token=hf_token |
|
|
) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "calculate", |
|
|
"description": "Performs mathematical calculations", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"operation": { |
|
|
"type": "string", |
|
|
"description": "The mathematical operation (add, subtract, multiply, divide)" |
|
|
}, |
|
|
"a": { |
|
|
"type": "number", |
|
|
"description": "First number" |
|
|
}, |
|
|
"b": { |
|
|
"type": "number", |
|
|
"description": "Second number" |
|
|
} |
|
|
}, |
|
|
"required": ["operation", "a", "b"] |
|
|
} |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "send_email", |
|
|
"description": "Sends an email to a specified recipient", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"to": { |
|
|
"type": "string", |
|
|
"description": "Email recipient address" |
|
|
}, |
|
|
"subject": { |
|
|
"type": "string", |
|
|
"description": "Email subject" |
|
|
}, |
|
|
"body": { |
|
|
"type": "string", |
|
|
"description": "Email body content" |
|
|
} |
|
|
}, |
|
|
"required": ["to", "subject", "body"] |
|
|
} |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "set_alarm", |
|
|
"description": "Sets an alarm for a specified time", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"time": { |
|
|
"type": "string", |
|
|
"description": "Time in HH:MM format (24-hour)" |
|
|
}, |
|
|
"label": { |
|
|
"type": "string", |
|
|
"description": "Label or description for the alarm" |
|
|
} |
|
|
}, |
|
|
"required": ["time"] |
|
|
} |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
def test_function_calling(user_input: str, temperature: float = 0.7) -> str: |
|
|
"""Test FunctionGemma model for function calling.""" |
|
|
try: |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "developer", |
|
|
"content": "You are a helpful assistant that can make function calls. When the user asks you to do something that matches one of the available functions, call that function." |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_input |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
formatted_input = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tools=tools, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=temperature, |
|
|
top_p=0.95, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="FunctionGemma-270M Function Calling Tester") as demo: |
|
|
gr.Markdown("# FunctionGemma-270M Function Calling Tester") |
|
|
gr.Markdown("Test the FunctionGemma model's ability to generate function calls.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
user_input = gr.Textbox( |
|
|
label="Input", |
|
|
placeholder="Describe a function call you want", |
|
|
lines=3 |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.7, |
|
|
label="Temperature" |
|
|
) |
|
|
submit_btn = gr.Button("Test Model", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Textbox( |
|
|
label="Model Output", |
|
|
lines=10 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Calculate 5 + 3"], |
|
|
["Send email to user@example.com"], |
|
|
["Set alarm for 6 AM"] |
|
|
], |
|
|
inputs=[user_input] |
|
|
) |
|
|
|
|
|
submit_btn.click(test_function_calling, inputs=[user_input, temperature], outputs=output) |
|
|
|
|
|
demo.launch() |