File size: 12,307 Bytes
89280a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e252fa
 
 
89280a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s | %(levelname)-6s | %(name)-40s || %(message)s',
                    datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)

import os
import sys
from pathlib import Path

import yaml
import json
from typing import Optional
from jinja2 import Template

# Use env-var based OpenAI client instead of Azure-specific endpoint
_openai_client = None

def _get_openai_client():
    """Get or create OpenAI client using environment variables."""
    global _openai_client
    if _openai_client is not None:
        return _openai_client

    try:
        import openai
    except ImportError:
        logger.warning("openai package not installed. LLM-based error token localization will not be available.")
        return None

    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        logger.warning("OPENAI_API_KEY not set. LLM-based error token localization will not be available.")
        return None

    base_url = os.environ.get("OPENAI_BASE_URL", None)
    kwargs = {"api_key": api_key}
    if base_url:
        kwargs["base_url"] = base_url

    _openai_client = openai.OpenAI(**kwargs)
    return _openai_client


def _get_default_models():
    """Get the list of LLM models to use as validators from env or defaults."""
    models_str = os.environ.get("LLM_MODELS", "gpt-4o-mini")
    return [m.strip() for m in models_str.split(",") if m.strip()]


class ErrorTokenLocator:
    def __init__(self, model, tokenizer, prompt_template_yaml=None):
        """
        Initialize the error token locator

        Args:
            model: The language model to use
            tokenizer: The corresponding tokenizer for tokenizing text
            prompt_template_yaml (str, optional): Path to the prompt template YAML file, uses default template when None
        """
        self.model = model
        self.tokenizer = tokenizer

        self.client = _get_openai_client()

        self.endpoint_list = _get_default_models()

        if prompt_template_yaml is None:
            self.system_prompt = self.load_general_prompt_template()
        else:
            self.system_prompt = self.load_general_prompt_template(prompt_template_yaml)

    def prompt_constructor(self, query, completion, ground_truth: str=None):
        """
        Construct prompts for error token location

        Args:
            query (str): The input query/prompt
            completion (str): The completion text generated by the model
            ground_truth (str, optional): The correct answer/ground truth, defaults to None

        Returns:
            tuple: (msg, tokens)
                - msg (list): The constructed conversation message list containing system and user messages
                - tokens (list): List of token IDs for the completion
        """
        tokens = self.tokenizer(completion).input_ids
        indexed_completion = ""
        for i, tok in enumerate(tokens):
            indexed_completion += f"{self.tokenizer.decode([tok])}[{i}] "

        user_msg_content = {
            "prompt": query,
            "completion": completion,
            "indexed_completion": indexed_completion,
            "ground_truth": ground_truth
        }

        msg = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": json.dumps(user_msg_content, indent=2)}
        ]
        return msg, tokens

    def load_general_prompt_template(self, prompt_template_yaml=None):
        """
        Load the general prompt template

        Args:
            prompt_template_yaml (str, optional): Path to the YAML template file, uses default path when None

        Returns:
            str: The rendered system prompt template string
        """
        if prompt_template_yaml is None:
            # Use path relative to this file's location
            prompt_template_yaml = os.path.join(
                os.path.dirname(os.path.abspath(__file__)),
                "token_locator_prompts",
                "err_token_localization.yaml"
            )

        with open(prompt_template_yaml, "r", encoding='utf-8') as f:
            system_prompt_temp = yaml.safe_load(f)
        system_prompt = Template(system_prompt_temp['system']).render(
            dataset_description="No dataset description provided.",
            dataset_specific_instructions="-No dataset specific instructions provided."
        )

        return system_prompt

    def call_validator(self, msg, tokens, endpoint_list=None):
        """
        Call the validator for error token location validation

        Args:
            msg (list): The constructed conversation message list
            tokens (list): List of token IDs
            endpoint_list (list): List of validator endpoint names

        Returns:
            tuple: (completion_before_err, explanation, vote_details)
                - completion_before_err (str): Completion text truncated before the error token
                - explanation (str): Explanation of why this token is the error
                - vote_details (dict): Detailed voting information for each validator
        """
        if self.client is None:
            raise RuntimeError(
                "OpenAI client not available. Please set OPENAI_API_KEY environment variable "
                "or use manual truncation instead."
            )

        responses = []
        if endpoint_list is None:
            endpoint_list = self.endpoint_list
        for model_name in endpoint_list:
            response = self.client.chat.completions.create(
                model=model_name,
                messages=msg,
                temperature=0,
                seed=42,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0,
                response_format={"type": "json_object"}
            )
            responses.append(response)

        votes = {}
        first_response_for_token = {}
        validator_votes = {}  # Track each validator's vote

        for idx, response in enumerate(responses):
            model_name = endpoint_list[idx]
            try:
                res_json = json.loads(response.choices[0].message.content)
                token_index = int(res_json["token_index"])
                explanation = res_json.get("explanation", "")

                # Track votes
                votes[token_index] = votes.get(token_index, 0) + 1
                if token_index not in first_response_for_token:
                    first_response_for_token[token_index] = (response, idx)

                # Store each validator's vote details
                validator_votes[model_name] = {
                    "token_index": token_index,
                    "error_token": self.tokenizer.decode([tokens[token_index]]) if token_index < len(tokens) else "N/A",
                    "explanation": explanation
                }
            except Exception as e:
                logger.error(f"Error processing response from {model_name}: {e}")
                validator_votes[model_name] = {
                    "token_index": -1,
                    "error_token": "Error",
                    "explanation": f"Failed to parse response: {str(e)}"
                }

        if not votes:
            return responses[0] if responses else None, "", {}

        max_votes = max(votes.values())
        candidates = [t for t, c in votes.items() if c == max_votes]
        # pick the candidate whose first corresponding response appeared earliest
        winner_token = min(candidates, key=lambda t: first_response_for_token[t][1])

        response = first_response_for_token[winner_token][0].choices[0].message.content
        failure_rca = json.loads(response)
        completion_before_err = ""
        for idx, tok in enumerate(tokens):
            if idx == failure_rca["token_index"]:
                break
            completion_before_err += f"{self.tokenizer.decode([tok])}"
        explanation = failure_rca["explanation"]

        # Add vote summary to vote_details
        vote_summary = {
            "winner_token_index": winner_token,
            "winner_votes": max_votes,
            "total_validators": len(endpoint_list),
            "vote_distribution": votes
        }

        return completion_before_err, explanation, {
            "validators": validator_votes,
            "summary": vote_summary
        }

    def locate_error_token(self, prompt: str, completion: str, ground_truth: str = None,
                           validators: Optional[list] = None,
                           use_llm: bool = True,
                           manual_chunks: Optional[list] = None):
        """
        Main method to locate the error token in a completion

        Args:
            prompt (str): The input prompt
            completion (str): The completion text to analyze
            ground_truth (str, optional): The correct answer/ground truth, defaults to None

        Returns:
            dict: Dictionary containing:
                - status (str): "success" or "error"
                - truncated_text (str): Prompt + completion truncated before error token
                - explanation (str): Explanation of the error
                - error_token_index (int): Index of the error token
        """
        try:
            # Construct prompt messages
            msg, tokens = self.prompt_constructor(prompt, completion, ground_truth)

            # If user requests to skip LLM search, use manual chunks if provided
            if not use_llm:
                if manual_chunks and len(manual_chunks) > 0:
                    manual_chunk = manual_chunks[0]
                    # manual_chunk already contains prompt + completion (set by frontend),
                    # so use it directly as truncated_text without prepending prompt again
                    truncated_text = manual_chunk
                    explanation = "Manual chunk provided by user (LLM search skipped)."
                    error_token_index = len(self.tokenizer(manual_chunk).input_ids)
                    return {
                        "status": "success",
                        "truncated_text": truncated_text,
                        "explanation": explanation,
                        "error_token_index": error_token_index
                    }
                else:
                    return {
                        "status": "error",
                        "message": "LLM search disabled but no manual chunk provided.",
                        "truncated_text": "",
                        "explanation": ""
                    }

            # Check if OpenAI client is available
            if self.client is None:
                return {
                    "status": "error",
                    "message": "OpenAI API key not configured. Please set OPENAI_API_KEY environment variable or use manual truncation (disable LLM search).",
                    "truncated_text": "",
                    "explanation": ""
                }

            # If validators provided, use them for this call
            endpoint_list = validators if (validators and isinstance(validators, list) and len(validators) > 0) else None

            # Call validator to get error token location with vote details
            completion_before_err, explanation, vote_details = self.call_validator(msg, tokens, endpoint_list)

            # Combine prompt with truncated completion
            truncated_text = prompt + completion_before_err

            # Calculate error token index
            error_token_index = len(self.tokenizer(completion_before_err).input_ids)

            return {
                "status": "success",
                "truncated_text": truncated_text,
                "explanation": explanation,
                "error_token_index": error_token_index,
                "vote_details": vote_details
            }

        except Exception as e:
            logger.error(f"Error in locate_error_token: {e}")
            import traceback
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "truncated_text": "",
                "explanation": ""
            }