| """ |
| AI trajectory analysis agent β Claude, OpenAI, Gemini. |
| All three providers support reliable tool use / function calling. |
| """ |
|
|
| import os |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| from .knowledge_base import CPPTrajKnowledgeBase |
| from .llm_backends import LLMBackend, create_backend |
| from .runner import CPPTrajRunner |
|
|
| TOOLS = [ |
| { |
| "name": "search_cpptraj_docs", |
| "description": ( |
| "Search the cpptraj manual for exact command names, syntax, and options. " |
| "ALWAYS call this before writing any cpptraj script to get the correct command name. " |
| "Returns the most relevant manual sections with exact syntax." |
| ), |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "query": {"type": "string", "description": "What to search for, e.g. 'radius of gyration', 'rmsd backbone', 'hydrogen bonds'"}, |
| }, |
| "required": ["query"], |
| }, |
| }, |
| { |
| "name": "run_cpptraj_script", |
| "description": ( |
| "Write and execute a cpptraj script to analyze the trajectory. " |
| "Always include parm, trajin, analysis commands, and 'go'. " |
| "Returns stdout, stderr, and output files generated." |
| ), |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "script": {"type": "string", "description": "Complete cpptraj script"}, |
| "description": {"type": "string", "description": "What this script does"}, |
| }, |
| "required": ["script", "description"], |
| }, |
| }, |
| { |
| "name": "read_output_file", |
| "description": "Read the content of an output file produced by a previous cpptraj run.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "filename": {"type": "string", "description": "Output file name (e.g. rmsd.dat)"}, |
| }, |
| "required": ["filename"], |
| }, |
| }, |
| { |
| "name": "list_output_files", |
| "description": "List all output files in the working directory.", |
| "input_schema": {"type": "object", "properties": {}}, |
| }, |
| { |
| "name": "run_python_script", |
| "description": ( |
| "Write and execute a Python script for post-processing, plotting, or statistical " |
| "analysis of cpptraj output files. Use matplotlib to save plots as PNG. " |
| "All output files (PNG, CSV, etc.) are saved to the working directory. " |
| "Returns stdout, stderr, and any new files created." |
| ), |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "script": {"type": "string", "description": "Complete Python script to execute"}, |
| "description": {"type": "string", "description": "What this script does"}, |
| }, |
| "required": ["script", "description"], |
| }, |
| }, |
| ] |
|
|
| SYSTEM_PROMPT = """\ |
| You are an expert computational biophysicist specializing in MD simulation analysis. |
| |
| ## EXECUTION RULES β NEVER VIOLATE |
| - NEVER write a script as text in your response. ALWAYS execute it immediately via run_cpptraj_script or run_python_script. |
| - NEVER describe what you are going to do. Just do it. No preamble, no step lists, no "Step 1 / Step 2". |
| - NEVER show the user a script and ask them to run it. You run it. |
| - If a previous script failed, fix it and call the tool again immediately. Do not explain the fix β just run it. |
| - After running any script: 1-2 sentence summary MAXIMUM. No markdown tables, no bullet lists, no interpretation sections, no headers. Plain text only. |
| - cpptraj task β run_cpptraj_script | plotting/stats β run_python_script | list files β list_output_files |
| - After cpptraj finishes: read output, report key numbers, then STOP. Never auto-run Python after cpptraj. |
| - run_python_script is ONLY for: plot, graph, chart, visualize, histogram, heatmap, statistics, stats, analyze further. |
| |
| cpptraj syntax (spaces, NOT colons): `parm file.prmtop` not `parm: file.prmtop`. Always end with `go`. |
| - Frame count: parm + trajin + go (stdout shows count). |
| - ALWAYS strip :WAT before autoimage and before any RMSD/distance/secstruct analysis. Order: strip β autoimage β analysis. Without stripping water first, autoimage anchors to water molecules causing artificially huge RMSD (20-40 Γ
). |
| - Output: `out rmsd.dat`. References: `first`, `refindex -1`. Masks: `@CA,C,N,O` `@CA` `:1-100` `!:WAT` |
| |
| ## cpptraj command names |
| Write scripts directly β you know cpptraj syntax well. |
| Only call search_cpptraj_docs when genuinely uncertain about an exact command name or obscure syntax. |
| After search_cpptraj_docs returns results, IMMEDIATELY call run_cpptraj_script β never stop to explain or summarize the search results. |
| |
| ## Multi-step workflows |
| Each run_cpptraj_script call is a fresh cpptraj process β in-memory datasets do NOT persist between calls. |
| - ALWAYS write every intermediate result to disk with `out filename` (matrix, diagmatrix, eigenvectors, etc.) |
| - If a subsequent script needs data from a previous run, reload it from disk using `readdata filename name datasetname` |
| - If unsure how many steps an analysis needs, call search_cpptraj_docs first to get the full workflow before writing the script. |
| |
| ## Python Environment |
| Available packages: pandas, numpy, matplotlib, scikit-learn, scipy. NOT available: MDAnalysis, parmed, pytraj, openmm. |
| NEVER use `delim_whitespace=True` (deprecated in pandas 2.x) β always use `sep=r'\s+'`. |
| |
| Python: `plt.savefig('f.png', dpi=150, bbox_inches='tight')` then `plt.close()`. Never plt.show(). |
| Read .dat files with pandas: `pd.read_csv('f.dat', sep=r'\\s+', comment='#')`. Print key stats to stdout. |
| - Before plotting any matrix/heatmap: always print `data.min().min(), data.max().max()` to stdout to validate the actual data range. Never assume or manually normalize β use the real range for vmin/vmax. |
| |
| ## Residue Classification (critical β never misclassify) |
| Protein residues (NOT ligands): ALA ARG ASN ASP CYS CYX GLN GLU GLY HIS HIE HID HIP ILE LEU LYS MET PHE PRO SER THR TRP TYR VAL |
| Capping groups (NOT ligands β part of the protein): ACE (N-terminal acetyl cap) NME (C-terminal methylamide cap) NHE NH2 |
| Water/solvent (NOT ligands): WAT HOH TIP3 TIP4 |
| Ions (NOT ligands): Na+ Cl- K+ MG CA ZN NA CL Mg2+ Ca2+ |
| Ligand = any residue that is NONE of the above. |
| |
| ## Ligand and Residue Information |
| The ## Topology Composition section at the top of every message already contains the ligand residue ID, name, atom count, and ready-to-use masks (e.g. ligand mask: :203, protein mask: :1-202). |
| NEVER run resinfo, parmed, or any identification script β the information is already provided. Use the masks directly. |
| """ |
|
|
|
|
| class TrajectoryAgent: |
| def __init__(self, runner: CPPTrajRunner, kb: CPPTrajKnowledgeBase, |
| provider: str = "", api_key: str = "", model: str = "", base_url: str = ""): |
| self.runner = runner |
| self.kb = kb |
| self.conversation_history: list[dict] = [] |
| self.parm_file: Path | None = None |
| self.traj_files: list[Path] = [] |
| self._topology_info: dict = {} |
|
|
| provider = provider or os.environ.get("LLM_PROVIDER", "claude") |
| model = model or os.environ.get("LLM_MODEL", "") |
| base_url = base_url or os.environ.get("LLM_BASE_URL", "") |
| |
|
|
| self._backend: LLMBackend = create_backend(provider, api_key, model, base_url) |
| self._system_prompt = self._build_system_prompt(provider, model) |
|
|
| def reconfigure(self, provider: str, api_key: str, model: str, base_url: str = ""): |
| self._backend = create_backend(provider, api_key, model, base_url) |
| self.conversation_history = [] |
| self._system_prompt = self._build_system_prompt(provider, model) |
|
|
| @staticmethod |
| def _build_system_prompt(provider: str, model: str) -> str: |
| prompt = SYSTEM_PROMPT |
| if provider == "ollama": |
| |
| if "qwen3" in model.lower(): |
| prompt = "/no_think\n" + prompt |
| |
| prompt = prompt.replace( |
| "## cpptraj command names\n" |
| "Write scripts directly β you know cpptraj syntax well.\n" |
| "Only call search_cpptraj_docs when genuinely uncertain about an exact command name or obscure syntax.", |
| "## cpptraj command names\n" |
| "ALWAYS call search_cpptraj_docs BEFORE writing any cpptraj script β even for common analyses.\n" |
| "Use ONLY the exact command name returned by the search." |
| ) |
| return prompt |
|
|
| _PROTEIN_RES = { |
| "ALA","ARG","ASN","ASP","CYS","CYX","GLN","GLU","GLY", |
| "HIS","HIE","HID","HIP","ILE","LEU","LYS","MET","PHE", |
| "PRO","SER","THR","TRP","TYR","VAL", |
| "ACE","NME","NHE","NH2", |
| } |
| _ION_RES = { |
| "NA","CL","K","MG","CA","ZN","NA+","CL-","K+", |
| "Na+","Cl-","Mg2+","Ca2+", |
| "SOD","CLA","POT","CAL", |
| "LI","RB","CS","F","BR","I", |
| } |
| _WATER_RES = {"WAT","HOH","TIP3","TIP4","SPC","SPCE"} |
| |
| _KNOWN_NON_LIGAND = _PROTEIN_RES | _ION_RES | _WATER_RES |
|
|
| def set_files(self, parm_file: Path | None, traj_files: list[Path]): |
| self.parm_file = parm_file |
| self.traj_files = traj_files |
| self._topology_info: dict = {} |
| if parm_file and parm_file.exists(): |
| self._scan_topology(parm_file) |
|
|
| def _scan_topology(self, parm_file: Path): |
| """Run resinfo once on upload and cache ligand/residue info.""" |
| import re |
| script = f"parm {parm_file}\nresinfo *\ngo" |
| res = self.runner.run_script(script) |
| stdout = res.get("stdout", "") |
| ligands, n_protein, n_water, n_ions = [], 0, 0, 0 |
| n_atoms_total = 0 |
| for line in stdout.splitlines(): |
| m = re.match(r'\s*(\d+)\s+(\S+)\s+\d+\s+\d+\s+(\d+)\s+', line) |
| if not m: |
| continue |
| resid, resname, natoms = int(m.group(1)), m.group(2), int(m.group(3)) |
| n_atoms_total += natoms |
| rname_up = resname.upper() |
| if rname_up in {r.upper() for r in self._WATER_RES}: |
| n_water += 1 |
| elif rname_up in {r.upper() for r in self._ION_RES}: |
| n_ions += 1 |
| elif rname_up in {r.upper() for r in self._PROTEIN_RES}: |
| n_protein += 1 |
| else: |
| ligands.append({"resid": resid, "name": resname, "natoms": natoms}) |
| n_residues_total = n_protein + n_water + n_ions + len(ligands) |
| self._topology_info = { |
| "n_atoms_total": n_atoms_total, |
| "n_residues_total": n_residues_total, |
| "n_protein_res": n_protein, |
| "n_water": n_water, |
| "n_ions": n_ions, |
| "ligands": ligands, |
| } |
|
|
| def reset_conversation(self): |
| self.conversation_history = [] |
|
|
| @property |
| def provider(self): return self._backend.provider |
|
|
| @property |
| def model(self): return self._backend.model |
|
|
| |
| _CMD_ALIASES = { |
| "rg": "radgyr", "radius of gyration": "radgyr", "radgyr": "radgyr", |
| "rmsf": "atomicfluct", "bfactor": "atomicfluct", "b-factor": "atomicfluct", |
| "rmsd": "rmsd", "hbond": "hbond", "hydrogen bond": "hbond", |
| "secondary structure": "secstruct", "dssp": "secstruct", |
| "cluster": "cluster", "clustering": "cluster", |
| "contact map": "nativecontacts", "native contact": "nativecontacts", |
| "pca": "matrix", "principal component": "pca", |
| "dihedral": "dihedral", "phi psi": "dihedral", |
| "distance": "distance", "angle": "angle", |
| "sasa": "surf", "surface area": "surf", |
| "diffusion": "diffusion", "msd": "diffusion", |
| } |
|
|
| def _build_user_message_with_rag(self, query: str) -> str: |
| fc = self._build_file_context() |
| return f"{fc}\n\n## User Request\n{query}" |
|
|
| def _trim_history(self, history: list) -> list: |
| """Keep the last few turns, always cutting at a real user-text boundary. |
| |
| Must never start the window on a tool-result wrapper (Claude list content, |
| OpenAI _multi, or Gemini _fn_responses) β that produces orphaned results |
| the API rejects with a 400. |
| """ |
| if len(history) <= 8: |
| return history |
|
|
| |
| real_user_idx = [] |
| for i, msg in enumerate(history): |
| if msg["role"] != "user": |
| continue |
| |
| if "_fn_responses" in msg: |
| continue |
| content = msg.get("content", "") |
| if isinstance(content, str) and content.strip(): |
| real_user_idx.append(i) |
| elif isinstance(content, list): |
| |
| if any(not (isinstance(b, dict) and b.get("type") == "tool_result") |
| for b in content): |
| real_user_idx.append(i) |
|
|
| |
| if len(real_user_idx) <= 2: |
| return history |
| return history[real_user_idx[-3]:] |
|
|
| @staticmethod |
| def _compress_result(result: str) -> str: |
| """Trim tool result stored in history to save tokens.""" |
| if len(result) <= 200: |
| return result |
| lines = result.splitlines() |
| head = "\n".join(lines[:8]) |
| return f"{head}\n⦠[{len(lines)} lines total, truncated]" |
|
|
| def _safe_trim(self, history: list) -> list: |
| """Emergency trim if total history exceeds ~120k chars (~30k tokens).""" |
| total = sum(len(str(m.get("content", ""))) for m in history) |
| if total <= 120_000: |
| return history |
| |
| real_user_idx = [] |
| for i, msg in enumerate(history): |
| if msg["role"] != "user": |
| continue |
| content = msg.get("content", "") |
| if isinstance(content, str): |
| real_user_idx.append(i) |
| elif isinstance(content, list): |
| if any(not (isinstance(b, dict) and b.get("type") == "tool_result") |
| for b in content): |
| real_user_idx.append(i) |
| if len(real_user_idx) >= 2: |
| return history[real_user_idx[-2]:] |
| return history[-4:] |
|
|
| def _build_file_context(self) -> str: |
| parts = ["## Available Files (use EXACTLY these names in every cpptraj script)"] |
| if self.parm_file: |
| parts.append(f"- TOPOLOGY β `parm {self.parm_file.name}` β use with parm command") |
| else: |
| parts.append("- Topology: *not uploaded yet*") |
| if self.traj_files: |
| for tf in self.traj_files: |
| parts.append(f"- TRAJECTORY β `trajin {tf.name}` β use with trajin command") |
| else: |
| parts.append("- Trajectory: *not uploaded yet*") |
|
|
| info = getattr(self, "_topology_info", {}) |
| if info: |
| parts.append(f"\n## Topology Composition") |
| if info.get("n_atoms_total"): |
| parts.append(f"- Total atoms: {info['n_atoms_total']}") |
| if info.get("n_residues_total"): |
| parts.append(f"- Total residues: {info['n_residues_total']}") |
| parts.append(f"- Protein residues: {info['n_protein_res']}") |
| if info.get('n_ions'): |
| parts.append(f"- Ions: {info['n_ions']} residues") |
| parts.append(f"- Water molecules: {info['n_water']}") |
| ligs = info.get("ligands", []) |
| if ligs: |
| parts.append(f"- Ligands ({len(ligs)} molecule{'s' if len(ligs)>1 else ''}):") |
| for lig in ligs: |
| parts.append(f" β’ {lig['name']} β residue :{lig['resid']} β {lig['natoms']} atoms") |
| parts.append(f" β protein mask: :1-{ligs[0]['resid']-1} ligand mask: :{lig['resid']}") |
| else: |
| parts.append("- Ligands: none detected") |
|
|
| existing = self.runner.list_output_files() |
| if existing: |
| parts.append("\n## Existing Output Files") |
| for f in existing: parts.append(f" - {f.name}") |
| return "\n".join(parts) |
|
|
| def _execute_tool(self, name: str, inp: dict) -> str: |
| if name == "search_cpptraj_docs": |
| query = inp.get("query", "") |
| return self.kb.get_context_for_llm(query, top_k=2, score_threshold=0.0) |
|
|
| if name == "run_cpptraj_script": |
| script = inp.get("script", "") |
| if not script: |
| return "Error: model did not provide a script." |
| if self.parm_file or self.traj_files: |
| script = self.runner.inject_paths_into_script(script, self.parm_file, self.traj_files) |
| res = self.runner.run_script(script) |
| out = [f"Success: {res['success']}", f"Elapsed: {res['elapsed']:.1f}s"] |
| if res["stdout"]: out.append(f"\nSTDOUT:\n{res['stdout'][:1500]}") |
| if res["stderr"]: out.append(f"\nSTDERR:\n{res['stderr'][:800]}") |
| if res["output_files"]: |
| out.append("Output files:") |
| for f in res["output_files"]: out.append(f" - {f.name}") |
| return "\n".join(out) |
|
|
| if name == "read_output_file": |
| path = self.runner.work_dir / inp["filename"] |
| if not path.exists(): |
| avail = [f.name for f in self.runner.list_output_files()] |
| return f"File '{inp['filename']}' not found. Available: {avail}" |
| content = self.runner.read_file(path) |
| lines = content.splitlines() |
| if len(lines) > 40: |
| return "\n".join(lines[:40]) + f"\n\n[{len(lines)} lines total β first 40 shown]" |
| return content |
|
|
| if name == "list_output_files": |
| files = self.runner.list_output_files() |
| if not files: return "No output files yet." |
| return "Output files:\n" + "\n".join( |
| f" - {f.name} ({f.stat().st_size} bytes)" for f in files) |
|
|
| if name == "run_python_script": |
| script = inp.get("script", "") |
| if not script: |
| return "Error: model did not provide a script." |
| work_dir = self.runner.work_dir |
| before = set(work_dir.iterdir()) |
| try: |
| proc = subprocess.run( |
| [sys.executable, "-c", script], |
| capture_output=True, text=True, timeout=60, |
| cwd=str(work_dir), |
| ) |
| after = set(work_dir.iterdir()) |
| new_files = sorted(after - before, key=lambda f: f.name) |
| out = [f"Success: {proc.returncode == 0}"] |
| if proc.stdout: out.append(f"\nSTDOUT:\n{proc.stdout[:1500]}") |
| if proc.stderr: out.append(f"\nSTDERR:\n{proc.stderr[:800]}") |
| if new_files: |
| out.append("New files created:") |
| for f in new_files: out.append(f" - {f.name} ({f.stat().st_size} bytes)") |
| return "\n".join(out) |
| except subprocess.TimeoutExpired: |
| return "Error: Python script timed out after 60 seconds." |
| except Exception as e: |
| return f"Error running Python script: {e}" |
|
|
| return f"Unknown tool: {name}" |
|
|
| def _sanitize_history(self): |
| while self.conversation_history: |
| last = self.conversation_history[-1] |
| role = last["role"] |
| |
| if role not in ("assistant", "model"): |
| break |
| content = last.get("content") or [] |
| has_unresolved = ( |
| any(isinstance(b, dict) and b.get("type") == "tool_use" for b in content) |
| if isinstance(content, list) |
| else bool(last.get("tool_calls") or last.get("_fn_calls")) |
| ) |
| if has_unresolved: |
| self.conversation_history.pop() |
| else: |
| break |
|
|
| def chat_stream(self, user_query: str): |
| """Generator yielding SSE-style dicts for streaming chat.""" |
| self._sanitize_history() |
| self.conversation_history.append({ |
| "role": "user", |
| "content": self._build_user_message_with_rag(user_query), |
| }) |
|
|
| backend = self._backend |
| max_iterations = 15 |
| iteration = 0 |
|
|
| while iteration < max_iterations: |
| iteration += 1 |
| text_acc = [] |
| tool_calls = [] |
| stop_reason = "end_turn" |
|
|
| for event_type, data in backend.stream_chat( |
| self._safe_trim(self._trim_history(self.conversation_history)), TOOLS, self._system_prompt): |
| if event_type == "text": |
| text_acc.append(data) |
| yield {"type": "text", "chunk": data} |
| elif event_type == "retract_text": |
| text_acc.clear() |
| yield {"type": "clear_text"} |
| elif event_type == "tool_calls": |
| tool_calls = data |
| elif event_type == "stop_reason": |
| stop_reason = data |
|
|
| full_text = "".join(text_acc) |
|
|
| |
| |
| if tool_calls and full_text.strip(): |
| full_text = "" |
| yield {"type": "clear_text"} |
|
|
| self.conversation_history.append( |
| backend.make_assistant_message(full_text, tool_calls)) |
|
|
| if stop_reason not in ("tool_use", "tool_calls") or not tool_calls: |
| yield {"type": "done"} |
| return |
|
|
| |
| results = [] |
| for tc in tool_calls: |
| yield {"type": "tool_start", "tool": tc["name"], |
| "description": tc["input"].get("description", tc["name"])} |
| try: |
| result = self._execute_tool(tc["name"], tc["input"]) |
| except Exception as e: |
| result = f"Error: {e}" |
| yield {"type": "tool_done", "tool": tc["name"], |
| "input": tc["input"], "result": result} |
| results.append(self._compress_result(result)) |
|
|
| |
| tool_result_msg = backend.make_tool_result_message(tool_calls, results) |
| if "_multi" in tool_result_msg: |
| self.conversation_history.extend(tool_result_msg["_multi"]) |
| else: |
| self.conversation_history.append(tool_result_msg) |
|
|
| |
| yield {"type": "text", "chunk": "\n\nβ Reached maximum tool iterations β stopping."} |
| yield {"type": "done"} |
|
|
| def chat(self, user_query: str) -> tuple[str, list[dict]]: |
| self._sanitize_history() |
| self.conversation_history.append({ |
| "role": "user", |
| "content": self._build_user_message_with_rag(user_query), |
| }) |
|
|
| tool_calls_log = [] |
| final_text = "" |
| backend = self._backend |
|
|
| while True: |
| try: |
| text, tool_calls, has_tool_use = backend.chat( |
| self._safe_trim(self._trim_history(self.conversation_history)), TOOLS, self._system_prompt) |
| except Exception as e: |
| if "tool_use" in str(e) or "tool_result" in str(e): |
| last = self.conversation_history[-1] |
| self.conversation_history = [last] |
| text, tool_calls, has_tool_use = backend.chat( |
| self._safe_trim(self._trim_history(self.conversation_history)), TOOLS, self._system_prompt) |
| else: |
| raise |
|
|
| self.conversation_history.append(backend.make_assistant_message(text, tool_calls)) |
|
|
| if not has_tool_use or not tool_calls: |
| final_text = text |
| break |
|
|
| results = [] |
| for tc in tool_calls: |
| result = self._execute_tool(tc["name"], tc["input"]) |
| tool_calls_log.append({"tool": tc["name"], "input": tc["input"], "result": result}) |
| results.append(self._compress_result(result)) |
|
|
| tool_result_msg = backend.make_tool_result_message(tool_calls, results) |
| if "_multi" in tool_result_msg: |
| self.conversation_history.extend(tool_result_msg["_multi"]) |
| else: |
| self.conversation_history.append(tool_result_msg) |
|
|
| return final_text, tool_calls_log |
|
|