"""Run the bio-experiment environment with a quantized Unsloth model.""" from __future__ import annotations import json import os import time from typing import Any, Dict, Optional from models import ActionType, ExperimentAction from server.hackathon_environment import BioExperimentEnvironment from training_unsloth import ( DEFAULT_MAX_SEQ_LENGTH, generate_action_with_model, load_model_artifacts, ) from training_script import DEFAULT_COMPLETION_TOKEN_BUDGET import run_agent as base MODEL_ID = os.getenv("RUN_AGENT_UNSLOTH_MODEL_ID", "unsloth/Qwen3.5-2B-GGUF") MAX_EPISODE_STEPS = int( os.getenv("RUN_AGENT_UNSLOTH_MAX_EPISODE_STEPS", str(base.MAX_EPISODE_STEPS)) ) MAX_NEW_TOKENS = int( os.getenv( "RUN_AGENT_UNSLOTH_MAX_NEW_TOKENS", str(DEFAULT_COMPLETION_TOKEN_BUDGET), ) ) MAX_SEQ_LENGTH = int( os.getenv("RUN_AGENT_UNSLOTH_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH)) ) TRUST_REMOTE_CODE = ( os.getenv("RUN_AGENT_UNSLOTH_TRUST_REMOTE_CODE", "1").strip().lower() not in {"0", "false", "off"} ) LOAD_IN_4BIT = ( os.getenv("RUN_AGENT_UNSLOTH_LOAD_IN_4BIT", "1").strip().lower() not in {"0", "false", "off"} ) FAST_INFERENCE = ( os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "0").strip().lower() not in {"0", "false", "off"} ) def check_dashboard_command() -> Optional[Dict[str, Any]]: try: raw = base.DASHBOARD_CMD_PATH.read_text(encoding="utf-8") base.DASHBOARD_CMD_PATH.unlink(missing_ok=True) return json.loads(raw) except (FileNotFoundError, json.JSONDecodeError): return None def run_episode( model: Any, tokenizer: Any, *, scenario_name: Optional[str] = None, custom_ground_truth: Optional[Dict[str, Any]] = None, ) -> None: env = BioExperimentEnvironment(scenario_name=scenario_name) obs = env.reset() if custom_ground_truth and env._latent: gt = custom_ground_truth bio = env._latent.biology if gt.get("true_markers"): bio.true_markers = gt["true_markers"] if gt.get("causal_mechanisms"): bio.causal_mechanisms = gt["causal_mechanisms"] if gt.get("true_pathways"): bio.true_pathways = { key: float(value) for key, value in gt["true_pathways"].items() } base.log("\n" + "=" * 70) base.log(f"TASK: {obs.task.problem_statement}") base.log(f"Conditions: {obs.task.conditions}") base.log( f"Budget: ${obs.task.budget_limit:,.0f} | " f"Time: {obs.task.time_limit_days:.0f} days" ) base.log("Runtime: Unsloth quantized generation") base.log("=" * 70) cumulative_reward = 0.0 base.write_dashboard_state(env, obs, step=0, cumulative_reward=0.0) for step in range(MAX_EPISODE_STEPS): cmd = check_dashboard_command() if cmd and cmd.get("action") == "restart": base.log("\n[DASHBOARD] Restart requested - ending episode early.") break t0 = time.time() result = generate_action_with_model( model, tokenizer, obs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.2, top_p=0.9, do_sample=True, ) response = result["response_text"] action = result["action"] gen_time = time.time() - t0 is_last_step = step == MAX_EPISODE_STEPS - 1 if action is None: if is_last_step: base.log("\n [!] Parse failed on final step - forcing synthesize_conclusion.") action = ExperimentAction( action_type=ActionType.SYNTHESIZE_CONCLUSION, justification="forced terminal conclusion", confidence=0.5, ) else: base.log( f"\n [!] Parse failed, skipping step. Raw: {response[:150]}" ) continue completed_types = { record.action_type for record in obs.pipeline_history if record.success } failed_types = { record.action_type for record in obs.pipeline_history if not record.success } if base.should_force_terminal_conclusion(action, completed_types): base.log( f"\n [!] repeated completed meta step {action.action_type.value} " f"- forcing synthesize_conclusion." ) action = ExperimentAction( action_type=ActionType.SYNTHESIZE_CONCLUSION, justification="repeated completed meta step forced terminal conclusion", confidence=action.confidence, ) completed_types = { record.action_type for record in obs.pipeline_history if record.success } skip_reason = None if action.action_type in completed_types: skip_reason = f"blocked repeat of completed step {action.action_type.value}" elif action.action_type in failed_types: if base.should_block_failed_reattempt(obs.pipeline_history, action.action_type): skip_reason = ( f"blocked re-attempt of failed step {action.action_type.value}" ) if skip_reason: if is_last_step: base.log( f"\n [!] {skip_reason} on final step - forcing synthesize_conclusion." ) action = ExperimentAction( action_type=ActionType.SYNTHESIZE_CONCLUSION, justification="forced terminal conclusion", confidence=0.5, ) else: base.log(f"\n [!] {skip_reason}, skipping step.") continue if is_last_step and action.action_type != ActionType.SYNTHESIZE_CONCLUSION: base.log( f"\n [!] Final step - overriding {action.action_type.value} " "with synthesize_conclusion." ) action = ExperimentAction( action_type=ActionType.SYNTHESIZE_CONCLUSION, justification="forced terminal conclusion", confidence=action.confidence, ) action = base.ensure_conclusion_claims(obs, action) base.log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)") if action.justification: base.log(f" Rationale: {action.justification}") else: base.log(" Rationale: [model did not provide one]") if action.parameters: base.log(f" Parameters: {base.compact_preview(action.parameters, 200)}") elif response: base.log( " Model response: " f"{base.compact_preview(response, base.MODEL_RESPONSE_PREVIEW_CHARS)}" ) obs = env.step(action) if obs.latest_output: latest_output = obs.latest_output status = "OK" if latest_output.success else "FAIL" base.log(f" [{status}] {latest_output.summary}") if latest_output.warnings: base.log(f" Warnings: {latest_output.warnings}") step_reward = obs.reward cumulative_reward += step_reward base.log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})") base.log( f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | " f"Time: {obs.resource_usage.time_remaining_days:.0f}d" ) base.write_dashboard_state( env, obs, step=step + 1, cumulative_reward=cumulative_reward, model_response=response, action=action, gen_time=gen_time, episode_done=obs.done, ) if obs.rule_violations: base.log(f" Violations: {obs.rule_violations}") if obs.done: break base.log(f"\n{'=' * 70}") base.log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})") base.log(f" Steps: {obs.step_index}") base.log(f" Total reward: {cumulative_reward:+.3f}") base.log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}") base.log(f" Time used: {obs.resource_usage.time_used_days:.0f} days") if obs.conclusions: base.log(" Conclusions:") for conclusion in obs.conclusions: base.log( f" [{conclusion.claim_type}, conf={conclusion.confidence:.2f}] " f"{conclusion.claim}" ) if conclusion.top_markers: base.log(f" Markers: {conclusion.top_markers}") if conclusion.causal_mechanisms: base.log(f" Mechanisms: {conclusion.causal_mechanisms}") if conclusion.predicted_pathways: base.log(f" Pathways: {conclusion.predicted_pathways}") base.log("=" * 70) def main() -> None: runtime = base.resolve_torch_runtime() base.log( f"Using Unsloth runtime: device={runtime['device']} " f"name={runtime['device_name']} dtype={runtime['dtype']}" ) tokenizer, model = load_model_artifacts( MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=LOAD_IN_4BIT, fast_inference=FAST_INFERENCE, prepare_for_inference=True, ) base.DASHBOARD_CMD_PATH.unlink(missing_ok=True) run_episode(model, tokenizer) while True: base.log("\nWaiting for dashboard command (restart / new task) ...") while True: cmd = check_dashboard_command() if cmd: break time.sleep(1.0) action_type = cmd.get("action", "restart") if action_type == "quit": base.log("Quit requested.") break scenario = cmd.get("scenario_name") ground_truth = cmd.get("ground_truth") base.log(f"\n[DASHBOARD] {action_type} - scenario={scenario}") run_episode( model, tokenizer, scenario_name=scenario, custom_ground_truth=ground_truth, ) if __name__ == "__main__": main()