Spaces:
Build error
Build error
| import re | |
| import json | |
| from typing import List, Dict, Any, Optional | |
| class ToolCallExtractor: | |
| def __init__(self): | |
| # Existing regex patterns (retain if needed for other formats) | |
| self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL) | |
| self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL) | |
| def _extract_function_args(self, args) -> Dict[str, Any]: | |
| """ | |
| Flatten the nested function args structure for Google AI protobuf types. | |
| """ | |
| flattened_args = {} | |
| try: | |
| # Explicitly check for fields | |
| if hasattr(args, 'fields'): | |
| # Iterate through fields using to_dict() to convert protobuf to dict | |
| for field in args.fields: | |
| key = field.key | |
| value = field.value | |
| # Additional debugging | |
| print(f"Field key: {key}") | |
| print(f"Field value type: {type(value)}") | |
| print(f"Field value: {value}") | |
| # Extract string value | |
| if hasattr(value, 'string_value'): | |
| flattened_args[key] = value.string_value | |
| print(f"Extracted string value: {value.string_value}") | |
| elif hasattr(value, 'number_value'): | |
| flattened_args[key] = value.number_value | |
| elif hasattr(value, 'bool_value') and value.bool_value is not None: | |
| flattened_args[key] = value.bool_value | |
| # Added additional debug information | |
| print(f"Final flattened args: {flattened_args}") | |
| except Exception as e: | |
| print(f"Error extracting function args: {e}") | |
| return flattened_args | |
| def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]: | |
| """ | |
| Extract tool calls from input string, handling various inconsistent formats. | |
| Args: | |
| input_string (str): The input string containing tool calls. | |
| Returns: | |
| list: A list of dictionaries representing the parsed tool calls. | |
| """ | |
| tool_calls = [] | |
| # Existing tag-based extraction (retain if needed) | |
| complete_matches = self.complete_pattern.findall(input_string) | |
| if complete_matches: | |
| for match in complete_matches: | |
| tool_calls.extend(self._extract_json_objects(match)) | |
| return tool_calls | |
| partial_matches = self.partial_pattern.findall(input_string) | |
| if partial_matches: | |
| for match in partial_matches: | |
| tool_calls.extend(self._extract_json_objects(match)) | |
| return tool_calls | |
| # Fallback: Attempt to parse the entire string | |
| tool_calls.extend(self._extract_json_objects(input_string)) | |
| return tool_calls | |
| def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]: | |
| """ | |
| Extract and parse multiple JSON objects from a string. | |
| """ | |
| json_objects = [] | |
| potential_jsons = text.split(';') | |
| for json_str in potential_jsons: | |
| parsed_obj = self._clean_and_parse_json(json_str) | |
| if parsed_obj: | |
| json_objects.append(parsed_obj) | |
| return json_objects | |
| def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Clean and parse a JSON string, handling common formatting issues. | |
| """ | |
| try: | |
| json_str = json_str.strip() | |
| if json_str.startswith('{') or json_str.startswith('['): | |
| return json.loads(json_str) | |
| return None | |
| except json.JSONDecodeError: | |
| return None | |
| def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool: | |
| """ | |
| Validate if a tool call has the required fields. | |
| """ | |
| return ( | |
| isinstance(tool_call, dict) and | |
| 'name' in tool_call and | |
| isinstance(tool_call['name'], str) | |
| ) | |
| def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]: | |
| """ | |
| Extract function call details from the response parts. | |
| Args: | |
| response_parts (list): The list of response parts from the chat model. | |
| Returns: | |
| dict: A dictionary containing the function name and flattened arguments. | |
| """ | |
| for part in response_parts: | |
| # Debug print | |
| print(f"Examining part: {part}") | |
| print(f"Part type: {type(part)}") | |
| # Check for function_call attribute | |
| if hasattr(part, 'function_call') and part.function_call: | |
| function_call = part.function_call | |
| # Debug print | |
| print(f"Function call: {function_call}") | |
| print(f"Function call type: {type(function_call)}") | |
| print(f"Function args: {function_call.args}") | |
| # Extract function name | |
| function_name = getattr(function_call, 'name', None) | |
| if not function_name: | |
| continue # Skip if function name is missing | |
| # Extract function arguments | |
| function_args = self._extract_function_args(function_call.args) | |
| return { | |
| "name": function_name, | |
| "args": function_args | |
| } | |
| return {} | |