OliverPerrin commited on
Commit
baf3026
·
1 Parent(s): 4667b2a

Fixed compiling issue, added legnth penalty, and atttempting freezing encoder layers 0-5 to lower parameters and preserve T5's langauge understanding.

Browse files
README.md CHANGED
@@ -18,9 +18,9 @@ This project is built with industry-standard MLOps practices, including configur
18
 
19
  ## Core Features
20
 
21
- * **Abstractive Summarization:** Generates concise, coherent summaries of long-form text using encoder-decoder attention. Trained on CNN/DailyMail (news) and BookSum (literary).
22
  * **Emotion Classification:** Identifies 28 emotions from Google's GoEmotions dataset (admiration, amusement, anger, joy, love, etc.).
23
- * **Topic Classification:** Classifies documents into 4 categories (World, Sports, Business, Sci/Tech) using AG News.
24
 
25
  ## Model Architecture
26
 
 
18
 
19
  ## Core Features
20
 
21
+ * **Abstractive Summarization:** Generates concise, coherent summaries of long-form text using encoder-decoder attention. Trained on BookSum (literary) and arXiv (academic papers).
22
  * **Emotion Classification:** Identifies 28 emotions from Google's GoEmotions dataset (admiration, amusement, anger, joy, love, etc.).
23
+ * **Topic Classification:** Classifies documents into 8 categories (Fiction, Science, Technology, Philosophy, History, Psychology, Business, Arts).
24
 
25
  ## Model Architecture
26
 
artifacts/labels.json CHANGED
@@ -30,9 +30,12 @@
30
  "surprise"
31
  ],
32
  "topic": [
 
33
  "Business",
34
- "Sci/Tech",
35
- "Sports",
36
- "World"
 
 
37
  ]
38
  }
 
30
  "surprise"
31
  ],
32
  "topic": [
33
+ "Arts",
34
  "Business",
35
+ "Fiction",
36
+ "History",
37
+ "Philosophy",
38
+ "Science",
39
+ "Technology"
40
  ]
41
  }
configs/data/datasets.yaml CHANGED
@@ -2,9 +2,9 @@
2
  # Data is downloaded via: python scripts/download_data.py
3
 
4
  processed:
5
- summarization: data/processed/summarization # CNN/DailyMail + BookSum
6
  emotion: data/processed/emotion # GoEmotions (28 labels)
7
- topic: data/processed/topic # AG News (4 labels)
8
  books: data/processed/books # Gutenberg prose chunks
9
 
10
  tokenizer:
 
2
  # Data is downloaded via: python scripts/download_data.py
3
 
4
  processed:
5
+ summarization: data/processed/summarization # BookSum + arXiv
6
  emotion: data/processed/emotion # GoEmotions (28 labels)
7
+ topic: data/processed/topic # Books + Papers (8 labels)
8
  books: data/processed/books # Gutenberg prose chunks
9
 
10
  tokenizer:
configs/training/dev.yaml CHANGED
@@ -1,11 +1,11 @@
1
  # Development/Testing Configuration for FLAN-T5-base
2
- # Fast iteration for debugging and testing changes
3
- # VRAM Usage: ~8-9GB peak (12GB available)
4
- # Training time: ~10-15 minutes on RTX 4070 12GB
5
  # Use: python scripts/train.py training=dev
6
 
7
  dataloader:
8
- batch_size: 5 # Conservative for 12GB VRAM
9
  shuffle: true
10
  num_workers: 4
11
  pin_memory: true
@@ -14,32 +14,42 @@ dataloader:
14
 
15
  optimizer:
16
  name: adamw
17
- lr: 5.0e-5 # Higher LR for faster convergence in dev
18
  weight_decay: 0.01
19
  eps: 1.0e-8
20
  betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
- warmup_steps: 100 # ~2% of training steps for smoother start
25
 
26
  trainer:
27
  max_epochs: 3
28
  gradient_clip_norm: 1.0
29
- gradient_accumulation_steps: 12 # Effective batch: 60 (5*12)
30
  validation_max_length: 128
31
- label_smoothing: 0.1
32
  task_weights:
33
  summarization: 1.0
34
- emotion: 0.5
35
- topic: 0.5
36
- max_train_samples: 3000 # 3k samples for better validation
37
  max_val_samples: 300
38
- early_stopping_patience: 5 # Stop if no improvement
39
  log_grad_norm_frequency: 100
40
 
41
- # Disable compile for faster startup in dev
42
- compile_encoder: false
43
- compile_decoder: false
44
 
45
- tokenizer_max_length: 512
 
 
 
 
 
 
 
 
 
 
 
1
  # Development/Testing Configuration for FLAN-T5-base
2
+ # FAST iteration for debugging - optimized for speed
3
+ # VRAM Usage: ~9-10GB peak (12GB available)
4
+ # Training time: ~5 minutes on RTX 4070 12GB
5
  # Use: python scripts/train.py training=dev
6
 
7
  dataloader:
8
+ batch_size: 10 # Optimal with FlashAttention
9
  shuffle: true
10
  num_workers: 4
11
  pin_memory: true
 
14
 
15
  optimizer:
16
  name: adamw
17
+ lr: 5.0e-5
18
  weight_decay: 0.01
19
  eps: 1.0e-8
20
  betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
+ warmup_steps: 50 # Less warmup for short runs
25
 
26
  trainer:
27
  max_epochs: 3
28
  gradient_clip_norm: 1.0
29
+ gradient_accumulation_steps: 6 # Effective batch: 60 (10*6)
30
  validation_max_length: 128
31
+ label_smoothing: 0.0 # Simpler backward graph for dev
32
  task_weights:
33
  summarization: 1.0
34
+ emotion: 1.5
35
+ topic: 0.5 # Reduced - topic already saturated at 86%
36
+ max_train_samples: 3000
37
  max_val_samples: 300
38
+ early_stopping_patience: 5
39
  log_grad_norm_frequency: 100
40
 
41
+ # Enable compile for speed (worth the startup cost)
42
+ compile_encoder: true
43
+ compile_decoder: true
44
 
45
+ # Speed optimizations
46
+ tokenizer_max_length: 256
47
+ gradient_checkpointing: true
48
+
49
+ # FLAN-T5 has NO learned positional embeddings - only relative position bias
50
+ # Disabling this causes repetition loops (model can't track sequence position)
51
+ use_relative_position_bias: true
52
+
53
+ # Freeze lower encoder layers (0-5) to preserve pretrained knowledge
54
+ # Upper layers (6-11) adapt to summarization style
55
+ freeze_encoder_layers: 6
configs/training/full.yaml CHANGED
@@ -1,11 +1,11 @@
1
  # Full Training Configuration for FLAN-T5-base
2
- # Complete training run with capped samples for reasonable time
3
- # VRAM Usage: ~11GB peak (12GB available)
4
- # Training time: ~2 hours on RTX 4070 12GB with torch.compile
5
  # Use: python scripts/train.py training=full
6
 
7
  dataloader:
8
- batch_size: 6 # Keep at 6 to stay within 12GB VRAM
9
  shuffle: true
10
  num_workers: 4
11
  pin_memory: true
@@ -14,32 +14,41 @@ dataloader:
14
 
15
  optimizer:
16
  name: adamw
17
- lr: 5.0e-5 # Slightly higher LR for faster convergence
18
  weight_decay: 0.01
19
  eps: 1.0e-6
20
  betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
- warmup_steps: 500 # Less warmup needed
25
 
26
  trainer:
27
- max_epochs: 5 # Converges by epoch 4-5
28
  gradient_clip_norm: 1.0
29
- gradient_accumulation_steps: 10 # Effective batch: 60 (6*10)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
- summarization: 1.0 # Reduced from 1.2 to give emotion room
34
- emotion: 1.5 # Increased to prevent degradation
35
- topic: 0.8 # Reduced since topic already near SOTA
36
- max_train_samples: 50000 # Cap training for speed
37
- max_val_samples: 3000 # Faster validation
38
- early_stopping_patience: 3
39
  log_grad_norm_frequency: 100
40
 
41
- # Enable torch.compile for maximum speed
42
  compile_encoder: true
43
  compile_decoder: true
44
 
45
- tokenizer_max_length: 512
 
 
 
 
 
 
 
 
 
 
 
1
  # Full Training Configuration for FLAN-T5-base
2
+ # BEST QUALITY - use for final model training
3
+ # VRAM Usage: ~9-10GB (12GB available)
4
+ # Training time: ~1 hour on RTX 4070 12GB
5
  # Use: python scripts/train.py training=full
6
 
7
  dataloader:
8
+ batch_size: 10 # Optimal for RTX 4070 12GB
9
  shuffle: true
10
  num_workers: 4
11
  pin_memory: true
 
14
 
15
  optimizer:
16
  name: adamw
17
+ lr: 2.0e-5 # Lower LR for best convergence
18
  weight_decay: 0.01
19
  eps: 1.0e-6
20
  betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
+ warmup_steps: 500
25
 
26
  trainer:
27
+ max_epochs: 8 # More epochs for best results
28
  gradient_clip_norm: 1.0
29
+ gradient_accumulation_steps: 6 # Effective batch: 60 (10*6)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
+ summarization: 1.0
34
+ emotion: 1.5 # Boost emotion (tends to underfit)
35
+ topic: 0.5 # Reduced - topic already saturated at 86%
36
+ max_train_samples: 50000
37
+ max_val_samples: 3000
38
+ early_stopping_patience: 4
39
  log_grad_norm_frequency: 100
40
 
 
41
  compile_encoder: true
42
  compile_decoder: true
43
 
44
+ # FULL QUALITY SETTINGS
45
+ tokenizer_max_length: 512 # Full context for summarization
46
+ gradient_checkpointing: true
47
+
48
+ # FLAN-T5 has NO learned positional embeddings - only relative position bias
49
+ # Disabling this causes repetition loops (model can't track sequence position)
50
+ use_relative_position_bias: true
51
+
52
+ # Freeze lower encoder layers (0-5) to preserve pretrained knowledge
53
+ # Upper layers (6-11) adapt to summarization style
54
+ freeze_encoder_layers: 6
configs/training/medium.yaml CHANGED
@@ -1,11 +1,11 @@
1
  # Medium Configuration for FLAN-T5-base
