|
|
import ast |
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import logging |
|
|
import datetime |
|
|
import xml.etree.ElementTree as ET |
|
|
from logger import logger |
|
|
|
|
|
from logging.handlers import RotatingFileHandler |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", |
|
|
datefmt="%Y-%m-%d:%H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
now = datetime.datetime.now() |
|
|
log_folder = os.path.join(script_dir, "inference_logs") |
|
|
os.makedirs(log_folder, exist_ok=True) |
|
|
log_file_path = os.path.join( |
|
|
log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log" |
|
|
) |
|
|
|
|
|
file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) |
|
|
file_handler.setLevel(logging.INFO) |
|
|
|
|
|
formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") |
|
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
def get_fewshot_examples(num_fewshot): |
|
|
"""return a list of few shot examples""" |
|
|
example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') |
|
|
with open(example_path, 'r') as file: |
|
|
examples = json.load(file) |
|
|
if num_fewshot > len(examples): |
|
|
raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") |
|
|
return examples[:num_fewshot] |
|
|
|
|
|
def get_chat_template(chat_template): |
|
|
"""read chat template from jinja file""" |
|
|
template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2") |
|
|
|
|
|
if not os.path.exists(template_path): |
|
|
print |
|
|
logger.error(f"Template file not found: {chat_template}") |
|
|
return None |
|
|
try: |
|
|
with open(template_path, 'r') as file: |
|
|
template = file.read() |
|
|
return template |
|
|
except Exception as e: |
|
|
print(f"Error loading template: {e}") |
|
|
return None |
|
|
|
|
|
def validate_and_extract_tool_calls(assistant_content): |
|
|
validation_result = False |
|
|
tool_calls = [] |
|
|
error_message = None |
|
|
|
|
|
try: |
|
|
|
|
|
xml_root_element = f"<root>{assistant_content}</root>" |
|
|
root = ET.fromstring(xml_root_element) |
|
|
|
|
|
|
|
|
for element in root.findall(".//tool_call"): |
|
|
json_data = None |
|
|
try: |
|
|
json_text = element.text.strip() |
|
|
|
|
|
try: |
|
|
|
|
|
json_data = json.loads(json_text) |
|
|
except json.JSONDecodeError as json_err: |
|
|
try: |
|
|
|
|
|
json_data = ast.literal_eval(json_text) |
|
|
except (SyntaxError, ValueError) as eval_err: |
|
|
error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\ |
|
|
f"- JSON Decode Error: {json_err}\n"\ |
|
|
f"- Fallback Syntax/Value Error: {eval_err}\n"\ |
|
|
f"- Problematic JSON text: {json_text}" |
|
|
logger.error(error_message) |
|
|
continue |
|
|
except Exception as e: |
|
|
error_message = f"Cannot strip text: {e}" |
|
|
logger.error(error_message) |
|
|
|
|
|
if json_data is not None: |
|
|
tool_calls.append(json_data) |
|
|
validation_result = True |
|
|
|
|
|
except ET.ParseError as err: |
|
|
error_message = f"XML Parse Error: {err}" |
|
|
logger.error(f"XML Parse Error: {err}") |
|
|
|
|
|
|
|
|
return validation_result, tool_calls, error_message |
|
|
|
|
|
def extract_json_from_markdown(text): |
|
|
""" |
|
|
Extracts the JSON string from the given text using a regular expression pattern. |
|
|
|
|
|
Args: |
|
|
text (str): The input text containing the JSON string. |
|
|
|
|
|
Returns: |
|
|
dict: The JSON data loaded from the extracted string, or None if the JSON string is not found. |
|
|
""" |
|
|
json_pattern = r'```json\r?\n(.*?)\r?\n```' |
|
|
match = re.search(json_pattern, text, re.DOTALL) |
|
|
if match: |
|
|
json_string = match.group(1) |
|
|
try: |
|
|
data = json.loads(json_string) |
|
|
return data |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Error decoding JSON string: {e}") |
|
|
else: |
|
|
print("JSON string not found in the text.") |
|
|
return None |
|
|
|