Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from typing import Any, Dict, List, TypedDict | |
| from langgraph.graph import END, StateGraph | |
| from agents import ( | |
| build_code_analysis_agent, | |
| build_feedback_agent, | |
| build_spec_agent, | |
| build_test_generator_agent, | |
| build_test_plan_agent, | |
| ) | |
| from llms import build_llm | |
| from schemas import ( | |
| CodeAnalysis, | |
| FeedbackSignal, | |
| FinalReport, | |
| Spec, | |
| StudentTestSuite, | |
| TestCaseList, | |
| TestCase, | |
| TestPlan, | |
| ) | |
| class GraphState(TypedDict): | |
| problem: str | |
| description: str | |
| constraints: str | |
| code: str | |
| language: str | |
| per_category: int | |
| student_count: int | |
| iteration: int | |
| spec: Spec | |
| analysis: CodeAnalysis | |
| plan: TestPlan | |
| suites: List[StudentTestSuite] | |
| feedback: FeedbackSignal | |
| issues: List[str] | |
| def _category_targets(per_category: int) -> Dict[str, int]: | |
| categories = [ | |
| "Basic cases", | |
| "Boundary cases", | |
| "Random cases", | |
| "Stress cases", | |
| "Invalid/robustness cases", | |
| "Bug-targeted cases", | |
| ] | |
| return {category: per_category for category in categories} | |
| def _normalize_category(label: str) -> str: | |
| lower = label.strip().lower() | |
| if "basic" in lower: | |
| return "Basic cases" | |
| if "boundary" in lower or "edge" in lower: | |
| return "Boundary cases" | |
| if "random" in lower: | |
| return "Random cases" | |
| if "stress" in lower: | |
| return "Stress cases" | |
| if "invalid" in lower or "robust" in lower: | |
| return "Invalid/robustness cases" | |
| if "bug" in lower: | |
| return "Bug-targeted cases" | |
| return label | |
| def _enforce_targets( | |
| cases: List[TestCase], targets: Dict[str, int] | |
| ) -> tuple[List[TestCase], Dict[str, int]]: | |
| by_category: Dict[str, List[TestCase]] = {category: [] for category in targets} | |
| for case in cases: | |
| normalized = _normalize_category(case.category) | |
| case.category = normalized | |
| if normalized in by_category: | |
| by_category[normalized].append(case) | |
| enforced: List[TestCase] = [] | |
| missing: Dict[str, int] = {} | |
| for category, count in targets.items(): | |
| selected = by_category.get(category, [])[:count] | |
| enforced.extend(selected) | |
| remaining = count - len(selected) | |
| if remaining > 0: | |
| missing[category] = remaining | |
| return enforced, missing | |
| def _strip_markdown(text: str) -> str: | |
| stripped = text.strip() | |
| if stripped.startswith("```") and stripped.endswith("```"): | |
| lines = stripped.splitlines() | |
| if len(lines) >= 2: | |
| return "\n".join(lines[1:-1]).strip() | |
| return stripped | |
| def _extract_json_blob(text: str) -> str: | |
| start_obj = text.find("{") | |
| end_obj = text.rfind("}") | |
| if start_obj != -1 and end_obj != -1 and end_obj > start_obj: | |
| return text[start_obj : end_obj + 1] | |
| start_list = text.find("[") | |
| end_list = text.rfind("]") | |
| if start_list != -1 and end_list != -1 and end_list > start_list: | |
| return text[start_list : end_list + 1] | |
| return text | |
| def _scan_string(text: str, start: int) -> int: | |
| index = start + 1 | |
| escaped = False | |
| while index < len(text): | |
| char = text[index] | |
| if escaped: | |
| escaped = False | |
| elif char == "\\": | |
| escaped = True | |
| elif char == '"': | |
| return index + 1 | |
| index += 1 | |
| return len(text) | |
| def _find_expr_end(text: str, start: int) -> int: | |
| index = start | |
| in_string = False | |
| escaped = False | |
| while index < len(text): | |
| char = text[index] | |
| if in_string: | |
| if escaped: | |
| escaped = False | |
| elif char == "\\": | |
| escaped = True | |
| elif char == '"': | |
| in_string = False | |
| else: | |
| if char == '"': | |
| in_string = True | |
| elif char in {",", "}", "]"}: | |
| return index | |
| index += 1 | |
| return len(text) | |
| def _tokenize_expr(expr: str) -> List[tuple[str, Any]] | None: | |
| tokens: List[tuple[str, Any]] = [] | |
| index = 0 | |
| while index < len(expr): | |
| char = expr[index] | |
| if char.isspace(): | |
| index += 1 | |
| continue | |
| if char == '"': | |
| end = _scan_string(expr, index) | |
| literal = expr[index:end] | |
| try: | |
| value = json.loads(literal) | |
| except json.JSONDecodeError: | |
| return None | |
| tokens.append(("str", value)) | |
| index = end | |
| continue | |
| if char.isdigit(): | |
| end = index | |
| while end < len(expr) and expr[end].isdigit(): | |
| end += 1 | |
| tokens.append(("int", int(expr[index:end]))) | |
| index = end | |
| continue | |
| if char in {"+", "*"}: | |
| tokens.append(("op", char)) | |
| index += 1 | |
| continue | |
| return None | |
| return tokens | |
| def _eval_string_expression(expr: str) -> str | None: | |
| tokens = _tokenize_expr(expr) | |
| if not tokens: | |
| return None | |
| has_string = any(token[0] == "str" for token in tokens) | |
| if not has_string: | |
| return None | |
| def parse_term(pos: int) -> tuple[str | None, int]: | |
| if pos >= len(tokens): | |
| return None, pos | |
| if tokens[pos][0] == "str": | |
| value = tokens[pos][1] | |
| elif tokens[pos][0] == "int": | |
| value = str(tokens[pos][1]) | |
| else: | |
| return None, pos | |
| pos += 1 | |
| while pos + 1 < len(tokens) and tokens[pos] == ("op", "*"): | |
| if tokens[pos + 1][0] != "int": | |
| return None, pos | |
| repeat = tokens[pos + 1][1] | |
| value = value * repeat | |
| pos += 2 | |
| return value, pos | |
| result, pos = parse_term(0) | |
| if result is None: | |
| return None | |
| while pos < len(tokens): | |
| if tokens[pos] != ("op", "+"): | |
| return None | |
| term, pos = parse_term(pos + 1) | |
| if term is None: | |
| return None | |
| result += term | |
| return result | |
| def _cap_string(value: str, limit: int = 200) -> str: | |
| if len(value) <= limit: | |
| return value | |
| return value[:limit] | |
| def _rewrite_repeat_calls(text: str) -> str: | |
| output: List[str] = [] | |
| index = 0 | |
| while index < len(text): | |
| char = text[index] | |
| if char == '"': | |
| start = index | |
| end = _scan_string(text, index) | |
| output.append(text[start:end]) | |
| probe = end | |
| while probe < len(text) and text[probe].isspace(): | |
| probe += 1 | |
| if text.startswith(".repeat", probe): | |
| cursor = probe + len(".repeat") | |
| while cursor < len(text) and text[cursor].isspace(): | |
| cursor += 1 | |
| if cursor < len(text) and text[cursor] == "(": | |
| cursor += 1 | |
| while cursor < len(text) and text[cursor].isspace(): | |
| cursor += 1 | |
| number_start = cursor | |
| while cursor < len(text) and text[cursor].isdigit(): | |
| cursor += 1 | |
| number = text[number_start:cursor] | |
| while cursor < len(text) and text[cursor].isspace(): | |
| cursor += 1 | |
| if number and cursor < len(text) and text[cursor] == ")": | |
| output.append(f" * {number}") | |
| index = cursor + 1 | |
| continue | |
| index = end | |
| continue | |
| output.append(char) | |
| index += 1 | |
| return "".join(output) | |
| def _replace_string_expressions(text: str) -> str: | |
| output: List[str] = [] | |
| index = 0 | |
| while index < len(text): | |
| char = text[index] | |
| if char == '"': | |
| start = index | |
| end = _scan_string(text, index) | |
| probe = end | |
| while probe < len(text) and text[probe].isspace(): | |
| probe += 1 | |
| if probe < len(text) and text[probe] in {"+", "*"}: | |
| expr_end = _find_expr_end(text, start) | |
| expr_text = text[start:expr_end] | |
| evaluated = _eval_string_expression(expr_text) | |
| if evaluated is not None: | |
| output.append(json.dumps(_cap_string(evaluated))) | |
| index = expr_end | |
| continue | |
| output.append(text[start:end]) | |
| index = end | |
| continue | |
| if char.isdigit(): | |
| start = index | |
| end = index | |
| while end < len(text) and text[end].isdigit(): | |
| end += 1 | |
| probe = end | |
| while probe < len(text) and text[probe].isspace(): | |
| probe += 1 | |
| if probe < len(text) and text[probe] == "*": | |
| expr_end = _find_expr_end(text, start) | |
| expr_text = text[start:expr_end] | |
| evaluated = _eval_string_expression(expr_text) | |
| if evaluated is not None: | |
| output.append(json.dumps(_cap_string(evaluated))) | |
| index = expr_end | |
| continue | |
| output.append(char) | |
| index += 1 | |
| return "".join(output) | |
| def _parse_case_list(raw_text: str) -> TestCaseList: | |
| cleaned = _strip_markdown(raw_text) | |
| rewritten = _rewrite_repeat_calls(cleaned) | |
| repaired = _replace_string_expressions(rewritten) | |
| blob = _extract_json_blob(repaired) | |
| try: | |
| data = json.loads(blob) | |
| if isinstance(data, list): | |
| data = {"cases": data} | |
| return TestCaseList.model_validate(data) | |
| except json.JSONDecodeError: | |
| return TestCaseList(cases=[]) | |
| def node_spec(state: GraphState) -> Dict[str, Any]: | |
| llm = build_llm("gemini-3-flash-preview", temperature=0.2) | |
| prompt, parser = build_spec_agent(llm) | |
| chain = prompt | llm | parser | |
| spec = chain.invoke( | |
| { | |
| "problem": state["problem"], | |
| "description": state["description"], | |
| "constraints": state["constraints"], | |
| "language": state["language"], | |
| "format_instructions": parser.get_format_instructions(), | |
| } | |
| ) | |
| return {"spec": spec} | |
| def node_analysis(state: GraphState) -> Dict[str, Any]: | |
| if not state["code"].strip(): | |
| return {"analysis": CodeAnalysis()} | |
| llm = build_llm("gemini-2.5-flash", temperature=0.2) | |
| prompt, parser = build_code_analysis_agent(llm) | |
| chain = prompt | llm | parser | |
| analysis = chain.invoke( | |
| { | |
| "code": state["code"], | |
| "language": state["language"], | |
| "format_instructions": parser.get_format_instructions(), | |
| } | |
| ) | |
| return {"analysis": analysis} | |
| def node_start(state: GraphState) -> Dict[str, Any]: | |
| return {"iteration": 0} | |
| def node_plan(state: GraphState) -> Dict[str, Any]: | |
| llm = build_llm("gemini-3.1-flash-lite-preview", temperature=0.3) | |
| prompt, parser = build_test_plan_agent(llm) | |
| chain = prompt | llm | parser | |
| per_category = max(2, min(3, state["per_category"])) | |
| plan = chain.invoke( | |
| { | |
| "spec": state["spec"].model_dump(), | |
| "analysis": state["analysis"].model_dump(), | |
| "issues": state.get("issues", []), | |
| "per_category": per_category, | |
| "format_instructions": parser.get_format_instructions(), | |
| } | |
| ) | |
| plan.targets = _category_targets(per_category) | |
| plan.categories = list(plan.targets.keys()) | |
| return {"plan": plan} | |
| def node_generate(state: GraphState) -> Dict[str, Any]: | |
| llm = build_llm("gemini-2.5-flash-lite", temperature=0.5) | |
| prompt, parser = build_test_generator_agent(llm) | |
| chain = prompt | llm | |
| suites: List[StudentTestSuite] = [] | |
| issues: List[str] = [] | |
| for student_id in range(1, state["student_count"] + 1): | |
| response = chain.invoke( | |
| { | |
| "spec": state["spec"].model_dump(), | |
| "plan": state["plan"].model_dump(), | |
| "student_id": student_id, | |
| "format_instructions": parser.get_format_instructions(), | |
| } | |
| ) | |
| raw_text = response.content if hasattr(response, "content") else str(response) | |
| case_list = _parse_case_list(raw_text) | |
| if not case_list.cases: | |
| issues.append(f"Student {student_id} output parsing failed") | |
| continue | |
| enforced, missing = _enforce_targets(case_list.cases, state["plan"].targets) | |
| suites.append(StudentTestSuite(student_id=student_id, cases=enforced)) | |
| if missing: | |
| issues.append( | |
| f"Student {student_id} missing categories: {sorted(missing.keys())}" | |
| ) | |
| return {"suites": suites, "issues": issues} | |
| def node_feedback(state: GraphState) -> Dict[str, Any]: | |
| llm = build_llm("gemini-3-flash-preview", temperature=0.2) | |
| prompt, parser = build_feedback_agent(llm) | |
| chain = prompt | llm | parser | |
| issues = state.get("issues", []) | |
| feedback = chain.invoke( | |
| { | |
| "spec": state["spec"].model_dump(), | |
| "plan": state["plan"].model_dump(), | |
| "issues": issues, | |
| "format_instructions": parser.get_format_instructions(), | |
| } | |
| ) | |
| needs_refine = feedback.needs_refine or bool(issues) | |
| iteration = state.get("iteration", 0) + (1 if needs_refine else 0) | |
| return {"feedback": feedback, "iteration": iteration} | |
| def should_refine(state: GraphState) -> str: | |
| max_refines = 1 | |
| if state.get("iteration", 0) > max_refines: | |
| return "final" | |
| if state.get("issues"): | |
| return "refine" | |
| return "refine" if state["feedback"].needs_refine else "final" | |
| def build_graph(): | |
| graph = StateGraph(GraphState) | |
| graph.add_node("start", node_start) | |
| graph.add_node("spec", node_spec) | |
| graph.add_node("analysis", node_analysis) | |
| graph.add_node("plan", node_plan) | |
| graph.add_node("generate", node_generate) | |
| graph.add_node("feedback", node_feedback) | |
| graph.set_entry_point("start") | |
| graph.add_edge("start", "spec") | |
| graph.add_edge("start", "analysis") | |
| graph.add_edge("spec", "plan") | |
| graph.add_edge("analysis", "plan") | |
| graph.add_edge("plan", "generate") | |
| graph.add_edge("generate", "feedback") | |
| graph.add_conditional_edges( | |
| "feedback", | |
| should_refine, | |
| { | |
| "refine": "plan", | |
| "final": END, | |
| }, | |
| ) | |
| return graph.compile() | |
| def run_pipeline( | |
| *, | |
| problem: str, | |
| description: str, | |
| constraints: str, | |
| code: str, | |
| language: str, | |
| student_count: int, | |
| per_category: int, | |
| issues: List[str] | None = None, | |
| ) -> FinalReport: | |
| app = build_graph() | |
| state = app.invoke( | |
| { | |
| "problem": problem, | |
| "description": description, | |
| "constraints": constraints, | |
| "code": code, | |
| "language": language, | |
| "student_count": student_count, | |
| "per_category": per_category, | |
| "issues": issues or [], | |
| } | |
| ) | |
| return FinalReport( | |
| spec=state["spec"], | |
| analysis=state["analysis"], | |
| plan=state["plan"], | |
| suites=state["suites"], | |
| feedback=state["feedback"], | |
| ) | |