|
|
import re |
|
|
import json |
|
|
import uuid |
|
|
|
|
|
def parse_arguments(json_value): |
|
|
""" |
|
|
Attempt to parse a string as JSON |
|
|
|
|
|
Args: |
|
|
json_value: String to parse |
|
|
|
|
|
Returns: |
|
|
tuple: (parsed_value, is_valid_json) |
|
|
""" |
|
|
try: |
|
|
parsed_value = json.loads(json_value) |
|
|
return parsed_value, True |
|
|
except: |
|
|
return json_value, False |
|
|
|
|
|
def get_argument_type(func_name: str, arg_key: str, defined_tools: list): |
|
|
""" |
|
|
Get the type definition of a tool parameter |
|
|
|
|
|
Args: |
|
|
func_name: Name of the function/tool |
|
|
arg_key: Parameter key name |
|
|
defined_tools: List of tool definitions |
|
|
|
|
|
Returns: |
|
|
str or None: Type of the parameter ('string', 'object', 'array', 'integer', 'number', 'boolean') |
|
|
""" |
|
|
name2tool = {tool["name"]: tool for tool in defined_tools} |
|
|
if func_name not in name2tool: |
|
|
return None |
|
|
tool = name2tool[func_name] |
|
|
if "parameters" not in tool or "properties" not in tool["parameters"]: |
|
|
return None |
|
|
if arg_key not in tool["parameters"]["properties"]: |
|
|
return None |
|
|
return tool["parameters"]["properties"][arg_key].get("type") |
|
|
|
|
|
def parse_model_response(response: str, defined_tools: list=[]): |
|
|
""" |
|
|
Parse model response to extract reasoning_content, content, and tool_calls |
|
|
|
|
|
Args: |
|
|
response: Raw response text from the model |
|
|
defined_tools: List of tool definitions |
|
|
|
|
|
Returns: |
|
|
dict: Message containing role, reasoning_content (optional), content (optional), |
|
|
and tool_calls (optional) |
|
|
""" |
|
|
text = response |
|
|
reasoning_content = None |
|
|
content = None |
|
|
tool_calls = [] |
|
|
|
|
|
formatted_tools = [] |
|
|
for tool in defined_tools: |
|
|
if "function" in tool: |
|
|
formatted_tools.append(tool['function']) |
|
|
else: |
|
|
formatted_tools.append(tool) |
|
|
|
|
|
if '</longcat_think>' in text: |
|
|
text = text.replace('<longcat_think>', '') |
|
|
thinking_end = text.find('</longcat_think>') |
|
|
reasoning_content = text[: thinking_end].strip() |
|
|
text = text[thinking_end + len('</longcat_think>'):].lstrip() |
|
|
|
|
|
assert '<longcat_think>' not in text, "Unclosed <longcat_think> tag found in remaining text" |
|
|
assert '</longcat_think>' not in text, "Unexpected </longcat_think> tag found without opening tag" |
|
|
|
|
|
if '<longcat_tool_call>' in text: |
|
|
index = text.find('<longcat_tool_call>') |
|
|
content = text[:index] |
|
|
text = text[index:].strip() |
|
|
else: |
|
|
content = text |
|
|
text = "" |
|
|
|
|
|
open_tags = text.count('<longcat_tool_call>') |
|
|
close_tags = text.count('</longcat_tool_call>') |
|
|
assert open_tags == close_tags, \ |
|
|
f"Mismatched tool_call tags: {open_tags} opening tags, {close_tags} closing tags" |
|
|
|
|
|
tool_call_strs = re.findall( |
|
|
r'<longcat_tool_call>(.*?)</longcat_tool_call>', |
|
|
text, |
|
|
re.DOTALL |
|
|
) |
|
|
|
|
|
for call in tool_call_strs: |
|
|
func_name_match = re.match(r'([^\n<]+)', call.strip()) |
|
|
assert func_name_match, f"Missing function name in tool call: {call[:100]}" |
|
|
|
|
|
func_name = func_name_match.group(1).strip() |
|
|
assert func_name, "Empty function name in tool call" |
|
|
|
|
|
|
|
|
arg_key_count = call.count('<longcat_arg_key>') |
|
|
arg_key_close_count = call.count('</longcat_arg_key>') |
|
|
arg_value_count = call.count('<longcat_arg_value>') |
|
|
arg_value_close_count = call.count('</longcat_arg_value>') |
|
|
|
|
|
assert arg_key_count == arg_key_close_count, \ |
|
|
f"Mismatched arg_key tags in function {func_name}: {arg_key_count} opening, {arg_key_close_count} closing" |
|
|
assert arg_value_count == arg_value_close_count, \ |
|
|
f"Mismatched arg_value tags in function {func_name}: {arg_value_count} opening, {arg_value_close_count} closing" |
|
|
assert arg_key_count == arg_value_count, \ |
|
|
f"Mismatched arg_key and arg_value count in function {func_name}: {arg_key_count} keys, {arg_value_count} values" |
|
|
|
|
|
pairs = re.findall( |
|
|
r'<longcat_arg_key>(.*?)</longcat_arg_key>\s*<longcat_arg_value>(.*?)</longcat_arg_value>', |
|
|
call, |
|
|
re.DOTALL |
|
|
) |
|
|
|
|
|
assert len(pairs) == arg_key_count, \ |
|
|
f"Failed to parse all arguments in function {func_name}: expected {arg_key_count}, got {len(pairs)}" |
|
|
|
|
|
arguments = {} |
|
|
for arg_key, arg_value in pairs: |
|
|
arg_key = arg_key.strip() |
|
|
arg_value = arg_value.strip() |
|
|
|
|
|
assert arg_key, f"Empty argument key in function {func_name}" |
|
|
assert arg_key not in arguments, \ |
|
|
f"Duplicate argument key '{arg_key}' in function {func_name}" |
|
|
|
|
|
arg_type = get_argument_type(func_name, arg_key, formatted_tools) |
|
|
|
|
|
if arg_type and arg_type != 'string': |
|
|
parsed_value, is_good_json = parse_arguments(arg_value) |
|
|
arg_value = parsed_value |
|
|
|
|
|
arguments[arg_key] = arg_value |
|
|
|
|
|
tool_calls.append({ |
|
|
'id': "tool-call-" + str(uuid.uuid4()), |
|
|
'type': "function", |
|
|
'function': { |
|
|
'name': func_name, |
|
|
'arguments': arguments |
|
|
} |
|
|
}) |
|
|
|
|
|
message = {'role': 'assistant'} |
|
|
|
|
|
if reasoning_content: |
|
|
message['reasoning_content'] = reasoning_content |
|
|
message['content'] = content |
|
|
if tool_calls: |
|
|
message['tool_calls'] = tool_calls |
|
|
|
|
|
return message |
|
|
|
|
|
if __name__=="__main__": |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from parse_model_response import parse_model_response |
|
|
|
|
|
model_name = "meituan-longcat/LongCat-Flash-Lite" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "Give me a brief introduction to large language models."} |
|
|
] |
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
).to(model.device) |
|
|
generated_ids = model.generate(inputs=input_ids, max_new_tokens=256) |
|
|
output_ids = generated_ids[0][len(input_ids[0]):].tolist() |
|
|
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") |
|
|
print("Example 1: sample response.") |
|
|
print("\nRaw response:") |
|
|
print(response) |
|
|
print("\nParsed result:") |
|
|
|
|
|
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") |
|
|
parsed_message = parse_model_response(response) |
|
|
print(json.dumps(parsed_message, indent=2, ensure_ascii=False)) |
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "func_add", |
|
|
"description": "Calculate the sum of two numbers", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"x1": {"type": "number", "description": "The first addend"}, |
|
|
"x2": {"type": "number", "description": "The second addend"} |
|
|
}, |
|
|
"required": ["x1", "x2"] |
|
|
} |
|
|
} |
|
|
} |
|
|
] |
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "Please tell me what is $$125679 + 234519$$?"}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tools=tools, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
).to(model.device) |
|
|
generated_ids = model.generate(inputs=input_ids, max_new_tokens=256) |
|
|
output_ids = generated_ids[0][len(input_ids[0]):].tolist() |
|
|
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") |
|
|
print("Example 2: tool call response.") |
|
|
print("\nRaw response:") |
|
|
print(response) |
|
|
print("\nParsed result:") |
|
|
parsed_message = parse_model_response(response, tools) |
|
|
print(json.dumps(parsed_message, indent=2, ensure_ascii=False)) |
|
|
|