from __future__ import annotations import typing as T from dataclasses import dataclass, fields import torch import torch.nn as nn import torch.nn.functional as F @dataclass class TTTConfig: lr: float = 4e-4 steps: int = 30 ags: int = 16 batch_size: int = 2 mask_ratio: float = 0.15 crop_size: int = 1024 bert_leave_prob: float = 0.1 bert_replace_prob: float = 0.1 optimizer: str = "sgd" momentum: float = 0.0 weight_decay: float = 0.0 seed: int | None = 0 lora_rank: int = 8 lora_alpha: float = 32.0 lora_target_replace_module: str | None = None lora_target_modules: tuple[str, ...] | None = None initial_state_reset: bool = True automatic_best_state_reset: bool = False eval_each_step: bool = False gradient_clip: bool = False gradient_clip_max_norm: float = 1.0 @classmethod def from_kwargs(cls, **kwargs: T.Any) -> "TTTConfig": valid_names = {field.name for field in fields(cls)} unknown_names = set(kwargs) - valid_names assert len(unknown_names) == 0, f"Unknown TTTConfig fields: {sorted(unknown_names)}" return cls(**kwargs) def merged(self, overrides: T.Mapping[str, T.Any] | "TTTConfig" | None) -> "TTTConfig": if overrides is None: return self if isinstance(overrides, TTTConfig): return overrides values = {field.name: self.__dict__[field.name] for field in fields(self)} for name, value in overrides.items(): assert name in values, f"Unknown TTTConfig field: {name}" values[name] = value return TTTConfig(**values) def verify(self) -> None: assert self.lr > 0.0, "TTT learning rate must be positive." assert self.steps >= 1, "TTT steps must be >= 1." assert self.ags >= 1, "TTT gradient accumulation steps must be >= 1." assert self.batch_size >= 1, "TTT batch_size must be >= 1." assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]." assert self.crop_size >= 1, "TTT crop_size must be >= 1." assert self.lora_rank >= 1, "TTT v1 is LoRA-only, so lora_rank must be >= 1." assert self.lora_alpha > 0.0, "TTT lora_alpha must be positive." assert self.optimizer in {"adamw", "sgd"}, "TTT optimizer must be 'adamw' or 'sgd'." assert 0.0 <= self.bert_leave_prob <= 1.0, "bert_leave_prob must be in [0, 1]." assert 0.0 <= self.bert_replace_prob <= 1.0, "bert_replace_prob must be in [0, 1]." assert self.bert_leave_prob + self.bert_replace_prob <= 1.0, ( "bert_leave_prob + bert_replace_prob must be <= 1." ) if self.gradient_clip: assert self.gradient_clip_max_norm > 0.0, "gradient_clip_max_norm must be positive." class LoraInjectedLinear(nn.Module): def __init__(self, linear: nn.Module, rank: int, alpha: float) -> None: super().__init__() weight = linear._parameters["weight"] assert weight.ndim == 2, "LoRA can only wrap 2D linear weights." self.linear = linear self.linear.requires_grad_(False) self.rank = rank self.scale = alpha in_features = weight.shape[1] out_features = weight.shape[0] self.lora_down = nn.Linear(in_features, rank, bias=False, dtype=torch.float32) self.lora_up = nn.Linear(rank, out_features, bias=False, dtype=torch.float32) self.lora_down.to(device=weight.device) self.lora_up.to(device=weight.device) nn.init.normal_(self.lora_down.weight, std=1.0 / rank) nn.init.zeros_(self.lora_up.weight) @property def weight(self) -> torch.Tensor: return self.linear._parameters["weight"] @property def bias(self) -> torch.Tensor | None: return self.linear._parameters["bias"] def forward(self, x: torch.Tensor) -> torch.Tensor: base = self.linear(x) delta = self.lora_up(self.lora_down(x.to(dtype=torch.float32))) * self.scale return base + delta.to(dtype=base.dtype) class FastPLMTestTimeTrainingMixin: def init_ttt(self, ttt_config: TTTConfig | T.Mapping[str, T.Any] | None = None) -> None: base_config = TTTConfig() self._ttt_cfg = base_config.merged(ttt_config) self._ttt_cfg.verify() self._ttt_initialized = False self._ttt_initial_state: list[dict[str, torch.Tensor]] | None = None @property def ttt_config(self) -> TTTConfig: if "_ttt_cfg" not in self.__dict__: self.init_ttt() return self._ttt_cfg def _ttt_get_trainable_modules(self) -> list[nn.Module]: return [self] def _ttt_get_frozen_modules(self) -> list[nn.Module]: return [] def _ttt_tokenize( self, seq: str | list[str] | None = None, input_ids: torch.Tensor | None = None, **kwargs: T.Any, ) -> torch.Tensor | dict[str, torch.Tensor]: del kwargs if input_ids is not None: return input_ids assert seq is not None, "Pass either seq or input_ids for TTT." tokenized = self.tokenizer(seq, return_tensors="pt", padding=True) return tokenized["input_ids"] def _ttt_mask_token(self) -> int: return int(self.tokenizer.mask_token_id) def _ttt_padding_token(self) -> int: return int(self.tokenizer.pad_token_id) def _ttt_replacement_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: tokenizer = self.tokenizer special_ids = set(tokenizer.all_special_ids) vocab_size = int(self.config.vocab_size) ids = [idx for idx in range(vocab_size) if idx not in special_ids] assert len(ids) > 0, "TTT replacement token set is empty." return torch.tensor(ids, device=input_ids.device, dtype=input_ids.dtype) def _ttt_predict_logits( self, batch: torch.Tensor | dict[str, torch.Tensor], **kwargs: T.Any, ) -> torch.Tensor: del kwargs if isinstance(batch, dict): output = self(**batch) return output.logits attention_mask = batch.ne(self._ttt_padding_token()) output = self(input_ids=batch, attention_mask=attention_mask) return output.logits def _ttt_eval_step( self, step: int, loss: float, seq: str | list[str] | None = None, input_ids: torch.Tensor | None = None, **kwargs: T.Any, ) -> tuple[dict[str, T.Any], float | None]: del step, loss, seq, input_ids, kwargs return {}, None def _ttt_is_lora_target( self, name: str, full_name: str, module: nn.Module, active: bool, target_modules: tuple[str, ...] | None, ) -> bool: if not active: return False if isinstance(module, LoraInjectedLinear): return False if ( target_modules is not None and name not in target_modules and full_name not in target_modules ): return False if isinstance(module, nn.Linear): return True if "weight" not in module._parameters: return False weight = module._parameters["weight"] if weight is None or weight.ndim != 2: return False return "Linear" in module.__class__.__name__ def _ttt_inject_lora(self) -> int: cfg = self.ttt_config cfg.verify() target_class = cfg.lora_target_replace_module target_modules = cfg.lora_target_modules wrapped = 0 def inject(module: nn.Module, prefix: str, active: bool) -> None: nonlocal wrapped for name, child in list(module.named_children()): full_name = f"{prefix}.{name}" if prefix else name child_active = active if target_class is not None: child_active = active or child.__class__.__name__ == target_class if self._ttt_is_lora_target(name, full_name, child, child_active, target_modules): setattr( module, name, LoraInjectedLinear(child, rank=cfg.lora_rank, alpha=cfg.lora_alpha), ) wrapped += 1 continue inject(child, full_name, child_active) for trainable_module in self._ttt_get_trainable_modules(): inject(trainable_module, "", target_class is None) assert wrapped > 0, "TTT LoRA injection did not find any target modules." return wrapped def _ttt_lora_modules(self) -> list[LoraInjectedLinear]: return [module for module in self.modules() if isinstance(module, LoraInjectedLinear)] def _ttt_lora_parameters(self) -> list[nn.Parameter]: params: list[nn.Parameter] = [] for module in self._ttt_lora_modules(): params.extend(module.lora_down.parameters()) params.extend(module.lora_up.parameters()) assert len(params) > 0, "TTT has no LoRA parameters." return params def _ttt_snapshot_lora_state(self) -> list[dict[str, torch.Tensor]]: snapshot = [] for module in self._ttt_lora_modules(): snapshot.append( { "lora_down.weight": module.lora_down.weight.detach().clone(), "lora_up.weight": module.lora_up.weight.detach().clone(), } ) assert len(snapshot) > 0, "TTT has no LoRA state to snapshot." return snapshot def _ttt_restore_lora_state(self, state: list[dict[str, torch.Tensor]]) -> None: modules = self._ttt_lora_modules() assert len(modules) == len(state), "TTT LoRA state/module count mismatch." with torch.no_grad(): for module, module_state in zip(modules, state): module.lora_down.weight.copy_(module_state["lora_down.weight"]) module.lora_up.weight.copy_(module_state["lora_up.weight"]) def _ttt_ensure_initialized(self) -> None: if "_ttt_cfg" not in self.__dict__: self.init_ttt() if self._ttt_initialized: return self._ttt_inject_lora() self._ttt_initial_state = self._ttt_snapshot_lora_state() self._ttt_initialized = True def ttt_reset(self) -> None: self._ttt_ensure_initialized() assert self._ttt_initial_state is not None, "TTT initial state is not available." self._ttt_restore_lora_state(self._ttt_initial_state) def _ttt_make_optimizer(self) -> torch.optim.Optimizer: cfg = self.ttt_config params = self._ttt_lora_parameters() if cfg.optimizer == "sgd": return torch.optim.SGD( params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay, ) return torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay) def _ttt_to_device( self, batch: torch.Tensor | dict[str, torch.Tensor], device: torch.device, ) -> torch.Tensor | dict[str, torch.Tensor]: if isinstance(batch, dict): return {name: tensor.to(device) for name, tensor in batch.items()} return batch.to(device) def _ttt_input_ids_from_batch( self, batch: torch.Tensor | dict[str, torch.Tensor], ) -> torch.Tensor: if isinstance(batch, dict): return batch["input_ids"] return batch def _ttt_set_input_ids( self, batch: torch.Tensor | dict[str, torch.Tensor], input_ids: torch.Tensor, ) -> torch.Tensor | dict[str, torch.Tensor]: if isinstance(batch, dict): updated = dict(batch) updated["input_ids"] = input_ids return updated return input_ids def _ttt_non_special_mask(self, input_ids: torch.Tensor) -> torch.Tensor: pad_token = self._ttt_padding_token() mask = input_ids.ne(pad_token) special_ids = set(self.tokenizer.all_special_ids) for special_id in special_ids: mask = mask & input_ids.ne(int(special_id)) return mask def _ttt_sample_crop( self, batch: torch.Tensor | dict[str, torch.Tensor], generator: torch.Generator, ) -> torch.Tensor | dict[str, torch.Tensor]: input_ids = self._ttt_input_ids_from_batch(batch) cfg = self.ttt_config if input_ids.shape[1] <= cfg.crop_size: return batch high = input_ids.shape[1] - cfg.crop_size + 1 start = int( torch.randint( high, (1,), generator=generator, device=input_ids.device, ).item() ) end = start + cfg.crop_size if isinstance(batch, dict): cropped = {} for name, tensor in batch.items(): if tensor.ndim >= 2 and tensor.shape[1] == input_ids.shape[1]: cropped[name] = tensor[:, start:end] else: cropped[name] = tensor return cropped return input_ids[:, start:end] def _ttt_sample_batch( self, tokenized: torch.Tensor | dict[str, torch.Tensor], generator: torch.Generator, ) -> tuple[torch.Tensor | dict[str, torch.Tensor], torch.Tensor]: cfg = self.ttt_config batch = self._ttt_sample_crop(tokenized, generator) input_ids = self._ttt_input_ids_from_batch(batch) rows = torch.randint( input_ids.shape[0], (cfg.batch_size,), generator=generator, device=input_ids.device, ) if isinstance(batch, dict): sampled: torch.Tensor | dict[str, torch.Tensor] = {} for name, tensor in batch.items(): if tensor.ndim >= 1 and tensor.shape[0] == input_ids.shape[0]: sampled[name] = tensor.index_select(0, rows) else: sampled[name] = tensor else: sampled = input_ids.index_select(0, rows) sampled_ids = self._ttt_input_ids_from_batch(sampled) labels = sampled_ids.clone() non_special = self._ttt_non_special_mask(sampled_ids) label_mask = torch.zeros_like(non_special) for row_idx in range(sampled_ids.shape[0]): candidate_positions = torch.where(non_special[row_idx])[0] if candidate_positions.numel() == 0: continue num_mask = max(1, int(round(candidate_positions.numel() * cfg.mask_ratio))) order = torch.randperm( candidate_positions.numel(), generator=generator, device=sampled_ids.device, ) chosen = candidate_positions[order[:num_mask]] label_mask[row_idx, chosen] = True labels = labels.masked_fill(~label_mask, -100) masked_ids = sampled_ids.clone() chosen_positions = torch.where(label_mask) if chosen_positions[0].numel() > 0: random_values = torch.rand( chosen_positions[0].shape, generator=generator, device=sampled_ids.device, ) leave = random_values < cfg.bert_leave_prob replace = (random_values >= cfg.bert_leave_prob) & ( random_values < cfg.bert_leave_prob + cfg.bert_replace_prob ) mask = ~(leave | replace) if mask.any(): masked_ids[ chosen_positions[0][mask], chosen_positions[1][mask], ] = self._ttt_mask_token() if replace.any(): replacement_tokens = self._ttt_replacement_tokens(sampled_ids) replacement_idx = torch.randint( replacement_tokens.shape[0], (int(replace.sum().item()),), generator=generator, device=sampled_ids.device, ) masked_ids[ chosen_positions[0][replace], chosen_positions[1][replace], ] = replacement_tokens[replacement_idx] return self._ttt_set_input_ids(sampled, masked_ids), labels def ttt( self, seq: str | list[str] | None = None, input_ids: torch.Tensor | None = None, ttt_config: TTTConfig | T.Mapping[str, T.Any] | None = None, **kwargs: T.Any, ) -> dict[str, T.Any]: if ttt_config is not None: if "_ttt_initialized" in self.__dict__ and self._ttt_initialized: next_cfg = self.ttt_config.merged(ttt_config) assert next_cfg.lora_rank == self.ttt_config.lora_rank, ( "Changing lora_rank after TTT initialization is not supported." ) assert next_cfg.lora_alpha == self.ttt_config.lora_alpha, ( "Changing lora_alpha after TTT initialization is not supported." ) assert ( next_cfg.lora_target_replace_module == self.ttt_config.lora_target_replace_module ), "Changing LoRA target class after TTT initialization is not supported." assert next_cfg.lora_target_modules == self.ttt_config.lora_target_modules, ( "Changing LoRA target modules after TTT initialization is not supported." ) self._ttt_cfg = next_cfg else: self.init_ttt(ttt_config) self._ttt_ensure_initialized() cfg = self.ttt_config if cfg.initial_state_reset: self.ttt_reset() device = next(self.parameters()).device tokenized = self._ttt_tokenize(seq=seq, input_ids=input_ids, **kwargs) tokenized = self._ttt_to_device(tokenized, device) generator_device = device if device.type == "cuda" else torch.device("cpu") generator = torch.Generator(device=generator_device) if cfg.seed is not None: generator.manual_seed(cfg.seed) module_modes = {module: module.training for module in self.modules()} requires_grad = {param: param.requires_grad for param in self.parameters()} losses: list[float] = [] step_metrics: list[dict[str, T.Any]] = [] best_state: list[dict[str, torch.Tensor]] | None = None best_metric: float | None = None best_step = 0 try: self.train() for param in self.parameters(): param.requires_grad_(False) for param in self._ttt_lora_parameters(): param.requires_grad_(True) optimizer = self._ttt_make_optimizer() optimizer.zero_grad(set_to_none=True) total_micro_steps = cfg.steps * cfg.ags for micro_step in range(total_micro_steps): batch, labels = self._ttt_sample_batch(tokenized, generator) logits = self._ttt_predict_logits(batch, **kwargs) labels = labels.to(device=logits.device) loss = F.cross_entropy( logits.reshape(-1, logits.shape[-1]), labels.reshape(-1), ignore_index=-100, ) (loss / cfg.ags).backward() if (micro_step + 1) % cfg.ags != 0: continue if cfg.gradient_clip: torch.nn.utils.clip_grad_norm_( self._ttt_lora_parameters(), cfg.gradient_clip_max_norm, ) optimizer.step() optimizer.zero_grad(set_to_none=True) step = (micro_step + 1) // cfg.ags loss_value = float(loss.detach().item()) losses.append(loss_value) if cfg.eval_each_step: metrics, metric = self._ttt_eval_step( step=step, loss=loss_value, seq=seq, input_ids=input_ids, **kwargs, ) if len(metrics) > 0: step_metrics.append(metrics) if metric is not None and ( best_metric is None or metric > best_metric ): best_metric = metric best_step = step best_state = self._ttt_snapshot_lora_state() if cfg.automatic_best_state_reset and best_state is not None: self._ttt_restore_lora_state(best_state) finally: for param, value in requires_grad.items(): param.requires_grad_(value) for module, training in module_modes.items(): module.train(training) return { "losses": losses, "step_metrics": step_metrics, "best_step": best_step, "best_metric": best_metric, }