import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-6s | %(name)-40s || %(message)s', datefmt='%m-%d %H:%M:%S') logger = logging.getLogger(__name__) import os import sys from pathlib import Path import yaml import json from typing import Optional from jinja2 import Template # Use env-var based OpenAI client instead of Azure-specific endpoint _openai_client = None def _get_openai_client(): """Get or create OpenAI client using environment variables.""" global _openai_client if _openai_client is not None: return _openai_client try: import openai except ImportError: logger.warning("openai package not installed. LLM-based error token localization will not be available.") return None api_key = os.environ.get("OPENAI_API_KEY") if not api_key: logger.warning("OPENAI_API_KEY not set. LLM-based error token localization will not be available.") return None base_url = os.environ.get("OPENAI_BASE_URL", None) kwargs = {"api_key": api_key} if base_url: kwargs["base_url"] = base_url _openai_client = openai.OpenAI(**kwargs) return _openai_client def _get_default_models(): """Get the list of LLM models to use as validators from env or defaults.""" models_str = os.environ.get("LLM_MODELS", "gpt-4o-mini") return [m.strip() for m in models_str.split(",") if m.strip()] class ErrorTokenLocator: def __init__(self, model, tokenizer, prompt_template_yaml=None): """ Initialize the error token locator Args: model: The language model to use tokenizer: The corresponding tokenizer for tokenizing text prompt_template_yaml (str, optional): Path to the prompt template YAML file, uses default template when None """ self.model = model self.tokenizer = tokenizer self.client = _get_openai_client() self.endpoint_list = _get_default_models() if prompt_template_yaml is None: self.system_prompt = self.load_general_prompt_template() else: self.system_prompt = self.load_general_prompt_template(prompt_template_yaml) def prompt_constructor(self, query, completion, ground_truth: str=None): """ Construct prompts for error token location Args: query (str): The input query/prompt completion (str): The completion text generated by the model ground_truth (str, optional): The correct answer/ground truth, defaults to None Returns: tuple: (msg, tokens) - msg (list): The constructed conversation message list containing system and user messages - tokens (list): List of token IDs for the completion """ tokens = self.tokenizer(completion).input_ids indexed_completion = "" for i, tok in enumerate(tokens): indexed_completion += f"{self.tokenizer.decode([tok])}[{i}] " user_msg_content = { "prompt": query, "completion": completion, "indexed_completion": indexed_completion, "ground_truth": ground_truth } msg = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": json.dumps(user_msg_content, indent=2)} ] return msg, tokens def load_general_prompt_template(self, prompt_template_yaml=None): """ Load the general prompt template Args: prompt_template_yaml (str, optional): Path to the YAML template file, uses default path when None Returns: str: The rendered system prompt template string """ if prompt_template_yaml is None: # Use path relative to this file's location prompt_template_yaml = os.path.join( os.path.dirname(os.path.abspath(__file__)), "token_locator_prompts", "err_token_localization.yaml" ) with open(prompt_template_yaml, "r", encoding='utf-8') as f: system_prompt_temp = yaml.safe_load(f) system_prompt = Template(system_prompt_temp['system']).render( dataset_description="No dataset description provided.", dataset_specific_instructions="-No dataset specific instructions provided." ) return system_prompt def call_validator(self, msg, tokens, endpoint_list=None): """ Call the validator for error token location validation Args: msg (list): The constructed conversation message list tokens (list): List of token IDs endpoint_list (list): List of validator endpoint names Returns: tuple: (completion_before_err, explanation, vote_details) - completion_before_err (str): Completion text truncated before the error token - explanation (str): Explanation of why this token is the error - vote_details (dict): Detailed voting information for each validator """ if self.client is None: raise RuntimeError( "OpenAI client not available. Please set OPENAI_API_KEY environment variable " "or use manual truncation instead." ) responses = [] if endpoint_list is None: endpoint_list = self.endpoint_list for model_name in endpoint_list: response = self.client.chat.completions.create( model=model_name, messages=msg, temperature=0, seed=42, top_p=1, frequency_penalty=0, presence_penalty=0, response_format={"type": "json_object"} ) responses.append(response) votes = {} first_response_for_token = {} validator_votes = {} # Track each validator's vote for idx, response in enumerate(responses): model_name = endpoint_list[idx] try: res_json = json.loads(response.choices[0].message.content) token_index = int(res_json["token_index"]) explanation = res_json.get("explanation", "") # Track votes votes[token_index] = votes.get(token_index, 0) + 1 if token_index not in first_response_for_token: first_response_for_token[token_index] = (response, idx) # Store each validator's vote details validator_votes[model_name] = { "token_index": token_index, "error_token": self.tokenizer.decode([tokens[token_index]]) if token_index < len(tokens) else "N/A", "explanation": explanation } except Exception as e: logger.error(f"Error processing response from {model_name}: {e}") validator_votes[model_name] = { "token_index": -1, "error_token": "Error", "explanation": f"Failed to parse response: {str(e)}" } if not votes: return responses[0] if responses else None, "", {} max_votes = max(votes.values()) candidates = [t for t, c in votes.items() if c == max_votes] # pick the candidate whose first corresponding response appeared earliest winner_token = min(candidates, key=lambda t: first_response_for_token[t][1]) response = first_response_for_token[winner_token][0].choices[0].message.content failure_rca = json.loads(response) completion_before_err = "" for idx, tok in enumerate(tokens): if idx == failure_rca["token_index"]: break completion_before_err += f"{self.tokenizer.decode([tok])}" explanation = failure_rca["explanation"] # Add vote summary to vote_details vote_summary = { "winner_token_index": winner_token, "winner_votes": max_votes, "total_validators": len(endpoint_list), "vote_distribution": votes } return completion_before_err, explanation, { "validators": validator_votes, "summary": vote_summary } def locate_error_token(self, prompt: str, completion: str, ground_truth: str = None, validators: Optional[list] = None, use_llm: bool = True, manual_chunks: Optional[list] = None): """ Main method to locate the error token in a completion Args: prompt (str): The input prompt completion (str): The completion text to analyze ground_truth (str, optional): The correct answer/ground truth, defaults to None Returns: dict: Dictionary containing: - status (str): "success" or "error" - truncated_text (str): Prompt + completion truncated before error token - explanation (str): Explanation of the error - error_token_index (int): Index of the error token """ try: # Construct prompt messages msg, tokens = self.prompt_constructor(prompt, completion, ground_truth) # If user requests to skip LLM search, use manual chunks if provided if not use_llm: if manual_chunks and len(manual_chunks) > 0: manual_chunk = manual_chunks[0] # manual_chunk already contains prompt + completion (set by frontend), # so use it directly as truncated_text without prepending prompt again truncated_text = manual_chunk explanation = "Manual chunk provided by user (LLM search skipped)." error_token_index = len(self.tokenizer(manual_chunk).input_ids) return { "status": "success", "truncated_text": truncated_text, "explanation": explanation, "error_token_index": error_token_index } else: return { "status": "error", "message": "LLM search disabled but no manual chunk provided.", "truncated_text": "", "explanation": "" } # Check if OpenAI client is available if self.client is None: return { "status": "error", "message": "OpenAI API key not configured. Please set OPENAI_API_KEY environment variable or use manual truncation (disable LLM search).", "truncated_text": "", "explanation": "" } # If validators provided, use them for this call endpoint_list = validators if (validators and isinstance(validators, list) and len(validators) > 0) else None # Call validator to get error token location with vote details completion_before_err, explanation, vote_details = self.call_validator(msg, tokens, endpoint_list) # Combine prompt with truncated completion truncated_text = prompt + completion_before_err # Calculate error token index error_token_index = len(self.tokenizer(completion_before_err).input_ids) return { "status": "success", "truncated_text": truncated_text, "explanation": explanation, "error_token_index": error_token_index, "vote_details": vote_details } except Exception as e: logger.error(f"Error in locate_error_token: {e}") import traceback traceback.print_exc() return { "status": "error", "message": str(e), "truncated_text": "", "explanation": "" }