| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import secrets |
| import shutil |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from accelerate import Accelerator |
| from huggingface_hub import HfApi |
| from torch.utils.data import DataLoader |
| from tqdm.auto import tqdm |
| from transformers import get_cosine_schedule_with_warmup |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| if str(SCRIPT_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| from config import ensure_dir, flatten_for_wandb, load_config, save_json |
| from data import ( |
| OBS_ROLE_NONE, |
| OBS_ROLE_USER, |
| ThoughtLoopConversationDataset, |
| build_chunk_batch, |
| identity_collate, |
| resolve_bucket_horizon_ticks, |
| ) |
| from model import ThoughtLoopT5Gemma |
|
|
| try: |
| import wandb |
| except ImportError: |
| wandb = None |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Train the recurrent T5Gemma thought-loop model.") |
| parser.add_argument("--config", default="configs/sft_b200.yaml") |
| parser.add_argument( |
| "--resume-from-checkpoint", |
| default=None, |
| help="Local checkpoint path to resume from, or 'latest'/'auto' to use the newest checkpoint in the run dir.", |
| ) |
| parser.add_argument( |
| "--wandb-run-id", |
| default=None, |
| help="Optional W&B run id to resume. If omitted, train.py reuses a stored/local run id when available.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def build_optimizer(model: torch.nn.Module, training_cfg: dict[str, Any]) -> torch.optim.Optimizer: |
| no_decay_terms = ("bias", "norm", "Norm") |
| decay_params = [] |
| no_decay_params = [] |
| for name, parameter in model.named_parameters(): |
| target = no_decay_params if any(term in name for term in no_decay_terms) else decay_params |
| target.append(parameter) |
|
|
| fused = bool(training_cfg.get("fused_adamw", True) and torch.cuda.is_available()) |
| return torch.optim.AdamW( |
| [ |
| {"params": decay_params, "weight_decay": float(training_cfg["weight_decay"])}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ], |
| lr=float(training_cfg["learning_rate"]), |
| betas=(float(training_cfg["adam_beta1"]), float(training_cfg["adam_beta2"])), |
| eps=float(training_cfg["adam_epsilon"]), |
| fused=fused, |
| ) |
|
|
|
|
| def resolve_bucket_tbptt_steps(training_cfg: dict[str, Any], duration_bucket: str) -> int | None: |
| bucket_map = training_cfg.get("tbptt_steps_by_bucket", {}) |
| if duration_bucket in bucket_map: |
| return int(bucket_map[duration_bucket]) |
|
|
| start_steps = int(training_cfg.get("tbptt_start_steps", 0)) |
| end_steps = int(training_cfg.get("tbptt_end_steps", start_steps)) |
| if start_steps <= 0 or end_steps <= 0: |
| return None |
| return end_steps |
|
|
|
|
| def resolve_linear_schedule_value( |
| *, |
| enabled: bool, |
| global_step: int, |
| total_update_steps: int, |
| start_fraction: float, |
| end_fraction: float, |
| start_value: float, |
| end_value: float, |
| ) -> float | None: |
| if not enabled: |
| return None |
|
|
| clamped_total_steps = max(total_update_steps, 1) |
| start_step = int(round(clamped_total_steps * start_fraction)) |
| end_step = int(round(clamped_total_steps * end_fraction)) |
| if end_step <= start_step: |
| end_step = start_step + 1 |
|
|
| if global_step <= start_step: |
| return start_value |
| if global_step >= end_step: |
| return end_value |
|
|
| progress = (global_step - start_step) / max(end_step - start_step, 1) |
| return start_value + progress * (end_value - start_value) |
|
|
|
|
| def should_freeze_gate_head( |
| training_cfg: dict[str, Any], |
| global_step: int, |
| total_update_steps: int, |
| ) -> bool: |
| if not bool(training_cfg.get("freeze_gate_head_first_half", False)): |
| return False |
| freeze_fraction = float(training_cfg.get("gate_head_freeze_fraction", 0.5)) |
| return global_step < int(round(max(total_update_steps, 1) * freeze_fraction)) |
|
|
|
|
| def resolve_decoder_trainable_fraction( |
| training_cfg: dict[str, Any], |
| global_step: int, |
| total_update_steps: int, |
| ) -> float | None: |
| return resolve_linear_schedule_value( |
| enabled=bool(training_cfg.get("gradual_unfreeze_decoder", False)), |
| global_step=global_step, |
| total_update_steps=total_update_steps, |
| start_fraction=float(training_cfg.get("decoder_unfreeze_start_fraction", 0.0)), |
| end_fraction=float(training_cfg.get("decoder_unfreeze_end_fraction", 0.5)), |
| start_value=float(training_cfg.get("decoder_initial_trainable_fraction", 0.0)), |
| end_value=1.0, |
| ) |
|
|
|
|
| def resolve_batch_duration_bucket(conversations: list[dict[str, Any]]) -> str: |
| if not conversations: |
| return "short" |
| buckets = {str(conversation["duration_bucket"]) for conversation in conversations} |
| if len(buckets) == 1: |
| return next(iter(buckets)) |
| return "mixed" |
|
|
|
|
| def ensure_single_chunk_dataset(dataset: ThoughtLoopConversationDataset, split: str) -> None: |
| offending_examples = [ |
| example for example in dataset.examples if int(example.get("chunk_count", 0)) != 1 |
| ] |
| if offending_examples: |
| raise ValueError( |
| f"{split} dataset expected exactly one chunk per conversation, but found " |
| f"{len(offending_examples)} multi-chunk examples. Check chunk_ticks/max_horizon_ticks." |
| ) |
|
|
|
|
| def maybe_sort_dataset_by_length( |
| dataset: ThoughtLoopConversationDataset, |
| *, |
| enabled: bool, |
| ) -> None: |
| if not enabled: |
| return |
| dataset.examples.sort( |
| key=lambda item: ( |
| int(item.get("effective_total_ticks", item.get("total_ticks", 0))), |
| int(item.get("total_ticks", 0)), |
| str(item.get("row_id", "")), |
| ) |
| ) |
|
|
|
|
| def move_chunk_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: |
| return {key: value.to(device) for key, value in batch.items()} |
|
|
|
|
| def apply_runtime_model_config(model: ThoughtLoopT5Gemma, config: dict[str, Any]) -> None: |
| model.config = config |
| model_cfg = config["model"] |
| expected_z_slots = int(model_cfg["z_slots"]) |
| if int(model.z_slots) != expected_z_slots: |
| raise ValueError( |
| f"Loaded model has z_slots={model.z_slots}, but config requested z_slots={expected_z_slots}." |
| ) |
| model.thought_loop_proposal_mode = str( |
| model_cfg.get("thought_loop_proposal_mode", model.thought_loop_proposal_mode) |
| ).strip().lower() |
| model.preserve_observation_encoder_manifold = bool( |
| model_cfg.get( |
| "preserve_observation_encoder_manifold", |
| model.thought_loop_proposal_mode == "observation_hidden_compression", |
| ) |
| ) |
| model.observation_encoder_use_state_context = bool( |
| model_cfg.get("observation_encoder_use_state_context", False) |
| ) |
| model.latent_attention_mask_mode = str( |
| model_cfg.get("latent_attention_mask_mode", getattr(model, "latent_attention_mask_mode", "observed")) |
| ).strip().lower() |
| if model.latent_attention_mask_mode not in {"observed", "full"}: |
| raise ValueError(f"Unsupported latent_attention_mask_mode: {model.latent_attention_mask_mode}") |
|
|
|
|
| def build_initial_model(config: dict[str, Any]) -> ThoughtLoopT5Gemma: |
| initial_model_path = config["model"].get("initial_model_path") |
| if initial_model_path: |
| model = ThoughtLoopT5Gemma.from_pretrained(str(initial_model_path), device="cpu", map_location="cpu") |
| apply_runtime_model_config(model, config) |
| return model |
| return ThoughtLoopT5Gemma(config) |
|
|
|
|
| def rollout_active_rows_only( |
| *, |
| model: ThoughtLoopT5Gemma, |
| batch: dict[str, torch.Tensor], |
| z_state: torch.Tensor, |
| state_mask: torch.Tensor, |
| step_index: int, |
| active_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| active_indices = torch.nonzero(active_mask, as_tuple=False).flatten() |
| active_next_z, active_gate_logits, active_next_mask = model.rollout_step_with_mask( |
| z_state=z_state[active_indices], |
| state_mask=state_mask[active_indices], |
| observation_input_ids=batch["observation_input_ids"][active_indices, step_index, :], |
| observation_attention_mask=batch["observation_attention_mask"][active_indices, step_index, :], |
| observation_role_ids=batch["observation_role"][active_indices, step_index], |
| delta_seconds=batch["delta_seconds"][active_indices, step_index], |
| elapsed_seconds=batch["elapsed_seconds"][active_indices, step_index], |
| since_last_user_seconds=torch.zeros_like(batch["elapsed_seconds"][active_indices, step_index]), |
| since_last_agent_seconds=torch.zeros_like(batch["elapsed_seconds"][active_indices, step_index]), |
| ) |
|
|
| next_z = z_state.clone() |
| next_z[active_indices] = active_next_z |
| next_state_mask = state_mask.clone() |
| next_state_mask[active_indices] = active_next_mask |
| gate_logits = torch.zeros(batch["tick_mask"].shape[0], device=z_state.device, dtype=active_gate_logits.dtype) |
| gate_logits[active_indices] = active_gate_logits |
| return next_z, next_state_mask, gate_logits |
|
|
|
|
| def compute_gate_positive_weight(batch: dict[str, torch.Tensor], training_cfg: dict[str, Any]) -> torch.Tensor: |
| if not bool(training_cfg.get("use_dynamic_gate_positive_weight", True)): |
| return torch.tensor(float(training_cfg["gate_positive_weight"]), device=batch["tick_mask"].device) |
|
|
| active_targets = batch["gate_target"][batch["tick_mask"]] |
| positive_count = active_targets.sum() |
| negative_count = active_targets.numel() - positive_count |
| if positive_count.item() <= 0.0: |
| fallback = float(training_cfg["gate_positive_weight"]) |
| return torch.tensor(fallback, device=batch["tick_mask"].device) |
|
|
| dynamic_weight = negative_count / positive_count.clamp_min(1.0) |
| min_weight = float(training_cfg.get("dynamic_gate_positive_weight_min", 1.0)) |
| max_weight = float(training_cfg.get("dynamic_gate_positive_weight_max", 256.0)) |
| dynamic_weight = dynamic_weight.clamp(min=min_weight, max=max_weight) |
| return dynamic_weight.to(device=batch["tick_mask"].device) |
|
|
|
|
| def compute_distance_aware_hazard_gate_loss( |
| gate_logits: torch.Tensor, |
| gate_targets: torch.Tensor, |
| gate_loss_mask: torch.Tensor, |
| training_cfg: dict[str, Any], |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: |
| device = gate_logits.device |
| batch_size, max_steps = gate_logits.shape |
| support_window_ticks = int(training_cfg.get("hazard_soft_target_window_ticks", 16)) |
| target_tau_ticks = float(training_cfg.get("hazard_soft_target_tau_ticks", 4.0)) |
| censor_weight = float(training_cfg.get("hazard_censor_weight", 0.25)) |
|
|
| log_hazard = F.logsigmoid(gate_logits) |
| log_survival = F.logsigmoid(-gate_logits) |
|
|
| gate_loss_sum = torch.zeros((), device=device) |
| gate_loss_count = torch.zeros((), device=device) |
| hazard_event_count = torch.zeros((), device=device) |
| hazard_censor_count = torch.zeros((), device=device) |
|
|
| for batch_index in range(batch_size): |
| active_len = int(gate_loss_mask[batch_index].sum().item()) |
| if active_len <= 0: |
| continue |
|
|
| active_targets = gate_targets[batch_index, :active_len] |
| active_log_hazard = log_hazard[batch_index, :active_len] |
| active_log_survival = log_survival[batch_index, :active_len] |
| positive_indices = torch.nonzero(active_targets > 0.5, as_tuple=False).flatten().tolist() |
|
|
| segment_start = 0 |
| for positive_index in positive_indices: |
| segment_log_hazard = active_log_hazard[segment_start : positive_index + 1] |
| segment_log_survival = active_log_survival[segment_start : positive_index + 1] |
| if segment_log_hazard.numel() <= 0: |
| segment_start = positive_index + 1 |
| continue |
|
|
| prefix_log_survival = torch.cumsum(segment_log_survival, dim=0) - segment_log_survival |
| segment_log_event = prefix_log_survival + segment_log_hazard |
|
|
| if support_window_ticks > 0: |
| support_start = max(0, segment_log_event.numel() - support_window_ticks) |
| else: |
| support_start = 0 |
|
|
| support_log_event = segment_log_event[support_start:] |
| support_indices = torch.arange( |
| support_start, |
| segment_log_event.numel(), |
| device=device, |
| dtype=torch.float32, |
| ) |
| distances = (segment_log_event.numel() - 1) - support_indices |
|
|
| if target_tau_ticks > 0.0: |
| target_weights = torch.exp(-distances / target_tau_ticks) |
| else: |
| target_weights = torch.zeros_like(distances) |
| target_weights[-1] = 1.0 |
| target_weights = target_weights / target_weights.sum().clamp_min(1e-8) |
|
|
| segment_loss = -(target_weights * support_log_event).sum() |
| gate_loss_sum = gate_loss_sum + segment_loss |
| gate_loss_count = gate_loss_count + 1.0 |
| hazard_event_count = hazard_event_count + 1.0 |
| segment_start = positive_index + 1 |
|
|
| if segment_start < active_len: |
| tail_log_survival = active_log_survival[segment_start:active_len] |
| if tail_log_survival.numel() > 0: |
| censored_loss = -tail_log_survival.mean() |
| gate_loss_sum = gate_loss_sum + censor_weight * censored_loss |
| gate_loss_count = gate_loss_count + censor_weight |
| hazard_censor_count = hazard_censor_count + 1.0 |
|
|
| gate_loss = gate_loss_sum / gate_loss_count.clamp_min(1.0) |
| metrics = { |
| "gate_loss_sum": gate_loss_sum.detach(), |
| "gate_count": gate_loss_count.detach(), |
| "hazard_event_count": hazard_event_count.detach(), |
| "hazard_censor_count": hazard_censor_count.detach(), |
| } |
| return gate_loss, metrics |
|
|
|
|
| def build_temporally_soft_gate_targets( |
| *, |
| gate_targets: torch.Tensor, |
| gate_loss_mask: torch.Tensor, |
| observation_roles: torch.Tensor, |
| training_cfg: dict[str, Any], |
| ) -> torch.Tensor: |
| early_window = max(0, int(training_cfg.get("gate_soft_target_early_window_ticks", 3))) |
| late_window = max(0, int(training_cfg.get("gate_soft_target_late_window_ticks", 3))) |
| tau_ticks = max(1e-6, float(training_cfg.get("gate_soft_target_tau_ticks", 1.5))) |
| peak = float(training_cfg.get("gate_soft_target_peak", 1.0)) |
| floor = max(0.0, float(training_cfg.get("gate_soft_target_floor", 0.0))) |
| shape = str(training_cfg.get("gate_soft_target_shape", "gaussian")).strip().lower() |
| block_user_ticks = bool(training_cfg.get("gate_soft_target_block_user_ticks", True)) |
|
|
| soft_targets = torch.zeros_like(gate_targets) |
| candidate_mask = gate_loss_mask > 0 |
| if block_user_ticks: |
| candidate_mask = candidate_mask & (observation_roles != OBS_ROLE_USER) |
|
|
| batch_size, _ = gate_targets.shape |
| for batch_index in range(batch_size): |
| active_len = int(gate_loss_mask[batch_index].sum().item()) |
| if active_len <= 0: |
| continue |
|
|
| positive_indices = torch.nonzero( |
| gate_targets[batch_index, :active_len] > 0.5, |
| as_tuple=False, |
| ).flatten().tolist() |
| for positive_index in positive_indices: |
| start_index = max(0, positive_index - early_window) |
| end_index = min(active_len, positive_index + late_window + 1) |
| if end_index <= start_index: |
| continue |
|
|
| indices = torch.arange(start_index, end_index, device=gate_targets.device) |
| distances = (indices - positive_index).abs().to(dtype=gate_targets.dtype) |
| if shape == "exponential": |
| values = torch.exp(-distances / tau_ticks) |
| elif shape == "linear": |
| side_windows = torch.where( |
| indices <= positive_index, |
| torch.full_like(indices, max(early_window, 1)), |
| torch.full_like(indices, max(late_window, 1)), |
| ).to(dtype=gate_targets.dtype) |
| values = 1.0 - distances / (side_windows + 1.0) |
| else: |
| values = torch.exp(-(distances.pow(2)) / (2.0 * tau_ticks * tau_ticks)) |
|
|
| values = (values * peak).clamp(0.0, 1.0) |
| if floor > 0.0: |
| values = torch.where(values > 0.0, values.clamp_min(floor), values) |
|
|
| valid = candidate_mask[batch_index, indices] |
| if not torch.any(valid): |
| continue |
|
|
| valid_indices = indices[valid] |
| soft_targets[batch_index, valid_indices] = torch.maximum( |
| soft_targets[batch_index, valid_indices], |
| values[valid], |
| ) |
|
|
| return torch.maximum(soft_targets, gate_targets).clamp(0.0, 1.0) |
|
|
|
|
| def compute_soft_bce_gate_loss( |
| *, |
| gate_logits: torch.Tensor, |
| hard_gate_targets: torch.Tensor, |
| gate_loss_mask: torch.Tensor, |
| observation_roles: torch.Tensor, |
| positive_weight: torch.Tensor, |
| training_cfg: dict[str, Any], |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: |
| soft_targets = build_temporally_soft_gate_targets( |
| gate_targets=hard_gate_targets, |
| gate_loss_mask=gate_loss_mask, |
| observation_roles=observation_roles, |
| training_cfg=training_cfg, |
| ) |
| loss_raw = F.binary_cross_entropy_with_logits(gate_logits, soft_targets, reduction="none") |
| soft_weight = 1.0 + (positive_weight.to(device=gate_logits.device) - 1.0) * soft_targets |
| masked_loss = loss_raw * soft_weight * gate_loss_mask |
| gate_loss_sum = masked_loss.sum() |
| gate_count = gate_loss_mask.sum() |
| gate_loss = gate_loss_sum / gate_count.clamp_min(1.0) |
| metrics = { |
| "gate_loss_sum": gate_loss_sum.detach(), |
| "gate_count": gate_count.detach(), |
| "soft_target_mass": (soft_targets * gate_loss_mask).sum().detach(), |
| "soft_target_positive_count": ((soft_targets >= 0.5) & (gate_loss_mask > 0)).float().sum().detach(), |
| } |
| return gate_loss, metrics |
|
|
|
|
| def compute_gate_ranking_loss( |
| gate_logits: torch.Tensor, |
| gate_targets: torch.Tensor, |
| gate_loss_mask: torch.Tensor, |
| ranking_margin: float, |
| ranking_lookback: int, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| device = gate_logits.device |
| ranking_loss_sum = torch.zeros((), device=device) |
| ranking_pair_count = torch.zeros((), device=device) |
|
|
| batch_size, max_steps = gate_logits.shape |
| for batch_index in range(batch_size): |
| for step_index in range(max_steps): |
| if float(gate_loss_mask[batch_index, step_index]) <= 0.0: |
| continue |
| if float(gate_targets[batch_index, step_index]) <= 0.5: |
| continue |
|
|
| positive_logit = gate_logits[batch_index, step_index] |
| search_start = max(0, step_index - ranking_lookback) |
| for previous_index in range(search_start, step_index): |
| if float(gate_loss_mask[batch_index, previous_index]) <= 0.0: |
| continue |
| if float(gate_targets[batch_index, previous_index]) >= 0.5: |
| continue |
|
|
| negative_logit = gate_logits[batch_index, previous_index] |
| ranking_loss_sum = ranking_loss_sum + F.relu(ranking_margin - positive_logit + negative_logit) |
| ranking_pair_count = ranking_pair_count + 1 |
|
|
| return ranking_loss_sum, ranking_pair_count |
|
|
|
|
| def compute_chunk_losses( |
| model: ThoughtLoopT5Gemma, |
| batch: dict[str, torch.Tensor], |
| training_cfg: dict[str, Any], |
| z_state: torch.Tensor, |
| state_mask: torch.Tensor, |
| tbptt_steps: int | None = None, |
| freeze_gate_head_to_targets: bool = False, |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor], torch.Tensor, torch.Tensor]: |
| device = batch["tick_mask"].device |
| batch_size, max_steps = batch["tick_mask"].shape |
| gate_loss_mode = str(training_cfg.get("gate_loss_mode", "hazard")).strip().lower() |
| use_hazard_gate_loss = gate_loss_mode == "hazard" |
| use_soft_bce_gate_loss = gate_loss_mode == "soft_bce" |
| pos_weight = ( |
| torch.zeros((), device=device) |
| if use_hazard_gate_loss |
| else compute_gate_positive_weight(batch, training_cfg) |
| ) |
|
|
| gate_loss_sum = torch.zeros((), device=device) |
| gate_count = torch.zeros((), device=device) |
| decoder_loss_sum = torch.zeros((), device=device) |
| decoder_token_count = torch.zeros((), device=device) |
| state_penalty_sum = torch.zeros((), device=device) |
| state_penalty_count = torch.zeros((), device=device) |
|
|
| tp = torch.zeros((), device=device) |
| fp = torch.zeros((), device=device) |
| fn = torch.zeros((), device=device) |
| correct = torch.zeros((), device=device) |
| counted = torch.zeros((), device=device) |
| target_positive_count = torch.zeros((), device=device) |
| predicted_positive_count = torch.zeros((), device=device) |
|
|
| state_delta_weight = float(training_cfg["state_delta_weight"]) |
| steps_since_detach = 0 |
| gate_logits_history: list[torch.Tensor] = [] |
| gate_target_history: list[torch.Tensor] = [] |
| gate_loss_mask_history: list[torch.Tensor] = [] |
| observation_role_history: list[torch.Tensor] = [] |
| hazard_event_count = torch.zeros((), device=device) |
| hazard_censor_count = torch.zeros((), device=device) |
| soft_target_mass = torch.zeros((), device=device) |
| soft_target_positive_count = torch.zeros((), device=device) |
| teacher_logit_scale = float(training_cfg.get("gate_teacher_logit_scale", 20.0)) |
|
|
| for step_index in range(max_steps): |
| active_mask = batch["tick_mask"][:, step_index] |
| if not torch.any(active_mask): |
| break |
|
|
| previous_z = z_state |
| next_z, next_state_mask, gate_logits = rollout_active_rows_only( |
| model=model, |
| batch=batch, |
| z_state=z_state, |
| state_mask=state_mask, |
| step_index=step_index, |
| active_mask=active_mask, |
| ) |
| z_state = next_z |
| state_mask = next_state_mask |
|
|
| step_loss_mask = active_mask.float() |
| step_targets = batch["gate_target"][:, step_index] |
| if freeze_gate_head_to_targets: |
| gate_logits = torch.where( |
| step_targets > 0.5, |
| torch.full_like(step_targets, teacher_logit_scale), |
| torch.full_like(step_targets, -teacher_logit_scale), |
| ) |
|
|
| gate_logits_history.append(gate_logits) |
| gate_target_history.append(step_targets) |
| gate_loss_mask_history.append(step_loss_mask) |
| observation_role_history.append(batch["observation_role"][:, step_index]) |
|
|
| if torch.any(step_loss_mask > 0): |
| step_predictions = (torch.sigmoid(gate_logits) >= 0.5).float() |
| valid = step_loss_mask > 0 |
| correct = correct + ((step_predictions == step_targets) & valid).float().sum() |
| counted = counted + valid.float().sum() |
| tp = tp + ((step_predictions == 1) & (step_targets == 1) & valid).float().sum() |
| fp = fp + ((step_predictions == 1) & (step_targets == 0) & valid).float().sum() |
| fn = fn + ((step_predictions == 0) & (step_targets == 1) & valid).float().sum() |
| target_positive_count = target_positive_count + ((step_targets == 1) & valid).float().sum() |
| predicted_positive_count = predicted_positive_count + ((step_predictions == 1) & valid).float().sum() |
|
|
| if ( |
| torch.any(step_loss_mask > 0) |
| and not use_hazard_gate_loss |
| and not use_soft_bce_gate_loss |
| and not freeze_gate_head_to_targets |
| ): |
| gate_loss_raw = F.binary_cross_entropy_with_logits(gate_logits, step_targets, reduction="none") |
| gate_weight = torch.where(step_targets > 0.5, pos_weight, torch.ones_like(step_targets)) |
| gate_loss_sum = gate_loss_sum + (gate_loss_raw * gate_weight * step_loss_mask).sum() |
| gate_count = gate_count + step_loss_mask.sum() |
|
|
| if state_delta_weight > 0.0: |
| state_delta = (next_z - previous_z).pow(2).mean(dim=(1, 2)) |
| state_penalty_sum = state_penalty_sum + (state_delta * step_loss_mask).sum() |
| state_penalty_count = state_penalty_count + step_loss_mask.sum() |
| elif torch.any(step_loss_mask > 0) and state_delta_weight > 0.0: |
| state_delta = (next_z - previous_z).pow(2).mean(dim=(1, 2)) |
| state_penalty_sum = state_penalty_sum + (state_delta * step_loss_mask).sum() |
| state_penalty_count = state_penalty_count + step_loss_mask.sum() |
|
|
| step_labels = batch["decoder_labels"][:, step_index, :] |
| decoder_rows = active_mask & torch.any(step_labels != -100, dim=-1) |
| if torch.any(decoder_rows): |
| step_loss_sum, step_token_count = model.decoder_loss( |
| z_state[decoder_rows], |
| step_labels[decoder_rows], |
| encoder_attention_mask=state_mask[decoder_rows], |
| ) |
| decoder_loss_sum = decoder_loss_sum + step_loss_sum |
| decoder_token_count = decoder_token_count + step_token_count |
|
|
| if tbptt_steps is not None: |
| steps_since_detach += 1 |
| if steps_since_detach >= tbptt_steps: |
| z_state = z_state.detach() |
| state_mask = state_mask.detach() |
| steps_since_detach = 0 |
|
|
| if gate_logits_history: |
| stacked_gate_logits = torch.stack(gate_logits_history, dim=1) |
| stacked_gate_targets = torch.stack(gate_target_history, dim=1) |
| stacked_gate_loss_mask = torch.stack(gate_loss_mask_history, dim=1) |
| stacked_observation_roles = torch.stack(observation_role_history, dim=1) |
| else: |
| stacked_gate_logits = torch.zeros((batch_size, 0), device=device) |
| stacked_gate_targets = torch.zeros((batch_size, 0), device=device) |
| stacked_gate_loss_mask = torch.zeros((batch_size, 0), device=device) |
| stacked_observation_roles = torch.zeros((batch_size, 0), device=device, dtype=torch.long) |
|
|
| gate_ranking_loss = torch.zeros((), device=device) |
| gate_ranking_loss_sum = torch.zeros((), device=device) |
| gate_ranking_pair_count = torch.zeros((), device=device) |
| if freeze_gate_head_to_targets: |
| gate_loss = torch.zeros((), device=device) |
| gate_loss_sum = torch.zeros((), device=device) |
| gate_count = torch.zeros((), device=device) |
| hazard_event_count = torch.zeros((), device=device) |
| hazard_censor_count = torch.zeros((), device=device) |
| elif use_hazard_gate_loss and gate_logits_history: |
| gate_loss, hazard_metrics = compute_distance_aware_hazard_gate_loss( |
| gate_logits=stacked_gate_logits, |
| gate_targets=stacked_gate_targets, |
| gate_loss_mask=stacked_gate_loss_mask, |
| training_cfg=training_cfg, |
| ) |
| gate_loss_sum = hazard_metrics["gate_loss_sum"] |
| gate_count = hazard_metrics["gate_count"] |
| hazard_event_count = hazard_metrics["hazard_event_count"] |
| hazard_censor_count = hazard_metrics["hazard_censor_count"] |
| elif use_soft_bce_gate_loss and gate_logits_history: |
| gate_loss, soft_bce_metrics = compute_soft_bce_gate_loss( |
| gate_logits=stacked_gate_logits, |
| hard_gate_targets=stacked_gate_targets, |
| gate_loss_mask=stacked_gate_loss_mask, |
| observation_roles=stacked_observation_roles, |
| positive_weight=pos_weight, |
| training_cfg=training_cfg, |
| ) |
| gate_loss_sum = soft_bce_metrics["gate_loss_sum"] |
| gate_count = soft_bce_metrics["gate_count"] |
| soft_target_mass = soft_bce_metrics["soft_target_mass"] |
| soft_target_positive_count = soft_bce_metrics["soft_target_positive_count"] |
| else: |
| gate_loss = gate_loss_sum / gate_count.clamp_min(1.0) |
|
|
| if ( |
| not use_hazard_gate_loss |
| and bool(training_cfg.get("use_ranked_gate_objective", False)) |
| and gate_logits_history |
| and not freeze_gate_head_to_targets |
| ): |
| gate_ranking_loss_sum, gate_ranking_pair_count = compute_gate_ranking_loss( |
| gate_logits=stacked_gate_logits, |
| gate_targets=stacked_gate_targets, |
| gate_loss_mask=stacked_gate_loss_mask, |
| ranking_margin=float(training_cfg.get("gate_ranking_margin", 0.5)), |
| ranking_lookback=int(training_cfg.get("gate_ranking_lookback", 64)), |
| ) |
| gate_ranking_loss = gate_ranking_loss_sum / gate_ranking_pair_count.clamp_min(1.0) |
| gate_loss = ( |
| float(training_cfg.get("gate_bce_weight", 1.0)) * gate_loss |
| + float(training_cfg.get("gate_ranking_weight", 1.0)) * gate_ranking_loss |
| ) |
|
|
| decoder_loss = decoder_loss_sum / decoder_token_count.clamp_min(1.0) |
| state_delta_loss = state_penalty_sum / state_penalty_count.clamp_min(1.0) |
| total_loss = ( |
| float(training_cfg["gate_loss_weight"]) * gate_loss |
| + float(training_cfg["decoder_loss_weight"]) * decoder_loss |
| + state_delta_weight * state_delta_loss |
| ) |
|
|
| metrics = { |
| "loss_total": total_loss.detach(), |
| "gate_loss_sum": gate_loss_sum.detach(), |
| "gate_count": gate_count.detach(), |
| "decoder_loss_sum": decoder_loss_sum.detach(), |
| "decoder_token_count": decoder_token_count.detach(), |
| "state_penalty_sum": state_penalty_sum.detach(), |
| "state_penalty_count": state_penalty_count.detach(), |
| "gate_tp": tp.detach(), |
| "gate_fp": fp.detach(), |
| "gate_fn": fn.detach(), |
| "gate_correct": correct.detach(), |
| "gate_eval_count": counted.detach(), |
| "gate_target_positive_count": target_positive_count.detach(), |
| "gate_pred_positive_count": predicted_positive_count.detach(), |
| "gate_ranking_loss_sum": gate_ranking_loss_sum.detach(), |
| "gate_ranking_pair_count": gate_ranking_pair_count.detach(), |
| "hazard_event_count": hazard_event_count.detach(), |
| "hazard_censor_count": hazard_censor_count.detach(), |
| "soft_target_mass": soft_target_mass.detach(), |
| "soft_target_positive_count": soft_target_positive_count.detach(), |
| "active_ticks": batch["tick_mask"].float().sum().detach(), |
| "gate_positive_weight": pos_weight.detach(), |
| "metrics_count": torch.ones((), device=device), |
| } |
| return total_loss, metrics, z_state, state_mask |
|
|
|
|
| def reduce_metrics( |
| accelerator: Accelerator, |
| metric_list: list[dict[str, torch.Tensor]], |
| training_cfg: dict[str, Any], |
| ) -> dict[str, float]: |
| totals: dict[str, torch.Tensor] = {} |
| for metrics in metric_list: |
| for key, value in metrics.items(): |
| totals[key] = totals.get(key, torch.zeros_like(value)) + value |
|
|
| reduced = {key: accelerator.reduce(value, reduction="sum") for key, value in totals.items()} |
|
|
| gate_loss_mode = str(training_cfg.get("gate_loss_mode", "hazard")).strip().lower() |
| gate_loss = reduced["gate_loss_sum"] / reduced["gate_count"].clamp_min(1.0) |
| gate_ranking_loss = reduced["gate_ranking_loss_sum"] / reduced["gate_ranking_pair_count"].clamp_min(1.0) |
| decoder_loss = reduced["decoder_loss_sum"] / reduced["decoder_token_count"].clamp_min(1.0) |
| state_delta_loss = reduced["state_penalty_sum"] / reduced["state_penalty_count"].clamp_min(1.0) |
| gate_precision = reduced["gate_tp"] / (reduced["gate_tp"] + reduced["gate_fp"]).clamp_min(1.0) |
| gate_recall = reduced["gate_tp"] / (reduced["gate_tp"] + reduced["gate_fn"]).clamp_min(1.0) |
| gate_f1 = 2 * gate_precision * gate_recall / (gate_precision + gate_recall).clamp_min(1e-8) |
| if gate_loss_mode != "hazard" and bool(training_cfg.get("use_ranked_gate_objective", False)): |
| gate_loss = ( |
| float(training_cfg.get("gate_bce_weight", 1.0)) * gate_loss |
| + float(training_cfg.get("gate_ranking_weight", 1.0)) * gate_ranking_loss |
| ) |
| total_loss = ( |
| float(training_cfg["gate_loss_weight"]) * gate_loss |
| + float(training_cfg["decoder_loss_weight"]) * decoder_loss |
| + float(training_cfg["state_delta_weight"]) * state_delta_loss |
| ) |
|
|
| return { |
| "loss_total": float(total_loss), |
| "loss_gate": float(gate_loss), |
| "loss_gate_ranking": float(gate_ranking_loss), |
| "loss_decoder": float(decoder_loss), |
| "loss_state_delta": float(state_delta_loss), |
| "gate_accuracy": float(reduced["gate_correct"] / reduced["gate_eval_count"].clamp_min(1.0)), |
| "gate_precision": float(gate_precision), |
| "gate_recall": float(gate_recall), |
| "gate_f1": float(gate_f1), |
| "gate_target_positive_count": float(reduced["gate_target_positive_count"]), |
| "gate_pred_positive_count": float(reduced["gate_pred_positive_count"]), |
| "gate_positive_weight": float(reduced["gate_positive_weight"] / reduced["metrics_count"].clamp_min(1.0)), |
| "hazard_event_count": float(reduced["hazard_event_count"]), |
| "hazard_censor_count": float(reduced["hazard_censor_count"]), |
| "soft_target_mass": float(reduced["soft_target_mass"]), |
| "soft_target_positive_count": float(reduced["soft_target_positive_count"]), |
| "decoder_token_count": float(reduced["decoder_token_count"]), |
| "active_ticks": float(reduced["active_ticks"]), |
| } |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| accelerator: Accelerator, |
| model: ThoughtLoopT5Gemma, |
| dataloader: DataLoader, |
| training_cfg: dict[str, Any], |
| rollout_cfg: dict[str, Any], |
| tokenizer_pad_token_id: int, |
| max_batches: int | None = None, |
| ) -> dict[str, float]: |
| model.eval() |
| all_metrics: list[dict[str, torch.Tensor]] = [] |
| chunk_ticks = int(rollout_cfg["chunk_ticks"]) |
| tick_seconds = float(rollout_cfg["tick_seconds"]) |
|
|
| for batch_index, conversations in enumerate(dataloader): |
| if max_batches is not None and batch_index >= max_batches: |
| break |
|
|
| batch_size = len(conversations) |
| z_state = model.initial_state(batch_size=batch_size, device=accelerator.device) |
| state_mask = model.initial_state_mask(batch_size=batch_size, device=accelerator.device) |
| max_chunks = max(int(conversation["chunk_count"]) for conversation in conversations) |
| for chunk_index in range(max_chunks): |
| chunk_batch = build_chunk_batch( |
| conversations=conversations, |
| chunk_index=chunk_index, |
| pad_token_id=tokenizer_pad_token_id, |
| chunk_ticks=chunk_ticks, |
| tick_seconds=tick_seconds, |
| ) |
| if chunk_batch is None: |
| break |
|
|
| chunk_batch = move_chunk_batch_to_device(chunk_batch, accelerator.device) |
| _, metrics, z_state, state_mask = compute_chunk_losses( |
| model=model, |
| batch=chunk_batch, |
| training_cfg=training_cfg, |
| z_state=z_state, |
| state_mask=state_mask, |
| tbptt_steps=None, |
| freeze_gate_head_to_targets=False, |
| ) |
| z_state = z_state.detach() |
| state_mask = state_mask.detach() |
| all_metrics.append(metrics) |
|
|
| if not all_metrics: |
| model.train() |
| return {} |
|
|
| reduced = reduce_metrics(accelerator, all_metrics, training_cfg) |
| model.train() |
| return reduced |
|
|
|
|
| def copy_supporting_files(output_dir: Path, config_path: Path) -> None: |
| for filename in ("config.py", "data.py", "model.py", "inference.py", "train.py"): |
| shutil.copy2(SCRIPT_DIR / filename, output_dir / filename) |
| shutil.copy2(config_path, output_dir / config_path.name) |
|
|
|
|
| def build_model_card(config: dict[str, Any], best_metrics: dict[str, float]) -> str: |
| return "\n".join( |
| [ |
| "# Samantha Thought-Loop SFT", |
| "", |
| f"- Base model: `{config['model']['base_model_name']}`", |
| f"- Cleaned dataset repo: `{config['dataset'].get('cleaned_repo_id')}`", |
| f"- Raw dataset repo: `{config['dataset'].get('raw_repo_id')}`", |
| f"- Latent slots: `{config['model']['z_slots']}`", |
| f"- Fixed tick seconds: `{config['rollout']['tick_seconds']}`", |
| f"- Chunk ticks: `{config['rollout']['chunk_ticks']}`", |
| f"- Max horizon ticks: `{config['rollout'].get('max_horizon_ticks', 36000)}`", |
| f"- Explicit time features: `{config['model'].get('use_explicit_time_features', False)}`", |
| f"- Gate loss mode: `{config['training'].get('gate_loss_mode', 'hazard')}`", |
| f"- Hazard soft target window: `{config['training'].get('hazard_soft_target_window_ticks', 16)}`", |
| f"- Hazard soft target tau ticks: `{config['training'].get('hazard_soft_target_tau_ticks', 4.0)}`", |
| f"- Hazard censor weight: `{config['training'].get('hazard_censor_weight', 0.25)}`", |
| f"- TBPTT by bucket: `{config['training'].get('tbptt_steps_by_bucket', {})}`", |
| "", |
| "## Best Validation Metrics", |
| "", |
| *[f"- `{key}`: `{value:.6f}`" for key, value in sorted(best_metrics.items())], |
| "", |
| ] |
| ) |
|
|
|
|
| def export_model( |
| accelerator: Accelerator, |
| model: ThoughtLoopT5Gemma, |
| output_dir: Path, |
| config: dict[str, Any], |
| config_path: Path, |
| best_metrics: dict[str, float], |
| ) -> None: |
| if not accelerator.is_main_process: |
| return |
| output_dir.mkdir(parents=True, exist_ok=True) |
| unwrapped = accelerator.unwrap_model(model) |
| unwrapped.save_pretrained(output_dir) |
| copy_supporting_files(output_dir, config_path) |
| (output_dir / "README.md").write_text(build_model_card(config, best_metrics), encoding="utf-8") |
| save_json(output_dir / "best_metrics.json", best_metrics) |
|
|
|
|
| def maybe_push_to_hub(config: dict[str, Any], export_dir: Path) -> None: |
| model_repo_id = config["hub"].get("model_repo_id") |
| if not model_repo_id: |
| return |
| api = HfApi() |
| api.create_repo(repo_id=model_repo_id, repo_type="model", exist_ok=True, private=bool(config["hub"].get("private"))) |
| api.upload_folder(folder_path=str(export_dir), repo_id=model_repo_id, repo_type="model") |
|
|
|
|
| def maybe_push_checkpoint_to_hub( |
| accelerator: Accelerator, |
| config: dict[str, Any], |
| checkpoint_dir: Path, |
| ) -> None: |
| model_repo_id = config["hub"].get("model_repo_id") |
| if not accelerator.is_main_process or not model_repo_id: |
| return |
| if not bool(config["hub"].get("push_checkpoints", True)): |
| return |
|
|
| run_name = str(config["wandb"].get("run_name", "run")) |
| checkpoint_prefix = str(config["hub"].get("checkpoint_path_prefix", "checkpoints")).strip("/") |
| path_parts = [part for part in (checkpoint_prefix, SCRIPT_DIR.name, run_name, checkpoint_dir.name) if part] |
| path_in_repo = "/".join(path_parts) |
|
|
| api = HfApi() |
| api.create_repo(repo_id=model_repo_id, repo_type="model", exist_ok=True, private=bool(config["hub"].get("private"))) |
| api.upload_folder( |
| folder_path=str(checkpoint_dir), |
| repo_id=model_repo_id, |
| repo_type="model", |
| path_in_repo=path_in_repo, |
| commit_message=f"Add {SCRIPT_DIR.name} {checkpoint_dir.name}", |
| ) |
|
|
|
|
| def read_json(path: Path) -> dict[str, Any]: |
| if not path.exists(): |
| return {} |
| with path.open("r", encoding="utf-8") as handle: |
| payload = json.load(handle) |
| if not isinstance(payload, dict): |
| raise ValueError(f"Expected a JSON object in {path}.") |
| return payload |
|
|
|
|
| def checkpoint_global_step(checkpoint_dir: Path) -> int | None: |
| state = read_json(checkpoint_dir / "trainer_state.json") |
| if "global_step" in state: |
| return int(state["global_step"]) |
|
|
| prefix = "checkpoint-" |
| if checkpoint_dir.name.startswith(prefix): |
| suffix = checkpoint_dir.name[len(prefix) :] |
| if suffix.isdigit(): |
| return int(suffix) |
| return None |
|
|
|
|
| def find_latest_checkpoint(run_dir: Path) -> Path | None: |
| checkpoints = [] |
| for candidate in run_dir.glob("checkpoint-*"): |
| if not candidate.is_dir(): |
| continue |
| step = checkpoint_global_step(candidate) |
| if step is not None: |
| checkpoints.append((step, candidate)) |
| if not checkpoints: |
| return None |
| return max(checkpoints, key=lambda item: item[0])[1] |
|
|
|
|
| def resolve_resume_checkpoint(requested: str | None, run_dir: Path) -> Path | None: |
| if requested is None: |
| return None |
|
|
| value = str(requested).strip() |
| if not value: |
| return None |
|
|
| if value.lower() in {"1", "true", "yes", "auto", "latest"}: |
| checkpoint_dir = find_latest_checkpoint(run_dir) |
| if checkpoint_dir is None: |
| raise FileNotFoundError(f"No checkpoint-* directory found under {run_dir}.") |
| return checkpoint_dir |
|
|
| checkpoint_dir = Path(value).expanduser() |
| if not checkpoint_dir.is_absolute(): |
| checkpoint_dir = SCRIPT_DIR / checkpoint_dir |
| checkpoint_dir = checkpoint_dir.resolve() |
| if not checkpoint_dir.is_dir(): |
| raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_dir}") |
| return checkpoint_dir |
|
|
|
|
| def parse_wandb_run_id_from_dir(path: Path) -> str | None: |
| name = path.resolve(strict=False).name if path.is_symlink() else path.name |
| for prefix in ("offline-run-", "run-"): |
| if name.startswith(prefix): |
| parts = name.split("-") |
| if parts: |
| return parts[-1] |
| return None |
|
|
|
|
| def discover_local_wandb_run_id(run_name: str) -> str | None: |
| wandb_dir = SCRIPT_DIR / "wandb" |
| if not wandb_dir.exists(): |
| return None |
|
|
| candidates: list[Path] = [] |
| latest = wandb_dir / "latest-run" |
| if latest.exists() or latest.is_symlink(): |
| candidates.append(latest) |
| candidates.extend(sorted(wandb_dir.glob("*run-*"), key=lambda path: path.stat().st_mtime, reverse=True)) |
|
|
| for candidate in candidates: |
| run_id = parse_wandb_run_id_from_dir(candidate) |
| if not run_id: |
| continue |
| resolved = candidate.resolve(strict=False) |
| config_path = resolved / "files" / "config.yaml" |
| debug_log_path = resolved / "logs" / "debug.log" |
| if config_path.exists() and run_name in config_path.read_text(encoding="utf-8", errors="ignore"): |
| return run_id |
| if debug_log_path.exists() and run_name in debug_log_path.read_text(encoding="utf-8", errors="ignore"): |
| return run_id |
| return None |
|
|
|
|
| def resolve_wandb_run_id(args: argparse.Namespace, config: dict[str, Any], run_dir: Path) -> str: |
| run_id_path = run_dir / "wandb_run_id.txt" |
| explicit = args.wandb_run_id or os.environ.get("WANDB_RUN_ID") or config["wandb"].get("run_id") |
| if explicit: |
| run_id = str(explicit).strip() |
| elif run_id_path.exists(): |
| run_id = run_id_path.read_text(encoding="utf-8").strip() |
| else: |
| run_id = discover_local_wandb_run_id(str(config["wandb"].get("run_name", ""))) or secrets.token_hex(4) |
|
|
| if not run_id: |
| run_id = secrets.token_hex(4) |
| run_id_path.write_text(run_id + "\n", encoding="utf-8") |
| return run_id |
|
|
|
|
| def save_trainer_state( |
| checkpoint_dir: Path, |
| *, |
| global_step: int, |
| epoch: int, |
| next_batch_index: int, |
| best_validation_loss: float, |
| best_metrics: dict[str, float], |
| wandb_run_id: str | None, |
| ) -> None: |
| payload: dict[str, Any] = { |
| "global_step": int(global_step), |
| "epoch": int(epoch), |
| "next_batch_index": int(next_batch_index), |
| "best_validation_loss": best_validation_loss if math.isfinite(best_validation_loss) else None, |
| "best_metrics": best_metrics, |
| "wandb_run_id": wandb_run_id, |
| } |
| save_json(checkpoint_dir / "trainer_state.json", payload) |
|
|
|
|
| def load_resume_progress( |
| checkpoint_dir: Path, |
| *, |
| total_chunk_microsteps: int, |
| gradient_accumulation_steps: int, |
| num_train_epochs: int, |
| ) -> tuple[int, int, int, dict[str, Any]]: |
| state = read_json(checkpoint_dir / "trainer_state.json") |
| global_step = int(state.get("global_step") or checkpoint_global_step(checkpoint_dir) or 0) |
| if "epoch" in state and "next_batch_index" in state: |
| start_epoch = int(state["epoch"]) |
| batches_to_skip = int(state["next_batch_index"]) |
| else: |
| updates_per_epoch = max(1, math.ceil(total_chunk_microsteps / max(gradient_accumulation_steps, 1))) |
| start_epoch = min(global_step // updates_per_epoch, max(num_train_epochs - 1, 0)) |
| updates_into_epoch = global_step - start_epoch * updates_per_epoch |
| batches_to_skip = updates_into_epoch * max(gradient_accumulation_steps, 1) |
|
|
| while batches_to_skip >= total_chunk_microsteps and start_epoch < num_train_epochs: |
| batches_to_skip -= total_chunk_microsteps |
| start_epoch += 1 |
|
|
| return global_step, start_epoch, batches_to_skip, state |
|
|
|
|
| def load_best_metrics(export_dir: Path, resume_state: dict[str, Any]) -> tuple[float, dict[str, float]]: |
| best_metrics = resume_state.get("best_metrics") |
| if not isinstance(best_metrics, dict): |
| best_metrics = read_json(export_dir / "best_metrics.json") |
|
|
| loss = resume_state.get("best_validation_loss") |
| if loss is None and isinstance(best_metrics, dict): |
| loss = best_metrics.get("loss_total") |
| best_validation_loss = float(loss) if loss is not None else float("inf") |
| return best_validation_loss, {str(key): float(value) for key, value in best_metrics.items()} |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| config_path = Path(args.config) |
| config = load_config(config_path) |
| training_cfg = config["training"] |
| rollout_cfg = config["rollout"] |
|
|
| set_seed(int(training_cfg["seed"])) |
| wandb_enabled = bool(config["wandb"].get("enabled", True) and wandb is not None) |
|
|
| accelerator = Accelerator( |
| gradient_accumulation_steps=int(training_cfg["gradient_accumulation_steps"]), |
| mixed_precision=training_cfg.get("mixed_precision", "bf16"), |
| log_with="wandb" if wandb_enabled else None, |
| project_dir=config["paths"]["run_root"], |
| ) |
|
|
| run_root = ensure_dir(config["paths"]["run_root"]) |
| export_root = ensure_dir(config["paths"]["export_root"]) |
| run_name = config["wandb"]["run_name"] |
| run_dir = ensure_dir(run_root / run_name) |
| export_dir = ensure_dir(export_root / run_name) |
| resume_request = args.resume_from_checkpoint or training_cfg.get("resume_from_checkpoint") |
| resume_checkpoint = resolve_resume_checkpoint(resume_request, run_dir) |
| wandb_run_id = resolve_wandb_run_id(args, config, run_dir) if wandb_enabled else None |
|
|
| model = build_initial_model(config) |
| if model.use_explicit_time_features: |
| raise ValueError("This trainer expects model.use_explicit_time_features=false for fixed-cadence training.") |
| tokenizer = model.tokenizer |
|
|
| train_dataset = ThoughtLoopConversationDataset(config=config, tokenizer=tokenizer, split="train") |
| validation_dataset = ThoughtLoopConversationDataset(config=config, tokenizer=tokenizer, split="validation") |
| sort_dataset_by_length = bool(training_cfg.get("sort_dataset_by_length", False)) |
|
|
| ensure_single_chunk_dataset(train_dataset, "train") |
| ensure_single_chunk_dataset(validation_dataset, "validation") |
| maybe_sort_dataset_by_length(train_dataset, enabled=sort_dataset_by_length) |
| maybe_sort_dataset_by_length(validation_dataset, enabled=sort_dataset_by_length) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=int(training_cfg["micro_batch_size"]), |
| shuffle=bool(training_cfg.get("shuffle_train", True)) and not sort_dataset_by_length, |
| num_workers=int(training_cfg["num_workers"]), |
| pin_memory=True, |
| persistent_workers=bool(int(training_cfg["num_workers"]) > 0), |
| collate_fn=identity_collate, |
| ) |
| validation_loader = DataLoader( |
| validation_dataset, |
| batch_size=int(training_cfg["eval_batch_size"]), |
| shuffle=False, |
| num_workers=int(training_cfg["num_workers"]), |
| pin_memory=True, |
| persistent_workers=bool(int(training_cfg["num_workers"]) > 0), |
| collate_fn=identity_collate, |
| ) |
|
|
| optimizer = build_optimizer(model, training_cfg) |
| total_chunk_microsteps = len(train_loader) |
| total_update_steps = max( |
| 1, |
| math.ceil(total_chunk_microsteps / max(int(training_cfg["gradient_accumulation_steps"]), 1)) |
| * int(training_cfg["num_train_epochs"]), |
| ) |
| warmup_steps = max(1, int(total_update_steps * float(training_cfg["warmup_ratio"]))) |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer=optimizer, |
| num_warmup_steps=warmup_steps, |
| num_training_steps=max(1, total_update_steps), |
| ) |
|
|
| model, optimizer, train_loader, validation_loader, scheduler = accelerator.prepare( |
| model, optimizer, train_loader, validation_loader, scheduler |
| ) |
|
|
| resume_state: dict[str, Any] = {} |
| global_step = 0 |
| start_epoch = 0 |
| batches_to_skip = 0 |
| if resume_checkpoint is not None: |
| accelerator.print(f"Loading local checkpoint from {resume_checkpoint}") |
| accelerator.load_state(str(resume_checkpoint)) |
| global_step, start_epoch, batches_to_skip, resume_state = load_resume_progress( |
| resume_checkpoint, |
| total_chunk_microsteps=total_chunk_microsteps, |
| gradient_accumulation_steps=int(training_cfg["gradient_accumulation_steps"]), |
| num_train_epochs=int(training_cfg["num_train_epochs"]), |
| ) |
| accelerator.print( |
| f"Resuming at global_step={global_step}, epoch={start_epoch + 1}, " |
| f"skipping {batches_to_skip} batch(es)." |
| ) |
|
|
| if accelerator.is_main_process and wandb_enabled: |
| accelerator.print(f"W&B run id: {wandb_run_id}") |
| accelerator.init_trackers( |
| project_name=config["wandb"]["project"], |
| config=flatten_for_wandb(config), |
| init_kwargs={ |
| "wandb": { |
| "name": run_name, |
| "id": wandb_run_id, |
| "resume": str(config["wandb"].get("resume", "allow")), |
| } |
| }, |
| ) |
|
|
| unwrapped_model = accelerator.unwrap_model(model) |
| current_decoder_fraction = resolve_decoder_trainable_fraction(training_cfg, global_step, total_update_steps) |
| if current_decoder_fraction is None: |
| current_decoder_fraction = 1.0 |
| current_decoder_trainable_layers = unwrapped_model.set_decoder_trainable_fraction(current_decoder_fraction) |
|
|
| current_gate_head_frozen = should_freeze_gate_head(training_cfg, global_step, total_update_steps) |
| unwrapped_model.set_gate_trainable(not current_gate_head_frozen) |
|
|
| progress_bar = tqdm( |
| total=total_update_steps, |
| initial=global_step, |
| disable=not accelerator.is_local_main_process, |
| desc="Training", |
| dynamic_ncols=True, |
| ) |
|
|
| best_validation_loss, best_metrics = load_best_metrics(export_dir, resume_state) |
| model.train() |
|
|
| chunk_ticks = int(rollout_cfg["chunk_ticks"]) |
| tick_seconds = float(rollout_cfg["tick_seconds"]) |
|
|
| for epoch in range(start_epoch, int(training_cfg["num_train_epochs"])): |
| epoch_train_loader = train_loader |
| batch_index_offset = 0 |
| if epoch == start_epoch and batches_to_skip > 0: |
| epoch_train_loader = accelerator.skip_first_batches(train_loader, batches_to_skip) |
| batch_index_offset = batches_to_skip |
|
|
| for batch_index, conversations in enumerate(epoch_train_loader, start=batch_index_offset): |
| batch_bucket = resolve_batch_duration_bucket(conversations) |
| current_tbptt_steps = resolve_bucket_tbptt_steps(training_cfg, batch_bucket) |
| current_horizon_ticks = resolve_bucket_horizon_ticks(batch_bucket, rollout_cfg) |
|
|
| z_state = model.initial_state(batch_size=len(conversations), device=accelerator.device) |
| state_mask = model.initial_state_mask(batch_size=len(conversations), device=accelerator.device) |
| max_chunks = max(int(conversation["chunk_count"]) for conversation in conversations) |
|
|
| for chunk_index in range(max_chunks): |
| scheduled_decoder_fraction = resolve_decoder_trainable_fraction(training_cfg, global_step, total_update_steps) |
| if scheduled_decoder_fraction is None: |
| scheduled_decoder_fraction = 1.0 |
| if abs(scheduled_decoder_fraction - current_decoder_fraction) > 1e-6: |
| current_decoder_fraction = scheduled_decoder_fraction |
| current_decoder_trainable_layers = unwrapped_model.set_decoder_trainable_fraction( |
| current_decoder_fraction |
| ) |
|
|
| scheduled_gate_head_frozen = should_freeze_gate_head(training_cfg, global_step, total_update_steps) |
| if scheduled_gate_head_frozen != current_gate_head_frozen: |
| current_gate_head_frozen = scheduled_gate_head_frozen |
| unwrapped_model.set_gate_trainable(not current_gate_head_frozen) |
|
|
| chunk_batch = build_chunk_batch( |
| conversations=conversations, |
| chunk_index=chunk_index, |
| pad_token_id=tokenizer.pad_token_id, |
| chunk_ticks=chunk_ticks, |
| tick_seconds=tick_seconds, |
| ) |
| if chunk_batch is None: |
| break |
| chunk_batch = move_chunk_batch_to_device(chunk_batch, accelerator.device) |
|
|
| with accelerator.accumulate(model): |
| total_loss, metrics, z_state, state_mask = compute_chunk_losses( |
| model=model, |
| batch=chunk_batch, |
| training_cfg=training_cfg, |
| z_state=z_state, |
| state_mask=state_mask, |
| tbptt_steps=current_tbptt_steps, |
| freeze_gate_head_to_targets=current_gate_head_frozen, |
| ) |
| accelerator.backward(total_loss) |
| accelerator.clip_grad_norm_(model.parameters(), float(training_cfg["max_grad_norm"])) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| z_state = z_state.detach() |
| state_mask = state_mask.detach() |
|
|
| if accelerator.sync_gradients: |
| global_step += 1 |
| progress_bar.update(1) |
| progress_bar.set_postfix( |
| loss=f"{float(total_loss.detach()):.4f}", |
| bucket=batch_bucket, |
| tbptt=int(current_tbptt_steps or 0), |
| horizon=int(current_horizon_ticks), |
| decoder=current_decoder_trainable_layers, |
| gate="locked" if current_gate_head_frozen else "train", |
| refresh=False, |
| ) |
|
|
| if global_step % int(training_cfg["logging_steps"]) == 0: |
| reduced = reduce_metrics(accelerator, [metrics], training_cfg) |
| reduced["lr"] = scheduler.get_last_lr()[0] |
| reduced["tbptt_steps"] = float(current_tbptt_steps or 0) |
| reduced["horizon_ticks"] = float(current_horizon_ticks) |
| reduced["chunk_ticks"] = float(chunk_ticks) |
| reduced["decoder_trainable_layers"] = float(current_decoder_trainable_layers) |
| reduced["gate_head_frozen"] = float(current_gate_head_frozen) |
| reduced["epoch"] = epoch + 1 |
| reduced["global_step"] = global_step |
| accelerator.log({f"train/{key}": value for key, value in reduced.items()}, step=global_step) |
|
|
| if global_step % int(training_cfg["eval_steps"]) == 0: |
| validation_metrics = evaluate( |
| accelerator=accelerator, |
| model=model, |
| dataloader=validation_loader, |
| training_cfg=training_cfg, |
| rollout_cfg=rollout_cfg, |
| tokenizer_pad_token_id=tokenizer.pad_token_id, |
| max_batches=training_cfg.get("eval_max_batches"), |
| ) |
| accelerator.log( |
| {f"validation/{key}": value for key, value in validation_metrics.items()}, |
| step=global_step, |
| ) |
|
|
| if validation_metrics and validation_metrics["loss_total"] < best_validation_loss: |
| best_validation_loss = validation_metrics["loss_total"] |
| best_metrics = validation_metrics |
| export_model( |
| accelerator=accelerator, |
| model=model, |
| output_dir=export_dir, |
| config=config, |
| config_path=config_path, |
| best_metrics=best_metrics, |
| ) |
|
|
| if global_step % int(training_cfg["checkpoint_steps"]) == 0: |
| checkpoint_dir = ensure_dir(run_dir / f"checkpoint-{global_step}") |
| accelerator.save_state(str(checkpoint_dir)) |
| if accelerator.is_main_process: |
| save_trainer_state( |
| checkpoint_dir, |
| global_step=global_step, |
| epoch=epoch, |
| next_batch_index=batch_index + 1, |
| best_validation_loss=best_validation_loss, |
| best_metrics=best_metrics, |
| wandb_run_id=wandb_run_id, |
| ) |
| accelerator.wait_for_everyone() |
| maybe_push_checkpoint_to_hub( |
| accelerator=accelerator, |
| config=config, |
| checkpoint_dir=checkpoint_dir, |
| ) |
| accelerator.wait_for_everyone() |
|
|
| if not best_metrics: |
| best_metrics = evaluate( |
| accelerator=accelerator, |
| model=model, |
| dataloader=validation_loader, |
| training_cfg=training_cfg, |
| rollout_cfg=rollout_cfg, |
| tokenizer_pad_token_id=tokenizer.pad_token_id, |
| max_batches=training_cfg.get("eval_max_batches"), |
| ) |
| export_model( |
| accelerator=accelerator, |
| model=model, |
| output_dir=export_dir, |
| config=config, |
| config_path=config_path, |
| best_metrics=best_metrics, |
| ) |
|
|
| if accelerator.is_main_process: |
| training_summary = { |
| "train_conversations": len(train_dataset), |
| "validation_conversations": len(validation_dataset), |
| "train_chunk_microsteps": total_chunk_microsteps, |
| "trainable_parameters": accelerator.unwrap_model(model).trainable_parameter_count(), |
| "best_metrics": best_metrics, |
| "global_step": global_step, |
| "resumed_from_checkpoint": str(resume_checkpoint) if resume_checkpoint is not None else None, |
| "wandb_run_id": wandb_run_id, |
| "decoder_layer_count": unwrapped_model.decoder_layer_count(), |
| "tick_seconds": tick_seconds, |
| "chunk_ticks": chunk_ticks, |
| "tbptt_steps_by_bucket": training_cfg.get("tbptt_steps_by_bucket", {}), |
| "horizon_ticks_by_bucket": rollout_cfg.get("horizon_ticks_by_bucket", {}), |
| "shuffle_train": bool(training_cfg.get("shuffle_train", True)), |
| "sort_dataset_by_length": sort_dataset_by_length, |
| } |
| save_json(run_dir / "training_summary.json", training_summary) |
| maybe_push_to_hub(config, export_dir) |
|
|
| progress_bar.close() |
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|