2
- # Balanced approach - good results in reasonable time
3
- # VRAM Usage: ~9-10GB peak (12GB available)
4
- # Training time: ~45-60 minutes on RTX 4070 12GB with torch.compile
5
  # Use: python scripts/train.py training=medium
6
 
7
  dataloader:
8
- batch_size: 6 # Conservative for 12GB VRAM with torch.compile
9
  shuffle: true
10
  num_workers: 4
11
  pin_memory: true
@@ -14,32 +14,41 @@ dataloader:
14
 
15
  optimizer:
16
  name: adamw
17
- lr: 3.0e-5 # Balanced LR for quality
18
  weight_decay: 0.01
19
  eps: 1.0e-6
20
  betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
- warmup_steps: 500 # ~2% warmup for 25k steps
25
 
26
  trainer:
27
- max_epochs: 5 # More epochs for better convergence
28
  gradient_clip_norm: 1.0
29
- gradient_accumulation_steps: 12 # Effective batch: 72 (6*12)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
- summarization: 1.2 # Slightly prioritize summarization
34
- emotion: 0.8
35
- topic: 0.8
36
- max_train_samples: 25000 # 25k samples - good balance
37
  max_val_samples: 2500
38
  early_stopping_patience: 3
39
  log_grad_norm_frequency: 100
40
 
41
- # Enable torch.compile for 1.5-2x speedup
42
  compile_encoder: true
43
  compile_decoder: true
44
 
45
- tokenizer_max_length: 512
 
 
 
 
 
 
 
 
 
 
 
1
  # Medium Configuration for FLAN-T5-base
2
+ # Balanced: good quality with reasonable speed
3
+ # VRAM Usage: ~8-9GB (12GB available)
4
+ # Training time: ~25-35 minutes on RTX 4070 12GB
5
  # Use: python scripts/train.py training=medium
6
 
7
  dataloader:
8
+ batch_size: 10 # Optimal for RTX 4070 12GB
9
  shuffle: true
10
  num_workers: 4
11
  pin_memory: true
 
14
 
15
  optimizer:
16
  name: adamw
17
+ lr: 3.0e-5 # Slightly lower LR for stability
18
  weight_decay: 0.01
19
  eps: 1.0e-6
20
  betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
+ warmup_steps: 300
25
 
26
  trainer:
27
+ max_epochs: 5
28
  gradient_clip_norm: 1.0
29
+ gradient_accumulation_steps: 6 # Effective batch: 60 (10*6)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
+ summarization: 1.0
34
+ emotion: 1.5
35
+ topic: 0.5 # Reduced - topic already saturated at 86%
36
+ max_train_samples: 25000
37
  max_val_samples: 2500
38
  early_stopping_patience: 3
39
  log_grad_norm_frequency: 100
40
 
 
41
  compile_encoder: true
42
  compile_decoder: true
43
 
44
+ # Balance: shorter sequences but keep T5's relative position bias for quality
45
+ tokenizer_max_length: 384
46
+ gradient_checkpointing: true
47
+
48
+ # FLAN-T5 has NO learned positional embeddings - only relative position bias
49
+ # Disabling this causes repetition loops (model can't track sequence position)
50
+ use_relative_position_bias: true
51
+
52
+ # Freeze lower encoder layers (0-5) to preserve pretrained knowledge
53
+ # Upper layers (6-11) adapt to summarization style
54
+ freeze_encoder_layers: 6
docs/architecture.md CHANGED
@@ -51,9 +51,9 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
51
 
52
  | Task | Dataset | Size | Labels |
53
  | ---- | ------- | ---- | ------ |
54
- | Summarization | CNN/DailyMail + BookSum | ~110K | Text→Summary |
55
  | Emotion | GoEmotions | ~43K | 28 emotions (multi-label) |
56
- | Topic | AG News | ~120K | 4 categories |
57
  | Books | Gutenberg (prose chunks) | ~30K | Literary text |
58
 
59
  ### T5 Tokenizer Differences
 
51
 
52
  | Task | Dataset | Size | Labels |
53
  | ---- | ------- | ---- | ------ |
54
+ | Summarization | BookSum + arXiv | ~90K | Text→Summary |
55
  | Emotion | GoEmotions | ~43K | 28 emotions (multi-label) |
56
+ | Topic | Books + Papers | ~50K | 8 categories (Fiction, Science, Technology, etc.) |
57
  | Books | Gutenberg (prose chunks) | ~30K | Literary text |
58
 
59
  ### T5 Tokenizer Differences
