|
|
import ast |
|
|
import re, json |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Tuple, Optional, Callable |
|
|
from .utils import RawInput, img2base64, python_eval |
|
|
from .models import LLMModel, SamplingParams |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PIPSMode(Enum): |
|
|
AGENT = "AGENT" |
|
|
INTERACTIVE = "INTERACTIVE" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TokenCb = Callable[[str, int, str], None] |
|
|
CbMap = Dict[str, Callable[..., Any]] |
|
|
|
|
|
|
|
|
|
|
|
class PIPSSolver: |
|
|
"""Python Iterative Problem Solving (PIPS) solver — unified streaming & non-streaming.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: LLMModel, |
|
|
*, |
|
|
max_iterations: int = 8, |
|
|
temperature: float = 0.0, |
|
|
max_tokens: int = 50000, |
|
|
top_p: float = 1.0, |
|
|
interactive: bool = False, |
|
|
critic_model: Optional[LLMModel] = None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
model: An object that implements .chat(...) and, optionally, .stream_chat(...). |
|
|
max_iterations: Maximum refinement loops for the code-generation mode. |
|
|
temperature: Sampling temperature passed to the LLM. |
|
|
max_tokens: Max tokens for each LLM response. |
|
|
top_p: Nucleus-sampling parameter. |
|
|
interactive: Whether to use interactive mode (wait for user feedback). |
|
|
critic_model: Optional separate model for criticism (defaults to main model). |
|
|
""" |
|
|
self.model = model |
|
|
self.critic_model = critic_model or model |
|
|
self.max_iterations = max_iterations |
|
|
self.temperature = temperature |
|
|
self.max_tokens = max_tokens |
|
|
self.top_p = top_p |
|
|
self.interactive = interactive |
|
|
self._mode_decision_summary: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
self._checkpoint = None |
|
|
self._current_conversation = None |
|
|
|
|
|
|
|
|
self.system_prompt = """You will be given a question and you must answer it by extracting relevant symbols in JSON format and then writing a Python program to calculate the final answer. |
|
|
|
|
|
You MUST always plan extensively before outputting any symbols or code. |
|
|
|
|
|
You MUST iterate and keep going until the problem is solved. |
|
|
|
|
|
# Workflow |
|
|
|
|
|
## Problem Solving Steps |
|
|
1. First extract relevant information from the input as JSON. Try to represent the relevant information in as much of a structured format as possible to help with further reasoning/processing. |
|
|
2. Using the information extracted, determine a reasonable approach to solving the problem using code, such that executing the code will return the final answer. |
|
|
3. Write a Python program to calculate and return the final answer. Use comments to explain the structure of the code and do not use a main() function. |
|
|
The JSON must be enclosed in a markdown code block and the Python function must be in a separate markdown code block and be called `solve` and accept a single input called `symbols` representing the JSON information extracted. Do not include any `if __name__ == "__main__"` statement and you can assume the JSON will be loaded into the variable called `symbols` by the user. |
|
|
The Python code should not just return the answer or perform all reasoning in comments and instead leverage the code itself to perform the reasoning. |
|
|
Be careful that the code returns the answer as expected by the question, for instance, if the question is multiple choice, the code must return the choice as described in the question. |
|
|
Be sure to always output a JSON code block and a Python code block. |
|
|
Make sure to follow these formatting requirements exactly. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
_MODE_SELECTION_LIST_RE = re.compile(r"\[([0-9eE+.\s,-]+)\]") |
|
|
|
|
|
def _parse_probability_scores(self, raw: str) -> Optional[List[float]]: |
|
|
"""Extract a list of 10 probability scores from raw LLM output.""" |
|
|
if not raw: |
|
|
return None |
|
|
|
|
|
candidates: List[Any] = [] |
|
|
|
|
|
try: |
|
|
parsed = ast.literal_eval(raw.strip()) |
|
|
candidates.append(parsed) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
for match in self._MODE_SELECTION_LIST_RE.finditer(raw): |
|
|
candidate_str = f"[{match.group(1)}]" |
|
|
try: |
|
|
candidates.append(ast.literal_eval(candidate_str)) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
for candidate in candidates: |
|
|
if ( |
|
|
isinstance(candidate, list) |
|
|
and len(candidate) == 10 |
|
|
and all(isinstance(x, (int, float)) for x in candidate) |
|
|
): |
|
|
floats = [float(x) for x in candidate] |
|
|
if all(0.0 <= x <= 1.0 for x in floats): |
|
|
return floats |
|
|
return None |
|
|
|
|
|
def _build_mode_selection_prompt(self, sample: RawInput) -> List[dict[str, Any]]: |
|
|
"""Create the conversation for deciding between code and chain-of-thought.""" |
|
|
from .prompts import CHOOSE_CONSERVATIVE_COT_VS_CODE_PROMPT |
|
|
|
|
|
instructions = CHOOSE_CONSERVATIVE_COT_VS_CODE_PROMPT.strip() |
|
|
extra_instruction = ( |
|
|
"\nAt the end of your response, output only the list of 10 probabilities inside square brackets " |
|
|
"after the text 'FINAL ANSWER:'." |
|
|
) |
|
|
|
|
|
content: List[dict[str, Any]] = [{"type": "text", "text": f"{instructions}{extra_instruction}"}] |
|
|
|
|
|
if sample.image_input is not None: |
|
|
content.append( |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": f"data:image/jpeg;base64,{img2base64(sample.image_input)}", |
|
|
"detail": "high", |
|
|
}, |
|
|
} |
|
|
) |
|
|
if sample.text_input is not None: |
|
|
content.append({"type": "text", "text": f"TARGET QUESTION:\n{sample.text_input}"}) |
|
|
|
|
|
return [{"role": "user", "content": content}] |
|
|
|
|
|
def _summarise_messages_for_log(self, messages: List[dict[str, Any]]) -> List[dict[str, Any]]: |
|
|
"""Return a copy of the conversation with image payloads redacted for logging.""" |
|
|
summary: List[dict[str, Any]] = [] |
|
|
for message in messages: |
|
|
content = message.get("content") |
|
|
if isinstance(content, list): |
|
|
new_content: List[dict[str, Any]] = [] |
|
|
for item in content: |
|
|
if isinstance(item, dict) and item.get("type") == "image_url": |
|
|
new_content.append({"type": "text", "text": "[image content omitted]"}) |
|
|
else: |
|
|
new_content.append(item) |
|
|
summary.append({**message, "content": new_content}) |
|
|
else: |
|
|
summary.append(dict(message)) |
|
|
return summary |
|
|
|
|
|
def _decide_solving_mode( |
|
|
self, |
|
|
messages: List[dict[str, Any]], |
|
|
*, |
|
|
max_tokens: int, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run the self-reflection prompt to choose between code and chain-of-thought.""" |
|
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, top_p=1.0) |
|
|
|
|
|
try: |
|
|
response = self.model.chat(messages, sampling_params=sampling_params, use_tqdm=False) |
|
|
except Exception as exc: |
|
|
print(f"[DEBUG] Mode selection prompt raised exception: {exc}. Falling back to chain-of-thought.") |
|
|
return { |
|
|
"use_code": False, |
|
|
"scores": None, |
|
|
"average": None, |
|
|
"raw_response": "", |
|
|
"error": str(exc), |
|
|
} |
|
|
|
|
|
raw_text = "" |
|
|
if response and getattr(response[0], "outputs", None): |
|
|
raw_text = response[0].outputs[0].text or "" |
|
|
|
|
|
scores = self._parse_probability_scores(raw_text) |
|
|
if scores is None: |
|
|
print("[DEBUG] Mode selection prompt failed to yield valid probability list; defaulting to chain-of-thought.") |
|
|
return { |
|
|
"use_code": False, |
|
|
"scores": None, |
|
|
"average": None, |
|
|
"raw_response": raw_text, |
|
|
"error": None, |
|
|
} |
|
|
|
|
|
average = sum(scores) / len(scores) |
|
|
use_code = average > 0.5 |
|
|
|
|
|
return { |
|
|
"use_code": use_code, |
|
|
"scores": scores, |
|
|
"average": average, |
|
|
"raw_response": raw_text, |
|
|
"error": None, |
|
|
} |
|
|
|
|
|
def _chat( |
|
|
self, |
|
|
conversation: List[Dict[str, Any]], |
|
|
sampling_params: SamplingParams, |
|
|
stream: bool, |
|
|
iteration: int, |
|
|
callbacks: Optional[CbMap] = None, |
|
|
) -> str: |
|
|
""" |
|
|
Wrapper around model.chat / model.stream_chat that: |
|
|
• chooses the right API based on `stream` |
|
|
• fires streaming callbacks if supplied |
|
|
• returns the full assistant text |
|
|
""" |
|
|
callbacks = callbacks or {} |
|
|
|
|
|
|
|
|
on_start = callbacks.get("on_llm_streaming_start", lambda *a, **k: None) |
|
|
on_token = callbacks.get("on_llm_streaming_token", lambda *a, **k: None) |
|
|
on_end = callbacks.get("on_llm_streaming_end", lambda *a, **k: None) |
|
|
interrupted = callbacks.get("check_interrupted", lambda: False) |
|
|
|
|
|
model_name = self.model.__class__.__name__ |
|
|
|
|
|
if not stream: |
|
|
|
|
|
resp = self.model.chat(conversation, sampling_params=sampling_params, use_tqdm=False) |
|
|
return resp[0].outputs[0].text |
|
|
|
|
|
|
|
|
on_start(iteration, model_name) |
|
|
|
|
|
def _emit(tok: str): |
|
|
if not interrupted(): |
|
|
on_token(tok, iteration, model_name) |
|
|
|
|
|
if hasattr(self.model, "stream_chat"): |
|
|
resp = self.model.stream_chat( |
|
|
conversation, |
|
|
sampling_params=sampling_params, |
|
|
emit_callback=_emit, |
|
|
interrupted_callback=interrupted, |
|
|
) |
|
|
else: |
|
|
resp = self.model.chat(conversation, sampling_params=sampling_params, use_tqdm=False) |
|
|
|
|
|
on_end(iteration, model_name) |
|
|
return resp[0].outputs[0].text |
|
|
|
|
|
|
|
|
|
|
|
def solve( |
|
|
self, |
|
|
sample: RawInput, |
|
|
*, |
|
|
stream: bool = False, |
|
|
callbacks: Optional[CbMap] = None, |
|
|
additional_rules: str = "", |
|
|
decision_max_tokens: int = 1024, |
|
|
interactive_requested: bool = False, |
|
|
) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: |
|
|
"""Automatically choose between chain-of-thought and code-based solving.""" |
|
|
callbacks = callbacks or {} |
|
|
step = callbacks.get("on_step_update", lambda *a, **k: None) |
|
|
|
|
|
decision_messages = self._build_mode_selection_prompt(sample) |
|
|
decision_prompt_details = { |
|
|
"description": "Choosing between chain-of-thought and iterative coding", |
|
|
"conversation": self._summarise_messages_for_log(decision_messages), |
|
|
} |
|
|
|
|
|
step( |
|
|
"mode_selection", |
|
|
"Choosing between chain-of-thought reasoning and iterative coding…", |
|
|
prompt_details=decision_prompt_details, |
|
|
) |
|
|
|
|
|
decision = self._decide_solving_mode(decision_messages, max_tokens=decision_max_tokens) |
|
|
use_code = decision.get("use_code", False) |
|
|
average = decision.get("average") |
|
|
scores = decision.get("scores") |
|
|
decision_error = decision.get("error") |
|
|
|
|
|
if scores is None: |
|
|
decision_message = "Could not parse confidence scores; defaulting to chain-of-thought reasoning." |
|
|
else: |
|
|
decision_message = ( |
|
|
f"Average code suitability score: {average:.2f}. " |
|
|
f"Proceeding with {'iterative code generation' if use_code else 'chain-of-thought reasoning'}." |
|
|
) |
|
|
|
|
|
step( |
|
|
"mode_selection", |
|
|
decision_message, |
|
|
prompt_details={**decision_prompt_details, "raw_response": decision.get("raw_response", ""), "error": decision_error}, |
|
|
) |
|
|
|
|
|
if interactive_requested and not use_code: |
|
|
step( |
|
|
"mode_selection", |
|
|
"Interactive mode requested, but chain-of-thought was selected; running without interactive checkpoints.", |
|
|
prompt_details=None, |
|
|
) |
|
|
|
|
|
mode_decision_summary = { |
|
|
"use_code": use_code, |
|
|
"scores": scores, |
|
|
"average_score": average, |
|
|
"raw_response": decision.get("raw_response", ""), |
|
|
"prompt": decision_prompt_details["conversation"], |
|
|
"error": decision_error, |
|
|
} |
|
|
self._mode_decision_summary = mode_decision_summary |
|
|
|
|
|
original_interactive = self.interactive |
|
|
if not use_code: |
|
|
self.interactive = False |
|
|
|
|
|
try: |
|
|
if use_code: |
|
|
answer, logs = self.solve_with_code( |
|
|
sample, |
|
|
stream=stream, |
|
|
callbacks=callbacks, |
|
|
additional_rules=additional_rules, |
|
|
) |
|
|
else: |
|
|
answer, logs = self.solve_chain_of_thought( |
|
|
sample, |
|
|
stream=stream, |
|
|
callbacks=callbacks, |
|
|
additional_rules=additional_rules, |
|
|
) |
|
|
finally: |
|
|
self.interactive = original_interactive |
|
|
|
|
|
if isinstance(logs, dict): |
|
|
logs.setdefault("mode_decision", mode_decision_summary) |
|
|
|
|
|
return answer, logs, mode_decision_summary |
|
|
|
|
|
def _extract_components(self, output: str) -> Tuple[Any, str, str]: |
|
|
"""(unchanged helper) extract JSON, code, and reasoning.""" |
|
|
json_obj, code_str, reasoning = "", "", "" |
|
|
try: |
|
|
if m := re.findall(r"```json(.*?)```", output, re.DOTALL): |
|
|
json_obj = json.loads(m[-1]) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
j_end = output.index("```", output.index("```json") + 7) + 3 |
|
|
p_start = output.index("```python", j_end) |
|
|
reasoning = output[j_end:p_start].strip() |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
if m := re.findall(r"```python(.*?)```", output, re.DOTALL): |
|
|
code_str = m[-1] |
|
|
except Exception: |
|
|
pass |
|
|
return json_obj, code_str, reasoning |
|
|
|
|
|
|
|
|
|
|
|
def solve_chain_of_thought( |
|
|
self, |
|
|
sample: RawInput, |
|
|
*, |
|
|
stream: bool = False, |
|
|
callbacks: Optional[CbMap] = None, |
|
|
additional_rules: str = "", |
|
|
) -> Tuple[str, Dict[str, Any]]: |
|
|
""" |
|
|
One implementation covers both streaming & non-streaming. |
|
|
If `stream=True`, supply the standard streaming callbacks. |
|
|
""" |
|
|
callbacks = callbacks or {} |
|
|
step = callbacks.get("on_step_update", lambda *a, **k: None) |
|
|
logs: Dict[str, Any] = {} |
|
|
|
|
|
|
|
|
system_content = "" |
|
|
if additional_rules.strip(): |
|
|
system_content = f"Additional Requirements:\n{additional_rules.strip()}\n\nMake sure to follow these additional requirements when answering." |
|
|
print(f"[DEBUG] Added custom rules to chain of thought prompt: {repr(additional_rules)}") |
|
|
|
|
|
if sample.image_input is not None: |
|
|
img_b64 = img2base64(sample.image_input) |
|
|
user_content = [ |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, |
|
|
{"type": "text", "text": f"Question: {sample.text_input}"}, |
|
|
{"type": "text", "text": "Answer step-by-step and finish with 'FINAL ANSWER:'"}, |
|
|
] |
|
|
else: |
|
|
user_content = f"Question: {sample.text_input}\nAnswer step-by-step and finish with 'FINAL ANSWER:'." |
|
|
|
|
|
prompt = [] |
|
|
if system_content: |
|
|
prompt.append({"role": "system", "content": system_content}) |
|
|
prompt.append({"role": "user", "content": user_content}) |
|
|
params = SamplingParams(self.temperature, self.max_tokens, self.top_p) |
|
|
|
|
|
|
|
|
cot_prompt_details = { |
|
|
"description": "Chain of thought reasoning", |
|
|
"conversation": prompt |
|
|
} |
|
|
|
|
|
step("reasoning", "Thinking step-by-step...", prompt_details=cot_prompt_details) |
|
|
|
|
|
|
|
|
output = self._chat(prompt, params, stream, iteration=0, callbacks=callbacks) |
|
|
logs["output"] = output |
|
|
|
|
|
|
|
|
ans = "" |
|
|
try: |
|
|
ans = re.findall(r"FINAL ANSWER:(.*)", output, re.DOTALL)[-1].strip() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
interrupted = callbacks.get("check_interrupted", lambda: False) |
|
|
if interrupted(): |
|
|
step("interrupted", "PIPS was interrupted by the user.", prompt_details=None) |
|
|
else: |
|
|
step("finished", "Chain of thought completed!", prompt_details=None) |
|
|
|
|
|
final = f"FINAL ANSWER: {ans}" if ans else output |
|
|
logs["final_answer"] = ans |
|
|
return final, logs |
|
|
|
|
|
|
|
|
|
|
|
def solve_with_code( |
|
|
self, |
|
|
sample: RawInput, |
|
|
*, |
|
|
stream: bool = False, |
|
|
callbacks: Optional[CbMap] = None, |
|
|
additional_rules: str = "", |
|
|
) -> Tuple[str, Dict[str, Any]]: |
|
|
""" |
|
|
Iterative code-generation solver (streaming or not). |
|
|
`callbacks` is optional; provide it only when you care about |
|
|
fine-grained streaming events. |
|
|
Args: |
|
|
sample: The raw input containing text and/or image. |
|
|
stream: Whether to stream tokens from the underlying LLM. |
|
|
callbacks: Optional callback map for streaming & execution events. |
|
|
additional_rules: Extra natural-language rules that will be forwarded to the internal code critic for more specialized checking. |
|
|
""" |
|
|
callbacks = callbacks or {} |
|
|
interrupted = callbacks.get("check_interrupted", lambda: False) |
|
|
step = callbacks.get("on_step_update", lambda *a, **k: None) |
|
|
|
|
|
logs = {"all_outputs": [], "all_symbols": [], "all_programs": [], "all_reasoning": []} |
|
|
|
|
|
|
|
|
if interrupted(): |
|
|
return "", logs |
|
|
|
|
|
|
|
|
|
|
|
system_content = self.system_prompt |
|
|
if additional_rules.strip(): |
|
|
system_content += f"\n\nAdditional Requirements: \n{additional_rules.strip()}\n\n Make sure to follow these additional requirements when generating your solution." |
|
|
print(f"[DEBUG] Added custom rules to initial code generation prompt: {repr(additional_rules)}") |
|
|
|
|
|
if sample.image_input is not None: |
|
|
img_b64 = img2base64(sample.image_input) |
|
|
content = [ |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, |
|
|
{"type": "text", "text": f"Question: {sample.text_input}"}, |
|
|
] |
|
|
else: |
|
|
content = f"Question: {sample.text_input}" |
|
|
|
|
|
conv = [ |
|
|
{"role": "system", "content": system_content}, |
|
|
{"role": "user", "content": content}, |
|
|
] |
|
|
params = SamplingParams(self.temperature, self.max_tokens, self.top_p) |
|
|
|
|
|
|
|
|
initial_prompt_details = { |
|
|
"description": "Initial solution generation", |
|
|
"conversation": conv |
|
|
} |
|
|
|
|
|
step("initial_generation", "Generating first solution…", prompt_details=initial_prompt_details) |
|
|
raw = self._chat(conv, params, stream, iteration=0, callbacks=callbacks) |
|
|
logs["all_outputs"].append(raw) |
|
|
conv.append({"role": "assistant", "content": raw}) |
|
|
|
|
|
|
|
|
current_symbols, current_code, reasoning = self._extract_components(raw) |
|
|
logs["all_symbols"].append(current_symbols) |
|
|
logs["all_programs"].append(current_code) |
|
|
if reasoning: |
|
|
logs["all_reasoning"].append(reasoning) |
|
|
|
|
|
|
|
|
exec_out, stdout, err = self._run_code(current_symbols, current_code, 0, callbacks, logs) |
|
|
for i in range(1, self.max_iterations + 1): |
|
|
if interrupted(): |
|
|
break |
|
|
|
|
|
|
|
|
feedback = self._critic( |
|
|
question=sample.text_input, |
|
|
code=current_code, |
|
|
symbols=current_symbols, |
|
|
out=exec_out, |
|
|
stdout=stdout, |
|
|
err=err, |
|
|
params=params, |
|
|
additional_rules=additional_rules, |
|
|
stream=stream, |
|
|
iteration=i, |
|
|
callbacks=callbacks, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.interactive: |
|
|
print(f"[DEBUG] Interactive mode triggered at iteration {i}") |
|
|
|
|
|
on_waiting_for_user = callbacks.get("on_waiting_for_user", lambda *a, **k: None) |
|
|
on_waiting_for_user(i, feedback, current_code, current_symbols) |
|
|
print(f"[DEBUG] Emitted awaiting_user_feedback event") |
|
|
|
|
|
|
|
|
self._checkpoint = { |
|
|
"sample": sample, |
|
|
"logs": logs, |
|
|
"conv": conv, |
|
|
"symbols": current_symbols, |
|
|
"code": current_code, |
|
|
"exec_out": exec_out, |
|
|
"stdout": stdout, |
|
|
"err": err, |
|
|
"feedback": feedback, |
|
|
"iteration": i, |
|
|
"params": params, |
|
|
"additional_rules": additional_rules, |
|
|
"stream": stream, |
|
|
"callbacks": callbacks |
|
|
} |
|
|
|
|
|
|
|
|
return "", logs |
|
|
|
|
|
|
|
|
fix_prompt = self._fix_prompt(sample.text_input, current_code, current_symbols, exec_out, stdout, err, feedback) |
|
|
conv.append({"role": "user", "content": fix_prompt}) |
|
|
|
|
|
|
|
|
refinement_prompt_details = { |
|
|
"description": f"Solution refinement (iteration {i})", |
|
|
"conversation": conv |
|
|
} |
|
|
|
|
|
step("refinement", f"Refining solution (iteration {i})...", iteration=i, prompt_details=refinement_prompt_details) |
|
|
raw = self._chat(conv, params, stream, iteration=i, callbacks=callbacks) |
|
|
logs["all_outputs"].append(raw) |
|
|
conv.append({"role": "assistant", "content": raw}) |
|
|
|
|
|
if "FINISHED" in raw: |
|
|
break |
|
|
|
|
|
|
|
|
new_symbols, new_code, reasoning = self._extract_components(raw) |
|
|
if new_symbols: |
|
|
current_symbols = new_symbols |
|
|
logs["all_symbols"].append(new_symbols) |
|
|
if new_code: |
|
|
current_code = new_code |
|
|
logs["all_programs"].append(new_code) |
|
|
if reasoning: |
|
|
logs["all_reasoning"].append(reasoning) |
|
|
|
|
|
exec_out, stdout, err = self._run_code(current_symbols, current_code, i, callbacks, logs) |
|
|
|
|
|
|
|
|
if interrupted(): |
|
|
step("interrupted", "PIPS was interrupted by the user.", prompt_details=None) |
|
|
else: |
|
|
step("finished", "Solution completed successfully!", prompt_details=None) |
|
|
|
|
|
final = f"FINAL ANSWER: {exec_out}" |
|
|
return final, logs |
|
|
|
|
|
|
|
|
|
|
|
def continue_from_checkpoint(self, user_feedback: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: |
|
|
""" |
|
|
Continue solving from a saved checkpoint with user feedback. |
|
|
|
|
|
Args: |
|
|
user_feedback: Dictionary containing user feedback with keys: |
|
|
- accept_critic: bool - whether to accept critic's feedback |
|
|
- extra_comments: str - additional user comments |
|
|
- quoted_ranges: list - specific code snippets user highlighted |
|
|
- terminate: bool - whether user wants to terminate |
|
|
|
|
|
Returns: |
|
|
Final answer and logs |
|
|
""" |
|
|
if not self._checkpoint: |
|
|
raise ValueError("No checkpoint available - cannot continue interactive mode") |
|
|
|
|
|
checkpoint = self._checkpoint |
|
|
user_feedback = user_feedback or {} |
|
|
|
|
|
|
|
|
if user_feedback.get("terminate", False): |
|
|
final = f"FINAL ANSWER: {checkpoint['exec_out']}" |
|
|
return final, checkpoint["logs"] |
|
|
|
|
|
|
|
|
merged_feedback = self.merge_user_feedback( |
|
|
checkpoint["feedback"], |
|
|
user_feedback.get("accept_critic", True), |
|
|
user_feedback.get("quoted_ranges", []) |
|
|
) |
|
|
|
|
|
|
|
|
has_user_feedback = bool(user_feedback.get("quoted_ranges", [])) |
|
|
|
|
|
|
|
|
current_symbols = checkpoint["symbols"] |
|
|
current_code = checkpoint["code"] |
|
|
exec_out = checkpoint["exec_out"] |
|
|
stdout = checkpoint["stdout"] |
|
|
err = checkpoint["err"] |
|
|
|
|
|
fix_prompt = self._fix_prompt( |
|
|
checkpoint["sample"].text_input, |
|
|
current_code, |
|
|
current_symbols, |
|
|
exec_out, |
|
|
stdout, |
|
|
err, |
|
|
merged_feedback, |
|
|
has_user_feedback |
|
|
) |
|
|
|
|
|
checkpoint["conv"].append({"role": "user", "content": fix_prompt}) |
|
|
|
|
|
|
|
|
refinement_prompt_details = { |
|
|
"description": f"Solution refinement (iteration {checkpoint['iteration']})", |
|
|
"conversation": checkpoint["conv"] |
|
|
} |
|
|
|
|
|
step = checkpoint["callbacks"].get("on_step_update", lambda *a, **k: None) |
|
|
step("refinement", f"Refining solution (iteration {checkpoint['iteration']})...", |
|
|
iteration=checkpoint['iteration'], prompt_details=refinement_prompt_details) |
|
|
|
|
|
raw = self._chat(checkpoint["conv"], checkpoint["params"], checkpoint["stream"], |
|
|
iteration=checkpoint['iteration'], callbacks=checkpoint["callbacks"]) |
|
|
|
|
|
checkpoint["logs"]["all_outputs"].append(raw) |
|
|
checkpoint["conv"].append({"role": "assistant", "content": raw}) |
|
|
|
|
|
if "FINISHED" in raw: |
|
|
final = f"FINAL ANSWER: {checkpoint['exec_out']}" |
|
|
return final, checkpoint["logs"] |
|
|
|
|
|
|
|
|
new_symbols, new_code, reasoning = self._extract_components(raw) |
|
|
if new_symbols: |
|
|
current_symbols = new_symbols |
|
|
checkpoint["logs"]["all_symbols"].append(new_symbols) |
|
|
if new_code: |
|
|
current_code = new_code |
|
|
checkpoint["logs"]["all_programs"].append(new_code) |
|
|
if reasoning: |
|
|
checkpoint["logs"]["all_reasoning"].append(reasoning) |
|
|
|
|
|
exec_out, stdout, err = self._run_code(current_symbols, current_code, checkpoint['iteration'], |
|
|
checkpoint["callbacks"], checkpoint["logs"]) |
|
|
checkpoint["symbols"] = current_symbols |
|
|
checkpoint["code"] = current_code |
|
|
checkpoint["exec_out"] = exec_out |
|
|
checkpoint["stdout"] = stdout |
|
|
checkpoint["err"] = err |
|
|
|
|
|
|
|
|
original_interactive = self.interactive |
|
|
self.interactive = False |
|
|
|
|
|
|
|
|
remaining_iterations = self.max_iterations - checkpoint['iteration'] |
|
|
if remaining_iterations > 0: |
|
|
|
|
|
sample = checkpoint["sample"] |
|
|
|
|
|
|
|
|
for i in range(checkpoint['iteration'] + 1, self.max_iterations + 1): |
|
|
interrupted = checkpoint["callbacks"].get("check_interrupted", lambda: False) |
|
|
if interrupted(): |
|
|
break |
|
|
|
|
|
feedback = self._critic( |
|
|
question=sample.text_input, |
|
|
code=current_code, |
|
|
symbols=current_symbols, |
|
|
out=exec_out, |
|
|
stdout=stdout, |
|
|
err=err, |
|
|
params=checkpoint["params"], |
|
|
additional_rules=checkpoint["additional_rules"], |
|
|
stream=checkpoint["stream"], |
|
|
iteration=i, |
|
|
callbacks=checkpoint["callbacks"], |
|
|
) |
|
|
|
|
|
fix_prompt = self._fix_prompt(sample.text_input, current_code, current_symbols, exec_out, stdout, err, feedback) |
|
|
checkpoint["conv"].append({"role": "user", "content": fix_prompt}) |
|
|
|
|
|
refinement_prompt_details = { |
|
|
"description": f"Solution refinement (iteration {i})", |
|
|
"conversation": checkpoint["conv"] |
|
|
} |
|
|
|
|
|
step("refinement", f"Refining solution (iteration {i})...", |
|
|
iteration=i, prompt_details=refinement_prompt_details) |
|
|
|
|
|
raw = self._chat(checkpoint["conv"], checkpoint["params"], checkpoint["stream"], |
|
|
iteration=i, callbacks=checkpoint["callbacks"]) |
|
|
|
|
|
checkpoint["logs"]["all_outputs"].append(raw) |
|
|
checkpoint["conv"].append({"role": "assistant", "content": raw}) |
|
|
|
|
|
if "FINISHED" in raw: |
|
|
break |
|
|
|
|
|
new_symbols, new_code, reasoning = self._extract_components(raw) |
|
|
if new_symbols: |
|
|
current_symbols = new_symbols |
|
|
checkpoint["logs"]["all_symbols"].append(new_symbols) |
|
|
if new_code: |
|
|
current_code = new_code |
|
|
checkpoint["logs"]["all_programs"].append(new_code) |
|
|
if reasoning: |
|
|
checkpoint["logs"]["all_reasoning"].append(reasoning) |
|
|
|
|
|
exec_out, stdout, err = self._run_code(current_symbols, current_code, i, checkpoint["callbacks"], checkpoint["logs"]) |
|
|
checkpoint["symbols"] = current_symbols |
|
|
checkpoint["code"] = current_code |
|
|
checkpoint["exec_out"] = exec_out |
|
|
checkpoint["stdout"] = stdout |
|
|
checkpoint["err"] = err |
|
|
|
|
|
|
|
|
self.interactive = original_interactive |
|
|
|
|
|
|
|
|
self._checkpoint = None |
|
|
|
|
|
final = f"FINAL ANSWER: {exec_out}" |
|
|
return final, checkpoint["logs"] |
|
|
|
|
|
def merge_user_feedback(self, critic_feedback: str, accept_critic: bool, |
|
|
quoted_ranges: List[Dict]) -> str: |
|
|
""" |
|
|
Merge critic feedback with user feedback. |
|
|
|
|
|
Args: |
|
|
critic_feedback: Original feedback from the critic |
|
|
accept_critic: Whether to include critic's feedback |
|
|
quoted_ranges: List of user feedback items (general comments, code feedback, symbol feedback) |
|
|
|
|
|
Returns: |
|
|
Merged feedback string |
|
|
""" |
|
|
feedback_parts = [] |
|
|
|
|
|
if accept_critic and critic_feedback: |
|
|
feedback_parts.append("AI Critic's feedback:") |
|
|
feedback_parts.append(critic_feedback) |
|
|
|
|
|
if quoted_ranges: |
|
|
|
|
|
general_comments = [] |
|
|
specific_feedback = [] |
|
|
|
|
|
for item in quoted_ranges: |
|
|
if not item.get("comment"): |
|
|
continue |
|
|
|
|
|
if item.get("type") == "general" or not item.get("text"): |
|
|
general_comments.append(item["comment"]) |
|
|
else: |
|
|
specific_feedback.append(item) |
|
|
|
|
|
|
|
|
if general_comments: |
|
|
feedback_parts.append("User feedback:") |
|
|
feedback_parts.extend(general_comments) |
|
|
|
|
|
|
|
|
if specific_feedback: |
|
|
feedback_parts.append("Specific code feedback:") |
|
|
for item in specific_feedback: |
|
|
feedback_parts.append(f"Regarding: {item['text']}") |
|
|
feedback_parts.append(f"Comment: {item['comment']}") |
|
|
|
|
|
return "\n\n".join(feedback_parts) if feedback_parts else "No specific issues identified." |
|
|
|
|
|
|
|
|
|
|
|
def _run_code( |
|
|
self, |
|
|
symbols: Any, |
|
|
code: str, |
|
|
iteration: int, |
|
|
callbacks: CbMap, |
|
|
logs: Dict[str, Any], |
|
|
) -> Tuple[str, str, str]: |
|
|
"""Execute candidate code, emit callbacks, store logs, return (out, stdout, err).""" |
|
|
on_exec_start = callbacks.get("on_code_execution_start", lambda *a, **k: None) |
|
|
on_exec_end = callbacks.get("on_code_execution_end", lambda *a, **k: None) |
|
|
on_exec = callbacks.get("on_code_execution", lambda *a, **k: None) |
|
|
max_time = callbacks.get("get_max_execution_time", lambda: 10)() |
|
|
|
|
|
on_exec_start(iteration) |
|
|
try: |
|
|
out, std, err = python_eval( |
|
|
f"{code}\nsymbols = {str(symbols)}\nanswer = solve(symbols)", |
|
|
max_execution_time=max_time, |
|
|
) |
|
|
except Exception as e: |
|
|
out, std, err = "None", "", str(e) |
|
|
|
|
|
on_exec_end(iteration) |
|
|
on_exec(iteration, str(out), std, err) |
|
|
logs.setdefault("execution_results", []).append({"output": out, "stdout": std, "error": err}) |
|
|
return str(out), std, err |
|
|
|
|
|
|
|
|
|
|
|
def _critic( |
|
|
self, |
|
|
question: str, |
|
|
code: str, |
|
|
symbols: Any, |
|
|
out: str, |
|
|
stdout: str, |
|
|
err: str, |
|
|
params: SamplingParams, |
|
|
additional_rules: str = "", |
|
|
stream: bool = False, |
|
|
iteration: int = 1, |
|
|
callbacks: Optional[CbMap] = None, |
|
|
) -> str: |
|
|
"""Ask the model to critique the code once per iteration.""" |
|
|
system_content = f"""You will be given a question and a code solution and you must judge the quality of the code for solving the problem. |
|
|
|
|
|
Look for any of the following issues in the code: |
|
|
- The code should be input dependent, meaning it should use the input symbols to compute the answer. It is OK for the code to be specialized to the input (i.e. the reasoning itself may be hardcoded, like a decision tree where the branches are hardcoded). |
|
|
- The code should not return None unless "None" is the correct answer. |
|
|
- The code should return the answer, not just print it. If the question asks for a multiple choice answer, the code should return the choice as described in the question. |
|
|
- There should not be any example usage of the code. |
|
|
- If there is a simpler way to solve the problem, please describe it. |
|
|
- If there are any clear bugs in the code which impact the correctness of the answer, please describe them. |
|
|
- If there are any issues with the extracted symbols, please describe them as well, but separate these issues from the issues with the code. |
|
|
- If it is possible to sanity check the output of the code, please do so and describe if there are any obvious issues with the output and how the code could be fixed to avoid these issues. |
|
|
|
|
|
{"Additional issues and specifications to looks for: " if additional_rules else ""} |
|
|
{additional_rules} |
|
|
|
|
|
After analyzing the code in depth, output a concrete and concise summary of the issues that are present, do not include any code examples. Please order the issues by impact on answer correctness.""" |
|
|
|
|
|
user_content = f"""Question: {question} |
|
|
|
|
|
The following are extracted symbols from the question in JSON format followed by a Python program which takes the JSON as an argument called `symbols` and computes the answer. |
|
|
```json |
|
|
{json.dumps(symbols, indent=2)} |
|
|
``` |
|
|
|
|
|
```python |
|
|
{code} |
|
|
``` |
|
|
|
|
|
Code execution result: |
|
|
``` |
|
|
Return value: {out} |
|
|
Standard output: {stdout} |
|
|
Exceptions: {err} |
|
|
``` |
|
|
|
|
|
Output a concrete and concise summary of only the issues that are present, do not include any code examples. |
|
|
""" |
|
|
|
|
|
prompt = [ |
|
|
{"role": "system", "content": system_content}, |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
|
|
|
|
|
|
critic_prompt_details = { |
|
|
"description": f"Code quality analysis and critique (iteration {iteration})", |
|
|
"conversation": prompt |
|
|
} |
|
|
|
|
|
|
|
|
callbacks = callbacks or {} |
|
|
step = callbacks.get("on_step_update", lambda *a, **k: None) |
|
|
step("code_checking", f"Running code critic (iteration {iteration})...", iteration=iteration, prompt_details=critic_prompt_details) |
|
|
|
|
|
if not stream: |
|
|
|
|
|
return self.critic_model.chat(prompt, sampling_params=params, use_tqdm=False)[0].outputs[0].text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_reviewer_callbacks(): |
|
|
on_start = callbacks.get("on_code_check_streaming_start", lambda *a, **k: None) |
|
|
on_token = callbacks.get("on_code_check_streaming_token", lambda *a, **k: None) |
|
|
on_end = callbacks.get("on_code_check_streaming_end", lambda *a, **k: None) |
|
|
interrupted = callbacks.get("check_interrupted", lambda: False) |
|
|
|
|
|
def _emit(tok: str): |
|
|
if not interrupted(): |
|
|
on_token(tok, iteration, "AI Code Reviewer") |
|
|
|
|
|
return on_start, on_token, on_end, _emit |
|
|
|
|
|
on_start, on_token, on_end, _emit = _make_reviewer_callbacks() |
|
|
|
|
|
|
|
|
model_name = "AI Code Reviewer" |
|
|
on_start(iteration, model_name) |
|
|
|
|
|
|
|
|
if hasattr(self.critic_model, "stream_chat"): |
|
|
resp = self.critic_model.stream_chat( |
|
|
prompt, |
|
|
sampling_params=params, |
|
|
emit_callback=_emit, |
|
|
) |
|
|
else: |
|
|
|
|
|
resp = self.critic_model.chat(prompt, sampling_params=params, use_tqdm=False) |
|
|
|
|
|
on_end(iteration, model_name) |
|
|
return resp[0].outputs[0].text |
|
|
|
|
|
|
|
|
|
|
|
def _fix_prompt( |
|
|
self, question, code, symbols, out, stdout, err, feedback, has_user_feedback=False |
|
|
) -> str: |
|
|
"""Return the prompt that asks the LLM to fix problems.""" |
|
|
base_prompt = f"""Please fix the issues with the code and symbols or output "FINISHED". |
|
|
The following is the result of evaluating the above code with the extracted symbols. |
|
|
``` |
|
|
Return value: {out} |
|
|
Standard output: {stdout} |
|
|
Exceptions: {err} |
|
|
``` |
|
|
|
|
|
The following is the summary of issues found with the code or the extracted symbols by another model: |
|
|
``` |
|
|
{feedback} |
|
|
``` |
|
|
""" |
|
|
|
|
|
if has_user_feedback: |
|
|
emphasis = """ |
|
|
IMPORTANT: The feedback above includes specific user input that you MUST prioritize and address. Pay special attention to any user comments and requirements, as they represent critical guidance from the human user that should take precedence in your solution. |
|
|
""" |
|
|
base_prompt += emphasis |
|
|
|
|
|
base_prompt += """ |
|
|
If there are any issues which impact the correctness of the answer, please output code which does not have the issues. Before outputting any code, plan how the code will solve the problem and avoid the issues. |
|
|
If stuck, try outputting different code to solve the problem in a different way. |
|
|
You may also revise the extracted symbols. To do this, output the revised symbols in a JSON code block. Only include information in the JSON which is present in the original input to keep the code grounded in the specific problem. Some examples of symbol revisions are changing the names of certain symbols, providing further granularity, and adding information which was originally missed. |
|
|
If everything is correct, output the word "FINISHED" and nothing else. |
|
|
""" |
|
|
return base_prompt |
|
|
|