Debug-XAI / backend /error_token_location.py
rongyuan
Bugs fixed.
4e252fa
import logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s | %(levelname)-6s | %(name)-40s || %(message)s',
datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
import os
import sys
from pathlib import Path
import yaml
import json
from typing import Optional
from jinja2 import Template
# Use env-var based OpenAI client instead of Azure-specific endpoint
_openai_client = None
def _get_openai_client():
"""Get or create OpenAI client using environment variables."""
global _openai_client
if _openai_client is not None:
return _openai_client
try:
import openai
except ImportError:
logger.warning("openai package not installed. LLM-based error token localization will not be available.")
return None
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set. LLM-based error token localization will not be available.")
return None
base_url = os.environ.get("OPENAI_BASE_URL", None)
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
_openai_client = openai.OpenAI(**kwargs)
return _openai_client
def _get_default_models():
"""Get the list of LLM models to use as validators from env or defaults."""
models_str = os.environ.get("LLM_MODELS", "gpt-4o-mini")
return [m.strip() for m in models_str.split(",") if m.strip()]
class ErrorTokenLocator:
def __init__(self, model, tokenizer, prompt_template_yaml=None):
"""
Initialize the error token locator
Args:
model: The language model to use
tokenizer: The corresponding tokenizer for tokenizing text
prompt_template_yaml (str, optional): Path to the prompt template YAML file, uses default template when None
"""
self.model = model
self.tokenizer = tokenizer
self.client = _get_openai_client()
self.endpoint_list = _get_default_models()
if prompt_template_yaml is None:
self.system_prompt = self.load_general_prompt_template()
else:
self.system_prompt = self.load_general_prompt_template(prompt_template_yaml)
def prompt_constructor(self, query, completion, ground_truth: str=None):
"""
Construct prompts for error token location
Args:
query (str): The input query/prompt
completion (str): The completion text generated by the model
ground_truth (str, optional): The correct answer/ground truth, defaults to None
Returns:
tuple: (msg, tokens)
- msg (list): The constructed conversation message list containing system and user messages
- tokens (list): List of token IDs for the completion
"""
tokens = self.tokenizer(completion).input_ids
indexed_completion = ""
for i, tok in enumerate(tokens):
indexed_completion += f"{self.tokenizer.decode([tok])}[{i}] "
user_msg_content = {
"prompt": query,
"completion": completion,
"indexed_completion": indexed_completion,
"ground_truth": ground_truth
}
msg = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": json.dumps(user_msg_content, indent=2)}
]
return msg, tokens
def load_general_prompt_template(self, prompt_template_yaml=None):
"""
Load the general prompt template
Args:
prompt_template_yaml (str, optional): Path to the YAML template file, uses default path when None
Returns:
str: The rendered system prompt template string
"""
if prompt_template_yaml is None:
# Use path relative to this file's location
prompt_template_yaml = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"token_locator_prompts",
"err_token_localization.yaml"
)
with open(prompt_template_yaml, "r", encoding='utf-8') as f:
system_prompt_temp = yaml.safe_load(f)
system_prompt = Template(system_prompt_temp['system']).render(
dataset_description="No dataset description provided.",
dataset_specific_instructions="-No dataset specific instructions provided."
)
return system_prompt
def call_validator(self, msg, tokens, endpoint_list=None):
"""
Call the validator for error token location validation
Args:
msg (list): The constructed conversation message list
tokens (list): List of token IDs
endpoint_list (list): List of validator endpoint names
Returns:
tuple: (completion_before_err, explanation, vote_details)
- completion_before_err (str): Completion text truncated before the error token
- explanation (str): Explanation of why this token is the error
- vote_details (dict): Detailed voting information for each validator
"""
if self.client is None:
raise RuntimeError(
"OpenAI client not available. Please set OPENAI_API_KEY environment variable "
"or use manual truncation instead."
)
responses = []
if endpoint_list is None:
endpoint_list = self.endpoint_list
for model_name in endpoint_list:
response = self.client.chat.completions.create(
model=model_name,
messages=msg,
temperature=0,
seed=42,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "json_object"}
)
responses.append(response)
votes = {}
first_response_for_token = {}
validator_votes = {} # Track each validator's vote
for idx, response in enumerate(responses):
model_name = endpoint_list[idx]
try:
res_json = json.loads(response.choices[0].message.content)
token_index = int(res_json["token_index"])
explanation = res_json.get("explanation", "")
# Track votes
votes[token_index] = votes.get(token_index, 0) + 1
if token_index not in first_response_for_token:
first_response_for_token[token_index] = (response, idx)
# Store each validator's vote details
validator_votes[model_name] = {
"token_index": token_index,
"error_token": self.tokenizer.decode([tokens[token_index]]) if token_index < len(tokens) else "N/A",
"explanation": explanation
}
except Exception as e:
logger.error(f"Error processing response from {model_name}: {e}")
validator_votes[model_name] = {
"token_index": -1,
"error_token": "Error",
"explanation": f"Failed to parse response: {str(e)}"
}
if not votes:
return responses[0] if responses else None, "", {}
max_votes = max(votes.values())
candidates = [t for t, c in votes.items() if c == max_votes]
# pick the candidate whose first corresponding response appeared earliest
winner_token = min(candidates, key=lambda t: first_response_for_token[t][1])
response = first_response_for_token[winner_token][0].choices[0].message.content
failure_rca = json.loads(response)
completion_before_err = ""
for idx, tok in enumerate(tokens):
if idx == failure_rca["token_index"]:
break
completion_before_err += f"{self.tokenizer.decode([tok])}"
explanation = failure_rca["explanation"]
# Add vote summary to vote_details
vote_summary = {
"winner_token_index": winner_token,
"winner_votes": max_votes,
"total_validators": len(endpoint_list),
"vote_distribution": votes
}
return completion_before_err, explanation, {
"validators": validator_votes,
"summary": vote_summary
}
def locate_error_token(self, prompt: str, completion: str, ground_truth: str = None,
validators: Optional[list] = None,
use_llm: bool = True,
manual_chunks: Optional[list] = None):
"""
Main method to locate the error token in a completion
Args:
prompt (str): The input prompt
completion (str): The completion text to analyze
ground_truth (str, optional): The correct answer/ground truth, defaults to None
Returns:
dict: Dictionary containing:
- status (str): "success" or "error"
- truncated_text (str): Prompt + completion truncated before error token
- explanation (str): Explanation of the error
- error_token_index (int): Index of the error token
"""
try:
# Construct prompt messages
msg, tokens = self.prompt_constructor(prompt, completion, ground_truth)
# If user requests to skip LLM search, use manual chunks if provided
if not use_llm:
if manual_chunks and len(manual_chunks) > 0:
manual_chunk = manual_chunks[0]
# manual_chunk already contains prompt + completion (set by frontend),
# so use it directly as truncated_text without prepending prompt again
truncated_text = manual_chunk
explanation = "Manual chunk provided by user (LLM search skipped)."
error_token_index = len(self.tokenizer(manual_chunk).input_ids)
return {
"status": "success",
"truncated_text": truncated_text,
"explanation": explanation,
"error_token_index": error_token_index
}
else:
return {
"status": "error",
"message": "LLM search disabled but no manual chunk provided.",
"truncated_text": "",
"explanation": ""
}
# Check if OpenAI client is available
if self.client is None:
return {
"status": "error",
"message": "OpenAI API key not configured. Please set OPENAI_API_KEY environment variable or use manual truncation (disable LLM search).",
"truncated_text": "",
"explanation": ""
}
# If validators provided, use them for this call
endpoint_list = validators if (validators and isinstance(validators, list) and len(validators) > 0) else None
# Call validator to get error token location with vote details
completion_before_err, explanation, vote_details = self.call_validator(msg, tokens, endpoint_list)
# Combine prompt with truncated completion
truncated_text = prompt + completion_before_err
# Calculate error token index
error_token_index = len(self.tokenizer(completion_before_err).input_ids)
return {
"status": "success",
"truncated_text": truncated_text,
"explanation": explanation,
"error_token_index": error_token_index,
"vote_details": vote_details
}
except Exception as e:
logger.error(f"Error in locate_error_token: {e}")
import traceback
traceback.print_exc()
return {
"status": "error",
"message": str(e),
"truncated_text": "",
"explanation": ""
}