from __future__ import annotations import argparse import json import math import os import random import secrets import shutil import sys from pathlib import Path from typing import Any import numpy as np import torch import torch.nn.functional as F from accelerate import Accelerator from huggingface_hub import HfApi from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import get_cosine_schedule_with_warmup SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) from config import ensure_dir, flatten_for_wandb, load_config, save_json from data import ( ThoughtLoopConversationDataset, build_chunk_batch, identity_collate, resolve_bucket_horizon_ticks, ) from model import ThoughtLoopT5Gemma try: import wandb except ImportError: # pragma: no cover wandb = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train the recurrent T5Gemma thought-loop model.") parser.add_argument("--config", default="configs/sft_b200.yaml") parser.add_argument( "--resume-from-checkpoint", default=None, help="Local checkpoint path to resume from, or 'latest'/'auto' to use the newest checkpoint in the run dir.", ) parser.add_argument( "--wandb-run-id", default=None, help="Optional W&B run id to resume. If omitted, train.py reuses a stored/local run id when available.", ) return parser.parse_args() def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def build_optimizer(model: torch.nn.Module, training_cfg: dict[str, Any]) -> torch.optim.Optimizer: no_decay_terms = ("bias", "norm", "Norm") decay_params = [] no_decay_params = [] for name, parameter in model.named_parameters(): target = no_decay_params if any(term in name for term in no_decay_terms) else decay_params target.append(parameter) fused = bool(training_cfg.get("fused_adamw", True) and torch.cuda.is_available()) return torch.optim.AdamW( [ {"params": decay_params, "weight_decay": float(training_cfg["weight_decay"])}, {"params": no_decay_params, "weight_decay": 0.0}, ], lr=float(training_cfg["learning_rate"]), betas=(float(training_cfg["adam_beta1"]), float(training_cfg["adam_beta2"])), eps=float(training_cfg["adam_epsilon"]), fused=fused, ) def resolve_bucket_tbptt_steps(training_cfg: dict[str, Any], duration_bucket: str) -> int | None: bucket_map = training_cfg.get("tbptt_steps_by_bucket", {}) if duration_bucket in bucket_map: return int(bucket_map[duration_bucket]) start_steps = int(training_cfg.get("tbptt_start_steps", 0)) end_steps = int(training_cfg.get("tbptt_end_steps", start_steps)) if start_steps <= 0 or end_steps <= 0: return None return end_steps def resolve_linear_schedule_value( *, enabled: bool, global_step: int, total_update_steps: int, start_fraction: float, end_fraction: float, start_value: float, end_value: float, ) -> float | None: if not enabled: return None clamped_total_steps = max(total_update_steps, 1) start_step = int(round(clamped_total_steps * start_fraction)) end_step = int(round(clamped_total_steps * end_fraction)) if end_step <= start_step: end_step = start_step + 1 if global_step <= start_step: return start_value if global_step >= end_step: return end_value progress = (global_step - start_step) / max(end_step - start_step, 1) return start_value + progress * (end_value - start_value) def should_freeze_gate_head( training_cfg: dict[str, Any], global_step: int, total_update_steps: int, ) -> bool: if not bool(training_cfg.get("freeze_gate_head_first_half", False)): return False freeze_fraction = float(training_cfg.get("gate_head_freeze_fraction", 0.5)) return global_step < int(round(max(total_update_steps, 1) * freeze_fraction)) def resolve_decoder_trainable_fraction( training_cfg: dict[str, Any], global_step: int, total_update_steps: int, ) -> float | None: return resolve_linear_schedule_value( enabled=bool(training_cfg.get("gradual_unfreeze_decoder", False)), global_step=global_step, total_update_steps=total_update_steps, start_fraction=float(training_cfg.get("decoder_unfreeze_start_fraction", 0.0)), end_fraction=float(training_cfg.get("decoder_unfreeze_end_fraction", 0.5)), start_value=float(training_cfg.get("decoder_initial_trainable_fraction", 0.0)), end_value=1.0, ) def resolve_batch_duration_bucket(conversations: list[dict[str, Any]]) -> str: if not conversations: return "short" buckets = {str(conversation["duration_bucket"]) for conversation in conversations} if len(buckets) == 1: return next(iter(buckets)) return "mixed" def ensure_single_chunk_dataset(dataset: ThoughtLoopConversationDataset, split: str) -> None: offending_examples = [ example for example in dataset.examples if int(example.get("chunk_count", 0)) != 1 ] if offending_examples: raise ValueError( f"{split} dataset expected exactly one chunk per conversation, but found " f"{len(offending_examples)} multi-chunk examples. Check chunk_ticks/max_horizon_ticks." ) def maybe_sort_dataset_by_length( dataset: ThoughtLoopConversationDataset, *, enabled: bool, ) -> None: if not enabled: return dataset.examples.sort( key=lambda item: ( int(item.get("effective_total_ticks", item.get("total_ticks", 0))), int(item.get("total_ticks", 0)), str(item.get("row_id", "")), ) ) def move_chunk_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: return {key: value.to(device) for key, value in batch.items()} def apply_runtime_model_config(model: ThoughtLoopT5Gemma, config: dict[str, Any]) -> None: model.config = config model_cfg = config["model"] expected_z_slots = int(model_cfg["z_slots"]) if int(model.z_slots) != expected_z_slots: raise ValueError( f"Loaded model has z_slots={model.z_slots}, but config requested z_slots={expected_z_slots}." ) model.thought_loop_proposal_mode = str( model_cfg.get("thought_loop_proposal_mode", model.thought_loop_proposal_mode) ).strip().lower() model.preserve_observation_encoder_manifold = bool( model_cfg.get( "preserve_observation_encoder_manifold", model.thought_loop_proposal_mode == "observation_hidden_compression", ) ) model.observation_encoder_use_state_context = bool( model_cfg.get("observation_encoder_use_state_context", False) ) model.latent_attention_mask_mode = str( model_cfg.get("latent_attention_mask_mode", getattr(model, "latent_attention_mask_mode", "observed")) ).strip().lower() if model.latent_attention_mask_mode not in {"observed", "full"}: raise ValueError(f"Unsupported latent_attention_mask_mode: {model.latent_attention_mask_mode}") def build_initial_model(config: dict[str, Any]) -> ThoughtLoopT5Gemma: initial_model_path = config["model"].get("initial_model_path") if initial_model_path: model = ThoughtLoopT5Gemma.from_pretrained(str(initial_model_path), device="cpu", map_location="cpu") apply_runtime_model_config(model, config) return model return ThoughtLoopT5Gemma(config) def rollout_active_rows_only( *, model: ThoughtLoopT5Gemma, batch: dict[str, torch.Tensor], z_state: torch.Tensor, state_mask: torch.Tensor, step_index: int, active_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: active_indices = torch.nonzero(active_mask, as_tuple=False).flatten() active_next_z, active_gate_logits, active_next_mask = model.rollout_step_with_mask( z_state=z_state[active_indices], state_mask=state_mask[active_indices], observation_input_ids=batch["observation_input_ids"][active_indices, step_index, :], observation_attention_mask=batch["observation_attention_mask"][active_indices, step_index, :], observation_role_ids=batch["observation_role"][active_indices, step_index], delta_seconds=batch["delta_seconds"][active_indices, step_index], elapsed_seconds=batch["elapsed_seconds"][active_indices, step_index], since_last_user_seconds=torch.zeros_like(batch["elapsed_seconds"][active_indices, step_index]), since_last_agent_seconds=torch.zeros_like(batch["elapsed_seconds"][active_indices, step_index]), ) next_z = z_state.clone() next_z[active_indices] = active_next_z next_state_mask = state_mask.clone() next_state_mask[active_indices] = active_next_mask gate_logits = torch.zeros(batch["tick_mask"].shape[0], device=z_state.device, dtype=active_gate_logits.dtype) gate_logits[active_indices] = active_gate_logits return next_z, next_state_mask, gate_logits def compute_gate_positive_weight(batch: dict[str, torch.Tensor], training_cfg: dict[str, Any]) -> torch.Tensor: if not bool(training_cfg.get("use_dynamic_gate_positive_weight", True)): return torch.tensor(float(training_cfg["gate_positive_weight"]), device=batch["tick_mask"].device) active_targets = batch["gate_target"][batch["tick_mask"]] positive_count = active_targets.sum() negative_count = active_targets.numel() - positive_count if positive_count.item() <= 0.0: fallback = float(training_cfg["gate_positive_weight"]) return torch.tensor(fallback, device=batch["tick_mask"].device) dynamic_weight = negative_count / positive_count.clamp_min(1.0) min_weight = float(training_cfg.get("dynamic_gate_positive_weight_min", 1.0)) max_weight = float(training_cfg.get("dynamic_gate_positive_weight_max", 256.0)) dynamic_weight = dynamic_weight.clamp(min=min_weight, max=max_weight) return dynamic_weight.to(device=batch["tick_mask"].device) def compute_distance_aware_hazard_gate_loss( gate_logits: torch.Tensor, gate_targets: torch.Tensor, gate_loss_mask: torch.Tensor, training_cfg: dict[str, Any], ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: device = gate_logits.device batch_size, max_steps = gate_logits.shape support_window_ticks = int(training_cfg.get("hazard_soft_target_window_ticks", 16)) target_tau_ticks = float(training_cfg.get("hazard_soft_target_tau_ticks", 4.0)) censor_weight = float(training_cfg.get("hazard_censor_weight", 0.25)) log_hazard = F.logsigmoid(gate_logits) log_survival = F.logsigmoid(-gate_logits) gate_loss_sum = torch.zeros((), device=device) gate_loss_count = torch.zeros((), device=device) hazard_event_count = torch.zeros((), device=device) hazard_censor_count = torch.zeros((), device=device) for batch_index in range(batch_size): active_len = int(gate_loss_mask[batch_index].sum().item()) if active_len <= 0: continue active_targets = gate_targets[batch_index, :active_len] active_log_hazard = log_hazard[batch_index, :active_len] active_log_survival = log_survival[batch_index, :active_len] positive_indices = torch.nonzero(active_targets > 0.5, as_tuple=False).flatten().tolist() segment_start = 0 for positive_index in positive_indices: segment_log_hazard = active_log_hazard[segment_start : positive_index + 1] segment_log_survival = active_log_survival[segment_start : positive_index + 1] if segment_log_hazard.numel() <= 0: segment_start = positive_index + 1 continue prefix_log_survival = torch.cumsum(segment_log_survival, dim=0) - segment_log_survival segment_log_event = prefix_log_survival + segment_log_hazard if support_window_ticks > 0: support_start = max(0, segment_log_event.numel() - support_window_ticks) else: support_start = 0 support_log_event = segment_log_event[support_start:] support_indices = torch.arange( support_start, segment_log_event.numel(), device=device, dtype=torch.float32, ) distances = (segment_log_event.numel() - 1) - support_indices if target_tau_ticks > 0.0: target_weights = torch.exp(-distances / target_tau_ticks) else: target_weights = torch.zeros_like(distances) target_weights[-1] = 1.0 target_weights = target_weights / target_weights.sum().clamp_min(1e-8) segment_loss = -(target_weights * support_log_event).sum() gate_loss_sum = gate_loss_sum + segment_loss gate_loss_count = gate_loss_count + 1.0 hazard_event_count = hazard_event_count + 1.0 segment_start = positive_index + 1 if segment_start < active_len: tail_log_survival = active_log_survival[segment_start:active_len] if tail_log_survival.numel() > 0: censored_loss = -tail_log_survival.mean() gate_loss_sum = gate_loss_sum + censor_weight * censored_loss gate_loss_count = gate_loss_count + censor_weight hazard_censor_count = hazard_censor_count + 1.0 gate_loss = gate_loss_sum / gate_loss_count.clamp_min(1.0) metrics = { "gate_loss_sum": gate_loss_sum.detach(), "gate_count": gate_loss_count.detach(), "hazard_event_count": hazard_event_count.detach(), "hazard_censor_count": hazard_censor_count.detach(), } return gate_loss, metrics def compute_gate_ranking_loss( gate_logits: torch.Tensor, gate_targets: torch.Tensor, gate_loss_mask: torch.Tensor, ranking_margin: float, ranking_lookback: int, ) -> tuple[torch.Tensor, torch.Tensor]: device = gate_logits.device ranking_loss_sum = torch.zeros((), device=device) ranking_pair_count = torch.zeros((), device=device) batch_size, max_steps = gate_logits.shape for batch_index in range(batch_size): for step_index in range(max_steps): if float(gate_loss_mask[batch_index, step_index]) <= 0.0: continue if float(gate_targets[batch_index, step_index]) <= 0.5: continue positive_logit = gate_logits[batch_index, step_index] search_start = max(0, step_index - ranking_lookback) for previous_index in range(search_start, step_index): if float(gate_loss_mask[batch_index, previous_index]) <= 0.0: continue if float(gate_targets[batch_index, previous_index]) >= 0.5: continue negative_logit = gate_logits[batch_index, previous_index] ranking_loss_sum = ranking_loss_sum + F.relu(ranking_margin - positive_logit + negative_logit) ranking_pair_count = ranking_pair_count + 1 return ranking_loss_sum, ranking_pair_count def compute_chunk_losses( model: ThoughtLoopT5Gemma, batch: dict[str, torch.Tensor], training_cfg: dict[str, Any], z_state: torch.Tensor, state_mask: torch.Tensor, tbptt_steps: int | None = None, freeze_gate_head_to_targets: bool = False, ) -> tuple[torch.Tensor, dict[str, torch.Tensor], torch.Tensor, torch.Tensor]: device = batch["tick_mask"].device batch_size, max_steps = batch["tick_mask"].shape gate_loss_mode = str(training_cfg.get("gate_loss_mode", "hazard")).strip().lower() use_hazard_gate_loss = gate_loss_mode == "hazard" pos_weight = ( torch.zeros((), device=device) if use_hazard_gate_loss else compute_gate_positive_weight(batch, training_cfg) ) gate_loss_sum = torch.zeros((), device=device) gate_count = torch.zeros((), device=device) decoder_loss_sum = torch.zeros((), device=device) decoder_token_count = torch.zeros((), device=device) state_penalty_sum = torch.zeros((), device=device) state_penalty_count = torch.zeros((), device=device) tp = torch.zeros((), device=device) fp = torch.zeros((), device=device) fn = torch.zeros((), device=device) correct = torch.zeros((), device=device) counted = torch.zeros((), device=device) target_positive_count = torch.zeros((), device=device) predicted_positive_count = torch.zeros((), device=device) state_delta_weight = float(training_cfg["state_delta_weight"]) steps_since_detach = 0 gate_logits_history: list[torch.Tensor] = [] gate_target_history: list[torch.Tensor] = [] gate_loss_mask_history: list[torch.Tensor] = [] hazard_event_count = torch.zeros((), device=device) hazard_censor_count = torch.zeros((), device=device) teacher_logit_scale = float(training_cfg.get("gate_teacher_logit_scale", 20.0)) for step_index in range(max_steps): active_mask = batch["tick_mask"][:, step_index] if not torch.any(active_mask): break previous_z = z_state next_z, next_state_mask, gate_logits = rollout_active_rows_only( model=model, batch=batch, z_state=z_state, state_mask=state_mask, step_index=step_index, active_mask=active_mask, ) z_state = next_z state_mask = next_state_mask step_loss_mask = active_mask.float() step_targets = batch["gate_target"][:, step_index] if freeze_gate_head_to_targets: gate_logits = torch.where( step_targets > 0.5, torch.full_like(step_targets, teacher_logit_scale), torch.full_like(step_targets, -teacher_logit_scale), ) gate_logits_history.append(gate_logits) gate_target_history.append(step_targets) gate_loss_mask_history.append(step_loss_mask) if torch.any(step_loss_mask > 0): step_predictions = (torch.sigmoid(gate_logits) >= 0.5).float() valid = step_loss_mask > 0 correct = correct + ((step_predictions == step_targets) & valid).float().sum() counted = counted + valid.float().sum() tp = tp + ((step_predictions == 1) & (step_targets == 1) & valid).float().sum() fp = fp + ((step_predictions == 1) & (step_targets == 0) & valid).float().sum() fn = fn + ((step_predictions == 0) & (step_targets == 1) & valid).float().sum() target_positive_count = target_positive_count + ((step_targets == 1) & valid).float().sum() predicted_positive_count = predicted_positive_count + ((step_predictions == 1) & valid).float().sum() if torch.any(step_loss_mask > 0) and not use_hazard_gate_loss and not freeze_gate_head_to_targets: gate_loss_raw = F.binary_cross_entropy_with_logits(gate_logits, step_targets, reduction="none") gate_weight = torch.where(step_targets > 0.5, pos_weight, torch.ones_like(step_targets)) gate_loss_sum = gate_loss_sum + (gate_loss_raw * gate_weight * step_loss_mask).sum() gate_count = gate_count + step_loss_mask.sum() if state_delta_weight > 0.0: state_delta = (next_z - previous_z).pow(2).mean(dim=(1, 2)) state_penalty_sum = state_penalty_sum + (state_delta * step_loss_mask).sum() state_penalty_count = state_penalty_count + step_loss_mask.sum() elif torch.any(step_loss_mask > 0) and state_delta_weight > 0.0: state_delta = (next_z - previous_z).pow(2).mean(dim=(1, 2)) state_penalty_sum = state_penalty_sum + (state_delta * step_loss_mask).sum() state_penalty_count = state_penalty_count + step_loss_mask.sum() step_labels = batch["decoder_labels"][:, step_index, :] decoder_rows = active_mask & torch.any(step_labels != -100, dim=-1) if torch.any(decoder_rows): step_loss_sum, step_token_count = model.decoder_loss( z_state[decoder_rows], step_labels[decoder_rows], encoder_attention_mask=state_mask[decoder_rows], ) decoder_loss_sum = decoder_loss_sum + step_loss_sum decoder_token_count = decoder_token_count + step_token_count if tbptt_steps is not None: steps_since_detach += 1 if steps_since_detach >= tbptt_steps: z_state = z_state.detach() state_mask = state_mask.detach() steps_since_detach = 0 if gate_logits_history: stacked_gate_logits = torch.stack(gate_logits_history, dim=1) stacked_gate_targets = torch.stack(gate_target_history, dim=1) stacked_gate_loss_mask = torch.stack(gate_loss_mask_history, dim=1) else: stacked_gate_logits = torch.zeros((batch_size, 0), device=device) stacked_gate_targets = torch.zeros((batch_size, 0), device=device) stacked_gate_loss_mask = torch.zeros((batch_size, 0), device=device) gate_ranking_loss = torch.zeros((), device=device) gate_ranking_loss_sum = torch.zeros((), device=device) gate_ranking_pair_count = torch.zeros((), device=device) if freeze_gate_head_to_targets: gate_loss = torch.zeros((), device=device) gate_loss_sum = torch.zeros((), device=device) gate_count = torch.zeros((), device=device) hazard_event_count = torch.zeros((), device=device) hazard_censor_count = torch.zeros((), device=device) elif use_hazard_gate_loss and gate_logits_history: gate_loss, hazard_metrics = compute_distance_aware_hazard_gate_loss( gate_logits=stacked_gate_logits, gate_targets=stacked_gate_targets, gate_loss_mask=stacked_gate_loss_mask, training_cfg=training_cfg, ) gate_loss_sum = hazard_metrics["gate_loss_sum"] gate_count = hazard_metrics["gate_count"] hazard_event_count = hazard_metrics["hazard_event_count"] hazard_censor_count = hazard_metrics["hazard_censor_count"] else: gate_loss = gate_loss_sum / gate_count.clamp_min(1.0) if ( not use_hazard_gate_loss and bool(training_cfg.get("use_ranked_gate_objective", False)) and gate_logits_history and not freeze_gate_head_to_targets ): gate_ranking_loss_sum, gate_ranking_pair_count = compute_gate_ranking_loss( gate_logits=stacked_gate_logits, gate_targets=stacked_gate_targets, gate_loss_mask=stacked_gate_loss_mask, ranking_margin=float(training_cfg.get("gate_ranking_margin", 0.5)), ranking_lookback=int(training_cfg.get("gate_ranking_lookback", 64)), ) gate_ranking_loss = gate_ranking_loss_sum / gate_ranking_pair_count.clamp_min(1.0) gate_loss = ( float(training_cfg.get("gate_bce_weight", 1.0)) * gate_loss + float(training_cfg.get("gate_ranking_weight", 1.0)) * gate_ranking_loss ) decoder_loss = decoder_loss_sum / decoder_token_count.clamp_min(1.0) state_delta_loss = state_penalty_sum / state_penalty_count.clamp_min(1.0) total_loss = ( float(training_cfg["gate_loss_weight"]) * gate_loss + float(training_cfg["decoder_loss_weight"]) * decoder_loss + state_delta_weight * state_delta_loss ) metrics = { "loss_total": total_loss.detach(), "gate_loss_sum": gate_loss_sum.detach(), "gate_count": gate_count.detach(), "decoder_loss_sum": decoder_loss_sum.detach(), "decoder_token_count": decoder_token_count.detach(), "state_penalty_sum": state_penalty_sum.detach(), "state_penalty_count": state_penalty_count.detach(), "gate_tp": tp.detach(), "gate_fp": fp.detach(), "gate_fn": fn.detach(), "gate_correct": correct.detach(), "gate_eval_count": counted.detach(), "gate_target_positive_count": target_positive_count.detach(), "gate_pred_positive_count": predicted_positive_count.detach(), "gate_ranking_loss_sum": gate_ranking_loss_sum.detach(), "gate_ranking_pair_count": gate_ranking_pair_count.detach(), "hazard_event_count": hazard_event_count.detach(), "hazard_censor_count": hazard_censor_count.detach(), "active_ticks": batch["tick_mask"].float().sum().detach(), "gate_positive_weight": pos_weight.detach(), "metrics_count": torch.ones((), device=device), } return total_loss, metrics, z_state, state_mask def reduce_metrics( accelerator: Accelerator, metric_list: list[dict[str, torch.Tensor]], training_cfg: dict[str, Any], ) -> dict[str, float]: totals: dict[str, torch.Tensor] = {} for metrics in metric_list: for key, value in metrics.items(): totals[key] = totals.get(key, torch.zeros_like(value)) + value reduced = {key: accelerator.reduce(value, reduction="sum") for key, value in totals.items()} gate_loss_mode = str(training_cfg.get("gate_loss_mode", "hazard")).strip().lower() gate_loss = reduced["gate_loss_sum"] / reduced["gate_count"].clamp_min(1.0) gate_ranking_loss = reduced["gate_ranking_loss_sum"] / reduced["gate_ranking_pair_count"].clamp_min(1.0) decoder_loss = reduced["decoder_loss_sum"] / reduced["decoder_token_count"].clamp_min(1.0) state_delta_loss = reduced["state_penalty_sum"] / reduced["state_penalty_count"].clamp_min(1.0) gate_precision = reduced["gate_tp"] / (reduced["gate_tp"] + reduced["gate_fp"]).clamp_min(1.0) gate_recall = reduced["gate_tp"] / (reduced["gate_tp"] + reduced["gate_fn"]).clamp_min(1.0) gate_f1 = 2 * gate_precision * gate_recall / (gate_precision + gate_recall).clamp_min(1e-8) if gate_loss_mode != "hazard" and bool(training_cfg.get("use_ranked_gate_objective", False)): gate_loss = ( float(training_cfg.get("gate_bce_weight", 1.0)) * gate_loss + float(training_cfg.get("gate_ranking_weight", 1.0)) * gate_ranking_loss ) total_loss = ( float(training_cfg["gate_loss_weight"]) * gate_loss + float(training_cfg["decoder_loss_weight"]) * decoder_loss + float(training_cfg["state_delta_weight"]) * state_delta_loss ) return { "loss_total": float(total_loss), "loss_gate": float(gate_loss), "loss_gate_ranking": float(gate_ranking_loss), "loss_decoder": float(decoder_loss), "loss_state_delta": float(state_delta_loss), "gate_accuracy": float(reduced["gate_correct"] / reduced["gate_eval_count"].clamp_min(1.0)), "gate_precision": float(gate_precision), "gate_recall": float(gate_recall), "gate_f1": float(gate_f1), "gate_target_positive_count": float(reduced["gate_target_positive_count"]), "gate_pred_positive_count": float(reduced["gate_pred_positive_count"]), "gate_positive_weight": float(reduced["gate_positive_weight"] / reduced["metrics_count"].clamp_min(1.0)), "hazard_event_count": float(reduced["hazard_event_count"]), "hazard_censor_count": float(reduced["hazard_censor_count"]), "decoder_token_count": float(reduced["decoder_token_count"]), "active_ticks": float(reduced["active_ticks"]), } @torch.no_grad() def evaluate( accelerator: Accelerator, model: ThoughtLoopT5Gemma, dataloader: DataLoader, training_cfg: dict[str, Any], rollout_cfg: dict[str, Any], tokenizer_pad_token_id: int, max_batches: int | None = None, ) -> dict[str, float]: model.eval() all_metrics: list[dict[str, torch.Tensor]] = [] chunk_ticks = int(rollout_cfg["chunk_ticks"]) tick_seconds = float(rollout_cfg["tick_seconds"]) for batch_index, conversations in enumerate(dataloader): if max_batches is not None and batch_index >= max_batches: break batch_size = len(conversations) z_state = model.initial_state(batch_size=batch_size, device=accelerator.device) state_mask = model.initial_state_mask(batch_size=batch_size, device=accelerator.device) max_chunks = max(int(conversation["chunk_count"]) for conversation in conversations) for chunk_index in range(max_chunks): chunk_batch = build_chunk_batch( conversations=conversations, chunk_index=chunk_index, pad_token_id=tokenizer_pad_token_id, chunk_ticks=chunk_ticks, tick_seconds=tick_seconds, ) if chunk_batch is None: break chunk_batch = move_chunk_batch_to_device(chunk_batch, accelerator.device) _, metrics, z_state, state_mask = compute_chunk_losses( model=model, batch=chunk_batch, training_cfg=training_cfg, z_state=z_state, state_mask=state_mask, tbptt_steps=None, freeze_gate_head_to_targets=False, ) z_state = z_state.detach() state_mask = state_mask.detach() all_metrics.append(metrics) if not all_metrics: model.train() return {} reduced = reduce_metrics(accelerator, all_metrics, training_cfg) model.train() return reduced def copy_supporting_files(output_dir: Path, config_path: Path) -> None: for filename in ("config.py", "data.py", "model.py", "inference.py", "train.py"): shutil.copy2(SCRIPT_DIR / filename, output_dir / filename) shutil.copy2(config_path, output_dir / config_path.name) def build_model_card(config: dict[str, Any], best_metrics: dict[str, float]) -> str: return "\n".join( [ "# Samantha Thought-Loop SFT", "", f"- Base model: `{config['model']['base_model_name']}`", f"- Cleaned dataset repo: `{config['dataset'].get('cleaned_repo_id')}`", f"- Raw dataset repo: `{config['dataset'].get('raw_repo_id')}`", f"- Latent slots: `{config['model']['z_slots']}`", f"- Fixed tick seconds: `{config['rollout']['tick_seconds']}`", f"- Chunk ticks: `{config['rollout']['chunk_ticks']}`", f"- Max horizon ticks: `{config['rollout'].get('max_horizon_ticks', 36000)}`", f"- Explicit time features: `{config['model'].get('use_explicit_time_features', False)}`", f"- Gate loss mode: `{config['training'].get('gate_loss_mode', 'hazard')}`", f"- Hazard soft target window: `{config['training'].get('hazard_soft_target_window_ticks', 16)}`", f"- Hazard soft target tau ticks: `{config['training'].get('hazard_soft_target_tau_ticks', 4.0)}`", f"- Hazard censor weight: `{config['training'].get('hazard_censor_weight', 0.25)}`", f"- TBPTT by bucket: `{config['training'].get('tbptt_steps_by_bucket', {})}`", "", "## Best Validation Metrics", "", *[f"- `{key}`: `{value:.6f}`" for key, value in sorted(best_metrics.items())], "", ] ) def export_model( accelerator: Accelerator, model: ThoughtLoopT5Gemma, output_dir: Path, config: dict[str, Any], config_path: Path, best_metrics: dict[str, float], ) -> None: if not accelerator.is_main_process: return output_dir.mkdir(parents=True, exist_ok=True) unwrapped = accelerator.unwrap_model(model) unwrapped.save_pretrained(output_dir) copy_supporting_files(output_dir, config_path) (output_dir / "README.md").write_text(build_model_card(config, best_metrics), encoding="utf-8") save_json(output_dir / "best_metrics.json", best_metrics) def maybe_push_to_hub(config: dict[str, Any], export_dir: Path) -> None: model_repo_id = config["hub"].get("model_repo_id") if not model_repo_id: return api = HfApi() api.create_repo(repo_id=model_repo_id, repo_type="model", exist_ok=True, private=bool(config["hub"].get("private"))) api.upload_folder(folder_path=str(export_dir), repo_id=model_repo_id, repo_type="model") def maybe_push_checkpoint_to_hub( accelerator: Accelerator, config: dict[str, Any], checkpoint_dir: Path, ) -> None: model_repo_id = config["hub"].get("model_repo_id") if not accelerator.is_main_process or not model_repo_id: return if not bool(config["hub"].get("push_checkpoints", True)): return run_name = str(config["wandb"].get("run_name", "run")) checkpoint_prefix = str(config["hub"].get("checkpoint_path_prefix", "checkpoints")).strip("/") path_parts = [part for part in (checkpoint_prefix, SCRIPT_DIR.name, run_name, checkpoint_dir.name) if part] path_in_repo = "/".join(path_parts) api = HfApi() api.create_repo(repo_id=model_repo_id, repo_type="model", exist_ok=True, private=bool(config["hub"].get("private"))) api.upload_folder( folder_path=str(checkpoint_dir), repo_id=model_repo_id, repo_type="model", path_in_repo=path_in_repo, commit_message=f"Add {SCRIPT_DIR.name} {checkpoint_dir.name}", ) def read_json(path: Path) -> dict[str, Any]: if not path.exists(): return {} with path.open("r", encoding="utf-8") as handle: payload = json.load(handle) if not isinstance(payload, dict): raise ValueError(f"Expected a JSON object in {path}.") return payload def checkpoint_global_step(checkpoint_dir: Path) -> int | None: state = read_json(checkpoint_dir / "trainer_state.json") if "global_step" in state: return int(state["global_step"]) prefix = "checkpoint-" if checkpoint_dir.name.startswith(prefix): suffix = checkpoint_dir.name[len(prefix) :] if suffix.isdigit(): return int(suffix) return None def find_latest_checkpoint(run_dir: Path) -> Path | None: checkpoints = [] for candidate in run_dir.glob("checkpoint-*"): if not candidate.is_dir(): continue step = checkpoint_global_step(candidate) if step is not None: checkpoints.append((step, candidate)) if not checkpoints: return None return max(checkpoints, key=lambda item: item[0])[1] def resolve_resume_checkpoint(requested: str | None, run_dir: Path) -> Path | None: if requested is None: return None value = str(requested).strip() if not value: return None if value.lower() in {"1", "true", "yes", "auto", "latest"}: checkpoint_dir = find_latest_checkpoint(run_dir) if checkpoint_dir is None: raise FileNotFoundError(f"No checkpoint-* directory found under {run_dir}.") return checkpoint_dir checkpoint_dir = Path(value).expanduser() if not checkpoint_dir.is_absolute(): checkpoint_dir = SCRIPT_DIR / checkpoint_dir checkpoint_dir = checkpoint_dir.resolve() if not checkpoint_dir.is_dir(): raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_dir}") return checkpoint_dir def parse_wandb_run_id_from_dir(path: Path) -> str | None: name = path.resolve(strict=False).name if path.is_symlink() else path.name for prefix in ("offline-run-", "run-"): if name.startswith(prefix): parts = name.split("-") if parts: return parts[-1] return None def discover_local_wandb_run_id(run_name: str) -> str | None: wandb_dir = SCRIPT_DIR / "wandb" if not wandb_dir.exists(): return None candidates: list[Path] = [] latest = wandb_dir / "latest-run" if latest.exists() or latest.is_symlink(): candidates.append(latest) candidates.extend(sorted(wandb_dir.glob("*run-*"), key=lambda path: path.stat().st_mtime, reverse=True)) for candidate in candidates: run_id = parse_wandb_run_id_from_dir(candidate) if not run_id: continue resolved = candidate.resolve(strict=False) config_path = resolved / "files" / "config.yaml" debug_log_path = resolved / "logs" / "debug.log" if config_path.exists() and run_name in config_path.read_text(encoding="utf-8", errors="ignore"): return run_id if debug_log_path.exists() and run_name in debug_log_path.read_text(encoding="utf-8", errors="ignore"): return run_id return None def resolve_wandb_run_id(args: argparse.Namespace, config: dict[str, Any], run_dir: Path) -> str: run_id_path = run_dir / "wandb_run_id.txt" explicit = args.wandb_run_id or os.environ.get("WANDB_RUN_ID") or config["wandb"].get("run_id") if explicit: run_id = str(explicit).strip() elif run_id_path.exists(): run_id = run_id_path.read_text(encoding="utf-8").strip() else: run_id = discover_local_wandb_run_id(str(config["wandb"].get("run_name", ""))) or secrets.token_hex(4) if not run_id: run_id = secrets.token_hex(4) run_id_path.write_text(run_id + "\n", encoding="utf-8") return run_id def save_trainer_state( checkpoint_dir: Path, *, global_step: int, epoch: int, next_batch_index: int, best_validation_loss: float, best_metrics: dict[str, float], wandb_run_id: str | None, ) -> None: payload: dict[str, Any] = { "global_step": int(global_step), "epoch": int(epoch), "next_batch_index": int(next_batch_index), "best_validation_loss": best_validation_loss if math.isfinite(best_validation_loss) else None, "best_metrics": best_metrics, "wandb_run_id": wandb_run_id, } save_json(checkpoint_dir / "trainer_state.json", payload) def load_resume_progress( checkpoint_dir: Path, *, total_chunk_microsteps: int, gradient_accumulation_steps: int, num_train_epochs: int, ) -> tuple[int, int, int, dict[str, Any]]: state = read_json(checkpoint_dir / "trainer_state.json") global_step = int(state.get("global_step") or checkpoint_global_step(checkpoint_dir) or 0) if "epoch" in state and "next_batch_index" in state: start_epoch = int(state["epoch"]) batches_to_skip = int(state["next_batch_index"]) else: updates_per_epoch = max(1, math.ceil(total_chunk_microsteps / max(gradient_accumulation_steps, 1))) start_epoch = min(global_step // updates_per_epoch, max(num_train_epochs - 1, 0)) updates_into_epoch = global_step - start_epoch * updates_per_epoch batches_to_skip = updates_into_epoch * max(gradient_accumulation_steps, 1) while batches_to_skip >= total_chunk_microsteps and start_epoch < num_train_epochs: batches_to_skip -= total_chunk_microsteps start_epoch += 1 return global_step, start_epoch, batches_to_skip, state def load_best_metrics(export_dir: Path, resume_state: dict[str, Any]) -> tuple[float, dict[str, float]]: best_metrics = resume_state.get("best_metrics") if not isinstance(best_metrics, dict): best_metrics = read_json(export_dir / "best_metrics.json") loss = resume_state.get("best_validation_loss") if loss is None and isinstance(best_metrics, dict): loss = best_metrics.get("loss_total") best_validation_loss = float(loss) if loss is not None else float("inf") return best_validation_loss, {str(key): float(value) for key, value in best_metrics.items()} def main() -> None: args = parse_args() config_path = Path(args.config) config = load_config(config_path) training_cfg = config["training"] rollout_cfg = config["rollout"] set_seed(int(training_cfg["seed"])) wandb_enabled = bool(config["wandb"].get("enabled", True) and wandb is not None) accelerator = Accelerator( gradient_accumulation_steps=int(training_cfg["gradient_accumulation_steps"]), mixed_precision=training_cfg.get("mixed_precision", "bf16"), log_with="wandb" if wandb_enabled else None, project_dir=config["paths"]["run_root"], ) run_root = ensure_dir(config["paths"]["run_root"]) export_root = ensure_dir(config["paths"]["export_root"]) run_name = config["wandb"]["run_name"] run_dir = ensure_dir(run_root / run_name) export_dir = ensure_dir(export_root / run_name) resume_request = args.resume_from_checkpoint or training_cfg.get("resume_from_checkpoint") resume_checkpoint = resolve_resume_checkpoint(resume_request, run_dir) wandb_run_id = resolve_wandb_run_id(args, config, run_dir) if wandb_enabled else None model = build_initial_model(config) if model.use_explicit_time_features: raise ValueError("This trainer expects model.use_explicit_time_features=false for fixed-cadence training.") tokenizer = model.tokenizer train_dataset = ThoughtLoopConversationDataset(config=config, tokenizer=tokenizer, split="train") validation_dataset = ThoughtLoopConversationDataset(config=config, tokenizer=tokenizer, split="validation") sort_dataset_by_length = bool(training_cfg.get("sort_dataset_by_length", False)) ensure_single_chunk_dataset(train_dataset, "train") ensure_single_chunk_dataset(validation_dataset, "validation") maybe_sort_dataset_by_length(train_dataset, enabled=sort_dataset_by_length) maybe_sort_dataset_by_length(validation_dataset, enabled=sort_dataset_by_length) train_loader = DataLoader( train_dataset, batch_size=int(training_cfg["micro_batch_size"]), shuffle=bool(training_cfg.get("shuffle_train", True)) and not sort_dataset_by_length, num_workers=int(training_cfg["num_workers"]), pin_memory=True, persistent_workers=bool(int(training_cfg["num_workers"]) > 0), collate_fn=identity_collate, ) validation_loader = DataLoader( validation_dataset, batch_size=int(training_cfg["eval_batch_size"]), shuffle=False, num_workers=int(training_cfg["num_workers"]), pin_memory=True, persistent_workers=bool(int(training_cfg["num_workers"]) > 0), collate_fn=identity_collate, ) optimizer = build_optimizer(model, training_cfg) total_chunk_microsteps = len(train_loader) total_update_steps = max( 1, math.ceil(total_chunk_microsteps / max(int(training_cfg["gradient_accumulation_steps"]), 1)) * int(training_cfg["num_train_epochs"]), ) warmup_steps = max(1, int(total_update_steps * float(training_cfg["warmup_ratio"]))) scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=max(1, total_update_steps), ) model, optimizer, train_loader, validation_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, validation_loader, scheduler ) resume_state: dict[str, Any] = {} global_step = 0 start_epoch = 0 batches_to_skip = 0 if resume_checkpoint is not None: accelerator.print(f"Loading local checkpoint from {resume_checkpoint}") accelerator.load_state(str(resume_checkpoint)) global_step, start_epoch, batches_to_skip, resume_state = load_resume_progress( resume_checkpoint, total_chunk_microsteps=total_chunk_microsteps, gradient_accumulation_steps=int(training_cfg["gradient_accumulation_steps"]), num_train_epochs=int(training_cfg["num_train_epochs"]), ) accelerator.print( f"Resuming at global_step={global_step}, epoch={start_epoch + 1}, " f"skipping {batches_to_skip} batch(es)." ) if accelerator.is_main_process and wandb_enabled: accelerator.print(f"W&B run id: {wandb_run_id}") accelerator.init_trackers( project_name=config["wandb"]["project"], config=flatten_for_wandb(config), init_kwargs={ "wandb": { "name": run_name, "id": wandb_run_id, "resume": str(config["wandb"].get("resume", "allow")), } }, ) unwrapped_model = accelerator.unwrap_model(model) current_decoder_fraction = resolve_decoder_trainable_fraction(training_cfg, global_step, total_update_steps) if current_decoder_fraction is None: current_decoder_fraction = 1.0 current_decoder_trainable_layers = unwrapped_model.set_decoder_trainable_fraction(current_decoder_fraction) current_gate_head_frozen = should_freeze_gate_head(training_cfg, global_step, total_update_steps) unwrapped_model.set_gate_trainable(not current_gate_head_frozen) progress_bar = tqdm( total=total_update_steps, initial=global_step, disable=not accelerator.is_local_main_process, desc="Training", dynamic_ncols=True, ) best_validation_loss, best_metrics = load_best_metrics(export_dir, resume_state) model.train() chunk_ticks = int(rollout_cfg["chunk_ticks"]) tick_seconds = float(rollout_cfg["tick_seconds"]) for epoch in range(start_epoch, int(training_cfg["num_train_epochs"])): epoch_train_loader = train_loader batch_index_offset = 0 if epoch == start_epoch and batches_to_skip > 0: epoch_train_loader = accelerator.skip_first_batches(train_loader, batches_to_skip) batch_index_offset = batches_to_skip for batch_index, conversations in enumerate(epoch_train_loader, start=batch_index_offset): batch_bucket = resolve_batch_duration_bucket(conversations) current_tbptt_steps = resolve_bucket_tbptt_steps(training_cfg, batch_bucket) current_horizon_ticks = resolve_bucket_horizon_ticks(batch_bucket, rollout_cfg) z_state = model.initial_state(batch_size=len(conversations), device=accelerator.device) state_mask = model.initial_state_mask(batch_size=len(conversations), device=accelerator.device) max_chunks = max(int(conversation["chunk_count"]) for conversation in conversations) for chunk_index in range(max_chunks): scheduled_decoder_fraction = resolve_decoder_trainable_fraction(training_cfg, global_step, total_update_steps) if scheduled_decoder_fraction is None: scheduled_decoder_fraction = 1.0 if abs(scheduled_decoder_fraction - current_decoder_fraction) > 1e-6: current_decoder_fraction = scheduled_decoder_fraction current_decoder_trainable_layers = unwrapped_model.set_decoder_trainable_fraction( current_decoder_fraction ) scheduled_gate_head_frozen = should_freeze_gate_head(training_cfg, global_step, total_update_steps) if scheduled_gate_head_frozen != current_gate_head_frozen: current_gate_head_frozen = scheduled_gate_head_frozen unwrapped_model.set_gate_trainable(not current_gate_head_frozen) chunk_batch = build_chunk_batch( conversations=conversations, chunk_index=chunk_index, pad_token_id=tokenizer.pad_token_id, chunk_ticks=chunk_ticks, tick_seconds=tick_seconds, ) if chunk_batch is None: break chunk_batch = move_chunk_batch_to_device(chunk_batch, accelerator.device) with accelerator.accumulate(model): total_loss, metrics, z_state, state_mask = compute_chunk_losses( model=model, batch=chunk_batch, training_cfg=training_cfg, z_state=z_state, state_mask=state_mask, tbptt_steps=current_tbptt_steps, freeze_gate_head_to_targets=current_gate_head_frozen, ) accelerator.backward(total_loss) accelerator.clip_grad_norm_(model.parameters(), float(training_cfg["max_grad_norm"])) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) z_state = z_state.detach() state_mask = state_mask.detach() if accelerator.sync_gradients: global_step += 1 progress_bar.update(1) progress_bar.set_postfix( loss=f"{float(total_loss.detach()):.4f}", bucket=batch_bucket, tbptt=int(current_tbptt_steps or 0), horizon=int(current_horizon_ticks), decoder=current_decoder_trainable_layers, gate="locked" if current_gate_head_frozen else "train", refresh=False, ) if global_step % int(training_cfg["logging_steps"]) == 0: reduced = reduce_metrics(accelerator, [metrics], training_cfg) reduced["lr"] = scheduler.get_last_lr()[0] reduced["tbptt_steps"] = float(current_tbptt_steps or 0) reduced["horizon_ticks"] = float(current_horizon_ticks) reduced["chunk_ticks"] = float(chunk_ticks) reduced["decoder_trainable_layers"] = float(current_decoder_trainable_layers) reduced["gate_head_frozen"] = float(current_gate_head_frozen) reduced["epoch"] = epoch + 1 reduced["global_step"] = global_step accelerator.log({f"train/{key}": value for key, value in reduced.items()}, step=global_step) if global_step % int(training_cfg["eval_steps"]) == 0: validation_metrics = evaluate( accelerator=accelerator, model=model, dataloader=validation_loader, training_cfg=training_cfg, rollout_cfg=rollout_cfg, tokenizer_pad_token_id=tokenizer.pad_token_id, max_batches=training_cfg.get("eval_max_batches"), ) accelerator.log( {f"validation/{key}": value for key, value in validation_metrics.items()}, step=global_step, ) if validation_metrics and validation_metrics["loss_total"] < best_validation_loss: best_validation_loss = validation_metrics["loss_total"] best_metrics = validation_metrics export_model( accelerator=accelerator, model=model, output_dir=export_dir, config=config, config_path=config_path, best_metrics=best_metrics, ) if global_step % int(training_cfg["checkpoint_steps"]) == 0: checkpoint_dir = ensure_dir(run_dir / f"checkpoint-{global_step}") accelerator.save_state(str(checkpoint_dir)) if accelerator.is_main_process: save_trainer_state( checkpoint_dir, global_step=global_step, epoch=epoch, next_batch_index=batch_index + 1, best_validation_loss=best_validation_loss, best_metrics=best_metrics, wandb_run_id=wandb_run_id, ) accelerator.wait_for_everyone() maybe_push_checkpoint_to_hub( accelerator=accelerator, config=config, checkpoint_dir=checkpoint_dir, ) accelerator.wait_for_everyone() if not best_metrics: best_metrics = evaluate( accelerator=accelerator, model=model, dataloader=validation_loader, training_cfg=training_cfg, rollout_cfg=rollout_cfg, tokenizer_pad_token_id=tokenizer.pad_token_id, max_batches=training_cfg.get("eval_max_batches"), ) export_model( accelerator=accelerator, model=model, output_dir=export_dir, config=config, config_path=config_path, best_metrics=best_metrics, ) if accelerator.is_main_process: training_summary = { "train_conversations": len(train_dataset), "validation_conversations": len(validation_dataset), "train_chunk_microsteps": total_chunk_microsteps, "trainable_parameters": accelerator.unwrap_model(model).trainable_parameter_count(), "best_metrics": best_metrics, "global_step": global_step, "resumed_from_checkpoint": str(resume_checkpoint) if resume_checkpoint is not None else None, "wandb_run_id": wandb_run_id, "decoder_layer_count": unwrapped_model.decoder_layer_count(), "tick_seconds": tick_seconds, "chunk_ticks": chunk_ticks, "tbptt_steps_by_bucket": training_cfg.get("tbptt_steps_by_bucket", {}), "horizon_ticks_by_bucket": rollout_cfg.get("horizon_ticks_by_bucket", {}), "shuffle_train": bool(training_cfg.get("shuffle_train", True)), "sort_dataset_by_length": sort_dataset_by_length, } save_json(run_dir / "training_summary.json", training_summary) maybe_push_to_hub(config, export_dir) progress_bar.close() accelerator.end_training() if __name__ == "__main__": main()