testCase / graph.py
kumar-aditya's picture
Create graph.py
25a40cd verified
from __future__ import annotations
import json
from typing import Any, Dict, List, TypedDict
from langgraph.graph import END, StateGraph
from agents import (
build_code_analysis_agent,
build_feedback_agent,
build_spec_agent,
build_test_generator_agent,
build_test_plan_agent,
)
from llms import build_llm
from schemas import (
CodeAnalysis,
FeedbackSignal,
FinalReport,
Spec,
StudentTestSuite,
TestCaseList,
TestCase,
TestPlan,
)
class GraphState(TypedDict):
problem: str
description: str
constraints: str
code: str
language: str
per_category: int
student_count: int
iteration: int
spec: Spec
analysis: CodeAnalysis
plan: TestPlan
suites: List[StudentTestSuite]
feedback: FeedbackSignal
issues: List[str]
def _category_targets(per_category: int) -> Dict[str, int]:
categories = [
"Basic cases",
"Boundary cases",
"Random cases",
"Stress cases",
"Invalid/robustness cases",
"Bug-targeted cases",
]
return {category: per_category for category in categories}
def _normalize_category(label: str) -> str:
lower = label.strip().lower()
if "basic" in lower:
return "Basic cases"
if "boundary" in lower or "edge" in lower:
return "Boundary cases"
if "random" in lower:
return "Random cases"
if "stress" in lower:
return "Stress cases"
if "invalid" in lower or "robust" in lower:
return "Invalid/robustness cases"
if "bug" in lower:
return "Bug-targeted cases"
return label
def _enforce_targets(
cases: List[TestCase], targets: Dict[str, int]
) -> tuple[List[TestCase], Dict[str, int]]:
by_category: Dict[str, List[TestCase]] = {category: [] for category in targets}
for case in cases:
normalized = _normalize_category(case.category)
case.category = normalized
if normalized in by_category:
by_category[normalized].append(case)
enforced: List[TestCase] = []
missing: Dict[str, int] = {}
for category, count in targets.items():
selected = by_category.get(category, [])[:count]
enforced.extend(selected)
remaining = count - len(selected)
if remaining > 0:
missing[category] = remaining
return enforced, missing
def _strip_markdown(text: str) -> str:
stripped = text.strip()
if stripped.startswith("```") and stripped.endswith("```"):
lines = stripped.splitlines()
if len(lines) >= 2:
return "\n".join(lines[1:-1]).strip()
return stripped
def _extract_json_blob(text: str) -> str:
start_obj = text.find("{")
end_obj = text.rfind("}")
if start_obj != -1 and end_obj != -1 and end_obj > start_obj:
return text[start_obj : end_obj + 1]
start_list = text.find("[")
end_list = text.rfind("]")
if start_list != -1 and end_list != -1 and end_list > start_list:
return text[start_list : end_list + 1]
return text
def _scan_string(text: str, start: int) -> int:
index = start + 1
escaped = False
while index < len(text):
char = text[index]
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
return index + 1
index += 1
return len(text)
def _find_expr_end(text: str, start: int) -> int:
index = start
in_string = False
escaped = False
while index < len(text):
char = text[index]
if in_string:
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
in_string = False
else:
if char == '"':
in_string = True
elif char in {",", "}", "]"}:
return index
index += 1
return len(text)
def _tokenize_expr(expr: str) -> List[tuple[str, Any]] | None:
tokens: List[tuple[str, Any]] = []
index = 0
while index < len(expr):
char = expr[index]
if char.isspace():
index += 1
continue
if char == '"':
end = _scan_string(expr, index)
literal = expr[index:end]
try:
value = json.loads(literal)
except json.JSONDecodeError:
return None
tokens.append(("str", value))
index = end
continue
if char.isdigit():
end = index
while end < len(expr) and expr[end].isdigit():
end += 1
tokens.append(("int", int(expr[index:end])))
index = end
continue
if char in {"+", "*"}:
tokens.append(("op", char))
index += 1
continue
return None
return tokens
def _eval_string_expression(expr: str) -> str | None:
tokens = _tokenize_expr(expr)
if not tokens:
return None
has_string = any(token[0] == "str" for token in tokens)
if not has_string:
return None
def parse_term(pos: int) -> tuple[str | None, int]:
if pos >= len(tokens):
return None, pos
if tokens[pos][0] == "str":
value = tokens[pos][1]
elif tokens[pos][0] == "int":
value = str(tokens[pos][1])
else:
return None, pos
pos += 1
while pos + 1 < len(tokens) and tokens[pos] == ("op", "*"):
if tokens[pos + 1][0] != "int":
return None, pos
repeat = tokens[pos + 1][1]
value = value * repeat
pos += 2
return value, pos
result, pos = parse_term(0)
if result is None:
return None
while pos < len(tokens):
if tokens[pos] != ("op", "+"):
return None
term, pos = parse_term(pos + 1)
if term is None:
return None
result += term
return result
def _cap_string(value: str, limit: int = 200) -> str:
if len(value) <= limit:
return value
return value[:limit]
def _rewrite_repeat_calls(text: str) -> str:
output: List[str] = []
index = 0
while index < len(text):
char = text[index]
if char == '"':
start = index
end = _scan_string(text, index)
output.append(text[start:end])
probe = end
while probe < len(text) and text[probe].isspace():
probe += 1
if text.startswith(".repeat", probe):
cursor = probe + len(".repeat")
while cursor < len(text) and text[cursor].isspace():
cursor += 1
if cursor < len(text) and text[cursor] == "(":
cursor += 1
while cursor < len(text) and text[cursor].isspace():
cursor += 1
number_start = cursor
while cursor < len(text) and text[cursor].isdigit():
cursor += 1
number = text[number_start:cursor]
while cursor < len(text) and text[cursor].isspace():
cursor += 1
if number and cursor < len(text) and text[cursor] == ")":
output.append(f" * {number}")
index = cursor + 1
continue
index = end
continue
output.append(char)
index += 1
return "".join(output)
def _replace_string_expressions(text: str) -> str:
output: List[str] = []
index = 0
while index < len(text):
char = text[index]
if char == '"':
start = index
end = _scan_string(text, index)
probe = end
while probe < len(text) and text[probe].isspace():
probe += 1
if probe < len(text) and text[probe] in {"+", "*"}:
expr_end = _find_expr_end(text, start)
expr_text = text[start:expr_end]
evaluated = _eval_string_expression(expr_text)
if evaluated is not None:
output.append(json.dumps(_cap_string(evaluated)))
index = expr_end
continue
output.append(text[start:end])
index = end
continue
if char.isdigit():
start = index
end = index
while end < len(text) and text[end].isdigit():
end += 1
probe = end
while probe < len(text) and text[probe].isspace():
probe += 1
if probe < len(text) and text[probe] == "*":
expr_end = _find_expr_end(text, start)
expr_text = text[start:expr_end]
evaluated = _eval_string_expression(expr_text)
if evaluated is not None:
output.append(json.dumps(_cap_string(evaluated)))
index = expr_end
continue
output.append(char)
index += 1
return "".join(output)
def _parse_case_list(raw_text: str) -> TestCaseList:
cleaned = _strip_markdown(raw_text)
rewritten = _rewrite_repeat_calls(cleaned)
repaired = _replace_string_expressions(rewritten)
blob = _extract_json_blob(repaired)
try:
data = json.loads(blob)
if isinstance(data, list):
data = {"cases": data}
return TestCaseList.model_validate(data)
except json.JSONDecodeError:
return TestCaseList(cases=[])
def node_spec(state: GraphState) -> Dict[str, Any]:
llm = build_llm("gemini-3-flash-preview", temperature=0.2)
prompt, parser = build_spec_agent(llm)
chain = prompt | llm | parser
spec = chain.invoke(
{
"problem": state["problem"],
"description": state["description"],
"constraints": state["constraints"],
"language": state["language"],
"format_instructions": parser.get_format_instructions(),
}
)
return {"spec": spec}
def node_analysis(state: GraphState) -> Dict[str, Any]:
if not state["code"].strip():
return {"analysis": CodeAnalysis()}
llm = build_llm("gemini-2.5-flash", temperature=0.2)
prompt, parser = build_code_analysis_agent(llm)
chain = prompt | llm | parser
analysis = chain.invoke(
{
"code": state["code"],
"language": state["language"],
"format_instructions": parser.get_format_instructions(),
}
)
return {"analysis": analysis}
def node_start(state: GraphState) -> Dict[str, Any]:
return {"iteration": 0}
def node_plan(state: GraphState) -> Dict[str, Any]:
llm = build_llm("gemini-3.1-flash-lite-preview", temperature=0.3)
prompt, parser = build_test_plan_agent(llm)
chain = prompt | llm | parser
per_category = max(2, min(3, state["per_category"]))
plan = chain.invoke(
{
"spec": state["spec"].model_dump(),
"analysis": state["analysis"].model_dump(),
"issues": state.get("issues", []),
"per_category": per_category,
"format_instructions": parser.get_format_instructions(),
}
)
plan.targets = _category_targets(per_category)
plan.categories = list(plan.targets.keys())
return {"plan": plan}
def node_generate(state: GraphState) -> Dict[str, Any]:
llm = build_llm("gemini-2.5-flash-lite", temperature=0.5)
prompt, parser = build_test_generator_agent(llm)
chain = prompt | llm
suites: List[StudentTestSuite] = []
issues: List[str] = []
for student_id in range(1, state["student_count"] + 1):
response = chain.invoke(
{
"spec": state["spec"].model_dump(),
"plan": state["plan"].model_dump(),
"student_id": student_id,
"format_instructions": parser.get_format_instructions(),
}
)
raw_text = response.content if hasattr(response, "content") else str(response)
case_list = _parse_case_list(raw_text)
if not case_list.cases:
issues.append(f"Student {student_id} output parsing failed")
continue
enforced, missing = _enforce_targets(case_list.cases, state["plan"].targets)
suites.append(StudentTestSuite(student_id=student_id, cases=enforced))
if missing:
issues.append(
f"Student {student_id} missing categories: {sorted(missing.keys())}"
)
return {"suites": suites, "issues": issues}
def node_feedback(state: GraphState) -> Dict[str, Any]:
llm = build_llm("gemini-3-flash-preview", temperature=0.2)
prompt, parser = build_feedback_agent(llm)
chain = prompt | llm | parser
issues = state.get("issues", [])
feedback = chain.invoke(
{
"spec": state["spec"].model_dump(),
"plan": state["plan"].model_dump(),
"issues": issues,
"format_instructions": parser.get_format_instructions(),
}
)
needs_refine = feedback.needs_refine or bool(issues)
iteration = state.get("iteration", 0) + (1 if needs_refine else 0)
return {"feedback": feedback, "iteration": iteration}
def should_refine(state: GraphState) -> str:
max_refines = 1
if state.get("iteration", 0) > max_refines:
return "final"
if state.get("issues"):
return "refine"
return "refine" if state["feedback"].needs_refine else "final"
def build_graph():
graph = StateGraph(GraphState)
graph.add_node("start", node_start)
graph.add_node("spec", node_spec)
graph.add_node("analysis", node_analysis)
graph.add_node("plan", node_plan)
graph.add_node("generate", node_generate)
graph.add_node("feedback", node_feedback)
graph.set_entry_point("start")
graph.add_edge("start", "spec")
graph.add_edge("start", "analysis")
graph.add_edge("spec", "plan")
graph.add_edge("analysis", "plan")
graph.add_edge("plan", "generate")
graph.add_edge("generate", "feedback")
graph.add_conditional_edges(
"feedback",
should_refine,
{
"refine": "plan",
"final": END,
},
)
return graph.compile()
def run_pipeline(
*,
problem: str,
description: str,
constraints: str,
code: str,
language: str,
student_count: int,
per_category: int,
issues: List[str] | None = None,
) -> FinalReport:
app = build_graph()
state = app.invoke(
{
"problem": problem,
"description": description,
"constraints": constraints,
"code": code,
"language": language,
"student_count": student_count,
"per_category": per_category,
"issues": issues or [],
}
)
return FinalReport(
spec=state["spec"],
analysis=state["analysis"],
plan=state["plan"],
suites=state["suites"],
feedback=state["feedback"],
)