OliverPerrin commited on
Commit
3225a94
·
1 Parent(s): 944ac49

Added pretrained BART wights for the encoder/decoder while keeping custom architectiure and task heads

Browse files
configs/model/base.yaml CHANGED
@@ -1,6 +1,8 @@
1
- d_model: 512
2
  num_encoder_layers: 6
3
  num_decoder_layers: 6
4
- num_attention_heads: 8
5
- ffn_dim: 2048
6
  dropout: 0.1
 
 
 
1
+ d_model: 768
2
  num_encoder_layers: 6
3
  num_decoder_layers: 6
4
+ num_attention_heads: 12
5
+ ffn_dim: 3072
6
  dropout: 0.1
7
+ use_pretrained: true
8
+ pretrained_model_name: facebook/bart-base
configs/training/default.yaml CHANGED
@@ -10,3 +10,5 @@ scheduler:
10
  trainer:
11
  max_epochs: 5
12
  gradient_clip_norm: 1.0
 
 
 
10
  trainer:
11
  max_epochs: 5
12
  gradient_clip_norm: 1.0
13
+ validation_samples: 3
14
+ validation_max_length: 128
scripts/inference.py CHANGED
@@ -3,9 +3,14 @@ from __future__ import annotations
3
 
4
  import argparse
5
  import json
 
6
  from pathlib import Path
7
  from typing import List, cast
8
 
 
 
 
 
9
  from src.data.tokenization import TokenizerConfig
10
  from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline
11
 
 
3
 
4
  import argparse
5
  import json
6
+ import sys
7
  from pathlib import Path
8
  from typing import List, cast
9
 
10
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
11
+ if str(PROJECT_ROOT) not in sys.path:
12
+ sys.path.insert(0, str(PROJECT_ROOT))
13
+
14
  from src.data.tokenization import TokenizerConfig
15
  from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline
16
 
src/models/attention.py CHANGED
@@ -83,8 +83,6 @@ class ScaledDotProductAttention(nn.Module):
83
  mask_bool = mask.to(dtype=torch.bool, device=scores.device)
84
  # masked_fill expects broadcastable mask: True means keep, False means mask out
85
  scores = scores.masked_fill(~mask_bool, float("-1e9"))
86
- # Applying Softmax to get attention weights
87
- attention_weights = F.softmax(scores, dim=-1)
88
 
89
  # Softmax to get attention probabilities
90
  p_attn = F.softmax(scores, dim=-1)
 
83
  mask_bool = mask.to(dtype=torch.bool, device=scores.device)
84
  # masked_fill expects broadcastable mask: True means keep, False means mask out
85
  scores = scores.masked_fill(~mask_bool, float("-1e9"))
 
 
86
 
87
  # Softmax to get attention probabilities
88
  p_attn = F.softmax(scores, dim=-1)
src/models/factory.py CHANGED
@@ -5,6 +5,9 @@ from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Optional
7
 
 
 
 
8
  from ..data.tokenization import Tokenizer
9
  from ..utils.config import load_yaml
10
  from .decoder import TransformerDecoder
@@ -23,6 +26,8 @@ class ModelConfig:
23
  num_attention_heads: int = 8
24
  ffn_dim: int = 2048
25
  dropout: float = 0.1
 
 
26
 
27
  def __post_init__(self):
28
  if self.d_model % self.num_attention_heads != 0:
@@ -51,9 +56,93 @@ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
51
  num_attention_heads=int(data.get("num_attention_heads", 8)),
52
  ffn_dim=int(data.get("ffn_dim", 2048)),
53
  dropout=float(data.get("dropout", 0.1)),
 
 
