# utils.py import re import subprocess import os from typing import Optional, Any, Type, TypedDict, List from pydantic import BaseModel, Field from langchain.chat_models import init_chat_model from langchain_community.vectorstores import FAISS from langchain_openai.embeddings import OpenAIEmbeddings from langchain_aws import ChatBedrock, ChatBedrockConverse from langchain_anthropic import ChatAnthropic from pathlib import Path import tracking_aws import requests import time import random from botocore.exceptions import ClientError import shutil from config import Config from langchain_ollama import ChatOllama # Global dictionary to store loaded FAISS databases FAISS_DB_CACHE = {} DATABASE_DIR = f"{Path(__file__).resolve().parent.parent}/database/faiss" FAISS_DB_CACHE = { "openfoam_allrun_scripts": FAISS.load_local(f"{DATABASE_DIR}/openfoam_allrun_scripts", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True), "openfoam_tutorials_structure": FAISS.load_local(f"{DATABASE_DIR}/openfoam_tutorials_structure", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True), "openfoam_tutorials_details": FAISS.load_local(f"{DATABASE_DIR}/openfoam_tutorials_details", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True), "openfoam_command_help": FAISS.load_local(f"{DATABASE_DIR}/openfoam_command_help", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True) } class FoamfilePydantic(BaseModel): file_name: str = Field(description="Name of the OpenFOAM input file") folder_name: str = Field(description="Folder where the foamfile should be stored") content: str = Field(description="Content of the OpenFOAM file, written in OpenFOAM dictionary format") class FoamPydantic(BaseModel): list_foamfile: List[FoamfilePydantic] = Field(description="List of OpenFOAM configuration files") class ResponseWithThinkPydantic(BaseModel): think: str = Field(description="Thought process of the LLM") response: str = Field(description="Response of the LLM") class LLMService: def __init__(self, config: object): self.model_version = getattr(config, "model_version", "gpt-4o") self.temperature = getattr(config, "temperature", 0) self.model_provider = getattr(config, "model_provider", "openai") # Initialize statistics self.total_calls = 0 self.total_prompt_tokens = 0 self.total_completion_tokens = 0 self.total_tokens = 0 self.failed_calls = 0 self.retry_count = 0 # Initialize the LLM if self.model_provider.lower() == "bedrock": bedrock_runtime = tracking_aws.new_default_client() self.llm = ChatBedrockConverse( client=bedrock_runtime, model_id=self.model_version, temperature=self.temperature, max_tokens=8192 ) elif self.model_provider.lower() == "anthropic": self.llm = ChatAnthropic( model=self.model_version, temperature=self.temperature ) elif self.model_provider.lower() == "openai": self.llm = init_chat_model( self.model_version, model_provider=self.model_provider, temperature=self.temperature ) elif self.model_provider.lower() == "ollama": try: response = requests.get("http://localhost:11434/api/version", timeout=2) # If request successful, service is running except requests.exceptions.RequestException: print("Ollama is not running, starting it...") subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Wait for service to start time.sleep(5) # Give it 3 seconds to initialize self.llm = ChatOllama( model=self.model_version, temperature=self.temperature, num_predict=-1, num_ctx=131072, base_url="http://localhost:11434" ) else: raise ValueError(f"{self.model_provider} is not a supported model_provider") def invoke(self, user_prompt: str, system_prompt: Optional[str] = None, pydantic_obj: Optional[Type[BaseModel]] = None, max_retries: int = 10) -> Any: """ Invoke the LLM with the given prompts and return the response. Args: user_prompt: The user's prompt system_prompt: Optional system prompt pydantic_obj: Optional Pydantic model for structured output max_retries: Maximum number of retries for throttling errors Returns: The LLM response with token usage statistics """ self.total_calls += 1 messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_prompt}) # Calculate prompt tokens prompt_tokens = 0 for message in messages: prompt_tokens += self.llm.get_num_tokens(message["content"]) retry_count = 0 while True: try: if pydantic_obj: structured_llm = self.llm.with_structured_output(pydantic_obj) response = structured_llm.invoke(messages) else: if self.model_version.startswith("deepseek"): structured_llm = self.llm.with_structured_output(ResponseWithThinkPydantic) response = structured_llm.invoke(messages) # Extract the resposne without the think response = response.response else: response = self.llm.invoke(messages) response = response.content # Calculate completion tokens response_content = str(response) completion_tokens = self.llm.get_num_tokens(response_content) total_tokens = prompt_tokens + completion_tokens # Update statistics self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens self.total_tokens += total_tokens return response except ClientError as e: if e.response['Error']['Code'] == 'Throttling' or e.response['Error']['Code'] == 'TooManyRequestsException': retry_count += 1 self.retry_count += 1 if retry_count > max_retries: self.failed_calls += 1 raise Exception(f"Maximum retries ({max_retries}) exceeded: {str(e)}") base_delay = 1.0 max_delay = 60.0 delay = min(max_delay, base_delay * (2 ** (retry_count - 1))) jitter = random.uniform(0, 0.1 * delay) sleep_time = delay + jitter print(f"ThrottlingException occurred: {str(e)}. Retrying in {sleep_time:.2f} seconds (attempt {retry_count}/{max_retries})") time.sleep(sleep_time) else: self.failed_calls += 1 raise e except Exception as e: self.failed_calls += 1 raise e def get_statistics(self) -> dict: """ Get the current statistics of the LLM service. Returns: Dictionary containing various statistics """ return { "total_calls": self.total_calls, "failed_calls": self.failed_calls, "retry_count": self.retry_count, "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, "total_tokens": self.total_tokens, "average_prompt_tokens": self.total_prompt_tokens / self.total_calls if self.total_calls > 0 else 0, "average_completion_tokens": self.total_completion_tokens / self.total_calls if self.total_calls > 0 else 0, "average_tokens": self.total_tokens / self.total_calls if self.total_calls > 0 else 0 } def print_statistics(self) -> None: """ Print the current statistics of the LLM service. """ stats = self.get_statistics() print("\n") print(f"Total calls: {stats['total_calls']}") print(f"Failed calls: {stats['failed_calls']}") print(f"Total retries: {stats['retry_count']}") print(f"Total prompt tokens: {stats['total_prompt_tokens']}") print(f"Total completion tokens: {stats['total_completion_tokens']}") print(f"Total tokens: {stats['total_tokens']}") print(f"Average prompt tokens per call: {stats['average_prompt_tokens']:.2f}") print(f"Average completion tokens per call: {stats['average_completion_tokens']:.2f}") print(f"Average tokens per call: {stats['average_tokens']:.2f}\n") print("") class GraphState(TypedDict): user_requirement: str config: Config case_dir: str tutorial: str case_name: str subtasks: List[dict] current_subtask_index: int error_command: Optional[str] error_content: Optional[str] loop_count: int # Additional state fields that will be added during execution llm_service: Optional['LLMService'] case_stats: Optional[dict] tutorial_reference: Optional[str] case_path_reference: Optional[str] dir_structure_reference: Optional[str] case_info: Optional[str] allrun_reference: Optional[str] dir_structure: Optional[dict] commands: Optional[List[str]] foamfiles: Optional[dict] error_logs: Optional[List[str]] history_text: Optional[List[str]] case_domain: Optional[str] case_category: Optional[str] case_solver: Optional[str] # Mesh-related state fields mesh_info: Optional[dict] mesh_commands: Optional[List[str]] custom_mesh_used: Optional[bool] mesh_type: Optional[str] custom_mesh_path: Optional[str] # Review and rewrite related fields review_analysis: Optional[str] input_writer_mode: Optional[str] # HPC-related fields job_id: Optional[str] cluster_info: Optional[dict] slurm_script_path: Optional[str] def tokenize(text: str) -> str: # Replace underscores with spaces text = text.replace('_', ' ') # Insert a space between a lowercase letter and an uppercase letter (global match) text = re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', text) return text.lower() def save_file(path: str, content: str) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'w') as f: f.write(content) print(f"Saved file at {path}") def read_file(path: str) -> str: if os.path.exists(path): with open(path, 'r') as f: return f.read() return "" def list_case_files(case_dir: str) -> str: files = [f for f in os.listdir(case_dir) if os.path.isfile(os.path.join(case_dir, f))] return ", ".join(files) def remove_files(directory: str, prefix: str) -> None: for file in os.listdir(directory): if file.startswith(prefix): os.remove(os.path.join(directory, file)) print(f"Removed files with prefix '{prefix}' in {directory}") def remove_file(path: str) -> None: if os.path.exists(path): os.remove(path) print(f"Removed file {path}") def remove_numeric_folders(case_dir: str) -> None: """ Remove all folders in case_dir that represent numeric values, including those with decimal points, except for the "0" folder. Args: case_dir (str): The directory path to process """ for item in os.listdir(case_dir): item_path = os.path.join(case_dir, item) if os.path.isdir(item_path) and item != "0": try: # Try to convert to float to check if it's a numeric value float(item) # If conversion succeeds, it's a numeric folder try: shutil.rmtree(item_path) print(f"Removed numeric folder: {item_path}") except Exception as e: print(f"Error removing folder {item_path}: {str(e)}") except ValueError: # Not a numeric value, so we keep this folder pass def run_command(script_path: str, out_file: str, err_file: str, working_dir: str, config : Config) -> None: print(f"Executing script {script_path} in {working_dir}") os.chmod(script_path, 0o777) openfoam_dir = os.getenv("WM_PROJECT_DIR") command = f"source {openfoam_dir}/etc/bashrc && bash {os.path.abspath(script_path)}" timeout_seconds = config.max_time_limit with open(out_file, 'w') as out, open(err_file, 'w') as err: process = subprocess.Popen( ['bash', "-c", command], cwd=working_dir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.DEVNULL, text=True ) # stdout, stderr = process.communicate() # out.write(stdout) # err.write(stderr) try: stdout, stderr = process.communicate(timeout=timeout_seconds) out.write(stdout) err.write(stderr) except subprocess.TimeoutExpired: process.kill() stdout, stderr = process.communicate() timeout_message = ( "OpenFOAM execution took too long. " "This case, if set up right, does not require such large execution times.\n" ) out.write(timeout_message + stdout) err.write(timeout_message + stderr) print(f"Execution timed out: {script_path}") print(f"Executed script {script_path}") def check_foam_errors(directory: str) -> list: error_logs = [] # DOTALL mode allows '.' to match newline characters pattern = re.compile(r"ERROR:(.*)", re.DOTALL) for file in os.listdir(directory): if file.startswith("log"): filepath = os.path.join(directory, file) with open(filepath, 'r') as f: content = f.read() match = pattern.search(content) if match: error_content = match.group(0).strip() error_logs.append({"file": file, "error_content": error_content}) elif "error" in content.lower(): print(f"Warning: file {file} contains 'error' but does not match expected format.") return error_logs def extract_commands_from_allrun_out(out_file: str) -> list: commands = [] if not os.path.exists(out_file): return commands with open(out_file, 'r') as f: for line in f: if line.startswith("Running "): parts = line.split(" ") if len(parts) > 1: commands.append(parts[1].strip()) return commands def parse_case_name(text: str) -> str: match = re.search(r'case name:\s*(.+)', text, re.IGNORECASE) return match.group(1).strip() if match else "default_case" def split_subtasks(text: str) -> list: header_match = re.search(r'splits into (\d+) subtasks:', text, re.IGNORECASE) if not header_match: print("Warning: No subtasks header found in the response.") return [] num_subtasks = int(header_match.group(1)) subtasks = re.findall(r'subtask\d+:\s*(.*)', text, re.IGNORECASE) if len(subtasks) != num_subtasks: print(f"Warning: Expected {num_subtasks} subtasks but found {len(subtasks)}.") return subtasks def parse_context(text: str) -> str: match = re.search(r'FoamFile\s*\{.*?(?=```|$)', text, re.DOTALL | re.IGNORECASE) if match: return match.group(0).strip() print("Warning: Could not parse context; returning original text.") return text def parse_file_name(subtask: str) -> str: match = re.search(r'openfoam\s+(.*?)\s+foamfile', subtask, re.IGNORECASE) return match.group(1).strip() if match else "" def parse_folder_name(subtask: str) -> str: match = re.search(r'foamfile in\s+(.*?)\s+folder', subtask, re.IGNORECASE) return match.group(1).strip() if match else "" def find_similar_file(description: str, tutorial: str) -> str: start_pos = tutorial.find(description) if start_pos == -1: return "None" end_marker = "input_file_end." end_pos = tutorial.find(end_marker, start_pos) if end_pos == -1: return "None" return tutorial[start_pos:end_pos + len(end_marker)] def read_commands(file_path: str) -> str: if not os.path.exists(file_path): raise FileNotFoundError(f"Commands file not found: {file_path}") with open(file_path, 'r') as f: # join non-empty lines with a comma return ", ".join(line.strip() for line in f if line.strip()) def find_input_file(case_dir: str, command: str) -> str: for root, _, files in os.walk(case_dir): for file in files: if command in file: return os.path.join(root, file) return "" def retrieve_faiss(database_name: str, query: str, topk: int = 1) -> dict: """ Retrieve a similar case from a FAISS database. """ if database_name not in FAISS_DB_CACHE: raise ValueError(f"Database '{database_name}' is not loaded.") # Tokenize the query query = tokenize(query) vectordb = FAISS_DB_CACHE[database_name] docs = vectordb.similarity_search(query, k=topk) if not docs: raise ValueError(f"No documents found for query: {query}") formatted_results = [] for doc in docs: metadata = doc.metadata or {} if database_name == "openfoam_allrun_scripts": formatted_results.append({ "index": doc.page_content, "full_content": metadata.get("full_content", "unknown"), "case_name": metadata.get("case_name", "unknown"), "case_domain": metadata.get("case_domain", "unknown"), "case_category": metadata.get("case_category", "unknown"), "case_solver": metadata.get("case_solver", "unknown"), "dir_structure": metadata.get("dir_structure", "unknown"), "allrun_script": metadata.get("allrun_script", "N/A") }) elif database_name == "openfoam_command_help": formatted_results.append({ "index": doc.page_content, "full_content": metadata.get("full_content", "unknown"), "command": metadata.get("command", "unknown"), "help_text": metadata.get("help_text", "unknown") }) elif database_name == "openfoam_tutorials_structure": formatted_results.append({ "index": doc.page_content, "full_content": metadata.get("full_content", "unknown"), "case_name": metadata.get("case_name", "unknown"), "case_domain": metadata.get("case_domain", "unknown"), "case_category": metadata.get("case_category", "unknown"), "case_solver": metadata.get("case_solver", "unknown"), "dir_structure": metadata.get("dir_structure", "unknown") }) elif database_name == "openfoam_tutorials_details": formatted_results.append({ "index": doc.page_content, "full_content": metadata.get("full_content", "unknown"), "case_name": metadata.get("case_name", "unknown"), "case_domain": metadata.get("case_domain", "unknown"), "case_category": metadata.get("case_category", "unknown"), "case_solver": metadata.get("case_solver", "unknown"), "dir_structure": metadata.get("dir_structure", "unknown"), "tutorials": metadata.get("tutorials", "N/A") }) else: raise ValueError(f"Unknown database name: {database_name}") return formatted_results def parse_directory_structure(data: str) -> dict: """ Parses the directory structure string and returns a dictionary where: - Keys: directory names - Values: count of files in that directory. """ directory_file_counts = {} # Find all ... blocks in the input string. dir_blocks = re.findall(r'(.*?)', data, re.DOTALL) for block in dir_blocks: # Extract the directory name (everything after "directory name:" until the first period) dir_name_match = re.search(r'directory name:\s*(.*?)\.', block) # Extract the list of file names within square brackets files_match = re.search(r'File names in this directory:\s*\[(.*?)\]', block) if dir_name_match and files_match: dir_name = dir_name_match.group(1).strip() files_str = files_match.group(1) # Split the file names by comma, removing any surrounding whitespace file_list = [filename.strip() for filename in files_str.split(',')] directory_file_counts[dir_name] = len(file_list) return directory_file_counts