Spaces:
Runtime error
Runtime error
| import spaces | |
| import json | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Load model and tokenizer | |
| model_name = "Salesforce/xLAM-1b-fc-r" | |
| title = f"Eval Model: {model_name}" | |
| description = """""" | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Set random seed for reproducibility | |
| torch.random.manual_seed(0) | |
| # Task and format instructions | |
| task_instruction = """ | |
| Based on the previous context and API request history, generate an API request or a response as an AI assistant.""".strip() | |
| format_instruction = """ | |
| The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make | |
| tool_calls an empty list "[]". | |
| ``` | |
| {"thought": "the thought process, or an empty string", "tool_calls": [{"name": "api_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}]} | |
| ``` | |
| """.strip() | |
| # Example tools and query | |
| example_tools = json.dumps([ | |
| { | |
| "name": "get_weather", | |
| "description": "Get the current weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": { | |
| "type": "string", | |
| "description": "The city and state, e.g. San Francisco, New York" | |
| }, | |
| "unit": { | |
| "type": "string", | |
| "enum": ["celsius", "fahrenheit"], | |
| "description": "The unit of temperature to return" | |
| } | |
| }, | |
| "required": ["location"] | |
| } | |
| }, | |
| { | |
| "name": "search", | |
| "description": "Search for information on the internet", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query, e.g. 'latest news on AI'" | |
| } | |
| }, | |
| "required": ["query"] | |
| } | |
| } | |
| ], indent=2) | |
| example_query = "What's the weather like in New York in fahrenheit?" | |
| def convert_to_xlam_tool(tools): | |
| if isinstance(tools, dict): | |
| return { | |
| "name": tools["name"], | |
| "description": tools["description"], | |
| "parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()} | |
| } | |
| elif isinstance(tools, list): | |
| return [convert_to_xlam_tool(tool) for tool in tools] | |
| else: | |
| return tools | |
| def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str): | |
| prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n" | |
| prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n" | |
| prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n" | |
| prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n" | |
| return prompt | |
| def generate_response(tools_input, query): | |
| try: | |
| tools = json.loads(tools_input) | |
| except json.JSONDecodeError: | |
| return "Error: Invalid JSON format for tools input." | |
| xlam_format_tools = convert_to_xlam_tool(tools) | |
| content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query) | |
| messages = [ | |
| {'role': 'user', 'content': content} | |
| ] | |
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
| outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) | |
| agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) | |
| return agent_action | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| tools_input = gr.Code( | |
| label="Available Tools (JSON format)", | |
| lines=20, | |
| value=example_tools, | |
| language='json' | |
| ) | |
| query_input = gr.Textbox( | |
| label="User Query", | |
| lines=2, | |
| value=example_query | |
| ) | |
| submit_button = gr.Button("Generate Response") | |
| with gr.Column(): | |
| output = gr.Code(label="🎬 xLam :", lines=10, language="json") | |
| submit_button.click(generate_response, inputs=[tools_input, query_input], outputs=output) | |
| if __name__ == "__main__": | |
| demo.launch() |