import re import json import uuid def parse_arguments(json_value): """ Attempt to parse a string as JSON Args: json_value: String to parse Returns: tuple: (parsed_value, is_valid_json) """ try: parsed_value = json.loads(json_value) return parsed_value, True except: return json_value, False def get_argument_type(func_name: str, arg_key: str, defined_tools: list): """ Get the type definition of a tool parameter Args: func_name: Name of the function/tool arg_key: Parameter key name defined_tools: List of tool definitions Returns: str or None: Type of the parameter ('string', 'object', 'array', 'integer', 'number', 'boolean') """ name2tool = {tool["name"]: tool for tool in defined_tools} if func_name not in name2tool: return None tool = name2tool[func_name] if "parameters" not in tool or "properties" not in tool["parameters"]: return None if arg_key not in tool["parameters"]["properties"]: return None return tool["parameters"]["properties"][arg_key].get("type") def parse_model_response(response: str, defined_tools: list=[]): """ Parse model response to extract reasoning_content, content, and tool_calls Args: response: Raw response text from the model defined_tools: List of tool definitions Returns: dict: Message containing role, reasoning_content (optional), content (optional), and tool_calls (optional) """ text = response reasoning_content = None content = None tool_calls = [] formatted_tools = [] for tool in defined_tools: if "function" in tool: formatted_tools.append(tool['function']) else: formatted_tools.append(tool) if '' in text: text = text.replace('', '') thinking_end = text.find('') reasoning_content = text[: thinking_end].strip() text = text[thinking_end + len(''):].lstrip() assert '' not in text, "Unclosed tag found in remaining text" assert '' not in text, "Unexpected tag found without opening tag" if '' in text: index = text.find('') content = text[:index] text = text[index:].strip() else: content = text text = "" open_tags = text.count('') close_tags = text.count('') assert open_tags == close_tags, \ f"Mismatched tool_call tags: {open_tags} opening tags, {close_tags} closing tags" tool_call_strs = re.findall( r'(.*?)', text, re.DOTALL ) for call in tool_call_strs: func_name_match = re.match(r'([^\n<]+)', call.strip()) assert func_name_match, f"Missing function name in tool call: {call[:100]}" func_name = func_name_match.group(1).strip() assert func_name, "Empty function name in tool call" # Verify argument tags are properly paired arg_key_count = call.count('') arg_key_close_count = call.count('') arg_value_count = call.count('') arg_value_close_count = call.count('') assert arg_key_count == arg_key_close_count, \ f"Mismatched arg_key tags in function {func_name}: {arg_key_count} opening, {arg_key_close_count} closing" assert arg_value_count == arg_value_close_count, \ f"Mismatched arg_value tags in function {func_name}: {arg_value_count} opening, {arg_value_close_count} closing" assert arg_key_count == arg_value_count, \ f"Mismatched arg_key and arg_value count in function {func_name}: {arg_key_count} keys, {arg_value_count} values" pairs = re.findall( r'(.*?)\s*(.*?)', call, re.DOTALL ) assert len(pairs) == arg_key_count, \ f"Failed to parse all arguments in function {func_name}: expected {arg_key_count}, got {len(pairs)}" arguments = {} for arg_key, arg_value in pairs: arg_key = arg_key.strip() arg_value = arg_value.strip() assert arg_key, f"Empty argument key in function {func_name}" assert arg_key not in arguments, \ f"Duplicate argument key '{arg_key}' in function {func_name}" arg_type = get_argument_type(func_name, arg_key, formatted_tools) if arg_type and arg_type != 'string': parsed_value, is_good_json = parse_arguments(arg_value) arg_value = parsed_value arguments[arg_key] = arg_value tool_calls.append({ 'id': "tool-call-" + str(uuid.uuid4()), 'type': "function", 'function': { 'name': func_name, 'arguments': arguments } }) message = {'role': 'assistant'} if reasoning_content: message['reasoning_content'] = reasoning_content message['content'] = content if tool_calls: message['tool_calls'] = tool_calls return message if __name__=="__main__": from transformers import AutoModelForCausalLM, AutoTokenizer from parse_model_response import parse_model_response model_name = "meituan-longcat/LongCat-Flash-Lite" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Give me a brief introduction to large language models."} ] input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) generated_ids = model.generate(inputs=input_ids, max_new_tokens=256) output_ids = generated_ids[0][len(input_ids[0]):].tolist() response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") print("Example 1: sample response.") print("\nRaw response:") print(response) print("\nParsed result:") response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") parsed_message = parse_model_response(response) print(json.dumps(parsed_message, indent=2, ensure_ascii=False)) tools = [ { "type": "function", "function": { "name": "func_add", "description": "Calculate the sum of two numbers", "parameters": { "type": "object", "properties": { "x1": {"type": "number", "description": "The first addend"}, "x2": {"type": "number", "description": "The second addend"} }, "required": ["x1", "x2"] } } } ] messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Please tell me what is $$125679 + 234519$$?"}, # { # "role": "assistant", # "content": "I'll calculate the sum of 125679 and 234519 for you.", # "tool_calls": [{"type": "function", "function": {"name": "func_add", "arguments": {"x1": 125679, "x2": 234519}}}] # }, # {"role": "tool", "name": "func_add", "content": '{"ans": 360198}'} ] input_ids = tokenizer.apply_chat_template( messages, tools=tools, add_generation_prompt=True, return_tensors="pt" ).to(model.device) generated_ids = model.generate(inputs=input_ids, max_new_tokens=256) output_ids = generated_ids[0][len(input_ids[0]):].tolist() response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") print("Example 2: tool call response.") print("\nRaw response:") print(response) print("\nParsed result:") parsed_message = parse_model_response(response, tools) print(json.dumps(parsed_message, indent=2, ensure_ascii=False))