Tonic
attempts fix tool use
007f3d9 unverified
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from globe import title, description, customtool, presentation1, presentation2, joinus, examples
import torchao
model_path = "mobiuslabsgmbh/Llama-3.1-8B-Instruct_gemlite-ao_a16w4_gs_128_pack_32bit"
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="cuda",
attn_implementation="sdpa",
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, model_kwargs={"torch_dtype": torch.float16}, device_map="cuda")
def create_prompt(system_message, user_message, tool_definition="", context=""):
# Llama 3.1 official prompt format
prompt = "<|begin_of_text|>\n"
if system_message:
prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_message.strip() + "\n<|eot_id|>\n"
if context:
prompt += "<|start_header_id|>context<|end_header_id|>\n" + context.strip() + "\n<|eot_id|>\n"
if tool_definition:
prompt += "<|start_header_id|>tool<|end_header_id|>\n" + tool_definition.strip() + "\n<|eot_id|>\n"
prompt += "<|start_header_id|>user<|end_header_id|>\n" + user_message.strip() + "\n<|eot_id|>\n"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
return prompt
@spaces.GPU()
def generate_response(message, history, system_message, max_tokens, temperature, top_p, do_sample, use_pipeline=False, tool_definition="", context=""):
full_prompt = create_prompt(system_message, message, tool_definition, context)
if use_pipeline:
response = pipe(full_prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample)[0]['generated_text']
else:
max_model_length = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 8192
max_length = max_model_length - max_tokens
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
attention_mask=attention_mask
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract the assistant response (everything after "Assistant: ")
assistant_response = response.split("Assistant: ")[-1].strip()
if tool_definition and "<toolcall>" in assistant_response:
tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0]
assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response."
return assistant_response
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, use_tool, tool_definition):
user_message = history[-1][0]
do_sample = advanced_checkbox
# Add a system message to encourage tool use if enabled
sys_msg = system_prompt
if use_tool:
tool_instruction = "If a tool is defined, use it to answer the user's question by calling the tool with the appropriate arguments."
sys_msg = (system_prompt + "\n" + tool_instruction) if system_prompt else tool_instruction
tool_def = tool_definition if use_tool else ""
bot_message = generate_response(user_message, history, sys_msg, max_length, temperature, top_p, do_sample, use_pipeline, tool_def)
history[-1][1] = bot_message
return history
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(title)
with gr.Row():
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(presentation1)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(joinus)
with gr.Row():
with gr.Column(scale=2):
system_prompt = gr.TextArea(label="📑System Prompt", placeholder="add system prompt here...", lines=5)
user_input = gr.TextArea(label="🦙 User Input", placeholder="Hi there my name is Tonic!", lines=2)
advanced_checkbox = gr.Checkbox(label="🧪 Advanced Settings", value=False)
with gr.Column(visible=False) as advanced_settings:
max_length = gr.Slider(label="📏Max Length", minimum=12, maximum=4096, value=2048, step=1)
temperature = gr.Slider(label="🌡️Temperature", minimum=0.01, maximum=1.0, value=0.7, step=0.01)
top_p = gr.Slider(label="⚛️Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
use_pipeline = gr.Checkbox(label="Use Pipeline", value=False)
use_tool = gr.Checkbox(label="Use Function Calling", value=False)
with gr.Column(visible=False) as tool_options:
tool_definition = gr.Code(
label="Tool Definition (JSON)",
value=customtool,
lines=15,
language="json"
)
generate_button = gr.Button(value="🦙 Llama-3.1-8B-Instruct GemLite")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="🦙 Llama-3.1-8B-Instruct GemLite")
gr.Examples(
examples=examples,
inputs=[user_input, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_tool, tool_definition],
label="Try these examples:"
)
generate_button.click(
user,
[user_input, chatbot],
[user_input, chatbot],
queue=False
).then(
bot,
[chatbot, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, use_tool, tool_definition],
chatbot
)
advanced_checkbox.change(
fn=lambda x: gr.update(visible=x),
inputs=[advanced_checkbox],
outputs=[advanced_settings]
)
use_tool.change(
fn=lambda x: gr.update(visible=x),
inputs=[use_tool],
outputs=[tool_options]
)
if __name__ == "__main__":
demo.queue()
demo.launch(ssr_mode=False, mcp_server=True)