test-true / train.py
BRlkl's picture
Upload folder using huggingface_hub
d97bf05 verified
Raw
History Blame Contribute Delete
54.9 kB
from __future__ import annotations
import argparse
import json
import math
import os
import random
import secrets
import shutil
import sys
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import get_cosine_schedule_with_warmup
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
from config import ensure_dir, flatten_for_wandb, load_config, save_json
from data import (
ThoughtLoopConversationDataset,
build_chunk_batch,
identity_collate,
resolve_bucket_horizon_ticks,
)
from model import ThoughtLoopT5Gemma
try:
import wandb
except ImportError: # pragma: no cover
wandb = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the recurrent T5Gemma thought-loop model.")
parser.add_argument("--config", default="configs/sft_b200.yaml")
parser.add_argument(
"--resume-from-checkpoint",
default=None,
help="Local checkpoint path to resume from, or 'latest'/'auto' to use the newest checkpoint in the run dir.",
)
parser.add_argument(
"--wandb-run-id",
default=None,
help="Optional W&B run id to resume. If omitted, train.py reuses a stored/local run id when available.",
)
return parser.parse_args()
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def build_optimizer(model: torch.nn.Module, training_cfg: dict[str, Any]) -> torch.optim.Optimizer:
no_decay_terms = ("bias", "norm", "Norm")
decay_params = []
no_decay_params = []
for name, parameter in model.named_parameters():
target = no_decay_params if any(term in name for term in no_decay_terms) else decay_params
target.append(parameter)
fused = bool(training_cfg.get("fused_adamw", True) and torch.cuda.is_available())
return torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": float(training_cfg["weight_decay"])},
{"params": no_decay_params, "weight_decay": 0.0},
],
lr=float(training_cfg["learning_rate"]),
betas=(float(training_cfg["adam_beta1"]), float(training_cfg["adam_beta2"])),
eps=float(training_cfg["adam_epsilon"]),
fused=fused,
)
def resolve_bucket_tbptt_steps(training_cfg: dict[str, Any], duration_bucket: str) -> int | None:
bucket_map = training_cfg.get("tbptt_steps_by_bucket", {})
if duration_bucket in bucket_map:
return int(bucket_map[duration_bucket])
start_steps = int(training_cfg.get("tbptt_start_steps", 0))
end_steps = int(training_cfg.get("tbptt_end_steps", start_steps))
if start_steps <= 0 or end_steps <= 0:
return None
return end_steps
def resolve_linear_schedule_value(
*,
enabled: bool,
global_step: int,
total_update_steps: int,
start_fraction: float,
end_fraction: float,
start_value: float,
end_value: float,
) -> float | None:
if not enabled:
return None
clamped_total_steps = max(total_update_steps, 1)
start_step = int(round(clamped_total_steps * start_fraction))
end_step = int(round(clamped_total_steps * end_fraction))
if end_step <= start_step:
end_step = start_step + 1
if global_step <= start_step:
return start_value
if global_step >= end_step:
return end_value
progress = (global_step - start_step) / max(end_step - start_step, 1)
return start_value + progress * (end_value - start_value)
def should_freeze_gate_head(
training_cfg: dict[str, Any],
global_step: int,
total_update_steps: int,
) -> bool:
if not bool(training_cfg.get("freeze_gate_head_first_half", False)):
return False
freeze_fraction = float(training_cfg.get("gate_head_freeze_fraction", 0.5))
return global_step < int(round(max(total_update_steps, 1) * freeze_fraction))
def resolve_decoder_trainable_fraction(
training_cfg: dict[str, Any],
global_step: int,
total_update_steps: int,
) -> float | None:
return resolve_linear_schedule_value(
enabled=bool(training_cfg.get("gradual_unfreeze_decoder", False)),
global_step=global_step,
total_update_steps=total_update_steps,
start_fraction=float(training_cfg.get("decoder_unfreeze_start_fraction", 0.0)),
end_fraction=float(training_cfg.get("decoder_unfreeze_end_fraction", 0.5)),
start_value=float(training_cfg.get("decoder_initial_trainable_fraction", 0.0)),
end_value=1.0,
)
def resolve_batch_duration_bucket(conversations: list[dict[str, Any]]) -> str:
if not conversations:
return "short"
buckets = {str(conversation["duration_bucket"]) for conversation in conversations}
if len(buckets) == 1:
return next(iter(buckets))
return "mixed"
def ensure_single_chunk_dataset(dataset: ThoughtLoopConversationDataset, split: str) -> None:
offending_examples = [
example for example in dataset.examples if int(example.get("chunk_count", 0)) != 1
]
if offending_examples:
raise ValueError(
f"{split} dataset expected exactly one chunk per conversation, but found "
f"{len(offending_examples)} multi-chunk examples. Check chunk_ticks/max_horizon_ticks."
)
def maybe_sort_dataset_by_length(
dataset: ThoughtLoopConversationDataset,
*,
enabled: bool,
) -> None:
if not enabled:
return
dataset.examples.sort(
key=lambda item: (
int(item.get("effective_total_ticks", item.get("total_ticks", 0))),
int(item.get("total_ticks", 0)),
str(item.get("row_id", "")),
)
)
def move_chunk_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
return {key: value.to(device) for key, value in batch.items()}
def apply_runtime_model_config(model: ThoughtLoopT5Gemma, config: dict[str, Any]) -> None:
model.config = config
model_cfg = config["model"]
expected_z_slots = int(model_cfg["z_slots"])
if int(model.z_slots) != expected_z_slots:
raise ValueError(
f"Loaded model has z_slots={model.z_slots}, but config requested z_slots={expected_z_slots}."
)
model.thought_loop_proposal_mode = str(
model_cfg.get("thought_loop_proposal_mode", model.thought_loop_proposal_mode)
).strip().lower()
model.preserve_observation_encoder_manifold = bool(
model_cfg.get(
"preserve_observation_encoder_manifold",
model.thought_loop_proposal_mode == "observation_hidden_compression",
)
)
model.observation_encoder_use_state_context = bool(
model_cfg.get("observation_encoder_use_state_context", False)
)
model.latent_attention_mask_mode = str(
model_cfg.get("latent_attention_mask_mode", getattr(model, "latent_attention_mask_mode", "observed"))
).strip().lower()
if model.latent_attention_mask_mode not in {"observed", "full"}:
raise ValueError(f"Unsupported latent_attention_mask_mode: {model.latent_attention_mask_mode}")
def build_initial_model(config: dict[str, Any]) -> ThoughtLoopT5Gemma:
initial_model_path = config["model"].get("initial_model_path")
if initial_model_path:
model = ThoughtLoopT5Gemma.from_pretrained(str(initial_model_path), device="cpu", map_location="cpu")
apply_runtime_model_config(model, config)
return model
return ThoughtLoopT5Gemma(config)
def rollout_active_rows_only(
*,
model: ThoughtLoopT5Gemma,
batch: dict[str, torch.Tensor],
z_state: torch.Tensor,
state_mask: torch.Tensor,
step_index: int,
active_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
active_indices = torch.nonzero(active_mask, as_tuple=False).flatten()
active_next_z, active_gate_logits, active_next_mask = model.rollout_step_with_mask(
z_state=z_state[active_indices],
state_mask=state_mask[active_indices],
observation_input_ids=batch["observation_input_ids"][active_indices, step_index, :],
observation_attention_mask=batch["observation_attention_mask"][active_indices, step_index, :],
observation_role_ids=batch["observation_role"][active_indices, step_index],
delta_seconds=batch["delta_seconds"][active_indices, step_index],
elapsed_seconds=batch["elapsed_seconds"][active_indices, step_index],
since_last_user_seconds=torch.zeros_like(batch["elapsed_seconds"][active_indices, step_index]),
since_last_agent_seconds=torch.zeros_like(batch["elapsed_seconds"][active_indices, step_index]),
)
next_z = z_state.clone()
next_z[active_indices] = active_next_z
next_state_mask = state_mask.clone()
next_state_mask[active_indices] = active_next_mask
gate_logits = torch.zeros(batch["tick_mask"].shape[0], device=z_state.device, dtype=active_gate_logits.dtype)
gate_logits[active_indices] = active_gate_logits
return next_z, next_state_mask, gate_logits
def compute_gate_positive_weight(batch: dict[str, torch.Tensor], training_cfg: dict[str, Any]) -> torch.Tensor:
if not bool(training_cfg.get("use_dynamic_gate_positive_weight", True)):
return torch.tensor(float(training_cfg["gate_positive_weight"]), device=batch["tick_mask"].device)
active_targets = batch["gate_target"][batch["tick_mask"]]
positive_count = active_targets.sum()
negative_count = active_targets.numel() - positive_count
if positive_count.item() <= 0.0:
fallback = float(training_cfg["gate_positive_weight"])
return torch.tensor(fallback, device=batch["tick_mask"].device)
dynamic_weight = negative_count / positive_count.clamp_min(1.0)
min_weight = float(training_cfg.get("dynamic_gate_positive_weight_min", 1.0))
max_weight = float(training_cfg.get("dynamic_gate_positive_weight_max", 256.0))
dynamic_weight = dynamic_weight.clamp(min=min_weight, max=max_weight)
return dynamic_weight.to(device=batch["tick_mask"].device)
def compute_distance_aware_hazard_gate_loss(
gate_logits: torch.Tensor,
gate_targets: torch.Tensor,
gate_loss_mask: torch.Tensor,
training_cfg: dict[str, Any],
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
device = gate_logits.device
batch_size, max_steps = gate_logits.shape
support_window_ticks = int(training_cfg.get("hazard_soft_target_window_ticks", 16))
target_tau_ticks = float(training_cfg.get("hazard_soft_target_tau_ticks", 4.0))
censor_weight = float(training_cfg.get("hazard_censor_weight", 0.25))
log_hazard = F.logsigmoid(gate_logits)
log_survival = F.logsigmoid(-gate_logits)
gate_loss_sum = torch.zeros((), device=device)
gate_loss_count = torch.zeros((), device=device)
hazard_event_count = torch.zeros((), device=device)
hazard_censor_count = torch.zeros((), device=device)
for batch_index in range(batch_size):
active_len = int(gate_loss_mask[batch_index].sum().item())
if active_len <= 0:
continue
active_targets = gate_targets[batch_index, :active_len]
active_log_hazard = log_hazard[batch_index, :active_len]
active_log_survival = log_survival[batch_index, :active_len]
positive_indices = torch.nonzero(active_targets > 0.5, as_tuple=False).flatten().tolist()
segment_start = 0
for positive_index in positive_indices:
segment_log_hazard = active_log_hazard[segment_start : positive_index + 1]
segment_log_survival = active_log_survival[segment_start : positive_index + 1]
if segment_log_hazard.numel() <= 0:
segment_start = positive_index + 1
continue
prefix_log_survival = torch.cumsum(segment_log_survival, dim=0) - segment_log_survival
segment_log_event = prefix_log_survival + segment_log_hazard
if support_window_ticks > 0:
support_start = max(0, segment_log_event.numel() - support_window_ticks)
else:
support_start = 0
support_log_event = segment_log_event[support_start:]
support_indices = torch.arange(
support_start,
segment_log_event.numel(),
device=device,
dtype=torch.float32,
)
distances = (segment_log_event.numel() - 1) - support_indices
if target_tau_ticks > 0.0:
target_weights = torch.exp(-distances / target_tau_ticks)
else:
target_weights = torch.zeros_like(distances)
target_weights[-1] = 1.0
target_weights = target_weights / target_weights.sum().clamp_min(1e-8)
segment_loss = -(target_weights * support_log_event).sum()
gate_loss_sum = gate_loss_sum + segment_loss
gate_loss_count = gate_loss_count + 1.0
hazard_event_count = hazard_event_count + 1.0
segment_start = positive_index + 1
if segment_start < active_len:
tail_log_survival = active_log_survival[segment_start:active_len]
if tail_log_survival.numel() > 0:
censored_loss = -tail_log_survival.mean()
gate_loss_sum = gate_loss_sum + censor_weight * censored_loss
gate_loss_count = gate_loss_count + censor_weight
hazard_censor_count = hazard_censor_count + 1.0
gate_loss = gate_loss_sum / gate_loss_count.clamp_min(1.0)
metrics = {
"gate_loss_sum": gate_loss_sum.detach(),
"gate_count": gate_loss_count.detach(),
"hazard_event_count": hazard_event_count.detach(),
"hazard_censor_count": hazard_censor_count.detach(),
}
return gate_loss, metrics
def compute_gate_ranking_loss(
gate_logits: torch.Tensor,
gate_targets: torch.Tensor,
gate_loss_mask: torch.Tensor,
ranking_margin: float,
ranking_lookback: int,
) -> tuple[torch.Tensor, torch.Tensor]:
device = gate_logits.device
ranking_loss_sum = torch.zeros((), device=device)
ranking_pair_count = torch.zeros((), device=device)
batch_size, max_steps = gate_logits.shape
for batch_index in range(batch_size):
for step_index in range(max_steps):
if float(gate_loss_mask[batch_index, step_index]) <= 0.0:
continue
if float(gate_targets[batch_index, step_index]) <= 0.5:
continue
positive_logit = gate_logits[batch_index, step_index]
search_start = max(0, step_index - ranking_lookback)
for previous_index in range(search_start, step_index):
if float(gate_loss_mask[batch_index, previous_index]) <= 0.0:
continue
if float(gate_targets[batch_index, previous_index]) >= 0.5:
continue
negative_logit = gate_logits[batch_index, previous_index]
ranking_loss_sum = ranking_loss_sum + F.relu(ranking_margin - positive_logit + negative_logit)
ranking_pair_count = ranking_pair_count + 1
return ranking_loss_sum, ranking_pair_count
def compute_chunk_losses(
model: ThoughtLoopT5Gemma,
batch: dict[str, torch.Tensor],
training_cfg: dict[str, Any],
z_state: torch.Tensor,
state_mask: torch.Tensor,
tbptt_steps: int | None = None,
freeze_gate_head_to_targets: bool = False,
) -> tuple[torch.Tensor, dict[str, torch.Tensor], torch.Tensor, torch.Tensor]:
device = batch["tick_mask"].device
batch_size, max_steps = batch["tick_mask"].shape
gate_loss_mode = str(training_cfg.get("gate_loss_mode", "hazard")).strip().lower()
use_hazard_gate_loss = gate_loss_mode == "hazard"
pos_weight = (
torch.zeros((), device=device)
if use_hazard_gate_loss
else compute_gate_positive_weight(batch, training_cfg)
)
gate_loss_sum = torch.zeros((), device=device)
gate_count = torch.zeros((), device=device)
decoder_loss_sum = torch.zeros((), device=device)
decoder_token_count = torch.zeros((), device=device)
state_penalty_sum = torch.zeros((), device=device)
state_penalty_count = torch.zeros((), device=device)
tp = torch.zeros((), device=device)
fp = torch.zeros((), device=device)
fn = torch.zeros((), device=device)
correct = torch.zeros((), device=device)
counted = torch.zeros((), device=device)
target_positive_count = torch.zeros((), device=device)
predicted_positive_count = torch.zeros((), device=device)
state_delta_weight = float(training_cfg["state_delta_weight"])
steps_since_detach = 0
gate_logits_history: list[torch.Tensor] = []
gate_target_history: list[torch.Tensor] = []
gate_loss_mask_history: list[torch.Tensor] = []
hazard_event_count = torch.zeros((), device=device)
hazard_censor_count = torch.zeros((), device=device)
teacher_logit_scale = float(training_cfg.get("gate_teacher_logit_scale", 20.0))
for step_index in range(max_steps):
active_mask = batch["tick_mask"][:, step_index]
if not torch.any(active_mask):
break
previous_z = z_state
next_z, next_state_mask, gate_logits = rollout_active_rows_only(
model=model,
batch=batch,
z_state=z_state,
state_mask=state_mask,
step_index=step_index,
active_mask=active_mask,
)
z_state = next_z
state_mask = next_state_mask
step_loss_mask = active_mask.float()
step_targets = batch["gate_target"][:, step_index]
if freeze_gate_head_to_targets:
gate_logits = torch.where(
step_targets > 0.5,
torch.full_like(step_targets, teacher_logit_scale),
torch.full_like(step_targets, -teacher_logit_scale),
)
gate_logits_history.append(gate_logits)
gate_target_history.append(step_targets)
gate_loss_mask_history.append(step_loss_mask)
if torch.any(step_loss_mask > 0):
step_predictions = (torch.sigmoid(gate_logits) >= 0.5).float()
valid = step_loss_mask > 0
correct = correct + ((step_predictions == step_targets) & valid).float().sum()
counted = counted + valid.float().sum()
tp = tp + ((step_predictions == 1) & (step_targets == 1) & valid).float().sum()
fp = fp + ((step_predictions == 1) & (step_targets == 0) & valid).float().sum()
fn = fn + ((step_predictions == 0) & (step_targets == 1) & valid).float().sum()
target_positive_count = target_positive_count + ((step_targets == 1) & valid).float().sum()
predicted_positive_count = predicted_positive_count + ((step_predictions == 1) & valid).float().sum()
if torch.any(step_loss_mask > 0) and not use_hazard_gate_loss and not freeze_gate_head_to_targets:
gate_loss_raw = F.binary_cross_entropy_with_logits(gate_logits, step_targets, reduction="none")
gate_weight = torch.where(step_targets > 0.5, pos_weight, torch.ones_like(step_targets))
gate_loss_sum = gate_loss_sum + (gate_loss_raw * gate_weight * step_loss_mask).sum()
gate_count = gate_count + step_loss_mask.sum()
if state_delta_weight > 0.0:
state_delta = (next_z - previous_z).pow(2).mean(dim=(1, 2))
state_penalty_sum = state_penalty_sum + (state_delta * step_loss_mask).sum()
state_penalty_count = state_penalty_count + step_loss_mask.sum()
elif torch.any(step_loss_mask > 0) and state_delta_weight > 0.0:
state_delta = (next_z - previous_z).pow(2).mean(dim=(1, 2))
state_penalty_sum = state_penalty_sum + (state_delta * step_loss_mask).sum()
state_penalty_count = state_penalty_count + step_loss_mask.sum()
step_labels = batch["decoder_labels"][:, step_index, :]
decoder_rows = active_mask & torch.any(step_labels != -100, dim=-1)
if torch.any(decoder_rows):
step_loss_sum, step_token_count = model.decoder_loss(
z_state[decoder_rows],
step_labels[decoder_rows],
encoder_attention_mask=state_mask[decoder_rows],
)
decoder_loss_sum = decoder_loss_sum + step_loss_sum
decoder_token_count = decoder_token_count + step_token_count
if tbptt_steps is not None:
steps_since_detach += 1
if steps_since_detach >= tbptt_steps:
z_state = z_state.detach()
state_mask = state_mask.detach()
steps_since_detach = 0
if gate_logits_history:
stacked_gate_logits = torch.stack(gate_logits_history, dim=1)
stacked_gate_targets = torch.stack(gate_target_history, dim=1)
stacked_gate_loss_mask = torch.stack(gate_loss_mask_history, dim=1)
else:
stacked_gate_logits = torch.zeros((batch_size, 0), device=device)
stacked_gate_targets = torch.zeros((batch_size, 0), device=device)
stacked_gate_loss_mask = torch.zeros((batch_size, 0), device=device)
gate_ranking_loss = torch.zeros((), device=device)
gate_ranking_loss_sum = torch.zeros((), device=device)
gate_ranking_pair_count = torch.zeros((), device=device)
if freeze_gate_head_to_targets:
gate_loss = torch.zeros((), device=device)
gate_loss_sum = torch.zeros((), device=device)
gate_count = torch.zeros((), device=device)
hazard_event_count = torch.zeros((), device=device)
hazard_censor_count = torch.zeros((), device=device)
elif use_hazard_gate_loss and gate_logits_history:
gate_loss, hazard_metrics = compute_distance_aware_hazard_gate_loss(
gate_logits=stacked_gate_logits,
gate_targets=stacked_gate_targets,
gate_loss_mask=stacked_gate_loss_mask,
training_cfg=training_cfg,
)
gate_loss_sum = hazard_metrics["gate_loss_sum"]
gate_count = hazard_metrics["gate_count"]
hazard_event_count = hazard_metrics["hazard_event_count"]
hazard_censor_count = hazard_metrics["hazard_censor_count"]
else:
gate_loss = gate_loss_sum / gate_count.clamp_min(1.0)
if (
not use_hazard_gate_loss
and bool(training_cfg.get("use_ranked_gate_objective", False))
and gate_logits_history
and not freeze_gate_head_to_targets
):
gate_ranking_loss_sum, gate_ranking_pair_count = compute_gate_ranking_loss(
gate_logits=stacked_gate_logits,
gate_targets=stacked_gate_targets,
gate_loss_mask=stacked_gate_loss_mask,
ranking_margin=float(training_cfg.get("gate_ranking_margin", 0.5)),
ranking_lookback=int(training_cfg.get("gate_ranking_lookback", 64)),
)
gate_ranking_loss = gate_ranking_loss_sum / gate_ranking_pair_count.clamp_min(1.0)
gate_loss = (
float(training_cfg.get("gate_bce_weight", 1.0)) * gate_loss
+ float(training_cfg.get("gate_ranking_weight", 1.0)) * gate_ranking_loss
)
decoder_loss = decoder_loss_sum / decoder_token_count.clamp_min(1.0)
state_delta_loss = state_penalty_sum / state_penalty_count.clamp_min(1.0)
total_loss = (
float(training_cfg["gate_loss_weight"]) * gate_loss
+ float(training_cfg["decoder_loss_weight"]) * decoder_loss
+ state_delta_weight * state_delta_loss
)
metrics = {
"loss_total": total_loss.detach(),
"gate_loss_sum": gate_loss_sum.detach(),
"gate_count": gate_count.detach(),
"decoder_loss_sum": decoder_loss_sum.detach(),
"decoder_token_count": decoder_token_count.detach(),
"state_penalty_sum": state_penalty_sum.detach(),
"state_penalty_count": state_penalty_count.detach(),
"gate_tp": tp.detach(),
"gate_fp": fp.detach(),
"gate_fn": fn.detach(),
"gate_correct": correct.detach(),
"gate_eval_count": counted.detach(),
"gate_target_positive_count": target_positive_count.detach(),
"gate_pred_positive_count": predicted_positive_count.detach(),
"gate_ranking_loss_sum": gate_ranking_loss_sum.detach(),
"gate_ranking_pair_count": gate_ranking_pair_count.detach(),
"hazard_event_count": hazard_event_count.detach(),
"hazard_censor_count": hazard_censor_count.detach(),
"active_ticks": batch["tick_mask"].float().sum().detach(),
"gate_positive_weight": pos_weight.detach(),
"metrics_count": torch.ones((), device=device),
}
return total_loss, metrics, z_state, state_mask
def reduce_metrics(
accelerator: Accelerator,
metric_list: list[dict[str, torch.Tensor]],
training_cfg: dict[str, Any],
) -> dict[str, float]:
totals: dict[str, torch.Tensor] = {}
for metrics in metric_list:
for key, value in metrics.items():
totals[key] = totals.get(key, torch.zeros_like(value)) + value
reduced = {key: accelerator.reduce(value, reduction="sum") for key, value in totals.items()}
gate_loss_mode = str(training_cfg.get("gate_loss_mode", "hazard")).strip().lower()
gate_loss = reduced["gate_loss_sum"] / reduced["gate_count"].clamp_min(1.0)
gate_ranking_loss = reduced["gate_ranking_loss_sum"] / reduced["gate_ranking_pair_count"].clamp_min(1.0)
decoder_loss = reduced["decoder_loss_sum"] / reduced["decoder_token_count"].clamp_min(1.0)
state_delta_loss = reduced["state_penalty_sum"] / reduced["state_penalty_count"].clamp_min(1.0)
gate_precision = reduced["gate_tp"] / (reduced["gate_tp"] + reduced["gate_fp"]).clamp_min(1.0)
gate_recall = reduced["gate_tp"] / (reduced["gate_tp"] + reduced["gate_fn"]).clamp_min(1.0)
gate_f1 = 2 * gate_precision * gate_recall / (gate_precision + gate_recall).clamp_min(1e-8)
if gate_loss_mode != "hazard" and bool(training_cfg.get("use_ranked_gate_objective", False)):
gate_loss = (
float(training_cfg.get("gate_bce_weight", 1.0)) * gate_loss
+ float(training_cfg.get("gate_ranking_weight", 1.0)) * gate_ranking_loss
)
total_loss = (
float(training_cfg["gate_loss_weight"]) * gate_loss
+ float(training_cfg["decoder_loss_weight"]) * decoder_loss
+ float(training_cfg["state_delta_weight"]) * state_delta_loss
)
return {
"loss_total": float(total_loss),
"loss_gate": float(gate_loss),
"loss_gate_ranking": float(gate_ranking_loss),
"loss_decoder": float(decoder_loss),
"loss_state_delta": float(state_delta_loss),
"gate_accuracy": float(reduced["gate_correct"] / reduced["gate_eval_count"].clamp_min(1.0)),
"gate_precision": float(gate_precision),
"gate_recall": float(gate_recall),
"gate_f1": float(gate_f1),
"gate_target_positive_count": float(reduced["gate_target_positive_count"]),
"gate_pred_positive_count": float(reduced["gate_pred_positive_count"]),
"gate_positive_weight": float(reduced["gate_positive_weight"] / reduced["metrics_count"].clamp_min(1.0)),
"hazard_event_count": float(reduced["hazard_event_count"]),
"hazard_censor_count": float(reduced["hazard_censor_count"]),
"decoder_token_count": float(reduced["decoder_token_count"]),
"active_ticks": float(reduced["active_ticks"]),
}
@torch.no_grad()
def evaluate(
accelerator: Accelerator,
model: ThoughtLoopT5Gemma,
dataloader: DataLoader,
training_cfg: dict[str, Any],
rollout_cfg: dict[str, Any],
tokenizer_pad_token_id: int,
max_batches: int | None = None,
) -> dict[str, float]:
model.eval()
all_metrics: list[dict[str, torch.Tensor]] = []
chunk_ticks = int(rollout_cfg["chunk_ticks"])
tick_seconds = float(rollout_cfg["tick_seconds"])
for batch_index, conversations in enumerate(dataloader):
if max_batches is not None and batch_index >= max_batches:
break
batch_size = len(conversations)
z_state = model.initial_state(batch_size=batch_size, device=accelerator.device)
state_mask = model.initial_state_mask(batch_size=batch_size, device=accelerator.device)
max_chunks = max(int(conversation["chunk_count"]) for conversation in conversations)
for chunk_index in range(max_chunks):
chunk_batch = build_chunk_batch(
conversations=conversations,
chunk_index=chunk_index,
pad_token_id=tokenizer_pad_token_id,
chunk_ticks=chunk_ticks,
tick_seconds=tick_seconds,
)
if chunk_batch is None:
break
chunk_batch = move_chunk_batch_to_device(chunk_batch, accelerator.device)
_, metrics, z_state, state_mask = compute_chunk_losses(
model=model,
batch=chunk_batch,
training_cfg=training_cfg,
z_state=z_state,
state_mask=state_mask,
tbptt_steps=None,
freeze_gate_head_to_targets=False,
)
z_state = z_state.detach()
state_mask = state_mask.detach()
all_metrics.append(metrics)
if not all_metrics:
model.train()
return {}
reduced = reduce_metrics(accelerator, all_metrics, training_cfg)
model.train()
return reduced
def copy_supporting_files(output_dir: Path, config_path: Path) -> None:
for filename in ("config.py", "data.py", "model.py", "inference.py", "train.py"):
shutil.copy2(SCRIPT_DIR / filename, output_dir / filename)
shutil.copy2(config_path, output_dir / config_path.name)
def build_model_card(config: dict[str, Any], best_metrics: dict[str, float]) -> str:
return "\n".join(
[
"# Samantha Thought-Loop SFT",
"",
f"- Base model: `{config['model']['base_model_name']}`",
f"- Cleaned dataset repo: `{config['dataset'].get('cleaned_repo_id')}`",
f"- Raw dataset repo: `{config['dataset'].get('raw_repo_id')}`",
f"- Latent slots: `{config['model']['z_slots']}`",
f"- Fixed tick seconds: `{config['rollout']['tick_seconds']}`",
f"- Chunk ticks: `{config['rollout']['chunk_ticks']}`",
f"- Max horizon ticks: `{config['rollout'].get('max_horizon_ticks', 36000)}`",
f"- Explicit time features: `{config['model'].get('use_explicit_time_features', False)}`",
f"- Gate loss mode: `{config['training'].get('gate_loss_mode', 'hazard')}`",
f"- Hazard soft target window: `{config['training'].get('hazard_soft_target_window_ticks', 16)}`",
f"- Hazard soft target tau ticks: `{config['training'].get('hazard_soft_target_tau_ticks', 4.0)}`",
f"- Hazard censor weight: `{config['training'].get('hazard_censor_weight', 0.25)}`",
f"- TBPTT by bucket: `{config['training'].get('tbptt_steps_by_bucket', {})}`",
"",
"## Best Validation Metrics",
"",
*[f"- `{key}`: `{value:.6f}`" for key, value in sorted(best_metrics.items())],
"",
]
)
def export_model(
accelerator: Accelerator,
model: ThoughtLoopT5Gemma,
output_dir: Path,
config: dict[str, Any],
config_path: Path,
best_metrics: dict[str, float],
) -> None:
if not accelerator.is_main_process:
return
output_dir.mkdir(parents=True, exist_ok=True)
unwrapped = accelerator.unwrap_model(model)
unwrapped.save_pretrained(output_dir)
copy_supporting_files(output_dir, config_path)
(output_dir / "README.md").write_text(build_model_card(config, best_metrics), encoding="utf-8")
save_json(output_dir / "best_metrics.json", best_metrics)
def maybe_push_to_hub(config: dict[str, Any], export_dir: Path) -> None:
model_repo_id = config["hub"].get("model_repo_id")
if not model_repo_id:
return
api = HfApi()
api.create_repo(repo_id=model_repo_id, repo_type="model", exist_ok=True, private=bool(config["hub"].get("private")))
api.upload_folder(folder_path=str(export_dir), repo_id=model_repo_id, repo_type="model")
def maybe_push_checkpoint_to_hub(
accelerator: Accelerator,
config: dict[str, Any],
checkpoint_dir: Path,
) -> None:
model_repo_id = config["hub"].get("model_repo_id")
if not accelerator.is_main_process or not model_repo_id:
return
if not bool(config["hub"].get("push_checkpoints", True)):
return
run_name = str(config["wandb"].get("run_name", "run"))
checkpoint_prefix = str(config["hub"].get("checkpoint_path_prefix", "checkpoints")).strip("/")
path_parts = [part for part in (checkpoint_prefix, SCRIPT_DIR.name, run_name, checkpoint_dir.name) if part]
path_in_repo = "/".join(path_parts)
api = HfApi()
api.create_repo(repo_id=model_repo_id, repo_type="model", exist_ok=True, private=bool(config["hub"].get("private")))
api.upload_folder(
folder_path=str(checkpoint_dir),
repo_id=model_repo_id,
repo_type="model",
path_in_repo=path_in_repo,
commit_message=f"Add {SCRIPT_DIR.name} {checkpoint_dir.name}",
)
def read_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
with path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
if not isinstance(payload, dict):
raise ValueError(f"Expected a JSON object in {path}.")
return payload
def checkpoint_global_step(checkpoint_dir: Path) -> int | None:
state = read_json(checkpoint_dir / "trainer_state.json")
if "global_step" in state:
return int(state["global_step"])
prefix = "checkpoint-"
if checkpoint_dir.name.startswith(prefix):
suffix = checkpoint_dir.name[len(prefix) :]
if suffix.isdigit():
return int(suffix)
return None
def find_latest_checkpoint(run_dir: Path) -> Path | None:
checkpoints = []
for candidate in run_dir.glob("checkpoint-*"):
if not candidate.is_dir():
continue
step = checkpoint_global_step(candidate)
if step is not None:
checkpoints.append((step, candidate))
if not checkpoints:
return None
return max(checkpoints, key=lambda item: item[0])[1]
def resolve_resume_checkpoint(requested: str | None, run_dir: Path) -> Path | None:
if requested is None:
return None
value = str(requested).strip()
if not value:
return None
if value.lower() in {"1", "true", "yes", "auto", "latest"}:
checkpoint_dir = find_latest_checkpoint(run_dir)
if checkpoint_dir is None:
raise FileNotFoundError(f"No checkpoint-* directory found under {run_dir}.")
return checkpoint_dir
checkpoint_dir = Path(value).expanduser()
if not checkpoint_dir.is_absolute():
checkpoint_dir = SCRIPT_DIR / checkpoint_dir
checkpoint_dir = checkpoint_dir.resolve()
if not checkpoint_dir.is_dir():
raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_dir}")
return checkpoint_dir
def parse_wandb_run_id_from_dir(path: Path) -> str | None:
name = path.resolve(strict=False).name if path.is_symlink() else path.name
for prefix in ("offline-run-", "run-"):
if name.startswith(prefix):
parts = name.split("-")
if parts:
return parts[-1]
return None
def discover_local_wandb_run_id(run_name: str) -> str | None:
wandb_dir = SCRIPT_DIR / "wandb"
if not wandb_dir.exists():
return None
candidates: list[Path] = []
latest = wandb_dir / "latest-run"
if latest.exists() or latest.is_symlink():
candidates.append(latest)
candidates.extend(sorted(wandb_dir.glob("*run-*"), key=lambda path: path.stat().st_mtime, reverse=True))
for candidate in candidates:
run_id = parse_wandb_run_id_from_dir(candidate)
if not run_id:
continue
resolved = candidate.resolve(strict=False)
config_path = resolved / "files" / "config.yaml"
debug_log_path = resolved / "logs" / "debug.log"
if config_path.exists() and run_name in config_path.read_text(encoding="utf-8", errors="ignore"):
return run_id
if debug_log_path.exists() and run_name in debug_log_path.read_text(encoding="utf-8", errors="ignore"):
return run_id
return None
def resolve_wandb_run_id(args: argparse.Namespace, config: dict[str, Any], run_dir: Path) -> str:
run_id_path = run_dir / "wandb_run_id.txt"
explicit = args.wandb_run_id or os.environ.get("WANDB_RUN_ID") or config["wandb"].get("run_id")
if explicit:
run_id = str(explicit).strip()
elif run_id_path.exists():
run_id = run_id_path.read_text(encoding="utf-8").strip()
else:
run_id = discover_local_wandb_run_id(str(config["wandb"].get("run_name", ""))) or secrets.token_hex(4)
if not run_id:
run_id = secrets.token_hex(4)
run_id_path.write_text(run_id + "\n", encoding="utf-8")
return run_id
def save_trainer_state(
checkpoint_dir: Path,
*,
global_step: int,
epoch: int,
next_batch_index: int,
best_validation_loss: float,
best_metrics: dict[str, float],
wandb_run_id: str | None,
) -> None:
payload: dict[str, Any] = {
"global_step": int(global_step),
"epoch": int(epoch),
"next_batch_index": int(next_batch_index),
"best_validation_loss": best_validation_loss if math.isfinite(best_validation_loss) else None,
"best_metrics": best_metrics,
"wandb_run_id": wandb_run_id,
}
save_json(checkpoint_dir / "trainer_state.json", payload)
def load_resume_progress(
checkpoint_dir: Path,
*,
total_chunk_microsteps: int,
gradient_accumulation_steps: int,
num_train_epochs: int,
) -> tuple[int, int, int, dict[str, Any]]:
state = read_json(checkpoint_dir / "trainer_state.json")
global_step = int(state.get("global_step") or checkpoint_global_step(checkpoint_dir) or 0)
if "epoch" in state and "next_batch_index" in state:
start_epoch = int(state["epoch"])
batches_to_skip = int(state["next_batch_index"])
else:
updates_per_epoch = max(1, math.ceil(total_chunk_microsteps / max(gradient_accumulation_steps, 1)))
start_epoch = min(global_step // updates_per_epoch, max(num_train_epochs - 1, 0))
updates_into_epoch = global_step - start_epoch * updates_per_epoch
batches_to_skip = updates_into_epoch * max(gradient_accumulation_steps, 1)
while batches_to_skip >= total_chunk_microsteps and start_epoch < num_train_epochs:
batches_to_skip -= total_chunk_microsteps
start_epoch += 1
return global_step, start_epoch, batches_to_skip, state
def load_best_metrics(export_dir: Path, resume_state: dict[str, Any]) -> tuple[float, dict[str, float]]:
best_metrics = resume_state.get("best_metrics")
if not isinstance(best_metrics, dict):
best_metrics = read_json(export_dir / "best_metrics.json")
loss = resume_state.get("best_validation_loss")
if loss is None and isinstance(best_metrics, dict):
loss = best_metrics.get("loss_total")
best_validation_loss = float(loss) if loss is not None else float("inf")
return best_validation_loss, {str(key): float(value) for key, value in best_metrics.items()}
def main() -> None:
args = parse_args()
config_path = Path(args.config)
config = load_config(config_path)
training_cfg = config["training"]
rollout_cfg = config["rollout"]
set_seed(int(training_cfg["seed"]))
wandb_enabled = bool(config["wandb"].get("enabled", True) and wandb is not None)
accelerator = Accelerator(
gradient_accumulation_steps=int(training_cfg["gradient_accumulation_steps"]),
mixed_precision=training_cfg.get("mixed_precision", "bf16"),
log_with="wandb" if wandb_enabled else None,
project_dir=config["paths"]["run_root"],
)
run_root = ensure_dir(config["paths"]["run_root"])
export_root = ensure_dir(config["paths"]["export_root"])
run_name = config["wandb"]["run_name"]
run_dir = ensure_dir(run_root / run_name)
export_dir = ensure_dir(export_root / run_name)
resume_request = args.resume_from_checkpoint or training_cfg.get("resume_from_checkpoint")
resume_checkpoint = resolve_resume_checkpoint(resume_request, run_dir)
wandb_run_id = resolve_wandb_run_id(args, config, run_dir) if wandb_enabled else None
model = build_initial_model(config)
if model.use_explicit_time_features:
raise ValueError("This trainer expects model.use_explicit_time_features=false for fixed-cadence training.")
tokenizer = model.tokenizer
train_dataset = ThoughtLoopConversationDataset(config=config, tokenizer=tokenizer, split="train")
validation_dataset = ThoughtLoopConversationDataset(config=config, tokenizer=tokenizer, split="validation")
sort_dataset_by_length = bool(training_cfg.get("sort_dataset_by_length", False))
ensure_single_chunk_dataset(train_dataset, "train")
ensure_single_chunk_dataset(validation_dataset, "validation")
maybe_sort_dataset_by_length(train_dataset, enabled=sort_dataset_by_length)
maybe_sort_dataset_by_length(validation_dataset, enabled=sort_dataset_by_length)
train_loader = DataLoader(
train_dataset,
batch_size=int(training_cfg["micro_batch_size"]),
shuffle=bool(training_cfg.get("shuffle_train", True)) and not sort_dataset_by_length,
num_workers=int(training_cfg["num_workers"]),
pin_memory=True,
persistent_workers=bool(int(training_cfg["num_workers"]) > 0),
collate_fn=identity_collate,
)
validation_loader = DataLoader(
validation_dataset,
batch_size=int(training_cfg["eval_batch_size"]),
shuffle=False,
num_workers=int(training_cfg["num_workers"]),
pin_memory=True,
persistent_workers=bool(int(training_cfg["num_workers"]) > 0),
collate_fn=identity_collate,
)
optimizer = build_optimizer(model, training_cfg)
total_chunk_microsteps = len(train_loader)
total_update_steps = max(
1,
math.ceil(total_chunk_microsteps / max(int(training_cfg["gradient_accumulation_steps"]), 1))
* int(training_cfg["num_train_epochs"]),
)
warmup_steps = max(1, int(total_update_steps * float(training_cfg["warmup_ratio"])))
scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max(1, total_update_steps),
)
model, optimizer, train_loader, validation_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, validation_loader, scheduler
)
resume_state: dict[str, Any] = {}
global_step = 0
start_epoch = 0
batches_to_skip = 0
if resume_checkpoint is not None:
accelerator.print(f"Loading local checkpoint from {resume_checkpoint}")
accelerator.load_state(str(resume_checkpoint))
global_step, start_epoch, batches_to_skip, resume_state = load_resume_progress(
resume_checkpoint,
total_chunk_microsteps=total_chunk_microsteps,
gradient_accumulation_steps=int(training_cfg["gradient_accumulation_steps"]),
num_train_epochs=int(training_cfg["num_train_epochs"]),
)
accelerator.print(
f"Resuming at global_step={global_step}, epoch={start_epoch + 1}, "
f"skipping {batches_to_skip} batch(es)."
)
if accelerator.is_main_process and wandb_enabled:
accelerator.print(f"W&B run id: {wandb_run_id}")
accelerator.init_trackers(
project_name=config["wandb"]["project"],
config=flatten_for_wandb(config),
init_kwargs={
"wandb": {
"name": run_name,
"id": wandb_run_id,
"resume": str(config["wandb"].get("resume", "allow")),
}
},
)
unwrapped_model = accelerator.unwrap_model(model)
current_decoder_fraction = resolve_decoder_trainable_fraction(training_cfg, global_step, total_update_steps)
if current_decoder_fraction is None:
current_decoder_fraction = 1.0
current_decoder_trainable_layers = unwrapped_model.set_decoder_trainable_fraction(current_decoder_fraction)
current_gate_head_frozen = should_freeze_gate_head(training_cfg, global_step, total_update_steps)
unwrapped_model.set_gate_trainable(not current_gate_head_frozen)
progress_bar = tqdm(
total=total_update_steps,
initial=global_step,
disable=not accelerator.is_local_main_process,
desc="Training",
dynamic_ncols=True,
)
best_validation_loss, best_metrics = load_best_metrics(export_dir, resume_state)
model.train()
chunk_ticks = int(rollout_cfg["chunk_ticks"])
tick_seconds = float(rollout_cfg["tick_seconds"])
for epoch in range(start_epoch, int(training_cfg["num_train_epochs"])):
epoch_train_loader = train_loader
batch_index_offset = 0
if epoch == start_epoch and batches_to_skip > 0:
epoch_train_loader = accelerator.skip_first_batches(train_loader, batches_to_skip)
batch_index_offset = batches_to_skip
for batch_index, conversations in enumerate(epoch_train_loader, start=batch_index_offset):
batch_bucket = resolve_batch_duration_bucket(conversations)
current_tbptt_steps = resolve_bucket_tbptt_steps(training_cfg, batch_bucket)
current_horizon_ticks = resolve_bucket_horizon_ticks(batch_bucket, rollout_cfg)
z_state = model.initial_state(batch_size=len(conversations), device=accelerator.device)
state_mask = model.initial_state_mask(batch_size=len(conversations), device=accelerator.device)
max_chunks = max(int(conversation["chunk_count"]) for conversation in conversations)
for chunk_index in range(max_chunks):
scheduled_decoder_fraction = resolve_decoder_trainable_fraction(training_cfg, global_step, total_update_steps)
if scheduled_decoder_fraction is None:
scheduled_decoder_fraction = 1.0
if abs(scheduled_decoder_fraction - current_decoder_fraction) > 1e-6:
current_decoder_fraction = scheduled_decoder_fraction
current_decoder_trainable_layers = unwrapped_model.set_decoder_trainable_fraction(
current_decoder_fraction
)
scheduled_gate_head_frozen = should_freeze_gate_head(training_cfg, global_step, total_update_steps)
if scheduled_gate_head_frozen != current_gate_head_frozen:
current_gate_head_frozen = scheduled_gate_head_frozen
unwrapped_model.set_gate_trainable(not current_gate_head_frozen)
chunk_batch = build_chunk_batch(
conversations=conversations,
chunk_index=chunk_index,
pad_token_id=tokenizer.pad_token_id,
chunk_ticks=chunk_ticks,
tick_seconds=tick_seconds,
)
if chunk_batch is None:
break
chunk_batch = move_chunk_batch_to_device(chunk_batch, accelerator.device)
with accelerator.accumulate(model):
total_loss, metrics, z_state, state_mask = compute_chunk_losses(
model=model,
batch=chunk_batch,
training_cfg=training_cfg,
z_state=z_state,
state_mask=state_mask,
tbptt_steps=current_tbptt_steps,
freeze_gate_head_to_targets=current_gate_head_frozen,
)
accelerator.backward(total_loss)
accelerator.clip_grad_norm_(model.parameters(), float(training_cfg["max_grad_norm"]))
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
z_state = z_state.detach()
state_mask = state_mask.detach()
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
progress_bar.set_postfix(
loss=f"{float(total_loss.detach()):.4f}",
bucket=batch_bucket,
tbptt=int(current_tbptt_steps or 0),
horizon=int(current_horizon_ticks),
decoder=current_decoder_trainable_layers,
gate="locked" if current_gate_head_frozen else "train",
refresh=False,
)
if global_step % int(training_cfg["logging_steps"]) == 0:
reduced = reduce_metrics(accelerator, [metrics], training_cfg)
reduced["lr"] = scheduler.get_last_lr()[0]
reduced["tbptt_steps"] = float(current_tbptt_steps or 0)
reduced["horizon_ticks"] = float(current_horizon_ticks)
reduced["chunk_ticks"] = float(chunk_ticks)
reduced["decoder_trainable_layers"] = float(current_decoder_trainable_layers)
reduced["gate_head_frozen"] = float(current_gate_head_frozen)
reduced["epoch"] = epoch + 1
reduced["global_step"] = global_step
accelerator.log({f"train/{key}": value for key, value in reduced.items()}, step=global_step)
if global_step % int(training_cfg["eval_steps"]) == 0:
validation_metrics = evaluate(
accelerator=accelerator,
model=model,
dataloader=validation_loader,
training_cfg=training_cfg,
rollout_cfg=rollout_cfg,
tokenizer_pad_token_id=tokenizer.pad_token_id,
max_batches=training_cfg.get("eval_max_batches"),
)
accelerator.log(
{f"validation/{key}": value for key, value in validation_metrics.items()},
step=global_step,
)
if validation_metrics and validation_metrics["loss_total"] < best_validation_loss:
best_validation_loss = validation_metrics["loss_total"]
best_metrics = validation_metrics
export_model(
accelerator=accelerator,
model=model,
output_dir=export_dir,
config=config,
config_path=config_path,
best_metrics=best_metrics,
)
if global_step % int(training_cfg["checkpoint_steps"]) == 0:
checkpoint_dir = ensure_dir(run_dir / f"checkpoint-{global_step}")
accelerator.save_state(str(checkpoint_dir))
if accelerator.is_main_process:
save_trainer_state(
checkpoint_dir,
global_step=global_step,
epoch=epoch,
next_batch_index=batch_index + 1,
best_validation_loss=best_validation_loss,
best_metrics=best_metrics,
wandb_run_id=wandb_run_id,
)
accelerator.wait_for_everyone()
maybe_push_checkpoint_to_hub(
accelerator=accelerator,
config=config,
checkpoint_dir=checkpoint_dir,
)
accelerator.wait_for_everyone()
if not best_metrics:
best_metrics = evaluate(
accelerator=accelerator,
model=model,
dataloader=validation_loader,
training_cfg=training_cfg,
rollout_cfg=rollout_cfg,
tokenizer_pad_token_id=tokenizer.pad_token_id,
max_batches=training_cfg.get("eval_max_batches"),
)
export_model(
accelerator=accelerator,
model=model,
output_dir=export_dir,
config=config,
config_path=config_path,
best_metrics=best_metrics,
)
if accelerator.is_main_process:
training_summary = {
"train_conversations": len(train_dataset),
"validation_conversations": len(validation_dataset),
"train_chunk_microsteps": total_chunk_microsteps,
"trainable_parameters": accelerator.unwrap_model(model).trainable_parameter_count(),
"best_metrics": best_metrics,
"global_step": global_step,
"resumed_from_checkpoint": str(resume_checkpoint) if resume_checkpoint is not None else None,
"wandb_run_id": wandb_run_id,
"decoder_layer_count": unwrapped_model.decoder_layer_count(),
"tick_seconds": tick_seconds,
"chunk_ticks": chunk_ticks,
"tbptt_steps_by_bucket": training_cfg.get("tbptt_steps_by_bucket", {}),
"horizon_ticks_by_bucket": rollout_cfg.get("horizon_ticks_by_bucket", {}),
"shuffle_train": bool(training_cfg.get("shuffle_train", True)),
"sort_dataset_by_length": sort_dataset_by_length,
}
save_json(run_dir / "training_summary.json", training_summary)
maybe_push_to_hub(config, export_dir)
progress_bar.close()
accelerator.end_training()
if __name__ == "__main__":
main()