import os import sys import json import time import argparse from colorama import init, Fore, Style sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from ER_MAP.envs.triage_env import TriageEnv try: from groq import Groq GROQ_AVAILABLE = True except ImportError: GROQ_AVAILABLE = False init(autoreset=True) def print_header(title): print(f"\n{Fore.CYAN}{Style.BRIGHT}{'='*60}") print(f"{Fore.CYAN}{Style.BRIGHT} {title}") print(f"{Fore.CYAN}{Style.BRIGHT}{'='*60}") def run_automated_cli(phase: int): groq_key = (os.environ.get("GROQ_API_KEY") or os.environ.get("groq")) or (os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or (os.environ.get("GROQ_PATIENT_API_KEY") or os.environ.get("patient")) if not groq_key or not GROQ_AVAILABLE: print(f"{Fore.RED}ERROR: GROQ_API_KEY environment variable required and 'groq' package must be installed.") return client = Groq(api_key=groq_key) print(f"{Fore.YELLOW}Initializing Environment (Phase {phase})...") env = TriageEnv(render_mode="human") obs_json, info = env.reset(options={"phase": phase}) gt = env.ground_truth print_header(f"GROUND TRUTH GENERATED (PHASE {phase})") print(f"{Fore.MAGENTA}Disease: {gt['disease']['true_disease']} | Diff: {gt['disease']['difficulty']} | Emergency: {gt['disease'].get('is_emergency', False)}") print(f"Correct Tx: {gt['disease']['correct_treatment']}") print(f"\n{Fore.BLUE}Patient: {gt['patient']['communication']}, {gt['patient']['compliance']}, {gt['patient']['literacy']}") print(f"{Fore.GREEN}Nurse: {gt['nurse']['experience']}, {gt['nurse']['bandwidth']}, {gt['nurse']['empathy']}") print_header("AUTOMATED DOCTOR AGENT RUNNING (70B)") messages = [ {"role": "system", "content": "You are the Doctor. You must output ONLY valid JSON matching the exact schema requested in your prompt. Tools available: speak_to, order_lab, read_soap, update_soap, terminal_discharge."} ] step_count = 0 max_steps = 30 cumulative_reward = 0.0 while step_count < max_steps: # Provide the observation messages.append({"role": "user", "content": obs_json}) print(f"\n{Fore.YELLOW}--- Step {step_count + 1} ---") print(f"{Fore.WHITE}Doctor 70B is thinking...") try: completion = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=messages, temperature=0.7, response_format={"type": "json_object"} ) raw_action = completion.choices[0].message.content messages.append({"role": "assistant", "content": raw_action}) # Print Action try: action_parsed = json.loads(raw_action) print(f"{Fore.CYAN}💭 Thought: {action_parsed.get('thought', '')}") print(f"{Fore.GREEN}🛠️ Action: {action_parsed.get('tool', '')} -> {json.dumps({k:v for k,v in action_parsed.items() if k not in ['thought', 'tool']})}") except: print(f"{Fore.GREEN}🛠️ Action (Raw): {raw_action}") # Execute step obs_json, reward, done, truncated, info = env.step(raw_action) cumulative_reward += reward print(f"{Fore.MAGENTA}🪙 Step Reward: {reward:.2f} | Total: {cumulative_reward:.2f}") if done or truncated: print(f"\n{Fore.RED}{Style.BRIGHT}=== EPISODE FINISHED ===") try: final_obs = json.loads(obs_json) print(f"Outcome: {final_obs.get('event')}") print(f"Message: {final_obs.get('message')}") if 'match_ratio' in final_obs: print(f"Match Ratio: {final_obs['match_ratio']:.0%}") except: print(f"Final Obs: {obs_json}") break except Exception as e: print(f"{Fore.RED}LLM Error: {e}") break step_count += 1 time.sleep(1) if step_count >= max_steps: print(f"\n{Fore.RED}{Style.BRIGHT}=== MAX STEPS REACHED ===") if __name__ == "__main__": parser = argparse.ArgumentParser(description="ER-MAP Automated CLI Tester") parser.add_argument("--phase", type=int, default=1, choices=[1, 2, 3], help="Curriculum phase (1-3)") args = parser.parse_args() try: run_automated_cli(args.phase) except KeyboardInterrupt: print("\nExiting...") sys.exit(0)