Spaces:
Paused
Paused
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import threading | |
| import time | |
| from typing import Any | |
| import httpx | |
| import uvicorn | |
| from trenches_env.models import AgentAction, Prediction | |
| from trenches_env.openenv_adapter import TrenchesOpenEnvAction, TrenchesOpenEnvObservation | |
| from trenches_env.openenv_client import TrenchesEnvClient | |
| from trenches_env.rl import AGENT_ALLOWED_ACTIONS | |
| from trenches_env.server import create_app | |
| DEFAULT_MODEL_ID = "Qwen/Qwen3-8B" | |
| DEFAULT_REPLAY_ID = "us_synthetic_seed_2025_2026" | |
| DEFAULT_TRAINING_STAGE = "stage_1_dense" | |
| DEFAULT_LORA_TARGET_MODULES = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" | |
| def _launch_backend(port: int) -> None: | |
| app = create_app() | |
| thread = threading.Thread( | |
| target=lambda: uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning"), | |
| daemon=True, | |
| ) | |
| thread.start() | |
| deadline = time.time() + 30.0 | |
| url = f"http://127.0.0.1:{port}/healthz" | |
| while time.time() < deadline: | |
| try: | |
| response = httpx.get(url, timeout=1.0) | |
| if response.status_code == 200: | |
| return | |
| except httpx.HTTPError: | |
| time.sleep(0.5) | |
| raise RuntimeError(f"Timed out waiting for backend health at {url}") | |
| def _build_base_prompt(training_agent: str) -> str: | |
| return ( | |
| f"You are training the {training_agent} policy in the Trenches OpenEnv historical replay environment. " | |
| "Return strict JSON only. " | |
| "Choose one legal action and forecast the next historical event from the visible timeline." | |
| ) | |
| def _render_observation_prompt( | |
| base_prompt: str, | |
| training_agent: str, | |
| observation: TrenchesOpenEnvObservation, | |
| ) -> str: | |
| agent_observation = observation.agent_observation | |
| public_brief = "\n".join(f"- {item.summary}" for item in agent_observation.public_brief[:4]) or "- None." | |
| private_brief = "\n".join(f"- {item.summary}" for item in agent_observation.private_brief[:4]) or "- None." | |
| historical_brief = "\n".join(f"- {line}" for line in agent_observation.historical_brief[:4]) or "- None." | |
| strategic_state = "\n".join( | |
| f"- {metric}: {value:.1f}" for metric, value in agent_observation.strategic_state.items() | |
| ) or "- None." | |
| available_actions = ", ".join(agent_observation.available_actions) | |
| return "\n".join( | |
| [ | |
| base_prompt, | |
| "", | |
| f"Training agent: {training_agent}", | |
| f"Turn: {observation.turn}", | |
| f"Decision prompt:\n{agent_observation.decision_prompt}", | |
| "Historical brief:", | |
| historical_brief, | |
| "Public brief:", | |
| public_brief, | |
| "Private brief:", | |
| private_brief, | |
| "Strategic state:", | |
| strategic_state, | |
| f"Allowed actions: {available_actions}", | |
| "Output schema:", | |
| "{", | |
| ' "action": {', | |
| f' "type": "{agent_observation.available_actions[0] if agent_observation.available_actions else "hold"}",', | |
| ' "target": "optional_target",', | |
| ' "summary": "one-sentence action rationale"', | |
| " },", | |
| ' "prediction": {', | |
| ' "topic": "shipping|border|corridor|domestic|cyber|market|humanitarian|diplomacy|commodities",', | |
| ' "predicted_actor": "actor or null",', | |
| ' "predicted_target": "target or null",', | |
| ' "time_horizon_turns": 1,', | |
| ' "expected_severity": "low|medium|high|critical",', | |
| ' "confidence": 0.0,', | |
| ' "summary": "one-sentence forecast",', | |
| ' "rationale": "why this next event is likely"', | |
| " }", | |
| "}", | |
| ] | |
| ) | |
| def _safe_json_loads(text: str) -> dict[str, Any]: | |
| text = text.strip() | |
| if not text: | |
| return {} | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| return {} | |
| try: | |
| return json.loads(text[start : end + 1]) | |
| except json.JSONDecodeError: | |
| return {} | |
| def _parse_turn_output(training_agent: str, completion: str) -> tuple[AgentAction, Prediction]: | |
| payload = _safe_json_loads(completion) | |
| action_payload = payload.get("action") if isinstance(payload.get("action"), dict) else {} | |
| prediction_payload = payload.get("prediction") if isinstance(payload.get("prediction"), dict) else {} | |
| action_type = str(action_payload.get("type") or "hold") | |
| if action_type not in AGENT_ALLOWED_ACTIONS.get(training_agent, ()): | |
| action_type = "hold" | |
| target = action_payload.get("target") | |
| action_summary = str( | |
| action_payload.get("summary") | |
| or "Fallback hold after invalid or partial model completion." | |
| ) | |
| prediction_topic = str(prediction_payload.get("topic") or "diplomacy") | |
| predicted_actor = prediction_payload.get("predicted_actor") | |
| predicted_target = prediction_payload.get("predicted_target") | |
| prediction_summary = str( | |
| prediction_payload.get("summary") | |
| or "Fallback low-confidence forecast after invalid or partial model completion." | |
| ) | |
| rationale = str(prediction_payload.get("rationale") or "Parser fallback.") | |
| confidence_raw = prediction_payload.get("confidence", 0.1) | |
| try: | |
| confidence = max(0.0, min(1.0, float(confidence_raw))) | |
| except (TypeError, ValueError): | |
| confidence = 0.1 | |
| horizon_raw = prediction_payload.get("time_horizon_turns", 1) | |
| try: | |
| time_horizon_turns = max(1, int(horizon_raw)) | |
| except (TypeError, ValueError): | |
| time_horizon_turns = 1 | |
| expected_severity = str(prediction_payload.get("expected_severity") or "medium") | |
| if expected_severity not in {"low", "medium", "high", "critical"}: | |
| expected_severity = "medium" | |
| return ( | |
| AgentAction( | |
| actor=training_agent, | |
| type=action_type, # type: ignore[arg-type] | |
| target=target if isinstance(target, str) else None, | |
| summary=action_summary, | |
| ), | |
| Prediction( | |
| agent_id=training_agent, | |
| topic=prediction_topic, | |
| predicted_actor=predicted_actor if isinstance(predicted_actor, str) else None, | |
| predicted_target=predicted_target if isinstance(predicted_target, str) else None, | |
| time_horizon_turns=time_horizon_turns, | |
| expected_severity=expected_severity, # type: ignore[arg-type] | |
| confidence=confidence, | |
| summary=prediction_summary, | |
| rationale=rationale, | |
| ), | |
| ) | |
| def _build_dataset(training_agent: str, size: int): | |
| from datasets import Dataset | |
| base_prompt = _build_base_prompt(training_agent) | |
| return Dataset.from_dict({"prompt": [base_prompt] * size}) | |
| def _required_training_imports() -> dict[str, Any]: | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| except ModuleNotFoundError as exc: | |
| raise RuntimeError( | |
| "Missing training dependencies. Install torch, transformers, trl, accelerate, and openenv-core first." | |
| ) from exc | |
| try: | |
| from trl.experimental.openenv import generate_rollout_completions | |
| except ModuleNotFoundError: | |
| generate_rollout_completions = None | |
| return { | |
| "torch": torch, | |
| "AutoTokenizer": AutoTokenizer, | |
| "GRPOConfig": GRPOConfig, | |
| "GRPOTrainer": GRPOTrainer, | |
| "generate_rollout_completions": generate_rollout_completions, | |
| } | |
| def _resolve_model_device(model: Any) -> Any: | |
| device = getattr(model, "device", None) | |
| if device is not None: | |
| return device | |
| return next(model.parameters()).device | |
| def _can_use_vllm(torch_module: Any) -> bool: | |
| if not hasattr(torch_module, "cuda") or not torch_module.cuda.is_available(): | |
| return False | |
| try: | |
| import vllm # noqa: F401 | |
| except ModuleNotFoundError: | |
| return False | |
| return True | |
| def _generate_rollout_completions_transformers( | |
| *, | |
| trainer: Any, | |
| prompts: list[str], | |
| tokenizer: Any, | |
| max_prompt_length: int, | |
| max_completion_length: int, | |
| ) -> dict[str, list[Any]]: | |
| import torch | |
| tokenizer.padding_side = "left" | |
| model = trainer.model | |
| device = _resolve_model_device(model) | |
| encoded = tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=max_prompt_length, | |
| ) | |
| encoded = {key: value.to(device) for key, value in encoded.items()} | |
| generation = model.generate( | |
| **encoded, | |
| max_new_tokens=max_completion_length, | |
| do_sample=True, | |
| temperature=0.9, | |
| top_p=0.95, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| prompt_ids: list[list[int]] = [] | |
| completion_ids: list[list[int]] = [] | |
| logprobs: list[list[float]] = [] | |
| input_width = encoded["input_ids"].shape[1] | |
| for batch_index in range(len(prompts)): | |
| prompt_ids.append( | |
| encoded["input_ids"][batch_index][encoded["attention_mask"][batch_index].bool()].tolist() | |
| ) | |
| sample_completion_ids: list[int] = [] | |
| sample_logprobs: list[float] = [] | |
| for step_index, step_scores in enumerate(generation.scores): | |
| token_position = input_width + step_index | |
| if token_position >= generation.sequences.shape[1]: | |
| break | |
| token_id = int(generation.sequences[batch_index, token_position].item()) | |
| if token_id == tokenizer.pad_token_id: | |
| break | |
| sample_completion_ids.append(token_id) | |
| token_logprob = torch.log_softmax(step_scores[batch_index], dim=-1)[token_id].item() | |
| sample_logprobs.append(float(token_logprob)) | |
| if tokenizer.eos_token_id is not None and token_id == tokenizer.eos_token_id: | |
| break | |
| completion_ids.append(sample_completion_ids) | |
| logprobs.append(sample_logprobs) | |
| return { | |
| "prompt_ids": prompt_ids, | |
| "completion_ids": completion_ids, | |
| "logprobs": logprobs, | |
| } | |
| def _parse_lora_target_modules(raw_value: str) -> str | list[str]: | |
| target_modules = [item.strip() for item in raw_value.split(",") if item.strip()] | |
| if not target_modules: | |
| raise ValueError("LoRA target modules cannot be empty.") | |
| if len(target_modules) == 1 and target_modules[0] == "all-linear": | |
| return "all-linear" | |
| return target_modules | |
| def _preview_rollouts( | |
| *, | |
| model: Any, | |
| tokenizer: Any, | |
| training_agent: str, | |
| port: int, | |
| replay_id: str, | |
| training_stage: str, | |
| samples: int, | |
| max_prompt_length: int, | |
| max_completion_length: int, | |
| ) -> None: | |
| import torch | |
| print("\nPreview rollouts") | |
| for sample_index in range(samples): | |
| client = TrenchesEnvClient(base_url=f"http://127.0.0.1:{port}/openenv") | |
| reset_result = client.reset( | |
| training_agent=training_agent, | |
| training_stage=training_stage, | |
| max_turns=1, | |
| replay_id=replay_id, | |
| episode_id=f"preview-{sample_index}-{int(time.time() * 1000)}", | |
| ) | |
| observation = reset_result.observation | |
| prompt = _render_observation_prompt( | |
| _build_base_prompt(training_agent), | |
| training_agent, | |
| observation, | |
| ) | |
| model_max_length = getattr(tokenizer, "model_max_length", None) | |
| if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 1_000_000: | |
| model_max_length = max_prompt_length | |
| preview_prompt_length = min(max_prompt_length, model_max_length) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=preview_prompt_length, | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_completion_length, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| completion_ids = output[0][inputs["input_ids"].shape[1] :] | |
| completion = tokenizer.decode(completion_ids, skip_special_tokens=True) | |
| action, prediction = _parse_turn_output(training_agent, completion) | |
| step_result = client.step( | |
| TrenchesOpenEnvAction( | |
| action=action, | |
| prediction=prediction, | |
| external_signals=[], | |
| ) | |
| ) | |
| step_obs = step_result.observation | |
| actual_event = step_obs.revealed_event.summary if step_obs.revealed_event is not None else "n/a" | |
| step_reward = step_result.reward if step_result.reward is not None else 0.0 | |
| print( | |
| f"[sample {sample_index + 1}] reward={step_reward:.3f} " | |
| f"action={action.type} topic={prediction.topic} actual={actual_event}" | |
| ) | |
| class OpenEnvGRPOTrainer: | |
| """Force GRPO to use the custom OpenEnv rollout path across generation backends.""" | |
| def _generate_single_turn(self, prompts: list[str]): # type: ignore[override] | |
| if getattr(self, "rollout_func", None) is None: | |
| return super()._generate_single_turn(prompts) | |
| output = self.rollout_func(prompts, self) | |
| required_keys = {"prompt_ids", "completion_ids", "logprobs"} | |
| missing = required_keys.difference(output) | |
| if missing: | |
| raise RuntimeError(f"rollout_func is missing required keys: {sorted(missing)}") | |
| extra_fields = {key: value for key, value in output.items() if key not in required_keys} | |
| return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train a replay-aware OpenEnv policy for Trenches.") | |
| parser.add_argument("--model-id", default=DEFAULT_MODEL_ID) | |
| parser.add_argument("--training-agent", default="us") | |
| parser.add_argument("--training-stage", default=DEFAULT_TRAINING_STAGE) | |
| parser.add_argument("--replay-id", default=DEFAULT_REPLAY_ID) | |
| parser.add_argument("--port", type=int, default=8000) | |
| parser.add_argument("--train-size", type=int, default=32) | |
| parser.add_argument("--max-steps", type=int, default=4) | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--generation-backend", choices=["auto", "vllm", "transformers"], default="auto") | |
| parser.add_argument("--max-prompt-length", type=int, default=1024) | |
| parser.add_argument("--max-completion-length", type=int, default=220) | |
| parser.add_argument("--per-device-train-batch-size", type=int, default=1) | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=1) | |
| parser.add_argument("--learning-rate", type=float, default=1e-6) | |
| parser.add_argument("--output-dir", default="trl-openenv-historical-replay") | |
| parser.add_argument("--preview-samples", type=int, default=3) | |
| parser.add_argument("--no-preview", action="store_true") | |
| # Post-training plan args | |
| parser.add_argument("--quantize-4bit", action="store_true", help="Load model with 4-bit NF4 quantization via bitsandbytes (requires CUDA)") | |
| parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank used with quantized training") | |
| parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha used with quantized training") | |
| parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout used with quantized training") | |
| parser.add_argument( | |
| "--lora-target-modules", | |
| default=DEFAULT_LORA_TARGET_MODULES, | |
| help='Comma-separated LoRA target modules, or "all-linear"', | |
| ) | |
| parser.add_argument("--beta", type=float, default=0.04, help="KL coefficient for GRPO") | |
| parser.add_argument("--warmup-steps", type=int, default=0, help="Number of warmup steps") | |
| parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature for generation") | |
| parser.add_argument("--top-k", type=int, default=0, help="Top-k sampling (0 = disabled)") | |
| parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling") | |
| parser.add_argument("--save-strategy", default="no", choices=["no", "steps", "epoch"], help="Checkpoint save strategy") | |
| parser.add_argument("--save-steps", type=int, default=100, help="Save checkpoint every N steps (when save-strategy=steps)") | |
| args = parser.parse_args() | |
| imports = _required_training_imports() | |
| torch = imports["torch"] | |
| AutoTokenizer = imports["AutoTokenizer"] | |
| GRPOConfig = imports["GRPOConfig"] | |
| GRPOTrainer = type("OpenEnvGRPOTrainer", (OpenEnvGRPOTrainer, imports["GRPOTrainer"]), {}) | |
| generate_rollout_completions = imports["generate_rollout_completions"] | |
| generation_backend = args.generation_backend | |
| if generation_backend == "auto": | |
| generation_backend = "vllm" if generate_rollout_completions is not None and _can_use_vllm(torch) else "transformers" | |
| if generation_backend == "vllm" and generate_rollout_completions is None: | |
| raise RuntimeError("The selected vLLM backend requires `trl.experimental.openenv.generate_rollout_completions`.") | |
| _launch_backend(args.port) | |
| # Model loading — optionally with 4-bit quantization | |
| model_ref = args.model_id | |
| peft_config = None | |
| if args.quantize_4bit: | |
| from transformers import AutoModelForCausalLM, BitsAndBytesConfig | |
| try: | |
| from peft import LoraConfig, TaskType | |
| except ModuleNotFoundError as exc: | |
| raise RuntimeError( | |
| "Missing PEFT dependency. Install backend[train] so quantized training can attach LoRA adapters." | |
| ) from exc | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| print(f"Loading {args.model_id} with 4-bit NF4 quantization") | |
| model_ref = AutoModelForCausalLM.from_pretrained( | |
| args.model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| bias="none", | |
| target_modules=_parse_lora_target_modules(args.lora_target_modules), | |
| ) | |
| print( | |
| "Attaching LoRA adapters for quantized training " | |
| f"(r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}, " | |
| f"targets={args.lora_target_modules})" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_id) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| train_dataset = _build_dataset(args.training_agent, args.train_size) | |
| base_prompt = _build_base_prompt(args.training_agent) | |
| def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list[Any]]: | |
| prompt_ids: list[list[int]] = [] | |
| completion_ids: list[list[int]] = [] | |
| logprobs: list[list[float]] = [] | |
| env_rewards: list[float] = [] | |
| forecast_rewards: list[float] = [] | |
| for index, prompt in enumerate(prompts): | |
| with TrenchesEnvClient(base_url=f"http://127.0.0.1:{args.port}/openenv") as client: | |
| reset_result = client.reset( | |
| training_agent=args.training_agent, | |
| training_stage=args.training_stage, | |
| max_turns=1, | |
| replay_id=args.replay_id, | |
| episode_id=f"train-{index}-{int(time.time() * 1000)}", | |
| ) | |
| grounded_prompt = _render_observation_prompt( | |
| prompt or base_prompt, | |
| args.training_agent, | |
| reset_result.observation, | |
| ) | |
| if generation_backend == "vllm": | |
| rollout_output = generate_rollout_completions(trainer, [grounded_prompt])[0] | |
| else: | |
| rollout_output = { | |
| key: value[0] | |
| for key, value in _generate_rollout_completions_transformers( | |
| trainer=trainer, | |
| prompts=[grounded_prompt], | |
| tokenizer=tokenizer, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| ).items() | |
| } | |
| completion_text = tokenizer.decode(rollout_output["completion_ids"], skip_special_tokens=True) | |
| action, prediction = _parse_turn_output(args.training_agent, completion_text) | |
| step_result = client.step( | |
| TrenchesOpenEnvAction( | |
| action=action, | |
| prediction=prediction, | |
| external_signals=[], | |
| ) | |
| ) | |
| prompt_ids.append(list(rollout_output["prompt_ids"])) | |
| completion_ids.append(list(rollout_output["completion_ids"])) | |
| logprobs.append([float(value) for value in rollout_output["logprobs"]]) | |
| step_reward = step_result.reward if step_result.reward is not None else 0.0 | |
| step_obs = step_result.observation | |
| forecast_total = step_obs.reward_breakdown.forecast_total if step_obs.reward_breakdown is not None else 0.0 | |
| env_rewards.append(float(step_reward)) | |
| forecast_rewards.append(float(forecast_total)) | |
| return { | |
| "prompt_ids": prompt_ids, | |
| "completion_ids": completion_ids, | |
| "logprobs": logprobs, | |
| "env_reward": env_rewards, | |
| "forecast_reward": forecast_rewards, | |
| } | |
| def reward_from_env(completions: list[str], **kwargs: Any) -> list[float]: | |
| rewards = kwargs.get("env_reward") | |
| if rewards is None: | |
| return [0.0 for _ in completions] | |
| return [float(reward) for reward in rewards] | |
| training_kwargs = { | |
| "output_dir": args.output_dir, | |
| "learning_rate": args.learning_rate, | |
| "max_steps": args.max_steps, | |
| "num_train_epochs": 1, | |
| "per_device_train_batch_size": args.per_device_train_batch_size, | |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, | |
| "num_generations": args.num_generations, | |
| "generation_batch_size": args.num_generations, | |
| "max_prompt_length": args.max_prompt_length, | |
| "max_completion_length": args.max_completion_length, | |
| "logging_steps": 1, | |
| "report_to": [], | |
| "use_vllm": generation_backend == "vllm", | |
| "beta": args.beta, | |
| "warmup_steps": args.warmup_steps, | |
| "temperature": args.temperature, | |
| "top_p": args.top_p, | |
| "save_strategy": args.save_strategy, | |
| "save_steps": args.save_steps, | |
| } | |
| if not args.quantize_4bit: | |
| training_kwargs["bf16"] = True | |
| training_kwargs["model_init_kwargs"] = {"dtype": "bfloat16"} | |
| if generation_backend == "vllm": | |
| training_kwargs["vllm_mode"] = "colocate" | |
| training_args = GRPOConfig(**training_kwargs) | |
| trainer = GRPOTrainer( | |
| model=model_ref, | |
| processing_class=tokenizer, | |
| reward_funcs=reward_from_env, | |
| train_dataset=train_dataset, | |
| rollout_func=rollout_func, | |
| args=training_args, | |
| peft_config=peft_config, | |
| ) | |
| train_result = trainer.train() | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| print("\nTraining complete") | |
| print(f"Generation backend: {generation_backend}") | |
| print(json.dumps(train_result.metrics, indent=2, sort_keys=True)) | |
| if not args.no_preview and args.preview_samples > 0: | |
| _preview_rollouts( | |
| model=trainer.model, | |
| tokenizer=tokenizer, | |
| training_agent=args.training_agent, | |
| port=args.port, | |
| replay_id=args.replay_id, | |
| training_stage=args.training_stage, | |
| samples=args.preview_samples, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |