catninja123 commited on
Commit
a88f7e6
·
verified ·
1 Parent(s): 1941b80

Upload src/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model.py +17 -10
src/model.py CHANGED
@@ -5,7 +5,11 @@ Upgrade from BART-base (140M) to Flan-T5-XL (3B).
5
  - Same style injection architecture (4 style vectors + fusion layer)
6
  - T5 encoder-decoder is native seq2seq, ideal for rewriting
7
  - Flan-T5 has instruction-following capability built in
8
- - fp16 training on A100 80GB
 
 
 
 
9
  """
10
 
11
  import torch
@@ -51,9 +55,10 @@ class StyleT5(nn.Module):
51
  config = T5Config.from_pretrained(model_name)
52
  config.dropout_rate = dropout
53
 
 
54
  self.t5 = T5ForConditionalGeneration.from_pretrained(
55
  model_name, config=config,
56
- torch_dtype=torch.float16, # fp16 to fit in memory
57
  )
58
  self.tokenizer = T5Tokenizer.from_pretrained(model_name)
59
 
@@ -63,7 +68,7 @@ class StyleT5(nn.Module):
63
  self.dropout_rate = dropout
64
  self.model_name_str = model_name
65
 
66
- # 4 trainable style embeddings (larger dim for larger model)
67
  self.style_embeddings = nn.ParameterDict({
68
  'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02),
69
  'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
@@ -71,7 +76,7 @@ class StyleT5(nn.Module):
71
  'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
72
  })
73
 
74
- # Style fusion layer — fp32 for stability
75
  self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout)
76
 
77
  def get_style_embedding(self, style_keys: list) -> torch.Tensor:
@@ -86,11 +91,13 @@ class StyleT5(nn.Module):
86
  )
87
  hidden_states = encoder_output.last_hidden_state
88
 
89
- # Cast to fp32 for fusion layer stability, then back to fp16
90
- hidden_fp32 = hidden_states.float()
91
- style_emb = self.get_style_embedding(style_keys).float()
92
- fused = self.fusion(hidden_fp32, style_emb)
93
- fused = fused.to(hidden_states.dtype)
 
 
94
 
95
  encoder_output.last_hidden_state = fused
96
  return encoder_output
@@ -125,7 +132,7 @@ class StyleT5(nn.Module):
125
  import os
126
  os.makedirs(path, exist_ok=True)
127
 
128
- # Save T5 model in fp16
129
  self.t5.save_pretrained(os.path.join(path, 't5'))
130
  self.tokenizer.save_pretrained(os.path.join(path, 't5'))
131
 
 
5
  - Same style injection architecture (4 style vectors + fusion layer)
6
  - T5 encoder-decoder is native seq2seq, ideal for rewriting
7
  - Flan-T5 has instruction-following capability built in
8
+ - bf16 training on A100 80GB (NOT fp16 — must match autocast dtype)
9
+
10
+ v3b fixes:
11
+ - Load model in bfloat16 (was fp16, causing NaN with bf16 autocast)
12
+ - Fusion layer stays in bf16 (no manual dtype casting needed)
13
  """
14
 
15
  import torch
 
55
  config = T5Config.from_pretrained(model_name)
56
  config.dropout_rate = dropout
57
 
58
+ # CRITICAL: Use bfloat16 to match autocast dtype (was float16 → caused NaN)
59
  self.t5 = T5ForConditionalGeneration.from_pretrained(
60
  model_name, config=config,
61
+ torch_dtype=torch.bfloat16,
62
  )
63
  self.tokenizer = T5Tokenizer.from_pretrained(model_name)
64
 
 
68
  self.dropout_rate = dropout
69
  self.model_name_str = model_name
70
 
71
+ # 4 trainable style embeddings
72
  self.style_embeddings = nn.ParameterDict({
73
  'human_ps': nn.Parameter(torch.randn(style_dim) * 0.02),
74
  'human_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
 
76
  'ai_supp': nn.Parameter(torch.randn(style_dim) * 0.02),
77
  })
78
 
79
+ # Style fusion layer
80
  self.fusion = StyleFusionLayer(hidden_dim, style_dim, dropout=dropout)
81
 
82
  def get_style_embedding(self, style_keys: list) -> torch.Tensor:
 
91
  )
92
  hidden_states = encoder_output.last_hidden_state
93
 
94
+ # Get style embedding and cast to same dtype as hidden states
95
+ style_emb = self.get_style_embedding(style_keys).to(hidden_states.dtype)
96
+
97
+ # Cast fusion layer to same dtype (it may be fp32 from init)
98
+ self.fusion = self.fusion.to(hidden_states.dtype)
99
+
100
+ fused = self.fusion(hidden_states, style_emb)
101
 
102
  encoder_output.last_hidden_state = fused
103
  return encoder_output
 
132
  import os
133
  os.makedirs(path, exist_ok=True)
134
 
135
+ # Save T5 model
136
  self.t5.save_pretrained(os.path.join(path, 't5'))
137
  self.tokenizer.save_pretrained(os.path.join(path, 't5'))
138