| import logging |
| logging.basicConfig(level=logging.INFO, |
| format='%(asctime)s | %(levelname)-6s | %(name)-40s || %(message)s', |
| datefmt='%m-%d %H:%M:%S') |
| logger = logging.getLogger(__name__) |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| import yaml |
| import json |
| from typing import Optional |
| from jinja2 import Template |
|
|
| |
| _openai_client = None |
|
|
| def _get_openai_client(): |
| """Get or create OpenAI client using environment variables.""" |
| global _openai_client |
| if _openai_client is not None: |
| return _openai_client |
|
|
| try: |
| import openai |
| except ImportError: |
| logger.warning("openai package not installed. LLM-based error token localization will not be available.") |
| return None |
|
|
| api_key = os.environ.get("OPENAI_API_KEY") |
| if not api_key: |
| logger.warning("OPENAI_API_KEY not set. LLM-based error token localization will not be available.") |
| return None |
|
|
| base_url = os.environ.get("OPENAI_BASE_URL", None) |
| kwargs = {"api_key": api_key} |
| if base_url: |
| kwargs["base_url"] = base_url |
|
|
| _openai_client = openai.OpenAI(**kwargs) |
| return _openai_client |
|
|
|
|
| def _get_default_models(): |
| """Get the list of LLM models to use as validators from env or defaults.""" |
| models_str = os.environ.get("LLM_MODELS", "gpt-4o-mini") |
| return [m.strip() for m in models_str.split(",") if m.strip()] |
|
|
|
|
| class ErrorTokenLocator: |
| def __init__(self, model, tokenizer, prompt_template_yaml=None): |
| """ |
| Initialize the error token locator |
| |
| Args: |
| model: The language model to use |
| tokenizer: The corresponding tokenizer for tokenizing text |
| prompt_template_yaml (str, optional): Path to the prompt template YAML file, uses default template when None |
| """ |
| self.model = model |
| self.tokenizer = tokenizer |
|
|
| self.client = _get_openai_client() |
|
|
| self.endpoint_list = _get_default_models() |
|
|
| if prompt_template_yaml is None: |
| self.system_prompt = self.load_general_prompt_template() |
| else: |
| self.system_prompt = self.load_general_prompt_template(prompt_template_yaml) |
|
|
| def prompt_constructor(self, query, completion, ground_truth: str=None): |
| """ |
| Construct prompts for error token location |
| |
| Args: |
| query (str): The input query/prompt |
| completion (str): The completion text generated by the model |
| ground_truth (str, optional): The correct answer/ground truth, defaults to None |
| |
| Returns: |
| tuple: (msg, tokens) |
| - msg (list): The constructed conversation message list containing system and user messages |
| - tokens (list): List of token IDs for the completion |
| """ |
| tokens = self.tokenizer(completion).input_ids |
| indexed_completion = "" |
| for i, tok in enumerate(tokens): |
| indexed_completion += f"{self.tokenizer.decode([tok])}[{i}] " |
|
|
| user_msg_content = { |
| "prompt": query, |
| "completion": completion, |
| "indexed_completion": indexed_completion, |
| "ground_truth": ground_truth |
| } |
|
|
| msg = [ |
| {"role": "system", "content": self.system_prompt}, |
| {"role": "user", "content": json.dumps(user_msg_content, indent=2)} |
| ] |
| return msg, tokens |
|
|
| def load_general_prompt_template(self, prompt_template_yaml=None): |
| """ |
| Load the general prompt template |
| |
| Args: |
| prompt_template_yaml (str, optional): Path to the YAML template file, uses default path when None |
| |
| Returns: |
| str: The rendered system prompt template string |
| """ |
| if prompt_template_yaml is None: |
| |
| prompt_template_yaml = os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), |
| "token_locator_prompts", |
| "err_token_localization.yaml" |
| ) |
|
|
| with open(prompt_template_yaml, "r", encoding='utf-8') as f: |
| system_prompt_temp = yaml.safe_load(f) |
| system_prompt = Template(system_prompt_temp['system']).render( |
| dataset_description="No dataset description provided.", |
| dataset_specific_instructions="-No dataset specific instructions provided." |
| ) |
|
|
| return system_prompt |
|
|
| def call_validator(self, msg, tokens, endpoint_list=None): |
| """ |
| Call the validator for error token location validation |
| |
| Args: |
| msg (list): The constructed conversation message list |
| tokens (list): List of token IDs |
| endpoint_list (list): List of validator endpoint names |
| |
| Returns: |
| tuple: (completion_before_err, explanation, vote_details) |
| - completion_before_err (str): Completion text truncated before the error token |
| - explanation (str): Explanation of why this token is the error |
| - vote_details (dict): Detailed voting information for each validator |
| """ |
| if self.client is None: |
| raise RuntimeError( |
| "OpenAI client not available. Please set OPENAI_API_KEY environment variable " |
| "or use manual truncation instead." |
| ) |
|
|
| responses = [] |
| if endpoint_list is None: |
| endpoint_list = self.endpoint_list |
| for model_name in endpoint_list: |
| response = self.client.chat.completions.create( |
| model=model_name, |
| messages=msg, |
| temperature=0, |
| seed=42, |
| top_p=1, |
| frequency_penalty=0, |
| presence_penalty=0, |
| response_format={"type": "json_object"} |
| ) |
| responses.append(response) |
|
|
| votes = {} |
| first_response_for_token = {} |
| validator_votes = {} |
|
|
| for idx, response in enumerate(responses): |
| model_name = endpoint_list[idx] |
| try: |
| res_json = json.loads(response.choices[0].message.content) |
| token_index = int(res_json["token_index"]) |
| explanation = res_json.get("explanation", "") |
|
|
| |
| votes[token_index] = votes.get(token_index, 0) + 1 |
| if token_index not in first_response_for_token: |
| first_response_for_token[token_index] = (response, idx) |
|
|
| |
| validator_votes[model_name] = { |
| "token_index": token_index, |
| "error_token": self.tokenizer.decode([tokens[token_index]]) if token_index < len(tokens) else "N/A", |
| "explanation": explanation |
| } |
| except Exception as e: |
| logger.error(f"Error processing response from {model_name}: {e}") |
| validator_votes[model_name] = { |
| "token_index": -1, |
| "error_token": "Error", |
| "explanation": f"Failed to parse response: {str(e)}" |
| } |
|
|
| if not votes: |
| return responses[0] if responses else None, "", {} |
|
|
| max_votes = max(votes.values()) |
| candidates = [t for t, c in votes.items() if c == max_votes] |
| |
| winner_token = min(candidates, key=lambda t: first_response_for_token[t][1]) |
|
|
| response = first_response_for_token[winner_token][0].choices[0].message.content |
| failure_rca = json.loads(response) |
| completion_before_err = "" |
| for idx, tok in enumerate(tokens): |
| if idx == failure_rca["token_index"]: |
| break |
| completion_before_err += f"{self.tokenizer.decode([tok])}" |
| explanation = failure_rca["explanation"] |
|
|
| |
| vote_summary = { |
| "winner_token_index": winner_token, |
| "winner_votes": max_votes, |
| "total_validators": len(endpoint_list), |
| "vote_distribution": votes |
| } |
|
|
| return completion_before_err, explanation, { |
| "validators": validator_votes, |
| "summary": vote_summary |
| } |
|
|
| def locate_error_token(self, prompt: str, completion: str, ground_truth: str = None, |
| validators: Optional[list] = None, |
| use_llm: bool = True, |
| manual_chunks: Optional[list] = None): |
| """ |
| Main method to locate the error token in a completion |
| |
| Args: |
| prompt (str): The input prompt |
| completion (str): The completion text to analyze |
| ground_truth (str, optional): The correct answer/ground truth, defaults to None |
| |
| Returns: |
| dict: Dictionary containing: |
| - status (str): "success" or "error" |
| - truncated_text (str): Prompt + completion truncated before error token |
| - explanation (str): Explanation of the error |
| - error_token_index (int): Index of the error token |
| """ |
| try: |
| |
| msg, tokens = self.prompt_constructor(prompt, completion, ground_truth) |
|
|
| |
| if not use_llm: |
| if manual_chunks and len(manual_chunks) > 0: |
| manual_chunk = manual_chunks[0] |
| |
| |
| truncated_text = manual_chunk |
| explanation = "Manual chunk provided by user (LLM search skipped)." |
| error_token_index = len(self.tokenizer(manual_chunk).input_ids) |
| return { |
| "status": "success", |
| "truncated_text": truncated_text, |
| "explanation": explanation, |
| "error_token_index": error_token_index |
| } |
| else: |
| return { |
| "status": "error", |
| "message": "LLM search disabled but no manual chunk provided.", |
| "truncated_text": "", |
| "explanation": "" |
| } |
|
|
| |
| if self.client is None: |
| return { |
| "status": "error", |
| "message": "OpenAI API key not configured. Please set OPENAI_API_KEY environment variable or use manual truncation (disable LLM search).", |
| "truncated_text": "", |
| "explanation": "" |
| } |
|
|
| |
| endpoint_list = validators if (validators and isinstance(validators, list) and len(validators) > 0) else None |
|
|
| |
| completion_before_err, explanation, vote_details = self.call_validator(msg, tokens, endpoint_list) |
|
|
| |
| truncated_text = prompt + completion_before_err |
|
|
| |
| error_token_index = len(self.tokenizer(completion_before_err).input_ids) |
|
|
| return { |
| "status": "success", |
| "truncated_text": truncated_text, |
| "explanation": explanation, |
| "error_token_index": error_token_index, |
| "vote_details": vote_details |
| } |
|
|
| except Exception as e: |
| logger.error(f"Error in locate_error_token: {e}") |
| import traceback |
| traceback.print_exc() |
| return { |
| "status": "error", |
| "message": str(e), |
| "truncated_text": "", |
| "explanation": "" |
| } |
|
|