test-true2 / model.py
BRlkl's picture
Upload folder using huggingface_hub
bc7437c verified
Raw
History Blame Contribute Delete
34.7 kB
from __future__ import annotations
import json
import math
import sys
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from safetensors.torch import load_model, save_file, save_model
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
class ParameterlessRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
original_dtype = inputs.dtype
normalized = inputs.float()
normalized = normalized * torch.rsqrt(normalized.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return normalized.to(dtype=original_dtype)
class TemporalFeatureProjector(nn.Module):
def __init__(self, hidden_size: int, num_time_tokens: int, scalar_count: int = 4) -> None:
super().__init__()
self.hidden_size = hidden_size
self.num_time_tokens = num_time_tokens
self.mlp = nn.Sequential(
nn.Linear(scalar_count, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size * num_time_tokens),
)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, temporal_features: torch.Tensor) -> torch.Tensor:
projected = self.mlp(temporal_features)
projected = projected.view(temporal_features.size(0), self.num_time_tokens, self.hidden_size)
return self.norm(projected)
class ThoughtLoopT5Gemma(nn.Module):
def __init__(self, config: dict[str, Any]) -> None:
super().__init__()
self.config = config
model_cfg = config["model"]
training_cfg = config.get("training", {})
backbone_dtype = getattr(torch, model_cfg.get("dtype", "bfloat16"))
backbone_kwargs: dict[str, Any] = {
# transformers currently warns that torch_dtype is deprecated in favor of dtype,
# but torch_dtype remains compatible across more installed versions.
"torch_dtype": backbone_dtype,
}
attn_implementation = model_cfg.get("attn_implementation")
if attn_implementation:
backbone_kwargs["attn_implementation"] = attn_implementation
self.backbone = AutoModelForSeq2SeqLM.from_pretrained(model_cfg["base_model_name"], **backbone_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_cfg["base_model_name"])
if training_cfg.get("gradient_checkpointing", True):
self.backbone.gradient_checkpointing_enable()
if getattr(self.backbone.config, "use_cache", None) is not None:
self.backbone.config.use_cache = False
self.encoder = self.backbone.get_encoder()
self.decoder = self.backbone.get_decoder()
self.input_embeddings = self.backbone.get_input_embeddings()
hidden_size = int(self.backbone.config.decoder.hidden_size)
self.hidden_size = hidden_size
self.z_slots = int(model_cfg["z_slots"])
self.thought_loop_proposal_mode = str(
model_cfg.get("thought_loop_proposal_mode", "latent_prefix")
).strip().lower()
if self.thought_loop_proposal_mode not in {"latent_prefix", "observation_hidden_compression"}:
raise ValueError(f"Unsupported thought_loop_proposal_mode: {self.thought_loop_proposal_mode}")
self.preserve_observation_encoder_manifold = bool(
model_cfg.get(
"preserve_observation_encoder_manifold",
self.thought_loop_proposal_mode == "observation_hidden_compression",
)
)
self.observation_encoder_use_state_context = bool(
model_cfg.get("observation_encoder_use_state_context", False)
)
self.latent_attention_mask_mode = str(
model_cfg.get("latent_attention_mask_mode", "observed")
).strip().lower()
if self.latent_attention_mask_mode not in {"observed", "full"}:
raise ValueError(f"Unsupported latent_attention_mask_mode: {self.latent_attention_mask_mode}")
self.use_explicit_time_features = bool(model_cfg.get("use_explicit_time_features", False))
self.num_time_tokens = int(model_cfg["num_time_tokens"]) if self.use_explicit_time_features else 0
self.observation_role_count = 3
magicnorm_eps = float(model_cfg.get("magicnorm_eps", 1e-6))
# Keep the newly initialized recurrent/gating modules in fp32. We cast only the
# tensors handed into the bf16 backbone. This avoids dtype crashes while preserving
# stable optimizer state for the custom modules.
self.z_init = nn.Parameter(torch.randn(self.z_slots, hidden_size) * 0.02)
self.segment_embeddings = nn.Embedding(3, hidden_size)
self.observation_role_embeddings = nn.Embedding(self.observation_role_count, hidden_size)
self.temporal_projector = (
TemporalFeatureProjector(hidden_size, self.num_time_tokens) if self.use_explicit_time_features else None
)
self.state_gate = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
self._initialize_update_gate_bias(model_cfg)
self.state_gate_input_norm = ParameterlessRMSNorm(hidden_size * 2, eps=magicnorm_eps)
self.proposed_state_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps)
self.recurrent_state_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps)
self.gate_context_norm = ParameterlessRMSNorm(hidden_size, eps=magicnorm_eps)
self.gate_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.gate_pool = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=int(model_cfg.get("gate_attention_heads", 4)),
batch_first=True,
)
self.gate_head = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, 1),
)
self._drop_unused_vision_modules()
def _initialize_update_gate_bias(self, model_cfg: dict[str, Any]) -> None:
if "initial_update_gate_bias" not in model_cfg:
return
final_gate_layer = self.state_gate[-1]
if not isinstance(final_gate_layer, nn.Linear) or final_gate_layer.bias is None:
raise RuntimeError("state_gate must end in a Linear layer with bias to use initial_update_gate_bias.")
nn.init.constant_(final_gate_layer.bias, float(model_cfg["initial_update_gate_bias"]))
def _drop_unused_vision_modules(self) -> None:
for module_path in (
"encoder.vision_tower",
"encoder.vision_model",
"model.encoder.vision_tower",
"model.encoder.vision_model",
):
parent: Any = self.backbone
path_parts = module_path.split(".")
try:
for part in path_parts[:-1]:
parent = getattr(parent, part)
except AttributeError:
continue
child_name = path_parts[-1]
child = getattr(parent, child_name, None)
if isinstance(child, nn.Module):
setattr(parent, child_name, None)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def backbone_dtype(self) -> torch.dtype:
return self.input_embeddings.weight.dtype
@property
def recurrent_dtype(self) -> torch.dtype:
return self.z_init.dtype
def trainable_parameter_count(self) -> int:
return sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad)
def _decoder_blocks(self) -> Any:
for attribute_name in ("block", "layers", "h"):
blocks = getattr(self.decoder, attribute_name, None)
if blocks is not None:
return blocks
return None
def decoder_layer_count(self) -> int:
blocks = self._decoder_blocks()
return len(blocks) if blocks is not None else 0
def set_gate_trainable(self, trainable: bool) -> None:
self.gate_query.requires_grad_(trainable)
self.gate_pool.requires_grad_(trainable)
self.gate_head.requires_grad_(trainable)
self.gate_context_norm.requires_grad_(trainable)
def set_decoder_trainable_fraction(self, trainable_fraction: float) -> int:
blocks = self._decoder_blocks()
if blocks is None:
return 0
clamped_fraction = min(max(trainable_fraction, 0.0), 1.0)
total_blocks = len(blocks)
trainable_blocks = min(total_blocks, math.ceil(total_blocks * clamped_fraction)) if clamped_fraction > 0 else 0
first_trainable_index = total_blocks - trainable_blocks
for block_index, block in enumerate(blocks):
block.requires_grad_(block_index >= first_trainable_index)
for attribute_name in ("final_layer_norm", "norm", "layer_norm"):
maybe_module = getattr(self.decoder, attribute_name, None)
if isinstance(maybe_module, nn.Module):
maybe_module.requires_grad_(trainable_blocks > 0)
decoder_embeddings = getattr(self.decoder, "embed_tokens", None)
if isinstance(decoder_embeddings, nn.Module):
decoder_embeddings.requires_grad_(trainable_blocks > 0)
for module_path in (
"lm_head",
"model.lm_head",
"backbone.lm_head",
"shared",
"model.shared",
):
parent: Any = self
path_parts = module_path.split(".")
try:
for part in path_parts[:-1]:
parent = getattr(parent, part)
except AttributeError:
continue
child_name = path_parts[-1]
child = getattr(parent, child_name, None)
if isinstance(child, nn.Module):
child.requires_grad_(trainable_blocks > 0)
return trainable_blocks
def initial_state(self, batch_size: int, device: torch.device | None = None) -> torch.Tensor:
target_device = device or self.device
initial_state = self.z_init.unsqueeze(0).expand(batch_size, -1, -1).to(device=target_device)
return self.recurrent_state_norm(initial_state)
def initial_state_mask(self, batch_size: int, device: torch.device | None = None) -> torch.Tensor:
target_device = device or self.device
if self.latent_attention_mask_mode == "full":
return torch.ones(batch_size, self.z_slots, dtype=torch.long, device=target_device)
if self.thought_loop_proposal_mode == "observation_hidden_compression":
return torch.zeros(batch_size, self.z_slots, dtype=torch.long, device=target_device)
return torch.ones(batch_size, self.z_slots, dtype=torch.long, device=target_device)
def _build_temporal_tensor(
self,
delta_seconds: torch.Tensor,
elapsed_seconds: torch.Tensor,
since_last_user_seconds: torch.Tensor,
since_last_agent_seconds: torch.Tensor,
) -> torch.Tensor:
if not self.use_explicit_time_features:
raise RuntimeError("Temporal features were requested, but this model is configured to disable them.")
stacked = torch.stack(
[
torch.log1p(delta_seconds),
torch.log1p(elapsed_seconds),
torch.log1p(since_last_user_seconds),
torch.log1p(since_last_agent_seconds),
],
dim=-1,
)
return stacked
def compress_hidden_states_to_slots(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
target_slots: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
slots = int(target_slots if target_slots is not None else self.z_slots)
batch_size, _, hidden_size = hidden_states.shape
compressed_hidden = hidden_states.new_zeros((batch_size, slots, hidden_size))
compressed_mask = torch.zeros(batch_size, slots, dtype=torch.long, device=hidden_states.device)
attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.bool)
for batch_index in range(batch_size):
valid_states = hidden_states[batch_index, attention_mask[batch_index]]
valid_length = int(valid_states.shape[0])
if valid_length <= 0:
continue
if valid_length <= slots:
compressed_hidden[batch_index, :valid_length] = valid_states
compressed_mask[batch_index, :valid_length] = 1
continue
for slot_index in range(slots):
start = (slot_index * valid_length) // slots
end = ((slot_index + 1) * valid_length) // slots
if end <= start:
end = min(valid_length, start + 1)
compressed_hidden[batch_index, slot_index] = valid_states[start:end].mean(dim=0)
compressed_mask[batch_index, slot_index] = 1
return compressed_hidden, compressed_mask
def _rollout_step_impl(
self,
z_state: torch.Tensor,
observation_input_ids: torch.Tensor,
observation_attention_mask: torch.Tensor,
observation_role_ids: torch.Tensor,
delta_seconds: torch.Tensor,
elapsed_seconds: torch.Tensor,
since_last_user_seconds: torch.Tensor,
since_last_agent_seconds: torch.Tensor,
previous_state_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = z_state.size(0)
recurrent_dtype = z_state.dtype
backbone_dtype = self.backbone_dtype
if self.thought_loop_proposal_mode == "latent_prefix":
aux_tokens: list[torch.Tensor] = []
aux_masks: list[torch.Tensor] = []
if self.use_explicit_time_features:
temporal_features = self._build_temporal_tensor(
delta_seconds=delta_seconds,
elapsed_seconds=elapsed_seconds,
since_last_user_seconds=since_last_user_seconds,
since_last_agent_seconds=since_last_agent_seconds,
).to(dtype=recurrent_dtype)
assert self.temporal_projector is not None
time_tokens = self.temporal_projector(temporal_features)
time_tokens = time_tokens + self.segment_embeddings.weight[1].view(1, 1, -1).to(dtype=recurrent_dtype)
aux_tokens.append(time_tokens)
aux_masks.append(torch.ones(batch_size, self.num_time_tokens, device=z_state.device, dtype=torch.long))
role_tokens = self.observation_role_embeddings(observation_role_ids).unsqueeze(1).to(dtype=recurrent_dtype)
role_tokens = role_tokens + self.segment_embeddings.weight[1].view(1, 1, -1).to(dtype=recurrent_dtype)
aux_tokens.append(role_tokens)
aux_masks.append(torch.ones(batch_size, 1, device=z_state.device, dtype=torch.long))
observation_embeds = self.input_embeddings(observation_input_ids).to(dtype=recurrent_dtype)
observation_embeds = observation_embeds + self.segment_embeddings.weight[2].view(1, 1, -1).to(
dtype=recurrent_dtype
)
observation_embeds = observation_embeds + self.observation_role_embeddings(observation_role_ids).unsqueeze(1).to(
dtype=recurrent_dtype
)
z_tokens = z_state + self.segment_embeddings.weight[0].view(1, 1, -1).to(dtype=recurrent_dtype)
encoder_inputs = torch.cat([z_tokens, *aux_tokens, observation_embeds], dim=1)
z_mask = torch.ones(batch_size, self.z_slots, device=z_state.device, dtype=torch.long)
encoder_attention_mask = torch.cat([z_mask, *aux_masks, observation_attention_mask.long()], dim=1)
# The pretrained T5Gemma backbone was loaded in bf16. Passing fp32 inputs_embeds
# into its bf16 Linear layers causes: expected mat1 and mat2 to have same dtype.
encoder_outputs = self.encoder(
inputs_embeds=encoder_inputs.to(dtype=backbone_dtype),
attention_mask=encoder_attention_mask,
return_dict=True,
)
latent_prefix_state = encoder_outputs.last_hidden_state[:, : self.z_slots, :].to(dtype=recurrent_dtype)
proposed_state = latent_prefix_state
proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device)
else:
use_state_context = (
self.observation_encoder_use_state_context
and previous_state_mask is not None
and torch.any(previous_state_mask > 0)
)
if use_state_context:
observation_embeds = self.input_embeddings(observation_input_ids).to(dtype=recurrent_dtype)
encoder_inputs = torch.cat([z_state, observation_embeds], dim=1)
encoder_attention_mask = torch.cat(
[
previous_state_mask.to(device=z_state.device, dtype=torch.long),
observation_attention_mask.long(),
],
dim=1,
)
observation_encoder_outputs = self.encoder(
inputs_embeds=encoder_inputs.to(dtype=backbone_dtype),
attention_mask=encoder_attention_mask,
return_dict=True,
)
if self.latent_attention_mask_mode == "full":
proposed_state = observation_encoder_outputs.last_hidden_state[:, : self.z_slots, :].to(
dtype=recurrent_dtype
)
proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device)
observation_outputs = None
else:
observation_outputs = observation_encoder_outputs.last_hidden_state[:, self.z_slots :, :].to(
dtype=recurrent_dtype
)
else:
observation_encoder_outputs = self.encoder(
input_ids=observation_input_ids,
attention_mask=observation_attention_mask.long(),
return_dict=True,
)
observation_outputs = observation_encoder_outputs.last_hidden_state.to(dtype=recurrent_dtype)
if observation_outputs is not None:
proposed_state, proposal_mask = self.compress_hidden_states_to_slots(
observation_outputs,
observation_attention_mask,
target_slots=self.z_slots,
)
proposed_state = proposed_state.to(dtype=recurrent_dtype)
if self.latent_attention_mask_mode == "full":
proposed_state = proposed_state.clone()
inactive_slot_mask = ~proposal_mask.bool()
proposed_state[inactive_slot_mask] = z_state[inactive_slot_mask]
proposal_mask = torch.ones(batch_size, self.z_slots, dtype=torch.long, device=z_state.device)
no_observation_mask = proposal_mask.sum(dim=1) <= 0
if torch.any(no_observation_mask):
proposed_state = proposed_state.clone()
proposal_mask = proposal_mask.clone()
proposed_state[no_observation_mask] = z_state[no_observation_mask]
if previous_state_mask is None:
proposal_mask[no_observation_mask] = 1
else:
proposal_mask[no_observation_mask] = previous_state_mask[no_observation_mask].to(
device=z_state.device,
dtype=torch.long,
)
if not self.preserve_observation_encoder_manifold:
proposed_state = self.proposed_state_norm(proposed_state)
gate_inputs = self.state_gate_input_norm(torch.cat([z_state, proposed_state], dim=-1))
update_gate = torch.sigmoid(self.state_gate(gate_inputs))
raw_next_state = update_gate * proposed_state + (1.0 - update_gate) * z_state
active_slot_mask = proposal_mask.bool().unsqueeze(-1)
next_state = torch.where(active_slot_mask, raw_next_state, z_state)
if not self.preserve_observation_encoder_manifold:
next_state = self.recurrent_state_norm(next_state)
if previous_state_mask is not None:
next_state_mask = torch.maximum(
previous_state_mask.to(device=z_state.device, dtype=torch.long),
proposal_mask,
)
else:
next_state_mask = proposal_mask
pooled_query = self.gate_query.to(dtype=next_state.dtype).expand(batch_size, -1, -1)
gate_context = self.gate_context_norm(next_state)
pooled_state, _ = self.gate_pool(pooled_query, gate_context, gate_context, need_weights=False)
gate_logits = self.gate_head(pooled_state.squeeze(1)).squeeze(-1)
return next_state, gate_logits, next_state_mask
def rollout_step(
self,
z_state: torch.Tensor,
observation_input_ids: torch.Tensor,
observation_attention_mask: torch.Tensor,
observation_role_ids: torch.Tensor,
delta_seconds: torch.Tensor,
elapsed_seconds: torch.Tensor,
since_last_user_seconds: torch.Tensor,
since_last_agent_seconds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
next_state, gate_logits, _ = self._rollout_step_impl(
z_state=z_state,
observation_input_ids=observation_input_ids,
observation_attention_mask=observation_attention_mask,
observation_role_ids=observation_role_ids,
delta_seconds=delta_seconds,
elapsed_seconds=elapsed_seconds,
since_last_user_seconds=since_last_user_seconds,
since_last_agent_seconds=since_last_agent_seconds,
previous_state_mask=None,
)
return next_state, gate_logits
def rollout_step_with_mask(
self,
z_state: torch.Tensor,
state_mask: torch.Tensor,
observation_input_ids: torch.Tensor,
observation_attention_mask: torch.Tensor,
observation_role_ids: torch.Tensor,
delta_seconds: torch.Tensor,
elapsed_seconds: torch.Tensor,
since_last_user_seconds: torch.Tensor,
since_last_agent_seconds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self._rollout_step_impl(
z_state=z_state,
observation_input_ids=observation_input_ids,
observation_attention_mask=observation_attention_mask,
observation_role_ids=observation_role_ids,
delta_seconds=delta_seconds,
elapsed_seconds=elapsed_seconds,
since_last_user_seconds=since_last_user_seconds,
since_last_agent_seconds=since_last_agent_seconds,
previous_state_mask=state_mask,
)
def _resolve_generation_eos_token_id(self) -> int | list[int] | None:
eos_token_id = getattr(self.backbone.generation_config, "eos_token_id", None)
if eos_token_id is None:
eos_token_id = getattr(self.backbone.config, "eos_token_id", None)
return eos_token_id
def _resolve_fill_token_id(self) -> int:
eos_token_id = self._resolve_generation_eos_token_id()
if isinstance(eos_token_id, list) and eos_token_id:
return int(eos_token_id[0])
if isinstance(eos_token_id, int):
return eos_token_id
if self.tokenizer.eos_token_id is not None:
return int(self.tokenizer.eos_token_id)
if self.tokenizer.pad_token_id is not None:
return int(self.tokenizer.pad_token_id)
return 0
def _teacher_forced_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor:
if labels.ndim != 2 or labels.size(0) <= 0:
raise ValueError("labels must be a rank-2 tensor with non-zero batch size.")
safe_labels = labels.clone()
safe_labels[safe_labels == -100] = self._resolve_fill_token_id()
if hasattr(self.backbone, "prepare_decoder_input_ids_from_labels"):
return self.backbone.prepare_decoder_input_ids_from_labels(labels=safe_labels)
return self.backbone._shift_right(safe_labels)
def _build_self_generated_decoder_inputs(
self,
labels: torch.Tensor,
generated: torch.Tensor,
*,
self_generated_prefix_tokens: int,
) -> torch.Tensor:
batch_size, target_length = labels.shape
teacher_inputs = self._teacher_forced_decoder_inputs(labels).to(device=labels.device)
if target_length <= 0:
return teacher_inputs
prefix_token_count = min(max(self_generated_prefix_tokens, 0), max(target_length - 1, 0))
if prefix_token_count <= 0:
return teacher_inputs
generated_tokens = generated.to(device=labels.device, dtype=torch.long)
if generated_tokens.ndim != 2:
raise ValueError("generated tokens must be rank-2.")
if generated_tokens.size(0) != batch_size:
raise ValueError("generated batch size must match labels batch size.")
if generated_tokens.size(1) > 0 and torch.equal(generated_tokens[:, :1], teacher_inputs[:, :1]):
generated_tokens = generated_tokens[:, 1:]
fill_token_id = self._resolve_fill_token_id()
if generated_tokens.size(1) < prefix_token_count:
padding = labels.new_full(
(batch_size, prefix_token_count - generated_tokens.size(1)),
fill_value=fill_token_id,
)
generated_tokens = torch.cat([generated_tokens, padding], dim=1)
else:
generated_tokens = generated_tokens[:, :prefix_token_count]
hybrid_inputs = teacher_inputs.clone()
hybrid_inputs[:, 1 : 1 + prefix_token_count] = generated_tokens
return hybrid_inputs
def decoder_loss(
self,
z_state: torch.Tensor,
labels: torch.Tensor,
*,
encoder_attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if z_state.numel() == 0:
zero = torch.zeros((), device=self.device)
return zero, zero
encoder_mask = (
encoder_attention_mask.to(device=z_state.device, dtype=torch.long)
if encoder_attention_mask is not None
else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device)
)
token_count = (labels != -100).sum()
if token_count.item() == 0:
zero = torch.zeros((), device=z_state.device)
return zero, zero
outputs = self.backbone(
encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)),
attention_mask=encoder_mask,
labels=labels,
return_dict=True,
)
loss_sum = outputs.loss * token_count
return loss_sum, token_count
def decoder_self_generated_loss(
self,
z_state: torch.Tensor,
labels: torch.Tensor,
*,
generation_kwargs: dict[str, Any] | None = None,
self_generated_prefix_tokens: int | None = None,
encoder_attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if z_state.numel() == 0:
zero = torch.zeros((), device=self.device)
return zero, zero
encoder_mask = (
encoder_attention_mask.to(device=z_state.device, dtype=torch.long)
if encoder_attention_mask is not None
else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device)
)
token_count = (labels != -100).sum()
if token_count.item() == 0:
zero = torch.zeros((), device=z_state.device)
return zero, zero
prefix_token_count = max(0, int(self_generated_prefix_tokens if self_generated_prefix_tokens is not None else labels.shape[1]))
if prefix_token_count <= 0:
return self.decoder_loss(z_state, labels, encoder_attention_mask=encoder_attention_mask)
effective_generation_kwargs = {
"max_new_tokens": prefix_token_count,
"do_sample": False,
"num_beams": 1,
"return_dict_in_generate": False,
}
if generation_kwargs:
effective_generation_kwargs.update(generation_kwargs)
was_training = self.backbone.training
if was_training:
self.backbone.eval()
try:
with torch.no_grad():
generated = self.backbone.generate(
encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)),
attention_mask=encoder_mask,
**effective_generation_kwargs,
)
finally:
if was_training:
self.backbone.train()
decoder_input_ids = self._build_self_generated_decoder_inputs(
labels,
generated,
self_generated_prefix_tokens=prefix_token_count,
)
outputs = self.backbone(
encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)),
attention_mask=encoder_mask,
decoder_input_ids=decoder_input_ids,
return_dict=True,
)
logits = outputs.logits.float()
loss_sum = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
return loss_sum, token_count
@torch.no_grad()
def first_token_exact_match_stats(
self,
z_state: torch.Tensor,
labels: torch.Tensor,
*,
encoder_attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if z_state.numel() == 0:
zero = torch.zeros((), device=self.device)
return zero, zero
first_logits, valid_mask = self.first_token_logits(
z_state,
labels,
encoder_attention_mask=encoder_attention_mask,
)
valid_count = valid_mask.sum()
if valid_count.item() == 0:
zero = torch.zeros((), device=self.device)
return zero, zero
first_targets = labels[:, 0]
predicted_first_tokens = first_logits.argmax(dim=-1)
correct = ((predicted_first_tokens == first_targets) & valid_mask).sum()
return correct, valid_count
def first_token_logits(
self,
z_state: torch.Tensor,
labels: torch.Tensor,
*,
encoder_attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if z_state.numel() == 0:
empty_logits = torch.zeros((0, 0), device=self.device)
empty_mask = torch.zeros((0,), dtype=torch.bool, device=self.device)
return empty_logits, empty_mask
first_targets = labels[:, 0]
valid_mask = first_targets != -100
encoder_mask = (
encoder_attention_mask.to(device=z_state.device, dtype=torch.long)
if encoder_attention_mask is not None
else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device)
)
decoder_input_ids = self._teacher_forced_decoder_inputs(labels)[:, :1].to(device=z_state.device)
outputs = self.backbone(
encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)),
attention_mask=encoder_mask,
decoder_input_ids=decoder_input_ids,
return_dict=True,
)
return outputs.logits[:, 0, :], valid_mask
@torch.no_grad()
def generate_from_state(
self,
z_state: torch.Tensor,
*,
encoder_attention_mask: torch.Tensor | None = None,
**generation_kwargs: Any,
) -> torch.Tensor:
encoder_mask = (
encoder_attention_mask.to(device=z_state.device, dtype=torch.long)
if encoder_attention_mask is not None
else torch.ones(z_state.shape[:2], dtype=torch.long, device=z_state.device)
)
return self.backbone.generate(
encoder_outputs=BaseModelOutput(last_hidden_state=z_state.to(dtype=self.backbone_dtype)),
attention_mask=encoder_mask,
**generation_kwargs,
)
def save_pretrained(self, output_dir: str | Path, tokenizer: Any | None = None) -> None:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
(output_path / "sft_config.json").write_text(
json.dumps(self.config, indent=2, ensure_ascii=False),
encoding="utf-8",
)
# The T5Gemma backbone exposes tied/shared weights, which raw save_file refuses.
save_model(self, str(output_path / "model.safetensors"))
save_file({"z_init": self.z_init.detach().cpu()}, str(output_path / "initial_latent_z.safetensors"))
active_tokenizer = tokenizer or self.tokenizer
active_tokenizer.save_pretrained(output_path)
@classmethod
def from_pretrained(
cls,
model_path_or_repo_id: str,
device: str | torch.device = "cpu",
map_location: str | torch.device = "cpu",
) -> "ThoughtLoopT5Gemma":
local_path = Path(model_path_or_repo_id)
if not local_path.exists():
local_path = Path(snapshot_download(repo_id=model_path_or_repo_id))
config = json.loads((local_path / "sft_config.json").read_text(encoding="utf-8"))
model = cls(config)
load_model(model, str(local_path / "model.safetensors"), device=str(map_location))
model.to(device)
return model