from __future__ import annotations import json import math import sys from pathlib import Path from typing import Any import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import snapshot_download from safetensors.torch import load_model, save_file, save_model from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers.modeling_outputs import BaseModelOutput SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) class ParameterlessRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.hidden_size = hidden_size self.eps = eps def forward(self, inputs: torch.Tensor) -> torch.Tensor: original_dtype = inputs.dtype normalized = inputs.float() normalized = normalized * torch.rsqrt(normalized.pow(2).mean(dim=-1, keepdim=True) + self.eps) return normalized.to(dtype=original_dtype) class TemporalFeatureProjector(nn.Module): def __init__(self, hidden_size: int, num_time_tokens: int, scalar_count: int = 4) -> None: super().__init__() self.hidden_size = hidden_size self.num_time_tokens = num_time_tokens self.mlp = nn.Sequential( nn.Linear(scalar_count, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size * num_time_tokens), ) self.norm = nn.LayerNorm(hidden_size) def forward(self, temporal_features: torch.Tensor) -> torch.Tensor: projected = self.mlp(temporal_features) projected = projected.view(temporal_features.size(0), self.num_time_tokens, self.hidden_size) return self.norm(projected) class ThoughtLoopT5Gemma(nn.Module): def __init__(self, config: dict[str, Any]) -> None: super().__init__() self.config = config model_cfg = config["model"] training_cfg = config.get("training", {}) backbone_dtype = getattr(torch, model_cfg.get("dtype", "bfloat16")) backbone_kwargs: dict[str, Any] = { # transformers currently warns that torch_dtype is deprecated in favor of dtype, # but torch_dtype remains compatible across more installed versions. "torch_dtype": backbone_dtype, } attn_implementation = model_cfg.get("attn_implementation") if attn_implementation: backbone_kwargs["attn_implementation"] = attn_implementation self.backbone = AutoModelForSeq2SeqLM.from_pretrained(model_cfg["base_model_name"], **backbone_kwargs) self.tokenizer = AutoTokenizer.from_pretrained(model_cfg["base_model_name"]) if training_cfg.get("gradient_checkpointing", True): self.backbone.gradient_checkpointing_enable() if getattr(self.backbone.config, "use_cache", None) is not None: self.backbone.config.use_cache = False self.encoder = self.backbone.get_encoder() self.decoder = self.backbone.get_decoder() self.input_embeddings = self.backbone.get_input_embeddings() hidden_size = int(self.backbone.config.decoder.hidden_size) self.hidden_size = hidden_size self.z_slots = int(model_cfg["z_slots"]) self.thought_loop_proposal_mode = str( model_cfg.get("thought_loop_proposal_mode", "latent_prefix") ).strip().lower() if self.thought_loop_proposal_mode not in {"latent_prefix", "observation_hidden_compression"}: raise ValueError(f"Unsupported thought_loop_proposal_mode: {self.thought_loop_proposal_mode}") self.preserve_observation_encoder_manifold = bool( model_cfg.get( "preserve_observation_encoder_manifold", self.thought_loop_proposal_mode == "observation_hidden_compression", ) ) self.observation_encoder_use_state_context = bool( model_cfg.get("observation_encoder_use_state_context", False) ) self.latent_attention_mask_mode = str( model_cfg.get("latent_attention_mask_mode", "observed") ).strip().lower() if self.latent_attention_mask_mode not in {"observed", "full"}: raise ValueError(f"Unsupported latent_attention_mask_mode: {self.latent_attention_mask_mode}") self.use_explicit_time_features = bool(model_cfg.get("use_explicit_time_features", False)) self.num_time_tokens = int(model_cfg["num_time_tokens"]) if self.use_explicit_time_features else 0 self.observation_role_count = 3 magicnorm_eps = float(model_cfg.get("magicnorm_eps", 1e-6)) # Keep the newly initialized recurrent/gating modules in fp32. We cast only the # tensors handed into the bf16 backbone. This avoids dtype crashes while preserving # stable optimizer state for the custom modules. self.z_init = nn.Parameter(torch.randn(self.z_slots, hidden_size) * 0.02) self.segment_embeddings = nn.Embedding(3, hidden_size) self.observation_role_embeddings = nn.Embedding(self.observation_role_count, hidden_size) self.temporal_projector = ( TemporalFeatureProjector(hidden_size, self.num_time_tokens) if self.use_explicit_time_features else None ) self.state_gate = nn.Sequential( nn.Linear(hidden_size * 2, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size), ) self._initialize_update_gate_bias(model_cfg) self.state_gate_input_norm = ParameterlessRMSNorm(hidden_size * 2, eps=magicnorm_eps) self.proposed_state_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps) self.recurrent_state_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps) self.gate_context_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps) self.gate_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.gate_pool = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=int(model_cfg.get("gate_attention_heads", 4)), batch_first=True, ) self.gate_head = nn.Sequential( nn.LayerNorm(hidden_size), nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, 1), ) self._drop_unused_vision_modules() def _initialize_update_gate_bias(self, model_cfg: dict[str, Any]) -> None: if "initial_update_gate_bias" not in model_cfg: return final_gate_layer = self.state_gate[-1] if not isinstance(final_gate_layer, nn.Linear) or final_gate_layer.bias is None: raise RuntimeError("state_gate must end in a Linear layer with bias to use initial_update_gate_bias.") nn.init.constant_(final_gate_layer.bias, float(model_cfg["initial_update_gate_bias"])) def _drop_unused_vision_modules(self) -> None: for module_path in ( "encoder.vision_tower", "encoder.vision_model", "model.encoder.vision_tower", "model.encoder.vision_model", ): parent: Any = self.backbone path_parts = module_path.split(".") try: for part in path_parts[:-1]: parent = getattr(parent, part) except AttributeError: continue child_name = path_parts[-1] child = getattr(parent, child_name, None) if isinstance(child, nn.Module): setattr(parent, child_name, None) @property def device(self) -> torch.device: return next(self.parameters()).device @property def backbone_dtype(self) -> torch.dtype: return self.input_embeddings.weight.dtype @property def recurrent_dtype(self) -> torch.dtype: return self.z_init.dtype def trainable_parameter_count(self) -> int: return sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad) def _decoder_blocks(self) -> Any: for attribute_name in ("block", "layers", "h"): blocks = getattr(self.decoder, attribute_name, None) if blocks is not None: return blocks return None def decoder_layer_count(self) -> int: blocks = self._decoder_blocks() return len(blocks) if blocks is not None else 0 def set_gate_trainable(self, trainable: bool) -> None: self.gate_query.requires_grad_(trainable) self.gate_pool.requires_grad_(trainable) self.gate_head.requires_grad_(trainable) self.gate_context_norm.requires_grad_(trainable) def set_decoder_trainable_fraction(self, trainable_fraction: float) -> int: blocks = self._decoder_blocks() if blocks is None: return 0 clamped_fraction = min(max(trainable_fraction, 0.0), 1.0) total_blocks = len(blocks) trainable_blocks = min(total_blocks, math.ceil(total_blocks * clamped_fraction)) if clamped_fraction > 0 else 0 first_trainable_index = total_blocks - trainable_blocks for block_index, block in enumerate(blocks): block.requires_grad_(block_index >= first_trainable_index) for attribute_name in ("final_layer_norm", "norm", "layer_norm"): maybe_module = getattr(self.decoder, attribute_name, None) if isinstance(maybe_module, nn.Module): maybe_module.requires_grad_(trainable_blocks > 0) decoder_embeddings = getattr(self.decoder, "embed_tokens", None) if isinstance(decoder_embeddings, nn.Module): decoder_embeddings.requires_grad_(trainable_blocks > 0) for module_path in ( "lm_head", "model.lm_head", "backbone.lm_head", "shared", "model.shared", ): parent: Any = self path_parts = module_path.split(".") try: for part in path_parts[:-1]: parent = getattr(parent, part) except AttributeError: continue child_name = path_parts[-1] child = getattr(parent, child_name, None) if isinstance(child, nn.Module): child.requires_grad_(trainable_blocks > 0) return trainable_blocks def initial_state(self, batch_size: int, device: torch.device | None = None) -> torch.Tensor: target_device = device or self.device initial_state = self.z_init.unsqueeze(0).expand(batch_size, -1, -1).to(device=target_device) return self.recurrent_state_norm(initial_state) def initial_state_mask(self, batch_size: int, device: torch.device | None = None) -> torch.Tensor: target_device = device or self.device if self.latent_attention_mask_mode == "full": return torch.ones(batch_size, self.z_slots, dtype=torch.long, device=target_device) if self.thought_loop_proposal_mode == "observation_hidden_compression": return torch.zeros(batch_size, self.z_slots, dtype=torch.long, device=target_device) return torch.ones(batch_size, self.z_slots, dtype=torch.long, device=target_device) def _build_temporal_tensor( self, delta_seconds: torch.Tensor, elapsed_seconds: torch.Tensor, since_last_user_seconds: torch.Tensor, since_last_agent_seconds: torch.Tensor, ) -> torch.Tensor: if not self.use_explicit_time_features: raise RuntimeError("Temporal features were requested, but this model is configured to disable them.") stacked = torch.stack( [ torch.log1p(delta_seconds), torch.log1p(elapsed_seconds), torch.log1p(since_last_user_seconds), torch.log1p(since_last_agent_seconds), ], dim=-1, ) return stacked def compress_hidden_states_to_slots( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, target_slots: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: slots = int(target_slots if target_slots is not None else self.z_slots) batch_size, _, hidden_size = hidden_states.shape compressed_hidden = hidden_states.new_zeros((batch_size, slots, hidden_size)) compressed_mask = torch.zeros(batch_size, slots, dtype=torch.long, device=hidden_states.device) attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.bool) for batch_index in range(batch_size): valid_states = hidden_states[batch_index, attention_mask[batch_index]] valid_length = int(valid_states.shape[0]) if valid_length <= 0: continue if valid_length <= slots: compressed_hidden[batch_index, :valid_length] = valid_states compressed_mask[batch_index, :valid_length] = 1 continue for slot_index in range(slots): start = (slot_index * valid_length) // slots end = ((slot_index + 1) * valid_length) // slots if end <= start: end = min(valid_length, start + 1) compressed_hidden[batch_index, slot_index] = valid_states[start:end].mean(dim=0) compressed_mask[batch_index, slot_index] = 1 return compressed_hidden, compressed_mask def _rollout_step_impl( self, z_state: torch.Tensor, observation_input_ids: torch.Tensor, observation_attention_mask: torch.Tensor, observation_role_ids: torch.Tensor, delta_seconds: torch.Tensor, elapsed_seconds: torch.Tensor, since_last_user_seconds: torch.Tensor, since_last_agent_seconds: torch.Tensor, previous_state_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = z_state.size(0) recurrent_dtype = z_state.dtype backbone_dtype = self.backbone_dtype if self.thought_loop_proposal_mode == "latent_prefix": aux_tokens: list[torch.Tensor] = [] aux_masks: list[torch.Tensor] = [] if self.use_explicit_time_features: temporal_features = self._build_temporal_tensor( delta_seconds=delta_seconds, elapsed_seconds=elapsed_seconds, since_last_user_seconds=since_last_user_seconds, since_last_agent_seconds=since_last_agent_seconds, ).to(dtype=recurrent_dtype) assert self.temporal_projector is not None time_tokens = self.temporal_projector(temporal_features) time_tokens = time_tokens + self.segment_embeddings.weight[1].view(1, 1, -1).to(dtype=recurrent_dtype) aux_tokens.append(time_tokens) aux_masks.append(torch.ones(batch_size, self.num_time_tokens, device=z_state.device, dtype=torch.long)) role_tokens = self.observation_role_embeddings(observation_role_ids).unsqueeze(1).to(dtype=recurrent_dtype) role_tokens = role_tokens + self.segment_embeddings.weight[1].view(1, 1, -1).to(dtype=recurrent_dtype) aux_tokens.append(role_tokens) aux_masks.append(torch.ones(batch_size, 1, device=z_state.device, dtype=torch.long)) observation_embeds = self.input_embeddings(observation_input_ids).to(dtype=recurrent_dtype) observation_embeds = observation_embeds + self.segment_embeddings.weight[2].view(1, 1, -1).to( dtype=recurrent_dtype ) observation_embeds = observation_embeds + self.observation_role_embeddings(observation_role_ids).unsqueeze(1).to( dtype=recurrent_dtype ) z_tokens = z_state + self.segment_embeddings.weight[0].view(1, 1, -1).to(dtype=recurrent_dtype) encoder_inputs = torch.cat([z_tokens, *aux_tokens, observation_embeds], dim=1) z_mask = torch.ones(batch_size, self.z_slots, device=z_state.device, dtype=torch.long) encoder_attention_mask = torch.cat([z_mask, *aux_masks, observation_attention_mask.long()], dim=1) # The pretrained T5Gemma backbone was loaded in bf16. Passing fp32 inputs_embeds # into its bf16 Linear layers causes: expected mat1 and mat2 to have same dtype. encoder_outputs = self.encoder( inputs_embeds=encoder_inputs.to(dtype=backbone_dtype), attention_mask=encoder_attention_mask, return_dict=True, ) latent_prefix_state = encoder_outputs.last_hidden_state[:, : self.z_slots, :].to(dtype=recurrent_dtype) proposed_state = latent_prefix_state proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device) else: use_state_context = ( self.observation_encoder_use_state_context and previous_state_mask is not None and torch.any(previous_state_mask > 0) ) if use_state_context: observation_embeds = self.input_embeddings(observation_input_ids).to(dtype=recurrent_dtype) encoder_inputs = torch.cat([z_state, observation_embeds], dim=1) encoder_attention_mask = torch.cat( [ previous_state_mask.to(device=z_state.device, dtype=torch.long), observation_attention_mask.long(), ], dim=1, ) observation_encoder_outputs = self.encoder( inputs_embeds=encoder_inputs.to(dtype=backbone_dtype), attention_mask=encoder_attention_mask, return_dict=True, ) if self.latent_attention_mask_mode == "full": proposed_state = observation_encoder_outputs.last_hidden_state[:, : self.z_slots, :].to( dtype=recurrent_dtype ) proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device) observation_outputs = None else: observation_outputs = observation_encoder_outputs.last_hidden_state[:, self.z_slots :, :].to( dtype=recurrent_dtype ) else: observation_encoder_outputs = self.encoder( input_ids=observation_input_ids, attention_mask=observation_attention_mask.long(), return_dict=True, ) observation_outputs = observation_encoder_outputs.last_hidden_state.to(dtype=recurrent_dtype) if observation_outputs is not None: proposed_state, proposal_mask = self.compress_hidden_states_to_slots( observation_outputs, observation_attention_mask, target_slots=self.z_slots, ) proposed_state = proposed_state.to(dtype=recurrent_dtype) if self.latent_attention_mask_mode == "full": proposed_state = proposed_state.clone() inactive_slot_mask = ~proposal_mask.bool() proposed_state[inactive_slot_mask] = z_state[inactive_slot_mask] proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device) no_observation_mask = proposal_mask.sum(dim=1) <= 0 if torch.any(no_observation_mask): proposed_state = proposed_state.clone() proposal_mask = proposal_mask.clone() proposed_state[no_observation_mask] = z_state[no_observation_mask] if previous_state_mask is None: proposal_mask[no_observation_mask] = 1 else: proposal_mask[no_observation_mask] = previous_state_mask[no_observation_mask].to( device=z_state.device, dtype=torch.long, ) if not self.preserve_observation_encoder_manifold: proposed_state = self.proposed_state_norm(proposed_state) gate_inputs = self.state_gate_input_norm(torch.cat([z_state, proposed_state], dim=-1)) update_gate = torch.sigmoid(self.state_gate(gate_inputs)) raw_next_state = update_gate * proposed_state + (1.0 - update_gate) * z_state active_slot_mask = proposal_mask.bool().unsqueeze(-1) next_state = torch.where(active_slot_mask, raw_next_state, z_state) if not self.preserve_observation_encoder_manifold: next_state = self.recurrent_state_norm(next_state) if previous_state_mask is not None: next_state_mask = torch.maximum( previous_state_mask.to(device=z_state.device, dtype=torch.long), proposal_mask, ) else: next_state_mask = proposal_mask pooled_query = self.gate_query.to(dtype=next_state.dtype).expand(batch_size, -1, -1) gate_context = self.gate_context_norm(next_state) pooled_state, _ = self.gate_pool(pooled_query, gate_context, gate_context, need_weights=False) gate_logits = self.gate_head(pooled_state.squeeze(1)).squeeze(-1) return next_state, gate_logits, next_state_mask def rollout_step( self, z_state: torch.Tensor, observation_input_ids: torch.Tensor, observation_attention_mask: torch.Tensor, observation_role_ids: torch.Tensor, delta_seconds: torch.Tensor, elapsed_seconds: torch.Tensor, since_last_user_seconds: torch.Tensor, since_last_agent_seconds: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: next_state, gate_logits, _ = self._rollout_step_impl( z_state=z_state, observation_input_ids=observation_input_ids, observation_attention_mask=observation_attention_mask, observation_role_ids=observation_role_ids, delta_seconds=delta_seconds, elapsed_seconds=elapsed_seconds, since_last_user_seconds=since_last_user_seconds, since_last_agent_seconds=since_last_agent_seconds, previous_state_mask=None, ) return next_state, gate_logits def rollout_step_with_mask( self, z_state: torch.Tensor, state_mask: torch.Tensor, observation_input_ids: torch.Tensor, observation_attention_mask: torch.Tensor, observation_role_ids: torch.Tensor, delta_seconds: torch.Tensor, elapsed_seconds: torch.Tensor, since_last_user_seconds: torch.Tensor, since_last_agent_seconds: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self._rollout_step_impl( z_state=z_state, observation_input_ids=observation_input_ids, observation_attention_mask=observation_attention_mask, observation_role_ids=observation_role_ids, delta_seconds=delta_seconds, elapsed_seconds=elapsed_seconds, since_last_user_seconds=since_last_user_seconds, since_last_agent_seconds=since_last_agent_seconds, previous_state_mask=state_mask, ) def _resolve_generation_eos_token_id(self) -> int | list[int] | None: eos_token_id = getattr(self.backbone.generation_config, "eos_token_id", None) if eos_token_id is None: eos_token_id = getattr(self.backbone.config, "eos_token_id", None) return eos_token_id def _resolve_fill_token_id(self) -> int: eos_token_id = self._resolve_generation_eos_token_id() if isinstance(eos_token_id, list) and eos_token_id: return int(eos_token_id[0]) if isinstance(eos_token_id, int): return eos_token_id if self.tokenizer.eos_token_id is not None: return int(self.tokenizer.eos_token_id) if self.tokenizer.pad_token_id is not None: return int(self.tokenizer.pad_token_id) return 0 def _teacher_forced_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor: if labels.ndim != 2 or labels.size(0) <= 0: raise ValueError("labels must be a rank-2 tensor with non-zero batch size.") safe_labels = labels.clone() safe_labels[safe_labels == -100] = self._resolve_fill_token_id() if hasattr(self.backbone, "prepare_decoder_input_ids_from_labels"): return self.backbone.prepare_decoder_input_ids_from_labels(labels=safe_labels) return self.backbone._shift_right(safe_labels) def _build_self_generated_decoder_inputs( self, labels: torch.Tensor, generated: torch.Tensor, *, self_generated_prefix_tokens: int, ) -> torch.Tensor: batch_size, target_length = labels.shape teacher_inputs = self._teacher_forced_decoder_inputs(labels).to(device=labels.device) if target_length <= 0: return teacher_inputs prefix_token_count = min(max(self_generated_prefix_tokens, 0), max(target_length - 1, 0)) if prefix_token_count <= 0: return teacher_inputs generated_tokens = generated.to(device=labels.device, dtype=torch.long) if generated_tokens.ndim != 2: raise ValueError("generated tokens must be rank-2.") if generated_tokens.size(0) != batch_size: raise ValueError("generated batch size must match labels batch size.") if generated_tokens.size(1) > 0 and torch.equal(generated_tokens[:, :1], teacher_inputs[:, :1]): generated_tokens = generated_tokens[:, 1:] fill_token_id = self._resolve_fill_token_id() if generated_tokens.size(1) < prefix_token_count: padding = labels.new_full( (batch_size, prefix_token_count - generated_tokens.size(1)), fill_value=fill_token_id, ) generated_tokens = torch.cat([generated_tokens, padding], dim=1) else: generated_tokens = generated_tokens[:, :prefix_token_count] hybrid_inputs = teacher_inputs.clone() hybrid_inputs[:, 1 : 1 + prefix_token_count] = generated_tokens return hybrid_inputs def decoder_loss( self, z_state: torch.Tensor, labels: torch.Tensor, *, encoder_attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if z_state.numel() == 0: zero = torch.zeros((), device=self.device) return zero, zero encoder_mask = ( encoder_attention_mask.to(device=z_state.device, dtype=torch.long) if encoder_attention_mask is not None else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) ) token_count = (labels != -100).sum() if token_count.item() == 0: zero = torch.zeros((), device=z_state.device) return zero, zero outputs = self.backbone( encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), attention_mask=encoder_mask, labels=labels, return_dict=True, ) loss_sum = outputs.loss * token_count return loss_sum, token_count def decoder_self_generated_loss( self, z_state: torch.Tensor, labels: torch.Tensor, *, generation_kwargs: dict[str, Any] | None = None, self_generated_prefix_tokens: int | None = None, encoder_attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if z_state.numel() == 0: zero = torch.zeros((), device=self.device) return zero, zero encoder_mask = ( encoder_attention_mask.to(device=z_state.device, dtype=torch.long) if encoder_attention_mask is not None else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) ) token_count = (labels != -100).sum() if token_count.item() == 0: zero = torch.zeros((), device=z_state.device) return zero, zero prefix_token_count = max(0, int(self_generated_prefix_tokens if self_generated_prefix_tokens is not None else labels.shape[1])) if prefix_token_count <= 0: return self.decoder_loss(z_state, labels, encoder_attention_mask=encoder_attention_mask) effective_generation_kwargs = { "max_new_tokens": prefix_token_count, "do_sample": False, "num_beams": 1, "return_dict_in_generate": False, } if generation_kwargs: effective_generation_kwargs.update(generation_kwargs) was_training = self.backbone.training if was_training: self.backbone.eval() try: with torch.no_grad(): generated = self.backbone.generate( encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), attention_mask=encoder_mask, **effective_generation_kwargs, ) finally: if was_training: self.backbone.train() decoder_input_ids = self._build_self_generated_decoder_inputs( labels, generated, self_generated_prefix_tokens=prefix_token_count, ) outputs = self.backbone( encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), attention_mask=encoder_mask, decoder_input_ids=decoder_input_ids, return_dict=True, ) logits = outputs.logits.float() loss_sum = F.cross_entropy( logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100, reduction="sum", ) return loss_sum, token_count @torch.no_grad() def first_token_exact_match_stats( self, z_state: torch.Tensor, labels: torch.Tensor, *, encoder_attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if z_state.numel() == 0: zero = torch.zeros((), device=self.device) return zero, zero first_logits, valid_mask = self.first_token_logits( z_state, labels, encoder_attention_mask=encoder_attention_mask, ) valid_count = valid_mask.sum() if valid_count.item() == 0: zero = torch.zeros((), device=self.device) return zero, zero first_targets = labels[:, 0] predicted_first_tokens = first_logits.argmax(dim=-1) correct = ((predicted_first_tokens == first_targets) & valid_mask).sum() return correct, valid_count def first_token_logits( self, z_state: torch.Tensor, labels: torch.Tensor, *, encoder_attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if z_state.numel() == 0: empty_logits = torch.zeros((0, 0), device=self.device) empty_mask = torch.zeros((0,), dtype=torch.bool, device=self.device) return empty_logits, empty_mask first_targets = labels[:, 0] valid_mask = first_targets != -100 encoder_mask = ( encoder_attention_mask.to(device=z_state.device, dtype=torch.long) if encoder_attention_mask is not None else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) ) decoder_input_ids = self._teacher_forced_decoder_inputs(labels)[:, :1].to(device=z_state.device) outputs = self.backbone( encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), attention_mask=encoder_mask, decoder_input_ids=decoder_input_ids, return_dict=True, ) return outputs.logits[:, 0, :], valid_mask @torch.no_grad() def generate_from_state( self, z_state: torch.Tensor, *, encoder_attention_mask: torch.Tensor | None = None, **generation_kwargs: Any, ) -> torch.Tensor: encoder_mask = ( encoder_attention_mask.to(device=z_state.device, dtype=torch.long) if encoder_attention_mask is not None else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device) ) return self.backbone.generate( encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)), attention_mask=encoder_mask, **generation_kwargs, ) def save_pretrained(self, output_dir: str | Path, tokenizer: Any | None = None) -> None: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) (output_path / "sft_config.json").write_text( json.dumps(self.config, indent=2, ensure_ascii=False), encoding="utf-8", ) # The T5Gemma backbone exposes tied/shared weights, which raw save_file refuses. save_model(self, str(output_path / "model.safetensors")) save_file({"z_init": self.z_init.detach().cpu()}, str(output_path / "initial_latent_z.safetensors")) active_tokenizer = tokenizer or self.tokenizer active_tokenizer.save_pretrained(output_path) @classmethod def from_pretrained( cls, model_path_or_repo_id: str, device: str | torch.device = "cpu", map_location: str | torch.device = "cpu", ) -> "ThoughtLoopT5Gemma": local_path = Path(model_path_or_repo_id) if not local_path.exists(): local_path = Path(snapshot_download(repo_id=model_path_or_repo_id)) config = json.loads((local_path / "sft_config.json").read_text(encoding="utf-8")) model = cls(config) load_model(model, str(local_path / "model.safetensors"), device=str(map_location)) model.to(device) return model