LongCat-Flash-Lite-4bit-DWQ / parse_model_response.py
kernelpool's picture
Add files using upload-large-folder tool
34a75a8 verified
import re
import json
import uuid
def parse_arguments(json_value):
"""
Attempt to parse a string as JSON
Args:
json_value: String to parse
Returns:
tuple: (parsed_value, is_valid_json)
"""
try:
parsed_value = json.loads(json_value)
return parsed_value, True
except:
return json_value, False
def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
"""
Get the type definition of a tool parameter
Args:
func_name: Name of the function/tool
arg_key: Parameter key name
defined_tools: List of tool definitions
Returns:
str or None: Type of the parameter ('string', 'object', 'array', 'integer', 'number', 'boolean')
"""
name2tool = {tool["name"]: tool for tool in defined_tools}
if func_name not in name2tool:
return None
tool = name2tool[func_name]
if "parameters" not in tool or "properties" not in tool["parameters"]:
return None
if arg_key not in tool["parameters"]["properties"]:
return None
return tool["parameters"]["properties"][arg_key].get("type")
def parse_model_response(response: str, defined_tools: list=[]):
"""
Parse model response to extract reasoning_content, content, and tool_calls
Args:
response: Raw response text from the model
defined_tools: List of tool definitions
Returns:
dict: Message containing role, reasoning_content (optional), content (optional),
and tool_calls (optional)
"""
text = response
reasoning_content = None
content = None
tool_calls = []
formatted_tools = []
for tool in defined_tools:
if "function" in tool:
formatted_tools.append(tool['function'])
else:
formatted_tools.append(tool)
if '</longcat_think>' in text:
text = text.replace('<longcat_think>', '')
thinking_end = text.find('</longcat_think>')
reasoning_content = text[: thinking_end].strip()
text = text[thinking_end + len('</longcat_think>'):].lstrip()
assert '<longcat_think>' not in text, "Unclosed <longcat_think> tag found in remaining text"
assert '</longcat_think>' not in text, "Unexpected </longcat_think> tag found without opening tag"
if '<longcat_tool_call>' in text:
index = text.find('<longcat_tool_call>')
content = text[:index]
text = text[index:].strip()
else:
content = text
text = ""
open_tags = text.count('<longcat_tool_call>')
close_tags = text.count('</longcat_tool_call>')
assert open_tags == close_tags, \
f"Mismatched tool_call tags: {open_tags} opening tags, {close_tags} closing tags"
tool_call_strs = re.findall(
r'<longcat_tool_call>(.*?)</longcat_tool_call>',
text,
re.DOTALL
)
for call in tool_call_strs:
func_name_match = re.match(r'([^\n<]+)', call.strip())
assert func_name_match, f"Missing function name in tool call: {call[:100]}"
func_name = func_name_match.group(1).strip()
assert func_name, "Empty function name in tool call"
# Verify argument tags are properly paired
arg_key_count = call.count('<longcat_arg_key>')
arg_key_close_count = call.count('</longcat_arg_key>')
arg_value_count = call.count('<longcat_arg_value>')
arg_value_close_count = call.count('</longcat_arg_value>')
assert arg_key_count == arg_key_close_count, \
f"Mismatched arg_key tags in function {func_name}: {arg_key_count} opening, {arg_key_close_count} closing"
assert arg_value_count == arg_value_close_count, \
f"Mismatched arg_value tags in function {func_name}: {arg_value_count} opening, {arg_value_close_count} closing"
assert arg_key_count == arg_value_count, \
f"Mismatched arg_key and arg_value count in function {func_name}: {arg_key_count} keys, {arg_value_count} values"
pairs = re.findall(
r'<longcat_arg_key>(.*?)</longcat_arg_key>\s*<longcat_arg_value>(.*?)</longcat_arg_value>',
call,
re.DOTALL
)
assert len(pairs) == arg_key_count, \
f"Failed to parse all arguments in function {func_name}: expected {arg_key_count}, got {len(pairs)}"
arguments = {}
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
arg_value = arg_value.strip()
assert arg_key, f"Empty argument key in function {func_name}"
assert arg_key not in arguments, \
f"Duplicate argument key '{arg_key}' in function {func_name}"
arg_type = get_argument_type(func_name, arg_key, formatted_tools)
if arg_type and arg_type != 'string':
parsed_value, is_good_json = parse_arguments(arg_value)
arg_value = parsed_value
arguments[arg_key] = arg_value
tool_calls.append({
'id': "tool-call-" + str(uuid.uuid4()),
'type': "function",
'function': {
'name': func_name,
'arguments': arguments
}
})
message = {'role': 'assistant'}
if reasoning_content:
message['reasoning_content'] = reasoning_content
message['content'] = content
if tool_calls:
message['tool_calls'] = tool_calls
return message
if __name__=="__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer
from parse_model_response import parse_model_response
model_name = "meituan-longcat/LongCat-Flash-Lite"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Give me a brief introduction to large language models."}
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
generated_ids = model.generate(inputs=input_ids, max_new_tokens=256)
output_ids = generated_ids[0][len(input_ids[0]):].tolist()
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
print("Example 1: sample response.")
print("\nRaw response:")
print(response)
print("\nParsed result:")
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
parsed_message = parse_model_response(response)
print(json.dumps(parsed_message, indent=2, ensure_ascii=False))
tools = [
{
"type": "function",
"function": {
"name": "func_add",
"description": "Calculate the sum of two numbers",
"parameters": {
"type": "object",
"properties": {
"x1": {"type": "number", "description": "The first addend"},
"x2": {"type": "number", "description": "The second addend"}
},
"required": ["x1", "x2"]
}
}
}
]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Please tell me what is $$125679 + 234519$$?"},
# {
# "role": "assistant",
# "content": "I'll calculate the sum of 125679 and 234519 for you.",
# "tool_calls": [{"type": "function", "function": {"name": "func_add", "arguments": {"x1": 125679, "x2": 234519}}}]
# },
# {"role": "tool", "name": "func_add", "content": '{"ans": 360198}'}
]
input_ids = tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
generated_ids = model.generate(inputs=input_ids, max_new_tokens=256)
output_ids = generated_ids[0][len(input_ids[0]):].tolist()
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
print("Example 2: tool call response.")
print("\nRaw response:")
print(response)
print("\nParsed result:")
parsed_message = parse_model_response(response, tools)
print(json.dumps(parsed_message, indent=2, ensure_ascii=False))