Upload src/model.py with huggingface_hub
Browse files- src/model.py +17 -10
src/model.py
CHANGED
|
@@ -5,7 +5,11 @@ Upgrade from BART-base (140M) to Flan-T5-XL (3B).
|
|
| 5 |
- Same style injection architecture (4 style vectors + fusion layer)
|
| 6 |
- T5 encoder-decoder is native seq2seq, ideal for rewriting
|
| 7 |
- Flan-T5 has instruction-following capability built in
|
| 8 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import torch
|
|
@@ -51,9 +55,10 @@ class StyleT5(nn.Module):
|
|
| 51 |
config = T5Config.from_pretrained(model_name)
|
| 52 |
config.dropout_rate = dropout
|
| 53 |
|
|
|
|
| 54 |
self.t5 = T5ForConditionalGeneration.from_pretrained(
|
| 55 |
model_name, config=config,
|
| 56 |
-
torch_dtype=torch.
|
| 57 |
)
|
| 58 |
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 59 |
|
|
@@ -63,7 +68,7 @@ class StyleT5(nn.Module):
|
|
| 63 |
self.dropout_rate = dropout
|
| 64 |
self.model_name_str = model_name
|
| 65 |
|
| 66 |
-
# 4 trainable style embeddings
|
| 67 |
self.style_embeddings = nn.ParameterDict({
|
| 68 |
'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02),
|
| 69 |
'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
|
|
@@ -71,7 +76,7 @@ class StyleT5(nn.Module):
|
|
| 71 |
'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
|
| 72 |
})
|
| 73 |
|
| 74 |
-
# Style fusion layer
|
| 75 |
self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout)
|
| 76 |
|
| 77 |
def get_style_embedding(self, style_keys: list) -> torch.Tensor:
|
|
@@ -86,11 +91,13 @@ class StyleT5(nn.Module):
|
|
| 86 |
)
|
| 87 |
hidden_states = encoder_output.last_hidden_state
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
encoder_output.last_hidden_state = fused
|
| 96 |
return encoder_output
|
|
@@ -125,7 +132,7 @@ class StyleT5(nn.Module):
|
|
| 125 |
import os
|
| 126 |
os.makedirs(path, exist_ok=True)
|
| 127 |
|
| 128 |
-
# Save T5 model
|
| 129 |
self.t5.save_pretrained(os.path.join(path, 't5'))
|
| 130 |
self.tokenizer.save_pretrained(os.path.join(path, 't5'))
|
| 131 |
|
|
|
|
| 5 |
- Same style injection architecture (4 style vectors + fusion layer)
|
| 6 |
- T5 encoder-decoder is native seq2seq, ideal for rewriting
|
| 7 |
- Flan-T5 has instruction-following capability built in
|
| 8 |
+
- bf16 training on A100 80GB (NOT fp16 — must match autocast dtype)
|
| 9 |
+
|
| 10 |
+
v3b fixes:
|
| 11 |
+
- Load model in bfloat16 (was fp16, causing NaN with bf16 autocast)
|
| 12 |
+
- Fusion layer stays in bf16 (no manual dtype casting needed)
|
| 13 |
"""
|
| 14 |
|
| 15 |
import torch
|
|
|
|
| 55 |
config = T5Config.from_pretrained(model_name)
|
| 56 |
config.dropout_rate = dropout
|
| 57 |
|
| 58 |
+
# CRITICAL: Use bfloat16 to match autocast dtype (was float16 → caused NaN)
|
| 59 |
self.t5 = T5ForConditionalGeneration.from_pretrained(
|
| 60 |
model_name, config=config,
|
| 61 |
+
torch_dtype=torch.bfloat16,
|
| 62 |
)
|
| 63 |
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 64 |
|
|
|
|
| 68 |
self.dropout_rate = dropout
|
| 69 |
self.model_name_str = model_name
|
| 70 |
|
| 71 |
+
# 4 trainable style embeddings
|
| 72 |
self.style_embeddings = nn.ParameterDict({
|
| 73 |
'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02),
|
| 74 |
'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
|
|
|
|
| 76 |
'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
|
| 77 |
})
|
| 78 |
|
| 79 |
+
# Style fusion layer
|
| 80 |
self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout)
|
| 81 |
|
| 82 |
def get_style_embedding(self, style_keys: list) -> torch.Tensor:
|
|
|
|
| 91 |
)
|
| 92 |
hidden_states = encoder_output.last_hidden_state
|
| 93 |
|
| 94 |
+
# Get style embedding and cast to same dtype as hidden states
|
| 95 |
+
style_emb = self.get_style_embedding(style_keys).to(hidden_states.dtype)
|
| 96 |
+
|
| 97 |
+
# Cast fusion layer to same dtype (it may be fp32 from init)
|
| 98 |
+
self.fusion = self.fusion.to(hidden_states.dtype)
|
| 99 |
+
|
| 100 |
+
fused = self.fusion(hidden_states, style_emb)
|
| 101 |
|
| 102 |
encoder_output.last_hidden_state = fused
|
| 103 |
return encoder_output
|
|
|
|
| 132 |
import os
|
| 133 |
os.makedirs(path, exist_ok=True)
|
| 134 |
|
| 135 |
+
# Save T5 model
|
| 136 |
self.t5.save_pretrained(os.path.join(path, 't5'))
|
| 137 |
self.tokenizer.save_pretrained(os.path.join(path, 't5'))
|
| 138 |
|