outputs/evaluation_report.json DELETED
@@ -1,81 +0,0 @@
1
- {
2
- "split": "val",
3
- "summarization": {
4
- "rouge_like": 0.2817535277055523,
5
- "bleu": 0.06501593900536834
6
- },
7
- "emotion": {
8
- "f1_macro": 0.4053446650505066
9
- },
10
- "topic": {
11
- "accuracy": 0.7548042704626334,
12
- "classification_report": {
13
- "Business & Finance": {
14
- "precision": 0.6826859776168532,
15
- "recall": 0.5221550855991943,
16
- "f1-score": 0.5917261055634807,
17
- "support": 1986
18
- },
19
- "Computers & Internet": {
20
- "precision": 0.8468166586883676,
21
- "recall": 0.894790085988872,
22
- "f1-score": 0.8701426463354648,
23
- "support": 1977
24
- },
25
- "Education & Reference": {
26
- "precision": 0.6067106710671067,
27
- "recall": 0.5627551020408164,
28
- "f1-score": 0.5839068290100582,
29
- "support": 1960
30
- },
31
- "Entertainment & Music": {
32
- "precision": 0.732976653696498,
33
- "recall": 0.7708439897698209,
34
- "f1-score": 0.7514335577162802,
35
- "support": 1955
36
- },
37
- "Family & Relationships": {
38
- "precision": 0.7356746765249538,
39
- "recall": 0.8101781170483461,
40
- "f1-score": 0.7711310244611286,
41
- "support": 1965
42
- },
43
- "Health": {
44
- "precision": 0.7917267917267917,
45
- "recall": 0.8372329603255341,
46
- "f1-score": 0.8138442521631644,
47
- "support": 1966
48
- },
49
- "Politics & Government": {
50
- "precision": 0.7916459472899056,
51
- "recall": 0.8097660223804679,
52
- "f1-score": 0.8006034699522253,
53
- "support": 1966
54
- },
55
- "Science & Mathematics": {
56
- "precision": 0.749162278602202,
57
- "recall": 0.7972491085073866,
58
- "f1-score": 0.7724580454096742,
59
- "support": 1963
60
- },
61
- "Society & Culture": {
62
- "precision": 0.6588683351468988,
63
- "recall": 0.6181725370086779,
64
- "f1-score": 0.637872004213853,
65
- "support": 1959
66
- },
67
- "Sports": {
68
- "precision": 0.909317389138017,
69
- "recall": 0.9249873289406995,
70
- "f1-score": 0.9170854271356784,
71
- "support": 1973
72
- },
73
- "macro avg": {
74
- "precision": 0.7505585379497595,
75
- "recall": 0.7548130337609815,
76
- "f1-score": 0.7510203361961008,
77
- "support": 19670
78
- }
79
- }
80
- }
81
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
outputs/training_history.json CHANGED
@@ -1,92 +1,92 @@
1
  {
2
  "train_epoch_1": {
3
- "summarization_loss": 3.7986922026081054,
4
- "summarization_rouge_like": 0.38785950375542677,
5
- "emotion_loss": 0.6569146523665603,
6
- "emotion_f1": 0.0803471759769852,
7
- "topic_loss": 1.3537324049331485,
8
- "topic_accuracy": 0.4645228381729452,
9
- "total_loss": 6.166948288969483
10
  },
11
  "val_epoch_1": {
12
- "summarization_loss": 3.1010914066140884,
13
- "summarization_rouge_like": 0.4547831050626749,
14
- "emotion_loss": 0.47831222164831,
15
- "emotion_f1": 0.07989733061380237,
16
- "topic_loss": 1.1463579110962023,
17
- "topic_accuracy": 0.8397282174260592,
18
- "total_loss": 5.021045794132517
19
  },
20
  "train_epoch_2": {
21
- "summarization_loss": 3.519661677836342,
22
- "summarization_rouge_like": 0.40693338191007866,
23
- "emotion_loss": 0.2990482480142052,
24
- "emotion_f1": 0.25253565061903593,
25
- "topic_loss": 0.5421501434865632,
26
- "topic_accuracy": 0.8869290456763608,
27
- "total_loss": 4.896552726604225
28
  },
29
  "val_epoch_2": {
30
- "summarization_loss": 3.022662199944329,
31
- "summarization_rouge_like": 0.45815133655381807,
32
- "emotion_loss": 0.19708226060124037,
33
- "emotion_f1": 0.302215425453955,
34
- "topic_loss": 0.28093130860647425,
35
- "topic_accuracy": 0.9172661870503583,
36
- "total_loss": 4.009605495299369
37
  },
38
  "train_epoch_3": {
39
- "summarization_loss": 3.456413923878735,
40
- "summarization_rouge_like": 0.4113752870178118,
41
- "emotion_loss": 0.18330693083835614,
42
- "emotion_f1": 0.30698023489509907,
43
- "topic_loss": 0.2889783758940973,
44
- "topic_accuracy": 0.9169066474682156,
45
- "total_loss": 4.525524954040441
46
  },
47
  "val_epoch_3": {
48
- "summarization_loss": 3.0019707325265275,
49
- "summarization_rouge_like": 0.4592321986281997,
50
- "emotion_loss": 0.16639868924014575,
51
- "emotion_f1": 0.3015063897543531,
52
- "topic_loss": 0.23863075083072524,
53
- "topic_accuracy": 0.9280575539568332,
54
- "total_loss": 3.9263884310885304
55
  },
56
  "train_epoch_4": {
57
- "summarization_loss": 3.4258855361860663,
58
- "summarization_rouge_like": 0.4135803384924355,
59
- "emotion_loss": 0.16595664669032975,
60
- "emotion_f1": 0.31446844452103895,
61
- "topic_loss": 0.24658246585826152,
62
- "topic_accuracy": 0.9276857851372029,
63
- "total_loss": 4.441093933462159
64
  },
65
  "val_epoch_4": {
66
- "summarization_loss": 2.992023795628719,
67
- "summarization_rouge_like": 0.4595829821013028,
68
- "emotion_loss": 0.16106250848201253,
69
- "emotion_f1": 0.299241534820635,
70
- "topic_loss": 0.2258928704747765,
71
- "topic_accuracy": 0.9280575539568333,
72
- "total_loss": 3.8999928579198935
73
  },
74
  "train_epoch_5": {
75
- "summarization_loss": 3.4150345063421232,
76
- "summarization_rouge_like": 0.41468036090685273,
77
- "emotion_loss": 0.1624394242665394,
78
- "emotion_f1": 0.31033963250845154,
79
- "topic_loss": 0.2336994289211126,
80
- "topic_accuracy": 0.9319654427645914,
81
- "total_loss": 4.4149524901606805
82
  },
83
  "val_epoch_5": {
84
- "summarization_loss": 2.9899252604523436,
85
- "summarization_rouge_like": 0.45984993646884514,
86
- "emotion_loss": 0.15985918722207026,
87
- "emotion_f1": 0.2971099066666419,
88
- "topic_loss": 0.22285484572162303,
89
- "topic_accuracy": 0.9284572342126283,
90
- "total_loss": 3.894081538897767
91
  }
92
  }
 
1
  {
2
  "train_epoch_1": {
3
+ "summarization_loss": 4.343819652843475,
4
+ "summarization_rouge_like": 0.18423229737482247,
5
+ "emotion_loss": 0.4579651211887598,
6
+ "emotion_f1": 0.11036156222745776,
7
+ "topic_loss": 1.6671979689359664,
8
+ "topic_accuracy": 0.4339600000000011,
9
+ "total_loss": 6.364525709775378
10
  },
11
  "val_epoch_1": {
12
+ "summarization_loss": 4.259150079727172,
13
+ "summarization_rouge_like": 0.17393867024365672,
14
+ "emotion_loss": 0.15817135846614838,
15
+ "emotion_f1": 0.07330303180590272,
16
+ "topic_loss": 0.8542358543872833,
17
+ "topic_accuracy": 0.7782222222222225,
18
+ "total_loss": 5.179795800936217
19
  },
20
  "train_epoch_2": {
21
+ "summarization_loss": 4.103479218292236,
22
+ "summarization_rouge_like": 0.1940134706014566,
23
+ "emotion_loss": 0.15515640188455582,
24
+ "emotion_f1": 0.2337232402900234,
25
+ "topic_loss": 0.4888198138475418,
26
+ "topic_accuracy": 0.9067600000000139,
27
+ "total_loss": 4.7272696721971075
28
  },
29
  "val_epoch_2": {
30
+ "summarization_loss": 4.18841742515564,
31
+ "summarization_rouge_like": 0.17885957731314292,
32
+ "emotion_loss": 0.1549839802980423,
33
+ "emotion_f1": 0.3034666753411293,
34
+ "topic_loss": 0.4580393745005131,
35
+ "topic_accuracy": 0.852800000000002,
36
+ "total_loss": 4.787324895203119
37
  },
38
  "train_epoch_3": {
39
+ "summarization_loss": 4.041395119285584,
40
+ "summarization_rouge_like": 0.1970914256375089,
41
+ "emotion_loss": 0.15249912014603614,
42
+ "emotion_f1": 0.24187604461871087,
43
+ "topic_loss": 0.21472855980992317,
44
+ "topic_accuracy": 0.9627200000000102,
45
+ "total_loss": 4.441926647352566
46
  },
47
  "val_epoch_3": {
48
+ "summarization_loss": 4.16257409954071,
49
+ "summarization_rouge_like": 0.18115953723449993,
50
+ "emotion_loss": 0.15324361461400987,
51
+ "emotion_f1": 0.30253334194421766,
52
+ "topic_loss": 0.4939193711131811,
53
+ "topic_accuracy": 0.8632000000000015,
54
+ "total_loss": 4.7875750183522765
55
  },
56
  "train_epoch_4": {
57
+ "summarization_loss": 4.012135830116272,
58
+ "summarization_rouge_like": 0.19873380769300908,
59
+ "emotion_loss": 0.15166676665246487,
60
+ "emotion_f1": 0.24661330536156892,
61
+ "topic_loss": 0.14288409658223392,
62
+ "topic_accuracy": 0.9780000000000073,
63
+ "total_loss": 4.353943257360758
64
  },
65
  "val_epoch_4": {
66
+ "summarization_loss": 4.1532666339874265,
67
+ "summarization_rouge_like": 0.18147128191578765,
68
+ "emotion_loss": 0.15282477751374246,
69
+ "emotion_f1": 0.2984000087380409,
70
+ "topic_loss": 0.5214869263619184,
71
+ "topic_accuracy": 0.8580000000000017,
72
+ "total_loss": 4.799693341347577
73
  },
74
  "train_epoch_5": {
75
+ "summarization_loss": 4.002264401054382,
76
+ "summarization_rouge_like": 0.1992749810224614,
77
+ "emotion_loss": 0.15127245344221593,
78
+ "emotion_f1": 0.24676951464861632,
79
+ "topic_loss": 0.12673698243945836,
80
+ "topic_accuracy": 0.9796800000000072,
81
+ "total_loss": 4.330562667169272
82
  },
83
  "val_epoch_5": {
84
+ "summarization_loss": 4.149239055633545,
85
+ "summarization_rouge_like": 0.18202557571683906,
86
+ "emotion_loss": 0.15270190620422364,
87
+ "emotion_f1": 0.3021333419680595,
88
+ "topic_loss": 0.5217973904460669,
89
+ "topic_accuracy": 0.8580000000000011,
90
+ "total_loss": 4.795729827296732
91
  }
92
  }
scripts/demo_gradio.py CHANGED
@@ -107,20 +107,20 @@ def analyze_text(text: str) -> tuple[str, str, str]:
107
  # --------------- Sample Texts ---------------
108
 
109
  SAMPLES = {
110
- "business": """Global markets tumbled today as investors reacted to rising inflation concerns.
111
- The Federal Reserve hinted at potential interest rate hikes, sending shockwaves through technology
112
- and banking sectors. Analysts predict continued volatility as economic uncertainty persists.
113
- Major indices fell by over 2%, with tech stocks leading the decline.""",
114
 
115
- "science": """Scientists at MIT have developed a breakthrough quantum computing chip that
116
  operates at room temperature. This advancement could revolutionize drug discovery, cryptography,
117
  and artificial intelligence. The research team published their findings in Nature, demonstrating
118
  stable qubit operations for over 100 microseconds.""",
119
 
120
- "sports": """The championship game ended in dramatic fashion as the underdog team scored in
121
- the final seconds to secure victory. Fans rushed the field in celebration, marking the team's
122
- first title in 25 years. The winning goal came from a rookie player who had only joined the
123
- team this season.""",
124
  }
125
 
126
 
@@ -146,7 +146,7 @@ with gr.Blocks(title="LexiMind") as demo:
146
  text_input = gr.Textbox(
147
  label="Input Text",
148
  lines=6,
149
- placeholder="Paste a news article or any text to analyze...",
150
  )
151
  with gr.Row():
152
  analyze_btn = gr.Button("Analyze", variant="primary")
@@ -154,9 +154,9 @@ with gr.Blocks(title="LexiMind") as demo:
154
 
155
  gr.Markdown("**Quick samples:**")
156
  with gr.Row():
157
- btn_business = gr.Button("Business", size="sm")
158
  btn_science = gr.Button("Science", size="sm")
159
- btn_sports = gr.Button("Sports", size="sm")
160
 
161
  with gr.Column(scale=2):
162
  summary_output = gr.Textbox(label="Generated Summary", lines=4, interactive=False)
@@ -167,9 +167,9 @@ with gr.Blocks(title="LexiMind") as demo:
167
  # Event handlers
168
  analyze_btn.click(analyze_text, inputs=[text_input], outputs=[summary_output, emotions_output, topic_output])
169
  clear_btn.click(lambda: ("", "", "", ""), outputs=[text_input, summary_output, emotions_output, topic_output])
170
- btn_business.click(lambda: SAMPLES["business"], outputs=[text_input])
171
  btn_science.click(lambda: SAMPLES["science"], outputs=[text_input])
172
- btn_sports.click(lambda: SAMPLES["sports"], outputs=[text_input])
173
 
174
  # ===================== TAB 2: METRICS =====================
175
  with gr.Tab("Metrics"):
 
107
  # --------------- Sample Texts ---------------
108
 
109
  SAMPLES = {
110
+ "fiction": """The old lighthouse keeper had watched countless storms batter the rocky coast,
111
+ but nothing prepared him for what emerged from the fog that evening. A ship unlike any he'd
112
+ seen before - its hull seemingly made of living shadow - drifted silently toward the rocks.
113
+ He rang the warning bell, knowing somehow it wouldn't matter.""",
114
 
115
+ "science": """Researchers at MIT have developed a breakthrough quantum computing chip that
116
  operates at room temperature. This advancement could revolutionize drug discovery, cryptography,
117
  and artificial intelligence. The research team published their findings in Nature, demonstrating
118
  stable qubit operations for over 100 microseconds.""",
119
 
120
+ "technology": """The new large language model demonstrates unprecedented reasoning capabilities,
121
+ solving complex mathematical proofs and generating functional code across multiple programming
122
+ languages. Benchmarks show it outperforms previous systems by significant margins on tasks
123
+ requiring multi-step logical inference and long-context understanding.""",
124
  }
125
 
126
 
 
146
  text_input = gr.Textbox(
147
  label="Input Text",
148
  lines=6,
149
+ placeholder="Paste a book excerpt, research abstract, or any text to analyze...",
150
  )
151
  with gr.Row():
152
  analyze_btn = gr.Button("Analyze", variant="primary")
 
154
 
155
  gr.Markdown("**Quick samples:**")
156
  with gr.Row():
157
+ btn_fiction = gr.Button("Fiction", size="sm")
158
  btn_science = gr.Button("Science", size="sm")
159
+ btn_tech = gr.Button("Technology", size="sm")
160
 
161
  with gr.Column(scale=2):
162
  summary_output = gr.Textbox(label="Generated Summary", lines=4, interactive=False)
 
167
  # Event handlers
168
  analyze_btn.click(analyze_text, inputs=[text_input], outputs=[summary_output, emotions_output, topic_output])
169
  clear_btn.click(lambda: ("", "", "", ""), outputs=[text_input, summary_output, emotions_output, topic_output])
170
+ btn_fiction.click(lambda: SAMPLES["fiction"], outputs=[text_input])
171
  btn_science.click(lambda: SAMPLES["science"], outputs=[text_input])
172
+ btn_tech.click(lambda: SAMPLES["technology"], outputs=[text_input])
173
 
174
  # ===================== TAB 2: METRICS =====================
175
  with gr.Tab("Metrics"):
scripts/download_data.py CHANGED
@@ -5,19 +5,24 @@
5
  """
6
  Dataset download script for LexiMind.
7
 
8
- Downloads and prepares training datasets:
9
- - CNN/DailyMail + BookSum for summarization (news + literary)
10
- - Project Gutenberg books for additional literary training
 
 
 
 
 
11
  - GoEmotions for emotion classification (28 labels)
12
- - AG News for topic classification (4 labels: World, Sports, Business, Sci/Tech)
13
 
14
  Usage:
15
  python scripts/download_data.py # Download all
16
- python scripts/download_data.py --task topic # Download specific task
17
- python scripts/download_data.py --max-books 30000 --max-gutenberg 20000
18
 
19
  Author: Oliver Perrin
20
- Date: December 2025
21
  """
22
 
23
  from __future__ import annotations
@@ -35,7 +40,9 @@ from tqdm import tqdm
35
  # Output directory
36
  OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
37
 
38
- # Label definitions
 
 
39
  EMOTION_LABELS = [
40
  "admiration", "amusement", "anger", "annoyance", "approval", "caring",
41
  "confusion", "curiosity", "desire", "disappointment", "disapproval",
@@ -44,7 +51,57 @@ EMOTION_LABELS = [
44
  "relief", "remorse", "sadness", "surprise", "neutral",
45
  ]
46
 
47
- TOPIC_LABELS = ["World", "Sports", "Business", "Sci/Tech"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing") -> None:
@@ -56,225 +113,359 @@ def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing"
56
  print(f" ✓ {len(records):,} samples → {path}")
57
 
58
 
59
- def download_summarization(max_news: int = 80000, max_books: int = 30000) -> None:
60
- """Download CNN/DailyMail + BookSum for summarization."""
61
- print("\n📰 Downloading Summarization...")
62
- out_dir = OUTPUT_DIR / "summarization"
63
-
64
- all_train: list[dict[str, Any]] = []
65
- all_val: list[dict[str, Any]] = []
66
- all_test: list[dict[str, Any]] = []
67
-
68
- # CNN/DailyMail - great for news summarization
69
- print(" Loading CNN/DailyMail...")
70
- cnn = load_dataset("cnn_dailymail", "3.0.0")
71
-
72
- for split_name in cnn.keys():
73
- split = str(split_name)
74
- data = cnn[split_name]
75
- limit = max_news if "train" in split else max_news // 10
76
- indices = random.sample(range(len(data)), min(len(data), limit))
77
-
78
- records: list[dict[str, Any]] = []
79
- for i in indices:
80
- item = data[i]
81
- article = item["article"]
82
- highlights = item["highlights"]
83
- if article and highlights:
84
- records.append({"source": article, "summary": highlights})
85
-
86
- if "train" in split:
87
- all_train.extend(records)
88
- elif "val" in split:
89
- all_val.extend(records)
90
- else:
91
- all_test.extend(records)
92
- print(f" {split}: {len(records):,}")
93
 
94
- # BookSum - literary text summarization (chapters → summaries)
95
- print(" Loading BookSum...")
96
  booksum = load_dataset("kmfoda/booksum")
97
 
98
  for split_name in booksum.keys():
99
  split = str(split_name)
100
  data = booksum[split_name]
101
- limit = max_books if "train" in split else max_books // 10
102
  indices = random.sample(range(len(data)), min(len(data), limit))
103
 
104
  records = []
105
- for i in indices:
106
  item = data[i]
107
  chapter = item.get("chapter", "")
108
  summary = item.get("summary_text") or item.get("summary", "")
109
  if chapter and summary and len(chapter) > 300:
110
- # Truncate very long chapters to fit model context
111
- records.append({"source": chapter[:4000], "summary": summary})
112
-
113
- if "train" in split:
114
- all_train.extend(records)
115
- elif "val" in split:
116
- all_val.extend(records)
117
- else:
118
- all_test.extend(records)
119
  print(f" {split}: {len(records):,}")
120
 
121
- random.shuffle(all_train)
122
- write_jsonl(all_train, out_dir / "train.jsonl", "train")
123
- write_jsonl(all_val, out_dir / "validation.jsonl", "validation")
124
- write_jsonl(all_test, out_dir / "test.jsonl", "test")
125
 
126
 
127
- # Patterns to filter out Gutenberg boilerplate
128
- GUTENBERG_JUNK_PATTERNS = [
129
- r"Project Gutenberg",
130
- r"www\.gutenberg\.org",
131
- r"This ebook is for the use of",
132
- r"You may copy it, give it away",
133
- r"Gutenberg License",
134
- r"^\*\*\* START OF",
135
- r"^\*\*\* END OF",
136
- r"Produced by",
137
- r"Transcriber's Note",
138
- r"Editor's Note",
139
- r"TABLE OF CONTENTS",
140
- r"CONTENTS\s*$",
141
- r"^\s*CHAPTER\s+[IVXLC\d]+",
142
- r"^\s*Chapter\s+[IVXLC\d]+",
143
- r"^\s*BOOK\s+[IVXLC\d]+",
144
- r"^\s*PART\s+[IVXLC\d]+",
145
- r"^\s*PREFACE\s*$",
146
- r"^\s*INTRODUCTION\s*$",
147
- r"^\s*EPILOGUE\s*$",
148
- r"^\s*PROLOGUE\s*$",
149
- r"^\s*APPENDIX",
150
- r"^\s*INDEX\s*$",
151
- r"^\s*FOOTNOTES?\s*$",
152
- r"^\s*\[Illustration",
153
- r"^\s*\[Transcriber",
154
- r"E-text prepared by",
155
- r"Internet Archive",
156
- r"This file was produced",
157
- r"Distributed Proofreaders",
158
- r"^\s*_+\s*$", # Lines of underscores
159
- r"^\s*\*+\s*$", # Lines of asterisks
160
- ]
161
- GUTENBERG_JUNK_REGEX = re.compile("|".join(GUTENBERG_JUNK_PATTERNS), re.IGNORECASE)
162
 
163
 
164
- def is_clean_prose(text: str) -> bool:
165
- """Check if text is clean literary prose (not boilerplate/metadata)."""
166
- # Must be substantial
167
- if len(text) < 300 or len(text) > 3000:
168
- return False
169
 
170
- # Skip if contains Gutenberg boilerplate
171
- if GUTENBERG_JUNK_REGEX.search(text):
172
- return False
173
 
174
- # Must have actual sentences (prose check)
175
- # Good prose has periods, commas, and lowercase letters
176
- if text.count('.') < 2:
177
- return False
178
 
179
- # Skip if mostly uppercase (headers, titles)
180
- uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
181
- if uppercase_ratio > 0.3:
182
- return False
183
 
184
- # Skip if too many numbers (tables, dates, page numbers)
185
- digit_ratio = sum(1 for c in text if c.isdigit()) / max(len(text), 1)
186
- if digit_ratio > 0.1:
187
- return False
188
 
189
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
 
192
- def download_gutenberg(max_samples: int = 20000) -> None:
193
  """
194
- Download Project Gutenberg books for literary language modeling.
195
 
196
- Uses the standardized_gutenberg dataset which has clean, parsed books.
197
- Creates paragraph-level chunks for training diversity.
198
- Filters out boilerplate (headers, licenses, TOC, etc).
199
  """
200
- print("\n📚 Downloading Gutenberg Books...")
201
- out_dir = OUTPUT_DIR / "books"
202
- out_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- # Load Gutenberg dataset - has ~60K books
205
- print(" Loading standardized_gutenberg dataset...")
206
  try:
207
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
208
  except Exception:
209
- # Fallback to alternative dataset
210
- print(" Trying alternative: pg19...")
211
  gutenberg = load_dataset("pg19", split="train")
212
 
213
  records: list[dict[str, Any]] = []
214
- books_processed = 0
215
- chunks_filtered = 0
216
 
217
- # Sample books randomly
218
  indices = list(range(len(gutenberg)))
219
  random.shuffle(indices)
220
 
221
- print(" Processing books into clean prose chunks...")
222
- for i in tqdm(indices, desc="Books", leave=False):
223
  if len(records) >= max_samples:
224
  break
225
-
226
  item = gutenberg[i]
227
- # Handle both uppercase (sedthh/gutenberg_english) and lowercase (pg19) keys
228
- text = item.get("TEXT", "") or item.get("text", "") or item.get("content", "")
229
  metadata = item.get("METADATA", {}) or {}
230
- title = metadata.get("title", "") if isinstance(metadata, dict) else ""
231
- if not title:
232
- title = item.get("title", f"Book_{i}")
233
 
234
  if not text or len(text) < 1000:
235
  continue
236
 
237
- # Split into paragraphs for diverse training samples
238
- paragraphs = re.split(r'\n\s*\n', text)
 
 
 
 
239
 
240
- for para in paragraphs:
241
- para = para.strip()
242
-
243
- # Use strict filtering for clean prose only
244
- if is_clean_prose(para):
245
- records.append({
246
- "text": para,
247
- "title": title,
248
- "type": "gutenberg"
249
- })
250
- if len(records) >= max_samples:
251
- break
252
- else:
253
- chunks_filtered += 1
254
 
255
- books_processed += 1
256
-
257
- # Split into train/val/test (90/5/5)
258
- random.shuffle(records)
259
- n = len(records)
260
- train_end = int(n * 0.9)
261
- val_end = int(n * 0.95)
262
-
263
- train_records = records[:train_end]
264
- val_records = records[train_end:val_end]
265
- test_records = records[val_end:]
266
-
267
- write_jsonl(train_records, out_dir / "train.jsonl", "train")
268
- write_jsonl(val_records, out_dir / "validation.jsonl", "validation")
269
- write_jsonl(test_records, out_dir / "test.jsonl", "test")
 
270
 
271
- print(f" {books_processed:,} books → {len(records):,} clean prose chunks")
272
- print(f" ✓ Filtered out {chunks_filtered:,} boilerplate/metadata chunks")
 
273
 
 
274
 
275
  def download_emotions() -> None:
276
  """Download GoEmotions for emotion classification."""
277
- print("\n😊 Downloading Emotions...")
278
  out_dir = OUTPUT_DIR / "emotion"
279
 
280
  ds = load_dataset("google-research-datasets/go_emotions", "simplified")
@@ -297,53 +488,94 @@ def download_emotions() -> None:
297
  print(f" ✓ {len(EMOTION_LABELS)} emotion labels saved")
298
 
299
 
300
- def download_topics(max_samples: int = 100000) -> None:
301
- """Download AG News for topic classification (4 clean categories)."""
302
- print("\n📂 Downloading Topics...")
303
- out_dir = OUTPUT_DIR / "topic"
304
-
305
- ds = load_dataset("fancyzhx/ag_news")
306
- train_data = ds["train"]
307
- test_data = ds["test"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
- # Split train into train/val
310
- all_idx = list(range(len(train_data)))
311
- random.shuffle(all_idx)
312
- train_idx = all_idx[:max_samples]
313
- val_idx = all_idx[max_samples:max_samples + max_samples // 10]
314
 
315
- splits_config = [
316
- ("train", train_idx, train_data),
317
- ("validation", val_idx, train_data),
318
- ("test", list(range(len(test_data))), test_data),
319
- ]
320
 
321
- for split_name, indices, data in splits_config:
322
- records: list[dict[str, Any]] = []
323
- for i in tqdm(indices, desc=split_name, leave=False):
324
- item = data[i]
325
- text = item.get("text", "")
326
- label = item.get("label", 0)
327
- if text and len(text) > 50:
328
- records.append({"text": text, "topic": TOPIC_LABELS[label]})
329
- write_jsonl(records, out_dir / f"{split_name}.jsonl", split_name)
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
- (out_dir / "labels.json").write_text(json.dumps(TOPIC_LABELS, indent=2))
332
- print(f" ✓ {len(TOPIC_LABELS)} topic labels saved")
 
 
 
 
333
 
 
334
 
335
  def main() -> None:
336
  parser = argparse.ArgumentParser(description="Download LexiMind datasets")
337
  parser.add_argument(
338
- "--task",
339
  choices=["all", "summarization", "emotion", "topic", "gutenberg"],
340
- default="all",
341
  help="Dataset to download"
342
  )
343
- parser.add_argument("--max-news", type=int, default=80000, help="Max news articles")
344
- parser.add_argument("--max-books", type=int, default=30000, help="Max BookSum chapters")
345
- parser.add_argument("--max-gutenberg", type=int, default=20000, help="Max Gutenberg chunks")
346
- parser.add_argument("--max-topics", type=int, default=100000, help="Max topic samples")
347
  parser.add_argument("--seed", type=int, default=42, help="Random seed")
348
  args = parser.parse_args()
349
 
@@ -351,16 +583,17 @@ def main() -> None:
351
 
352
  print("=" * 60)
353
  print("LexiMind Dataset Download")
 
354
  print("=" * 60)
355
 
356
  if args.task in ["all", "summarization"]:
357
- download_summarization(args.max_news, args.max_books)
358
- if args.task in ["all", "gutenberg"]:
359
- download_gutenberg(args.max_gutenberg)
360
  if args.task in ["all", "emotion"]:
361
  download_emotions()
362
  if args.task in ["all", "topic"]:
363
  download_topics(args.max_topics)
 
 
364
 
365
  print("\n" + "=" * 60)
366
  print("✅ Download complete!")
 
5
  """
6
  Dataset download script for LexiMind.
7
 
8
+ Focus: Books, Academic Papers, Technical Writing
9
+ - NO news articles (overdone, dated)
10
+ - YES literary text, research, technical writing
11
+
12
+ Datasets:
13
+ - BookSum for literary summarization
14
+ - arXiv for academic paper summarization
15
+ - Project Gutenberg for literary language
16
  - GoEmotions for emotion classification (28 labels)
17
+ - Custom topic classification: Fiction, Science, Technology, etc.
18
 
19
  Usage:
20
  python scripts/download_data.py # Download all
21
+ python scripts/download_data.py --task arxiv # Download specific task
22
+ python scripts/download_data.py --max-arxiv 50000
23
 
24
  Author: Oliver Perrin
25
+ Date: January 2026
26
  """
27
 
28
  from __future__ import annotations
 
40
  # Output directory
41
  OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
42
 
43
+ # ============== LABEL DEFINITIONS ==============
44
+
45
+ # 28 emotions from GoEmotions - works for all text types
46
  EMOTION_LABELS = [
47
  "admiration", "amusement", "anger", "annoyance", "approval", "caring",
48
  "confusion", "curiosity", "desire", "disappointment", "disapproval",
 
51
  "relief", "remorse", "sadness", "surprise", "neutral",
52
  ]
53
 
54
+ # New topic labels for books + papers + blogs
55
+ TOPIC_LABELS = [
56
+ "Fiction", # Novels, short stories, literary fiction
57
+ "Science", # Physics, chemistry, biology, nature
58
+ "Technology", # CS, engineering, programming, AI/ML
59
+ "Philosophy", # Ethics, logic, metaphysics, epistemology
60
+ "History", # Historical texts, biographies, memoirs
61
+ "Psychology", # Mind, behavior, self-help, mental health
62
+ "Business", # Economics, finance, entrepreneurship
63
+ "Arts", # Music, visual arts, film, architecture
64
+ ]
65
+
66
+ # arXiv category → our topic mapping
67
+ ARXIV_CATEGORY_MAP = {
68
+ # Computer Science
69
+ "cs.AI": "Technology", "cs.CL": "Technology", "cs.CV": "Technology",
70
+ "cs.LG": "Technology", "cs.NE": "Technology", "cs.RO": "Technology",
71
+ "cs.SE": "Technology", "cs.PL": "Technology", "cs.DB": "Technology",
72
+ "cs.DS": "Technology", "cs.CR": "Technology", "cs.DC": "Technology",
73
+ "cs.HC": "Technology", "cs.IR": "Technology", "cs.IT": "Technology",
74
+ "cs.MA": "Technology", "cs.MM": "Technology", "cs.NI": "Technology",
75
+ "cs.OS": "Technology", "cs.PF": "Technology", "cs.SY": "Technology",
76
+ # Physics
77
+ "physics": "Science", "astro-ph": "Science", "cond-mat": "Science",
78
+ "gr-qc": "Science", "hep-ex": "Science", "hep-lat": "Science",
79
+ "hep-ph": "Science", "hep-th": "Science", "math-ph": "Science",
80
+ "nlin": "Science", "nucl-ex": "Science", "nucl-th": "Science",
81
+ "quant-ph": "Science",
82
+ # Math
83
+ "math": "Science",
84
+ # Biology/Medicine
85
+ "q-bio": "Science", "stat": "Science",
86
+ # Economics/Finance
87
+ "econ": "Business", "q-fin": "Business",
88
+ # Electrical Engineering
89
+ "eess": "Technology",
90
+ }
91
+
92
+ # Gutenberg subject → our topic mapping
93
+ GUTENBERG_SUBJECT_MAP = {
94
+ "fiction": "Fiction", "novel": "Fiction", "stories": "Fiction",
95
+ "poetry": "Arts", "drama": "Arts", "plays": "Arts",
96
+ "science": "Science", "physics": "Science", "chemistry": "Science",
97
+ "biology": "Science", "nature": "Science", "astronomy": "Science",
98
+ "philosophy": "Philosophy", "ethics": "Philosophy", "logic": "Philosophy",
99
+ "history": "History", "biography": "History", "memoir": "History",
100
+ "psychology": "Psychology", "mind": "Psychology",
101
+ "economics": "Business", "business": "Business", "finance": "Business",
102
+ "art": "Arts", "music": "Arts", "architecture": "Arts",
103
+ "technology": "Technology", "engineering": "Technology",
104
+ }
105
 
106
 
107
  def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing") -> None:
 
113
  print(f" ✓ {len(records):,} samples → {path}")
114
 
115
 
116
+ # ============== SUMMARIZATION: BOOKS + ARXIV ==============
117
+
118
+ def download_booksum(max_samples: int = 40000) -> list[dict[str, Any]]:
119
+ """Download BookSum - literary chapter summarization."""
120
+ print("\n📖 Loading BookSum (literary summarization)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ all_records: list[dict[str, Any]] = []
 
123
  booksum = load_dataset("kmfoda/booksum")
124
 
125
  for split_name in booksum.keys():
126
  split = str(split_name)
127
  data = booksum[split_name]
128
+ limit = max_samples if "train" in split else max_samples // 10
129
  indices = random.sample(range(len(data)), min(len(data), limit))
130
 
131
  records = []
132
+ for i in tqdm(indices, desc=f"BookSum {split}", leave=False):
133
  item = data[i]
134
  chapter = item.get("chapter", "")
135
  summary = item.get("summary_text") or item.get("summary", "")
136
  if chapter and summary and len(chapter) > 300:
137
+ records.append({
138
+ "source": chapter[:4000],
139
+ "summary": summary,
140
+ "type": "literary",
141
+ "split": split,
142
+ })
143
+ all_records.extend(records)
 
 
144
  print(f" {split}: {len(records):,}")
145
 
146
+ return all_records
 
 
 
147
 
148
 
149
+ def clean_arxiv_text(text: str) -> str:
150
+ """Clean arXiv LaTeX-style text to make it more readable."""
151
+ import re
152
+ # Remove LaTeX math placeholders
153
+ text = re.sub(r'@xmath\d+', '', text)
154
+ text = re.sub(r'@xcite', '', text)
155
+ # Remove excessive whitespace
156
+ text = re.sub(r'\s+', ' ', text)
157
+ # Remove LaTeX commands
158
+ text = re.sub(r'\\[a-zA-Z]+\{[^}]*\}', '', text)
159
+ text = re.sub(r'\\[a-zA-Z]+', '', text)
160
+ return text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
+ def download_arxiv_summarization(max_samples: int = 50000) -> list[dict[str, Any]]:
164
+ """
165
+ Download arXiv papers for academic summarization only.
166
+ Note: This dataset doesn't have categories, so can't be used for topic classification.
 
167
 
168
+ Returns: summarization_records
169
+ """
170
+ print("\n🎓 Loading arXiv (academic papers for summarization)...")
171
 
172
+ print(" Loading dataset (this may take a minute)...")
173
+ arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
 
 
174
 
175
+ summ_records: list[dict[str, Any]] = []
 
 
 
176
 
177
+ indices = list(range(len(arxiv)))
178
+ random.shuffle(indices)
 
 
179
 
180
+ print(" Processing papers...")
181
+ for i in tqdm(indices[:max_samples * 2], desc="arXiv", leave=False):
182
+ if len(summ_records) >= max_samples:
183
+ break
184
+
185
+ item = arxiv[i]
186
+
187
+ # Get abstract and article
188
+ abstract = item.get("abstract", "")
189
+ article = item.get("article", "")
190
+
191
+ if not abstract or len(abstract) < 100:
192
+ continue
193
+
194
+ # Clean LaTeX artifacts
195
+ abstract = clean_arxiv_text(abstract)
196
+ article = clean_arxiv_text(article)
197
+
198
+ # Skip if still has too many weird characters after cleaning
199
+ if '@' in abstract or '@' in article[:500]:
200
+ continue
201
+
202
+ # Summarization: article → abstract
203
+ if article and len(article) > 500:
204
+ summ_records.append({
205
+ "source": article[:4000],
206
+ "summary": abstract,
207
+ "type": "academic",
208
+ })
209
+
210
+ print(f" Summarization: {len(summ_records):,}")
211
+
212
+ return summ_records
213
 
214
 
215
+ def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, Any]]:
216
  """
217
+ Download topic classification data from multiple sources with real categories.
218
 
219
+ Sources:
220
+ - 20 Newsgroups (classic topic classification)
221
+ - Wikipedia (article categories)
222
  """
223
+ print("\n📂 Loading topic classification datasets...")
224
+
225
+ records: list[dict[str, Any]] = []
226
+
227
+ # 20 Newsgroups - classic topic dataset
228
+ print(" Loading 20 Newsgroups...")
229
+ try:
230
+ newsgroups = load_dataset("SetFit/20_newsgroups", split="train")
231
+
232
+ # Map 20 newsgroups categories to our 8 topics
233
+ newsgroup_map = {
234
+ # Science
235
+ "sci.crypt": "Science", "sci.electronics": "Science",
236
+ "sci.med": "Science", "sci.space": "Science",
237
+ # Technology
238
+ "comp.graphics": "Technology", "comp.os.ms-windows.misc": "Technology",
239
+ "comp.sys.ibm.pc.hardware": "Technology", "comp.sys.mac.hardware": "Technology",
240
+ "comp.windows.x": "Technology",
241
+ # Philosophy/Religion
242
+ "alt.atheism": "Philosophy", "soc.religion.christian": "Philosophy",
243
+ "talk.religion.misc": "Philosophy",
244
+ # History/Politics
245
+ "talk.politics.guns": "History", "talk.politics.mideast": "History",
246
+ "talk.politics.misc": "History",
247
+ # Business
248
+ "misc.forsale": "Business",
249
+ # Sports/Recreation
250
+ "rec.autos": "Arts", "rec.motorcycles": "Arts",
251
+ "rec.sport.baseball": "Arts", "rec.sport.hockey": "Arts",
252
+ }
253
+
254
+ for item in tqdm(newsgroups, desc="20 Newsgroups", leave=False):
255
+ if len(records) >= max_samples:
256
+ break
257
+ label_name = item.get("label_text", "")
258
+ text = item.get("text", "")
259
+
260
+ if label_name in newsgroup_map and text and len(text) > 100:
261
+ records.append({
262
+ "text": text[:1500],
263
+ "topic": newsgroup_map[label_name],
264
+ "source": "newsgroups",
265
+ })
266
+
267
+ print(f" 20 Newsgroups: {len(records):,}")
268
+ except Exception as e:
269
+ print(f" 20 Newsgroups failed: {e}")
270
+
271
+ # Add from Gutenberg for Fiction
272
+ gutenberg_topics = download_gutenberg_topics(max_samples // 4)
273
+ records.extend(gutenberg_topics)
274
+
275
+ # Add from scientific papers abstract dataset for more Science/Tech
276
+ print(" Loading scientific papers...")
277
+ try:
278
+ sci_papers = load_dataset("scientific_papers", "arxiv", split="train", streaming=True)
279
+ sci_count = 0
280
+ for item in tqdm(sci_papers, desc="Scientific papers", leave=False, total=max_samples//4):
281
+ if sci_count >= max_samples // 4:
282
+ break
283
+ abstract = item.get("abstract", "")
284
+ if abstract and len(abstract) > 100:
285
+ # Alternate between Science and Technology
286
+ topic = "Science" if sci_count % 2 == 0 else "Technology"
287
+ records.append({
288
+ "text": abstract[:1500],
289
+ "topic": topic,
290
+ "source": "scientific_papers",
291
+ })
292
+ sci_count += 1
293
+ print(f" Scientific papers: {sci_count:,}")
294
+ except Exception as e:
295
+ print(f" Scientific papers failed: {e}")
296
+
297
+ return records
298
+
299
+
300
+ def download_summarization(max_books: int = 40000, max_arxiv: int = 50000) -> None:
301
+ """Download all summarization data (books + arxiv, NO news)."""
302
+ print("\n📝 Downloading Summarization Data...")
303
+ out_dir = OUTPUT_DIR / "summarization"
304
+
305
+ all_records: list[dict[str, Any]] = []
306
+
307
+ # BookSum - literary
308
+ book_records = download_booksum(max_books)
309
+ all_records.extend(book_records)
310
+
311
+ # arXiv - academic (summarization only, no categories in this dataset)
312
+ arxiv_summ = download_arxiv_summarization(max_arxiv)
313
+ all_records.extend(arxiv_summ)
314
+
315
+ # Shuffle and split
316
+ random.shuffle(all_records)
317
+
318
+ # Split by original split if available, else 90/5/5
319
+ train_records = [r for r in all_records if r.get("split", "train") == "train" or "split" not in r]
320
+ val_records = [r for r in all_records if r.get("split") == "validation"]
321
+ test_records = [r for r in all_records if r.get("split") == "test"]
322
+
323
+ # If no split info, do 90/5/5
324
+ if len(val_records) < 100:
325
+ n = len(train_records)
326
+ random.shuffle(train_records)
327
+ val_records = train_records[int(n*0.9):int(n*0.95)]
328
+ test_records = train_records[int(n*0.95):]
329
+ train_records = train_records[:int(n*0.9)]
330
+
331
+ # Remove split key before saving
332
+ for r in train_records + val_records + test_records:
333
+ r.pop("split", None)
334
+
335
+ write_jsonl(train_records, out_dir / "train.jsonl", "train")
336
+ write_jsonl(val_records, out_dir / "validation.jsonl", "val")
337
+ write_jsonl(test_records, out_dir / "test.jsonl", "test")
338
+
339
+ print(f"\n ✓ Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
340
+
341
+
342
+ # ============== TOPIC CLASSIFICATION ==============
343
+
344
+ def download_topics(max_samples: int = 50000) -> None:
345
+ """
346
+ Download topic classification data from multiple sources.
347
+
348
+ Sources:
349
+ - 20 Newsgroups (classic topic dataset)
350
+ - Gutenberg books (Fiction)
351
+ - Scientific papers (Science, Technology)
352
+ """
353
+ print("\n📂 Downloading Topic Classification...")
354
+ out_dir = OUTPUT_DIR / "topic"
355
+
356
+ # Get topic records from various sources
357
+ all_records = download_topics_from_datasets(max_samples)
358
+
359
+ # Balance topics
360
+ topic_counts: dict[str, list] = {t: [] for t in TOPIC_LABELS}
361
+ for r in all_records:
362
+ topic = r.get("topic")
363
+ if topic in topic_counts:
364
+ topic_counts[topic].append(r)
365
+
366
+ # Print distribution before balancing
367
+ print("\n Topic distribution (before balancing):")
368
+ for topic, records in topic_counts.items():
369
+ print(f" {topic}: {len(records):,}")
370
+
371
+ # Balance to min count (with some tolerance) - only from topics that have data
372
+ counts_with_data = [len(v) for v in topic_counts.values() if v]
373
+ if not counts_with_data:
374
+ print(" ⚠️ No topic data found!")
375
+ return
376
+
377
+ min_count = min(counts_with_data)
378
+ target_count = min(min_count, max_samples // len(TOPIC_LABELS))
379
+
380
+ balanced: list[dict[str, Any]] = []
381
+ for topic, records in topic_counts.items():
382
+ if records:
383
+ random.shuffle(records)
384
+ balanced.extend(records[:target_count])
385
+
386
+ random.shuffle(balanced)
387
+
388
+ # Split 90/5/5
389
+ n = len(balanced)
390
+ train_records = balanced[:int(n*0.9)]
391
+ val_records = balanced[int(n*0.9):int(n*0.95)]
392
+ test_records = balanced[int(n*0.95):]
393
+
394
+ write_jsonl(train_records, out_dir / "train.jsonl", "train")
395
+ write_jsonl(val_records, out_dir / "validation.jsonl", "val")
396
+ write_jsonl(test_records, out_dir / "test.jsonl", "test")
397
+
398
+ # Save labels - only labels that have data
399
+ used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
400
+ (out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
401
+ print(f"\n ✓ {len(used_labels)} topic labels with data: {used_labels}")
402
+
403
+
404
+ def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
405
+ """Extract topic-labeled samples from Gutenberg books."""
406
+ print("\n📚 Loading Gutenberg for topic classification...")
407
 
 
 
408
  try:
409
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
410
  except Exception:
411
+ print(" Trying pg19...")
 
412
  gutenberg = load_dataset("pg19", split="train")
413
 
414
  records: list[dict[str, Any]] = []
 
 
415
 
 
416
  indices = list(range(len(gutenberg)))
417
  random.shuffle(indices)
418
 
419
+ for i in tqdm(indices, desc="Gutenberg topics", leave=False):
 
420
  if len(records) >= max_samples:
421
  break
422
+
423
  item = gutenberg[i]
424
+ text = item.get("TEXT", "") or item.get("text", "")
 
425
  metadata = item.get("METADATA", {}) or {}
 
 
 
426
 
427
  if not text or len(text) < 1000:
428
  continue
429
 
430
+ # Try to determine topic from metadata
431
+ subjects = ""
432
+ if isinstance(metadata, dict):
433
+ subjects = str(metadata.get("subjects", "")).lower()
434
+ subjects += " " + str(metadata.get("subject", "")).lower()
435
+ subjects += " " + str(metadata.get("category", "")).lower()
436
 
437
+ topic = None
438
+ for keyword, mapped_topic in GUTENBERG_SUBJECT_MAP.items():
439
+ if keyword in subjects:
440
+ topic = mapped_topic
441
+ break
 
 
 
 
 
 
 
 
 
442
 
443
+ # Default fiction for novels without clear subject
444
+ if not topic and ("novel" in subjects or not subjects.strip()):
445
+ topic = "Fiction"
446
+
447
+ if topic:
448
+ # Get a clean paragraph as sample
449
+ paragraphs = re.split(r'\n\s*\n', text)
450
+ for para in paragraphs[5:]: # Skip front matter
451
+ para = para.strip()
452
+ if 200 < len(para) < 1500 and para.count('.') >= 2:
453
+ records.append({
454
+ "text": para,
455
+ "topic": topic,
456
+ "source": "gutenberg",
457
+ })
458
+ break
459
 
460
+ print(f" Gutenberg topics: {len(records):,}")
461
+ return records
462
+
463
 
464
+ # ============== EMOTIONS (unchanged) ==============
465
 
466
  def download_emotions() -> None:
467
  """Download GoEmotions for emotion classification."""
468
+ print("\n😊 Downloading Emotions (GoEmotions)...")
469
  out_dir = OUTPUT_DIR / "emotion"
470
 
471
  ds = load_dataset("google-research-datasets/go_emotions", "simplified")
 
488
  print(f" ✓ {len(EMOTION_LABELS)} emotion labels saved")
489
 
490
 
491
+ # ============== GUTENBERG BOOKS (for language modeling) ==============
492
+
493
+ GUTENBERG_JUNK_PATTERNS = [
494
+ r"Project Gutenberg", r"www\.gutenberg\.org", r"This ebook is for",
495
+ r"Gutenberg License", r"^\*\*\* START OF", r"^\*\*\* END OF",
496
+ r"Produced by", r"Transcriber's Note", r"TABLE OF CONTENTS",
497
+ r"^\s*CHAPTER\s+[IVXLC\d]+", r"^\s*Chapter\s+[IVXLC\d]+",
498
+ r"^\s*BOOK\s+[IVXLC\d]+", r"^\s*PREFACE\s*$", r"^\s*INTRODUCTION\s*$",
499
+ r"E-text prepared by", r"Internet Archive", r"Distributed Proofreaders",
500
+ ]
501
+ GUTENBERG_JUNK_REGEX = re.compile("|".join(GUTENBERG_JUNK_PATTERNS), re.IGNORECASE)
502
+
503
+
504
+ def is_clean_prose(text: str) -> bool:
505
+ """Check if text is clean literary prose."""
506
+ if len(text) < 300 or len(text) > 3000:
507
+ return False
508
+ if GUTENBERG_JUNK_REGEX.search(text):
509
+ return False
510
+ if text.count('.') < 2:
511
+ return False
512
+ uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
513
+ if uppercase_ratio > 0.3:
514
+ return False
515
+ digit_ratio = sum(1 for c in text if c.isdigit()) / max(len(text), 1)
516
+ if digit_ratio > 0.1:
517
+ return False
518
+ return True
519
+
520
+
521
+ def download_gutenberg(max_samples: int = 30000) -> None:
522
+ """Download Gutenberg books for language modeling."""
523
+ print("\n📚 Downloading Gutenberg Books...")
524
+ out_dir = OUTPUT_DIR / "books"
525
+ out_dir.mkdir(parents=True, exist_ok=True)
526
 
527
+ try:
528
+ gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
529
+ except Exception:
530
+ gutenberg = load_dataset("pg19", split="train")
 
531
 
532
+ records: list[dict[str, Any]] = []
533
+ indices = list(range(len(gutenberg)))
534
+ random.shuffle(indices)
 
 
535
 
536
+ for i in tqdm(indices, desc="Books", leave=False):
537
+ if len(records) >= max_samples:
538
+ break
539
+
540
+ item = gutenberg[i]
541
+ text = item.get("TEXT", "") or item.get("text", "")
542
+ metadata = item.get("METADATA", {}) or {}
543
+ title = metadata.get("title", "") if isinstance(metadata, dict) else ""
544
+ if not title:
545
+ title = item.get("title", f"Book_{i}")
546
+
547
+ if not text or len(text) < 1000:
548
+ continue
549
+
550
+ paragraphs = re.split(r'\n\s*\n', text)
551
+ for para in paragraphs:
552
+ para = para.strip()
553
+ if is_clean_prose(para):
554
+ records.append({"text": para, "title": title, "type": "gutenberg"})
555
+ if len(records) >= max_samples:
556
+ break
557
 
558
+ random.shuffle(records)
559
+ n = len(records)
560
+ write_jsonl(records[:int(n*0.9)], out_dir / "train.jsonl", "train")
561
+ write_jsonl(records[int(n*0.9):int(n*0.95)], out_dir / "validation.jsonl", "val")
562
+ write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
563
+
564
 
565
+ # ============== MAIN ==============
566
 
567
  def main() -> None:
568
  parser = argparse.ArgumentParser(description="Download LexiMind datasets")
569
  parser.add_argument(
570
+ "--task",
571
  choices=["all", "summarization", "emotion", "topic", "gutenberg"],
572
+ default="all",
573
  help="Dataset to download"
574
  )
575
+ parser.add_argument("--max-books", type=int, default=40000, help="Max BookSum samples")
576
+ parser.add_argument("--max-arxiv", type=int, default=50000, help="Max arXiv samples")
577
+ parser.add_argument("--max-gutenberg", type=int, default=30000, help="Max Gutenberg chunks")
578
+ parser.add_argument("--max-topics", type=int, default=50000, help="Max topic samples")
579
  parser.add_argument("--seed", type=int, default=42, help="Random seed")
580
  args = parser.parse_args()
581
 
 
583
 
584
  print("=" * 60)
585
  print("LexiMind Dataset Download")
586
+ print("Books + Academic Papers + Topic Classification")
587
  print("=" * 60)
588
 
589
  if args.task in ["all", "summarization"]:
590
+ download_summarization(args.max_books, args.max_arxiv)
 
 
591
  if args.task in ["all", "emotion"]:
592
  download_emotions()
593
  if args.task in ["all", "topic"]:
594
  download_topics(args.max_topics)
595
+ if args.task in ["all", "gutenberg"]:
596
+ download_gutenberg(args.max_gutenberg)
597
 
598
  print("\n" + "=" * 60)
599
  print("✅ Download complete!")
scripts/train.py CHANGED
@@ -3,9 +3,9 @@
3
  Training script for LexiMind.
4
 
5
  Simple, clean training with multi-task learning across:
6
- - Summarization (CNN/DailyMail + BookSum)
7
  - Emotion classification (GoEmotions, 28 labels)
8
- - Topic classification (AG News, 4 labels)
9
 
10
  Usage:
11
  python scripts/train.py training=medium
@@ -89,11 +89,17 @@ def main(cfg: DictConfig) -> None:
89
  device = torch.device(cfg.device)
90
 
91
  # GPU optimizations for Ampere+
92
- if device.type == "cuda" and torch.cuda.get_device_capability()[0] >= 8:
93
- torch.set_float32_matmul_precision("high")
94
- torch.backends.cuda.matmul.allow_tf32 = True
95
- torch.backends.cudnn.allow_tf32 = True
96
- print("✓ TF32 enabled for Ampere GPU")
 
 
 
 
 
 
97
 
98
  # --------------- Load Data ---------------
99
 
@@ -187,6 +193,11 @@ def main(cfg: DictConfig) -> None:
187
  # --------------- Model ---------------
188
 
189
  print("\nBuilding model...")
 
 
 
 
 
190
  model_cfg = ModelConfig(
191
  d_model=cfg.model.d_model,
192
  vocab_size=getattr(cfg.model, "vocab_size", None),
@@ -198,9 +209,15 @@ def main(cfg: DictConfig) -> None:
198
  use_pretrained=cfg.model.use_pretrained,
199
  pretrained_model_name=cfg.model.pretrained_model_name,
200
  activation=getattr(cfg.model, "activation", "gelu"),
201
- use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
 
202
  )
203
 
 
 
 
 
 
204
  model = build_multitask_model(
205
  tokenizer,
206
  num_emotions=len(emot_train.emotion_classes),
@@ -211,6 +228,26 @@ def main(cfg: DictConfig) -> None:
211
  param_count = sum(p.numel() for p in model.parameters())
212
  print(f" Parameters: {param_count:,} ({param_count/1e6:.1f}M)")
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # Resume from checkpoint?
215
  start_epoch = 1
216
  resume_path = cfg.get("resume_from")
@@ -223,12 +260,15 @@ def main(cfg: DictConfig) -> None:
223
  start_epoch = int(digits[-1]) + 1
224
 
225
  # Compile model for speed
 
 
 
226
  if cfg.training.get("compile_encoder", True):
227
- model.encoder = torch.compile(model.encoder, backend="inductor") # type: ignore[assignment]
228
- print(" ✓ Encoder compiled")
229
  if cfg.training.get("compile_decoder", True):
230
- model.decoder = torch.compile(model.decoder, backend="inductor") # type: ignore[assignment]
231
- print(" ✓ Decoder compiled")
232
 
233
  # --------------- Train ---------------
234
 
@@ -236,11 +276,16 @@ def main(cfg: DictConfig) -> None:
236
  opt_cfg = cfg.training.get("optimizer", {})
237
  sched_cfg = cfg.training.get("scheduler", {})
238
 
 
 
239
  optimizer = torch.optim.AdamW(
240
  model.parameters(),
241
  lr=float(opt_cfg.get("lr", 3e-5)),
242
  weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
 
243
  )
 
 
244
 
245
  trainer = Trainer(
246
  model=model,
 
3
  Training script for LexiMind.
4
 
5
  Simple, clean training with multi-task learning across:
6
+ - Summarization (BookSum + arXiv papers)
7
  - Emotion classification (GoEmotions, 28 labels)
8
+ - Topic classification (Books + Papers, 8 labels: Fiction, Science, Technology, etc.)
9
 
10
  Usage:
11
  python scripts/train.py training=medium
 
89
  device = torch.device(cfg.device)
90
 
91
  # GPU optimizations for Ampere+
92
+ if device.type == "cuda":
93
+ # Enable cudnn benchmark for fixed-size inputs (10-20% speedup)
94
+ torch.backends.cudnn.benchmark = True
95
+
96
+ if torch.cuda.get_device_capability()[0] >= 8:
97
+ torch.set_float32_matmul_precision("high")
98
+ torch.backends.cuda.matmul.allow_tf32 = True
99
+ torch.backends.cudnn.allow_tf32 = True
100
+ print("✓ TF32 + cudnn.benchmark enabled for Ampere GPU")
101
+ else:
102
+ print("✓ cudnn.benchmark enabled")
103
 
104
  # --------------- Load Data ---------------
105
 
 
193
  # --------------- Model ---------------
194
 
195
  print("\nBuilding model...")
196
+
197
+ # Check for overrides in training config
198
+ grad_ckpt = cfg.training.get("gradient_checkpointing", cfg.model.get("gradient_checkpointing", False))
199
+ use_rel_pos = cfg.training.get("use_relative_position_bias", cfg.model.get("use_relative_position_bias", False))
200
+
201
  model_cfg = ModelConfig(
202
  d_model=cfg.model.d_model,
203
  vocab_size=getattr(cfg.model, "vocab_size", None),
 
209
  use_pretrained=cfg.model.use_pretrained,
210
  pretrained_model_name=cfg.model.pretrained_model_name,
211
  activation=getattr(cfg.model, "activation", "gelu"),
212
+ use_relative_position_bias=use_rel_pos,
213
+ gradient_checkpointing=grad_ckpt,
214
  )
215
 
216
+ if grad_ckpt:
217
+ print(" ✓ Gradient checkpointing enabled")
218
+ if not use_rel_pos:
219
+ print(" ✓ FlashAttention enabled (no relative position bias)")
220
+
221
  model = build_multitask_model(
222
  tokenizer,
223
  num_emotions=len(emot_train.emotion_classes),
 
228
  param_count = sum(p.numel() for p in model.parameters())
229
  print(f" Parameters: {param_count:,} ({param_count/1e6:.1f}M)")
230
 
231
+ # Freeze lower encoder layers (keeps pretrained language understanding, adapts upper layers)
232
+ freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
233
+ if freeze_layers > 0:
234
+ frozen_params = 0
235
+ # Freeze embedding layer
236
+ if hasattr(model.encoder, 'embed_tokens'):
237
+ for p in model.encoder.embed_tokens.parameters():
238
+ p.requires_grad = False
239
+ frozen_params += p.numel()
240
+ # Freeze specified number of encoder layers
241
+ if hasattr(model.encoder, 'layers'):
242
+ for i, layer in enumerate(model.encoder.layers):
243
+ if i < freeze_layers:
244
+ for p in layer.parameters():
245
+ p.requires_grad = False
246
+ frozen_params += p.numel()
247
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
248
+ print(f" ✓ Frozen encoder layers 0-{freeze_layers-1} ({frozen_params/1e6:.1f}M params)")
249
+ print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
250
+
251
  # Resume from checkpoint?
252
  start_epoch = 1
253
  resume_path = cfg.get("resume_from")
 
260
  start_epoch = int(digits[-1]) + 1
261
 
262
  # Compile model for speed
263
+ # Note: "reduce-overhead" mode uses CUDA graphs which conflicts with gradient checkpointing
264
+ # Use "default" mode when checkpointing is enabled
265
+ compile_mode = "default" if grad_ckpt else "reduce-overhead"
266
  if cfg.training.get("compile_encoder", True):
267
+ model.encoder = torch.compile(model.encoder, mode=compile_mode) # type: ignore[assignment]
268
+ print(f" ✓ Encoder compiled ({compile_mode})")
269
  if cfg.training.get("compile_decoder", True):
270
+ model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
271
+ print(f" ✓ Decoder compiled ({compile_mode})")
272
 
273
  # --------------- Train ---------------
274
 
 
276
  opt_cfg = cfg.training.get("optimizer", {})
277
  sched_cfg = cfg.training.get("scheduler", {})
278
 
279
+ # Use fused AdamW on CUDA for ~5-10% speedup
280
+ use_fused = device.type == "cuda" and "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
281
  optimizer = torch.optim.AdamW(
282
  model.parameters(),
283
  lr=float(opt_cfg.get("lr", 3e-5)),
284
  weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
285
+ fused=use_fused,
286
  )
287
+ if use_fused:
288
+ print(" ✓ Fused AdamW optimizer")
289
 
290
  trainer = Trainer(
291
  model=model,
src/inference/pipeline.py CHANGED
@@ -68,6 +68,7 @@ class InferenceConfig:
68
 
69
  summary_max_length: int = 128
70
  summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
 
71
  summary_formatting: bool = True # Apply text cleanup/formatting to generated summaries
72
  emotion_threshold: float = 0.5
73
  device: str | None = None
@@ -157,6 +158,7 @@ class InferencePipeline:
157
  ban_token_ids=[i for i in ban_ids if i is not None],
158
  no_repeat_ngram_size=3,
159
  repetition_penalty=self.config.summary_repetition_penalty,
 
160
  memory_mask=src_mask,
161
  )
162
 
 
68
 
69
  summary_max_length: int = 128
70
  summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
71
+ summary_length_penalty: float = 1.5 # Encourage EOS token as length increases (>1 = shorter)
72
  summary_formatting: bool = True # Apply text cleanup/formatting to generated summaries
73
  emotion_threshold: float = 0.5
74
  device: str | None = None
 
158
  ban_token_ids=[i for i in ban_ids if i is not None],
159
  no_repeat_ngram_size=3,
160
  repetition_penalty=self.config.summary_repetition_penalty,
161
+ length_penalty=self.config.summary_length_penalty,
162
  memory_mask=src_mask,
163
  )
164
 
src/models/decoder.py CHANGED
@@ -445,10 +445,15 @@ class TransformerDecoder(nn.Module):
445
  ban_token_ids: Optional[List[int]] = None,
446
  no_repeat_ngram_size: int = 0,
447
  repetition_penalty: float = 1.0,
 
448
  memory_mask: Optional[torch.Tensor] = None,
449
  ) -> torch.Tensor:
450
  """
451
  Greedy decoding with KV caching for O(N) complexity.
 
 
 
 
452
  """
453
  if device is None:
454
  device = memory.device
@@ -519,6 +524,13 @@ class TransformerDecoder(nn.Module):
519
  if banned_for_this_batch:
520
  next_step_logits[b, list(banned_for_this_batch)] = float("-inf")
521
 
 
 
 
 
 
 
 
522
  # Greedy selection
523
  next_token = next_step_logits.argmax(dim=-1, keepdim=True) # (B, 1)
524
 
 
445
  ban_token_ids: Optional[List[int]] = None,
446
  no_repeat_ngram_size: int = 0,
447
  repetition_penalty: float = 1.0,
448
+ length_penalty: float = 1.0,
449
  memory_mask: Optional[torch.Tensor] = None,
450
  ) -> torch.Tensor:
451
  """
452
  Greedy decoding with KV caching for O(N) complexity.
453
+
454
+ Args:
455
+ length_penalty: Values > 1.0 encourage shorter sequences by boosting EOS probability
456
+ as sequence length increases. Default 1.0 (no penalty).
457
  """
458
  if device is None:
459
  device = memory.device
 
524
  if banned_for_this_batch:
525
  next_step_logits[b, list(banned_for_this_batch)] = float("-inf")
526
 
527
+ # Length penalty to boost EOS probability as sequence grows (encourages shorter outputs)
528
+ if length_penalty != 1.0 and end_token_id is not None and generated.size(1) >= min_len:
529
+ # Scale EOS logit based on current length relative to max
530
+ length_ratio = generated.size(1) / max_len
531
+ eos_boost = length_penalty * length_ratio # Grows as we approach max_len
532
+ next_step_logits[:, end_token_id] = next_step_logits[:, end_token_id] + eos_boost
533
+
534
  # Greedy selection
535
  next_token = next_step_logits.argmax(dim=-1, keepdim=True) # (B, 1)
536
 
src/training/trainer.py CHANGED
@@ -369,17 +369,19 @@ class Trainer:
369
  if src_mask is not None:
370
  src_mask = src_mask[:1]
371
 
372
- # Generate
373
  model: Any = self.model
374
  enc_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
375
  memory = model.encoder(src_ids, mask=enc_mask)
376
- generated = model.decoder.greedy_decode_naive(
377
  memory=memory,
378
  max_len=self.config.validation_max_length,
379
  start_token_id=self.tokenizer.bos_token_id,
380
  end_token_id=self.tokenizer.eos_token_id,
381
  device=self.device,
382
  memory_mask=src_mask,
 
 
383
  )
384
 
385
  src = self.tokenizer.decode(src_ids[0].tolist())
 
369
  if src_mask is not None:
370
  src_mask = src_mask[:1]
371
 
372
+ # Generate with anti-repetition
373
  model: Any = self.model
374
  enc_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
375
  memory = model.encoder(src_ids, mask=enc_mask)
376
+ generated = model.decoder.greedy_decode(
377
  memory=memory,
378
  max_len=self.config.validation_max_length,
379
  start_token_id=self.tokenizer.bos_token_id,
380
  end_token_id=self.tokenizer.eos_token_id,
381
  device=self.device,
382
  memory_mask=src_mask,
383
+ no_repeat_ngram_size=3,
384
+ repetition_penalty=1.2,
385
  )
386
 
387
  src = self.tokenizer.decode(src_ids[0].tolist())