from dataclasses import dataclass, field from typing import List, Optional, TypedDict, Literal from langgraph.graph import StateGraph, START, END from langgraph.types import Command import argparse from pathlib import Path from utils import LLMService, GraphState from config import Config from nodes.architect_node import architect_node from nodes.meshing_node import meshing_node from nodes.input_writer_node import input_writer_node from nodes.local_runner_node import local_runner_node from nodes.reviewer_node import reviewer_node from nodes.visualization_node import visualization_node from nodes.hpc_runner_node import hpc_runner_node from router_func import ( route_after_architect, route_after_input_writer, route_after_runner, route_after_reviewer ) import json def create_foam_agent_graph() -> StateGraph: """Create the OpenFOAM agent workflow graph.""" # Create the graph workflow = StateGraph(GraphState) # Add nodes workflow.add_node("architect", architect_node) workflow.add_node("meshing", meshing_node) workflow.add_node("input_writer", input_writer_node) workflow.add_node("local_runner", local_runner_node) workflow.add_node("hpc_runner", hpc_runner_node) workflow.add_node("reviewer", reviewer_node) workflow.add_node("visualization", visualization_node) # Add edges workflow.add_edge(START, "architect") workflow.add_conditional_edges("architect", route_after_architect) workflow.add_edge("meshing", "input_writer") workflow.add_conditional_edges("input_writer", route_after_input_writer) workflow.add_conditional_edges("hpc_runner", route_after_runner) workflow.add_conditional_edges("local_runner", route_after_runner) workflow.add_conditional_edges("reviewer", route_after_reviewer) workflow.add_edge("visualization", END) return workflow def initialize_state(user_requirement: str, config: Config, custom_mesh_path: Optional[str] = None) -> GraphState: case_stats = json.load(open(f"{config.database_path}/raw/openfoam_case_stats.json", "r")) # mesh_type = "custom_mesh" if custom_mesh_path else "standard_mesh" state = GraphState( user_requirement=user_requirement, config=config, case_dir="", tutorial="", case_name="", subtasks=[], current_subtask_index=0, error_command=None, error_content=None, loop_count=0, llm_service=LLMService(config), case_stats=case_stats, tutorial_reference=None, case_path_reference=None, dir_structure_reference=None, case_info=None, allrun_reference=None, dir_structure=None, commands=None, foamfiles=None, error_logs=None, history_text=None, case_domain=None, case_category=None, case_solver=None, mesh_info=None, mesh_commands=None, custom_mesh_used=None, mesh_type=None, custom_mesh_path=custom_mesh_path, review_analysis=None, input_writer_mode="initial", job_id=None, cluster_info=None, slurm_script_path=None ) if custom_mesh_path: print(f"Custom mesh path: {custom_mesh_path}") else: print("No custom mesh path provided.") return state def main(user_requirement: str, config: Config, custom_mesh_path: Optional[str] = None): """Main function to run the OpenFOAM workflow.""" # Create and compile the graph workflow = create_foam_agent_graph() app = workflow.compile() # Initialize the state initial_state = initialize_state(user_requirement, config, custom_mesh_path) print("Starting Foam-Agent...") # Invoke the graph try: result = app.invoke(initial_state) print("Workflow completed successfully!") # Print final statistics if result.get("llm_service"): result["llm_service"].print_statistics() # print(f"Final state: {result}") except Exception as e: print(f"Workflow failed with error: {e}") raise if __name__ == "__main__": # python main.py parser = argparse.ArgumentParser( description="Run the OpenFOAM workflow" ) parser.add_argument( "--prompt_path", type=str, default=f"{Path(__file__).parent.parent}/user_requirement.txt", help="User requirement file path for the workflow.", ) parser.add_argument( "--output_dir", type=str, default="", help="Output directory for the workflow.", ) parser.add_argument( "--custom_mesh_path", type=str, default=None, help="Path to custom mesh file (e.g., .msh, .stl, .obj). If not provided, no custom mesh will be used.", ) args = parser.parse_args() print(args) # Initialize configuration. config = Config() if args.output_dir != "": config.case_dir = args.output_dir with open(args.prompt_path, 'r') as f: user_requirement = f.read() main(user_requirement, config, args.custom_mesh_path)