catninja123's picture
Upload src/model.py with huggingface_hub
a88f7e6 verified
"""
MASH Style-injection T5 Model (Multi-Style Version) — v3
Upgrade from BART-base (140M) to Flan-T5-XL (3B).
- Same style injection architecture (4 style vectors + fusion layer)
- T5 encoder-decoder is native seq2seq, ideal for rewriting
- Flan-T5 has instruction-following capability built in
- bf16 training on A100 80GB (NOT fp16 — must match autocast dtype)
v3b fixes:
- Load model in bfloat16 (was fp16, causing NaN with bf16 autocast)
- Fusion layer stays in bf16 (no manual dtype casting needed)
"""
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config
class StyleFusionLayer(nn.Module):
"""Fuses content representation with style embedding via linear projection + dropout."""
def __init__(self, content_dim: int, style_dim: int, dropout: float = 0.15):
super().__init__()
self.projection = nn.Linear(content_dim + style_dim, content_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(content_dim)
def forward(self, content_hidden: torch.Tensor, style_emb: torch.Tensor) -> torch.Tensor:
style_expanded = style_emb.unsqueeze(1).expand(-1, content_hidden.size(1), -1)
concat = torch.cat([content_hidden, style_expanded], dim=-1)
fused = self.projection(concat)
fused = self.dropout(fused)
# Residual connection + layer norm
fused = self.layer_norm(fused + content_hidden)
return fused
class StyleT5(nn.Module):
"""
Flan-T5-XL with multi-style injection for AI→Human text style transfer.
4 style vectors capture the cross-product of:
- Voice: human vs AI
- Genre: PS (personal narrative) vs Supp (argumentative/academic)
"""
STYLES = ['human_ps', 'human_supp', 'ai_ps', 'ai_supp']
def __init__(self, model_name: str = 'google/flan-t5-xl', style_dim: int = 128,
dropout: float = 0.15):
super().__init__()
# Load T5 with adjusted dropout
config = T5Config.from_pretrained(model_name)
config.dropout_rate = dropout
# CRITICAL: Use bfloat16 to match autocast dtype (was float16 → caused NaN)
self.t5 = T5ForConditionalGeneration.from_pretrained(
model_name, config=config,
torch_dtype=torch.bfloat16,
)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
hidden_dim = self.t5.config.d_model # 2048 for t5-xl
self.style_dim = style_dim
self.hidden_dim = hidden_dim
self.dropout_rate = dropout
self.model_name_str = model_name
# 4 trainable style embeddings
self.style_embeddings = nn.ParameterDict({
'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02),
'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
'ai_ps': nn.Parameter(torch.randn(style_dim) * 0.02),
'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
})
# Style fusion layer
self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout)
def get_style_embedding(self, style_keys: list) -> torch.Tensor:
embs = [self.style_embeddings[k] for k in style_keys]
return torch.stack(embs, dim=0)
def encode_with_style(self, input_ids, attention_mask, style_keys: list):
encoder = self.t5.get_encoder()
encoder_output = encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
hidden_states = encoder_output.last_hidden_state
# Get style embedding and cast to same dtype as hidden states
style_emb = self.get_style_embedding(style_keys).to(hidden_states.dtype)
# Cast fusion layer to same dtype (it may be fp32 from init)
self.fusion = self.fusion.to(hidden_states.dtype)
fused = self.fusion(hidden_states, style_emb)
encoder_output.last_hidden_state = fused
return encoder_output
def forward(self, input_ids, attention_mask, labels, style_keys: list):
encoder_output = self.encode_with_style(input_ids, attention_mask, style_keys)
outputs = self.t5(
encoder_outputs=encoder_output,
attention_mask=attention_mask,
labels=labels,
)
return outputs
def generate_text(self, input_ids, attention_mask, style_keys: list,
max_length: int = 512, num_beams: int = 4, **kwargs):
encoder_output = self.encode_with_style(input_ids, attention_mask, style_keys)
generated = self.t5.generate(
encoder_outputs=encoder_output,
attention_mask=attention_mask,
max_new_tokens=max_length,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=3,
length_penalty=1.0,
**kwargs,
)
return generated
def save_pretrained(self, path: str):
import os
os.makedirs(path, exist_ok=True)
# Save T5 model
self.t5.save_pretrained(os.path.join(path, 't5'))
self.tokenizer.save_pretrained(os.path.join(path, 't5'))
# Save style components separately (small, always fp32)
torch.save({
'style_embeddings': {k: v.data.float() for k, v in self.style_embeddings.items()},
'fusion_state_dict': {k: v.float() for k, v in self.fusion.state_dict().items()},
'style_dim': self.style_dim,
'hidden_dim': self.hidden_dim,
'dropout_rate': self.dropout_rate,
'model_name': self.model_name_str,
}, os.path.join(path, 'style_components.pt'))
@classmethod
def load_pretrained(cls, path: str, device: str = 'cpu'):
import os
style_data = torch.load(
os.path.join(path, 'style_components.pt'),
map_location=device,
weights_only=True,
)
model = cls(
model_name=os.path.join(path, 't5'),
style_dim=style_data['style_dim'],
dropout=style_data.get('dropout_rate', 0.15),
)
for k, v in style_data['style_embeddings'].items():
model.style_embeddings[k].data = v.to(device)
model.fusion.load_state_dict(style_data['fusion_state_dict'])
return model