54
  )
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def build_multitask_model(
58
  tokenizer: Tokenizer,
59
  *,
@@ -68,6 +157,7 @@ def build_multitask_model(
68
  raise ValueError("num_emotions must be a positive integer")
69
  if not isinstance(num_topics, int) or num_topics <= 0:
70
  raise ValueError("num_topics must be a positive integer")
 
71
  encoder = TransformerEncoder(
72
  vocab_size=tokenizer.vocab_size,
73
  d_model=cfg.d_model,
@@ -88,7 +178,12 @@ def build_multitask_model(
88
  max_len=tokenizer.config.max_length,
89
  pad_token_id=tokenizer.pad_token_id,
90
  )
 
 
 
 
91
 
 
92
  model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
93
  model.add_head(
94
  "summarization",
 
5
  from pathlib import Path
6
  from typing import Optional
7
 
8
+ import torch
9
+ from transformers import BartModel
10
+
11
  from ..data.tokenization import Tokenizer
12
  from ..utils.config import load_yaml
13
  from .decoder import TransformerDecoder
 
26
  num_attention_heads: int = 8
27
  ffn_dim: int = 2048
28
  dropout: float = 0.1
29
+ use_pretrained: bool = False
30
+ pretrained_model_name: str = "facebook/bart-base"
31
 
32
  def __post_init__(self):
33
  if self.d_model % self.num_attention_heads != 0:
 
56
  num_attention_heads=int(data.get("num_attention_heads", 8)),
57
  ffn_dim=int(data.get("ffn_dim", 2048)),
58
  dropout=float(data.get("dropout", 0.1)),
59
+ use_pretrained=bool(data.get("use_pretrained", False)),
60
+ pretrained_model_name=str(data.get("pretrained_model_name", "facebook/bart-base")),
61
  )
62
 
63
 
64
+ def _load_pretrained_weights(encoder: TransformerEncoder, decoder: TransformerDecoder, model_name: str) -> None:
65
+ """Load pretrained BART weights into custom encoder/decoder."""
66
+ print(f"Loading pretrained weights from {model_name}...")
67
+ bart = BartModel.from_pretrained(model_name)
68
+
69
+ # Load encoder weights
70
+ print("Transferring encoder weights...")
71
+ encoder.embedding.weight.data.copy_(bart.encoder.embed_tokens.weight.data)
72
+ encoder.pos_encoder.pe.data.copy_(bart.encoder.embed_positions.weight.data.unsqueeze(0))
73
+
74
+ for i, (custom_layer, bart_layer) in enumerate(zip(encoder.layers, bart.encoder.layers)):
75
+ # Self-attention
76
+ custom_layer.self_attn.W_Q.weight.data.copy_(bart_layer.self_attn.q_proj.weight.data)
77
+ custom_layer.self_attn.W_Q.bias.data.copy_(bart_layer.self_attn.q_proj.bias.data)
78
+ custom_layer.self_attn.W_K.weight.data.copy_(bart_layer.self_attn.k_proj.weight.data)
79
+ custom_layer.self_attn.W_K.bias.data.copy_(bart_layer.self_attn.k_proj.bias.data)
80
+ custom_layer.self_attn.W_V.weight.data.copy_(bart_layer.self_attn.v_proj.weight.data)
81
+ custom_layer.self_attn.W_V.bias.data.copy_(bart_layer.self_attn.v_proj.bias.data)
82
+ custom_layer.self_attn.W_O.weight.data.copy_(bart_layer.self_attn.out_proj.weight.data)
83
+ custom_layer.self_attn.W_O.bias.data.copy_(bart_layer.self_attn.out_proj.bias.data)
84
+
85
+ # Layer norms
86
+ custom_layer.norm1.weight.data.copy_(bart_layer.self_attn_layer_norm.weight.data)
87
+ custom_layer.norm1.bias.data.copy_(bart_layer.self_attn_layer_norm.bias.data)
88
+ custom_layer.norm2.weight.data.copy_(bart_layer.final_layer_norm.weight.data)
89
+ custom_layer.norm2.bias.data.copy_(bart_layer.final_layer_norm.bias.data)
90
+
91
+ # FFN
92
+ custom_layer.ffn.fc1.weight.data.copy_(bart_layer.fc1.weight.data)
93
+ custom_layer.ffn.fc1.bias.data.copy_(bart_layer.fc1.bias.data)
94
+ custom_layer.ffn.fc2.weight.data.copy_(bart_layer.fc2.weight.data)
95
+ custom_layer.ffn.fc2.bias.data.copy_(bart_layer.fc2.bias.data)
96
+
97
+ encoder.final_norm.weight.data.copy_(bart.encoder.layernorm_embedding.weight.data)
98
+ encoder.final_norm.bias.data.copy_(bart.encoder.layernorm_embedding.bias.data)
99
+
100
+ # Load decoder weights
101
+ print("Transferring decoder weights...")
102
+ decoder.embedding.weight.data.copy_(bart.decoder.embed_tokens.weight.data)
103
+ decoder.pos_encoder.pe.data.copy_(bart.decoder.embed_positions.weight.data.unsqueeze(0))
104
+
105
+ for i, (custom_layer, bart_layer) in enumerate(zip(decoder.layers, bart.decoder.layers)):
106
+ # Self-attention
107
+ custom_layer.self_attn.W_Q.weight.data.copy_(bart_layer.self_attn.q_proj.weight.data)
108
+ custom_layer.self_attn.W_Q.bias.data.copy_(bart_layer.self_attn.q_proj.bias.data)
109
+ custom_layer.self_attn.W_K.weight.data.copy_(bart_layer.self_attn.k_proj.weight.data)
110
+ custom_layer.self_attn.W_K.bias.data.copy_(bart_layer.self_attn.k_proj.bias.data)
111
+ custom_layer.self_attn.W_V.weight.data.copy_(bart_layer.self_attn.v_proj.weight.data)
112
+ custom_layer.self_attn.W_V.bias.data.copy_(bart_layer.self_attn.v_proj.bias.data)
113
+ custom_layer.self_attn.W_O.weight.data.copy_(bart_layer.self_attn.out_proj.weight.data)
114
+ custom_layer.self_attn.W_O.bias.data.copy_(bart_layer.self_attn.out_proj.bias.data)
115
+
116
+ # Cross-attention
117
+ custom_layer.cross_attn.W_Q.weight.data.copy_(bart_layer.encoder_attn.q_proj.weight.data)
118
+ custom_layer.cross_attn.W_Q.bias.data.copy_(bart_layer.encoder_attn.q_proj.bias.data)
119
+ custom_layer.cross_attn.W_K.weight.data.copy_(bart_layer.encoder_attn.k_proj.weight.data)
120
+ custom_layer.cross_attn.W_K.bias.data.copy_(bart_layer.encoder_attn.k_proj.bias.data)
121
+ custom_layer.cross_attn.W_V.weight.data.copy_(bart_layer.encoder_attn.v_proj.weight.data)
122
+ custom_layer.cross_attn.W_V.bias.data.copy_(bart_layer.encoder_attn.v_proj.bias.data)
123
+ custom_layer.cross_attn.W_O.weight.data.copy_(bart_layer.encoder_attn.out_proj.weight.data)
124
+ custom_layer.cross_attn.W_O.bias.data.copy_(bart_layer.encoder_attn.out_proj.bias.data)
125
+
126
+ # Layer norms
127
+ custom_layer.norm1.weight.data.copy_(bart_layer.self_attn_layer_norm.weight.data)
128
+ custom_layer.norm1.bias.data.copy_(bart_layer.self_attn_layer_norm.bias.data)
129
+ custom_layer.norm2.weight.data.copy_(bart_layer.encoder_attn_layer_norm.weight.data)
130
+ custom_layer.norm2.bias.data.copy_(bart_layer.encoder_attn_layer_norm.bias.data)
131
+ custom_layer.norm3.weight.data.copy_(bart_layer.final_layer_norm.weight.data)
132
+ custom_layer.norm3.bias.data.copy_(bart_layer.final_layer_norm.bias.data)
133
+
134
+ # FFN
135
+ custom_layer.ffn.fc1.weight.data.copy_(bart_layer.fc1.weight.data)
136
+ custom_layer.ffn.fc1.bias.data.copy_(bart_layer.fc1.bias.data)
137
+ custom_layer.ffn.fc2.weight.data.copy_(bart_layer.fc2.weight.data)
138
+ custom_layer.ffn.fc2.bias.data.copy_(bart_layer.fc2.bias.data)
139
+
140
+ decoder.final_norm.weight.data.copy_(bart.decoder.layernorm_embedding.weight.data)
141
+ decoder.final_norm.bias.data.copy_(bart.decoder.layernorm_embedding.bias.data)
142
+
143
+ print("Pretrained weights loaded successfully!")
144
+
145
+
146
  def build_multitask_model(
147
  tokenizer: Tokenizer,
148
  *,
 
157
  raise ValueError("num_emotions must be a positive integer")
158
  if not isinstance(num_topics, int) or num_topics <= 0:
159
  raise ValueError("num_topics must be a positive integer")
160
+
161
  encoder = TransformerEncoder(
162
  vocab_size=tokenizer.vocab_size,
163
  d_model=cfg.d_model,
 
178
  max_len=tokenizer.config.max_length,
179
  pad_token_id=tokenizer.pad_token_id,
180
  )
181
+
182
+ # Load pretrained weights if requested
183
+ if cfg.use_pretrained:
184
+ _load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
185
 
186
+
187
  model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
188
  model.add_head(
189
  "summarization",
src/training/trainer.py CHANGED
@@ -20,6 +20,8 @@ class TrainerConfig:
20
  gradient_clip_norm: float = 1.0
21
  logging_interval: int = 50
22
  task_weights: Dict[str, float] | None = None
 
 
23
 
24
 
25
  class Trainer:
@@ -63,6 +65,9 @@ class Trainer:
63
  if val_loaders:
64
  val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch)
65
  history[f"val_epoch_{epoch}"] = val_metrics
 
 
 
66
  epoch_duration = time.perf_counter() - epoch_start
67
  total_elapsed = time.perf_counter() - start_time
68
  self._print_epoch_progress(epoch, total_epochs, epoch_duration, total_elapsed)
@@ -223,6 +228,72 @@ class Trainer:
223
  valid[valid == -100] = self.tokenizer.pad_token_id
224
  return self.tokenizer.decode_batch(valid.tolist())
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  def _print_epoch_progress(
227
  self,
228
  epoch: int,
 
20
  gradient_clip_norm: float = 1.0
21
  logging_interval: int = 50
22
  task_weights: Dict[str, float] | None = None
23
+ validation_samples: int = 3
24
+ validation_max_length: int = 128
25
 
26
 
27
  class Trainer:
 
65
  if val_loaders:
66
  val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch)
67
  history[f"val_epoch_{epoch}"] = val_metrics
68
+ # Generate sample summaries for validation
69
+ if "summarization" in val_loaders:
70
+ self._validate_generation(val_loaders["summarization"], epoch)
71
  epoch_duration = time.perf_counter() - epoch_start
72
  total_elapsed = time.perf_counter() - start_time
73
  self._print_epoch_progress(epoch, total_epochs, epoch_duration, total_elapsed)
 
228
  valid[valid == -100] = self.tokenizer.pad_token_id
229
  return self.tokenizer.decode_batch(valid.tolist())
230
 
231
+ def _validate_generation(self, val_loader: DataLoader, epoch: int) -> None:
232
+ """Generate and print sample summaries to monitor quality during training."""
233
+ self.model.eval()
234
+ samples_generated = 0
235
+ print(f"\n{'='*80}")
236
+ print(f"[Validation Generation - Epoch {epoch}]")
237
+ print(f"{'='*80}")
238
+
239
+ with torch.no_grad():
240
+ for batch in val_loader:
241
+ if samples_generated >= self.config.validation_samples:
242
+ break
243
+
244
+ batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
245
+ src_ids = batch["src_ids"]
246
+ src_mask = batch.get("src_mask")
247
+ labels = batch["labels"]
248
+
249
+ # Only process first item from batch
250
+ src_ids = src_ids[:1]
251
+ if src_mask is not None:
252
+ src_mask = src_mask[:1]
253
+ labels = labels[:1]
254
+
255
+ # Encode source
256
+ encoder_mask = None
257
+ if src_mask is not None:
258
+ encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
259
+ memory = self.model.encoder(src_ids, mask=encoder_mask)
260
+
261
+ # Ban special tokens from generation
262
+ ban_token_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
263
+ unk_id = getattr(self.tokenizer._tokenizer, 'unk_token_id', None)
264
+ if isinstance(unk_id, int):
265
+ ban_token_ids.append(unk_id)
266
+ ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
267
+
268
+ # Generate
269
+ generated = self.model.decoder.greedy_decode(
270
+ memory=memory,
271
+ max_len=self.config.validation_max_length,
272
+ start_token_id=self.tokenizer.bos_token_id,
273
+ end_token_id=self.tokenizer.eos_token_id,
274
+ device=self.device,
275
+ min_len=10,
276
+ ban_token_ids=ban_token_ids,
277
+ no_repeat_ngram_size=3,
278
+ memory_mask=src_mask,
279
+ )
280
+
281
+ # Decode
282
+ source_text = self.tokenizer.decode(src_ids[0].tolist())
283
+ generated_text = self.tokenizer.decode(generated[0].tolist())
284
+ reference_text = self._decode_labels(labels)[0]
285
+
286
+ print(f"\nSample {samples_generated + 1}:")
287
+ print(f"Source: {source_text[:200]}..." if len(source_text) > 200 else f"Source: {source_text}")
288
+ print(f"Generated: {generated_text}")
289
+ print(f"Reference: {reference_text[:200]}..." if len(reference_text) > 200 else f"Reference: {reference_text}")
290
+ print("-" * 80)
291
+
292
+ samples_generated += 1
293
+
294
+ print(f"{'='*80}\n")
295
+ self.model.train()
296
+
297
  def _print_epoch_progress(
298
  self,
299
  epoch: int,