kumar-aditya commited on
Commit
25a40cd
·
verified ·
1 Parent(s): 7b3f919

Create graph.py

Browse files
Files changed (1) hide show
  1. graph.py +495 -0
graph.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any, Dict, List, TypedDict
5
+
6
+ from langgraph.graph import END, StateGraph
7
+
8
+ from agents import (
9
+ build_code_analysis_agent,
10
+ build_feedback_agent,
11
+ build_spec_agent,
12
+ build_test_generator_agent,
13
+ build_test_plan_agent,
14
+ )
15
+ from llms import build_llm
16
+ from schemas import (
17
+ CodeAnalysis,
18
+ FeedbackSignal,
19
+ FinalReport,
20
+ Spec,
21
+ StudentTestSuite,
22
+ TestCaseList,
23
+ TestCase,
24
+ TestPlan,
25
+ )
26
+
27
+
28
+ class GraphState(TypedDict):
29
+ problem: str
30
+ description: str
31
+ constraints: str
32
+ code: str
33
+ language: str
34
+ per_category: int
35
+ student_count: int
36
+ iteration: int
37
+ spec: Spec
38
+ analysis: CodeAnalysis
39
+ plan: TestPlan
40
+ suites: List[StudentTestSuite]
41
+ feedback: FeedbackSignal
42
+ issues: List[str]
43
+
44
+
45
+ def _category_targets(per_category: int) -> Dict[str, int]:
46
+ categories = [
47
+ "Basic cases",
48
+ "Boundary cases",
49
+ "Random cases",
50
+ "Stress cases",
51
+ "Invalid/robustness cases",
52
+ "Bug-targeted cases",
53
+ ]
54
+ return {category: per_category for category in categories}
55
+
56
+
57
+ def _normalize_category(label: str) -> str:
58
+ lower = label.strip().lower()
59
+ if "basic" in lower:
60
+ return "Basic cases"
61
+ if "boundary" in lower or "edge" in lower:
62
+ return "Boundary cases"
63
+ if "random" in lower:
64
+ return "Random cases"
65
+ if "stress" in lower:
66
+ return "Stress cases"
67
+ if "invalid" in lower or "robust" in lower:
68
+ return "Invalid/robustness cases"
69
+ if "bug" in lower:
70
+ return "Bug-targeted cases"
71
+ return label
72
+
73
+
74
+ def _enforce_targets(
75
+ cases: List[TestCase], targets: Dict[str, int]
76
+ ) -> tuple[List[TestCase], Dict[str, int]]:
77
+ by_category: Dict[str, List[TestCase]] = {category: [] for category in targets}
78
+ for case in cases:
79
+ normalized = _normalize_category(case.category)
80
+ case.category = normalized
81
+ if normalized in by_category:
82
+ by_category[normalized].append(case)
83
+
84
+ enforced: List[TestCase] = []
85
+ missing: Dict[str, int] = {}
86
+ for category, count in targets.items():
87
+ selected = by_category.get(category, [])[:count]
88
+ enforced.extend(selected)
89
+ remaining = count - len(selected)
90
+ if remaining > 0:
91
+ missing[category] = remaining
92
+ return enforced, missing
93
+
94
+
95
+ def _strip_markdown(text: str) -> str:
96
+ stripped = text.strip()
97
+ if stripped.startswith("```") and stripped.endswith("```"):
98
+ lines = stripped.splitlines()
99
+ if len(lines) >= 2:
100
+ return "\n".join(lines[1:-1]).strip()
101
+ return stripped
102
+
103
+
104
+ def _extract_json_blob(text: str) -> str:
105
+ start_obj = text.find("{")
106
+ end_obj = text.rfind("}")
107
+ if start_obj != -1 and end_obj != -1 and end_obj > start_obj:
108
+ return text[start_obj : end_obj + 1]
109
+ start_list = text.find("[")
110
+ end_list = text.rfind("]")
111
+ if start_list != -1 and end_list != -1 and end_list > start_list:
112
+ return text[start_list : end_list + 1]
113
+ return text
114
+
115
+
116
+ def _scan_string(text: str, start: int) -> int:
117
+ index = start + 1
118
+ escaped = False
119
+ while index < len(text):
120
+ char = text[index]
121
+ if escaped:
122
+ escaped = False
123
+ elif char == "\\":
124
+ escaped = True
125
+ elif char == '"':
126
+ return index + 1
127
+ index += 1
128
+ return len(text)
129
+
130
+
131
+ def _find_expr_end(text: str, start: int) -> int:
132
+ index = start
133
+ in_string = False
134
+ escaped = False
135
+ while index < len(text):
136
+ char = text[index]
137
+ if in_string:
138
+ if escaped:
139
+ escaped = False
140
+ elif char == "\\":
141
+ escaped = True
142
+ elif char == '"':
143
+ in_string = False
144
+ else:
145
+ if char == '"':
146
+ in_string = True
147
+ elif char in {",", "}", "]"}:
148
+ return index
149
+ index += 1
150
+ return len(text)
151
+
152
+
153
+ def _tokenize_expr(expr: str) -> List[tuple[str, Any]] | None:
154
+ tokens: List[tuple[str, Any]] = []
155
+ index = 0
156
+ while index < len(expr):
157
+ char = expr[index]
158
+ if char.isspace():
159
+ index += 1
160
+ continue
161
+ if char == '"':
162
+ end = _scan_string(expr, index)
163
+ literal = expr[index:end]
164
+ try:
165
+ value = json.loads(literal)
166
+ except json.JSONDecodeError:
167
+ return None
168
+ tokens.append(("str", value))
169
+ index = end
170
+ continue
171
+ if char.isdigit():
172
+ end = index
173
+ while end < len(expr) and expr[end].isdigit():
174
+ end += 1
175
+ tokens.append(("int", int(expr[index:end])))
176
+ index = end
177
+ continue
178
+ if char in {"+", "*"}:
179
+ tokens.append(("op", char))
180
+ index += 1
181
+ continue
182
+ return None
183
+ return tokens
184
+
185
+
186
+ def _eval_string_expression(expr: str) -> str | None:
187
+ tokens = _tokenize_expr(expr)
188
+ if not tokens:
189
+ return None
190
+
191
+ has_string = any(token[0] == "str" for token in tokens)
192
+ if not has_string:
193
+ return None
194
+
195
+ def parse_term(pos: int) -> tuple[str | None, int]:
196
+ if pos >= len(tokens):
197
+ return None, pos
198
+ if tokens[pos][0] == "str":
199
+ value = tokens[pos][1]
200
+ elif tokens[pos][0] == "int":
201
+ value = str(tokens[pos][1])
202
+ else:
203
+ return None, pos
204
+ pos += 1
205
+ while pos + 1 < len(tokens) and tokens[pos] == ("op", "*"):
206
+ if tokens[pos + 1][0] != "int":
207
+ return None, pos
208
+ repeat = tokens[pos + 1][1]
209
+ value = value * repeat
210
+ pos += 2
211
+ return value, pos
212
+
213
+ result, pos = parse_term(0)
214
+ if result is None:
215
+ return None
216
+ while pos < len(tokens):
217
+ if tokens[pos] != ("op", "+"):
218
+ return None
219
+ term, pos = parse_term(pos + 1)
220
+ if term is None:
221
+ return None
222
+ result += term
223
+ return result
224
+
225
+
226
+ def _cap_string(value: str, limit: int = 200) -> str:
227
+ if len(value) <= limit:
228
+ return value
229
+ return value[:limit]
230
+
231
+
232
+ def _rewrite_repeat_calls(text: str) -> str:
233
+ output: List[str] = []
234
+ index = 0
235
+ while index < len(text):
236
+ char = text[index]
237
+ if char == '"':
238
+ start = index
239
+ end = _scan_string(text, index)
240
+ output.append(text[start:end])
241
+ probe = end
242
+ while probe < len(text) and text[probe].isspace():
243
+ probe += 1
244
+ if text.startswith(".repeat", probe):
245
+ cursor = probe + len(".repeat")
246
+ while cursor < len(text) and text[cursor].isspace():
247
+ cursor += 1
248
+ if cursor < len(text) and text[cursor] == "(":
249
+ cursor += 1
250
+ while cursor < len(text) and text[cursor].isspace():
251
+ cursor += 1
252
+ number_start = cursor
253
+ while cursor < len(text) and text[cursor].isdigit():
254
+ cursor += 1
255
+ number = text[number_start:cursor]
256
+ while cursor < len(text) and text[cursor].isspace():
257
+ cursor += 1
258
+ if number and cursor < len(text) and text[cursor] == ")":
259
+ output.append(f" * {number}")
260
+ index = cursor + 1
261
+ continue
262
+ index = end
263
+ continue
264
+ output.append(char)
265
+ index += 1
266
+ return "".join(output)
267
+
268
+
269
+ def _replace_string_expressions(text: str) -> str:
270
+ output: List[str] = []
271
+ index = 0
272
+ while index < len(text):
273
+ char = text[index]
274
+ if char == '"':
275
+ start = index
276
+ end = _scan_string(text, index)
277
+ probe = end
278
+ while probe < len(text) and text[probe].isspace():
279
+ probe += 1
280
+ if probe < len(text) and text[probe] in {"+", "*"}:
281
+ expr_end = _find_expr_end(text, start)
282
+ expr_text = text[start:expr_end]
283
+ evaluated = _eval_string_expression(expr_text)
284
+ if evaluated is not None:
285
+ output.append(json.dumps(_cap_string(evaluated)))
286
+ index = expr_end
287
+ continue
288
+ output.append(text[start:end])
289
+ index = end
290
+ continue
291
+ if char.isdigit():
292
+ start = index
293
+ end = index
294
+ while end < len(text) and text[end].isdigit():
295
+ end += 1
296
+ probe = end
297
+ while probe < len(text) and text[probe].isspace():
298
+ probe += 1
299
+ if probe < len(text) and text[probe] == "*":
300
+ expr_end = _find_expr_end(text, start)
301
+ expr_text = text[start:expr_end]
302
+ evaluated = _eval_string_expression(expr_text)
303
+ if evaluated is not None:
304
+ output.append(json.dumps(_cap_string(evaluated)))
305
+ index = expr_end
306
+ continue
307
+ output.append(char)
308
+ index += 1
309
+ return "".join(output)
310
+
311
+
312
+ def _parse_case_list(raw_text: str) -> TestCaseList:
313
+ cleaned = _strip_markdown(raw_text)
314
+ rewritten = _rewrite_repeat_calls(cleaned)
315
+ repaired = _replace_string_expressions(rewritten)
316
+ blob = _extract_json_blob(repaired)
317
+ try:
318
+ data = json.loads(blob)
319
+ if isinstance(data, list):
320
+ data = {"cases": data}
321
+ return TestCaseList.model_validate(data)
322
+ except json.JSONDecodeError:
323
+ return TestCaseList(cases=[])
324
+
325
+
326
+ def node_spec(state: GraphState) -> Dict[str, Any]:
327
+ llm = build_llm("gemini-3-flash-preview", temperature=0.2)
328
+ prompt, parser = build_spec_agent(llm)
329
+ chain = prompt | llm | parser
330
+ spec = chain.invoke(
331
+ {
332
+ "problem": state["problem"],
333
+ "description": state["description"],
334
+ "constraints": state["constraints"],
335
+ "language": state["language"],
336
+ "format_instructions": parser.get_format_instructions(),
337
+ }
338
+ )
339
+ return {"spec": spec}
340
+
341
+
342
+ def node_analysis(state: GraphState) -> Dict[str, Any]:
343
+ if not state["code"].strip():
344
+ return {"analysis": CodeAnalysis()}
345
+ llm = build_llm("gemini-2.5-flash", temperature=0.2)
346
+ prompt, parser = build_code_analysis_agent(llm)
347
+ chain = prompt | llm | parser
348
+ analysis = chain.invoke(
349
+ {
350
+ "code": state["code"],
351
+ "language": state["language"],
352
+ "format_instructions": parser.get_format_instructions(),
353
+ }
354
+ )
355
+ return {"analysis": analysis}
356
+
357
+
358
+ def node_start(state: GraphState) -> Dict[str, Any]:
359
+ return {"iteration": 0}
360
+
361
+
362
+ def node_plan(state: GraphState) -> Dict[str, Any]:
363
+ llm = build_llm("gemini-3.1-flash-lite-preview", temperature=0.3)
364
+ prompt, parser = build_test_plan_agent(llm)
365
+ chain = prompt | llm | parser
366
+ per_category = max(2, min(3, state["per_category"]))
367
+ plan = chain.invoke(
368
+ {
369
+ "spec": state["spec"].model_dump(),
370
+ "analysis": state["analysis"].model_dump(),
371
+ "issues": state.get("issues", []),
372
+ "per_category": per_category,
373
+ "format_instructions": parser.get_format_instructions(),
374
+ }
375
+ )
376
+ plan.targets = _category_targets(per_category)
377
+ plan.categories = list(plan.targets.keys())
378
+ return {"plan": plan}
379
+
380
+
381
+ def node_generate(state: GraphState) -> Dict[str, Any]:
382
+ llm = build_llm("gemini-2.5-flash-lite", temperature=0.5)
383
+ prompt, parser = build_test_generator_agent(llm)
384
+ chain = prompt | llm
385
+ suites: List[StudentTestSuite] = []
386
+ issues: List[str] = []
387
+ for student_id in range(1, state["student_count"] + 1):
388
+ response = chain.invoke(
389
+ {
390
+ "spec": state["spec"].model_dump(),
391
+ "plan": state["plan"].model_dump(),
392
+ "student_id": student_id,
393
+ "format_instructions": parser.get_format_instructions(),
394
+ }
395
+ )
396
+ raw_text = response.content if hasattr(response, "content") else str(response)
397
+ case_list = _parse_case_list(raw_text)
398
+ if not case_list.cases:
399
+ issues.append(f"Student {student_id} output parsing failed")
400
+ continue
401
+ enforced, missing = _enforce_targets(case_list.cases, state["plan"].targets)
402
+ suites.append(StudentTestSuite(student_id=student_id, cases=enforced))
403
+ if missing:
404
+ issues.append(
405
+ f"Student {student_id} missing categories: {sorted(missing.keys())}"
406
+ )
407
+ return {"suites": suites, "issues": issues}
408
+
409
+
410
+ def node_feedback(state: GraphState) -> Dict[str, Any]:
411
+ llm = build_llm("gemini-3-flash-preview", temperature=0.2)
412
+ prompt, parser = build_feedback_agent(llm)
413
+ chain = prompt | llm | parser
414
+ issues = state.get("issues", [])
415
+ feedback = chain.invoke(
416
+ {
417
+ "spec": state["spec"].model_dump(),
418
+ "plan": state["plan"].model_dump(),
419
+ "issues": issues,
420
+ "format_instructions": parser.get_format_instructions(),
421
+ }
422
+ )
423
+ needs_refine = feedback.needs_refine or bool(issues)
424
+ iteration = state.get("iteration", 0) + (1 if needs_refine else 0)
425
+ return {"feedback": feedback, "iteration": iteration}
426
+
427
+
428
+ def should_refine(state: GraphState) -> str:
429
+ max_refines = 1
430
+ if state.get("iteration", 0) > max_refines:
431
+ return "final"
432
+ if state.get("issues"):
433
+ return "refine"
434
+ return "refine" if state["feedback"].needs_refine else "final"
435
+
436
+
437
+ def build_graph():
438
+ graph = StateGraph(GraphState)
439
+ graph.add_node("start", node_start)
440
+ graph.add_node("spec", node_spec)
441
+ graph.add_node("analysis", node_analysis)
442
+ graph.add_node("plan", node_plan)
443
+ graph.add_node("generate", node_generate)
444
+ graph.add_node("feedback", node_feedback)
445
+
446
+ graph.set_entry_point("start")
447
+ graph.add_edge("start", "spec")
448
+ graph.add_edge("start", "analysis")
449
+ graph.add_edge("spec", "plan")
450
+ graph.add_edge("analysis", "plan")
451
+ graph.add_edge("plan", "generate")
452
+ graph.add_edge("generate", "feedback")
453
+ graph.add_conditional_edges(
454
+ "feedback",
455
+ should_refine,
456
+ {
457
+ "refine": "plan",
458
+ "final": END,
459
+ },
460
+ )
461
+
462
+ return graph.compile()
463
+
464
+
465
+ def run_pipeline(
466
+ *,
467
+ problem: str,
468
+ description: str,
469
+ constraints: str,
470
+ code: str,
471
+ language: str,
472
+ student_count: int,
473
+ per_category: int,
474
+ issues: List[str] | None = None,
475
+ ) -> FinalReport:
476
+ app = build_graph()
477
+ state = app.invoke(
478
+ {
479
+ "problem": problem,
480
+ "description": description,
481
+ "constraints": constraints,
482
+ "code": code,
483
+ "language": language,
484
+ "student_count": student_count,
485
+ "per_category": per_category,
486
+ "issues": issues or [],
487
+ }
488
+ )
489
+ return FinalReport(
490
+ spec=state["spec"],
491
+ analysis=state["analysis"],
492
+ plan=state["plan"],
493
+ suites=state["suites"],
494
+ feedback=state["feedback"],
495
+ )