""" MASH Style-injection T5 Model (Multi-Style Version) — v3 Upgrade from BART-base (140M) to Flan-T5-XL (3B). - Same style injection architecture (4 style vectors + fusion layer) - T5 encoder-decoder is native seq2seq, ideal for rewriting - Flan-T5 has instruction-following capability built in - bf16 training on A100 80GB (NOT fp16 — must match autocast dtype) v3b fixes: - Load model in bfloat16 (was fp16, causing NaN with bf16 autocast) - Fusion layer stays in bf16 (no manual dtype casting needed) """ import torch import torch.nn as nn from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config class StyleFusionLayer(nn.Module): """Fuses content representation with style embedding via linear projection + dropout.""" def __init__(self, content_dim: int, style_dim: int, dropout: float = 0.15): super().__init__() self.projection = nn.Linear(content_dim + style_dim, content_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(content_dim) def forward(self, content_hidden: torch.Tensor, style_emb: torch.Tensor) -> torch.Tensor: style_expanded = style_emb.unsqueeze(1).expand(-1, content_hidden.size(1), -1) concat = torch.cat([content_hidden, style_expanded], dim=-1) fused = self.projection(concat) fused = self.dropout(fused) # Residual connection + layer norm fused = self.layer_norm(fused + content_hidden) return fused class StyleT5(nn.Module): """ Flan-T5-XL with multi-style injection for AI→Human text style transfer. 4 style vectors capture the cross-product of: - Voice: human vs AI - Genre: PS (personal narrative) vs Supp (argumentative/academic) """ STYLES = ['human_ps', 'human_supp', 'ai_ps', 'ai_supp'] def __init__(self, model_name: str = 'google/flan-t5-xl', style_dim: int = 128, dropout: float = 0.15): super().__init__() # Load T5 with adjusted dropout config = T5Config.from_pretrained(model_name) config.dropout_rate = dropout # CRITICAL: Use bfloat16 to match autocast dtype (was float16 → caused NaN) self.t5 = T5ForConditionalGeneration.from_pretrained( model_name, config=config, torch_dtype=torch.bfloat16, ) self.tokenizer = T5Tokenizer.from_pretrained(model_name) hidden_dim = self.t5.config.d_model # 2048 for t5-xl self.style_dim = style_dim self.hidden_dim = hidden_dim self.dropout_rate = dropout self.model_name_str = model_name # 4 trainable style embeddings self.style_embeddings = nn.ParameterDict({ 'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02), 'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02), 'ai_ps': nn.Parameter(torch.randn(style_dim) * 0.02), 'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02), }) # Style fusion layer self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout) def get_style_embedding(self, style_keys: list) -> torch.Tensor: embs = [self.style_embeddings[k] for k in style_keys] return torch.stack(embs, dim=0) def encode_with_style(self, input_ids, attention_mask, style_keys: list): encoder = self.t5.get_encoder() encoder_output = encoder( input_ids=input_ids, attention_mask=attention_mask, ) hidden_states = encoder_output.last_hidden_state # Get style embedding and cast to same dtype as hidden states style_emb = self.get_style_embedding(style_keys).to(hidden_states.dtype) # Cast fusion layer to same dtype (it may be fp32 from init) self.fusion = self.fusion.to(hidden_states.dtype) fused = self.fusion(hidden_states, style_emb) encoder_output.last_hidden_state = fused return encoder_output def forward(self, input_ids, attention_mask, labels, style_keys: list): encoder_output = self.encode_with_style(input_ids, attention_mask, style_keys) outputs = self.t5( encoder_outputs=encoder_output, attention_mask=attention_mask, labels=labels, ) return outputs def generate_text(self, input_ids, attention_mask, style_keys: list, max_length: int = 512, num_beams: int = 4, **kwargs): encoder_output = self.encode_with_style(input_ids, attention_mask, style_keys) generated = self.t5.generate( encoder_outputs=encoder_output, attention_mask=attention_mask, max_new_tokens=max_length, num_beams=num_beams, early_stopping=True, no_repeat_ngram_size=3, length_penalty=1.0, **kwargs, ) return generated def save_pretrained(self, path: str): import os os.makedirs(path, exist_ok=True) # Save T5 model self.t5.save_pretrained(os.path.join(path, 't5')) self.tokenizer.save_pretrained(os.path.join(path, 't5')) # Save style components separately (small, always fp32) torch.save({ 'style_embeddings': {k: v.data.float() for k, v in self.style_embeddings.items()}, 'fusion_state_dict': {k: v.float() for k, v in self.fusion.state_dict().items()}, 'style_dim': self.style_dim, 'hidden_dim': self.hidden_dim, 'dropout_rate': self.dropout_rate, 'model_name': self.model_name_str, }, os.path.join(path, 'style_components.pt')) @classmethod def load_pretrained(cls, path: str, device: str = 'cpu'): import os style_data = torch.load( os.path.join(path, 'style_components.pt'), map_location=device, weights_only=True, ) model = cls( model_name=os.path.join(path, 't5'), style_dim=style_data['style_dim'], dropout=style_data.get('dropout_rate', 0.15), ) for k, v in style_data['style_embeddings'].items(): model.style_embeddings[k].data = v.to(device) model.fusion.load_state_dict(style_data['fusion_state_dict']) return model