from __future__ import annotations import argparse import json import threading import time from typing import Any import httpx import uvicorn from trenches_env.models import AgentAction, Prediction from trenches_env.openenv_adapter import TrenchesOpenEnvAction, TrenchesOpenEnvObservation from trenches_env.openenv_client import TrenchesEnvClient from trenches_env.rl import AGENT_ALLOWED_ACTIONS from trenches_env.server import create_app DEFAULT_MODEL_ID = "Qwen/Qwen3-8B" DEFAULT_REPLAY_ID = "us_synthetic_seed_2025_2026" DEFAULT_TRAINING_STAGE = "stage_1_dense" DEFAULT_LORA_TARGET_MODULES = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" def _launch_backend(port: int) -> None: app = create_app() thread = threading.Thread( target=lambda: uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning"), daemon=True, ) thread.start() deadline = time.time() + 30.0 url = f"http://127.0.0.1:{port}/healthz" while time.time() < deadline: try: response = httpx.get(url, timeout=1.0) if response.status_code == 200: return except httpx.HTTPError: time.sleep(0.5) raise RuntimeError(f"Timed out waiting for backend health at {url}") def _build_base_prompt(training_agent: str) -> str: return ( f"You are training the {training_agent} policy in the Trenches OpenEnv historical replay environment. " "Return strict JSON only. " "Choose one legal action and forecast the next historical event from the visible timeline." ) def _render_observation_prompt( base_prompt: str, training_agent: str, observation: TrenchesOpenEnvObservation, ) -> str: agent_observation = observation.agent_observation public_brief = "\n".join(f"- {item.summary}" for item in agent_observation.public_brief[:4]) or "- None." private_brief = "\n".join(f"- {item.summary}" for item in agent_observation.private_brief[:4]) or "- None." historical_brief = "\n".join(f"- {line}" for line in agent_observation.historical_brief[:4]) or "- None." strategic_state = "\n".join( f"- {metric}: {value:.1f}" for metric, value in agent_observation.strategic_state.items() ) or "- None." available_actions = ", ".join(agent_observation.available_actions) return "\n".join( [ base_prompt, "", f"Training agent: {training_agent}", f"Turn: {observation.turn}", f"Decision prompt:\n{agent_observation.decision_prompt}", "Historical brief:", historical_brief, "Public brief:", public_brief, "Private brief:", private_brief, "Strategic state:", strategic_state, f"Allowed actions: {available_actions}", "Output schema:", "{", ' "action": {', f' "type": "{agent_observation.available_actions[0] if agent_observation.available_actions else "hold"}",', ' "target": "optional_target",', ' "summary": "one-sentence action rationale"', " },", ' "prediction": {', ' "topic": "shipping|border|corridor|domestic|cyber|market|humanitarian|diplomacy|commodities",', ' "predicted_actor": "actor or null",', ' "predicted_target": "target or null",', ' "time_horizon_turns": 1,', ' "expected_severity": "low|medium|high|critical",', ' "confidence": 0.0,', ' "summary": "one-sentence forecast",', ' "rationale": "why this next event is likely"', " }", "}", ] ) def _safe_json_loads(text: str) -> dict[str, Any]: text = text.strip() if not text: return {} try: return json.loads(text) except json.JSONDecodeError: start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end <= start: return {} try: return json.loads(text[start : end + 1]) except json.JSONDecodeError: return {} def _parse_turn_output(training_agent: str, completion: str) -> tuple[AgentAction, Prediction]: payload = _safe_json_loads(completion) action_payload = payload.get("action") if isinstance(payload.get("action"), dict) else {} prediction_payload = payload.get("prediction") if isinstance(payload.get("prediction"), dict) else {} action_type = str(action_payload.get("type") or "hold") if action_type not in AGENT_ALLOWED_ACTIONS.get(training_agent, ()): action_type = "hold" target = action_payload.get("target") action_summary = str( action_payload.get("summary") or "Fallback hold after invalid or partial model completion." ) prediction_topic = str(prediction_payload.get("topic") or "diplomacy") predicted_actor = prediction_payload.get("predicted_actor") predicted_target = prediction_payload.get("predicted_target") prediction_summary = str( prediction_payload.get("summary") or "Fallback low-confidence forecast after invalid or partial model completion." ) rationale = str(prediction_payload.get("rationale") or "Parser fallback.") confidence_raw = prediction_payload.get("confidence", 0.1) try: confidence = max(0.0, min(1.0, float(confidence_raw))) except (TypeError, ValueError): confidence = 0.1 horizon_raw = prediction_payload.get("time_horizon_turns", 1) try: time_horizon_turns = max(1, int(horizon_raw)) except (TypeError, ValueError): time_horizon_turns = 1 expected_severity = str(prediction_payload.get("expected_severity") or "medium") if expected_severity not in {"low", "medium", "high", "critical"}: expected_severity = "medium" return ( AgentAction( actor=training_agent, type=action_type, # type: ignore[arg-type] target=target if isinstance(target, str) else None, summary=action_summary, ), Prediction( agent_id=training_agent, topic=prediction_topic, predicted_actor=predicted_actor if isinstance(predicted_actor, str) else None, predicted_target=predicted_target if isinstance(predicted_target, str) else None, time_horizon_turns=time_horizon_turns, expected_severity=expected_severity, # type: ignore[arg-type] confidence=confidence, summary=prediction_summary, rationale=rationale, ), ) def _build_dataset(training_agent: str, size: int): from datasets import Dataset base_prompt = _build_base_prompt(training_agent) return Dataset.from_dict({"prompt": [base_prompt] * size}) def _required_training_imports() -> dict[str, Any]: try: import torch from transformers import AutoTokenizer from trl import GRPOConfig, GRPOTrainer except ModuleNotFoundError as exc: raise RuntimeError( "Missing training dependencies. Install torch, transformers, trl, accelerate, and openenv-core first." ) from exc try: from trl.experimental.openenv import generate_rollout_completions except ModuleNotFoundError: generate_rollout_completions = None return { "torch": torch, "AutoTokenizer": AutoTokenizer, "GRPOConfig": GRPOConfig, "GRPOTrainer": GRPOTrainer, "generate_rollout_completions": generate_rollout_completions, } def _resolve_model_device(model: Any) -> Any: device = getattr(model, "device", None) if device is not None: return device return next(model.parameters()).device def _can_use_vllm(torch_module: Any) -> bool: if not hasattr(torch_module, "cuda") or not torch_module.cuda.is_available(): return False try: import vllm # noqa: F401 except ModuleNotFoundError: return False return True def _generate_rollout_completions_transformers( *, trainer: Any, prompts: list[str], tokenizer: Any, max_prompt_length: int, max_completion_length: int, ) -> dict[str, list[Any]]: import torch tokenizer.padding_side = "left" model = trainer.model device = _resolve_model_device(model) encoded = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_prompt_length, ) encoded = {key: value.to(device) for key, value in encoded.items()} generation = model.generate( **encoded, max_new_tokens=max_completion_length, do_sample=True, temperature=0.9, top_p=0.95, return_dict_in_generate=True, output_scores=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) prompt_ids: list[list[int]] = [] completion_ids: list[list[int]] = [] logprobs: list[list[float]] = [] input_width = encoded["input_ids"].shape[1] for batch_index in range(len(prompts)): prompt_ids.append( encoded["input_ids"][batch_index][encoded["attention_mask"][batch_index].bool()].tolist() ) sample_completion_ids: list[int] = [] sample_logprobs: list[float] = [] for step_index, step_scores in enumerate(generation.scores): token_position = input_width + step_index if token_position >= generation.sequences.shape[1]: break token_id = int(generation.sequences[batch_index, token_position].item()) if token_id == tokenizer.pad_token_id: break sample_completion_ids.append(token_id) token_logprob = torch.log_softmax(step_scores[batch_index], dim=-1)[token_id].item() sample_logprobs.append(float(token_logprob)) if tokenizer.eos_token_id is not None and token_id == tokenizer.eos_token_id: break completion_ids.append(sample_completion_ids) logprobs.append(sample_logprobs) return { "prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs, } def _parse_lora_target_modules(raw_value: str) -> str | list[str]: target_modules = [item.strip() for item in raw_value.split(",") if item.strip()] if not target_modules: raise ValueError("LoRA target modules cannot be empty.") if len(target_modules) == 1 and target_modules[0] == "all-linear": return "all-linear" return target_modules def _preview_rollouts( *, model: Any, tokenizer: Any, training_agent: str, port: int, replay_id: str, training_stage: str, samples: int, max_prompt_length: int, max_completion_length: int, ) -> None: import torch print("\nPreview rollouts") for sample_index in range(samples): client = TrenchesEnvClient(base_url=f"http://127.0.0.1:{port}/openenv") reset_result = client.reset( training_agent=training_agent, training_stage=training_stage, max_turns=1, replay_id=replay_id, episode_id=f"preview-{sample_index}-{int(time.time() * 1000)}", ) observation = reset_result.observation prompt = _render_observation_prompt( _build_base_prompt(training_agent), training_agent, observation, ) model_max_length = getattr(tokenizer, "model_max_length", None) if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 1_000_000: model_max_length = max_prompt_length preview_prompt_length = min(max_prompt_length, model_max_length) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=preview_prompt_length, ).to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=max_completion_length, do_sample=False, pad_token_id=tokenizer.pad_token_id, ) completion_ids = output[0][inputs["input_ids"].shape[1] :] completion = tokenizer.decode(completion_ids, skip_special_tokens=True) action, prediction = _parse_turn_output(training_agent, completion) step_result = client.step( TrenchesOpenEnvAction( action=action, prediction=prediction, external_signals=[], ) ) step_obs = step_result.observation actual_event = step_obs.revealed_event.summary if step_obs.revealed_event is not None else "n/a" step_reward = step_result.reward if step_result.reward is not None else 0.0 print( f"[sample {sample_index + 1}] reward={step_reward:.3f} " f"action={action.type} topic={prediction.topic} actual={actual_event}" ) class OpenEnvGRPOTrainer: """Force GRPO to use the custom OpenEnv rollout path across generation backends.""" def _generate_single_turn(self, prompts: list[str]): # type: ignore[override] if getattr(self, "rollout_func", None) is None: return super()._generate_single_turn(prompts) output = self.rollout_func(prompts, self) required_keys = {"prompt_ids", "completion_ids", "logprobs"} missing = required_keys.difference(output) if missing: raise RuntimeError(f"rollout_func is missing required keys: {sorted(missing)}") extra_fields = {key: value for key, value in output.items() if key not in required_keys} return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields def main() -> None: parser = argparse.ArgumentParser(description="Train a replay-aware OpenEnv policy for Trenches.") parser.add_argument("--model-id", default=DEFAULT_MODEL_ID) parser.add_argument("--training-agent", default="us") parser.add_argument("--training-stage", default=DEFAULT_TRAINING_STAGE) parser.add_argument("--replay-id", default=DEFAULT_REPLAY_ID) parser.add_argument("--port", type=int, default=8000) parser.add_argument("--train-size", type=int, default=32) parser.add_argument("--max-steps", type=int, default=4) parser.add_argument("--num-generations", type=int, default=4) parser.add_argument("--generation-backend", choices=["auto", "vllm", "transformers"], default="auto") parser.add_argument("--max-prompt-length", type=int, default=1024) parser.add_argument("--max-completion-length", type=int, default=220) parser.add_argument("--per-device-train-batch-size", type=int, default=1) parser.add_argument("--gradient-accumulation-steps", type=int, default=1) parser.add_argument("--learning-rate", type=float, default=1e-6) parser.add_argument("--output-dir", default="trl-openenv-historical-replay") parser.add_argument("--preview-samples", type=int, default=3) parser.add_argument("--no-preview", action="store_true") # Post-training plan args parser.add_argument("--quantize-4bit", action="store_true", help="Load model with 4-bit NF4 quantization via bitsandbytes (requires CUDA)") parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank used with quantized training") parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha used with quantized training") parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout used with quantized training") parser.add_argument( "--lora-target-modules", default=DEFAULT_LORA_TARGET_MODULES, help='Comma-separated LoRA target modules, or "all-linear"', ) parser.add_argument("--beta", type=float, default=0.04, help="KL coefficient for GRPO") parser.add_argument("--warmup-steps", type=int, default=0, help="Number of warmup steps") parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature for generation") parser.add_argument("--top-k", type=int, default=0, help="Top-k sampling (0 = disabled)") parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling") parser.add_argument("--save-strategy", default="no", choices=["no", "steps", "epoch"], help="Checkpoint save strategy") parser.add_argument("--save-steps", type=int, default=100, help="Save checkpoint every N steps (when save-strategy=steps)") args = parser.parse_args() imports = _required_training_imports() torch = imports["torch"] AutoTokenizer = imports["AutoTokenizer"] GRPOConfig = imports["GRPOConfig"] GRPOTrainer = type("OpenEnvGRPOTrainer", (OpenEnvGRPOTrainer, imports["GRPOTrainer"]), {}) generate_rollout_completions = imports["generate_rollout_completions"] generation_backend = args.generation_backend if generation_backend == "auto": generation_backend = "vllm" if generate_rollout_completions is not None and _can_use_vllm(torch) else "transformers" if generation_backend == "vllm" and generate_rollout_completions is None: raise RuntimeError("The selected vLLM backend requires `trl.experimental.openenv.generate_rollout_completions`.") _launch_backend(args.port) # Model loading — optionally with 4-bit quantization model_ref = args.model_id peft_config = None if args.quantize_4bit: from transformers import AutoModelForCausalLM, BitsAndBytesConfig try: from peft import LoraConfig, TaskType except ModuleNotFoundError as exc: raise RuntimeError( "Missing PEFT dependency. Install backend[train] so quantized training can attach LoRA adapters." ) from exc bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) print(f"Loading {args.model_id} with 4-bit NF4 quantization") model_ref = AutoModelForCausalLM.from_pretrained( args.model_id, quantization_config=bnb_config, device_map="auto", ) peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", target_modules=_parse_lora_target_modules(args.lora_target_modules), ) print( "Attaching LoRA adapters for quantized training " f"(r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}, " f"targets={args.lora_target_modules})" ) tokenizer = AutoTokenizer.from_pretrained(args.model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token train_dataset = _build_dataset(args.training_agent, args.train_size) base_prompt = _build_base_prompt(args.training_agent) def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list[Any]]: prompt_ids: list[list[int]] = [] completion_ids: list[list[int]] = [] logprobs: list[list[float]] = [] env_rewards: list[float] = [] forecast_rewards: list[float] = [] for index, prompt in enumerate(prompts): with TrenchesEnvClient(base_url=f"http://127.0.0.1:{args.port}/openenv") as client: reset_result = client.reset( training_agent=args.training_agent, training_stage=args.training_stage, max_turns=1, replay_id=args.replay_id, episode_id=f"train-{index}-{int(time.time() * 1000)}", ) grounded_prompt = _render_observation_prompt( prompt or base_prompt, args.training_agent, reset_result.observation, ) if generation_backend == "vllm": rollout_output = generate_rollout_completions(trainer, [grounded_prompt])[0] else: rollout_output = { key: value[0] for key, value in _generate_rollout_completions_transformers( trainer=trainer, prompts=[grounded_prompt], tokenizer=tokenizer, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, ).items() } completion_text = tokenizer.decode(rollout_output["completion_ids"], skip_special_tokens=True) action, prediction = _parse_turn_output(args.training_agent, completion_text) step_result = client.step( TrenchesOpenEnvAction( action=action, prediction=prediction, external_signals=[], ) ) prompt_ids.append(list(rollout_output["prompt_ids"])) completion_ids.append(list(rollout_output["completion_ids"])) logprobs.append([float(value) for value in rollout_output["logprobs"]]) step_reward = step_result.reward if step_result.reward is not None else 0.0 step_obs = step_result.observation forecast_total = step_obs.reward_breakdown.forecast_total if step_obs.reward_breakdown is not None else 0.0 env_rewards.append(float(step_reward)) forecast_rewards.append(float(forecast_total)) return { "prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs, "env_reward": env_rewards, "forecast_reward": forecast_rewards, } def reward_from_env(completions: list[str], **kwargs: Any) -> list[float]: rewards = kwargs.get("env_reward") if rewards is None: return [0.0 for _ in completions] return [float(reward) for reward in rewards] training_kwargs = { "output_dir": args.output_dir, "learning_rate": args.learning_rate, "max_steps": args.max_steps, "num_train_epochs": 1, "per_device_train_batch_size": args.per_device_train_batch_size, "gradient_accumulation_steps": args.gradient_accumulation_steps, "num_generations": args.num_generations, "generation_batch_size": args.num_generations, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, "logging_steps": 1, "report_to": [], "use_vllm": generation_backend == "vllm", "beta": args.beta, "warmup_steps": args.warmup_steps, "temperature": args.temperature, "top_p": args.top_p, "save_strategy": args.save_strategy, "save_steps": args.save_steps, } if not args.quantize_4bit: training_kwargs["bf16"] = True training_kwargs["model_init_kwargs"] = {"dtype": "bfloat16"} if generation_backend == "vllm": training_kwargs["vllm_mode"] = "colocate" training_args = GRPOConfig(**training_kwargs) trainer = GRPOTrainer( model=model_ref, processing_class=tokenizer, reward_funcs=reward_from_env, train_dataset=train_dataset, rollout_func=rollout_func, args=training_args, peft_config=peft_config, ) train_result = trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) print("\nTraining complete") print(f"Generation backend: {generation_backend}") print(json.dumps(train_result.metrics, indent=2, sort_keys=True)) if not args.no_preview and args.preview_samples > 0: _preview_rollouts( model=trainer.model, tokenizer=tokenizer, training_agent=args.training_agent, port=args.port, replay_id=args.replay_id, training_stage=args.training_stage, samples=args.preview_samples, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, ) if __name__ == "__main__": main()