import torch import re from typing import Optional, Dict, Any, List, Tuple from .VisualScaffold import ProductionLogger, ProductionThinking class ActionPlanValidator: """ Advanced Action Plan Validator with Agentic Self-Correction. Ensures structural integrity and logical consistency of machine-readable outputs. """ def __init__(self, tokenizer=None, workspace=None, file_index=None, show_thinking: bool = True): self.tokenizer = tokenizer self.workspace = workspace self.file_index = file_index self.logger = ProductionLogger(show_thinking=show_thinking) self._max_candidate_tokens = 20_000 self._repair_loops = 2 def get_tidbit(self, user_prompt: str = "") -> str: """ Dynamic structural tidbit. Only triggers if the prompt suggests code changes or complex formatting. """ p_lower = (user_prompt or "").lower() if any(w in p_lower for w in ["modify", "edit", "change", "diff", "replace", "refactor", "patch"]): return ( "[STRUCTURAL TIDBIT]\n" "- Use surgical patches (<|diff_start|>) for file edits.\n" "- Maintain strict balance for markdown blocks (```).\n" ) return "" def validate_content(self, text: str) -> Dict[str, Any]: """ Validates the content of a generation, checking for schema and common failures. """ report = { "is_valid": True, "risks": [], "suggested_fixes": [], "stats": {} } if not text or not text.strip(): report["is_valid"] = False report["risks"].append("Output is empty or null.") return report # 1. Structural Validation (Diffs - Support Multiple) diff_starts = list(re.finditer(r"<\|diff_start\|>", text)) if diff_starts: import diff_utils for match in diff_starts: try: start_idx = match.end() # Find matching end or EOF end_match = re.search(r"<\|diff_end\|>", text[start_idx:]) if not end_match: report["is_valid"] = False report["risks"].append("Unclosed <|diff_start|> tag.") continue diff_content = text[start_idx : start_idx + end_match.start()].strip() if not diff_utils.validate_diff(diff_content): report["is_valid"] = False report["risks"].append(f"Malformed surgical patch in block starting at char {start_idx}.") report["suggested_fixes"].append("Ensure the diff contains valid hunk headers (@@ -L,C +L,C @@).") except Exception as e: report["is_valid"] = False report["risks"].append(f"Internal error validating diff: {e}") # 2. Holistic Coherence (Grounding) if self.workspace: objective = self.workspace.read_objective().lower() # If the output text is too short but objective is complex, warn if len(text) < 100 and len(objective) > 200: report["risks"].append("Output seems suspiciously brief given the complex objective.") # We don't invalidate completely, just warn or nudge in repair # 3. Heuristic Complexity Check approx_tokens = len(text) // 4 report["stats"]["approx_tokens"] = approx_tokens if approx_tokens > self._max_candidate_tokens: report["risks"].append(f"Payload exceeds heuristic budget (~{approx_tokens} tokens).") report["is_valid"] = False # 4. Code Block Balance if text.count("```") % 2 != 0: report["risks"].append("Unbalanced markdown code blocks detected.") report["is_valid"] = False return report @torch.inference_mode() def self_correct(self, model, input_ids: torch.LongTensor, draft_text: str, report: Dict[str, Any], **kwargs) -> str: """ Agentic Repair Loop: Prompts the model to fix its own validation failures. """ if report["is_valid"]: return draft_text with ProductionThinking(self.logger, "VALIDATOR", "Initiating Neural Self-Correction"): fixes = "\n".join([f"- {r}" for r in report["risks"]]) repair_prompt = ( f"<|im_start|>system\n[VALIDATION FAILURE]\nThe previous output failed structural checks:\n{fixes}\n" f"Please fix the formatting and return only the CORRECTED output.<|im_end|>\n" f"<|im_start|>assistant\n<|improve|>\n" ) tokenizer = self.tokenizer or getattr(model, "tokenizer", None) if tokenizer is None: self.logger.error("VALIDATOR", "No tokenizer found for repair. Returning original draft.") return draft_text inputs = tokenizer(repair_prompt, return_tensors="pt").to(model.device) gen_kwargs = kwargs.copy() for k in ["input_ids", "attention_mask", "position_ids", "past_key_values", "use_cache", "labels", "epistemic_strictness"]: gen_kwargs.pop(k, None) out_ids = model.generate_base(inputs.input_ids, **gen_kwargs) repaired_text = self.tokenizer.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) self.logger.success("VALIDATOR", "Neural correction cycle complete.") return repaired_text.strip()