File size: 12,307 Bytes
89280a9 4e252fa 89280a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 | import logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s | %(levelname)-6s | %(name)-40s || %(message)s',
datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
import os
import sys
from pathlib import Path
import yaml
import json
from typing import Optional
from jinja2 import Template
# Use env-var based OpenAI client instead of Azure-specific endpoint
_openai_client = None
def _get_openai_client():
"""Get or create OpenAI client using environment variables."""
global _openai_client
if _openai_client is not None:
return _openai_client
try:
import openai
except ImportError:
logger.warning("openai package not installed. LLM-based error token localization will not be available.")
return None
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set. LLM-based error token localization will not be available.")
return None
base_url = os.environ.get("OPENAI_BASE_URL", None)
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
_openai_client = openai.OpenAI(**kwargs)
return _openai_client
def _get_default_models():
"""Get the list of LLM models to use as validators from env or defaults."""
models_str = os.environ.get("LLM_MODELS", "gpt-4o-mini")
return [m.strip() for m in models_str.split(",") if m.strip()]
class ErrorTokenLocator:
def __init__(self, model, tokenizer, prompt_template_yaml=None):
"""
Initialize the error token locator
Args:
model: The language model to use
tokenizer: The corresponding tokenizer for tokenizing text
prompt_template_yaml (str, optional): Path to the prompt template YAML file, uses default template when None
"""
self.model = model
self.tokenizer = tokenizer
self.client = _get_openai_client()
self.endpoint_list = _get_default_models()
if prompt_template_yaml is None:
self.system_prompt = self.load_general_prompt_template()
else:
self.system_prompt = self.load_general_prompt_template(prompt_template_yaml)
def prompt_constructor(self, query, completion, ground_truth: str=None):
"""
Construct prompts for error token location
Args:
query (str): The input query/prompt
completion (str): The completion text generated by the model
ground_truth (str, optional): The correct answer/ground truth, defaults to None
Returns:
tuple: (msg, tokens)
- msg (list): The constructed conversation message list containing system and user messages
- tokens (list): List of token IDs for the completion
"""
tokens = self.tokenizer(completion).input_ids
indexed_completion = ""
for i, tok in enumerate(tokens):
indexed_completion += f"{self.tokenizer.decode([tok])}[{i}] "
user_msg_content = {
"prompt": query,
"completion": completion,
"indexed_completion": indexed_completion,
"ground_truth": ground_truth
}
msg = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": json.dumps(user_msg_content, indent=2)}
]
return msg, tokens
def load_general_prompt_template(self, prompt_template_yaml=None):
"""
Load the general prompt template
Args:
prompt_template_yaml (str, optional): Path to the YAML template file, uses default path when None
Returns:
str: The rendered system prompt template string
"""
if prompt_template_yaml is None:
# Use path relative to this file's location
prompt_template_yaml = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"token_locator_prompts",
"err_token_localization.yaml"
)
with open(prompt_template_yaml, "r", encoding='utf-8') as f:
system_prompt_temp = yaml.safe_load(f)
system_prompt = Template(system_prompt_temp['system']).render(
dataset_description="No dataset description provided.",
dataset_specific_instructions="-No dataset specific instructions provided."
)
return system_prompt
def call_validator(self, msg, tokens, endpoint_list=None):
"""
Call the validator for error token location validation
Args:
msg (list): The constructed conversation message list
tokens (list): List of token IDs
endpoint_list (list): List of validator endpoint names
Returns:
tuple: (completion_before_err, explanation, vote_details)
- completion_before_err (str): Completion text truncated before the error token
- explanation (str): Explanation of why this token is the error
- vote_details (dict): Detailed voting information for each validator
"""
if self.client is None:
raise RuntimeError(
"OpenAI client not available. Please set OPENAI_API_KEY environment variable "
"or use manual truncation instead."
)
responses = []
if endpoint_list is None:
endpoint_list = self.endpoint_list
for model_name in endpoint_list:
response = self.client.chat.completions.create(
model=model_name,
messages=msg,
temperature=0,
seed=42,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "json_object"}
)
responses.append(response)
votes = {}
first_response_for_token = {}
validator_votes = {} # Track each validator's vote
for idx, response in enumerate(responses):
model_name = endpoint_list[idx]
try:
res_json = json.loads(response.choices[0].message.content)
token_index = int(res_json["token_index"])
explanation = res_json.get("explanation", "")
# Track votes
votes[token_index] = votes.get(token_index, 0) + 1
if token_index not in first_response_for_token:
first_response_for_token[token_index] = (response, idx)
# Store each validator's vote details
validator_votes[model_name] = {
"token_index": token_index,
"error_token": self.tokenizer.decode([tokens[token_index]]) if token_index < len(tokens) else "N/A",
"explanation": explanation
}
except Exception as e:
logger.error(f"Error processing response from {model_name}: {e}")
validator_votes[model_name] = {
"token_index": -1,
"error_token": "Error",
"explanation": f"Failed to parse response: {str(e)}"
}
if not votes:
return responses[0] if responses else None, "", {}
max_votes = max(votes.values())
candidates = [t for t, c in votes.items() if c == max_votes]
# pick the candidate whose first corresponding response appeared earliest
winner_token = min(candidates, key=lambda t: first_response_for_token[t][1])
response = first_response_for_token[winner_token][0].choices[0].message.content
failure_rca = json.loads(response)
completion_before_err = ""
for idx, tok in enumerate(tokens):
if idx == failure_rca["token_index"]:
break
completion_before_err += f"{self.tokenizer.decode([tok])}"
explanation = failure_rca["explanation"]
# Add vote summary to vote_details
vote_summary = {
"winner_token_index": winner_token,
"winner_votes": max_votes,
"total_validators": len(endpoint_list),
"vote_distribution": votes
}
return completion_before_err, explanation, {
"validators": validator_votes,
"summary": vote_summary
}
def locate_error_token(self, prompt: str, completion: str, ground_truth: str = None,
validators: Optional[list] = None,
use_llm: bool = True,
manual_chunks: Optional[list] = None):
"""
Main method to locate the error token in a completion
Args:
prompt (str): The input prompt
completion (str): The completion text to analyze
ground_truth (str, optional): The correct answer/ground truth, defaults to None
Returns:
dict: Dictionary containing:
- status (str): "success" or "error"
- truncated_text (str): Prompt + completion truncated before error token
- explanation (str): Explanation of the error
- error_token_index (int): Index of the error token
"""
try:
# Construct prompt messages
msg, tokens = self.prompt_constructor(prompt, completion, ground_truth)
# If user requests to skip LLM search, use manual chunks if provided
if not use_llm:
if manual_chunks and len(manual_chunks) > 0:
manual_chunk = manual_chunks[0]
# manual_chunk already contains prompt + completion (set by frontend),
# so use it directly as truncated_text without prepending prompt again
truncated_text = manual_chunk
explanation = "Manual chunk provided by user (LLM search skipped)."
error_token_index = len(self.tokenizer(manual_chunk).input_ids)
return {
"status": "success",
"truncated_text": truncated_text,
"explanation": explanation,
"error_token_index": error_token_index
}
else:
return {
"status": "error",
"message": "LLM search disabled but no manual chunk provided.",
"truncated_text": "",
"explanation": ""
}
# Check if OpenAI client is available
if self.client is None:
return {
"status": "error",
"message": "OpenAI API key not configured. Please set OPENAI_API_KEY environment variable or use manual truncation (disable LLM search).",
"truncated_text": "",
"explanation": ""
}
# If validators provided, use them for this call
endpoint_list = validators if (validators and isinstance(validators, list) and len(validators) > 0) else None
# Call validator to get error token location with vote details
completion_before_err, explanation, vote_details = self.call_validator(msg, tokens, endpoint_list)
# Combine prompt with truncated completion
truncated_text = prompt + completion_before_err
# Calculate error token index
error_token_index = len(self.tokenizer(completion_before_err).input_ids)
return {
"status": "success",
"truncated_text": truncated_text,
"explanation": explanation,
"error_token_index": error_token_index,
"vote_details": vote_details
}
except Exception as e:
logger.error(f"Error in locate_error_token: {e}")
import traceback
traceback.print_exc()
return {
"status": "error",
"message": str(e),
"truncated_text": "",
"explanation": ""
}
|