| from __future__ import annotations |
|
|
| import json |
| import math |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from huggingface_hub import snapshot_download |
| from safetensors.torch import load_model, save_file, save_model |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| if str(SCRIPT_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
|
|
| class ParameterlessRMSNorm(nn.Module): |
| def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.eps = eps |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| original_dtype = inputs.dtype |
| normalized = inputs.float() |
| normalized = normalized * torch.rsqrt(normalized.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
| return normalized.to(dtype=original_dtype) |
|
|
|
|
| class TemporalFeatureProjector(nn.Module): |
| def __init__(self, hidden_size: int, num_time_tokens: int, scalar_count: int = 4) -> None: |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_time_tokens = num_time_tokens |
| self.mlp = nn.Sequential( |
| nn.Linear(scalar_count, hidden_size), |
| nn.GELU(), |
| nn.Linear(hidden_size, hidden_size * num_time_tokens), |
| ) |
| self.norm = nn.LayerNorm(hidden_size) |
|
|
| def forward(self, temporal_features: torch.Tensor) -> torch.Tensor: |
| projected = self.mlp(temporal_features) |
| projected = projected.view(temporal_features.size(0), self.num_time_tokens, self.hidden_size) |
| return self.norm(projected) |
|
|
|
|
| class ThoughtLoopT5Gemma(nn.Module): |
| def __init__(self, config: dict[str, Any]) -> None: |
| super().__init__() |
| self.config = config |
| model_cfg = config["model"] |
| training_cfg = config.get("training", {}) |
|
|
| backbone_dtype = getattr(torch, model_cfg.get("dtype", "bfloat16")) |
| backbone_kwargs: dict[str, Any] = { |
| |
| |
| "torch_dtype": backbone_dtype, |
| } |
| attn_implementation = model_cfg.get("attn_implementation") |
| if attn_implementation: |
| backbone_kwargs["attn_implementation"] = attn_implementation |
|
|
| self.backbone = AutoModelForSeq2SeqLM.from_pretrained(model_cfg["base_model_name"], **backbone_kwargs) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_cfg["base_model_name"]) |
|
|
| if training_cfg.get("gradient_checkpointing", True): |
| self.backbone.gradient_checkpointing_enable() |
| if getattr(self.backbone.config, "use_cache", None) is not None: |
| self.backbone.config.use_cache = False |
|
|
| self.encoder = self.backbone.get_encoder() |
| self.decoder = self.backbone.get_decoder() |
| self.input_embeddings = self.backbone.get_input_embeddings() |
| hidden_size = int(self.backbone.config.decoder.hidden_size) |
| self.hidden_size = hidden_size |
| self.z_slots = int(model_cfg["z_slots"]) |
| self.thought_loop_proposal_mode = str( |
| model_cfg.get("thought_loop_proposal_mode", "latent_prefix") |
| ).strip().lower() |
| if self.thought_loop_proposal_mode not in {"latent_prefix", "observation_hidden_compression"}: |
| raise ValueError(f"Unsupported thought_loop_proposal_mode: {self.thought_loop_proposal_mode}") |
| self.preserve_observation_encoder_manifold = bool( |
| model_cfg.get( |
| "preserve_observation_encoder_manifold", |
| self.thought_loop_proposal_mode == "observation_hidden_compression", |
| ) |
| ) |
| self.observation_encoder_use_state_context = bool( |
| model_cfg.get("observation_encoder_use_state_context", False) |
| ) |
| self.latent_attention_mask_mode = str( |
| model_cfg.get("latent_attention_mask_mode", "observed") |
| ).strip().lower() |
| if self.latent_attention_mask_mode not in {"observed", "full"}: |
| raise ValueError(f"Unsupported latent_attention_mask_mode: {self.latent_attention_mask_mode}") |
| self.use_explicit_time_features = bool(model_cfg.get("use_explicit_time_features", False)) |
| self.num_time_tokens = int(model_cfg["num_time_tokens"]) if self.use_explicit_time_features else 0 |
| self.observation_role_count = 3 |
| magicnorm_eps = float(model_cfg.get("magicnorm_eps", 1e-6)) |
|
|
| |
| |
| |
| self.z_init = nn.Parameter(torch.randn(self.z_slots, hidden_size) * 0.02) |
| self.segment_embeddings = nn.Embedding(3, hidden_size) |
| self.observation_role_embeddings = nn.Embedding(self.observation_role_count, hidden_size) |
| self.temporal_projector = ( |
| TemporalFeatureProjector(hidden_size, self.num_time_tokens) if self.use_explicit_time_features else None |
| ) |
|
|
| self.state_gate = nn.Sequential( |
| nn.Linear(hidden_size * 2, hidden_size), |
| nn.GELU(), |
| nn.Linear(hidden_size, hidden_size), |
| ) |
| self._initialize_update_gate_bias(model_cfg) |
| self.state_gate_input_norm = ParameterlessRMSNorm(hidden_size * 2, eps=magicnorm_eps) |
| self.proposed_state_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps) |
| self.recurrent_state_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps) |
| self.gate_context_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps) |
|
|
| self.gate_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) |
| self.gate_pool = nn.MultiheadAttention( |
| embed_dim=hidden_size, |
| num_heads=int(model_cfg.get("gate_attention_heads", 4)), |
| batch_first=True, |
| ) |
| self.gate_head = nn.Sequential( |
| nn.LayerNorm(hidden_size), |
| nn.Linear(hidden_size, hidden_size), |
| nn.GELU(), |
| nn.Linear(hidden_size, 1), |
| ) |
|
|
| self._drop_unused_vision_modules() |
|
|
| def _initialize_update_gate_bias(self, model_cfg: dict[str, Any]) -> None: |
| if "initial_update_gate_bias" not in model_cfg: |
| return |
| final_gate_layer = self.state_gate[-1] |
| if not isinstance(final_gate_layer, nn.Linear) or final_gate_layer.bias is None: |
| raise RuntimeError("state_gate must end in a Linear layer with bias to use initial_update_gate_bias.") |
| nn.init.constant_(final_gate_layer.bias, float(model_cfg["initial_update_gate_bias"])) |
|
|
| def _drop_unused_vision_modules(self) -> None: |
| for module_path in ( |
| "encoder.vision_tower", |
| "encoder.vision_model", |
| "model.encoder.vision_tower", |
| "model.encoder.vision_model", |
| ): |
| parent: Any = self.backbone |
| path_parts = module_path.split(".") |
| try: |
| for part in path_parts[:-1]: |
| parent = getattr(parent, part) |
| except AttributeError: |
| continue |
|
|
| child_name = path_parts[-1] |
| child = getattr(parent, child_name, None) |
| if isinstance(child, nn.Module): |
| setattr(parent, child_name, None) |
|
|
| @property |
| def device(self) -> torch.device: |
| return next(self.parameters()).device |
|
|
| @property |
| def backbone_dtype(self) -> torch.dtype: |
| return self.input_embeddings.weight.dtype |
|
|
| @property |
| def recurrent_dtype(self) -> torch.dtype: |
| return self.z_init.dtype |
|
|
| def trainable_parameter_count(self) -> int: |
| return sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad) |
|
|
| def _decoder_blocks(self) -> Any: |
| for attribute_name in ("block", "layers", "h"): |
| blocks = getattr(self.decoder, attribute_name, None) |
| if blocks is not None: |
| return blocks |
| return None |
|
|
| def decoder_layer_count(self) -> int: |
| blocks = self._decoder_blocks() |
| return len(blocks) if blocks is not None else 0 |
|
|
| def set_gate_trainable(self, trainable: bool) -> None: |
| self.gate_query.requires_grad_(trainable) |
| self.gate_pool.requires_grad_(trainable) |
| self.gate_head.requires_grad_(trainable) |
| self.gate_context_norm.requires_grad_(trainable) |
|
|
| def set_decoder_trainable_fraction(self, trainable_fraction: float) -> int: |
| blocks = self._decoder_blocks() |
| if blocks is None: |
| return 0 |
|
|
| clamped_fraction = min(max(trainable_fraction, 0.0), 1.0) |
| total_blocks = len(blocks) |
| trainable_blocks = min(total_blocks, math.ceil(total_blocks * clamped_fraction)) if clamped_fraction > 0 else 0 |
| first_trainable_index = total_blocks - trainable_blocks |
|
|
| for block_index, block in enumerate(blocks): |
| block.requires_grad_(block_index >= first_trainable_index) |
|
|
| for attribute_name in ("final_layer_norm", "norm", "layer_norm"): |
| maybe_module = getattr(self.decoder, attribute_name, None) |
| if isinstance(maybe_module, nn.Module): |
| maybe_module.requires_grad_(trainable_blocks > 0) |
|
|
| decoder_embeddings = getattr(self.decoder, "embed_tokens", None) |
| if isinstance(decoder_embeddings, nn.Module): |
| decoder_embeddings.requires_grad_(trainable_blocks > 0) |
|
|
| for module_path in ( |
| "lm_head", |
| "model.lm_head", |
| "backbone.lm_head", |
| "shared", |
| "model.shared", |
| ): |
| parent: Any = self |
| path_parts = module_path.split(".") |
| try: |
| for part in path_parts[:-1]: |
| parent = getattr(parent, part) |
| except AttributeError: |
| continue |
|
|
| child_name = path_parts[-1] |
| child = getattr(parent, child_name, None) |
| if isinstance(child, nn.Module): |
| child.requires_grad_(trainable_blocks > 0) |
|
|
| return trainable_blocks |
|
|
| def initial_state(self, batch_size: int, device: torch.device | None = None) -> torch.Tensor: |
| target_device = device or self.device |
| initial_state = self.z_init.unsqueeze(0).expand(batch_size, -1, -1).to(device=target_device) |
| return self.recurrent_state_norm(initial_state) |
|
|
| def initial_state_mask(self, batch_size: int, device: torch.device | None = None) -> torch.Tensor: |
| target_device = device or self.device |
| if self.latent_attention_mask_mode == "full": |
| return torch.ones(batch_size, self.z_slots, dtype=torch.long, device=target_device) |
| if self.thought_loop_proposal_mode == "observation_hidden_compression": |
| return torch.zeros(batch_size, self.z_slots, dtype=torch.long, device=target_device) |
| return torch.ones(batch_size, self.z_slots, dtype=torch.long, device=target_device) |
|
|
| def _build_temporal_tensor( |
| self, |
| delta_seconds: torch.Tensor, |
| elapsed_seconds: torch.Tensor, |
| since_last_user_seconds: torch.Tensor, |
| since_last_agent_seconds: torch.Tensor, |
| ) -> torch.Tensor: |
| if not self.use_explicit_time_features: |
| raise RuntimeError("Temporal features were requested, but this model is configured to disable them.") |
| stacked = torch.stack( |
| [ |
| torch.log1p(delta_seconds), |
| torch.log1p(elapsed_seconds), |
| torch.log1p(since_last_user_seconds), |
| torch.log1p(since_last_agent_seconds), |
| ], |
| dim=-1, |
| ) |
| return stacked |
|
|
| def compress_hidden_states_to_slots( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor, |
| target_slots: int | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| slots = int(target_slots if target_slots is not None else self.z_slots) |
| batch_size, _, hidden_size = hidden_states.shape |
| compressed_hidden = hidden_states.new_zeros((batch_size, slots, hidden_size)) |
| compressed_mask = torch.zeros(batch_size, slots, dtype=torch.long, device=hidden_states.device) |
| attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.bool) |
|
|
| for batch_index in range(batch_size): |
| valid_states = hidden_states[batch_index, attention_mask[batch_index]] |
| valid_length = int(valid_states.shape[0]) |
| if valid_length <= 0: |
| continue |
|
|
| if valid_length <= slots: |
| compressed_hidden[batch_index, :valid_length] = valid_states |
| compressed_mask[batch_index, :valid_length] = 1 |
| continue |
|
|
| for slot_index in range(slots): |
| start = (slot_index * valid_length) // slots |
| end = ((slot_index + 1) * valid_length) // slots |
| if end <= start: |
| end = min(valid_length, start + 1) |
| compressed_hidden[batch_index, slot_index] = valid_states[start:end].mean(dim=0) |
| compressed_mask[batch_index, slot_index] = 1 |
|
|
| return compressed_hidden, compressed_mask |
|
|
| def _rollout_step_impl( |
| self, |
| z_state: torch.Tensor, |
| observation_input_ids: torch.Tensor, |
| observation_attention_mask: torch.Tensor, |
| observation_role_ids: torch.Tensor, |
| delta_seconds: torch.Tensor, |
| elapsed_seconds: torch.Tensor, |
| since_last_user_seconds: torch.Tensor, |
| since_last_agent_seconds: torch.Tensor, |
| previous_state_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| batch_size = z_state.size(0) |
| recurrent_dtype = z_state.dtype |
| backbone_dtype = self.backbone_dtype |
|
|
| if self.thought_loop_proposal_mode == "latent_prefix": |
| aux_tokens: list[torch.Tensor] = [] |
| aux_masks: list[torch.Tensor] = [] |
| if self.use_explicit_time_features: |
| temporal_features = self._build_temporal_tensor( |
| delta_seconds=delta_seconds, |
| elapsed_seconds=elapsed_seconds, |
| since_last_user_seconds=since_last_user_seconds, |
| since_last_agent_seconds=since_last_agent_seconds, |
| ).to(dtype=recurrent_dtype) |
| assert self.temporal_projector is not None |
| time_tokens = self.temporal_projector(temporal_features) |
| time_tokens = time_tokens + self.segment_embeddings.weight[1].view(1, 1, -1).to(dtype=recurrent_dtype) |
| aux_tokens.append(time_tokens) |
| aux_masks.append(torch.ones(batch_size, self.num_time_tokens, device=z_state.device, dtype=torch.long)) |
|
|
| role_tokens = self.observation_role_embeddings(observation_role_ids).unsqueeze(1).to(dtype=recurrent_dtype) |
| role_tokens = role_tokens + self.segment_embeddings.weight[1].view(1, 1, -1).to(dtype=recurrent_dtype) |
| aux_tokens.append(role_tokens) |
| aux_masks.append(torch.ones(batch_size, 1, device=z_state.device, dtype=torch.long)) |
|
|
| observation_embeds = self.input_embeddings(observation_input_ids).to(dtype=recurrent_dtype) |
| observation_embeds = observation_embeds + self.segment_embeddings.weight[2].view(1, 1, -1).to( |
| dtype=recurrent_dtype |
| ) |
| observation_embeds = observation_embeds + self.observation_role_embeddings(observation_role_ids).unsqueeze(1).to( |
| dtype=recurrent_dtype |
| ) |
|
|
| z_tokens = z_state + self.segment_embeddings.weight[0].view(1, 1, -1).to(dtype=recurrent_dtype) |
| encoder_inputs = torch.cat([z_tokens, *aux_tokens, observation_embeds], dim=1) |
|
|
| z_mask = torch.ones(batch_size, self.z_slots, device=z_state.device, dtype=torch.long) |
| encoder_attention_mask = torch.cat([z_mask, *aux_masks, observation_attention_mask.long()], dim=1) |
|
|
| |
| |
| encoder_outputs = self.encoder( |
| inputs_embeds=encoder_inputs.to(dtype=backbone_dtype), |
| attention_mask=encoder_attention_mask, |
| return_dict=True, |
| ) |
| latent_prefix_state = encoder_outputs.last_hidden_state[:, : self.z_slots, :].to(dtype=recurrent_dtype) |
| proposed_state = latent_prefix_state |
| proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device) |
| else: |
| use_state_context = ( |
| self.observation_encoder_use_state_context |
| and previous_state_mask is not None |
| and torch.any(previous_state_mask > 0) |
| ) |
| if use_state_context: |
| observation_embeds = self.input_embeddings(observation_input_ids).to(dtype=recurrent_dtype) |
| encoder_inputs = torch.cat([z_state, observation_embeds], dim=1) |
| encoder_attention_mask = torch.cat( |
| [ |
| previous_state_mask.to(device=z_state.device, dtype=torch.long), |
| observation_attention_mask.long(), |
| ], |
| dim=1, |
| ) |
| observation_encoder_outputs = self.encoder( |
| inputs_embeds=encoder_inputs.to(dtype=backbone_dtype), |
| attention_mask=encoder_attention_mask, |
| return_dict=True, |
| ) |
| if self.latent_attention_mask_mode == "full": |
| proposed_state = observation_encoder_outputs.last_hidden_state[:, : self.z_slots, :].to( |
| dtype=recurrent_dtype |
| ) |
| proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device) |
| observation_outputs = None |
| else: |
| observation_outputs = observation_encoder_outputs.last_hidden_state[:, self.z_slots :, :].to( |
| dtype=recurrent_dtype |
| ) |
| else: |
| observation_encoder_outputs = self.encoder( |
| input_ids=observation_input_ids, |
| attention_mask=observation_attention_mask.long(), |
| return_dict=True, |
| ) |
| observation_outputs = observation_encoder_outputs.last_hidden_state.to(dtype=recurrent_dtype) |
| if observation_outputs is not None: |
| proposed_state, proposal_mask = self.compress_hidden_states_to_slots( |
| observation_outputs, |
| observation_attention_mask, |
| target_slots=self.z_slots, |
| ) |
| proposed_state = proposed_state.to(dtype=recurrent_dtype) |
| if self.latent_attention_mask_mode == "full": |
| proposed_state = proposed_state.clone() |
| inactive_slot_mask = ~proposal_mask.bool() |
| proposed_state[inactive_slot_mask] = z_state[inactive_slot_mask] |
| proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device) |
|
|
| no_observation_mask = proposal_mask.sum(dim=1) <= 0 |
| if torch.any(no_observation_mask): |
| proposed_state = proposed_state.clone() |
| proposal_mask = proposal_mask.clone() |
| proposed_state[no_observation_mask] = z_state[no_observation_mask] |
| if previous_state_mask is None: |
| proposal_mask[no_observation_mask] = 1 |
| else: |
| proposal_mask[no_observation_mask] = previous_state_mask[no_observation_mask].to( |
| device=z_state.device, |
| dtype=torch.long, |
| ) |
|
|
| if not self.preserve_observation_encoder_manifold: |
| proposed_state = self.proposed_state_norm(proposed_state) |
|
|
| gate_inputs = self.state_gate_input_norm(torch.cat([z_state, proposed_state], dim=-1)) |
| update_gate = torch.sigmoid(self.state_gate(gate_inputs)) |
| raw_next_state = update_gate * proposed_state + (1.0 - update_gate) * z_state |
| active_slot_mask = proposal_mask.bool().unsqueeze(-1) |
| next_state = torch.where(active_slot_mask, raw_next_state, z_state) |
| if not self.preserve_observation_encoder_manifold: |
| next_state = self.recurrent_state_norm(next_state) |
| if previous_state_mask is not None: |
| next_state_mask = torch.maximum( |
| previous_state_mask.to(device=z_state.device, dtype=torch.long), |
| proposal_mask, |
| ) |
| else: |
| next_state_mask = proposal_mask |
|
|
| pooled_query = self.gate_query.to(dtype=next_state.dtype).expand(batch_size, -1, -1) |
| gate_context = self.gate_context_norm(next_state) |
| pooled_state, _ = self.gate_pool(pooled_query, gate_context, gate_context, need_weights=False) |
| gate_logits = self.gate_head(pooled_state.squeeze(1)).squeeze(-1) |
| return next_state, gate_logits, next_state_mask |
|
|
| def rollout_step( |
| self, |
| z_state: torch.Tensor, |
| observation_input_ids: torch.Tensor, |
| observation_attention_mask: torch.Tensor, |
| observation_role_ids: torch.Tensor, |
| delta_seconds: torch.Tensor, |
| elapsed_seconds: torch.Tensor, |
| since_last_user_seconds: torch.Tensor, |
| since_last_agent_seconds: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| next_state, gate_logits, _ = self._rollout_step_impl( |
| z_state=z_state, |
| observation_input_ids=observation_input_ids, |
| observation_attention_mask=observation_attention_mask, |
| observation_role_ids=observation_role_ids, |
| delta_seconds=delta_seconds, |
| elapsed_seconds=elapsed_seconds, |
| since_last_user_seconds=since_last_user_seconds, |
| since_last_agent_seconds=since_last_agent_seconds, |
| previous_state_mask=None, |
| ) |
| return next_state, gate_logits |
|
|
| def rollout_step_with_mask( |
| self, |
| z_state: torch.Tensor, |
| state_mask: torch.Tensor, |
| observation_input_ids: torch.Tensor, |
| observation_attention_mask: torch.Tensor, |
| observation_role_ids: torch.Tensor, |
| delta_seconds: torch.Tensor, |
| elapsed_seconds: torch.Tensor, |
| since_last_user_seconds: torch.Tensor, |
| since_last_agent_seconds: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| return self._rollout_step_impl( |
| z_state=z_state, |
| observation_input_ids=observation_input_ids, |
| observation_attention_mask=observation_attention_mask, |
| observation_role_ids=observation_role_ids, |
| delta_seconds=delta_seconds, |
| elapsed_seconds=elapsed_seconds, |
| since_last_user_seconds=since_last_user_seconds, |
| since_last_agent_seconds=since_last_agent_seconds, |
| previous_state_mask=state_mask, |
| ) |
|
|
| def _resolve_generation_eos_token_id(self) -> int | list[int] | None: |
| eos_token_id = getattr(self.backbone.generation_config, "eos_token_id", None) |
| if eos_token_id is None: |
| eos_token_id = getattr(self.backbone.config, "eos_token_id", None) |
| return eos_token_id |
|
|
| def _resolve_fill_token_id(self) -> int: |
| eos_token_id = self._resolve_generation_eos_token_id() |
| if isinstance(eos_token_id, list) and eos_token_id: |
| return int(eos_token_id[0]) |
| if isinstance(eos_token_id, int): |
| return eos_token_id |
| if self.tokenizer.eos_token_id is not None: |
| return int(self.tokenizer.eos_token_id) |
| if self.tokenizer.pad_token_id is not None: |
| return int(self.tokenizer.pad_token_id) |
| return 0 |
|
|
| def _teacher_forced_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor: |
| if labels.ndim != 2 or labels.size(0) <= 0: |
| raise ValueError("labels must be a rank-2 tensor with non-zero batch size.") |
| safe_labels = labels.clone() |
| safe_labels[safe_labels == -100] = self._resolve_fill_token_id() |
| if hasattr(self.backbone, "prepare_decoder_input_ids_from_labels"): |
| return self.backbone.prepare_decoder_input_ids_from_labels(labels=safe_labels) |
| return self.backbone._shift_right(safe_labels) |
|
|
| def _build_self_generated_decoder_inputs( |
| self, |
| labels: torch.Tensor, |
| generated: torch.Tensor, |
| *, |
| self_generated_prefix_tokens: int, |
| ) -> torch.Tensor: |
| batch_size, target_length = labels.shape |
| teacher_inputs = self._teacher_forced_decoder_inputs(labels).to(device=labels.device) |
| if target_length <= 0: |
| return teacher_inputs |
|
|
| prefix_token_count = min(max(self_generated_prefix_tokens, 0), max(target_length - 1, 0)) |
| if prefix_token_count <= 0: |
| return teacher_inputs |
|
|
| generated_tokens = generated.to(device=labels.device, dtype=torch.long) |
| if generated_tokens.ndim != 2: |
| raise ValueError("generated tokens must be rank-2.") |
| if generated_tokens.size(0) != batch_size: |
| raise ValueError("generated batch size must match labels batch size.") |
| if generated_tokens.size(1) > 0 and torch.equal(generated_tokens[:, :1], teacher_inputs[:, :1]): |
| generated_tokens = generated_tokens[:, 1:] |
|
|
| fill_token_id = self._resolve_fill_token_id() |
| if generated_tokens.size(1) < prefix_token_count: |
| padding = labels.new_full( |
| (batch_size, prefix_token_count - generated_tokens.size(1)), |
| fill_value=fill_token_id, |
| ) |
| generated_tokens = torch.cat([generated_tokens, padding], dim=1) |
| else: |
| generated_tokens = generated_tokens[:, :prefix_token_count] |
|
|
| hybrid_inputs = teacher_inputs.clone() |
| hybrid_inputs[:, 1 : 1 + prefix_token_count] = generated_tokens |
| return hybrid_inputs |
|
|
| def decoder_loss( |
| self, |
| z_state: torch.Tensor, |
| labels: torch.Tensor, |
| *, |
| encoder_attention_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if z_state.numel() == 0: |
| zero = torch.zeros((), device=self.device) |
| return zero, zero |
| encoder_mask = ( |
| encoder_attention_mask.to(device=z_state.device, dtype=torch.long) |
| if encoder_attention_mask is not None |
| else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) |
| ) |
| token_count = (labels != -100).sum() |
| if token_count.item() == 0: |
| zero = torch.zeros((), device=z_state.device) |
| return zero, zero |
| outputs = self.backbone( |
| encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), |
| attention_mask=encoder_mask, |
| labels=labels, |
| return_dict=True, |
| ) |
| loss_sum = outputs.loss * token_count |
| return loss_sum, token_count |
|
|
| def decoder_self_generated_loss( |
| self, |
| z_state: torch.Tensor, |
| labels: torch.Tensor, |
| *, |
| generation_kwargs: dict[str, Any] | None = None, |
| self_generated_prefix_tokens: int | None = None, |
| encoder_attention_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if z_state.numel() == 0: |
| zero = torch.zeros((), device=self.device) |
| return zero, zero |
|
|
| encoder_mask = ( |
| encoder_attention_mask.to(device=z_state.device, dtype=torch.long) |
| if encoder_attention_mask is not None |
| else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) |
| ) |
| token_count = (labels != -100).sum() |
| if token_count.item() == 0: |
| zero = torch.zeros((), device=z_state.device) |
| return zero, zero |
|
|
| prefix_token_count = max(0, int(self_generated_prefix_tokens if self_generated_prefix_tokens is not None else labels.shape[1])) |
| if prefix_token_count <= 0: |
| return self.decoder_loss(z_state, labels, encoder_attention_mask=encoder_attention_mask) |
|
|
| effective_generation_kwargs = { |
| "max_new_tokens": prefix_token_count, |
| "do_sample": False, |
| "num_beams": 1, |
| "return_dict_in_generate": False, |
| } |
| if generation_kwargs: |
| effective_generation_kwargs.update(generation_kwargs) |
|
|
| was_training = self.backbone.training |
| if was_training: |
| self.backbone.eval() |
| try: |
| with torch.no_grad(): |
| generated = self.backbone.generate( |
| encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), |
| attention_mask=encoder_mask, |
| **effective_generation_kwargs, |
| ) |
| finally: |
| if was_training: |
| self.backbone.train() |
|
|
| decoder_input_ids = self._build_self_generated_decoder_inputs( |
| labels, |
| generated, |
| self_generated_prefix_tokens=prefix_token_count, |
| ) |
| outputs = self.backbone( |
| encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), |
| attention_mask=encoder_mask, |
| decoder_input_ids=decoder_input_ids, |
| return_dict=True, |
| ) |
| logits = outputs.logits.float() |
| loss_sum = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| labels.reshape(-1), |
| ignore_index=-100, |
| reduction="sum", |
| ) |
| return loss_sum, token_count |
|
|
| @torch.no_grad() |
| def first_token_exact_match_stats( |
| self, |
| z_state: torch.Tensor, |
| labels: torch.Tensor, |
| *, |
| encoder_attention_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if z_state.numel() == 0: |
| zero = torch.zeros((), device=self.device) |
| return zero, zero |
|
|
| first_logits, valid_mask = self.first_token_logits( |
| z_state, |
| labels, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| valid_count = valid_mask.sum() |
| if valid_count.item() == 0: |
| zero = torch.zeros((), device=self.device) |
| return zero, zero |
|
|
| first_targets = labels[:, 0] |
| predicted_first_tokens = first_logits.argmax(dim=-1) |
| correct = ((predicted_first_tokens == first_targets) & valid_mask).sum() |
| return correct, valid_count |
|
|
| def first_token_logits( |
| self, |
| z_state: torch.Tensor, |
| labels: torch.Tensor, |
| *, |
| encoder_attention_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if z_state.numel() == 0: |
| empty_logits = torch.zeros((0, 0), device=self.device) |
| empty_mask = torch.zeros((0,), dtype=torch.bool, device=self.device) |
| return empty_logits, empty_mask |
|
|
| first_targets = labels[:, 0] |
| valid_mask = first_targets != -100 |
| encoder_mask = ( |
| encoder_attention_mask.to(device=z_state.device, dtype=torch.long) |
| if encoder_attention_mask is not None |
| else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) |
| ) |
| decoder_input_ids = self._teacher_forced_decoder_inputs(labels)[:, :1].to(device=z_state.device) |
| outputs = self.backbone( |
| encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), |
| attention_mask=encoder_mask, |
| decoder_input_ids=decoder_input_ids, |
| return_dict=True, |
| ) |
| return outputs.logits[:, 0, :], valid_mask |
|
|
| @torch.no_grad() |
| def generate_from_state( |
| self, |
| z_state: torch.Tensor, |
| *, |
| encoder_attention_mask: torch.Tensor | None = None, |
| **generation_kwargs: Any, |
| ) -> torch.Tensor: |
| encoder_mask = ( |
| encoder_attention_mask.to(device=z_state.device, dtype=torch.long) |
| if encoder_attention_mask is not None |
| else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) |
| ) |
| return self.backbone.generate( |
| encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), |
| attention_mask=encoder_mask, |
| **generation_kwargs, |
| ) |
|
|
| def save_pretrained(self, output_dir: str | Path, tokenizer: Any | None = None) -> None: |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| (output_path / "sft_config.json").write_text( |
| json.dumps(self.config, indent=2, ensure_ascii=False), |
| encoding="utf-8", |
| ) |
| |
| save_model(self, str(output_path / "model.safetensors")) |
| save_file({"z_init": self.z_init.detach().cpu()}, str(output_path / "initial_latent_z.safetensors")) |
| active_tokenizer = tokenizer or self.tokenizer |
| active_tokenizer.save_pretrained(output_path) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_path_or_repo_id: str, |
| device: str | torch.device = "cpu", |
| map_location: str | torch.device = "cpu", |
| ) -> "ThoughtLoopT5Gemma": |
| local_path = Path(model_path_or_repo_id) |
| if not local_path.exists(): |
| local_path = Path(snapshot_download(repo_id=model_path_or_repo_id)) |
|
|
| config = json.loads((local_path / "sft_config.json").read_text(encoding="utf-8")) |
| model = cls(config) |
| load_model(model, str(local_path / "model.safetensors"), device=str(map_location)) |
| model.to(device) |
| return model |
|
|