Phillnet-2 / AgenticScaffold /validator.py
ayjays132's picture
Upload 478 files
101858b verified
import torch
import re
from typing import Optional, Dict, Any, List, Tuple
from .VisualScaffold import ProductionLogger, ProductionThinking
class ActionPlanValidator:
"""
Advanced Action Plan Validator with Agentic Self-Correction.
Ensures structural integrity and logical consistency of machine-readable outputs.
"""
def __init__(self, tokenizer=None, workspace=None, file_index=None, show_thinking: bool = True):
self.tokenizer = tokenizer
self.workspace = workspace
self.file_index = file_index
self.logger = ProductionLogger(show_thinking=show_thinking)
self._max_candidate_tokens = 20_000
self._repair_loops = 2
def get_tidbit(self, user_prompt: str = "") -> str:
"""
Dynamic structural tidbit.
Only triggers if the prompt suggests code changes or complex formatting.
"""
p_lower = (user_prompt or "").lower()
if any(w in p_lower for w in ["modify", "edit", "change", "diff", "replace", "refactor", "patch"]):
return (
"[STRUCTURAL TIDBIT]\n"
"- Use surgical patches (<|diff_start|>) for file edits.\n"
"- Maintain strict balance for markdown blocks (```).\n"
)
return ""
def validate_content(self, text: str) -> Dict[str, Any]:
"""
Validates the content of a generation, checking for schema and common failures.
"""
report = {
"is_valid": True,
"risks": [],
"suggested_fixes": [],
"stats": {}
}
if not text or not text.strip():
report["is_valid"] = False
report["risks"].append("Output is empty or null.")
return report
# 1. Structural Validation (Diffs - Support Multiple)
diff_starts = list(re.finditer(r"<\|diff_start\|>", text))
if diff_starts:
import diff_utils
for match in diff_starts:
try:
start_idx = match.end()
# Find matching end or EOF
end_match = re.search(r"<\|diff_end\|>", text[start_idx:])
if not end_match:
report["is_valid"] = False
report["risks"].append("Unclosed <|diff_start|> tag.")
continue
diff_content = text[start_idx : start_idx + end_match.start()].strip()
if not diff_utils.validate_diff(diff_content):
report["is_valid"] = False
report["risks"].append(f"Malformed surgical patch in block starting at char {start_idx}.")
report["suggested_fixes"].append("Ensure the diff contains valid hunk headers (@@ -L,C +L,C @@).")
except Exception as e:
report["is_valid"] = False
report["risks"].append(f"Internal error validating diff: {e}")
# 2. Holistic Coherence (Grounding)
if self.workspace:
objective = self.workspace.read_objective().lower()
# If the output text is too short but objective is complex, warn
if len(text) < 100 and len(objective) > 200:
report["risks"].append("Output seems suspiciously brief given the complex objective.")
# We don't invalidate completely, just warn or nudge in repair
# 3. Heuristic Complexity Check
approx_tokens = len(text) // 4
report["stats"]["approx_tokens"] = approx_tokens
if approx_tokens > self._max_candidate_tokens:
report["risks"].append(f"Payload exceeds heuristic budget (~{approx_tokens} tokens).")
report["is_valid"] = False
# 4. Code Block Balance
if text.count("```") % 2 != 0:
report["risks"].append("Unbalanced markdown code blocks detected.")
report["is_valid"] = False
return report
@torch.inference_mode()
def self_correct(self, model, input_ids: torch.LongTensor, draft_text: str, report: Dict[str, Any], **kwargs) -> str:
"""
Agentic Repair Loop: Prompts the model to fix its own validation failures.
"""
if report["is_valid"]:
return draft_text
with ProductionThinking(self.logger, "VALIDATOR", "Initiating Neural Self-Correction"):
fixes = "\n".join([f"- {r}" for r in report["risks"]])
repair_prompt = (
f"<|im_start|>system\n[VALIDATION FAILURE]\nThe previous output failed structural checks:\n{fixes}\n"
f"Please fix the formatting and return only the CORRECTED output.<|im_end|>\n"
f"<|im_start|>assistant\n<|improve|>\n"
)
tokenizer = self.tokenizer or getattr(model, "tokenizer", None)
if tokenizer is None:
self.logger.error("VALIDATOR", "No tokenizer found for repair. Returning original draft.")
return draft_text
inputs = tokenizer(repair_prompt, return_tensors="pt").to(model.device)
gen_kwargs = kwargs.copy()
for k in ["input_ids", "attention_mask", "position_ids", "past_key_values", "use_cache", "labels", "epistemic_strictness"]:
gen_kwargs.pop(k, None)
out_ids = model.generate_base(inputs.input_ids, **gen_kwargs)
repaired_text = self.tokenizer.decode(out_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
self.logger.success("VALIDATOR", "Neural correction cycle complete.")
return repaired_text.strip()