| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import threading |
| | import time |
| | from typing import Any |
| |
|
| | import httpx |
| | import uvicorn |
| |
|
| | from trenches_env.models import AgentAction, Prediction |
| | from trenches_env.openenv_adapter import TrenchesOpenEnvAction, TrenchesOpenEnvObservation |
| | from trenches_env.openenv_client import TrenchesEnvClient |
| | from trenches_env.rl import AGENT_ALLOWED_ACTIONS |
| | from trenches_env.server import create_app |
| |
|
| | DEFAULT_MODEL_ID = "Qwen/Qwen3-8B" |
| | DEFAULT_REPLAY_ID = "us_synthetic_seed_2025_2026" |
| | DEFAULT_TRAINING_STAGE = "stage_1_dense" |
| | DEFAULT_LORA_TARGET_MODULES = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" |
| |
|
| |
|
| | def _launch_backend(port: int) -> None: |
| | app = create_app() |
| | thread = threading.Thread( |
| | target=lambda: uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning"), |
| | daemon=True, |
| | ) |
| | thread.start() |
| |
|
| | deadline = time.time() + 30.0 |
| | url = f"http://127.0.0.1:{port}/healthz" |
| | while time.time() < deadline: |
| | try: |
| | response = httpx.get(url, timeout=1.0) |
| | if response.status_code == 200: |
| | return |
| | except httpx.HTTPError: |
| | time.sleep(0.5) |
| | raise RuntimeError(f"Timed out waiting for backend health at {url}") |
| |
|
| |
|
| | def _build_base_prompt(training_agent: str) -> str: |
| | return ( |
| | f"You are training the {training_agent} policy in the Trenches OpenEnv historical replay environment. " |
| | "Return strict JSON only. " |
| | "Choose one legal action and forecast the next historical event from the visible timeline." |
| | ) |
| |
|
| |
|
| | def _render_observation_prompt( |
| | base_prompt: str, |
| | training_agent: str, |
| | observation: TrenchesOpenEnvObservation, |
| | ) -> str: |
| | agent_observation = observation.agent_observation |
| | public_brief = "\n".join(f"- {item.summary}" for item in agent_observation.public_brief[:4]) or "- None." |
| | private_brief = "\n".join(f"- {item.summary}" for item in agent_observation.private_brief[:4]) or "- None." |
| | historical_brief = "\n".join(f"- {line}" for line in agent_observation.historical_brief[:4]) or "- None." |
| | strategic_state = "\n".join( |
| | f"- {metric}: {value:.1f}" for metric, value in agent_observation.strategic_state.items() |
| | ) or "- None." |
| | available_actions = ", ".join(agent_observation.available_actions) |
| | return "\n".join( |
| | [ |
| | base_prompt, |
| | "", |
| | f"Training agent: {training_agent}", |
| | f"Turn: {observation.turn}", |
| | f"Decision prompt:\n{agent_observation.decision_prompt}", |
| | "Historical brief:", |
| | historical_brief, |
| | "Public brief:", |
| | public_brief, |
| | "Private brief:", |
| | private_brief, |
| | "Strategic state:", |
| | strategic_state, |
| | f"Allowed actions: {available_actions}", |
| | "Output schema:", |
| | "{", |
| | ' "action": {', |
| | f' "type": "{agent_observation.available_actions[0] if agent_observation.available_actions else "hold"}",', |
| | ' "target": "optional_target",', |
| | ' "summary": "one-sentence action rationale"', |
| | " },", |
| | ' "prediction": {', |
| | ' "topic": "shipping|border|corridor|domestic|cyber|market|humanitarian|diplomacy|commodities",', |
| | ' "predicted_actor": "actor or null",', |
| | ' "predicted_target": "target or null",', |
| | ' "time_horizon_turns": 1,', |
| | ' "expected_severity": "low|medium|high|critical",', |
| | ' "confidence": 0.0,', |
| | ' "summary": "one-sentence forecast",', |
| | ' "rationale": "why this next event is likely"', |
| | " }", |
| | "}", |
| | ] |
| | ) |
| |
|
| |
|
| | def _safe_json_loads(text: str) -> dict[str, Any]: |
| | text = text.strip() |
| | if not text: |
| | return {} |
| | try: |
| | return json.loads(text) |
| | except json.JSONDecodeError: |
| | start = text.find("{") |
| | end = text.rfind("}") |
| | if start == -1 or end == -1 or end <= start: |
| | return {} |
| | try: |
| | return json.loads(text[start : end + 1]) |
| | except json.JSONDecodeError: |
| | return {} |
| |
|
| |
|
| | def _parse_turn_output(training_agent: str, completion: str) -> tuple[AgentAction, Prediction]: |
| | payload = _safe_json_loads(completion) |
| | action_payload = payload.get("action") if isinstance(payload.get("action"), dict) else {} |
| | prediction_payload = payload.get("prediction") if isinstance(payload.get("prediction"), dict) else {} |
| |
|
| | action_type = str(action_payload.get("type") or "hold") |
| | if action_type not in AGENT_ALLOWED_ACTIONS.get(training_agent, ()): |
| | action_type = "hold" |
| | target = action_payload.get("target") |
| | action_summary = str( |
| | action_payload.get("summary") |
| | or "Fallback hold after invalid or partial model completion." |
| | ) |
| |
|
| | prediction_topic = str(prediction_payload.get("topic") or "diplomacy") |
| | predicted_actor = prediction_payload.get("predicted_actor") |
| | predicted_target = prediction_payload.get("predicted_target") |
| | prediction_summary = str( |
| | prediction_payload.get("summary") |
| | or "Fallback low-confidence forecast after invalid or partial model completion." |
| | ) |
| | rationale = str(prediction_payload.get("rationale") or "Parser fallback.") |
| | confidence_raw = prediction_payload.get("confidence", 0.1) |
| | try: |
| | confidence = max(0.0, min(1.0, float(confidence_raw))) |
| | except (TypeError, ValueError): |
| | confidence = 0.1 |
| |
|
| | horizon_raw = prediction_payload.get("time_horizon_turns", 1) |
| | try: |
| | time_horizon_turns = max(1, int(horizon_raw)) |
| | except (TypeError, ValueError): |
| | time_horizon_turns = 1 |
| |
|
| | expected_severity = str(prediction_payload.get("expected_severity") or "medium") |
| | if expected_severity not in {"low", "medium", "high", "critical"}: |
| | expected_severity = "medium" |
| |
|
| | return ( |
| | AgentAction( |
| | actor=training_agent, |
| | type=action_type, |
| | target=target if isinstance(target, str) else None, |
| | summary=action_summary, |
| | ), |
| | Prediction( |
| | agent_id=training_agent, |
| | topic=prediction_topic, |
| | predicted_actor=predicted_actor if isinstance(predicted_actor, str) else None, |
| | predicted_target=predicted_target if isinstance(predicted_target, str) else None, |
| | time_horizon_turns=time_horizon_turns, |
| | expected_severity=expected_severity, |
| | confidence=confidence, |
| | summary=prediction_summary, |
| | rationale=rationale, |
| | ), |
| | ) |
| |
|
| |
|
| | def _build_dataset(training_agent: str, size: int): |
| | from datasets import Dataset |
| |
|
| | base_prompt = _build_base_prompt(training_agent) |
| | return Dataset.from_dict({"prompt": [base_prompt] * size}) |
| |
|
| |
|
| | def _required_training_imports() -> dict[str, Any]: |
| | try: |
| | import torch |
| | from transformers import AutoTokenizer |
| | from trl import GRPOConfig, GRPOTrainer |
| | except ModuleNotFoundError as exc: |
| | raise RuntimeError( |
| | "Missing training dependencies. Install torch, transformers, trl, accelerate, and openenv-core first." |
| | ) from exc |
| | try: |
| | from trl.experimental.openenv import generate_rollout_completions |
| | except ModuleNotFoundError: |
| | generate_rollout_completions = None |
| | return { |
| | "torch": torch, |
| | "AutoTokenizer": AutoTokenizer, |
| | "GRPOConfig": GRPOConfig, |
| | "GRPOTrainer": GRPOTrainer, |
| | "generate_rollout_completions": generate_rollout_completions, |
| | } |
| |
|
| |
|
| | def _resolve_model_device(model: Any) -> Any: |
| | device = getattr(model, "device", None) |
| | if device is not None: |
| | return device |
| | return next(model.parameters()).device |
| |
|
| |
|
| | def _can_use_vllm(torch_module: Any) -> bool: |
| | if not hasattr(torch_module, "cuda") or not torch_module.cuda.is_available(): |
| | return False |
| | try: |
| | import vllm |
| | except ModuleNotFoundError: |
| | return False |
| | return True |
| |
|
| |
|
| | def _generate_rollout_completions_transformers( |
| | *, |
| | trainer: Any, |
| | prompts: list[str], |
| | tokenizer: Any, |
| | max_prompt_length: int, |
| | max_completion_length: int, |
| | ) -> dict[str, list[Any]]: |
| | import torch |
| |
|
| | tokenizer.padding_side = "left" |
| | model = trainer.model |
| | device = _resolve_model_device(model) |
| | encoded = tokenizer( |
| | prompts, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=max_prompt_length, |
| | ) |
| | encoded = {key: value.to(device) for key, value in encoded.items()} |
| | generation = model.generate( |
| | **encoded, |
| | max_new_tokens=max_completion_length, |
| | do_sample=True, |
| | temperature=0.9, |
| | top_p=0.95, |
| | return_dict_in_generate=True, |
| | output_scores=True, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | prompt_ids: list[list[int]] = [] |
| | completion_ids: list[list[int]] = [] |
| | logprobs: list[list[float]] = [] |
| | input_width = encoded["input_ids"].shape[1] |
| |
|
| | for batch_index in range(len(prompts)): |
| | prompt_ids.append( |
| | encoded["input_ids"][batch_index][encoded["attention_mask"][batch_index].bool()].tolist() |
| | ) |
| | sample_completion_ids: list[int] = [] |
| | sample_logprobs: list[float] = [] |
| |
|
| | for step_index, step_scores in enumerate(generation.scores): |
| | token_position = input_width + step_index |
| | if token_position >= generation.sequences.shape[1]: |
| | break |
| | token_id = int(generation.sequences[batch_index, token_position].item()) |
| | if token_id == tokenizer.pad_token_id: |
| | break |
| | sample_completion_ids.append(token_id) |
| | token_logprob = torch.log_softmax(step_scores[batch_index], dim=-1)[token_id].item() |
| | sample_logprobs.append(float(token_logprob)) |
| | if tokenizer.eos_token_id is not None and token_id == tokenizer.eos_token_id: |
| | break |
| |
|
| | completion_ids.append(sample_completion_ids) |
| | logprobs.append(sample_logprobs) |
| |
|
| | return { |
| | "prompt_ids": prompt_ids, |
| | "completion_ids": completion_ids, |
| | "logprobs": logprobs, |
| | } |
| |
|
| |
|
| | def _parse_lora_target_modules(raw_value: str) -> str | list[str]: |
| | target_modules = [item.strip() for item in raw_value.split(",") if item.strip()] |
| | if not target_modules: |
| | raise ValueError("LoRA target modules cannot be empty.") |
| | if len(target_modules) == 1 and target_modules[0] == "all-linear": |
| | return "all-linear" |
| | return target_modules |
| |
|
| |
|
| | def _preview_rollouts( |
| | *, |
| | model: Any, |
| | tokenizer: Any, |
| | training_agent: str, |
| | port: int, |
| | replay_id: str, |
| | training_stage: str, |
| | samples: int, |
| | max_prompt_length: int, |
| | max_completion_length: int, |
| | ) -> None: |
| | import torch |
| |
|
| | print("\nPreview rollouts") |
| | for sample_index in range(samples): |
| | client = TrenchesEnvClient(base_url=f"http://127.0.0.1:{port}/openenv") |
| | reset_result = client.reset( |
| | training_agent=training_agent, |
| | training_stage=training_stage, |
| | max_turns=1, |
| | replay_id=replay_id, |
| | episode_id=f"preview-{sample_index}-{int(time.time() * 1000)}", |
| | ) |
| | observation = reset_result.observation |
| | prompt = _render_observation_prompt( |
| | _build_base_prompt(training_agent), |
| | training_agent, |
| | observation, |
| | ) |
| | model_max_length = getattr(tokenizer, "model_max_length", None) |
| | if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 1_000_000: |
| | model_max_length = max_prompt_length |
| | preview_prompt_length = min(max_prompt_length, model_max_length) |
| | inputs = tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=preview_prompt_length, |
| | ).to(model.device) |
| | with torch.no_grad(): |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=max_completion_length, |
| | do_sample=False, |
| | pad_token_id=tokenizer.pad_token_id, |
| | ) |
| | completion_ids = output[0][inputs["input_ids"].shape[1] :] |
| | completion = tokenizer.decode(completion_ids, skip_special_tokens=True) |
| | action, prediction = _parse_turn_output(training_agent, completion) |
| | step_result = client.step( |
| | TrenchesOpenEnvAction( |
| | action=action, |
| | prediction=prediction, |
| | external_signals=[], |
| | ) |
| | ) |
| | step_obs = step_result.observation |
| | actual_event = step_obs.revealed_event.summary if step_obs.revealed_event is not None else "n/a" |
| | step_reward = step_result.reward if step_result.reward is not None else 0.0 |
| | print( |
| | f"[sample {sample_index + 1}] reward={step_reward:.3f} " |
| | f"action={action.type} topic={prediction.topic} actual={actual_event}" |
| | ) |
| |
|
| |
|
| | class OpenEnvGRPOTrainer: |
| | """Force GRPO to use the custom OpenEnv rollout path across generation backends.""" |
| |
|
| | def _generate_single_turn(self, prompts: list[str]): |
| | if getattr(self, "rollout_func", None) is None: |
| | return super()._generate_single_turn(prompts) |
| |
|
| | output = self.rollout_func(prompts, self) |
| | required_keys = {"prompt_ids", "completion_ids", "logprobs"} |
| | missing = required_keys.difference(output) |
| | if missing: |
| | raise RuntimeError(f"rollout_func is missing required keys: {sorted(missing)}") |
| | extra_fields = {key: value for key, value in output.items() if key not in required_keys} |
| | return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser(description="Train a replay-aware OpenEnv policy for Trenches.") |
| | parser.add_argument("--model-id", default=DEFAULT_MODEL_ID) |
| | parser.add_argument("--training-agent", default="us") |
| | parser.add_argument("--training-stage", default=DEFAULT_TRAINING_STAGE) |
| | parser.add_argument("--replay-id", default=DEFAULT_REPLAY_ID) |
| | parser.add_argument("--port", type=int, default=8000) |
| | parser.add_argument("--train-size", type=int, default=32) |
| | parser.add_argument("--max-steps", type=int, default=4) |
| | parser.add_argument("--num-generations", type=int, default=4) |
| | parser.add_argument("--generation-backend", choices=["auto", "vllm", "transformers"], default="auto") |
| | parser.add_argument("--max-prompt-length", type=int, default=1024) |
| | parser.add_argument("--max-completion-length", type=int, default=220) |
| | parser.add_argument("--per-device-train-batch-size", type=int, default=1) |
| | parser.add_argument("--gradient-accumulation-steps", type=int, default=1) |
| | parser.add_argument("--learning-rate", type=float, default=1e-6) |
| | parser.add_argument("--output-dir", default="trl-openenv-historical-replay") |
| | parser.add_argument("--preview-samples", type=int, default=3) |
| | parser.add_argument("--no-preview", action="store_true") |
| | |
| | parser.add_argument("--quantize-4bit", action="store_true", help="Load model with 4-bit NF4 quantization via bitsandbytes (requires CUDA)") |
| | parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank used with quantized training") |
| | parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha used with quantized training") |
| | parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout used with quantized training") |
| | parser.add_argument( |
| | "--lora-target-modules", |
| | default=DEFAULT_LORA_TARGET_MODULES, |
| | help='Comma-separated LoRA target modules, or "all-linear"', |
| | ) |
| | parser.add_argument("--beta", type=float, default=0.04, help="KL coefficient for GRPO") |
| | parser.add_argument("--warmup-steps", type=int, default=0, help="Number of warmup steps") |
| | parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature for generation") |
| | parser.add_argument("--top-k", type=int, default=0, help="Top-k sampling (0 = disabled)") |
| | parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling") |
| | parser.add_argument("--save-strategy", default="no", choices=["no", "steps", "epoch"], help="Checkpoint save strategy") |
| | parser.add_argument("--save-steps", type=int, default=100, help="Save checkpoint every N steps (when save-strategy=steps)") |
| | args = parser.parse_args() |
| |
|
| | imports = _required_training_imports() |
| | torch = imports["torch"] |
| | AutoTokenizer = imports["AutoTokenizer"] |
| | GRPOConfig = imports["GRPOConfig"] |
| | GRPOTrainer = type("OpenEnvGRPOTrainer", (OpenEnvGRPOTrainer, imports["GRPOTrainer"]), {}) |
| | generate_rollout_completions = imports["generate_rollout_completions"] |
| | generation_backend = args.generation_backend |
| | if generation_backend == "auto": |
| | generation_backend = "vllm" if generate_rollout_completions is not None and _can_use_vllm(torch) else "transformers" |
| | if generation_backend == "vllm" and generate_rollout_completions is None: |
| | raise RuntimeError("The selected vLLM backend requires `trl.experimental.openenv.generate_rollout_completions`.") |
| |
|
| | _launch_backend(args.port) |
| |
|
| | |
| | model_ref = args.model_id |
| | peft_config = None |
| | if args.quantize_4bit: |
| | from transformers import AutoModelForCausalLM, BitsAndBytesConfig |
| | try: |
| | from peft import LoraConfig, TaskType |
| | except ModuleNotFoundError as exc: |
| | raise RuntimeError( |
| | "Missing PEFT dependency. Install backend[train] so quantized training can attach LoRA adapters." |
| | ) from exc |
| |
|
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | ) |
| | print(f"Loading {args.model_id} with 4-bit NF4 quantization") |
| | model_ref = AutoModelForCausalLM.from_pretrained( |
| | args.model_id, |
| | quantization_config=bnb_config, |
| | device_map="auto", |
| | ) |
| | peft_config = LoraConfig( |
| | task_type=TaskType.CAUSAL_LM, |
| | inference_mode=False, |
| | r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | lora_dropout=args.lora_dropout, |
| | bias="none", |
| | target_modules=_parse_lora_target_modules(args.lora_target_modules), |
| | ) |
| | print( |
| | "Attaching LoRA adapters for quantized training " |
| | f"(r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}, " |
| | f"targets={args.lora_target_modules})" |
| | ) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(args.model_id) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | train_dataset = _build_dataset(args.training_agent, args.train_size) |
| | base_prompt = _build_base_prompt(args.training_agent) |
| |
|
| | def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list[Any]]: |
| | prompt_ids: list[list[int]] = [] |
| | completion_ids: list[list[int]] = [] |
| | logprobs: list[list[float]] = [] |
| | env_rewards: list[float] = [] |
| | forecast_rewards: list[float] = [] |
| |
|
| | for index, prompt in enumerate(prompts): |
| | with TrenchesEnvClient(base_url=f"http://127.0.0.1:{args.port}/openenv") as client: |
| | reset_result = client.reset( |
| | training_agent=args.training_agent, |
| | training_stage=args.training_stage, |
| | max_turns=1, |
| | replay_id=args.replay_id, |
| | episode_id=f"train-{index}-{int(time.time() * 1000)}", |
| | ) |
| | grounded_prompt = _render_observation_prompt( |
| | prompt or base_prompt, |
| | args.training_agent, |
| | reset_result.observation, |
| | ) |
| | if generation_backend == "vllm": |
| | rollout_output = generate_rollout_completions(trainer, [grounded_prompt])[0] |
| | else: |
| | rollout_output = { |
| | key: value[0] |
| | for key, value in _generate_rollout_completions_transformers( |
| | trainer=trainer, |
| | prompts=[grounded_prompt], |
| | tokenizer=tokenizer, |
| | max_prompt_length=args.max_prompt_length, |
| | max_completion_length=args.max_completion_length, |
| | ).items() |
| | } |
| |
|
| | completion_text = tokenizer.decode(rollout_output["completion_ids"], skip_special_tokens=True) |
| | action, prediction = _parse_turn_output(args.training_agent, completion_text) |
| | step_result = client.step( |
| | TrenchesOpenEnvAction( |
| | action=action, |
| | prediction=prediction, |
| | external_signals=[], |
| | ) |
| | ) |
| |
|
| | prompt_ids.append(list(rollout_output["prompt_ids"])) |
| | completion_ids.append(list(rollout_output["completion_ids"])) |
| | logprobs.append([float(value) for value in rollout_output["logprobs"]]) |
| | step_reward = step_result.reward if step_result.reward is not None else 0.0 |
| | step_obs = step_result.observation |
| | forecast_total = step_obs.reward_breakdown.forecast_total if step_obs.reward_breakdown is not None else 0.0 |
| | env_rewards.append(float(step_reward)) |
| | forecast_rewards.append(float(forecast_total)) |
| |
|
| | return { |
| | "prompt_ids": prompt_ids, |
| | "completion_ids": completion_ids, |
| | "logprobs": logprobs, |
| | "env_reward": env_rewards, |
| | "forecast_reward": forecast_rewards, |
| | } |
| |
|
| | def reward_from_env(completions: list[str], **kwargs: Any) -> list[float]: |
| | rewards = kwargs.get("env_reward") |
| | if rewards is None: |
| | return [0.0 for _ in completions] |
| | return [float(reward) for reward in rewards] |
| |
|
| | training_kwargs = { |
| | "output_dir": args.output_dir, |
| | "learning_rate": args.learning_rate, |
| | "max_steps": args.max_steps, |
| | "num_train_epochs": 1, |
| | "per_device_train_batch_size": args.per_device_train_batch_size, |
| | "gradient_accumulation_steps": args.gradient_accumulation_steps, |
| | "num_generations": args.num_generations, |
| | "generation_batch_size": args.num_generations, |
| | "max_prompt_length": args.max_prompt_length, |
| | "max_completion_length": args.max_completion_length, |
| | "logging_steps": 1, |
| | "report_to": [], |
| | "use_vllm": generation_backend == "vllm", |
| | "beta": args.beta, |
| | "warmup_steps": args.warmup_steps, |
| | "temperature": args.temperature, |
| | "top_p": args.top_p, |
| | "save_strategy": args.save_strategy, |
| | "save_steps": args.save_steps, |
| | } |
| | if not args.quantize_4bit: |
| | training_kwargs["bf16"] = True |
| | training_kwargs["model_init_kwargs"] = {"dtype": "bfloat16"} |
| | if generation_backend == "vllm": |
| | training_kwargs["vllm_mode"] = "colocate" |
| |
|
| | training_args = GRPOConfig(**training_kwargs) |
| |
|
| | trainer = GRPOTrainer( |
| | model=model_ref, |
| | processing_class=tokenizer, |
| | reward_funcs=reward_from_env, |
| | train_dataset=train_dataset, |
| | rollout_func=rollout_func, |
| | args=training_args, |
| | peft_config=peft_config, |
| | ) |
| |
|
| | train_result = trainer.train() |
| | trainer.save_model(args.output_dir) |
| | tokenizer.save_pretrained(args.output_dir) |
| |
|
| | print("\nTraining complete") |
| | print(f"Generation backend: {generation_backend}") |
| | print(json.dumps(train_result.metrics, indent=2, sort_keys=True)) |
| |
|
| | if not args.no_preview and args.preview_samples > 0: |
| | _preview_rollouts( |
| | model=trainer.model, |
| | tokenizer=tokenizer, |
| | training_agent=args.training_agent, |
| | port=args.port, |
| | replay_id=args.replay_id, |
| | training_stage=args.training_stage, |
| | samples=args.preview_samples, |
| | max_prompt_length=args.max_prompt_length, |
| | max_completion_length=args.max_completion_length, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|