| | import ast |
| | import os |
| | import re |
| | import json |
| | import logging |
| | import datetime |
| | import xml.etree.ElementTree as ET |
| |
|
| | from logging.handlers import RotatingFileHandler |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", |
| | datefmt="%Y-%m-%d:%H:%M:%S", |
| | level=logging.INFO, |
| | ) |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| | now = datetime.datetime.now() |
| | log_folder = os.path.join(script_dir, "inference_logs") |
| | os.makedirs(log_folder, exist_ok=True) |
| | log_file_path = os.path.join( |
| | log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log" |
| | ) |
| | |
| | file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) |
| | file_handler.setLevel(logging.INFO) |
| |
|
| | formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") |
| | file_handler.setFormatter(formatter) |
| |
|
| | inference_logger = logging.getLogger("function-calling-inference") |
| | inference_logger.addHandler(file_handler) |
| |
|
| | def get_fewshot_examples(num_fewshot): |
| | """return a list of few shot examples""" |
| | example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') |
| | with open(example_path, 'r') as file: |
| | examples = json.load(file) |
| | if num_fewshot > len(examples): |
| | raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") |
| | return examples[:num_fewshot] |
| |
|
| | def get_chat_template(chat_template): |
| | """read chat template from jinja file""" |
| | template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2") |
| |
|
| | if not os.path.exists(template_path): |
| | print |
| | inference_logger.error(f"Template file not found: {chat_template}") |
| | return None |
| | try: |
| | with open(template_path, 'r') as file: |
| | template = file.read() |
| | return template |
| | except Exception as e: |
| | print(f"Error loading template: {e}") |
| | return None |
| |
|
| | def get_assistant_message(completion, chat_template, eos_token): |
| | """define and match pattern to find the assistant message""" |
| | completion = completion.strip() |
| |
|
| | if chat_template == "zephyr": |
| | assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL) |
| | elif chat_template == "chatml": |
| | assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL) |
| |
|
| | elif chat_template == "vicuna": |
| | assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL) |
| | else: |
| | raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.") |
| | |
| | assistant_match = assistant_pattern.search(completion) |
| | if assistant_match: |
| | assistant_content = assistant_match.group(1).strip() |
| | if chat_template == "vicuna": |
| | eos_token = f"</s>{eos_token}" |
| | return assistant_content.replace(eos_token, "") |
| | else: |
| | assistant_content = None |
| | inference_logger.info("No match found for the assistant pattern") |
| | return assistant_content |
| |
|
| | def validate_and_extract_tool_calls(assistant_content): |
| | validation_result = False |
| | tool_calls = [] |
| | error_message = None |
| |
|
| | try: |
| | |
| | xml_root_element = f"<root>{assistant_content}</root>" |
| | root = ET.fromstring(xml_root_element) |
| |
|
| | |
| | for element in root.findall(".//tool_call"): |
| | json_data = None |
| | try: |
| | json_text = element.text.strip() |
| |
|
| | try: |
| | |
| | json_data = json.loads(json_text) |
| | except json.JSONDecodeError as json_err: |
| | try: |
| | |
| | json_data = ast.literal_eval(json_text) |
| | except (SyntaxError, ValueError) as eval_err: |
| | error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\ |
| | f"- JSON Decode Error: {json_err}\n"\ |
| | f"- Fallback Syntax/Value Error: {eval_err}\n"\ |
| | f"- Problematic JSON text: {json_text}" |
| | inference_logger.error(error_message) |
| | continue |
| | except Exception as e: |
| | error_message = f"Cannot strip text: {e}" |
| | inference_logger.error(error_message) |
| |
|
| | if json_data is not None: |
| | tool_calls.append(json_data) |
| | validation_result = True |
| |
|
| | except ET.ParseError as err: |
| | error_message = f"XML Parse Error: {err}" |
| | inference_logger.error(f"XML Parse Error: {err}") |
| |
|
| | |
| | return validation_result, tool_calls, error_message |
| |
|
| | def extract_json_from_markdown(text): |
| | """ |
| | Extracts the JSON string from the given text using a regular expression pattern. |
| | |
| | Args: |
| | text (str): The input text containing the JSON string. |
| | |
| | Returns: |
| | dict: The JSON data loaded from the extracted string, or None if the JSON string is not found. |
| | """ |
| | json_pattern = r'```json\r?\n(.*?)\r?\n```' |
| | match = re.search(json_pattern, text, re.DOTALL) |
| | if match: |
| | json_string = match.group(1) |
| | try: |
| | data = json.loads(json_string) |
| | return data |
| | except json.JSONDecodeError as e: |
| | print(f"Error decoding JSON string: {e}") |
| | else: |
| | print("JSON string not found in the text.") |
| | return None |
| |
|