| """ |
| MASH Style-injection T5 Model (Multi-Style Version) — v3 |
| |
| Upgrade from BART-base (140M) to Flan-T5-XL (3B). |
| - Same style injection architecture (4 style vectors + fusion layer) |
| - T5 encoder-decoder is native seq2seq, ideal for rewriting |
| - Flan-T5 has instruction-following capability built in |
| - bf16 training on A100 80GB (NOT fp16 — must match autocast dtype) |
| |
| v3b fixes: |
| - Load model in bfloat16 (was fp16, causing NaN with bf16 autocast) |
| - Fusion layer stays in bf16 (no manual dtype casting needed) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config |
|
|
|
|
| class StyleFusionLayer(nn.Module): |
| """Fuses content representation with style embedding via linear projection + dropout.""" |
| |
| def __init__(self, content_dim: int, style_dim: int, dropout: float = 0.15): |
| super().__init__() |
| self.projection = nn.Linear(content_dim + style_dim, content_dim) |
| self.dropout = nn.Dropout(dropout) |
| self.layer_norm = nn.LayerNorm(content_dim) |
| |
| def forward(self, content_hidden: torch.Tensor, style_emb: torch.Tensor) -> torch.Tensor: |
| style_expanded = style_emb.unsqueeze(1).expand(-1, content_hidden.size(1), -1) |
| concat = torch.cat([content_hidden, style_expanded], dim=-1) |
| fused = self.projection(concat) |
| fused = self.dropout(fused) |
| |
| fused = self.layer_norm(fused + content_hidden) |
| return fused |
|
|
|
|
| class StyleT5(nn.Module): |
| """ |
| Flan-T5-XL with multi-style injection for AI→Human text style transfer. |
| |
| 4 style vectors capture the cross-product of: |
| - Voice: human vs AI |
| - Genre: PS (personal narrative) vs Supp (argumentative/academic) |
| """ |
| |
| STYLES = ['human_ps', 'human_supp', 'ai_ps', 'ai_supp'] |
| |
| def __init__(self, model_name: str = 'google/flan-t5-xl', style_dim: int = 128, |
| dropout: float = 0.15): |
| super().__init__() |
| |
| |
| config = T5Config.from_pretrained(model_name) |
| config.dropout_rate = dropout |
| |
| |
| self.t5 = T5ForConditionalGeneration.from_pretrained( |
| model_name, config=config, |
| torch_dtype=torch.bfloat16, |
| ) |
| self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
| |
| hidden_dim = self.t5.config.d_model |
| self.style_dim = style_dim |
| self.hidden_dim = hidden_dim |
| self.dropout_rate = dropout |
| self.model_name_str = model_name |
| |
| |
| self.style_embeddings = nn.ParameterDict({ |
| 'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02), |
| 'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02), |
| 'ai_ps': nn.Parameter(torch.randn(style_dim) * 0.02), |
| 'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02), |
| }) |
| |
| |
| self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout) |
| |
| def get_style_embedding(self, style_keys: list) -> torch.Tensor: |
| embs = [self.style_embeddings[k] for k in style_keys] |
| return torch.stack(embs, dim=0) |
| |
| def encode_with_style(self, input_ids, attention_mask, style_keys: list): |
| encoder = self.t5.get_encoder() |
| encoder_output = encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| hidden_states = encoder_output.last_hidden_state |
| |
| |
| style_emb = self.get_style_embedding(style_keys).to(hidden_states.dtype) |
| |
| |
| self.fusion = self.fusion.to(hidden_states.dtype) |
| |
| fused = self.fusion(hidden_states, style_emb) |
| |
| encoder_output.last_hidden_state = fused |
| return encoder_output |
| |
| def forward(self, input_ids, attention_mask, labels, style_keys: list): |
| encoder_output = self.encode_with_style(input_ids, attention_mask, style_keys) |
| |
| outputs = self.t5( |
| encoder_outputs=encoder_output, |
| attention_mask=attention_mask, |
| labels=labels, |
| ) |
| return outputs |
| |
| def generate_text(self, input_ids, attention_mask, style_keys: list, |
| max_length: int = 512, num_beams: int = 4, **kwargs): |
| encoder_output = self.encode_with_style(input_ids, attention_mask, style_keys) |
| |
| generated = self.t5.generate( |
| encoder_outputs=encoder_output, |
| attention_mask=attention_mask, |
| max_new_tokens=max_length, |
| num_beams=num_beams, |
| early_stopping=True, |
| no_repeat_ngram_size=3, |
| length_penalty=1.0, |
| **kwargs, |
| ) |
| return generated |
| |
| def save_pretrained(self, path: str): |
| import os |
| os.makedirs(path, exist_ok=True) |
| |
| |
| self.t5.save_pretrained(os.path.join(path, 't5')) |
| self.tokenizer.save_pretrained(os.path.join(path, 't5')) |
| |
| |
| torch.save({ |
| 'style_embeddings': {k: v.data.float() for k, v in self.style_embeddings.items()}, |
| 'fusion_state_dict': {k: v.float() for k, v in self.fusion.state_dict().items()}, |
| 'style_dim': self.style_dim, |
| 'hidden_dim': self.hidden_dim, |
| 'dropout_rate': self.dropout_rate, |
| 'model_name': self.model_name_str, |
| }, os.path.join(path, 'style_components.pt')) |
| |
| @classmethod |
| def load_pretrained(cls, path: str, device: str = 'cpu'): |
| import os |
| |
| style_data = torch.load( |
| os.path.join(path, 'style_components.pt'), |
| map_location=device, |
| weights_only=True, |
| ) |
| |
| model = cls( |
| model_name=os.path.join(path, 't5'), |
| style_dim=style_data['style_dim'], |
| dropout=style_data.get('dropout_rate', 0.15), |
| ) |
| |
| for k, v in style_data['style_embeddings'].items(): |
| model.style_embeddings[k].data = v.to(device) |
| model.fusion.load_state_dict(style_data['fusion_state_dict']) |
| |
| return model |
|
|