dina1 commited on
Commit
b4c62e2
·
verified ·
1 Parent(s): 796a18f

Create agents/graph_builder_langgraph.py

Browse files
Files changed (1) hide show
  1. agents/graph_builder_langgraph.py +118 -0
agents/graph_builder_langgraph.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, TypedDict, Optional
2
+ from langgraph.graph import StateGraph, START, END
3
+ from .document_parser import parse_documents
4
+ from .requirements_extractor import extract_requirements
5
+ from .ui_generator import generate_ui_html
6
+ from .ui_qa import run_ui_qa
7
+ from .ui_repair import repair_ui_html
8
+ from .report_generator import generate_pdf_report
9
+ import os
10
+ import uuid
11
+ import asyncio
12
+
13
+ # ----- shared state -----
14
+ class State(TypedDict, total=False):
15
+ file_paths: List[str]
16
+ texts: str
17
+ files_meta: List[Dict[str, Any]]
18
+ requirements: Dict[str, Any]
19
+ html: str
20
+ qa: Dict[str, Any]
21
+ attempts: int
22
+ pdf_path: str
23
+ public_url: str
24
+
25
+ # ----- nodes -----
26
+ async def node_parse(state: State, config=None, runtime=None) -> State:
27
+ res = await parse_documents(state["file_paths"])
28
+ return {"texts": res["texts"], "files_meta": res["files"]}
29
+
30
+ async def node_extract(state: State, config=None, runtime=None) -> State:
31
+ reqs = await extract_requirements(state["texts"])
32
+ return {"requirements": reqs}
33
+
34
+ async def node_generate(state: State, config=None, runtime=None) -> State:
35
+ reference_html = runtime.config.get("reference_html") if runtime else os.environ.get("REFERENCE_HTML", "")
36
+ if not reference_html and os.path.exists("templates/demo_qms_design.html"):
37
+ with open("templates/demo_qms_design.html", "r", encoding="utf-8") as f:
38
+ reference_html = f.read()
39
+ res = await generate_ui_html(state["requirements"], reference_html)
40
+ return {"html": res["html"], "attempts": state.get("attempts", 0) + 1}
41
+
42
+ async def node_qa(state: State, config=None, runtime=None) -> State:
43
+ qa = await run_ui_qa(state["html"])
44
+ return {"qa": qa}
45
+
46
+ async def node_repair(state: State, config=None, runtime=None) -> State:
47
+ repaired = await repair_ui_html(state["html"], state["qa"]["comment"])
48
+ return {"html": repaired["html"]}
49
+
50
+ async def node_report(state: State, config=None, runtime=None) -> State:
51
+ html = state["html"]
52
+ session = uuid.uuid4().hex
53
+ out_dir = os.path.join("static", "outputs")
54
+ os.makedirs(out_dir, exist_ok=True)
55
+ html_path = os.path.join(out_dir, f"{session}.html")
56
+ with open(html_path, "w", encoding="utf-8") as f:
57
+ f.write(html)
58
+
59
+ base_url = state.get("public_url") or os.environ.get("BASE_PUBLIC_URL", "")
60
+ if base_url:
61
+ base_url = base_url.rstrip("/")
62
+ public_url = f"{base_url}/static/outputs/{session}.html"
63
+ else:
64
+ public_url = f"/static/outputs/{session}.html"
65
+
66
+ pdf_info = await generate_pdf_report(public_url)
67
+ return {"pdf_path": pdf_info["pdf_path"], "public_url": public_url}
68
+
69
+ # ----- conditional guard -----
70
+ def qa_guard(state: State) -> str:
71
+ """Return next node key based on QA score."""
72
+ score = int(state.get("qa", {}).get("score", 0))
73
+ attempts = int(state.get("attempts", 0))
74
+ max_retries = int(os.environ.get("MAX_RETRIES", "2"))
75
+ if score >= int(os.environ.get("QA_THRESHOLD", "85")) or attempts >= max_retries:
76
+ return "report"
77
+ return "repair"
78
+
79
+ # ----- builder -----
80
+ def build_langgraph_runner(reference_html: Optional[str] = None) -> StateGraph:
81
+ g = StateGraph(State)
82
+ g.add_node("parse", node_parse)
83
+ g.add_node("extract", node_extract)
84
+ g.add_node("generate", node_generate)
85
+ g.add_node("qa", node_qa)
86
+ g.add_node("repair", node_repair)
87
+ g.add_node("report", node_report)
88
+
89
+ g.set_entry_point("parse")
90
+ g.set_finish_point("report")
91
+
92
+ # linear part
93
+ g.add_edge(START, "parse")
94
+ g.add_edge("parse", "extract")
95
+ g.add_edge("extract", "generate")
96
+ g.add_edge("generate", "qa")
97
+
98
+ # explicit conditional edges
99
+ g.add_conditional_edges(
100
+ "qa",
101
+ qa_guard,
102
+ {"report": "report", "repair": "repair"},
103
+ )
104
+
105
+ # loop
106
+ g.add_edge("repair", "generate")
107
+
108
+ graph = g.compile()
109
+ if reference_html:
110
+ graph.config = {"reference_html": reference_html}
111
+ return graph
112
+
113
+ # ----- runner -----
114
+ async def run_graph_once(graph, file_paths: List[str], base_public_url: str | None = None) -> Dict[str, Any]:
115
+ if base_public_url:
116
+ os.environ["BASE_PUBLIC_URL"] = base_public_url
117
+ state: State = {"file_paths": file_paths, "attempts": 0}
118
+ return await graph.ainvoke(state)