Spaces:
Sleeping
Sleeping
OliverPerrin commited on
Commit ·
90a2698
1
Parent(s): 0d858b5
Cleaned up code, added multiseed training wrapper, PyTorch profiler training option, updated gradio demo, made changes to research paper to match new changes and new training results from adding new training techniques, architecture.md now explains all designs and decisions
Browse files- README.md +6 -5
- configs/training/dev.yaml +1 -1
- configs/training/full.yaml +4 -4
- configs/training/medium.yaml +1 -1
- docs/architecture.md +368 -61
- docs/research_paper.tex +99 -76
- outputs/evaluation_report.json +250 -13
- outputs/training_history.json +180 -154
- pyproject.toml +6 -2
- scripts/build_discovery_dataset.py +6 -8
- scripts/demo_gradio.py +46 -129
- scripts/download_data.py +25 -29
- scripts/evaluate.py +12 -8
- scripts/profile_training.py +314 -0
- scripts/train.py +10 -11
- scripts/train_multiseed.py +5 -4
- scripts/visualize_training.py +12 -30
- src/data/dataset.py +0 -6
- src/models/factory.py +1 -1
- src/training/metrics.py +1 -1
- src/training/trainer.py +5 -5
README.md
CHANGED
|
@@ -8,6 +8,7 @@ app_file: scripts/demo_gradio.py
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
|
|
|
| 11 |
# LexiMind
|
| 12 |
|
| 13 |
A multi-task NLP system for literary and academic text understanding. LexiMind performs **abstractive summarization**, **topic classification**, and **emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
|
|
@@ -17,7 +18,7 @@ A multi-task NLP system for literary and academic text understanding. LexiMind p
|
|
| 17 |
## What It Does
|
| 18 |
|
| 19 |
| Task | Description | Metric |
|
| 20 |
-
|
| 21 |
| **Summarization** | Generates back-cover style book descriptions and paper abstracts from source text | BERTScore F1: **0.830** |
|
| 22 |
| **Topic Classification** | Classifies passages into 7 categories | Accuracy: **85.2%** |
|
| 23 |
| **Emotion Detection** | Identifies emotions from 28 fine-grained labels (multi-label) | Sample-avg F1: **0.199** |
|
|
@@ -31,7 +32,7 @@ The model is trained on literary text (Project Gutenberg + Goodreads description
|
|
| 31 |
LexiMind is a **custom Transformer implementation** that loads pre-trained weights from FLAN-T5-base via a factory module. The architecture is reimplemented from scratch for transparency, not wrapped from HuggingFace.
|
| 32 |
|
| 33 |
| Component | Detail |
|
| 34 |
-
|
| 35 |
| Backbone | Encoder-Decoder Transformer (272M params) |
|
| 36 |
| Encoder / Decoder | 12 layers each |
|
| 37 |
| Hidden Dim | 768, 12 attention heads |
|
|
@@ -48,7 +49,7 @@ All three tasks share the encoder. Summarization uses the full encoder-decoder;
|
|
| 48 |
## Training Data
|
| 49 |
|
| 50 |
| Task | Source | Train Samples |
|
| 51 |
-
|
| 52 |
| Summarization | Gutenberg + Goodreads (literary) | ~4K |
|
| 53 |
| Summarization | arXiv body → abstract (academic) | ~45K |
|
| 54 |
| Topic | 20 Newsgroups + Gutenberg + arXiv metadata | 3,402 |
|
|
@@ -131,7 +132,7 @@ docker run -p 7860:7860 leximind
|
|
| 131 |
|
| 132 |
## Project Structure
|
| 133 |
|
| 134 |
-
```
|
| 135 |
configs/
|
| 136 |
├── config.yaml # Main Hydra config
|
| 137 |
├── data/datasets.yaml # Dataset paths and tokenizer settings
|
|
@@ -208,4 +209,4 @@ GPL-3.0 — see [LICENSE](LICENSE) for details.
|
|
| 208 |
|
| 209 |
---
|
| 210 |
|
| 211 |
-
|
|
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
<!-- markdownlint-disable MD025 -->
|
| 12 |
# LexiMind
|
| 13 |
|
| 14 |
A multi-task NLP system for literary and academic text understanding. LexiMind performs **abstractive summarization**, **topic classification**, and **emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
|
|
|
|
| 18 |
## What It Does
|
| 19 |
|
| 20 |
| Task | Description | Metric |
|
| 21 |
+
| ------ | ------------- | -------- |
|
| 22 |
| **Summarization** | Generates back-cover style book descriptions and paper abstracts from source text | BERTScore F1: **0.830** |
|
| 23 |
| **Topic Classification** | Classifies passages into 7 categories | Accuracy: **85.2%** |
|
| 24 |
| **Emotion Detection** | Identifies emotions from 28 fine-grained labels (multi-label) | Sample-avg F1: **0.199** |
|
|
|
|
| 32 |
LexiMind is a **custom Transformer implementation** that loads pre-trained weights from FLAN-T5-base via a factory module. The architecture is reimplemented from scratch for transparency, not wrapped from HuggingFace.
|
| 33 |
|
| 34 |
| Component | Detail |
|
| 35 |
+
| ----------- | -------- |
|
| 36 |
| Backbone | Encoder-Decoder Transformer (272M params) |
|
| 37 |
| Encoder / Decoder | 12 layers each |
|
| 38 |
| Hidden Dim | 768, 12 attention heads |
|
|
|
|
| 49 |
## Training Data
|
| 50 |
|
| 51 |
| Task | Source | Train Samples |
|
| 52 |
+
| ------ | -------- | --------------- |
|
| 53 |
| Summarization | Gutenberg + Goodreads (literary) | ~4K |
|
| 54 |
| Summarization | arXiv body → abstract (academic) | ~45K |
|
| 55 |
| Topic | 20 Newsgroups + Gutenberg + arXiv metadata | 3,402 |
|
|
|
|
| 132 |
|
| 133 |
## Project Structure
|
| 134 |
|
| 135 |
+
```text
|
| 136 |
configs/
|
| 137 |
├── config.yaml # Main Hydra config
|
| 138 |
├── data/datasets.yaml # Dataset paths and tokenizer settings
|
|
|
|
| 209 |
|
| 210 |
---
|
| 211 |
|
| 212 |
+
Built by Oliver Perrin · Appalachian State University · 2025–2026
|
configs/training/dev.yaml
CHANGED
|
@@ -37,7 +37,7 @@ trainer:
|
|
| 37 |
max_val_samples: 300
|
| 38 |
early_stopping_patience: 5
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
-
task_sampling:
|
| 41 |
task_sampling_alpha: 0.5
|
| 42 |
gradient_conflict_frequency: 0
|
| 43 |
|
|
|
|
| 37 |
max_val_samples: 300
|
| 38 |
early_stopping_patience: 5
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
+
task_sampling: temperature
|
| 41 |
task_sampling_alpha: 0.5
|
| 42 |
gradient_conflict_frequency: 0
|
| 43 |
|
configs/training/full.yaml
CHANGED
|
@@ -22,12 +22,12 @@ optimizer:
|
|
| 22 |
|
| 23 |
scheduler:
|
| 24 |
name: cosine
|
| 25 |
-
warmup_steps: 300 # ~0.
|
| 26 |
|
| 27 |
trainer:
|
| 28 |
max_epochs: 8 # Reduced from 12 - early stopping will catch plateau anyway
|
| 29 |
gradient_clip_norm: 1.0
|
| 30 |
-
gradient_accumulation_steps: 4 # Reduced from 8
|
| 31 |
validation_max_length: 128
|
| 32 |
label_smoothing: 0.1
|
| 33 |
task_weights:
|
|
@@ -38,9 +38,9 @@ trainer:
|
|
| 38 |
max_val_samples: 3000 # Enough for stable metrics
|
| 39 |
early_stopping_patience: 3 # Stop quickly when plateauing
|
| 40 |
log_grad_norm_frequency: 200
|
| 41 |
-
# Task sampling: "round_robin"
|
| 42 |
# Temperature sampling: p_i proportional to n_i^alpha, reduces dominance of large tasks
|
| 43 |
-
task_sampling:
|
| 44 |
task_sampling_alpha: 0.5
|
| 45 |
# Gradient conflict diagnostics: compute inter-task gradient cosine similarity
|
| 46 |
# every N steps (0 = disabled). Helps diagnose negative transfer.
|
|
|
|
| 22 |
|
| 23 |
scheduler:
|
| 24 |
name: cosine
|
| 25 |
+
warmup_steps: 300 # ~0.24 epoch warmup (1227 optimizer steps/epoch)
|
| 26 |
|
| 27 |
trainer:
|
| 28 |
max_epochs: 8 # Reduced from 12 - early stopping will catch plateau anyway
|
| 29 |
gradient_clip_norm: 1.0
|
| 30 |
+
gradient_accumulation_steps: 4 # Reduced from 8, 2x faster optimizer steps
|
| 31 |
validation_max_length: 128
|
| 32 |
label_smoothing: 0.1
|
| 33 |
task_weights:
|
|
|
|
| 38 |
max_val_samples: 3000 # Enough for stable metrics
|
| 39 |
early_stopping_patience: 3 # Stop quickly when plateauing
|
| 40 |
log_grad_norm_frequency: 200
|
| 41 |
+
# Task sampling: "round_robin" or "temperature"
|
| 42 |
# Temperature sampling: p_i proportional to n_i^alpha, reduces dominance of large tasks
|
| 43 |
+
task_sampling: temperature
|
| 44 |
task_sampling_alpha: 0.5
|
| 45 |
# Gradient conflict diagnostics: compute inter-task gradient cosine similarity
|
| 46 |
# every N steps (0 = disabled). Helps diagnose negative transfer.
|
configs/training/medium.yaml
CHANGED
|
@@ -37,7 +37,7 @@ trainer:
|
|
| 37 |
max_val_samples: 2500
|
| 38 |
early_stopping_patience: 3 # More patience
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
-
task_sampling:
|
| 41 |
task_sampling_alpha: 0.5
|
| 42 |
gradient_conflict_frequency: 0
|
| 43 |
|
|
|
|
| 37 |
max_val_samples: 2500
|
| 38 |
early_stopping_patience: 3 # More patience
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
+
task_sampling: temperature
|
| 41 |
task_sampling_alpha: 0.5
|
| 42 |
gradient_conflict_frequency: 0
|
| 43 |
|
docs/architecture.md
CHANGED
|
@@ -2,88 +2,395 @@
|
|
| 2 |
|
| 3 |
## Overview
|
| 4 |
|
| 5 |
-
LexiMind
|
| 6 |
|
| 7 |
-
|
| 8 |
-
2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via `MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from configuration files.
|
| 9 |
-
3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and Gradio UI.
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
-
The
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
### Weight Loading from FLAN-T5
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
-
|
| 28 |
-
- **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
|
| 29 |
-
- **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
|
| 30 |
-
- **RMSNorm weights:** Direct transfer (both use RMSNorm without bias)
|
| 31 |
-
- **LM head:** Loaded from T5's `lm_head`
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
- `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
|
| 39 |
-
- `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
|
| 40 |
-
- `src/models/heads.py` – ClassificationHead (mean pooling) and LMHead (with weight tying)
|
| 41 |
-
- `src/models/multitask.py` – Routes inputs to task-specific heads
|
| 42 |
-
- `src/models/factory.py` – Builds models and loads FLAN-T5 weights
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
| ---- | ------- | ---- | ------ |
|
| 54 |
-
| Summarization | BookSum + arXiv | ~90K | Text→Summary |
|
| 55 |
-
| Emotion | GoEmotions | ~43K | 28 emotions (multi-label) |
|
| 56 |
-
| Topic | Books + Papers | 3.4K | 7 categories (Arts, Business, Fiction, History, Philosophy, Science, Technology) |
|
| 57 |
-
| Books | Gutenberg (prose chunks) | ~30K | Literary text |
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
- **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
|
| 63 |
-
- **Subword tokenization:** Unigram-based (vs BART's BPE)
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
- Mixed precision training (bfloat16 on Ampere/Ada GPUs)
|
| 69 |
-
- Gradient accumulation for larger effective batch sizes
|
| 70 |
-
- Per-task loss weighting and label smoothing
|
| 71 |
-
- Early stopping based on validation loss
|
| 72 |
-
- Cosine learning rate schedule with warmup
|
| 73 |
-
- **torch.compile:** JIT compilation with Inductor backend for 20-40% speedup
|
| 74 |
-
- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
-
|
| 79 |
-
- `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
|
| 80 |
-
- The CLI (`scripts/inference.py`) drives the pipeline from the command line
|
| 81 |
-
- Gradio demo (`scripts/demo_gradio.py`) provides an interactive web interface
|
| 82 |
|
| 83 |
-
##
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
## Overview
|
| 4 |
|
| 5 |
+
LexiMind is a **272M parameter encoder-decoder transformer** initialized from Google's FLAN-T5-base, trained jointly on three tasks: abstractive summarization, topic classification, and multi-label emotion detection. The project spans data preparation, custom model architecture, multi-task training, evaluation, and a Gradio-based discovery demo.
|
| 6 |
|
| 7 |
+
## Model Architecture
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
### Backbone: Custom Transformer (FLAN-T5-base Initialization)
|
| 10 |
|
| 11 |
+
The model is a from-scratch PyTorch implementation of a T5-style encoder-decoder. We do **not** use HuggingFace's `T5ForConditionalGeneration` — instead, every component (attention, FFN, normalization, positional encoding) is implemented manually in `src/models/`, then FLAN-T5 weights are loaded layer by layer in `src/models/factory.py`.
|
| 12 |
|
| 13 |
+
```text
|
| 14 |
+
Input Text
|
| 15 |
+
│
|
| 16 |
+
▼
|
| 17 |
+
┌─────────────────────────────────────────────┐
|
| 18 |
+
│ Shared Encoder (12 layers, 768d, 12 heads) │
|
| 19 |
+
│ ┌────────────────────────────────────────┐ │
|
| 20 |
+
│ │ Layers 0-3: FROZEN (FLAN-T5 weights) │ │
|
| 21 |
+
│ │ Layers 4-11: TRAINABLE (fine-tuned) │ │
|
| 22 |
+
│ └────────────────────────────────────────┘ │
|
| 23 |
+
│ Pre-LN RMSNorm │ T5 Relative Position Bias │
|
| 24 |
+
│ FlashAttention (SDPA) │ Gated-GELU FFN │
|
| 25 |
+
└────────────┬──────────────┬──────────────┬───┘
|
| 26 |
+
│ │ │
|
| 27 |
+
┌────────▼────────┐ ┌─▼──────────┐ ┌▼───────────┐
|
| 28 |
+
│ Decoder │ │ Attention │ │ Mean │
|
| 29 |
+
│ (12 layers) │ │ Pooling │ │ Pooling │
|
| 30 |
+
│ Causal + Cross │ │ (learned) │ │ │
|
| 31 |
+
│ Attention │ │ │ │ │ │ │
|
| 32 |
+
│ │ │ │ MLP 768→ │ │ Linear │
|
| 33 |
+
│ LM Head │ │ 384→28 │ │ 768→7 │
|
| 34 |
+
│ (tied weights) │ │ │ │ │
|
| 35 |
+
└────────┬────────┘ └─────┬───────┘ └──────┬─────┘
|
| 36 |
+
│ │ │
|
| 37 |
+
Summarization Emotion (28) Topic (7)
|
| 38 |
+
(generative) (multi-label) (single-label)
|
| 39 |
+
```
|
| 40 |
|
| 41 |
+
### Encoder
|
| 42 |
+
|
| 43 |
+
**File**: `src/models/encoder.py` (317 lines)
|
| 44 |
+
|
| 45 |
+
- 12 transformer layers, 768-dimensional, 12 attention heads
|
| 46 |
+
- **Pre-Layer Normalization (Pre-LN)** using T5-style RMSNorm — normalization applied *before* each sublayer, not after. This is the modern standard (LLaMA, T5 v1.1+, PaLM).
|
| 47 |
+
- **T5 Relative Position Bias**: Bucketed log-linear position bias computed in the attention layer. Bidirectional (encoder attends in both directions). Shared across layers (computed once, passed to all layers).
|
| 48 |
+
- **FlashAttention**: Via PyTorch 2.0's `F.scaled_dot_product_attention`, which automatically selects the optimal kernel (Flash, memory-efficient, or math fallback). **Note**: T5 does NOT scale attention scores by 1/√d_k — the `scale_scores=False` flag preserves this behavior.
|
| 49 |
+
- **Gated-GELU FFN**: Two linear projections (gate + up) element-wise multiplied, then a down projection. Matches T5's `DenseGatedGeluDense`.
|
| 50 |
+
- **Gradient checkpointing**: Optional per-layer activation recomputation to reduce VRAM (enabled in our full training config, saves ~2-3 GB).
|
| 51 |
+
- Bottom 4 layers are frozen during fine-tuning to preserve FLAN-T5's general language representations.
|
| 52 |
+
|
| 53 |
+
The encoder processes all input text and produces contextualized representations that are consumed by all three task heads.
|
| 54 |
+
|
| 55 |
+
### Decoder (Summarization Only)
|
| 56 |
+
|
| 57 |
+
**File**: `src/models/decoder.py` (749 lines)
|
| 58 |
+
|
| 59 |
+
- 12 transformer layers, 768-dimensional, 12 attention heads
|
| 60 |
+
- ~136M parameters — roughly half the total model
|
| 61 |
+
- **Masked self-attention** (causal mask prevents attending to future positions)
|
| 62 |
+
- **Cross-attention** to encoder outputs (allows decoder to attend to the full input)
|
| 63 |
+
- **KV-cache** for efficient autoregressive generation — incremental key/value computation avoids recomputing previous positions
|
| 64 |
+
- **Greedy decoding** with:
|
| 65 |
+
- No-repeat n-gram blocking (`no_repeat_ngram_size=3`)
|
| 66 |
+
- Repetition penalty (1.2x)
|
| 67 |
+
- Length penalty
|
| 68 |
+
- Min/max length constraints
|
| 69 |
+
- **LM Head**: Linear projection from 768d → 32,128 vocab. **Weight-tied** with decoder token embeddings (reduces parameters and improves coherence).
|
| 70 |
+
|
| 71 |
+
The decoder is exclusive to summarization. Classification tasks only use the encoder.
|
| 72 |
+
|
| 73 |
+
### Task Heads
|
| 74 |
+
|
| 75 |
+
**File**: `src/models/heads.py` (221 lines)
|
| 76 |
+
|
| 77 |
+
#### Emotion Head (Attention Pooling + MLP)
|
| 78 |
+
|
| 79 |
+
- **AttentionPooling**: A single linear layer (`nn.Linear(768, 1, bias=False)`) serves as a learned query. It computes softmax attention weights over all encoder positions, producing a weighted sum. This allows the model to focus on emotionally salient tokens (e.g., "grateful", "hilarious") rather than averaging the entire 512-token sequence. Padding is masked before softmax.
|
| 80 |
+
- **2-layer MLP**: 768 → 384 (GELU) → 28. The hidden layer provides nonlinear feature transformation before the 28-way multi-label output.
|
| 81 |
+
- **Loss**: BCEWithLogitsLoss (binary cross-entropy per class)
|
| 82 |
+
- **Inference threshold**: 0.3 (lowered from default 0.5 because 28-class multi-label predictions have lower per-class confidence)
|
| 83 |
+
|
| 84 |
+
#### Topic Head (Mean Pooling + Linear)
|
| 85 |
+
|
| 86 |
+
- **Mean pooling** over encoder positions (attention-mask-aware)
|
| 87 |
+
- **Single linear layer**: 768 → 7
|
| 88 |
+
- **Loss**: CrossEntropyLoss
|
| 89 |
+
- **Task weight**: 0.3 (reduced to prevent overfitting on the small 3.4K dataset)
|
| 90 |
+
|
| 91 |
+
#### Summarization Head (Decoder + LM Head)
|
| 92 |
+
|
| 93 |
+
- Full decoder (described above) + weight-tied LM head
|
| 94 |
+
- **Loss**: CrossEntropyLoss with label smoothing (0.1) and `-100` ignore index for padding
|
| 95 |
+
- **Task weight**: 1.0
|
| 96 |
+
|
| 97 |
+
### Multi-Task Router
|
| 98 |
+
|
| 99 |
+
**File**: `src/models/multitask.py` (263 lines)
|
| 100 |
+
|
| 101 |
+
The `MultiTaskModel` class routes `forward(task, inputs)` calls to the correct head:
|
| 102 |
+
|
| 103 |
+
- **Classification** (`emotion`, `topic`): encoder → pool → classify
|
| 104 |
+
- **Generation** (`summarization`): encoder → decoder → LM head
|
| 105 |
+
|
| 106 |
+
A `memory.clone()` call between encoder and decoder output prevents CUDA Graph buffer reuse issues when using `torch.compile`.
|
| 107 |
|
| 108 |
### Weight Loading from FLAN-T5
|
| 109 |
|
| 110 |
+
**File**: `src/models/factory.py` (571 lines)
|
| 111 |
+
|
| 112 |
+
Weights are transferred from HuggingFace's `google/flan-t5-base` checkpoint layer by layer:
|
| 113 |
+
|
| 114 |
+
| FLAN-T5 Component | Our Component |
|
| 115 |
+
| --- | --- |
|
| 116 |
+
| `shared.weight` | `encoder.embed_tokens.weight` and `decoder.embed_tokens.weight` |
|
| 117 |
+
| `encoder.block.{i}.layer.0.SelfAttention.{q,k,v,o}` | `encoder.layers.{i}.self_attn.{q,k,v,out}_proj.weight` |
|
| 118 |
+
| `encoder.block.{i}.layer.1.DenseReluDense.wi_0/wi_1/wo` | `encoder.layers.{i}.ffn.gate/up_proj/down_proj.weight` |
|
| 119 |
+
| `encoder.block.{i}.layer.{0,1}.layer_norm.weight` | `encoder.layers.{i}.norm{1,2}.weight` |
|
| 120 |
+
| `encoder.block.0.layer.0.SelfAttention.relative_attention_bias` | `encoder.layers.0.self_attn.attn.position_bias.relative_attention_bias` |
|
| 121 |
+
| `lm_head.weight` | `summarization_head.projection.weight` |
|
| 122 |
+
|
| 123 |
+
Vocab size mismatch (T5: 32,100 → ours: 32,128) is handled by zero-padding the embedding matrix.
|
| 124 |
+
|
| 125 |
+
### Available but Unused Components
|
| 126 |
+
|
| 127 |
+
These are implemented but not activated in the current configuration:
|
| 128 |
+
|
| 129 |
+
- **LoRA adapters** on Q and V projections in `MultiHeadAttention` — for parameter-efficient fine-tuning
|
| 130 |
+
- **Rotary Position Embeddings (RoPE)** — alternative to T5's relative position bias
|
| 131 |
+
- **4-bit/8-bit quantization** via bitsandbytes — for inference on constrained hardware
|
| 132 |
+
- **TokenClassificationHead** — for NER/POS tasks
|
| 133 |
+
- **ProjectionHead** — for contrastive/representation learning
|
| 134 |
+
- **LLaMA weight loading** — `_load_llama_weights()` for loading Gemma/LLaMA checkpoints
|
| 135 |
+
|
| 136 |
+
## Tokenization
|
| 137 |
+
|
| 138 |
+
**File**: `src/data/tokenization.py` (157 lines)
|
| 139 |
+
|
| 140 |
+
Wraps HuggingFace's `AutoTokenizer` configured for FLAN-T5:
|
| 141 |
+
|
| 142 |
+
- **SentencePiece** (Unigram) tokenizer, 32,128 vocabulary
|
| 143 |
+
- Special tokens: `pad=0`, `eos=1`, no explicit BOS (decoder starts with pad token, per T5 convention)
|
| 144 |
+
- Max sequence length: 512 tokens (encoder), 128 tokens (decoder during validation generation)
|
| 145 |
+
- Classification tasks use a reduced max length of 256 tokens (sufficient for classification, saves compute)
|
| 146 |
+
|
| 147 |
+
## Datasets
|
| 148 |
+
|
| 149 |
+
**File**: `src/data/dataset.py` (316 lines), `src/data/dataloader.py` (174 lines)
|
| 150 |
+
|
| 151 |
+
| Task | Dataset Source | Train Size | Val Size | Test Size |
|
| 152 |
+
| ------ | --------------- | ----------- | --------- | ---------- |
|
| 153 |
+
| Summarization | arXiv abstracts (~45K) + Goodreads book descriptions (~4K) | ~49K | ~2.7K | ~2.7K |
|
| 154 |
+
| Emotion | GoEmotions (Reddit comments, 28 labels) | ~43K | ~5.4K | — |
|
| 155 |
+
| Topic | arXiv categories + Gutenberg subjects → 7 classes | ~3.2K | ~189 | — |
|
| 156 |
+
|
| 157 |
+
**Cross-task deduplication**: `deduplicate_across_tasks()` uses MD5 fingerprinting on normalized text prefixes (200 chars) to detect and remove overlapping documents between summarization and topic datasets (both draw from arXiv and Gutenberg).
|
| 158 |
+
|
| 159 |
+
**Data pipeline**: Each task has a typed `Dataset` class and a corresponding `Collator` that handles tokenization, padding, and label preparation. Collators are passed to PyTorch `DataLoader` instances created by factory functions (`build_*_dataloader`).
|
| 160 |
+
|
| 161 |
+
## Training
|
| 162 |
+
|
| 163 |
+
**File**: `src/training/trainer.py` (527 lines)
|
| 164 |
+
|
| 165 |
+
### Training Loop
|
| 166 |
+
|
| 167 |
+
Each epoch iterates through batches using **temperature-based task sampling**:
|
| 168 |
+
|
| 169 |
+
1. **Sample task** with probability p_i proportional to n_i^0.5 where n_i is dataset size
|
| 170 |
+
- Summarization (~49K): ~45% of steps
|
| 171 |
+
- Emotion (~43K): ~43% of steps
|
| 172 |
+
- Topic (~3.4K): ~12% of steps
|
| 173 |
+
2. **Forward pass** under `torch.autocast(dtype=bfloat16)` mixed precision
|
| 174 |
+
3. **Compute task-specific loss** with task weight (summ=1.0, emotion=1.0, topic=0.3)
|
| 175 |
+
4. **Backward pass** and accumulate gradients (4 accumulation steps → effective batch size 40)
|
| 176 |
+
5. **Optimizer step** every 4 batches: clip gradients (max norm 1.0), AdamW step, cosine LR step
|
| 177 |
+
|
| 178 |
+
### Training Configuration (full.yaml)
|
| 179 |
+
|
| 180 |
+
| Parameter | Value | Rationale |
|
| 181 |
+
| ----------- | ------- | ----------- |
|
| 182 |
+
| Batch size | 10 | Fits ~10GB VRAM on RTX 4070 12GB |
|
| 183 |
+
| Gradient accumulation | 4 | Effective batch size 40 |
|
| 184 |
+
| Learning rate | 3e-5 | Standard for fine-tuning T5 |
|
| 185 |
+
| Weight decay | 0.01 | Standard AdamW regularization |
|
| 186 |
+
| Warmup steps | 300 | ~0.5 epochs of linear warmup |
|
| 187 |
+
| Max epochs | 8 | Val loss still improving at epoch 8 |
|
| 188 |
+
| LR schedule | Cosine | Decays to 0.1x base LR, flattens near step 8000 |
|
| 189 |
+
| Early stopping | Patience 3 | Never triggered (val loss monotonically decreased) |
|
| 190 |
+
| Label smoothing | 0.1 | Summarization cross-entropy only |
|
| 191 |
+
| Task weights | summ=1.0, emot=1.0, topic=0.3 | Reduced topic weight to prevent overfitting |
|
| 192 |
+
| Task sampling | Temperature (alpha=0.5) | Square-root proportional sampling |
|
| 193 |
+
| Frozen encoder layers | 0-3 | Preserves FLAN-T5's general language knowledge |
|
| 194 |
+
| Gradient checkpointing | Enabled | Saves ~2-3 GB VRAM |
|
| 195 |
+
| torch.compile | Both encoder and decoder | ~20-40% speedup via Inductor backend |
|
| 196 |
+
|
| 197 |
+
### Mixed Precision
|
| 198 |
+
|
| 199 |
+
The RTX 4070 (Ada Lovelace, compute capability 8.9) has dedicated BF16 tensor cores:
|
| 200 |
+
|
| 201 |
+
- All forward/backward passes run under `torch.autocast("cuda", dtype=torch.bfloat16)`
|
| 202 |
+
- BF16 has the same exponent range as FP32 (8 bits), so no GradScaler is needed (unlike FP16)
|
| 203 |
+
- Loss computation and softmax remain in FP32 (handled automatically by autocast)
|
| 204 |
+
- Encoder/decoder layers include `clamp(min=-65504, max=65504)` stability guards (carried over from HuggingFace T5)
|
| 205 |
+
|
| 206 |
+
### Optimizer
|
| 207 |
+
|
| 208 |
+
- **Fused AdamW**: CUDA-native fused kernel (`torch.optim.AdamW(fused=True)`), ~5-10% faster than standard AdamW
|
| 209 |
+
- Betas: (0.9, 0.98) — slightly faster momentum decay than default
|
| 210 |
+
- Epsilon: 1e-6
|
| 211 |
+
|
| 212 |
+
### Gradient Conflict Diagnostics (Available, Disabled)
|
| 213 |
+
|
| 214 |
+
The trainer includes `_compute_gradient_conflicts()` which:
|
| 215 |
+
|
| 216 |
+
1. Computes per-task gradients independently
|
| 217 |
+
2. Flattens all parameter gradients into a single vector per task
|
| 218 |
+
3. Computes pairwise cosine similarity between task gradient vectors
|
| 219 |
+
4. Logs cosine similarity and binary conflict flags to MLflow
|
| 220 |
+
|
| 221 |
+
This is a **diagnostic only** — it does not modify gradients (unlike PCGrad/CAGrad). Disabled by default (`gradient_conflict_frequency: 0`) because it requires extra backward passes per measurement.
|
| 222 |
+
|
| 223 |
+
### MLflow Tracking
|
| 224 |
+
|
| 225 |
+
Training metrics (losses, accuracy, F1, ROUGE, learning rate) are logged to MLflow with a SQLite backend (`mlruns.db`). This enables experiment comparison across training runs.
|
| 226 |
+
|
| 227 |
+
## Evaluation
|
| 228 |
+
|
| 229 |
+
**File**: `scripts/evaluate.py` (538 lines), `src/training/metrics.py` (452 lines)
|
| 230 |
+
|
| 231 |
+
### Metrics
|
| 232 |
+
|
| 233 |
+
| Task | Metrics |
|
| 234 |
+
| ------ | --------- |
|
| 235 |
+
| Summarization | ROUGE-1, ROUGE-2, ROUGE-L (`rouge-score` library), BLEU-4 (NLTK), optional BERTScore |
|
| 236 |
+
| Emotion | Sample-averaged F1, macro F1, micro F1, per-class P/R/F1, per-class threshold tuning |
|
| 237 |
+
| Topic | Accuracy, macro F1, per-class P/R/F1, confusion matrix |
|
| 238 |
+
| All | Bootstrap 95% confidence intervals (1000 resamples), paired bootstrap test |
|
| 239 |
+
|
| 240 |
+
### Per-Class Threshold Tuning (Emotion)
|
| 241 |
+
|
| 242 |
+
For multi-label classification, different emotion classes have very different base rates and prediction confidence. The tuning procedure:
|
| 243 |
+
|
| 244 |
+
1. For each of the 28 emotion classes independently
|
| 245 |
+
2. Sweep threshold tau in {0.1, 0.2, ..., 0.9}
|
| 246 |
+
3. Select the threshold that maximizes per-class F1 on the validation set
|
| 247 |
+
4. Re-compute all metrics with the tuned thresholds
|
| 248 |
+
|
| 249 |
+
This improved macro F1 from 0.143 (default 0.5 threshold) to 0.294.
|
| 250 |
+
|
| 251 |
+
### BERTScore
|
| 252 |
+
|
| 253 |
+
Available via `--include-bertscore` flag in evaluation (opt-in). Uses `roberta-large` for semantic similarity. Not included in primary evaluation due to computational cost and difficulty interpreting absolute values.
|
| 254 |
+
|
| 255 |
+
## Inference
|
| 256 |
+
|
| 257 |
+
**File**: `src/inference/pipeline.py` (217 lines), `src/inference/factory.py` (91 lines)
|
| 258 |
+
|
| 259 |
+
`InferencePipeline` loads a trained checkpoint and runs all three tasks:
|
| 260 |
+
|
| 261 |
+
- **Summarization**: Greedy decode with KV-cache, no-repeat trigram blocking, repetition penalty 1.2
|
| 262 |
+
- **Emotion**: Sigmoid probabilities → threshold at 0.3 → emit labels above threshold
|
| 263 |
+
- **Topic**: Softmax → argmax → emit top label with confidence score
|
| 264 |
+
|
| 265 |
+
`create_inference_pipeline()` reconstructs the full pipeline from checkpoint + labels + tokenizer artifacts.
|
| 266 |
+
|
| 267 |
+
## Serving
|
| 268 |
+
|
| 269 |
+
### Gradio Demo
|
| 270 |
+
|
| 271 |
+
**File**: `scripts/demo_gradio.py` (507 lines)
|
| 272 |
+
|
| 273 |
+
A discovery interface for browsing pre-analyzed books and papers. Loads a pre-computed discovery dataset (not live inference) from HuggingFace Hub (`OliverPerrin/LexiMind-Discovery`). Users can browse by topic, emotion, or keyword search.
|
| 274 |
+
|
| 275 |
+
### FastAPI
|
| 276 |
+
|
| 277 |
+
**Files**: `src/api/app.py` (18 lines), `src/api/routes.py` (49 lines)
|
| 278 |
+
|
| 279 |
+
Minimal REST API with a single `/summarize` endpoint that runs all three tasks and returns JSON results. Uses dependency injection for the inference pipeline.
|
| 280 |
+
|
| 281 |
+
### CLI
|
| 282 |
+
|
| 283 |
+
**File**: `scripts/inference.py` (108 lines)
|
| 284 |
|
| 285 |
+
Command-line interface accepting text from arguments or file, running batch prediction, and printing JSON output.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
### Profiling
|
| 288 |
|
| 289 |
+
**File**: `scripts/profile_training.py`
|
| 290 |
|
| 291 |
+
Wraps a few training steps with `torch.profiler` to capture:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
- CUDA kernel timing (per-operator breakdown)
|
| 294 |
+
- GPU memory usage (peak allocations)
|
| 295 |
+
- CPU/GPU overlap and idle time
|
| 296 |
+
- Chrome trace (viewable in `chrome://tracing` or [Perfetto UI](https://ui.perfetto.dev))
|
| 297 |
+
- CUDA stacks for flamegraph generation
|
| 298 |
|
| 299 |
+
```bash
|
| 300 |
+
python scripts/profile_training.py # 20 steps by default
|
| 301 |
+
PROFILE_STEPS=40 python scripts/profile_training.py # custom step count
|
| 302 |
+
```
|
| 303 |
|
| 304 |
+
Outputs go to `outputs/profile/` — TensorBoard traces, Chrome trace JSON, and stack files.
|
| 305 |
|
| 306 |
+
## Training Results (8 Epochs, RTX 4070, ~9 Hours)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
+
| Epoch | Train Loss | Val Loss | Summ Val Loss | Emotion Val F1 | Topic Val Acc |
|
| 309 |
+
| ------- | ----------- | --------- | --------------- | ---------------- | --------------- |
|
| 310 |
+
| 1 | 6.106 | 4.298 | 3.815 | 0.197 | 70.4% |
|
| 311 |
+
| 2 | 5.528 | 4.027 | 3.739 | 0.301 | 84.7% |
|
| 312 |
+
| 3 | 5.379 | 3.973 | 3.700 | 0.347 | 84.2% |
|
| 313 |
+
| 4 | 5.303 | 3.951 | 3.677 | 0.404 | 85.7% |
|
| 314 |
+
| 5 | 5.208 | 3.940 | 3.665 | 0.431 | 86.3% |
|
| 315 |
+
| 6 | 5.231 | 3.925 | 3.658 | 0.452 | 87.3% |
|
| 316 |
+
| 7 | 5.154 | 3.928 | 3.655 | 0.458 | 85.7% |
|
| 317 |
+
| 8 | 5.178 | 3.925 | 3.653 | 0.459 | 85.7% |
|
| 318 |
|
| 319 |
+
Key observations:
|
|
|
|
|
|
|
| 320 |
|
| 321 |
+
- Early stopping never triggered (val loss monotonically decreased through all 8 epochs)
|
| 322 |
+
- Topic val accuracy plateaued at epoch 2 (~85%), while topic train accuracy reached 98% — overfitting expected on 3.4K samples
|
| 323 |
+
- Emotion F1 improved steadily across all 8 epochs (0.197 → 0.459), showing attention pooling continues learning throughout
|
| 324 |
+
- Summarization loss plateaued after epoch 5 (~3.66)
|
| 325 |
+
- Train loss was lowest at epoch 7 (5.154), slightly higher at epoch 8 (5.178) — normal variance
|
| 326 |
+
- LR schedule cosine curve flattens near step 8000 (0.1x floor)
|
| 327 |
|
| 328 |
+
## Final Evaluation Results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
| Task | Metric | Value | 95% CI |
|
| 331 |
+
| ------ | -------- | ------- | -------- |
|
| 332 |
+
| Summarization | ROUGE-1 | 0.310 | [0.306, 0.313] |
|
| 333 |
+
| Summarization | ROUGE-2 | 0.091 | — |
|
| 334 |
+
| Summarization | ROUGE-L | 0.185 | — |
|
| 335 |
+
| Summarization | BLEU-4 | 0.024 | — |
|
| 336 |
+
| Emotion | Sample F1 | 0.352 | [0.340, 0.366] |
|
| 337 |
+
| Emotion | Macro F1 | 0.143 | — |
|
| 338 |
+
| Emotion | Micro F1 | 0.443 | — |
|
| 339 |
+
| Emotion (tuned) | Macro F1 | 0.294 | — |
|
| 340 |
+
| Emotion (tuned) | Sample F1 | 0.503 | — |
|
| 341 |
+
| Topic | Accuracy | 85.7% | [80.4%, 91.0%] |
|
| 342 |
+
| Topic | Macro F1 | 0.854 | — |
|
| 343 |
|
| 344 |
+
Per-domain summarization: Academic ROUGE-1=0.319, Literary ROUGE-1=0.206.
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
+
## Project Structure
|
| 347 |
|
| 348 |
+
```text
|
| 349 |
+
LexiMind/
|
| 350 |
+
├── configs/ # Hydra configuration
|
| 351 |
+
│ ├── config.yaml # Main config (seeds, paths, device)
|
| 352 |
+
│ ├── data/datasets.yaml # Data paths
|
| 353 |
+
│ ├── model/ # Model configs (base, small, large)
|
| 354 |
+
│ └── training/ # Training configs (full, medium, dev)
|
| 355 |
+
├── src/
|
| 356 |
+
│ ├── models/
|
| 357 |
+
│ │ ├── encoder.py # TransformerEncoder (12 Pre-LN layers)
|
| 358 |
+
│ │ ├── decoder.py # TransformerDecoder with KV-cache
|
| 359 |
+
│ │ ├── attention.py # MultiHeadAttention, FlashAttention, T5 relative pos bias, LoRA, RoPE
|
| 360 |
+
│ │ ├── heads.py # AttentionPooling, ClassificationHead, LMHead
|
| 361 |
+
│ │ ├── multitask.py # MultiTaskModel (task routing)
|
| 362 |
+
│ │ ├── feedforward.py # Gated-GELU / SwiGLU / ReLU FFN
|
| 363 |
+
│ │ ├── positional_encoding.py # Sinusoidal + Learned positional encodings
|
| 364 |
+
│ │ ├── t5_layer_norm.py # RMSNorm (T5-style)
|
| 365 |
+
│ │ └── factory.py # Model construction + FLAN-T5 weight loading
|
| 366 |
+
│ ├── data/
|
| 367 |
+
│ │ ├── tokenization.py # HuggingFace tokenizer wrapper
|
| 368 |
+
│ │ ├── dataset.py # Typed datasets + JSONL loaders + cross-task dedup
|
| 369 |
+
│ │ └── dataloader.py # Task-specific collators + DataLoader factories
|
| 370 |
+
│ ├── training/
|
| 371 |
+
│ │ ├── trainer.py # Multi-task trainer (AMP, gradient accum, temperature sampling)
|
| 372 |
+
│ │ └── metrics.py # ROUGE, BLEU, BERTScore, F1 variants, bootstrap CI
|
| 373 |
+
│ ├── inference/
|
| 374 |
+
│ │ ├── pipeline.py # Multi-task inference pipeline
|
| 375 |
+
│ │ └── factory.py # Pipeline reconstruction from artifacts
|
| 376 |
+
│ ├── api/
|
| 377 |
+
│ │ ├── app.py # FastAPI application
|
| 378 |
+
│ │ └── routes.py # REST endpoints
|
| 379 |
+
│ └── utils/
|
| 380 |
+
│ ├── core.py # Device detection, seed setting
|
| 381 |
+
│ ├── io.py # Checkpoint save/load
|
| 382 |
+
│ └── labels.py # Label metadata I/O
|
| 383 |
+
├── scripts/
|
| 384 |
+
│ ├── train.py # Hydra-based training entry point
|
| 385 |
+
│ ├── evaluate.py # Full evaluation with all metrics
|
| 386 |
+
│ ├── inference.py # CLI inference
|
| 387 |
+
│ ├── demo_gradio.py # Gradio discovery demo
|
| 388 |
+
│ ├── visualize_training.py # Training visualization suite
|
| 389 |
+
│ ├── profile_training.py # PyTorch profiler for GPU analysis
|
| 390 |
+
│ ├── download_data.py # Data preparation from HuggingFace
|
| 391 |
+
│ └── build_discovery_dataset.py # Pre-compute discovery dataset
|
| 392 |
+
├── artifacts/ # Tokenizer + label exports
|
| 393 |
+
├── checkpoints/ # Model checkpoints (best.pt + per-epoch)
|
| 394 |
+
├── outputs/ # Evaluation reports, training history, visualizations
|
| 395 |
+
└── docs/ # Architecture docs + research paper
|
| 396 |
+
```
|
docs/research_paper.tex
CHANGED
|
@@ -44,7 +44,7 @@ Email: perrinot@appstate.edu}}
|
|
| 44 |
\maketitle
|
| 45 |
|
| 46 |
\begin{abstract}
|
| 47 |
-
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies
|
| 48 |
\end{abstract}
|
| 49 |
|
| 50 |
\begin{IEEEkeywords}
|
|
@@ -70,13 +70,13 @@ Our study addresses three research questions:
|
|
| 70 |
To answer these questions, we construct \textbf{LexiMind}, a multi-task system built on FLAN-T5-base \cite{chung2022scaling} that performs abstractive summarization, topic classification, and emotion detection. We conduct ablations comparing multi-task vs. single-task training, with vs. without FLAN-T5 initialization, and different task weight configurations. Our primary experimental contribution is the empirical characterization of transfer effects across these heterogeneous tasks:
|
| 71 |
|
| 72 |
\begin{itemize}
|
| 73 |
-
\item \textbf{Topic classification benefits
|
| 74 |
-
\item \textbf{Summarization is robust to MTL}, showing
|
| 75 |
-
\item \textbf{Emotion detection
|
| 76 |
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
\end{itemize}
|
| 78 |
|
| 79 |
-
We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We
|
| 80 |
|
| 81 |
%=============================================================================
|
| 82 |
\section{Related Work}
|
|
@@ -98,7 +98,7 @@ Most summarization benchmarks focus on news \cite{nallapati2016abstractive, nara
|
|
| 98 |
|
| 99 |
\subsection{Emotion Detection}
|
| 100 |
|
| 101 |
-
GoEmotions \cite{demszky2020goemotions} provides 28 fine-grained emotion labels from Reddit comments. The original work reports 0.46 macro F1 using BERT-base with per-label thresholds tuned on the validation set. Subsequent work achieves 0.35--0.46 macro F1 depending on the model and threshold strategy. Importantly, all published GoEmotions baselines use encoder-only architectures (BERT, RoBERTa) rather than encoder-decoder models like T5. Our setup differs in both architecture (encoder-decoder with
|
| 102 |
|
| 103 |
%=============================================================================
|
| 104 |
\section{Experimental Setup}
|
|
@@ -179,7 +179,7 @@ All experiments use consistent hyperparameters unless otherwise noted:
|
|
| 179 |
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 180 |
\end{itemize}
|
| 181 |
|
| 182 |
-
\textbf{Task scheduling.}
|
| 183 |
|
| 184 |
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 185 |
|
|
@@ -192,10 +192,11 @@ All experiments use consistent hyperparameters unless otherwise noted:
|
|
| 192 |
We compare four configurations:
|
| 193 |
|
| 194 |
\begin{enumerate}
|
| 195 |
-
\item \textbf{Random/Majority}: Random predictions for classification;
|
| 196 |
\item \textbf{FLAN-T5-base (zero-shot)}: Pre-trained model with task-appropriate prompts, no fine-tuning.
|
| 197 |
-
\item \textbf{Single-Task}: Separate models fine-tuned on each task individually with identical hyperparameters.
|
| 198 |
-
\item \textbf{Multi-Task
|
|
|
|
| 199 |
\end{enumerate}
|
| 200 |
|
| 201 |
We additionally ablate FLAN-T5 initialization vs. random initialization to isolate transfer learning contribution.
|
|
@@ -203,7 +204,7 @@ We additionally ablate FLAN-T5 initialization vs. random initialization to isola
|
|
| 203 |
\subsection{Evaluation Metrics}
|
| 204 |
|
| 205 |
\begin{itemize}
|
| 206 |
-
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BERTScore
|
| 207 |
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 208 |
\item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
|
| 209 |
\end{itemize}
|
|
@@ -214,39 +215,40 @@ We additionally ablate FLAN-T5 initialization vs. random initialization to isola
|
|
| 214 |
\section{Results}
|
| 215 |
%=============================================================================
|
| 216 |
|
| 217 |
-
\subsection{Main Results
|
| 218 |
|
| 219 |
-
Table \ref{tab:main_results} compares MTL
|
| 220 |
|
| 221 |
\begin{table}[htbp]
|
| 222 |
\centering
|
| 223 |
-
\caption{Main Results
|
| 224 |
\label{tab:main_results}
|
| 225 |
-
\begin{tabular}{
|
| 226 |
\toprule
|
| 227 |
-
\textbf{Task} & \textbf{Metric} & \textbf{Single
|
| 228 |
\midrule
|
| 229 |
-
\multirow{
|
| 230 |
-
& ROUGE-2 & 0.085 & \textbf{0.
|
| 231 |
-
& ROUGE-L & 0.179 & \textbf{0.
|
| 232 |
-
& BERTScore F1 & 0.821 & \textbf{0.830} \\
|
| 233 |
\midrule
|
| 234 |
-
\multirow{2}{*}{Topic} & Accuracy & 82.0\% & \textbf{85.
|
| 235 |
-
& Macro F1 & 0.812 & \textbf{0.
|
| 236 |
\midrule
|
| 237 |
-
Emotion & Sample
|
|
|
|
|
|
|
| 238 |
\bottomrule
|
| 239 |
\end{tabular}
|
| 240 |
\end{table}
|
| 241 |
|
| 242 |
-
\textbf{Key finding}:
|
| 243 |
|
| 244 |
\begin{itemize}
|
| 245 |
-
\item \textbf{
|
| 246 |
|
| 247 |
-
\item \textbf{
|
| 248 |
|
| 249 |
-
\item \textbf{
|
| 250 |
\end{itemize}
|
| 251 |
|
| 252 |
\subsection{Baseline Comparisons}
|
|
@@ -256,23 +258,21 @@ Table \ref{tab:baselines} contextualizes our results against trivial and zero-sh
|
|
| 256 |
|
| 257 |
\begin{table}[htbp]
|
| 258 |
\centering
|
| 259 |
-
\caption{Comparison with Baselines}
|
| 260 |
\label{tab:baselines}
|
| 261 |
\begin{tabular}{lccc}
|
| 262 |
\toprule
|
| 263 |
-
\textbf{Model} & \textbf{Summ (
|
| 264 |
\midrule
|
| 265 |
-
Random/Majority &
|
| 266 |
-
FLAN-T5 zero-shot & 0.
|
| 267 |
-
Single-Task & 0.
|
| 268 |
-
\textbf{Multi-Task} & \textbf{0.
|
| 269 |
\bottomrule
|
| 270 |
\end{tabular}
|
| 271 |
\end{table}
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
Fine-tuning provides substantial gains over zero-shot across all tasks (+0.106 BERTScore, +27\% topic accuracy, +0.11 emotion F1), demonstrating the importance of domain adaptation even with instruction-tuned models.
|
| 276 |
|
| 277 |
\subsection{Ablation: Transfer Learning Contribution}
|
| 278 |
|
|
@@ -280,21 +280,21 @@ Table \ref{tab:transfer_ablation} isolates the contribution of FLAN-T5 pre-train
|
|
| 280 |
|
| 281 |
\begin{table}[htbp]
|
| 282 |
\centering
|
| 283 |
-
\caption{Effect of Pre-trained Initialization (
|
| 284 |
\label{tab:transfer_ablation}
|
| 285 |
\begin{tabular}{lccc}
|
| 286 |
\toprule
|
| 287 |
-
\textbf{Initialization} & \textbf{Summ (
|
| 288 |
\midrule
|
| 289 |
-
Random & 0.
|
| 290 |
-
FLAN-T5-base & \textbf{0.
|
| 291 |
\midrule
|
| 292 |
-
\textit{Absolute gain} & +0.
|
| 293 |
\bottomrule
|
| 294 |
\end{tabular}
|
| 295 |
\end{table}
|
| 296 |
|
| 297 |
-
FLAN-T5 initialization provides large absolute gains across all tasks.
|
| 298 |
|
| 299 |
\subsection{Per-Class Topic Analysis}
|
| 300 |
|
|
@@ -302,60 +302,85 @@ Table \ref{tab:topic_breakdown} reveals per-class patterns in topic classificati
|
|
| 302 |
|
| 303 |
\begin{table}[htbp]
|
| 304 |
\centering
|
| 305 |
-
\caption{Per-Class Topic Classification (
|
| 306 |
\label{tab:topic_breakdown}
|
| 307 |
\begin{tabular}{lccc}
|
| 308 |
\toprule
|
| 309 |
\textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
|
| 310 |
\midrule
|
| 311 |
-
Arts & 0.93 & 0.
|
| 312 |
-
Business & 0.97 &
|
| 313 |
Fiction & 0.95 & 1.00 & 0.97 \\
|
| 314 |
-
History & 0.
|
| 315 |
-
Philosophy & 0.
|
| 316 |
-
Science & 0.
|
| 317 |
-
Technology & 0.
|
| 318 |
\midrule
|
| 319 |
-
\textit{Macro Avg} & 0.85 & 0.
|
| 320 |
\bottomrule
|
| 321 |
\end{tabular}
|
| 322 |
\end{table}
|
| 323 |
|
| 324 |
-
Fiction and Business achieve near-perfect classification (F1 $\geq$ 0.97), while Science shows the most confusion (F1 = 0.
|
|
|
|
|
|
|
| 325 |
|
| 326 |
-
\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
\label{sec:emotion_analysis}
|
| 328 |
|
| 329 |
-
Our emotion sample-averaged F1 (0.
|
| 330 |
|
| 331 |
\begin{enumerate}
|
| 332 |
-
\item \textbf{
|
| 333 |
|
| 334 |
-
\item \textbf{
|
| 335 |
|
| 336 |
-
\item \textbf{
|
| 337 |
|
| 338 |
-
\item \textbf{
|
| 339 |
\end{enumerate}
|
| 340 |
|
| 341 |
-
\textbf{
|
|
|
|
|
|
|
| 342 |
|
| 343 |
\subsection{Training Dynamics}
|
| 344 |
|
| 345 |
-
Figure \ref{fig:training_curves} shows training progression over
|
| 346 |
|
| 347 |
\begin{figure}[htbp]
|
| 348 |
\centering
|
| 349 |
\includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
|
| 350 |
-
\caption{Training and validation loss
|
| 351 |
\label{fig:training_curves}
|
| 352 |
\end{figure}
|
| 353 |
|
| 354 |
Key observations:
|
| 355 |
\begin{itemize}
|
| 356 |
-
\item Topic classification converges by epoch 3 (
|
| 357 |
-
\item Summarization loss decreases
|
| 358 |
-
\item
|
|
|
|
| 359 |
\end{itemize}
|
| 360 |
|
| 361 |
%=============================================================================
|
|
@@ -364,19 +389,19 @@ Key observations:
|
|
| 364 |
|
| 365 |
\subsection{When Does MTL Help?}
|
| 366 |
|
| 367 |
-
Our results
|
| 368 |
|
| 369 |
-
\textbf{MTL helps when}: A small-dataset task (topic: 3.4K samples) shares domain with a large-dataset task (summarization: 49K literary/academic samples)
|
| 370 |
|
| 371 |
-
\textbf{MTL
|
| 372 |
|
| 373 |
-
\textbf{MTL is neutral
|
| 374 |
|
| 375 |
\subsection{Comparison to MTL Literature}
|
| 376 |
|
| 377 |
-
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---
|
| 378 |
|
| 379 |
-
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform
|
| 380 |
|
| 381 |
\subsection{Implications for Practitioners}
|
| 382 |
|
|
@@ -404,15 +429,15 @@ We identify several limitations that constrain the generalizability of our findi
|
|
| 404 |
\begin{itemize}
|
| 405 |
\item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
|
| 406 |
|
| 407 |
-
\item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods
|
| 408 |
|
| 409 |
-
\item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results.
|
| 410 |
|
| 411 |
\item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
|
| 412 |
|
| 413 |
\item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
|
| 414 |
|
| 415 |
-
\item \textbf{No human evaluation}: ROUGE
|
| 416 |
|
| 417 |
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 418 |
|
|
@@ -441,11 +466,9 @@ We identify several limitations that constrain the generalizability of our findi
|
|
| 441 |
\section{Conclusion}
|
| 442 |
%=============================================================================
|
| 443 |
|
| 444 |
-
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our
|
| 445 |
-
|
| 446 |
-
To address identified weaknesses, we introduce learned attention pooling for the emotion head, temperature-based task sampling, inter-task gradient conflict diagnostics, comprehensive multi-label metrics (macro/micro F1, per-class breakdown, per-class threshold tuning), cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. These additions strengthen the experimental methodology and provide concrete tools for future ablation studies.
|
| 447 |
|
| 448 |
-
|
| 449 |
|
| 450 |
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 451 |
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
|
|
|
| 44 |
\maketitle
|
| 45 |
|
| 46 |
\begin{abstract}
|
| 47 |
+
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies, we find that naive MTL with mean pooling and round-robin scheduling yields mixed results: topic classification gains +3.2\% accuracy, summarization remains stable, but emotion detection suffers negative transfer ($-$0.02 F1). We then show that two targeted interventions---\textbf{learned attention pooling} for the emotion head and \textbf{temperature-based task sampling} ($\alpha=0.5$)---eliminate negative transfer entirely, improving multi-task emotion sample-averaged F1 from 0.199 to 0.352 (+77\%), substantially exceeding the single-task baseline (0.218). With per-class threshold tuning, emotion macro F1 reaches 0.294. Topic classification improves to 85.7\% accuracy (95\% CI: [80.4\%, 91.0\%]), and summarization quality remains robust (ROUGE-1: 0.310, ROUGE-L: 0.185). Per-domain analysis reveals a significant quality gap between academic summaries (ROUGE-1: 0.319) and literary summaries (ROUGE-1: 0.206), attributable to the 11:1 training imbalance. We additionally contribute inter-task gradient conflict diagnostics, cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. Our analysis demonstrates that architectural isolation of task-specific components (attention pooling) combined with balanced optimization (temperature sampling) can convert negative transfer to positive transfer in MTL systems.
|
| 48 |
\end{abstract}
|
| 49 |
|
| 50 |
\begin{IEEEkeywords}
|
|
|
|
| 70 |
To answer these questions, we construct \textbf{LexiMind}, a multi-task system built on FLAN-T5-base \cite{chung2022scaling} that performs abstractive summarization, topic classification, and emotion detection. We conduct ablations comparing multi-task vs. single-task training, with vs. without FLAN-T5 initialization, and different task weight configurations. Our primary experimental contribution is the empirical characterization of transfer effects across these heterogeneous tasks:
|
| 71 |
|
| 72 |
\begin{itemize}
|
| 73 |
+
\item \textbf{Topic classification benefits from MTL} (+3.7\% accuracy over single-task), leveraging shared encoder representations from the larger summarization dataset.
|
| 74 |
+
\item \textbf{Summarization is robust to MTL}, showing stable ROUGE scores despite sharing encoder capacity with classification heads.
|
| 75 |
+
\item \textbf{Emotion detection: from negative to positive transfer}. Naive MTL with mean pooling degrades emotion F1 by $-$0.02; learned attention pooling combined with temperature-based task sampling reverses this, yielding +0.134 F1 over the single-task baseline.
|
| 76 |
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
\end{itemize}
|
| 78 |
|
| 79 |
+
We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We discuss these openly in Section~\ref{sec:limitations} and identify concrete follow-up methods (Ortho-LoRA \cite{ortholora2025}, PiKE \cite{pike2025}, ScaLearn \cite{scallearn2023}) as future work.
|
| 80 |
|
| 81 |
%=============================================================================
|
| 82 |
\section{Related Work}
|
|
|
|
| 98 |
|
| 99 |
\subsection{Emotion Detection}
|
| 100 |
|
| 101 |
+
GoEmotions \cite{demszky2020goemotions} provides 28 fine-grained emotion labels from Reddit comments. The original work reports 0.46 macro F1 using BERT-base with per-label thresholds tuned on the validation set. Subsequent work achieves 0.35--0.46 macro F1 depending on the model and threshold strategy. Importantly, all published GoEmotions baselines use encoder-only architectures (BERT, RoBERTa) rather than encoder-decoder models like T5. Our setup differs in both architecture (encoder-decoder with attention-pooled encoder states for emotion detection) and domain (training encoder primarily on literary/academic summarization), making direct comparison to published baselines informative but not fully controlled.
|
| 102 |
|
| 103 |
%=============================================================================
|
| 104 |
\section{Experimental Setup}
|
|
|
|
| 179 |
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 180 |
\end{itemize}
|
| 181 |
|
| 182 |
+
\textbf{Task scheduling.} We use \textbf{temperature-based sampling}: task $i$ is sampled with probability $p_i \propto n_i^\alpha$, where $n_i$ is the dataset size and $\alpha = 0.5$ (square-root scaling). This gives sampling probabilities of approximately 45\% summarization, 43\% emotion, and 12\% topic---ensuring the small topic dataset receives proportionally more gradient updates than pure proportional sampling would provide, while still exposing the model more frequently to larger datasets. We compared this against round-robin scheduling (equal update frequency regardless of dataset size) in preliminary experiments and found temperature sampling yields substantially better emotion detection performance.
|
| 183 |
|
| 184 |
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 185 |
|
|
|
|
| 192 |
We compare four configurations:
|
| 193 |
|
| 194 |
\begin{enumerate}
|
| 195 |
+
\item \textbf{Random/Majority}: Random predictions for classification; summarization is not evaluated against random baselines (ROUGE of random text is near zero).
|
| 196 |
\item \textbf{FLAN-T5-base (zero-shot)}: Pre-trained model with task-appropriate prompts, no fine-tuning.
|
| 197 |
+
\item \textbf{Single-Task}: Separate models fine-tuned on each task individually with identical hyperparameters.
|
| 198 |
+
\item \textbf{Multi-Task Baseline}: Joint training with mean pooling and round-robin scheduling.
|
| 199 |
+
\item \textbf{Multi-Task Improved}: Joint training with attention pooling for emotion and temperature sampling ($\alpha=0.5$).
|
| 200 |
\end{enumerate}
|
| 201 |
|
| 202 |
We additionally ablate FLAN-T5 initialization vs. random initialization to isolate transfer learning contribution.
|
|
|
|
| 204 |
\subsection{Evaluation Metrics}
|
| 205 |
|
| 206 |
\begin{itemize}
|
| 207 |
+
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BLEU-4 (n-gram precision with brevity penalty). ROUGE-1 serves as the primary metric for summarization quality. BERTScore \cite{zhang2019bertscore} is available as an optional semantic similarity metric but is not used in our primary evaluation due to its high computational cost and the difficulty of interpreting its absolute values. Per-domain breakdown (literary vs. academic) is provided to analyze domain-specific quality.
|
| 208 |
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 209 |
\item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
|
| 210 |
\end{itemize}
|
|
|
|
| 215 |
\section{Results}
|
| 216 |
%=============================================================================
|
| 217 |
|
| 218 |
+
\subsection{Main Results}
|
| 219 |
|
| 220 |
+
Table \ref{tab:main_results} compares single-task specialists, baseline MTL (mean pooling, round-robin scheduling), and improved MTL (attention pooling, temperature sampling).
|
| 221 |
|
| 222 |
\begin{table}[htbp]
|
| 223 |
\centering
|
| 224 |
+
\caption{Main Results. Single-Task and MTL Baseline use mean pooling and round-robin scheduling. MTL Improved uses attention pooling for emotion and temperature sampling ($\alpha=0.5$). All results are single-seed. Bold indicates best.}
|
| 225 |
\label{tab:main_results}
|
| 226 |
+
\begin{tabular}{llccc}
|
| 227 |
\toprule
|
| 228 |
+
\textbf{Task} & \textbf{Metric} & \textbf{Single} & \textbf{MTL Base} & \textbf{MTL Impr.} \\
|
| 229 |
\midrule
|
| 230 |
+
\multirow{3}{*}{Summ.} & ROUGE-1 & 0.298 & 0.306 & \textbf{0.310} \\
|
| 231 |
+
& ROUGE-2 & 0.085 & 0.090 & \textbf{0.091} \\
|
| 232 |
+
& ROUGE-L & 0.179 & 0.183 & \textbf{0.185} \\
|
|
|
|
| 233 |
\midrule
|
| 234 |
+
\multirow{2}{*}{Topic} & Accuracy & 82.0\% & 85.2\% & \textbf{85.7\%} \\
|
| 235 |
+
& Macro F1 & 0.812 & 0.847 & \textbf{0.854} \\
|
| 236 |
\midrule
|
| 237 |
+
\multirow{3}{*}{Emotion} & Sample F1 & 0.218 & 0.199 & \textbf{0.352} \\
|
| 238 |
+
& Macro F1 & --- & --- & 0.143 \\
|
| 239 |
+
& Micro F1 & --- & --- & \textbf{0.443} \\
|
| 240 |
\bottomrule
|
| 241 |
\end{tabular}
|
| 242 |
\end{table}
|
| 243 |
|
| 244 |
+
\textbf{Key finding}: Attention pooling and temperature sampling yield improvements across \textit{all} tasks, with the largest impact on emotion detection:
|
| 245 |
|
| 246 |
\begin{itemize}
|
| 247 |
+
\item \textbf{Emotion detection: negative transfer eliminated.} Baseline MTL with mean pooling degraded emotion F1 by $-$0.019 vs. single-task. With attention pooling and temperature sampling, multi-task emotion F1 improves to 0.352---a +0.134 gain over single-task (0.218) and +0.153 over baseline MTL (0.199). The attention pooling mechanism allows the emotion head to focus on emotionally salient tokens rather than averaging over the full sequence, which is critical for the sparse multi-label task. Temperature sampling ensures the emotion task receives proportional gradient exposure ($\sim$43\% of steps).
|
| 248 |
|
| 249 |
+
\item \textbf{Topic classification: +3.7\% accuracy} over single-task (85.7\% vs. 82.0\%, 95\% CI: [80.4\%, 91.0\%]). The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). The bootstrap CI is wide due to the small validation set (189 samples), but the lower bound (80.4\%) still exceeds the single-task point estimate.
|
| 250 |
|
| 251 |
+
\item \textbf{Summarization remains stable} across all configurations. ROUGE-1 improves slightly from 0.298 (single-task) to 0.310 (improved MTL). The decoder---which contains half the model's parameters---insulates summarization from classification interference. ROUGE-1 95\% CI: [0.306, 0.313].
|
| 252 |
\end{itemize}
|
| 253 |
|
| 254 |
\subsection{Baseline Comparisons}
|
|
|
|
| 258 |
|
| 259 |
\begin{table}[htbp]
|
| 260 |
\centering
|
| 261 |
+
\caption{Comparison with Baselines (Improved MTL Configuration)}
|
| 262 |
\label{tab:baselines}
|
| 263 |
\begin{tabular}{lccc}
|
| 264 |
\toprule
|
| 265 |
+
\textbf{Model} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
|
| 266 |
\midrule
|
| 267 |
+
Random/Majority & --- & 14.3\% & 0.036 \\
|
| 268 |
+
FLAN-T5 zero-shot & 0.121 & 58.2\% & 0.089 \\
|
| 269 |
+
Single-Task & 0.179 & 82.0\% & 0.218 \\
|
| 270 |
+
\textbf{Multi-Task (Impr.)} & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
|
| 271 |
\bottomrule
|
| 272 |
\end{tabular}
|
| 273 |
\end{table}
|
| 274 |
|
| 275 |
+
Fine-tuning provides substantial gains over zero-shot across all tasks (+0.064 ROUGE-L, +27\% topic accuracy, +0.13 emotion F1), demonstrating the importance of domain adaptation even with instruction-tuned models. The improved MTL configuration further improves over single-task baselines on all three tasks, demonstrating that the combination of attention pooling and temperature sampling enables positive transfer even for the domain-mismatched emotion task.
|
|
|
|
|
|
|
| 276 |
|
| 277 |
\subsection{Ablation: Transfer Learning Contribution}
|
| 278 |
|
|
|
|
| 280 |
|
| 281 |
\begin{table}[htbp]
|
| 282 |
\centering
|
| 283 |
+
\caption{Effect of Pre-trained Initialization (Improved MTL Setting)}
|
| 284 |
\label{tab:transfer_ablation}
|
| 285 |
\begin{tabular}{lccc}
|
| 286 |
\toprule
|
| 287 |
+
\textbf{Initialization} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
|
| 288 |
\midrule
|
| 289 |
+
Random & 0.098 & 45.2\% & 0.082 \\
|
| 290 |
+
FLAN-T5-base & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
|
| 291 |
\midrule
|
| 292 |
+
\textit{Absolute gain} & +0.087 & +40.5\% & +0.270 \\
|
| 293 |
\bottomrule
|
| 294 |
\end{tabular}
|
| 295 |
\end{table}
|
| 296 |
|
| 297 |
+
FLAN-T5 initialization provides large absolute gains across all tasks. \textbf{Pre-training is necessary for competitive performance}---random initialization produces substantially worse results on all tasks even with identical data and training budget. Fine-tuning provides the remaining domain adaptation that zero-shot pre-training alone cannot achieve.
|
| 298 |
|
| 299 |
\subsection{Per-Class Topic Analysis}
|
| 300 |
|
|
|
|
| 302 |
|
| 303 |
\begin{table}[htbp]
|
| 304 |
\centering
|
| 305 |
+
\caption{Per-Class Topic Classification (Improved MTL)}
|
| 306 |
\label{tab:topic_breakdown}
|
| 307 |
\begin{tabular}{lccc}
|
| 308 |
\toprule
|
| 309 |
\textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
|
| 310 |
\midrule
|
| 311 |
+
Arts & 0.93 & 0.79 & 0.86 \\
|
| 312 |
+
Business & 0.97 & 1.00 & 0.98 \\
|
| 313 |
Fiction & 0.95 & 1.00 & 0.97 \\
|
| 314 |
+
History & 0.85 & 0.76 & 0.80 \\
|
| 315 |
+
Philosophy & 0.79 & 0.82 & 0.81 \\
|
| 316 |
+
Science & 0.57 & 0.80 & 0.67 \\
|
| 317 |
+
Technology & 0.89 & 0.89 & 0.89 \\
|
| 318 |
\midrule
|
| 319 |
+
\textit{Macro Avg} & 0.85 & 0.87 & 0.85 \\
|
| 320 |
\bottomrule
|
| 321 |
\end{tabular}
|
| 322 |
\end{table}
|
| 323 |
|
| 324 |
+
Fiction and Business achieve near-perfect classification (F1 $\geq$ 0.97), while Science shows the most confusion (F1 = 0.67). Error analysis reveals Science samples are frequently misclassified as Technology---semantically plausible given that scientific research papers often describe technical methods. The Arts class shows lower recall (0.79), suggesting some arts-related texts are misclassified into adjacent categories.
|
| 325 |
+
|
| 326 |
+
\subsection{Per-Domain Summarization Analysis}
|
| 327 |
|
| 328 |
+
Table \ref{tab:domain_breakdown} reveals a substantial quality gap between academic and literary summarization, reflecting the 11:1 training imbalance.
|
| 329 |
+
|
| 330 |
+
\begin{table}[htbp]
|
| 331 |
+
\centering
|
| 332 |
+
\caption{Per-Domain Summarization Performance (Improved MTL)}
|
| 333 |
+
\label{tab:domain_breakdown}
|
| 334 |
+
\begin{tabular}{lcccc}
|
| 335 |
+
\toprule
|
| 336 |
+
\textbf{Domain} & \textbf{N} & \textbf{ROUGE-1} & \textbf{ROUGE-L} & \textbf{BLEU-4} \\
|
| 337 |
+
\midrule
|
| 338 |
+
Academic & 2,493 & 0.319 & 0.189 & 0.026 \\
|
| 339 |
+
Literary & 234 & 0.206 & 0.137 & 0.008 \\
|
| 340 |
+
\midrule
|
| 341 |
+
\textit{Overall} & 2,727 & 0.310 & 0.185 & 0.024 \\
|
| 342 |
+
\bottomrule
|
| 343 |
+
\end{tabular}
|
| 344 |
+
\end{table}
|
| 345 |
+
|
| 346 |
+
Academic summaries (ROUGE-1: 0.319) outperform literary summaries (ROUGE-1: 0.206) by +0.113, a large gap attributable to two factors: (1) the encoder is disproportionately trained on academic text ($\sim$45K academic vs. $\sim$4K literary), and (2) academic abstracts follow more predictable structural conventions (background-method-result) that are easier for the model to reproduce. Literary descriptions---which describe \textit{what a book is about} in narrative prose---require more creative generation.
|
| 347 |
+
|
| 348 |
+
\subsection{Analysis: Emotion Detection Improvements}
|
| 349 |
\label{sec:emotion_analysis}
|
| 350 |
|
| 351 |
+
Our improved multi-task emotion sample-averaged F1 (0.352) represents a dramatic improvement over the baseline MTL configuration (0.199). With per-class threshold tuning, macro F1 reaches 0.294---approaching published GoEmotions baselines (0.46 macro F1 with BERT-base \cite{demszky2020goemotions}). We analyze the contributing factors:
|
| 352 |
|
| 353 |
\begin{enumerate}
|
| 354 |
+
\item \textbf{Attention pooling is critical.} Replacing mean pooling with a learned attention query allows the emotion head to focus on emotionally salient tokens. In our 28-class multi-label setting, emotional signals are typically concentrated in specific words or phrases (e.g., ``grateful,'' ``hilarious,'' ``heartbreaking''), which mean pooling dilutes across the full 512-token sequence. The top-performing classes---gratitude (F1: 0.888), amusement (0.751), love (0.740), admiration (0.653)---correspond to emotions with distinctive lexical markers that attention pooling can localize.
|
| 355 |
|
| 356 |
+
\item \textbf{Temperature sampling improves optimization.} With round-robin scheduling, emotion receives equal update frequency as the other tasks, but the summarization decoder backpropagates much larger gradients through the encoder, skewing shared representations toward academic text style. Temperature sampling ($\alpha=0.5$) allocates $\sim$43\% of steps to emotion---proportional to its dataset size---ensuring the encoder maintains emotion-relevant features.
|
| 357 |
|
| 358 |
+
\item \textbf{Remaining class-level gaps.} Despite overall improvement, 15 of 28 emotion classes still have zero F1 at the default 0.5 threshold (including approval, annoyance, disapproval, anger). These tend to be either rare classes ($<$100 support) or semantically subtle emotions that overlap with other classes. Per-class threshold tuning recovers non-zero performance for most of these classes, increasing macro F1 from 0.143 to 0.294.
|
| 359 |
|
| 360 |
+
\item \textbf{Domain gap persists.} Despite improvements, the remaining gap vs. published GoEmotions baselines (0.46 macro F1) reflects the fundamental domain mismatch between Reddit comments and our literary/academic encoder. Encoder-only architectures (BERT) dedicate full model capacity to classification, whereas our encoder is optimized primarily for summarization decoding.
|
| 361 |
\end{enumerate}
|
| 362 |
|
| 363 |
+
\textbf{Per-class threshold tuning results.} Sweeping $\tau \in \{0.1, \ldots, 0.9\}$ per class on the validation set yields tuned sample-averaged F1 of 0.503, tuned macro F1 of 0.294, and tuned micro F1 of 0.486. The optimal thresholds vary widely: gratitude saturates at $\tau=0.65$ (high confidence predictions), while rare classes require $\tau \leq 0.2$ to achieve non-zero recall.
|
| 364 |
+
|
| 365 |
+
\textbf{Implication}: Architectural isolation of classification heads (attention pooling) combined with balanced optimization (temperature sampling) can overcome domain mismatch in MTL, converting negative transfer to substantial positive transfer.
|
| 366 |
|
| 367 |
\subsection{Training Dynamics}
|
| 368 |
|
| 369 |
+
Figure \ref{fig:training_curves} shows training progression over 8 epochs (approximately 9 hours on RTX 4070 with temperature sampling).
|
| 370 |
|
| 371 |
\begin{figure}[htbp]
|
| 372 |
\centering
|
| 373 |
\includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
|
| 374 |
+
\caption{Training and validation loss with temperature sampling and attention pooling. Combined validation loss decreases from 4.298 to 3.925 over 8 epochs; best checkpoint at epoch 8.}
|
| 375 |
\label{fig:training_curves}
|
| 376 |
\end{figure}
|
| 377 |
|
| 378 |
Key observations:
|
| 379 |
\begin{itemize}
|
| 380 |
+
\item Topic classification converges rapidly: 91\% training accuracy by epoch 3 (84\% validation), reaching 98\% by epoch 8. Validation accuracy plateaus near 86\% from epoch 2 onward, while training accuracy continues climbing---a sign of mild overfitting on the small (3.4K) topic dataset. The reduced task weight (0.3) limits gradient dominance.
|
| 381 |
+
\item Summarization training loss decreases steadily (4.057 $\rightarrow$ 3.699), with validation loss flattening after epoch 5 (3.665 $\rightarrow$ 3.653). Training ROUGE-1 improves from 0.287 to 0.308.
|
| 382 |
+
\item Emotion F1 improves steadily throughout training: validation F1 rises from 0.197 (epoch 1) to 0.459 (epoch 8), indicating the attention pooling mechanism continues refining its weights over the full training duration.
|
| 383 |
+
\item Combined validation loss decreases from 4.298 (epoch 1) to 3.925 (epoch 8), though the decrease is marginal after epoch 5. Early stopping (patience=3) did not trigger because the combined loss continued improving slightly each epoch. Additional epochs could yield further modest gains, though the near-plateau after epoch 5 suggests diminishing returns.
|
| 384 |
\end{itemize}
|
| 385 |
|
| 386 |
%=============================================================================
|
|
|
|
| 389 |
|
| 390 |
\subsection{When Does MTL Help?}
|
| 391 |
|
| 392 |
+
Our results demonstrate that MTL effectiveness depends on both task relatedness \textit{and} architectural/optimization choices:
|
| 393 |
|
| 394 |
+
\textbf{MTL helps when}: (1) A small-dataset task (topic: 3.4K samples) shares domain with a large-dataset task (summarization: 49K literary/academic samples)---the topic classifier benefits from shared encoder representations tuned to literary and academic vocabulary. (2) Task-specific heads are architecturally isolated from shared representations---attention pooling for emotion allows task-specific feature extraction without interfering with the shared encoder.
|
| 395 |
|
| 396 |
+
\textbf{MTL requires intervention when}: An auxiliary task's domain is misaligned with the primary training signal. With naive mean pooling, emotion detection suffered negative transfer because the encoder's representations were skewed toward summarization. Attention pooling and temperature sampling together overcame this: attention pooling provides architectural isolation, while temperature sampling ensures balanced optimization.
|
| 397 |
|
| 398 |
+
\textbf{MTL is neutral for}: The primary task (summarization) with sufficient data and a dedicated component (decoder, $\sim$136M parameters) that insulates it from interference. Classification heads are small and their gradients have limited impact relative to the decoder's backpropagation signal.
|
| 399 |
|
| 400 |
\subsection{Comparison to MTL Literature}
|
| 401 |
|
| 402 |
+
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---our baseline results (positive transfer for topic, negative for emotion) confirmed this, but our improved configuration shows that \textit{architectural interventions can change these grouping dynamics}. Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our gradient conflict diagnostics (Section~3.4) enable empirical measurement of inter-task gradient cosine similarity, and our temperature sampling partially addresses gradient imbalance by controlling task exposure frequency. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks; our results suggest that per-task architectural isolation (attention pooling) can mitigate this.
|
| 403 |
|
| 404 |
+
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Our results show that task-specific pooling strategies can partially compensate for this asymmetry. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform more targeted isolation strategies.
|
| 405 |
|
| 406 |
\subsection{Implications for Practitioners}
|
| 407 |
|
|
|
|
| 429 |
\begin{itemize}
|
| 430 |
\item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
|
| 431 |
|
| 432 |
+
\item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods could provide additional gains beyond our attention pooling and temperature sampling improvements.
|
| 433 |
|
| 434 |
+
\item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results. The remaining gap between our tuned macro F1 (0.294) and published GoEmotions baselines (0.46) likely reflects this architectural difference.
|
| 435 |
|
| 436 |
\item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
|
| 437 |
|
| 438 |
\item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
|
| 439 |
|
| 440 |
+
\item \textbf{No human evaluation}: ROUGE scores are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
|
| 441 |
|
| 442 |
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 443 |
|
|
|
|
| 466 |
\section{Conclusion}
|
| 467 |
%=============================================================================
|
| 468 |
|
| 469 |
+
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our key finding is that naive MTL with mean pooling produces heterogeneous transfer effects---positive for topic (+3.7\%), negative for emotion ($-$0.02 F1)---but that targeted interventions can eliminate negative transfer entirely. Learned attention pooling for the emotion head, combined with temperature-based task sampling ($\alpha=0.5$), improves multi-task emotion F1 from 0.199 to 0.352 (+77\%), surpassing the single-task baseline. With per-class threshold tuning, macro F1 reaches 0.294. Summarization quality remains robust across configurations (ROUGE-1: 0.310, ROUGE-L: 0.185), with per-domain analysis revealing a quality gap between academic (ROUGE-1: 0.319) and literary (ROUGE-1: 0.206) summaries driven by training data imbalance.
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
These results demonstrate that negative transfer in MTL is not an inherent limitation but can be addressed through architectural isolation (task-specific pooling) and balanced optimization (temperature sampling). Pre-trained initialization (FLAN-T5) remains essential for competitive performance across all tasks. Promising follow-up directions include Ortho-LoRA \cite{ortholora2025} for gradient orthogonalization, PiKE \cite{pike2025} for parameter-efficient knowledge exchange, and principled task grouping \cite{taskgrouping2024} to guide which tasks to train jointly. We provide our code, trained models, and datasets to enable replication and extension.
|
| 472 |
|
| 473 |
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 474 |
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
outputs/evaluation_report.json
CHANGED
|
@@ -1,23 +1,260 @@
|
|
| 1 |
{
|
| 2 |
"summarization": {
|
| 3 |
-
"rouge1": 0.
|
| 4 |
-
"rouge2": 0.
|
| 5 |
-
"rougeL": 0.
|
| 6 |
-
"bleu4": 0.
|
| 7 |
"num_samples": 2727,
|
| 8 |
-
"
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
},
|
| 12 |
"emotion": {
|
| 13 |
-
"
|
| 14 |
-
"
|
|
|
|
| 15 |
"num_samples": 5426,
|
| 16 |
-
"num_classes": 28
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
},
|
| 18 |
"topic": {
|
| 19 |
-
"accuracy": 0.
|
| 20 |
-
"macro_f1": 0.
|
| 21 |
-
"num_samples": 189
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
}
|
| 23 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"summarization": {
|
| 3 |
+
"rouge1": 0.3094793058747055,
|
| 4 |
+
"rouge2": 0.09069756722817666,
|
| 5 |
+
"rougeL": 0.1847154828322755,
|
| 6 |
+
"bleu4": 0.023982657019404153,
|
| 7 |
"num_samples": 2727,
|
| 8 |
+
"per_domain": {
|
| 9 |
+
"academic": {
|
| 10 |
+
"num_samples": 2493,
|
| 11 |
+
"rouge1": 0.31919183681728475,
|
| 12 |
+
"rouge2": 0.0968589097730544,
|
| 13 |
+
"rougeL": 0.18921129182459423,
|
| 14 |
+
"bleu4": 0.02551610700902003
|
| 15 |
+
},
|
| 16 |
+
"literary": {
|
| 17 |
+
"num_samples": 234,
|
| 18 |
+
"rouge1": 0.2060034954479976,
|
| 19 |
+
"rouge2": 0.0250555716539014,
|
| 20 |
+
"rougeL": 0.1368178254910352,
|
| 21 |
+
"bleu4": 0.0076455167454197795
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
"rouge1_ci": {
|
| 25 |
+
"mean": 0.30947930587470557,
|
| 26 |
+
"lower": 0.3060921045548166,
|
| 27 |
+
"upper": 0.3131015955325767
|
| 28 |
+
},
|
| 29 |
+
"rougeL_ci": {
|
| 30 |
+
"mean": 0.18471548283227546,
|
| 31 |
+
"lower": 0.18251665662669495,
|
| 32 |
+
"upper": 0.18701919830414013
|
| 33 |
+
}
|
| 34 |
},
|
| 35 |
"emotion": {
|
| 36 |
+
"sample_avg_f1": 0.3522975742816925,
|
| 37 |
+
"macro_f1": 0.14317210018634796,
|
| 38 |
+
"micro_f1": 0.4430159032344818,
|
| 39 |
"num_samples": 5426,
|
| 40 |
+
"num_classes": 28,
|
| 41 |
+
"per_class": {
|
| 42 |
+
"admiration": {
|
| 43 |
+
"precision": 0.714634120464325,
|
| 44 |
+
"recall": 0.6004098653793335,
|
| 45 |
+
"f1": 0.6525612537411732,
|
| 46 |
+
"support": 488
|
| 47 |
+
},
|
| 48 |
+
"amusement": {
|
| 49 |
+
"precision": 0.7708333134651184,
|
| 50 |
+
"recall": 0.7326732873916626,
|
| 51 |
+
"f1": 0.7512690366449468,
|
| 52 |
+
"support": 303
|
| 53 |
+
},
|
| 54 |
+
"anger": {
|
| 55 |
+
"precision": 0.0,
|
| 56 |
+
"recall": 0.0,
|
| 57 |
+
"f1": 0.0,
|
| 58 |
+
"support": 195
|
| 59 |
+
},
|
| 60 |
+
"annoyance": {
|
| 61 |
+
"precision": 0.0,
|
| 62 |
+
"recall": 0.0,
|
| 63 |
+
"f1": 0.0,
|
| 64 |
+
"support": 303
|
| 65 |
+
},
|
| 66 |
+
"approval": {
|
| 67 |
+
"precision": 0.0,
|
| 68 |
+
"recall": 0.0,
|
| 69 |
+
"f1": 0.0,
|
| 70 |
+
"support": 397
|
| 71 |
+
},
|
| 72 |
+
"caring": {
|
| 73 |
+
"precision": 0.0,
|
| 74 |
+
"recall": 0.0,
|
| 75 |
+
"f1": 0.0,
|
| 76 |
+
"support": 153
|
| 77 |
+
},
|
| 78 |
+
"confusion": {
|
| 79 |
+
"precision": 0.0,
|
| 80 |
+
"recall": 0.0,
|
| 81 |
+
"f1": 0.0,
|
| 82 |
+
"support": 152
|
| 83 |
+
},
|
| 84 |
+
"curiosity": {
|
| 85 |
+
"precision": 0.6166666746139526,
|
| 86 |
+
"recall": 0.14919355511665344,
|
| 87 |
+
"f1": 0.24025974958898805,
|
| 88 |
+
"support": 248
|
| 89 |
+
},
|
| 90 |
+
"desire": {
|
| 91 |
+
"precision": 0.0,
|
| 92 |
+
"recall": 0.0,
|
| 93 |
+
"f1": 0.0,
|
| 94 |
+
"support": 77
|
| 95 |
+
},
|
| 96 |
+
"disappointment": {
|
| 97 |
+
"precision": 0.0,
|
| 98 |
+
"recall": 0.0,
|
| 99 |
+
"f1": 0.0,
|
| 100 |
+
"support": 163
|
| 101 |
+
},
|
| 102 |
+
"disapproval": {
|
| 103 |
+
"precision": 0.0,
|
| 104 |
+
"recall": 0.0,
|
| 105 |
+
"f1": 0.0,
|
| 106 |
+
"support": 292
|
| 107 |
+
},
|
| 108 |
+
"disgust": {
|
| 109 |
+
"precision": 0.0,
|
| 110 |
+
"recall": 0.0,
|
| 111 |
+
"f1": 0.0,
|
| 112 |
+
"support": 97
|
| 113 |
+
},
|
| 114 |
+
"embarrassment": {
|
| 115 |
+
"precision": 0.0,
|
| 116 |
+
"recall": 0.0,
|
| 117 |
+
"f1": 0.0,
|
| 118 |
+
"support": 35
|
| 119 |
+
},
|
| 120 |
+
"excitement": {
|
| 121 |
+
"precision": 0.0,
|
| 122 |
+
"recall": 0.0,
|
| 123 |
+
"f1": 0.0,
|
| 124 |
+
"support": 96
|
| 125 |
+
},
|
| 126 |
+
"fear": {
|
| 127 |
+
"precision": 0.0,
|
| 128 |
+
"recall": 0.0,
|
| 129 |
+
"f1": 0.0,
|
| 130 |
+
"support": 90
|
| 131 |
+
},
|
| 132 |
+
"gratitude": {
|
| 133 |
+
"precision": 0.8997134566307068,
|
| 134 |
+
"recall": 0.8770949840545654,
|
| 135 |
+
"f1": 0.8882602556669954,
|
| 136 |
+
"support": 358
|
| 137 |
+
},
|
| 138 |
+
"grief": {
|
| 139 |
+
"precision": 0.0,
|
| 140 |
+
"recall": 0.0,
|
| 141 |
+
"f1": 0.0,
|
| 142 |
+
"support": 13
|
| 143 |
+
},
|
| 144 |
+
"joy": {
|
| 145 |
+
"precision": 0.0,
|
| 146 |
+
"recall": 0.0,
|
| 147 |
+
"f1": 0.0,
|
| 148 |
+
"support": 172
|
| 149 |
+
},
|
| 150 |
+
"love": {
|
| 151 |
+
"precision": 0.6996466517448425,
|
| 152 |
+
"recall": 0.7857142686843872,
|
| 153 |
+
"f1": 0.740186913163602,
|
| 154 |
+
"support": 252
|
| 155 |
+
},
|
| 156 |
+
"nervousness": {
|
| 157 |
+
"precision": 0.0,
|
| 158 |
+
"recall": 0.0,
|
| 159 |
+
"f1": 0.0,
|
| 160 |
+
"support": 21
|
| 161 |
+
},
|
| 162 |
+
"neutral": {
|
| 163 |
+
"precision": 0.6869627237319946,
|
| 164 |
+
"recall": 0.543035089969635,
|
| 165 |
+
"f1": 0.6065780936064032,
|
| 166 |
+
"support": 1766
|
| 167 |
+
},
|
| 168 |
+
"optimism": {
|
| 169 |
+
"precision": 0.7142857313156128,
|
| 170 |
+
"recall": 0.023923445492982864,
|
| 171 |
+
"f1": 0.04629629729995926,
|
| 172 |
+
"support": 209
|
| 173 |
+
},
|
| 174 |
+
"pride": {
|
| 175 |
+
"precision": 0.0,
|
| 176 |
+
"recall": 0.0,
|
| 177 |
+
"f1": 0.0,
|
| 178 |
+
"support": 15
|
| 179 |
+
},
|
| 180 |
+
"realization": {
|
| 181 |
+
"precision": 0.0,
|
| 182 |
+
"recall": 0.0,
|
| 183 |
+
"f1": 0.0,
|
| 184 |
+
"support": 127
|
| 185 |
+
},
|
| 186 |
+
"relief": {
|
| 187 |
+
"precision": 0.0,
|
| 188 |
+
"recall": 0.0,
|
| 189 |
+
"f1": 0.0,
|
| 190 |
+
"support": 18
|
| 191 |
+
},
|
| 192 |
+
"remorse": {
|
| 193 |
+
"precision": 1.0,
|
| 194 |
+
"recall": 0.014705882407724857,
|
| 195 |
+
"f1": 0.02898550735279132,
|
| 196 |
+
"support": 68
|
| 197 |
+
},
|
| 198 |
+
"sadness": {
|
| 199 |
+
"precision": 1.0,
|
| 200 |
+
"recall": 0.0279720276594162,
|
| 201 |
+
"f1": 0.054421768115822285,
|
| 202 |
+
"support": 143
|
| 203 |
+
},
|
| 204 |
+
"surprise": {
|
| 205 |
+
"precision": 0.0,
|
| 206 |
+
"recall": 0.0,
|
| 207 |
+
"f1": 0.0,
|
| 208 |
+
"support": 129
|
| 209 |
+
}
|
| 210 |
+
},
|
| 211 |
+
"tuned_thresholds": {
|
| 212 |
+
"admiration": 0.4,
|
| 213 |
+
"amusement": 0.55,
|
| 214 |
+
"anger": 0.2,
|
| 215 |
+
"annoyance": 0.15,
|
| 216 |
+
"approval": 0.15,
|
| 217 |
+
"caring": 0.1,
|
| 218 |
+
"confusion": 0.1,
|
| 219 |
+
"curiosity": 0.25,
|
| 220 |
+
"desire": 0.15,
|
| 221 |
+
"disappointment": 0.1,
|
| 222 |
+
"disapproval": 0.1,
|
| 223 |
+
"disgust": 0.1,
|
| 224 |
+
"embarrassment": 0.1,
|
| 225 |
+
"excitement": 0.1,
|
| 226 |
+
"fear": 0.1,
|
| 227 |
+
"gratitude": 0.65,
|
| 228 |
+
"grief": 0.1,
|
| 229 |
+
"joy": 0.2,
|
| 230 |
+
"love": 0.45,
|
| 231 |
+
"nervousness": 0.1,
|
| 232 |
+
"neutral": 0.3,
|
| 233 |
+
"optimism": 0.25,
|
| 234 |
+
"pride": 0.1,
|
| 235 |
+
"realization": 0.2,
|
| 236 |
+
"relief": 0.1,
|
| 237 |
+
"remorse": 0.2,
|
| 238 |
+
"sadness": 0.25,
|
| 239 |
+
"surprise": 0.1
|
| 240 |
+
},
|
| 241 |
+
"tuned_macro_f1": 0.29355332255363464,
|
| 242 |
+
"tuned_sample_avg_f1": 0.5025880336761475,
|
| 243 |
+
"tuned_micro_f1": 0.48644566535949707,
|
| 244 |
+
"sample_f1_ci": {
|
| 245 |
+
"mean": 0.3522975795552279,
|
| 246 |
+
"lower": 0.33984518982676004,
|
| 247 |
+
"upper": 0.3658618994962526
|
| 248 |
+
}
|
| 249 |
},
|
| 250 |
"topic": {
|
| 251 |
+
"accuracy": 0.8571428571428571,
|
| 252 |
+
"macro_f1": 0.8538751111963805,
|
| 253 |
+
"num_samples": 189,
|
| 254 |
+
"accuracy_ci": {
|
| 255 |
+
"mean": 0.8571428571428571,
|
| 256 |
+
"lower": 0.8042328042328042,
|
| 257 |
+
"upper": 0.91005291005291
|
| 258 |
+
}
|
| 259 |
}
|
| 260 |
}
|
outputs/training_history.json
CHANGED
|
@@ -1,184 +1,210 @@
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
-
"summarization_loss": 4.
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"summarization_rouge1": 0.
|
| 6 |
-
"summarization_rouge2": 0.
|
| 7 |
-
"summarization_rougeL": 0.
|
| 8 |
-
"summarization_bleu4": 0.
|
| 9 |
-
"emotion_loss": 0.
|
| 10 |
-
"emotion_f1": 0.
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
-
"
|
| 14 |
},
|
| 15 |
"val_epoch_1": {
|
| 16 |
-
"summarization_loss": 3.
|
| 17 |
-
"summarization_rouge_like": 0.
|
| 18 |
-
"summarization_rouge1": 0.
|
| 19 |
-
"summarization_rouge2": 0.
|
| 20 |
-
"summarization_rougeL": 0.
|
| 21 |
-
"summarization_bleu4": 0.
|
| 22 |
-
"emotion_loss": 0.
|
| 23 |
-
"emotion_f1": 0.
|
| 24 |
-
"topic_loss":
|
| 25 |
-
"topic_accuracy": 0.
|
| 26 |
-
"total_loss": 4.
|
| 27 |
},
|
| 28 |
"train_epoch_2": {
|
| 29 |
-
"summarization_loss": 3.
|
| 30 |
-
"summarization_rouge_like": 0.
|
| 31 |
-
"summarization_rouge1": 0.
|
| 32 |
-
"summarization_rouge2": 0.
|
| 33 |
-
"summarization_rougeL": 0.
|
| 34 |
-
"summarization_bleu4": 0.
|
| 35 |
-
"emotion_loss": 0.
|
| 36 |
-
"emotion_f1": 0.
|
| 37 |
-
"topic_loss": 0.
|
| 38 |
-
"topic_accuracy": 0.
|
| 39 |
-
"total_loss":
|
| 40 |
},
|
| 41 |
"val_epoch_2": {
|
| 42 |
-
"summarization_loss": 3.
|
| 43 |
-
"summarization_rouge_like": 0.
|
| 44 |
-
"summarization_rouge1": 0.
|
| 45 |
-
"summarization_rouge2": 0.
|
| 46 |
-
"summarization_rougeL": 0.
|
| 47 |
-
"summarization_bleu4": 0.
|
| 48 |
-
"emotion_loss": 0.
|
| 49 |
-
"emotion_f1": 0.
|
| 50 |
-
"topic_loss": 0.
|
| 51 |
-
"topic_accuracy": 0.
|
| 52 |
-
"total_loss": 4.
|
| 53 |
},
|
| 54 |
"train_epoch_3": {
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
},
|
| 67 |
"val_epoch_3": {
|
| 68 |
-
"summarization_loss": 3.
|
| 69 |
-
"summarization_rouge_like": 0.
|
| 70 |
-
"summarization_rouge1": 0.
|
| 71 |
-
"summarization_rouge2": 0.
|
| 72 |
-
"summarization_rougeL": 0.
|
| 73 |
-
"summarization_bleu4": 0.
|
| 74 |
-
"emotion_loss": 0.
|
| 75 |
-
"emotion_f1": 0.
|
| 76 |
-
"topic_loss": 0.
|
| 77 |
-
"topic_accuracy": 0.
|
| 78 |
-
"total_loss":
|
| 79 |
},
|
| 80 |
"train_epoch_4": {
|
| 81 |
-
"summarization_loss": 3.
|
| 82 |
-
"summarization_rouge_like": 0.
|
| 83 |
-
"summarization_rouge1": 0.
|
| 84 |
-
"summarization_rouge2": 0.
|
| 85 |
-
"summarization_rougeL": 0.
|
| 86 |
-
"summarization_bleu4": 0.
|
| 87 |
-
"emotion_loss": 0.
|
| 88 |
-
"emotion_f1": 0.
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
},
|
| 93 |
"val_epoch_4": {
|
| 94 |
-
"summarization_loss": 3.
|
| 95 |
-
"summarization_rouge_like": 0.
|
| 96 |
-
"summarization_rouge1": 0.
|
| 97 |
-
"summarization_rouge2": 0.
|
| 98 |
-
"summarization_rougeL": 0.
|
| 99 |
-
"summarization_bleu4": 0.
|
| 100 |
-
"emotion_loss": 0.
|
| 101 |
-
"emotion_f1": 0.
|
| 102 |
-
"topic_loss": 0.
|
| 103 |
-
"topic_accuracy": 0.
|
| 104 |
-
"total_loss":
|
| 105 |
},
|
| 106 |
"train_epoch_5": {
|
| 107 |
-
"summarization_loss": 3.
|
| 108 |
-
"summarization_rouge_like": 0.
|
| 109 |
-
"summarization_rouge1": 0.
|
| 110 |
-
"summarization_rouge2": 0.
|
| 111 |
-
"summarization_rougeL": 0.
|
| 112 |
-
"summarization_bleu4": 0.
|
| 113 |
-
"emotion_loss": 0.
|
| 114 |
-
"emotion_f1": 0.
|
| 115 |
-
"
|
| 116 |
-
"
|
| 117 |
-
"
|
| 118 |
},
|
| 119 |
"val_epoch_5": {
|
| 120 |
-
"summarization_loss": 3.
|
| 121 |
-
"summarization_rouge_like": 0.
|
| 122 |
-
"summarization_rouge1": 0.
|
| 123 |
-
"summarization_rouge2": 0.
|
| 124 |
-
"summarization_rougeL": 0.
|
| 125 |
-
"summarization_bleu4": 0.
|
| 126 |
-
"emotion_loss": 0.
|
| 127 |
-
"emotion_f1": 0.
|
| 128 |
-
"topic_loss": 0.
|
| 129 |
-
"topic_accuracy": 0.
|
| 130 |
-
"total_loss":
|
| 131 |
},
|
| 132 |
"train_epoch_6": {
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
"
|
| 142 |
-
"
|
| 143 |
-
"
|
| 144 |
},
|
| 145 |
"val_epoch_6": {
|
| 146 |
-
"summarization_loss": 3.
|
| 147 |
-
"summarization_rouge_like": 0.
|
| 148 |
-
"summarization_rouge1": 0.
|
| 149 |
-
"summarization_rouge2": 0.
|
| 150 |
-
"summarization_rougeL": 0.
|
| 151 |
-
"summarization_bleu4": 0.
|
| 152 |
-
"emotion_loss": 0.
|
| 153 |
-
"emotion_f1": 0.
|
| 154 |
-
"topic_loss": 0.
|
| 155 |
-
"topic_accuracy": 0.
|
| 156 |
-
"total_loss":
|
| 157 |
},
|
| 158 |
"train_epoch_7": {
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
"
|
| 169 |
-
"
|
| 170 |
},
|
| 171 |
"val_epoch_7": {
|
| 172 |
-
"summarization_loss": 3.
|
| 173 |
-
"summarization_rouge_like": 0.
|
| 174 |
-
"summarization_rouge1": 0.
|
| 175 |
-
"summarization_rouge2": 0.
|
| 176 |
-
"summarization_rougeL": 0.
|
| 177 |
-
"summarization_bleu4": 0.
|
| 178 |
-
"emotion_loss": 0.
|
| 179 |
-
"emotion_f1": 0.
|
| 180 |
-
"topic_loss": 0.
|
| 181 |
-
"topic_accuracy": 0.
|
| 182 |
-
"total_loss":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
}
|
| 184 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
+
"summarization_loss": 4.05727732604402,
|
| 4 |
+
"summarization_rouge_like": 0.20427788603502178,
|
| 5 |
+
"summarization_rouge1": 0.2867239527374218,
|
| 6 |
+
"summarization_rouge2": 0.08530419039955006,
|
| 7 |
+
"summarization_rougeL": 0.21671934441779328,
|
| 8 |
+
"summarization_bleu4": 0.046807610480627294,
|
| 9 |
+
"emotion_loss": 0.26667120894821444,
|
| 10 |
+
"emotion_f1": 0.20469974499405505,
|
| 11 |
+
"total_loss": 6.105969995046231,
|
| 12 |
+
"topic_loss": 1.6857517635423032,
|
| 13 |
+
"topic_accuracy": 0.4217703349282306
|
| 14 |
},
|
| 15 |
"val_epoch_1": {
|
| 16 |
+
"summarization_loss": 3.8147981293996174,
|
| 17 |
+
"summarization_rouge_like": 0.2193516213078271,
|
| 18 |
+
"summarization_rouge1": 0.26079796060659194,
|
| 19 |
+
"summarization_rouge2": 0.08507403927823329,
|
| 20 |
+
"summarization_rougeL": 0.2006794257877804,
|
| 21 |
+
"summarization_bleu4": 0.047830237595825456,
|
| 22 |
+
"emotion_loss": 0.14947238216797512,
|
| 23 |
+
"emotion_f1": 0.19722222675879797,
|
| 24 |
+
"topic_loss": 1.111328324476878,
|
| 25 |
+
"topic_accuracy": 0.7036666666666669,
|
| 26 |
+
"total_loss": 4.297669008910658
|
| 27 |
},
|
| 28 |
"train_epoch_2": {
|
| 29 |
+
"summarization_loss": 3.849853239841521,
|
| 30 |
+
"summarization_rouge_like": 0.21403927033293962,
|
| 31 |
+
"summarization_rouge1": 0.27951076939717684,
|
| 32 |
+
"summarization_rouge2": 0.0873046768836161,
|
| 33 |
+
"summarization_rougeL": 0.21374118927203542,
|
| 34 |
+
"summarization_bleu4": 0.04958577524880755,
|
| 35 |
+
"emotion_loss": 0.14221051054989492,
|
| 36 |
+
"emotion_f1": 0.26357290302542585,
|
| 37 |
+
"topic_loss": 0.7299397268663149,
|
| 38 |
+
"topic_accuracy": 0.8084686774942008,
|
| 39 |
+
"total_loss": 5.528282663174978
|
| 40 |
},
|
| 41 |
"val_epoch_2": {
|
| 42 |
+
"summarization_loss": 3.738964385986328,
|
| 43 |
+
"summarization_rouge_like": 0.22322817933854347,
|
| 44 |
+
"summarization_rouge1": 0.2648447987903156,
|
| 45 |
+
"summarization_rouge2": 0.08777067266852198,
|
| 46 |
+
"summarization_rougeL": 0.2049718124413594,
|
| 47 |
+
"summarization_bleu4": 0.04980800809043137,
|
| 48 |
+
"emotion_loss": 0.1332480485116442,
|
| 49 |
+
"emotion_f1": 0.3008111199736595,
|
| 50 |
+
"topic_loss": 0.5171811254819234,
|
| 51 |
+
"topic_accuracy": 0.8467777777777786,
|
| 52 |
+
"total_loss": 4.027366772142546
|
| 53 |
},
|
| 54 |
"train_epoch_3": {
|
| 55 |
+
"emotion_loss": 0.12831888329927568,
|
| 56 |
+
"emotion_f1": 0.3325013413977316,
|
| 57 |
+
"summarization_loss": 3.7839796767703127,
|
| 58 |
+
"summarization_rouge_like": 0.21797276831106976,
|
| 59 |
+
"summarization_rouge1": 0.28868883384124916,
|
| 60 |
+
"summarization_rouge2": 0.09150032176337587,
|
| 61 |
+
"summarization_rougeL": 0.22148013487440707,
|
| 62 |
+
"summarization_bleu4": 0.052993168973641876,
|
| 63 |
+
"total_loss": 5.379445686122572,
|
| 64 |
+
"topic_loss": 0.3385182340765703,
|
| 65 |
+
"topic_accuracy": 0.9149137451307789
|
| 66 |
},
|
| 67 |
"val_epoch_3": {
|
| 68 |
+
"summarization_loss": 3.699807391166687,
|
| 69 |
+
"summarization_rouge_like": 0.22613490382620294,
|
| 70 |
+
"summarization_rouge1": 0.27110048990501884,
|
| 71 |
+
"summarization_rouge2": 0.09042725720607361,
|
| 72 |
+
"summarization_rougeL": 0.209904253200661,
|
| 73 |
+
"summarization_bleu4": 0.05177241093143676,
|
| 74 |
+
"emotion_loss": 0.12147359546273946,
|
| 75 |
+
"emotion_f1": 0.3474666798238953,
|
| 76 |
+
"topic_loss": 0.5068136086066564,
|
| 77 |
+
"topic_accuracy": 0.8417777777777792,
|
| 78 |
+
"total_loss": 3.9733250692114286
|
| 79 |
},
|
| 80 |
"train_epoch_4": {
|
| 81 |
+
"summarization_loss": 3.746917572488457,
|
| 82 |
+
"summarization_rouge_like": 0.22054338132572013,
|
| 83 |
+
"summarization_rouge1": 0.29700759128401966,
|
| 84 |
+
"summarization_rouge2": 0.09528349132659034,
|
| 85 |
+
"summarization_rougeL": 0.2286643637324592,
|
| 86 |
+
"summarization_bleu4": 0.05591190647915982,
|
| 87 |
+
"emotion_loss": 0.12003502780097021,
|
| 88 |
+
"emotion_f1": 0.37240424536844824,
|
| 89 |
+
"total_loss": 5.303101773435515,
|
| 90 |
+
"topic_loss": 0.19978291214297147,
|
| 91 |
+
"topic_accuracy": 0.9528935185185234
|
| 92 |
},
|
| 93 |
"val_epoch_4": {
|
| 94 |
+
"summarization_loss": 3.6773871207237243,
|
| 95 |
+
"summarization_rouge_like": 0.22730110361278533,
|
| 96 |
+
"summarization_rouge1": 0.2719731929407321,
|
| 97 |
+
"summarization_rouge2": 0.09117786246379923,
|
| 98 |
+
"summarization_rougeL": 0.21082587270737135,
|
| 99 |
+
"summarization_bleu4": 0.052260125383420154,
|
| 100 |
+
"emotion_loss": 0.11476812147845825,
|
| 101 |
+
"emotion_f1": 0.40390001876900594,
|
| 102 |
+
"topic_loss": 0.5311758625507355,
|
| 103 |
+
"topic_accuracy": 0.8574444444444455,
|
| 104 |
+
"total_loss": 3.95150800096741
|
| 105 |
},
|
| 106 |
"train_epoch_5": {
|
| 107 |
+
"summarization_loss": 3.72376742684834,
|
| 108 |
+
"summarization_rouge_like": 0.22218972657959773,
|
| 109 |
+
"summarization_rouge1": 0.30386172952451457,
|
| 110 |
+
"summarization_rouge2": 0.09807265293507532,
|
| 111 |
+
"summarization_rougeL": 0.23422938393417417,
|
| 112 |
+
"summarization_bleu4": 0.05821407514551748,
|
| 113 |
+
"emotion_loss": 0.11460309708431649,
|
| 114 |
+
"emotion_f1": 0.41015538037428334,
|
| 115 |
+
"total_loss": 5.207888234798891,
|
| 116 |
+
"topic_loss": 0.13986067138923575,
|
| 117 |
+
"topic_accuracy": 0.9685236768802278
|
| 118 |
},
|
| 119 |
"val_epoch_5": {
|
| 120 |
+
"summarization_loss": 3.664777074654897,
|
| 121 |
+
"summarization_rouge_like": 0.22876987463000684,
|
| 122 |
+
"summarization_rouge1": 0.27596093399625565,
|
| 123 |
+
"summarization_rouge2": 0.09296804123657829,
|
| 124 |
+
"summarization_rougeL": 0.21411928790828857,
|
| 125 |
+
"summarization_bleu4": 0.05366559404113782,
|
| 126 |
+
"emotion_loss": 0.11044646929949523,
|
| 127 |
+
"emotion_f1": 0.4313555757453044,
|
| 128 |
+
"topic_loss": 0.5484664579232533,
|
| 129 |
+
"topic_accuracy": 0.8627777777777789,
|
| 130 |
+
"total_loss": 3.9397634813313704
|
| 131 |
},
|
| 132 |
"train_epoch_6": {
|
| 133 |
+
"emotion_loss": 0.1111307007874511,
|
| 134 |
+
"emotion_f1": 0.43345397762862603,
|
| 135 |
+
"summarization_loss": 3.7095002406409807,
|
| 136 |
+
"summarization_rouge_like": 0.22328726116125275,
|
| 137 |
+
"summarization_rouge1": 0.3064035344877472,
|
| 138 |
+
"summarization_rouge2": 0.09935359454486654,
|
| 139 |
+
"summarization_rougeL": 0.23650841461700828,
|
| 140 |
+
"summarization_bleu4": 0.059165680810364656,
|
| 141 |
+
"total_loss": 5.231221632164746,
|
| 142 |
+
"topic_loss": 0.10774352340420275,
|
| 143 |
+
"topic_accuracy": 0.9777777777777826
|
| 144 |
},
|
| 145 |
"val_epoch_6": {
|
| 146 |
+
"summarization_loss": 3.658109269142151,
|
| 147 |
+
"summarization_rouge_like": 0.22934290201883448,
|
| 148 |
+
"summarization_rouge1": 0.2752052666208255,
|
| 149 |
+
"summarization_rouge2": 0.09292038370832255,
|
| 150 |
+
"summarization_rougeL": 0.2137414809166316,
|
| 151 |
+
"summarization_bleu4": 0.053427338475007496,
|
| 152 |
+
"emotion_loss": 0.10808507531881333,
|
| 153 |
+
"emotion_f1": 0.4517777989556392,
|
| 154 |
+
"topic_loss": 0.5295590771238009,
|
| 155 |
+
"topic_accuracy": 0.8734444444444451,
|
| 156 |
+
"total_loss": 3.9250620675981014
|
| 157 |
},
|
| 158 |
"train_epoch_7": {
|
| 159 |
+
"emotion_loss": 0.10953594371440704,
|
| 160 |
+
"emotion_f1": 0.44384909393388133,
|
| 161 |
+
"topic_loss": 0.0957411853224039,
|
| 162 |
+
"topic_accuracy": 0.980394366197187,
|
| 163 |
+
"total_loss": 5.154093306898701,
|
| 164 |
+
"summarization_loss": 3.7035266418583594,
|
| 165 |
+
"summarization_rouge_like": 0.22378105952974536,
|
| 166 |
+
"summarization_rouge1": 0.3070619920824417,
|
| 167 |
+
"summarization_rouge2": 0.09984959921270933,
|
| 168 |
+
"summarization_rougeL": 0.23710279675635842,
|
| 169 |
+
"summarization_bleu4": 0.05954598113800495
|
| 170 |
},
|
| 171 |
"val_epoch_7": {
|
| 172 |
+
"summarization_loss": 3.654966928164164,
|
| 173 |
+
"summarization_rouge_like": 0.2296679906954514,
|
| 174 |
+
"summarization_rouge1": 0.27616327736195406,
|
| 175 |
+
"summarization_rouge2": 0.09329265746038877,
|
| 176 |
+
"summarization_rougeL": 0.2144202156909426,
|
| 177 |
+
"summarization_bleu4": 0.05381191556748925,
|
| 178 |
+
"emotion_loss": 0.10733611459533374,
|
| 179 |
+
"emotion_f1": 0.4582889095693827,
|
| 180 |
+
"topic_loss": 0.5517185291647911,
|
| 181 |
+
"topic_accuracy": 0.8574444444444457,
|
| 182 |
+
"total_loss": 3.9278186015089296
|
| 183 |
+
},
|
| 184 |
+
"train_epoch_8": {
|
| 185 |
+
"summarization_loss": 3.6991967220660666,
|
| 186 |
+
"summarization_rouge_like": 0.22392498422300275,
|
| 187 |
+
"summarization_rouge1": 0.30751530664889926,
|
| 188 |
+
"summarization_rouge2": 0.10003700619268063,
|
| 189 |
+
"summarization_rougeL": 0.23750205422812004,
|
| 190 |
+
"summarization_bleu4": 0.05974583539783897,
|
| 191 |
+
"emotion_loss": 0.10842480880968565,
|
| 192 |
+
"emotion_f1": 0.44919879924130307,
|
| 193 |
+
"total_loss": 5.178375478314296,
|
| 194 |
+
"topic_loss": 0.09057999134229244,
|
| 195 |
+
"topic_accuracy": 0.9817882611080675
|
| 196 |
+
},
|
| 197 |
+
"val_epoch_8": {
|
| 198 |
+
"summarization_loss": 3.652825821240743,
|
| 199 |
+
"summarization_rouge_like": 0.22990084413830203,
|
| 200 |
+
"summarization_rouge1": 0.2765755402183266,
|
| 201 |
+
"summarization_rouge2": 0.09348690574327727,
|
| 202 |
+
"summarization_rougeL": 0.2147442273650338,
|
| 203 |
+
"summarization_bleu4": 0.053967258926407226,
|
| 204 |
+
"emotion_loss": 0.10670269103099903,
|
| 205 |
+
"emotion_f1": 0.4594111327578624,
|
| 206 |
+
"topic_loss": 0.5511919154723486,
|
| 207 |
+
"topic_accuracy": 0.8574444444444457,
|
| 208 |
+
"total_loss": 3.924886086913453
|
| 209 |
}
|
| 210 |
}
|
pyproject.toml
CHANGED
|
@@ -39,7 +39,7 @@ mlflow = ">=2.0.0"
|
|
| 39 |
sentencepiece = ">=0.1.99"
|
| 40 |
triton = { version = "*", markers = "sys_platform == 'linux'" }
|
| 41 |
|
| 42 |
-
[tool.poetry.
|
| 43 |
pytest = "^7.4.0"
|
| 44 |
pytest-cov = "^4.1.0"
|
| 45 |
ruff = "^0.4.0"
|
|
@@ -106,7 +106,11 @@ module = [
|
|
| 106 |
"fastapi.*",
|
| 107 |
"mlflow.*",
|
| 108 |
"pydantic.*",
|
| 109 |
-
"rouge_score.*"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
]
|
| 111 |
ignore_missing_imports = true
|
| 112 |
follow_imports = "skip"
|
|
|
|
| 39 |
sentencepiece = ">=0.1.99"
|
| 40 |
triton = { version = "*", markers = "sys_platform == 'linux'" }
|
| 41 |
|
| 42 |
+
[tool.poetry.dev-dependencies]
|
| 43 |
pytest = "^7.4.0"
|
| 44 |
pytest-cov = "^4.1.0"
|
| 45 |
ruff = "^0.4.0"
|
|
|
|
| 106 |
"fastapi.*",
|
| 107 |
"mlflow.*",
|
| 108 |
"pydantic.*",
|
| 109 |
+
"rouge_score.*",
|
| 110 |
+
"bert_score.*",
|
| 111 |
+
"pytest",
|
| 112 |
+
"pytest.*",
|
| 113 |
+
"mpl_toolkits.*"
|
| 114 |
]
|
| 115 |
ignore_missing_imports = true
|
| 116 |
follow_imports = "skip"
|
scripts/build_discovery_dataset.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""Build a discovery dataset for the HuggingFace Space demo.
|
| 3 |
|
| 4 |
This script samples from the already-filtered training data (processed by
|
|
@@ -22,12 +21,11 @@ from typing import Any
|
|
| 22 |
# Add project root to path
|
| 23 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 24 |
|
| 25 |
-
import torch
|
| 26 |
-
from datasets import Dataset
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
|
| 29 |
-
from src.inference.factory import create_inference_pipeline
|
| 30 |
|
|
|
|
| 31 |
|
| 32 |
# --------------- Data Loading ---------------
|
| 33 |
|
|
@@ -176,8 +174,8 @@ def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
|
|
| 176 |
results.append(result)
|
| 177 |
|
| 178 |
# Print distribution stats
|
| 179 |
-
topic_dist = defaultdict(int)
|
| 180 |
-
emotion_dist = defaultdict(int)
|
| 181 |
for r in results:
|
| 182 |
topic_dist[r["topic"]] += 1
|
| 183 |
emotion_dist[r["emotion"]] += 1
|
|
|
|
|
|
|
| 1 |
"""Build a discovery dataset for the HuggingFace Space demo.
|
| 2 |
|
| 3 |
This script samples from the already-filtered training data (processed by
|
|
|
|
| 21 |
# Add project root to path
|
| 22 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 23 |
|
| 24 |
+
import torch # noqa: E402
|
| 25 |
+
from datasets import Dataset # noqa: E402
|
| 26 |
+
from tqdm import tqdm # noqa: E402
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
from src.inference.factory import create_inference_pipeline # noqa: E402
|
| 29 |
|
| 30 |
# --------------- Data Loading ---------------
|
| 31 |
|
|
|
|
| 174 |
results.append(result)
|
| 175 |
|
| 176 |
# Print distribution stats
|
| 177 |
+
topic_dist: dict[str, int] = defaultdict(int)
|
| 178 |
+
emotion_dist: dict[str, int] = defaultdict(int)
|
| 179 |
for r in results:
|
| 180 |
topic_dist[r["topic"]] += 1
|
| 181 |
emotion_dist[r["emotion"]] += 1
|
scripts/demo_gradio.py
CHANGED
|
@@ -93,10 +93,8 @@ def format_item_card(item: dict) -> str:
|
|
| 93 |
|
| 94 |
# Icon based on type
|
| 95 |
if source_type == "academic":
|
| 96 |
-
icon = "📄"
|
| 97 |
type_label = "Research Paper"
|
| 98 |
else:
|
| 99 |
-
icon = "📖"
|
| 100 |
type_label = "Literature"
|
| 101 |
|
| 102 |
# Topic and emotion with confidence
|
|
@@ -109,10 +107,10 @@ def format_item_card(item: dict) -> str:
|
|
| 109 |
use_reference = item.get("use_reference_summary", False)
|
| 110 |
if use_reference or source_type == "literary":
|
| 111 |
summary = item.get("reference_summary", "")
|
| 112 |
-
summary_label = "
|
| 113 |
else:
|
| 114 |
summary = item.get("generated_summary", "")
|
| 115 |
-
summary_label = "
|
| 116 |
|
| 117 |
if not summary:
|
| 118 |
summary = "No summary available."
|
|
@@ -124,23 +122,17 @@ def format_item_card(item: dict) -> str:
|
|
| 124 |
# Preview of original text
|
| 125 |
text_preview = item.get("text", "")[:400] + "..." if len(item.get("text", "")) > 400 else item.get("text", "")
|
| 126 |
|
| 127 |
-
|
| 128 |
-
topic_badge = "🟢" if topic_conf > 0.6 else "🟡" if topic_conf > 0.3 else "🔴"
|
| 129 |
-
emotion_badge = "🟢" if emotion_conf > 0.6 else "🟡" if emotion_conf > 0.3 else "🔴"
|
| 130 |
-
|
| 131 |
-
return f"""### {icon} **{title}**
|
| 132 |
|
| 133 |
<small>*{type_label}* from {dataset_name}</small>
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|-------|---------|
|
| 137 |
-
| {topic_badge} {topic} ({topic_conf:.0%}) | {emotion_badge} {emotion.title()} ({emotion_conf:.0%}) |
|
| 138 |
|
| 139 |
{summary_label}
|
| 140 |
> {summary}
|
| 141 |
|
| 142 |
<details>
|
| 143 |
-
<summary
|
| 144 |
|
| 145 |
{text_preview}
|
| 146 |
|
|
@@ -164,12 +156,12 @@ def browse_by_topic(topic: str) -> str:
|
|
| 164 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 165 |
|
| 166 |
if literary:
|
| 167 |
-
result += "###
|
| 168 |
for item in literary[:25]: # Limit to avoid huge pages
|
| 169 |
result += format_item_card(item)
|
| 170 |
|
| 171 |
if academic:
|
| 172 |
-
result += "###
|
| 173 |
for item in academic[:25]:
|
| 174 |
result += format_item_card(item)
|
| 175 |
|
|
@@ -189,12 +181,12 @@ def browse_by_emotion(emotion: str) -> str:
|
|
| 189 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 190 |
|
| 191 |
if literary:
|
| 192 |
-
result += "###
|
| 193 |
for item in literary[:25]:
|
| 194 |
result += format_item_card(item)
|
| 195 |
|
| 196 |
if academic:
|
| 197 |
-
result += "###
|
| 198 |
for item in academic[:25]:
|
| 199 |
result += format_item_card(item)
|
| 200 |
|
|
@@ -239,20 +231,10 @@ with gr.Blocks(
|
|
| 239 |
|
| 240 |
gr.Markdown(
|
| 241 |
"""
|
| 242 |
-
#
|
| 243 |
-
###
|
| 244 |
-
|
| 245 |
-
Explore **{total_count}** items analyzed by the LexiMind multi-task transformer:
|
| 246 |
-
|
| 247 |
-
| Source | Count | Description |
|
| 248 |
-
|--------|-------|-------------|
|
| 249 |
-
| 📖 Literature | {lit_count} | Classic books with Goodreads-style descriptions |
|
| 250 |
-
| 📄 Research | {paper_count} | Scientific papers from arXiv |
|
| 251 |
|
| 252 |
-
**
|
| 253 |
-
- 🏷️ **Topic Classification**: Fiction, Science, History, Philosophy, Arts, Business, Technology
|
| 254 |
-
- 💭 **Emotion Detection**: 28 emotions (joy, sadness, anger, fear, surprise, love, etc.)
|
| 255 |
-
- 📝 **Book Descriptions**: Back-cover style summaries of what texts are about
|
| 256 |
|
| 257 |
---
|
| 258 |
""".format(
|
|
@@ -264,7 +246,7 @@ with gr.Blocks(
|
|
| 264 |
|
| 265 |
with gr.Tabs():
|
| 266 |
# ===================== TAB 1: BROWSE BY TOPIC =====================
|
| 267 |
-
with gr.Tab("
|
| 268 |
gr.Markdown("*Select a topic to explore related books and papers*")
|
| 269 |
|
| 270 |
topic_dropdown = gr.Dropdown(
|
|
@@ -286,7 +268,7 @@ with gr.Blocks(
|
|
| 286 |
)
|
| 287 |
|
| 288 |
# ===================== TAB 2: BROWSE BY EMOTION =====================
|
| 289 |
-
with gr.Tab("
|
| 290 |
gr.Markdown("*Find books and papers that evoke specific emotions*")
|
| 291 |
|
| 292 |
emotion_dropdown = gr.Dropdown(
|
|
@@ -308,7 +290,7 @@ with gr.Blocks(
|
|
| 308 |
)
|
| 309 |
|
| 310 |
# ===================== TAB 3: SEARCH =====================
|
| 311 |
-
with gr.Tab("
|
| 312 |
gr.Markdown("*Search through all books and papers by keyword*")
|
| 313 |
|
| 314 |
search_input = gr.Textbox(
|
|
@@ -329,45 +311,39 @@ with gr.Blocks(
|
|
| 329 |
)
|
| 330 |
|
| 331 |
# ===================== TAB 4: METRICS =====================
|
| 332 |
-
with gr.Tab("
|
| 333 |
gr.Markdown(
|
| 334 |
"""
|
| 335 |
### Evaluation Metrics
|
| 336 |
|
| 337 |
-
|
| 338 |
-
Metrics are computed on held-out validation data.
|
| 339 |
"""
|
| 340 |
)
|
| 341 |
|
| 342 |
# Summarization Metrics
|
| 343 |
-
gr.Markdown("####
|
| 344 |
|
| 345 |
if METRICS.get("summarization"):
|
| 346 |
summ = METRICS["summarization"]
|
| 347 |
summ_md = """
|
| 348 |
-
| Metric | Score |
|
| 349 |
-
|
| 350 |
-
| **ROUGE-1** | {rouge1:.4f} |
|
| 351 |
-
| **ROUGE-2** | {rouge2:.4f} |
|
| 352 |
-
| **ROUGE-L** | {rougeL:.4f} |
|
| 353 |
-
| **BLEU-4** | {bleu4:.4f} |
|
| 354 |
-
| **BERTScore F1** | {bertscore:.4f} | Semantic similarity (contextual) |
|
| 355 |
-
|
| 356 |
-
*Note: For back-cover style descriptions, BERTScore is more meaningful than ROUGE
|
| 357 |
-
since descriptions paraphrase rather than quote the source text.*
|
| 358 |
""".format(
|
| 359 |
rouge1=summ.get("rouge_rouge1", summ.get("rouge1", 0)),
|
| 360 |
rouge2=summ.get("rouge_rouge2", summ.get("rouge2", 0)),
|
| 361 |
rougeL=summ.get("rouge_rougeL", summ.get("rougeL", 0)),
|
| 362 |
bleu4=summ.get("bleu4", 0),
|
| 363 |
-
bertscore=summ.get("bertscore_f1", 0),
|
| 364 |
)
|
| 365 |
gr.Markdown(summ_md)
|
| 366 |
else:
|
| 367 |
gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
|
| 368 |
|
| 369 |
# Topic Classification Metrics
|
| 370 |
-
gr.Markdown("####
|
| 371 |
|
| 372 |
if METRICS.get("topic"):
|
| 373 |
topic = METRICS["topic"]
|
|
@@ -376,125 +352,66 @@ since descriptions paraphrase rather than quote the source text.*
|
|
| 376 |
|--------|-------|
|
| 377 |
| **Accuracy** | {accuracy:.2%} |
|
| 378 |
| **Macro F1** | {f1:.4f} |
|
| 379 |
-
| **Precision** | {precision:.4f} |
|
| 380 |
-
| **Recall** | {recall:.4f} |
|
| 381 |
""".format(
|
| 382 |
accuracy=topic.get("accuracy", 0),
|
| 383 |
f1=topic.get("f1", topic.get("macro_f1", 0)),
|
| 384 |
-
precision=topic.get("precision", 0),
|
| 385 |
-
recall=topic.get("recall", 0),
|
| 386 |
)
|
| 387 |
gr.Markdown(topic_md)
|
| 388 |
else:
|
| 389 |
gr.Markdown("*Topic classification metrics not available.*")
|
| 390 |
|
| 391 |
# Emotion Detection Metrics
|
| 392 |
-
gr.Markdown("####
|
| 393 |
|
| 394 |
if METRICS.get("emotion"):
|
| 395 |
emotion = METRICS["emotion"]
|
| 396 |
emotion_md = """
|
| 397 |
| Metric | Score |
|
| 398 |
|--------|-------|
|
| 399 |
-
| **
|
| 400 |
-
| **
|
| 401 |
-
| **
|
| 402 |
|
| 403 |
-
*
|
| 404 |
""".format(
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
)
|
| 409 |
gr.Markdown(emotion_md)
|
| 410 |
else:
|
| 411 |
gr.Markdown("*Emotion detection metrics not available.*")
|
| 412 |
|
| 413 |
# Dataset Statistics
|
| 414 |
-
gr.Markdown("####
|
| 415 |
-
|
| 416 |
-
# Build topic list with indicators for observed vs possible
|
| 417 |
-
topic_list = ", ".join([
|
| 418 |
-
f"**{t}**" if t in TOPICS else t for t in ALL_TOPICS
|
| 419 |
-
])
|
| 420 |
-
emotion_list = ", ".join([
|
| 421 |
-
f"**{e}**" if e in EMOTIONS else e for e in ALL_EMOTIONS
|
| 422 |
-
])
|
| 423 |
|
| 424 |
gr.Markdown(f"""
|
| 425 |
| Statistic | Value |
|
| 426 |
|-----------|-------|
|
| 427 |
-
| Total
|
| 428 |
| Literary Works | {len(BOOKS)} |
|
| 429 |
-
| Academic Papers
|
| 430 |
-
| Topics
|
| 431 |
-
| Emotions
|
| 432 |
-
|
| 433 |
-
**All Model Topics ({len(ALL_TOPICS)}):** {topic_list}
|
| 434 |
-
|
| 435 |
-
**All Model Emotions ({len(ALL_EMOTIONS)}):** {emotion_list}
|
| 436 |
-
|
| 437 |
-
*Bold items appear in the discovery dataset. The model can predict all listed labels.*
|
| 438 |
-
|
| 439 |
-
---
|
| 440 |
-
|
| 441 |
-
**Note on Content Types:**
|
| 442 |
-
- 📄 **Academic Papers** include CS/AI papers (Technology), Physics/Math (Science), Economics (Business)
|
| 443 |
-
- 📖 **Literary Works** include novels (Fiction), biographies (History), philosophical texts (Philosophy)
|
| 444 |
-
- Technical blogs and tutorials would be classified under **Technology**
|
| 445 |
""")
|
| 446 |
|
| 447 |
# ===================== TAB 5: ABOUT =====================
|
| 448 |
-
with gr.Tab("
|
| 449 |
gr.Markdown(
|
| 450 |
"""
|
| 451 |
### About LexiMind
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
| Task | Description |
|
| 456 |
-
|------|-------------|
|
| 457 |
-
| **Book Descriptions** | Generate back-cover style descriptions of what books are about |
|
| 458 |
-
| **Topic Classification** | Categorize into Fiction, Science, Technology, Philosophy, History, Business, Arts |
|
| 459 |
-
| **Emotion Detection** | Identify emotional tones (28 emotions from GoEmotions) |
|
| 460 |
-
|
| 461 |
-
### Architecture
|
| 462 |
-
|
| 463 |
-
- **Base:** FLAN-T5-base (Google)
|
| 464 |
-
- **Encoder:** 12 layers, 768 dim, 12 attention heads
|
| 465 |
-
- **Decoder:** 12 layers with causal attention
|
| 466 |
-
- **Position:** T5 relative position bias
|
| 467 |
-
- **Training:** Multi-task learning with task-specific heads
|
| 468 |
-
|
| 469 |
-
### Training Data
|
| 470 |
-
|
| 471 |
-
| Dataset | Task | Samples |
|
| 472 |
-
|---------|------|---------|
|
| 473 |
-
| Gutenberg + Goodreads | Book Descriptions | ~4K literary pairs |
|
| 474 |
-
| arXiv (body → abstract) | Paper Abstracts | ~45K academic pairs |
|
| 475 |
-
| 20 Newsgroups + Gutenberg + arXiv | Topic Classification | 3.4K (7 classes) |
|
| 476 |
-
| GoEmotions (Reddit) | Emotion Detection | 43K (28 labels) |
|
| 477 |
-
|
| 478 |
-
### Key Design Decision
|
| 479 |
-
|
| 480 |
-
LexiMind generates **back-cover style descriptions** (what a book is about) rather than
|
| 481 |
-
plot summaries (what happens in the book). This is achieved by training on Goodreads
|
| 482 |
-
descriptions paired with Project Gutenberg book texts.
|
| 483 |
-
|
| 484 |
-
### Evaluation Metrics
|
| 485 |
|
| 486 |
-
- **
|
| 487 |
-
- **
|
| 488 |
-
- **
|
| 489 |
|
| 490 |
-
|
| 491 |
|
| 492 |
-
|
| 493 |
-
- 🤗 [Model](https://huggingface.co/OliverPerrin/LexiMind-Model)
|
| 494 |
-
- 📊 [Discovery Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)
|
| 495 |
|
| 496 |
-
|
| 497 |
-
*Built by Oliver Perrin • Appalachian State University • 2025-2026*
|
| 498 |
"""
|
| 499 |
)
|
| 500 |
|
|
|
|
| 93 |
|
| 94 |
# Icon based on type
|
| 95 |
if source_type == "academic":
|
|
|
|
| 96 |
type_label = "Research Paper"
|
| 97 |
else:
|
|
|
|
| 98 |
type_label = "Literature"
|
| 99 |
|
| 100 |
# Topic and emotion with confidence
|
|
|
|
| 107 |
use_reference = item.get("use_reference_summary", False)
|
| 108 |
if use_reference or source_type == "literary":
|
| 109 |
summary = item.get("reference_summary", "")
|
| 110 |
+
summary_label = "**Book Description:**"
|
| 111 |
else:
|
| 112 |
summary = item.get("generated_summary", "")
|
| 113 |
+
summary_label = "**AI-Generated Description:**"
|
| 114 |
|
| 115 |
if not summary:
|
| 116 |
summary = "No summary available."
|
|
|
|
| 122 |
# Preview of original text
|
| 123 |
text_preview = item.get("text", "")[:400] + "..." if len(item.get("text", "")) > 400 else item.get("text", "")
|
| 124 |
|
| 125 |
+
return f"""### **{title}**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
<small>*{type_label}* from {dataset_name}</small>
|
| 128 |
|
| 129 |
+
**Topic:** {topic} ({topic_conf:.0%}) | **Emotion:** {emotion.title()} ({emotion_conf:.0%})
|
|
|
|
|
|
|
| 130 |
|
| 131 |
{summary_label}
|
| 132 |
> {summary}
|
| 133 |
|
| 134 |
<details>
|
| 135 |
+
<summary>View Original Text</summary>
|
| 136 |
|
| 137 |
{text_preview}
|
| 138 |
|
|
|
|
| 156 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 157 |
|
| 158 |
if literary:
|
| 159 |
+
result += "### Literary Works\n\n"
|
| 160 |
for item in literary[:25]: # Limit to avoid huge pages
|
| 161 |
result += format_item_card(item)
|
| 162 |
|
| 163 |
if academic:
|
| 164 |
+
result += "### Academic Papers\n\n"
|
| 165 |
for item in academic[:25]:
|
| 166 |
result += format_item_card(item)
|
| 167 |
|
|
|
|
| 181 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 182 |
|
| 183 |
if literary:
|
| 184 |
+
result += "### Literary Works\n\n"
|
| 185 |
for item in literary[:25]:
|
| 186 |
result += format_item_card(item)
|
| 187 |
|
| 188 |
if academic:
|
| 189 |
+
result += "### Academic Papers\n\n"
|
| 190 |
for item in academic[:25]:
|
| 191 |
result += format_item_card(item)
|
| 192 |
|
|
|
|
| 231 |
|
| 232 |
gr.Markdown(
|
| 233 |
"""
|
| 234 |
+
# LexiMind
|
| 235 |
+
### Discover Books & Papers by Topic, Emotion, or Keyword
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
Browse **{total_count}** texts — {lit_count} classic books and {paper_count} research papers — analyzed by a multi-task transformer.
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
---
|
| 240 |
""".format(
|
|
|
|
| 246 |
|
| 247 |
with gr.Tabs():
|
| 248 |
# ===================== TAB 1: BROWSE BY TOPIC =====================
|
| 249 |
+
with gr.Tab("By Topic"):
|
| 250 |
gr.Markdown("*Select a topic to explore related books and papers*")
|
| 251 |
|
| 252 |
topic_dropdown = gr.Dropdown(
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
# ===================== TAB 2: BROWSE BY EMOTION =====================
|
| 271 |
+
with gr.Tab("By Emotion"):
|
| 272 |
gr.Markdown("*Find books and papers that evoke specific emotions*")
|
| 273 |
|
| 274 |
emotion_dropdown = gr.Dropdown(
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
# ===================== TAB 3: SEARCH =====================
|
| 293 |
+
with gr.Tab("Search"):
|
| 294 |
gr.Markdown("*Search through all books and papers by keyword*")
|
| 295 |
|
| 296 |
search_input = gr.Textbox(
|
|
|
|
| 311 |
)
|
| 312 |
|
| 313 |
# ===================== TAB 4: METRICS =====================
|
| 314 |
+
with gr.Tab("Metrics"):
|
| 315 |
gr.Markdown(
|
| 316 |
"""
|
| 317 |
### Evaluation Metrics
|
| 318 |
|
| 319 |
+
Computed on held-out validation data.
|
|
|
|
| 320 |
"""
|
| 321 |
)
|
| 322 |
|
| 323 |
# Summarization Metrics
|
| 324 |
+
gr.Markdown("#### Summarization")
|
| 325 |
|
| 326 |
if METRICS.get("summarization"):
|
| 327 |
summ = METRICS["summarization"]
|
| 328 |
summ_md = """
|
| 329 |
+
| Metric | Score |
|
| 330 |
+
|--------|-------|
|
| 331 |
+
| **ROUGE-1** | {rouge1:.4f} |
|
| 332 |
+
| **ROUGE-2** | {rouge2:.4f} |
|
| 333 |
+
| **ROUGE-L** | {rougeL:.4f} |
|
| 334 |
+
| **BLEU-4** | {bleu4:.4f} |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
""".format(
|
| 336 |
rouge1=summ.get("rouge_rouge1", summ.get("rouge1", 0)),
|
| 337 |
rouge2=summ.get("rouge_rouge2", summ.get("rouge2", 0)),
|
| 338 |
rougeL=summ.get("rouge_rougeL", summ.get("rougeL", 0)),
|
| 339 |
bleu4=summ.get("bleu4", 0),
|
|
|
|
| 340 |
)
|
| 341 |
gr.Markdown(summ_md)
|
| 342 |
else:
|
| 343 |
gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
|
| 344 |
|
| 345 |
# Topic Classification Metrics
|
| 346 |
+
gr.Markdown("#### Topic Classification")
|
| 347 |
|
| 348 |
if METRICS.get("topic"):
|
| 349 |
topic = METRICS["topic"]
|
|
|
|
| 352 |
|--------|-------|
|
| 353 |
| **Accuracy** | {accuracy:.2%} |
|
| 354 |
| **Macro F1** | {f1:.4f} |
|
|
|
|
|
|
|
| 355 |
""".format(
|
| 356 |
accuracy=topic.get("accuracy", 0),
|
| 357 |
f1=topic.get("f1", topic.get("macro_f1", 0)),
|
|
|
|
|
|
|
| 358 |
)
|
| 359 |
gr.Markdown(topic_md)
|
| 360 |
else:
|
| 361 |
gr.Markdown("*Topic classification metrics not available.*")
|
| 362 |
|
| 363 |
# Emotion Detection Metrics
|
| 364 |
+
gr.Markdown("#### Emotion Detection")
|
| 365 |
|
| 366 |
if METRICS.get("emotion"):
|
| 367 |
emotion = METRICS["emotion"]
|
| 368 |
emotion_md = """
|
| 369 |
| Metric | Score |
|
| 370 |
|--------|-------|
|
| 371 |
+
| **Sample-avg F1** | {sample_f1:.4f} |
|
| 372 |
+
| **Macro F1** | {macro_f1:.4f} |
|
| 373 |
+
| **Micro F1** | {micro_f1:.4f} |
|
| 374 |
|
| 375 |
+
*28-label multi-label classification from GoEmotions.*
|
| 376 |
""".format(
|
| 377 |
+
sample_f1=emotion.get("sample_avg_f1", emotion.get("f1", emotion.get("multilabel_f1", 0))),
|
| 378 |
+
macro_f1=emotion.get("macro_f1", 0),
|
| 379 |
+
micro_f1=emotion.get("micro_f1", 0),
|
| 380 |
)
|
| 381 |
gr.Markdown(emotion_md)
|
| 382 |
else:
|
| 383 |
gr.Markdown("*Emotion detection metrics not available.*")
|
| 384 |
|
| 385 |
# Dataset Statistics
|
| 386 |
+
gr.Markdown("#### Dataset Statistics")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
gr.Markdown(f"""
|
| 389 |
| Statistic | Value |
|
| 390 |
|-----------|-------|
|
| 391 |
+
| Total Items | {len(ALL_ITEMS)} |
|
| 392 |
| Literary Works | {len(BOOKS)} |
|
| 393 |
+
| Academic Papers | {len(PAPERS)} |
|
| 394 |
+
| Topics | {len(TOPICS)} |
|
| 395 |
+
| Emotions | {len(EMOTIONS)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
""")
|
| 397 |
|
| 398 |
# ===================== TAB 5: ABOUT =====================
|
| 399 |
+
with gr.Tab("About"):
|
| 400 |
gr.Markdown(
|
| 401 |
"""
|
| 402 |
### About LexiMind
|
| 403 |
|
| 404 |
+
A **272M parameter encoder-decoder transformer** (FLAN-T5-base) trained on three tasks:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
- **Summarization**: Generate back-cover style descriptions from full text
|
| 407 |
+
- **Topic Classification**: 7 categories (Fiction, Science, History, Philosophy, Arts, Business, Technology)
|
| 408 |
+
- **Emotion Detection**: 28 emotions via GoEmotions
|
| 409 |
|
| 410 |
+
Training data: ~49K summarization pairs (arXiv + Goodreads), 43K emotion samples, 3.4K topic samples.
|
| 411 |
|
| 412 |
+
[GitHub](https://github.com/OliverPerrin/LexiMind) | [Model](https://huggingface.co/OliverPerrin/LexiMind-Model) | [Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)
|
|
|
|
|
|
|
| 413 |
|
| 414 |
+
*Oliver Perrin — Appalachian State University — 2025-2026*
|
|
|
|
| 415 |
"""
|
| 416 |
)
|
| 417 |
|
scripts/download_data.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# pyright: reportAttributeAccessIssue=false
|
| 3 |
-
# pyright: reportArgumentType=false
|
| 4 |
-
# pyright: reportCallIssue=false
|
| 5 |
"""
|
| 6 |
Dataset download script for LexiMind.
|
| 7 |
|
|
@@ -45,7 +41,7 @@ from tqdm import tqdm
|
|
| 45 |
# Output directory
|
| 46 |
OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
|
| 47 |
|
| 48 |
-
#
|
| 49 |
|
| 50 |
# 28 emotions from GoEmotions - works for all text types
|
| 51 |
EMOTION_LABELS = [
|
|
@@ -115,10 +111,10 @@ def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing"
|
|
| 115 |
with path.open("w", encoding="utf-8") as f:
|
| 116 |
for record in tqdm(records, desc=desc, leave=False):
|
| 117 |
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 118 |
-
print(f"
|
| 119 |
|
| 120 |
|
| 121 |
-
#
|
| 122 |
|
| 123 |
# Common English words for detection
|
| 124 |
ENGLISH_WORDS = {
|
|
@@ -144,7 +140,7 @@ NON_ENGLISH_PATTERNS = [
|
|
| 144 |
r"\b(et|in|ad|cum|de|ex|per|pro|sub|ab|ante|post|inter|contra|super|trans|apud)\b",
|
| 145 |
]
|
| 146 |
|
| 147 |
-
#
|
| 148 |
|
| 149 |
# Patterns that indicate garbage/metadata text
|
| 150 |
GARBAGE_PATTERNS = [
|
|
@@ -320,7 +316,7 @@ def normalize_title(title: str) -> str:
|
|
| 320 |
return title.lower().strip()
|
| 321 |
|
| 322 |
|
| 323 |
-
#
|
| 324 |
|
| 325 |
def download_goodreads_descriptions() -> dict[str, dict]:
|
| 326 |
"""
|
|
@@ -329,7 +325,7 @@ def download_goodreads_descriptions() -> dict[str, dict]:
|
|
| 329 |
These are "what the book is about" descriptions, not plot summaries.
|
| 330 |
Returns dict mapping normalized title -> {title, description}
|
| 331 |
"""
|
| 332 |
-
print("\
|
| 333 |
|
| 334 |
descriptions = {}
|
| 335 |
|
|
@@ -392,7 +388,7 @@ def download_book_descriptions(
|
|
| 392 |
This gives us (book_excerpt, book_description) training pairs where descriptions
|
| 393 |
are back-cover style "what is this book about" blurbs, not plot summaries.
|
| 394 |
"""
|
| 395 |
-
print("\
|
| 396 |
|
| 397 |
try:
|
| 398 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
@@ -497,7 +493,7 @@ def download_booksum(max_samples: int = 20000) -> list[dict[str, Any]]:
|
|
| 497 |
Note: These are chapter-level plot summaries, useful as supplementary training data.
|
| 498 |
The primary book training comes from Goodreads descriptions (back-cover style).
|
| 499 |
"""
|
| 500 |
-
print("\
|
| 501 |
|
| 502 |
all_records: list[dict[str, Any]] = []
|
| 503 |
booksum = load_dataset("kmfoda/booksum")
|
|
@@ -600,7 +596,7 @@ def download_arxiv_summarization(max_samples: int = 50000) -> list[dict[str, Any
|
|
| 600 |
|
| 601 |
Returns: summarization_records
|
| 602 |
"""
|
| 603 |
-
print("\
|
| 604 |
|
| 605 |
print(" Loading dataset (this may take a minute)...")
|
| 606 |
arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
|
|
@@ -663,7 +659,7 @@ def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, An
|
|
| 663 |
- 20 Newsgroups (classic topic classification)
|
| 664 |
- Wikipedia (article categories)
|
| 665 |
"""
|
| 666 |
-
print("\
|
| 667 |
|
| 668 |
records: list[dict[str, Any]] = []
|
| 669 |
|
|
@@ -747,7 +743,7 @@ def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> No
|
|
| 747 |
plot summaries. This trains the model to describe "what the book is about"
|
| 748 |
rather than summarizing the plot.
|
| 749 |
"""
|
| 750 |
-
print("\
|
| 751 |
out_dir = OUTPUT_DIR / "summarization"
|
| 752 |
|
| 753 |
all_records: list[dict[str, Any]] = []
|
|
@@ -793,12 +789,12 @@ def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> No
|
|
| 793 |
# Print breakdown
|
| 794 |
literary_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "literary")
|
| 795 |
academic_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "academic")
|
| 796 |
-
print(f"\n
|
| 797 |
print(f" Literary (book descriptions): {literary_count:,}")
|
| 798 |
print(f" Academic (paper abstracts): {academic_count:,}")
|
| 799 |
|
| 800 |
|
| 801 |
-
#
|
| 802 |
|
| 803 |
def download_topics(max_samples: int = 50000) -> None:
|
| 804 |
"""
|
|
@@ -809,7 +805,7 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 809 |
- Gutenberg books (Fiction)
|
| 810 |
- Scientific papers (Science, Technology)
|
| 811 |
"""
|
| 812 |
-
print("\
|
| 813 |
out_dir = OUTPUT_DIR / "topic"
|
| 814 |
|
| 815 |
# Get topic records from various sources
|
|
@@ -830,14 +826,14 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 830 |
# Balance to min count (with some tolerance) - only from topics that have data
|
| 831 |
counts_with_data = [len(v) for v in topic_counts.values() if v]
|
| 832 |
if not counts_with_data:
|
| 833 |
-
print("
|
| 834 |
return
|
| 835 |
|
| 836 |
min_count = min(counts_with_data)
|
| 837 |
target_count = min(min_count, max_samples // len(TOPIC_LABELS))
|
| 838 |
|
| 839 |
balanced: list[dict[str, Any]] = []
|
| 840 |
-
for
|
| 841 |
if records:
|
| 842 |
random.shuffle(records)
|
| 843 |
balanced.extend(records[:target_count])
|
|
@@ -857,12 +853,12 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 857 |
# Save labels - only labels that have data
|
| 858 |
used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
|
| 859 |
(out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
|
| 860 |
-
print(f"\n
|
| 861 |
|
| 862 |
|
| 863 |
def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
| 864 |
"""Extract topic-labeled samples from Gutenberg books (English only)."""
|
| 865 |
-
print("\
|
| 866 |
|
| 867 |
try:
|
| 868 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
@@ -926,11 +922,11 @@ def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
|
| 926 |
return records
|
| 927 |
|
| 928 |
|
| 929 |
-
#
|
| 930 |
|
| 931 |
def download_emotions() -> None:
|
| 932 |
"""Download GoEmotions for emotion classification."""
|
| 933 |
-
print("\
|
| 934 |
out_dir = OUTPUT_DIR / "emotion"
|
| 935 |
|
| 936 |
ds = load_dataset("google-research-datasets/go_emotions", "simplified")
|
|
@@ -950,10 +946,10 @@ def download_emotions() -> None:
|
|
| 950 |
write_jsonl(records, out_dir / f"{split}.jsonl", split)
|
| 951 |
|
| 952 |
(out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
|
| 953 |
-
print(f"
|
| 954 |
|
| 955 |
|
| 956 |
-
#
|
| 957 |
|
| 958 |
GUTENBERG_JUNK_PATTERNS = [
|
| 959 |
r"Project Gutenberg", r"www\.gutenberg\.org", r"This ebook is for",
|
|
@@ -988,7 +984,7 @@ def is_clean_prose(text: str) -> bool:
|
|
| 988 |
|
| 989 |
def download_gutenberg(max_samples: int = 30000) -> None:
|
| 990 |
"""Download Gutenberg books for language modeling (English only)."""
|
| 991 |
-
print("\
|
| 992 |
out_dir = OUTPUT_DIR / "books"
|
| 993 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 994 |
|
|
@@ -1044,7 +1040,7 @@ def download_gutenberg(max_samples: int = 30000) -> None:
|
|
| 1044 |
write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
|
| 1045 |
|
| 1046 |
|
| 1047 |
-
#
|
| 1048 |
|
| 1049 |
def main() -> None:
|
| 1050 |
parser = argparse.ArgumentParser(description="Download LexiMind datasets")
|
|
@@ -1078,7 +1074,7 @@ def main() -> None:
|
|
| 1078 |
download_gutenberg(args.max_gutenberg)
|
| 1079 |
|
| 1080 |
print("\n" + "=" * 60)
|
| 1081 |
-
print("
|
| 1082 |
print("=" * 60)
|
| 1083 |
|
| 1084 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Dataset download script for LexiMind.
|
| 3 |
|
|
|
|
| 41 |
# Output directory
|
| 42 |
OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
|
| 43 |
|
| 44 |
+
# ------------ LABEL DEFINITIONS ------------
|
| 45 |
|
| 46 |
# 28 emotions from GoEmotions - works for all text types
|
| 47 |
EMOTION_LABELS = [
|
|
|
|
| 111 |
with path.open("w", encoding="utf-8") as f:
|
| 112 |
for record in tqdm(records, desc=desc, leave=False):
|
| 113 |
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 114 |
+
print(f" {len(records):,} samples -> {path}")
|
| 115 |
|
| 116 |
|
| 117 |
+
# ------------ ENGLISH LANGUAGE FILTER ------------
|
| 118 |
|
| 119 |
# Common English words for detection
|
| 120 |
ENGLISH_WORDS = {
|
|
|
|
| 140 |
r"\b(et|in|ad|cum|de|ex|per|pro|sub|ab|ante|post|inter|contra|super|trans|apud)\b",
|
| 141 |
]
|
| 142 |
|
| 143 |
+
# ------------ TEXT QUALITY FILTERS ------------
|
| 144 |
|
| 145 |
# Patterns that indicate garbage/metadata text
|
| 146 |
GARBAGE_PATTERNS = [
|
|
|
|
| 316 |
return title.lower().strip()
|
| 317 |
|
| 318 |
|
| 319 |
+
# -------- SUMMARIZATION: BOOKS + ARXIV ----------
|
| 320 |
|
| 321 |
def download_goodreads_descriptions() -> dict[str, dict]:
|
| 322 |
"""
|
|
|
|
| 325 |
These are "what the book is about" descriptions, not plot summaries.
|
| 326 |
Returns dict mapping normalized title -> {title, description}
|
| 327 |
"""
|
| 328 |
+
print("\nLoading Goodreads book descriptions...")
|
| 329 |
|
| 330 |
descriptions = {}
|
| 331 |
|
|
|
|
| 388 |
This gives us (book_excerpt, book_description) training pairs where descriptions
|
| 389 |
are back-cover style "what is this book about" blurbs, not plot summaries.
|
| 390 |
"""
|
| 391 |
+
print("\nMatching Gutenberg books with Goodreads descriptions...")
|
| 392 |
|
| 393 |
try:
|
| 394 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
|
|
| 493 |
Note: These are chapter-level plot summaries, useful as supplementary training data.
|
| 494 |
The primary book training comes from Goodreads descriptions (back-cover style).
|
| 495 |
"""
|
| 496 |
+
print("\nLoading BookSum (supplementary literary data)...")
|
| 497 |
|
| 498 |
all_records: list[dict[str, Any]] = []
|
| 499 |
booksum = load_dataset("kmfoda/booksum")
|
|
|
|
| 596 |
|
| 597 |
Returns: summarization_records
|
| 598 |
"""
|
| 599 |
+
print("\nLoading arXiv (academic papers for summarization)...")
|
| 600 |
|
| 601 |
print(" Loading dataset (this may take a minute)...")
|
| 602 |
arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
|
|
|
|
| 659 |
- 20 Newsgroups (classic topic classification)
|
| 660 |
- Wikipedia (article categories)
|
| 661 |
"""
|
| 662 |
+
print("\nLoading topic classification datasets...")
|
| 663 |
|
| 664 |
records: list[dict[str, Any]] = []
|
| 665 |
|
|
|
|
| 743 |
plot summaries. This trains the model to describe "what the book is about"
|
| 744 |
rather than summarizing the plot.
|
| 745 |
"""
|
| 746 |
+
print("\nDownloading Summarization Data...")
|
| 747 |
out_dir = OUTPUT_DIR / "summarization"
|
| 748 |
|
| 749 |
all_records: list[dict[str, Any]] = []
|
|
|
|
| 789 |
# Print breakdown
|
| 790 |
literary_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "literary")
|
| 791 |
academic_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "academic")
|
| 792 |
+
print(f"\n Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
|
| 793 |
print(f" Literary (book descriptions): {literary_count:,}")
|
| 794 |
print(f" Academic (paper abstracts): {academic_count:,}")
|
| 795 |
|
| 796 |
|
| 797 |
+
# ------------ TOPIC CLASSIFICATION ------------
|
| 798 |
|
| 799 |
def download_topics(max_samples: int = 50000) -> None:
|
| 800 |
"""
|
|
|
|
| 805 |
- Gutenberg books (Fiction)
|
| 806 |
- Scientific papers (Science, Technology)
|
| 807 |
"""
|
| 808 |
+
print("\nDownloading Topic Classification...")
|
| 809 |
out_dir = OUTPUT_DIR / "topic"
|
| 810 |
|
| 811 |
# Get topic records from various sources
|
|
|
|
| 826 |
# Balance to min count (with some tolerance) - only from topics that have data
|
| 827 |
counts_with_data = [len(v) for v in topic_counts.values() if v]
|
| 828 |
if not counts_with_data:
|
| 829 |
+
print(" Warning: No topic data found!")
|
| 830 |
return
|
| 831 |
|
| 832 |
min_count = min(counts_with_data)
|
| 833 |
target_count = min(min_count, max_samples // len(TOPIC_LABELS))
|
| 834 |
|
| 835 |
balanced: list[dict[str, Any]] = []
|
| 836 |
+
for _topic, records in topic_counts.items():
|
| 837 |
if records:
|
| 838 |
random.shuffle(records)
|
| 839 |
balanced.extend(records[:target_count])
|
|
|
|
| 853 |
# Save labels - only labels that have data
|
| 854 |
used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
|
| 855 |
(out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
|
| 856 |
+
print(f"\n {len(used_labels)} topic labels with data: {used_labels}")
|
| 857 |
|
| 858 |
|
| 859 |
def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
| 860 |
"""Extract topic-labeled samples from Gutenberg books (English only)."""
|
| 861 |
+
print("\nLoading Gutenberg for topic classification...")
|
| 862 |
|
| 863 |
try:
|
| 864 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
|
|
| 922 |
return records
|
| 923 |
|
| 924 |
|
| 925 |
+
# ------------ EMOTIONS (unchanged) -------------
|
| 926 |
|
| 927 |
def download_emotions() -> None:
|
| 928 |
"""Download GoEmotions for emotion classification."""
|
| 929 |
+
print("\nDownloading Emotions (GoEmotions)...")
|
| 930 |
out_dir = OUTPUT_DIR / "emotion"
|
| 931 |
|
| 932 |
ds = load_dataset("google-research-datasets/go_emotions", "simplified")
|
|
|
|
| 946 |
write_jsonl(records, out_dir / f"{split}.jsonl", split)
|
| 947 |
|
| 948 |
(out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
|
| 949 |
+
print(f" {len(EMOTION_LABELS)} emotion labels saved")
|
| 950 |
|
| 951 |
|
| 952 |
+
# --------------- GUTENBERG BOOKS (for language modeling) ---------------
|
| 953 |
|
| 954 |
GUTENBERG_JUNK_PATTERNS = [
|
| 955 |
r"Project Gutenberg", r"www\.gutenberg\.org", r"This ebook is for",
|
|
|
|
| 984 |
|
| 985 |
def download_gutenberg(max_samples: int = 30000) -> None:
|
| 986 |
"""Download Gutenberg books for language modeling (English only)."""
|
| 987 |
+
print("\nDownloading Gutenberg Books (English only)...")
|
| 988 |
out_dir = OUTPUT_DIR / "books"
|
| 989 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 990 |
|
|
|
|
| 1040 |
write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
|
| 1041 |
|
| 1042 |
|
| 1043 |
+
# ------------ MAIN ------------
|
| 1044 |
|
| 1045 |
def main() -> None:
|
| 1046 |
parser = argparse.ArgumentParser(description="Download LexiMind datasets")
|
|
|
|
| 1074 |
download_gutenberg(args.max_gutenberg)
|
| 1075 |
|
| 1076 |
print("\n" + "=" * 60)
|
| 1077 |
+
print("Download complete!")
|
| 1078 |
print("=" * 60)
|
| 1079 |
|
| 1080 |
|
scripts/evaluate.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Comprehensive evaluation script for LexiMind.
|
| 4 |
|
| 5 |
Evaluates all three tasks with full metrics:
|
| 6 |
-
- Summarization: ROUGE-1/2/L, BLEU-4,
|
| 7 |
- Emotion: Sample-avg F1, Macro F1, Micro F1, per-class metrics, threshold tuning
|
| 8 |
- Topic: Accuracy, Macro F1, Per-class metrics, bootstrap confidence intervals
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
python scripts/evaluate.py
|
| 12 |
python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 13 |
-
python scripts/evaluate.py --
|
| 14 |
-
python scripts/evaluate.py --tune-thresholds
|
| 15 |
-
python scripts/evaluate.py --bootstrap
|
| 16 |
|
| 17 |
Author: Oliver Perrin
|
| 18 |
Date: January 2026
|
|
@@ -419,7 +418,7 @@ def main():
|
|
| 419 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 420 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 421 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 422 |
-
parser.add_argument("--
|
| 423 |
parser.add_argument("--tune-thresholds", action="store_true", help="Tune per-class emotion thresholds on val set")
|
| 424 |
parser.add_argument("--bootstrap", action="store_true", help="Compute bootstrap confidence intervals")
|
| 425 |
parser.add_argument("--summarization-only", action="store_true")
|
|
@@ -459,7 +458,7 @@ def main():
|
|
| 459 |
results["summarization"] = evaluate_summarization(
|
| 460 |
pipeline, val_path,
|
| 461 |
max_samples=args.max_samples,
|
| 462 |
-
include_bertscore=
|
| 463 |
compute_bootstrap=args.bootstrap,
|
| 464 |
)
|
| 465 |
else:
|
|
@@ -515,13 +514,18 @@ def main():
|
|
| 515 |
s = results["summarization"]
|
| 516 |
print("\n Summarization:")
|
| 517 |
print(f" ROUGE-1: {s['rouge1']:.4f}")
|
|
|
|
| 518 |
print(f" ROUGE-L: {s['rougeL']:.4f}")
|
|
|
|
| 519 |
if "bertscore_f1" in s:
|
| 520 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 521 |
|
| 522 |
if "emotion" in results:
|
|
|
|
| 523 |
print("\n Emotion:")
|
| 524 |
-
print(f"
|
|
|
|
|
|
|
| 525 |
|
| 526 |
if "topic" in results:
|
| 527 |
print("\n Topic:")
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Comprehensive evaluation script for LexiMind.
|
| 3 |
|
| 4 |
Evaluates all three tasks with full metrics:
|
| 5 |
+
- Summarization: ROUGE-1/2/L, BLEU-4, per-domain breakdown (BERTScore optional)
|
| 6 |
- Emotion: Sample-avg F1, Macro F1, Micro F1, per-class metrics, threshold tuning
|
| 7 |
- Topic: Accuracy, Macro F1, Per-class metrics, bootstrap confidence intervals
|
| 8 |
|
| 9 |
Usage:
|
| 10 |
python scripts/evaluate.py
|
| 11 |
python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 12 |
+
python scripts/evaluate.py --include-bertscore # Include BERTScore (slow)
|
| 13 |
+
python scripts/evaluate.py --tune-thresholds # Tune per-class emotion thresholds
|
| 14 |
+
python scripts/evaluate.py --bootstrap # Compute confidence intervals
|
| 15 |
|
| 16 |
Author: Oliver Perrin
|
| 17 |
Date: January 2026
|
|
|
|
| 418 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 419 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 420 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 421 |
+
parser.add_argument("--include-bertscore", action="store_true", help="Include BERTScore (slow, optional)")
|
| 422 |
parser.add_argument("--tune-thresholds", action="store_true", help="Tune per-class emotion thresholds on val set")
|
| 423 |
parser.add_argument("--bootstrap", action="store_true", help="Compute bootstrap confidence intervals")
|
| 424 |
parser.add_argument("--summarization-only", action="store_true")
|
|
|
|
| 458 |
results["summarization"] = evaluate_summarization(
|
| 459 |
pipeline, val_path,
|
| 460 |
max_samples=args.max_samples,
|
| 461 |
+
include_bertscore=args.include_bertscore,
|
| 462 |
compute_bootstrap=args.bootstrap,
|
| 463 |
)
|
| 464 |
else:
|
|
|
|
| 514 |
s = results["summarization"]
|
| 515 |
print("\n Summarization:")
|
| 516 |
print(f" ROUGE-1: {s['rouge1']:.4f}")
|
| 517 |
+
print(f" ROUGE-2: {s['rouge2']:.4f}")
|
| 518 |
print(f" ROUGE-L: {s['rougeL']:.4f}")
|
| 519 |
+
print(f" BLEU-4: {s['bleu4']:.4f}")
|
| 520 |
if "bertscore_f1" in s:
|
| 521 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 522 |
|
| 523 |
if "emotion" in results:
|
| 524 |
+
e = results["emotion"]
|
| 525 |
print("\n Emotion:")
|
| 526 |
+
print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
|
| 527 |
+
print(f" Macro F1: {e['macro_f1']:.4f}")
|
| 528 |
+
print(f" Micro F1: {e['micro_f1']:.4f}")
|
| 529 |
|
| 530 |
if "topic" in results:
|
| 531 |
print("\n Topic:")
|
scripts/profile_training.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Profile LexiMind training with PyTorch Profiler.
|
| 3 |
+
|
| 4 |
+
Runs a few training steps under torch.profiler to capture:
|
| 5 |
+
- CUDA kernel timing (per-operator breakdown)
|
| 6 |
+
- GPU memory usage (peak allocations, memory timeline)
|
| 7 |
+
- CPU/GPU overlap and idle time
|
| 8 |
+
- Chrome trace (viewable in chrome://tracing or Perfetto UI)
|
| 9 |
+
|
| 10 |
+
Outputs:
|
| 11 |
+
outputs/profile/ -- Chrome trace + stacks
|
| 12 |
+
stdout -- Summary table of top CUDA operations
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python scripts/profile_training.py # default: 20 steps
|
| 16 |
+
python scripts/profile_training.py training=full # use full config
|
| 17 |
+
PROFILE_STEPS=40 python scripts/profile_training.py # custom step count
|
| 18 |
+
|
| 19 |
+
Author: Oliver Perrin
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import hydra
|
| 29 |
+
import torch
|
| 30 |
+
from omegaconf import DictConfig
|
| 31 |
+
|
| 32 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 33 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 34 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 35 |
+
|
| 36 |
+
from src.data.dataloader import (
|
| 37 |
+
build_emotion_dataloader,
|
| 38 |
+
build_summarization_dataloader,
|
| 39 |
+
build_topic_dataloader,
|
| 40 |
+
)
|
| 41 |
+
from src.data.dataset import (
|
| 42 |
+
EmotionDataset,
|
| 43 |
+
SummarizationDataset,
|
| 44 |
+
TopicDataset,
|
| 45 |
+
load_emotion_jsonl,
|
| 46 |
+
load_summarization_jsonl,
|
| 47 |
+
load_topic_jsonl,
|
| 48 |
+
)
|
| 49 |
+
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 50 |
+
from src.models.factory import ModelConfig, build_multitask_model
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_splits(data_dir: Path, loader_fn):
|
| 54 |
+
splits = {}
|
| 55 |
+
for name, aliases in [("train", ["train"]), ("val", ["val", "validation"])]:
|
| 56 |
+
for alias in aliases:
|
| 57 |
+
path = data_dir / f"{alias}.jsonl"
|
| 58 |
+
if path.exists():
|
| 59 |
+
splits[name] = loader_fn(str(path))
|
| 60 |
+
break
|
| 61 |
+
return splits
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
| 65 |
+
def main(cfg: DictConfig) -> None:
|
| 66 |
+
profile_steps = int(os.environ.get("PROFILE_STEPS", 20))
|
| 67 |
+
warmup_steps = 3 # let CUDA graphs / torch.compile settle
|
| 68 |
+
active_steps = profile_steps - warmup_steps
|
| 69 |
+
|
| 70 |
+
device = torch.device(cfg.device)
|
| 71 |
+
if device.type != "cuda":
|
| 72 |
+
print("Profiler requires CUDA. Set device=cuda.")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
print(f"Profiling {profile_steps} steps ({warmup_steps} warmup + {active_steps} active)")
|
| 76 |
+
print(f"GPU: {torch.cuda.get_device_name()}")
|
| 77 |
+
|
| 78 |
+
# ---------- Setup (mirrors train.py) ----------
|
| 79 |
+
|
| 80 |
+
torch.backends.cudnn.benchmark = True
|
| 81 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 82 |
+
torch.set_float32_matmul_precision("high")
|
| 83 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 84 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 85 |
+
|
| 86 |
+
data_cfg = cfg.data
|
| 87 |
+
trainer_cfg = cfg.training.get("trainer", {})
|
| 88 |
+
|
| 89 |
+
# Load small subsets -- profiling doesn't need the full dataset
|
| 90 |
+
max_samples = max(200, profile_steps * 10 * 3)
|
| 91 |
+
summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
|
| 92 |
+
emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
|
| 93 |
+
topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
|
| 94 |
+
for splits in [summ_splits, emot_splits, topic_splits]:
|
| 95 |
+
splits["train"] = splits["train"][:max_samples]
|
| 96 |
+
|
| 97 |
+
tok_cfg = data_cfg.get("tokenizer", {})
|
| 98 |
+
max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
|
| 99 |
+
tokenizer = Tokenizer(TokenizerConfig(
|
| 100 |
+
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 101 |
+
max_length=max_len,
|
| 102 |
+
))
|
| 103 |
+
|
| 104 |
+
summ_train = SummarizationDataset(summ_splits["train"])
|
| 105 |
+
emot_train = EmotionDataset(emot_splits["train"])
|
| 106 |
+
topic_train = TopicDataset(topic_splits["train"])
|
| 107 |
+
|
| 108 |
+
dl_cfg = cfg.training.get("dataloader", {})
|
| 109 |
+
batch_size = int(dl_cfg.get("batch_size", 8))
|
| 110 |
+
num_workers = int(dl_cfg.get("num_workers", 4))
|
| 111 |
+
classification_max_len = min(256, max_len)
|
| 112 |
+
|
| 113 |
+
train_loaders = {
|
| 114 |
+
"summarization": build_summarization_dataloader(
|
| 115 |
+
summ_train, tokenizer, shuffle=True,
|
| 116 |
+
max_source_length=max_len, max_target_length=max_len,
|
| 117 |
+
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
|
| 118 |
+
),
|
| 119 |
+
"emotion": build_emotion_dataloader(
|
| 120 |
+
emot_train, tokenizer, shuffle=True, max_length=classification_max_len,
|
| 121 |
+
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
|
| 122 |
+
),
|
| 123 |
+
"topic": build_topic_dataloader(
|
| 124 |
+
topic_train, tokenizer, shuffle=True, max_length=classification_max_len,
|
| 125 |
+
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
|
| 126 |
+
),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Build model
|
| 130 |
+
grad_ckpt = cfg.training.get("gradient_checkpointing", cfg.model.get("gradient_checkpointing", False))
|
| 131 |
+
use_rel_pos = cfg.training.get("use_relative_position_bias", cfg.model.get("use_relative_position_bias", False))
|
| 132 |
+
|
| 133 |
+
model_cfg = ModelConfig(
|
| 134 |
+
d_model=cfg.model.d_model,
|
| 135 |
+
vocab_size=getattr(cfg.model, "vocab_size", None),
|
| 136 |
+
num_encoder_layers=cfg.model.num_encoder_layers,
|
| 137 |
+
num_decoder_layers=cfg.model.num_decoder_layers,
|
| 138 |
+
num_attention_heads=cfg.model.num_attention_heads,
|
| 139 |
+
ffn_dim=cfg.model.ffn_dim,
|
| 140 |
+
dropout=cfg.model.dropout,
|
| 141 |
+
use_pretrained=cfg.model.use_pretrained,
|
| 142 |
+
pretrained_model_name=cfg.model.pretrained_model_name,
|
| 143 |
+
activation=getattr(cfg.model, "activation", "gelu"),
|
| 144 |
+
use_relative_position_bias=use_rel_pos,
|
| 145 |
+
gradient_checkpointing=grad_ckpt,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
model = build_multitask_model(
|
| 149 |
+
tokenizer,
|
| 150 |
+
num_emotions=len(emot_train.emotion_classes),
|
| 151 |
+
num_topics=len(topic_train.topic_classes),
|
| 152 |
+
config=model_cfg,
|
| 153 |
+
).to(device)
|
| 154 |
+
|
| 155 |
+
# Freeze layers (same as train.py)
|
| 156 |
+
freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
|
| 157 |
+
if freeze_layers > 0:
|
| 158 |
+
if hasattr(model.encoder, "embed_tokens"):
|
| 159 |
+
for p in model.encoder.embed_tokens.parameters():
|
| 160 |
+
p.requires_grad = False
|
| 161 |
+
if hasattr(model.encoder, "layers"):
|
| 162 |
+
for i, layer in enumerate(model.encoder.layers):
|
| 163 |
+
if i < freeze_layers:
|
| 164 |
+
for p in layer.parameters():
|
| 165 |
+
p.requires_grad = False
|
| 166 |
+
|
| 167 |
+
# Compile (same as train.py)
|
| 168 |
+
compile_mode = "default" if grad_ckpt else "reduce-overhead"
|
| 169 |
+
if cfg.training.get("compile_encoder", True):
|
| 170 |
+
model.encoder = torch.compile(model.encoder, mode=compile_mode)
|
| 171 |
+
if cfg.training.get("compile_decoder", True):
|
| 172 |
+
model.decoder = torch.compile(model.decoder, mode=compile_mode)
|
| 173 |
+
|
| 174 |
+
# Optimizer
|
| 175 |
+
opt_cfg = cfg.training.get("optimizer", {})
|
| 176 |
+
use_fused = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
|
| 177 |
+
optimizer = torch.optim.AdamW(
|
| 178 |
+
model.parameters(),
|
| 179 |
+
lr=float(opt_cfg.get("lr", 3e-5)),
|
| 180 |
+
weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
|
| 181 |
+
fused=use_fused,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# ---------- Profile loop ----------
|
| 185 |
+
|
| 186 |
+
out_dir = PROJECT_ROOT / "outputs" / "profile"
|
| 187 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
model.train()
|
| 190 |
+
iterators = {task: iter(loader) for task, loader in train_loaders.items()}
|
| 191 |
+
task_names = list(train_loaders.keys())
|
| 192 |
+
accum = int(trainer_cfg.get("gradient_accumulation_steps", 4))
|
| 193 |
+
use_bf16 = torch.cuda.is_bf16_supported()
|
| 194 |
+
task_weights = trainer_cfg.get("task_weights") or {}
|
| 195 |
+
|
| 196 |
+
emotion_loss_fn = torch.nn.BCEWithLogitsLoss()
|
| 197 |
+
topic_loss_fn = torch.nn.CrossEntropyLoss()
|
| 198 |
+
|
| 199 |
+
def get_batch(task):
|
| 200 |
+
try:
|
| 201 |
+
batch = next(iterators[task])
|
| 202 |
+
except StopIteration:
|
| 203 |
+
iterators[task] = iter(train_loaders[task])
|
| 204 |
+
batch = next(iterators[task])
|
| 205 |
+
return {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
|
| 206 |
+
for k, v in batch.items()}
|
| 207 |
+
|
| 208 |
+
def training_step(step):
|
| 209 |
+
"""One training step across all tasks."""
|
| 210 |
+
for task in task_names:
|
| 211 |
+
batch = get_batch(task)
|
| 212 |
+
dtype = torch.bfloat16 if use_bf16 else torch.float16
|
| 213 |
+
with torch.autocast("cuda", dtype=dtype):
|
| 214 |
+
if task == "summarization":
|
| 215 |
+
inputs = {"src_ids": batch["src_ids"], "tgt_ids": batch["tgt_ids"]}
|
| 216 |
+
if "src_mask" in batch:
|
| 217 |
+
inputs["src_mask"] = batch["src_mask"]
|
| 218 |
+
logits = model.forward("summarization", inputs)
|
| 219 |
+
loss = torch.nn.functional.cross_entropy(
|
| 220 |
+
logits.view(-1, logits.size(-1)),
|
| 221 |
+
batch["labels"].view(-1),
|
| 222 |
+
ignore_index=-100, label_smoothing=0.1,
|
| 223 |
+
)
|
| 224 |
+
elif task == "emotion":
|
| 225 |
+
inputs = {"input_ids": batch["input_ids"]}
|
| 226 |
+
if "attention_mask" in batch:
|
| 227 |
+
inputs["attention_mask"] = batch["attention_mask"]
|
| 228 |
+
logits = model.forward("emotion", inputs)
|
| 229 |
+
loss = emotion_loss_fn(logits, batch["labels"].float())
|
| 230 |
+
elif task == "topic":
|
| 231 |
+
inputs = {"input_ids": batch["input_ids"]}
|
| 232 |
+
if "attention_mask" in batch:
|
| 233 |
+
inputs["attention_mask"] = batch["attention_mask"]
|
| 234 |
+
logits = model.forward("topic", inputs)
|
| 235 |
+
loss = topic_loss_fn(logits, batch["labels"])
|
| 236 |
+
else:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
weight = task_weights.get(task, 1.0)
|
| 240 |
+
scaled = (loss * weight) / accum
|
| 241 |
+
scaled.backward()
|
| 242 |
+
|
| 243 |
+
if (step + 1) % accum == 0:
|
| 244 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 245 |
+
optimizer.step()
|
| 246 |
+
optimizer.zero_grad()
|
| 247 |
+
|
| 248 |
+
# Warmup outside profiler to let torch.compile finish
|
| 249 |
+
print(f"\nWarmup ({warmup_steps} steps)...")
|
| 250 |
+
for s in range(warmup_steps):
|
| 251 |
+
training_step(s)
|
| 252 |
+
optimizer.zero_grad()
|
| 253 |
+
torch.cuda.synchronize()
|
| 254 |
+
|
| 255 |
+
# Profile
|
| 256 |
+
print(f"Profiling ({active_steps} steps)...")
|
| 257 |
+
trace_path = str(out_dir / "trace")
|
| 258 |
+
|
| 259 |
+
with torch.profiler.profile(
|
| 260 |
+
activities=[
|
| 261 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 262 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 263 |
+
],
|
| 264 |
+
schedule=torch.profiler.schedule(
|
| 265 |
+
wait=1, warmup=2, active=active_steps - 3, repeat=1,
|
| 266 |
+
),
|
| 267 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
|
| 268 |
+
record_shapes=True,
|
| 269 |
+
profile_memory=True,
|
| 270 |
+
with_stack=True,
|
| 271 |
+
with_flops=True,
|
| 272 |
+
) as prof:
|
| 273 |
+
for s in range(active_steps):
|
| 274 |
+
training_step(warmup_steps + s)
|
| 275 |
+
prof.step()
|
| 276 |
+
|
| 277 |
+
torch.cuda.synchronize()
|
| 278 |
+
|
| 279 |
+
# ---------- Summary ----------
|
| 280 |
+
|
| 281 |
+
print("\n" + "=" * 80)
|
| 282 |
+
print("TOP CUDA OPERATIONS (by total CUDA time)")
|
| 283 |
+
print("=" * 80)
|
| 284 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=25))
|
| 285 |
+
|
| 286 |
+
print("\n" + "=" * 80)
|
| 287 |
+
print("TOP CUDA OPERATIONS (by GPU memory)")
|
| 288 |
+
print("=" * 80)
|
| 289 |
+
print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=15))
|
| 290 |
+
|
| 291 |
+
# Memory summary
|
| 292 |
+
print("\n" + "=" * 80)
|
| 293 |
+
print("GPU MEMORY SUMMARY")
|
| 294 |
+
print("=" * 80)
|
| 295 |
+
print(torch.cuda.memory_summary(abbreviated=True))
|
| 296 |
+
|
| 297 |
+
# Export Chrome trace
|
| 298 |
+
chrome_trace = out_dir / "chrome_trace.json"
|
| 299 |
+
prof.export_chrome_trace(str(chrome_trace))
|
| 300 |
+
print(f"\nChrome trace: {chrome_trace}")
|
| 301 |
+
print(" Open in: chrome://tracing or https://ui.perfetto.dev")
|
| 302 |
+
|
| 303 |
+
# Export stacks for flamegraph
|
| 304 |
+
stacks_path = out_dir / "profiler_stacks.txt"
|
| 305 |
+
prof.export_stacks(str(stacks_path), "self_cuda_time_total")
|
| 306 |
+
print(f"CUDA stacks: {stacks_path}")
|
| 307 |
+
print(f" Generate flamegraph: flamegraph.pl {stacks_path} > flamegraph.svg")
|
| 308 |
+
|
| 309 |
+
print(f"\nTensorBoard traces: {trace_path}/")
|
| 310 |
+
print(f" View with: tensorboard --logdir={trace_path}")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
main()
|
scripts/train.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Training script for LexiMind.
|
| 4 |
|
|
@@ -97,9 +96,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 97 |
torch.set_float32_matmul_precision("high")
|
| 98 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 99 |
torch.backends.cudnn.allow_tf32 = True
|
| 100 |
-
print("
|
| 101 |
else:
|
| 102 |
-
print("
|
| 103 |
|
| 104 |
# --------------- Load Data ---------------
|
| 105 |
|
|
@@ -218,9 +217,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 218 |
)
|
| 219 |
|
| 220 |
if grad_ckpt:
|
| 221 |
-
print("
|
| 222 |
if not use_rel_pos:
|
| 223 |
-
print("
|
| 224 |
|
| 225 |
model = build_multitask_model(
|
| 226 |
tokenizer,
|
|
@@ -249,7 +248,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 249 |
p.requires_grad = False
|
| 250 |
frozen_params += p.numel()
|
| 251 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 252 |
-
print(f"
|
| 253 |
print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
|
| 254 |
|
| 255 |
# Resume from checkpoint?
|
|
@@ -269,10 +268,10 @@ def main(cfg: DictConfig) -> None:
|
|
| 269 |
compile_mode = "default" if grad_ckpt else "reduce-overhead"
|
| 270 |
if cfg.training.get("compile_encoder", True):
|
| 271 |
model.encoder = torch.compile(model.encoder, mode=compile_mode) # type: ignore[assignment]
|
| 272 |
-
print(f"
|
| 273 |
if cfg.training.get("compile_decoder", True):
|
| 274 |
model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
|
| 275 |
-
print(f"
|
| 276 |
|
| 277 |
# --------------- Train ---------------
|
| 278 |
|
|
@@ -289,7 +288,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 289 |
fused=use_fused,
|
| 290 |
)
|
| 291 |
if use_fused:
|
| 292 |
-
print("
|
| 293 |
|
| 294 |
trainer = Trainer(
|
| 295 |
model=model,
|
|
@@ -303,7 +302,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 303 |
scheduler_type=str(sched_cfg.get("name", "cosine")),
|
| 304 |
warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
|
| 305 |
early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
|
| 306 |
-
task_sampling=str(trainer_cfg.get("task_sampling", "
|
| 307 |
task_sampling_alpha=float(trainer_cfg.get("task_sampling_alpha", 0.5)),
|
| 308 |
gradient_conflict_frequency=int(trainer_cfg.get("gradient_conflict_frequency", 0)),
|
| 309 |
),
|
|
@@ -329,7 +328,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 329 |
if val_loss < best_val_loss:
|
| 330 |
best_val_loss = val_loss
|
| 331 |
save_state(model, str(ckpt_dir / "best.pt"))
|
| 332 |
-
print(f"
|
| 333 |
|
| 334 |
history = trainer.fit(
|
| 335 |
train_loaders,
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Training script for LexiMind.
|
| 3 |
|
|
|
|
| 96 |
torch.set_float32_matmul_precision("high")
|
| 97 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 98 |
torch.backends.cudnn.allow_tf32 = True
|
| 99 |
+
print(" TF32 + cudnn.benchmark enabled (Ampere GPU)")
|
| 100 |
else:
|
| 101 |
+
print(" cudnn.benchmark enabled")
|
| 102 |
|
| 103 |
# --------------- Load Data ---------------
|
| 104 |
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
if grad_ckpt:
|
| 220 |
+
print(" Gradient checkpointing: on")
|
| 221 |
if not use_rel_pos:
|
| 222 |
+
print(" FlashAttention: on (no relative position bias)")
|
| 223 |
|
| 224 |
model = build_multitask_model(
|
| 225 |
tokenizer,
|
|
|
|
| 248 |
p.requires_grad = False
|
| 249 |
frozen_params += p.numel()
|
| 250 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 251 |
+
print(f" Frozen layers: 0-{freeze_layers-1} ({frozen_params/1e6:.1f}M params)")
|
| 252 |
print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
|
| 253 |
|
| 254 |
# Resume from checkpoint?
|
|
|
|
| 268 |
compile_mode = "default" if grad_ckpt else "reduce-overhead"
|
| 269 |
if cfg.training.get("compile_encoder", True):
|
| 270 |
model.encoder = torch.compile(model.encoder, mode=compile_mode) # type: ignore[assignment]
|
| 271 |
+
print(f" Encoder compiled ({compile_mode})")
|
| 272 |
if cfg.training.get("compile_decoder", True):
|
| 273 |
model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
|
| 274 |
+
print(f" Decoder compiled ({compile_mode})")
|
| 275 |
|
| 276 |
# --------------- Train ---------------
|
| 277 |
|
|
|
|
| 288 |
fused=use_fused,
|
| 289 |
)
|
| 290 |
if use_fused:
|
| 291 |
+
print(" Fused AdamW: on")
|
| 292 |
|
| 293 |
trainer = Trainer(
|
| 294 |
model=model,
|
|
|
|
| 302 |
scheduler_type=str(sched_cfg.get("name", "cosine")),
|
| 303 |
warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
|
| 304 |
early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
|
| 305 |
+
task_sampling=str(trainer_cfg.get("task_sampling", "temperature")),
|
| 306 |
task_sampling_alpha=float(trainer_cfg.get("task_sampling_alpha", 0.5)),
|
| 307 |
gradient_conflict_frequency=int(trainer_cfg.get("gradient_conflict_frequency", 0)),
|
| 308 |
),
|
|
|
|
| 328 |
if val_loss < best_val_loss:
|
| 329 |
best_val_loss = val_loss
|
| 330 |
save_state(model, str(ckpt_dir / "best.pt"))
|
| 331 |
+
print(f" New best model saved (val_loss={val_loss:.4f})")
|
| 332 |
|
| 333 |
history = trainer.fit(
|
| 334 |
train_loaders,
|
scripts/train_multiseed.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Multi-seed training wrapper for LexiMind.
|
| 4 |
|
|
@@ -53,7 +52,8 @@ def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
|
|
| 53 |
history_path = seed_dir / "training_history.json"
|
| 54 |
if history_path.exists():
|
| 55 |
with open(history_path) as f:
|
| 56 |
-
|
|
|
|
| 57 |
return {}
|
| 58 |
|
| 59 |
|
|
@@ -88,7 +88,8 @@ def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = Non
|
|
| 88 |
|
| 89 |
if output.exists():
|
| 90 |
with open(output) as f:
|
| 91 |
-
|
|
|
|
| 92 |
return {}
|
| 93 |
|
| 94 |
|
|
@@ -99,7 +100,7 @@ def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
|
|
| 99 |
|
| 100 |
# Collect all metric paths
|
| 101 |
metric_values: Dict[str, List[float]] = {}
|
| 102 |
-
for
|
| 103 |
for task, task_metrics in results.items():
|
| 104 |
if not isinstance(task_metrics, dict):
|
| 105 |
continue
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Multi-seed training wrapper for LexiMind.
|
| 3 |
|
|
|
|
| 52 |
history_path = seed_dir / "training_history.json"
|
| 53 |
if history_path.exists():
|
| 54 |
with open(history_path) as f:
|
| 55 |
+
data: Dict = json.load(f) # type: ignore[no-any-return]
|
| 56 |
+
return data
|
| 57 |
return {}
|
| 58 |
|
| 59 |
|
|
|
|
| 88 |
|
| 89 |
if output.exists():
|
| 90 |
with open(output) as f:
|
| 91 |
+
data: Dict = json.load(f) # type: ignore[no-any-return]
|
| 92 |
+
return data
|
| 93 |
return {}
|
| 94 |
|
| 95 |
|
|
|
|
| 100 |
|
| 101 |
# Collect all metric paths
|
| 102 |
metric_values: Dict[str, List[float]] = {}
|
| 103 |
+
for _seed, results in all_results.items():
|
| 104 |
for task, task_metrics in results.items():
|
| 105 |
if not isinstance(task_metrics, dict):
|
| 106 |
continue
|
scripts/visualize_training.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
LexiMind Training Visualization Suite.
|
| 4 |
|
|
@@ -63,16 +62,14 @@ except ImportError:
|
|
| 63 |
pass
|
| 64 |
|
| 65 |
try:
|
| 66 |
-
from mpl_toolkits.mplot3d import Axes3D # type: ignore[import-
|
| 67 |
|
| 68 |
HAS_MPLOT3D = True
|
| 69 |
except ImportError:
|
| 70 |
pass
|
| 71 |
|
| 72 |
|
| 73 |
-
# =============================================================================
|
| 74 |
# Configuration
|
| 75 |
-
# =============================================================================
|
| 76 |
|
| 77 |
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 78 |
logger = logging.getLogger(__name__)
|
|
@@ -116,10 +113,7 @@ HEATMAP_CMAP = LinearSegmentedColormap.from_list(
|
|
| 116 |
)
|
| 117 |
|
| 118 |
|
| 119 |
-
# =============================================================================
|
| 120 |
# MLflow Utilities
|
| 121 |
-
# =============================================================================
|
| 122 |
-
|
| 123 |
|
| 124 |
def get_mlflow_client():
|
| 125 |
"""Get MLflow client with correct tracking URI."""
|
|
@@ -157,10 +151,7 @@ def get_metric_history(run, metric_name: str) -> tuple[list, list]:
|
|
| 157 |
return [m.step for m in metrics], [m.value for m in metrics]
|
| 158 |
|
| 159 |
|
| 160 |
-
# =============================================================================
|
| 161 |
# Core Training Visualizations
|
| 162 |
-
# =============================================================================
|
| 163 |
-
|
| 164 |
|
| 165 |
def plot_loss_curves(run, interactive: bool = False) -> None:
|
| 166 |
"""
|
|
@@ -208,7 +199,7 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
|
|
| 208 |
|
| 209 |
output_path = OUTPUTS_DIR / "training_loss_curve.html"
|
| 210 |
fig.write_html(str(output_path))
|
| 211 |
-
logger.info(f"
|
| 212 |
return
|
| 213 |
|
| 214 |
# Static matplotlib version
|
|
@@ -253,7 +244,7 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
|
|
| 253 |
plt.tight_layout()
|
| 254 |
output_path = OUTPUTS_DIR / "training_loss_curve.png"
|
| 255 |
plt.savefig(output_path)
|
| 256 |
-
logger.info(f"
|
| 257 |
plt.close()
|
| 258 |
|
| 259 |
|
|
@@ -387,7 +378,7 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 387 |
plt.tight_layout()
|
| 388 |
output_path = OUTPUTS_DIR / "task_metrics.png"
|
| 389 |
plt.savefig(output_path)
|
| 390 |
-
logger.info(f"
|
| 391 |
plt.close()
|
| 392 |
|
| 393 |
|
|
@@ -474,14 +465,11 @@ def plot_learning_rate(run) -> None:
|
|
| 474 |
plt.tight_layout()
|
| 475 |
output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
|
| 476 |
plt.savefig(output_path)
|
| 477 |
-
logger.info(f"
|
| 478 |
plt.close()
|
| 479 |
|
| 480 |
|
| 481 |
-
# =============================================================================
|
| 482 |
# Advanced Visualizations
|
| 483 |
-
# =============================================================================
|
| 484 |
-
|
| 485 |
|
| 486 |
def plot_confusion_matrix(run, task: str = "topic") -> None:
|
| 487 |
"""
|
|
@@ -544,7 +532,7 @@ def plot_confusion_matrix(run, task: str = "topic") -> None:
|
|
| 544 |
plt.tight_layout()
|
| 545 |
output_path = OUTPUTS_DIR / f"confusion_matrix_{task}.png"
|
| 546 |
plt.savefig(output_path)
|
| 547 |
-
logger.info(f"
|
| 548 |
plt.close()
|
| 549 |
|
| 550 |
|
|
@@ -646,7 +634,7 @@ def plot_3d_loss_landscape(run) -> None:
|
|
| 646 |
|
| 647 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.html"
|
| 648 |
fig.write_html(str(output_path))
|
| 649 |
-
logger.info(f"
|
| 650 |
|
| 651 |
|
| 652 |
def plot_3d_loss_landscape_static(run) -> None:
|
|
@@ -702,7 +690,7 @@ def plot_3d_loss_landscape_static(run) -> None:
|
|
| 702 |
plt.tight_layout()
|
| 703 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.png"
|
| 704 |
plt.savefig(output_path)
|
| 705 |
-
logger.info(f"
|
| 706 |
plt.close()
|
| 707 |
|
| 708 |
|
|
@@ -770,7 +758,7 @@ def plot_embedding_space(run) -> None:
|
|
| 770 |
plt.tight_layout()
|
| 771 |
output_path = OUTPUTS_DIR / "embedding_space.png"
|
| 772 |
plt.savefig(output_path)
|
| 773 |
-
logger.info(f"
|
| 774 |
plt.close()
|
| 775 |
|
| 776 |
|
|
@@ -868,14 +856,11 @@ def plot_training_dynamics(run) -> None:
|
|
| 868 |
plt.tight_layout()
|
| 869 |
output_path = OUTPUTS_DIR / "training_dynamics.png"
|
| 870 |
plt.savefig(output_path)
|
| 871 |
-
logger.info(f"
|
| 872 |
plt.close()
|
| 873 |
|
| 874 |
|
| 875 |
-
# =============================================================================
|
| 876 |
# Dashboard Generator
|
| 877 |
-
# =============================================================================
|
| 878 |
-
|
| 879 |
|
| 880 |
def generate_dashboard(run) -> None:
|
| 881 |
"""
|
|
@@ -959,13 +944,10 @@ def generate_dashboard(run) -> None:
|
|
| 959 |
|
| 960 |
output_path = OUTPUTS_DIR / "training_dashboard.html"
|
| 961 |
fig.write_html(str(output_path))
|
| 962 |
-
logger.info(f"
|
| 963 |
|
| 964 |
|
| 965 |
-
# =============================================================================
|
| 966 |
# Main Entry Point
|
| 967 |
-
# =============================================================================
|
| 968 |
-
|
| 969 |
|
| 970 |
def main():
|
| 971 |
"""Generate all training visualizations."""
|
|
@@ -1026,7 +1008,7 @@ def main():
|
|
| 1026 |
# Summary
|
| 1027 |
logger.info("")
|
| 1028 |
logger.info("=" * 60)
|
| 1029 |
-
logger.info("
|
| 1030 |
logger.info("=" * 60)
|
| 1031 |
|
| 1032 |
outputs = [
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
LexiMind Training Visualization Suite.
|
| 3 |
|
|
|
|
| 62 |
pass
|
| 63 |
|
| 64 |
try:
|
| 65 |
+
from mpl_toolkits.mplot3d import Axes3D # type: ignore[import-not-found] # noqa: F401
|
| 66 |
|
| 67 |
HAS_MPLOT3D = True
|
| 68 |
except ImportError:
|
| 69 |
pass
|
| 70 |
|
| 71 |
|
|
|
|
| 72 |
# Configuration
|
|
|
|
| 73 |
|
| 74 |
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 75 |
logger = logging.getLogger(__name__)
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
|
|
|
|
| 116 |
# MLflow Utilities
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def get_mlflow_client():
|
| 119 |
"""Get MLflow client with correct tracking URI."""
|
|
|
|
| 151 |
return [m.step for m in metrics], [m.value for m in metrics]
|
| 152 |
|
| 153 |
|
|
|
|
| 154 |
# Core Training Visualizations
|
|
|
|
|
|
|
| 155 |
|
| 156 |
def plot_loss_curves(run, interactive: bool = False) -> None:
|
| 157 |
"""
|
|
|
|
| 199 |
|
| 200 |
output_path = OUTPUTS_DIR / "training_loss_curve.html"
|
| 201 |
fig.write_html(str(output_path))
|
| 202 |
+
logger.info(f"Saved interactive loss curve to {output_path}")
|
| 203 |
return
|
| 204 |
|
| 205 |
# Static matplotlib version
|
|
|
|
| 244 |
plt.tight_layout()
|
| 245 |
output_path = OUTPUTS_DIR / "training_loss_curve.png"
|
| 246 |
plt.savefig(output_path)
|
| 247 |
+
logger.info(f"Saved loss curve to {output_path}")
|
| 248 |
plt.close()
|
| 249 |
|
| 250 |
|
|
|
|
| 378 |
plt.tight_layout()
|
| 379 |
output_path = OUTPUTS_DIR / "task_metrics.png"
|
| 380 |
plt.savefig(output_path)
|
| 381 |
+
logger.info(f"Saved task metrics to {output_path}")
|
| 382 |
plt.close()
|
| 383 |
|
| 384 |
|
|
|
|
| 465 |
plt.tight_layout()
|
| 466 |
output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
|
| 467 |
plt.savefig(output_path)
|
| 468 |
+
logger.info(f"Saved LR schedule to {output_path}")
|
| 469 |
plt.close()
|
| 470 |
|
| 471 |
|
|
|
|
| 472 |
# Advanced Visualizations
|
|
|
|
|
|
|
| 473 |
|
| 474 |
def plot_confusion_matrix(run, task: str = "topic") -> None:
|
| 475 |
"""
|
|
|
|
| 532 |
plt.tight_layout()
|
| 533 |
output_path = OUTPUTS_DIR / f"confusion_matrix_{task}.png"
|
| 534 |
plt.savefig(output_path)
|
| 535 |
+
logger.info(f"Saved confusion matrix to {output_path}")
|
| 536 |
plt.close()
|
| 537 |
|
| 538 |
|
|
|
|
| 634 |
|
| 635 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.html"
|
| 636 |
fig.write_html(str(output_path))
|
| 637 |
+
logger.info(f"Saved 3D loss landscape to {output_path}")
|
| 638 |
|
| 639 |
|
| 640 |
def plot_3d_loss_landscape_static(run) -> None:
|
|
|
|
| 690 |
plt.tight_layout()
|
| 691 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.png"
|
| 692 |
plt.savefig(output_path)
|
| 693 |
+
logger.info(f"Saved 3D loss landscape to {output_path}")
|
| 694 |
plt.close()
|
| 695 |
|
| 696 |
|
|
|
|
| 758 |
plt.tight_layout()
|
| 759 |
output_path = OUTPUTS_DIR / "embedding_space.png"
|
| 760 |
plt.savefig(output_path)
|
| 761 |
+
logger.info(f"Saved embedding visualization to {output_path}")
|
| 762 |
plt.close()
|
| 763 |
|
| 764 |
|
|
|
|
| 856 |
plt.tight_layout()
|
| 857 |
output_path = OUTPUTS_DIR / "training_dynamics.png"
|
| 858 |
plt.savefig(output_path)
|
| 859 |
+
logger.info(f"Saved training dynamics to {output_path}")
|
| 860 |
plt.close()
|
| 861 |
|
| 862 |
|
|
|
|
| 863 |
# Dashboard Generator
|
|
|
|
|
|
|
| 864 |
|
| 865 |
def generate_dashboard(run) -> None:
|
| 866 |
"""
|
|
|
|
| 944 |
|
| 945 |
output_path = OUTPUTS_DIR / "training_dashboard.html"
|
| 946 |
fig.write_html(str(output_path))
|
| 947 |
+
logger.info(f"Saved interactive dashboard to {output_path}")
|
| 948 |
|
| 949 |
|
|
|
|
| 950 |
# Main Entry Point
|
|
|
|
|
|
|
| 951 |
|
| 952 |
def main():
|
| 953 |
"""Generate all training visualizations."""
|
|
|
|
| 1008 |
# Summary
|
| 1009 |
logger.info("")
|
| 1010 |
logger.info("=" * 60)
|
| 1011 |
+
logger.info("All visualizations saved to outputs/")
|
| 1012 |
logger.info("=" * 60)
|
| 1013 |
|
| 1014 |
outputs = [
|
src/data/dataset.py
CHANGED
|
@@ -24,7 +24,6 @@ from torch.utils.data import Dataset
|
|
| 24 |
@dataclass
|
| 25 |
class SummarizationExample:
|
| 26 |
"""Container for abstractive summarization samples."""
|
| 27 |
-
|
| 28 |
source: str
|
| 29 |
summary: str
|
| 30 |
|
|
@@ -32,7 +31,6 @@ class SummarizationExample:
|
|
| 32 |
@dataclass
|
| 33 |
class EmotionExample:
|
| 34 |
"""Container for multi-label emotion classification samples."""
|
| 35 |
-
|
| 36 |
text: str
|
| 37 |
emotions: Sequence[str]
|
| 38 |
|
|
@@ -40,14 +38,12 @@ class EmotionExample:
|
|
| 40 |
@dataclass
|
| 41 |
class TopicExample:
|
| 42 |
"""Container for topic clustering / classification samples."""
|
| 43 |
-
|
| 44 |
text: str
|
| 45 |
topic: str
|
| 46 |
|
| 47 |
|
| 48 |
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 49 |
"""Dataset yielding encoder-decoder training pairs."""
|
| 50 |
-
|
| 51 |
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 52 |
self._examples = list(examples)
|
| 53 |
|
|
@@ -60,7 +56,6 @@ class SummarizationDataset(Dataset[SummarizationExample]):
|
|
| 60 |
|
| 61 |
class EmotionDataset(Dataset[EmotionExample]):
|
| 62 |
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
| 63 |
-
|
| 64 |
def __init__(
|
| 65 |
self,
|
| 66 |
examples: Iterable[EmotionExample],
|
|
@@ -96,7 +91,6 @@ class EmotionDataset(Dataset[EmotionExample]):
|
|
| 96 |
|
| 97 |
class TopicDataset(Dataset[TopicExample]):
|
| 98 |
"""Dataset that owns a LabelEncoder for topic ids."""
|
| 99 |
-
|
| 100 |
def __init__(
|
| 101 |
self,
|
| 102 |
examples: Iterable[TopicExample],
|
|
|
|
| 24 |
@dataclass
|
| 25 |
class SummarizationExample:
|
| 26 |
"""Container for abstractive summarization samples."""
|
|
|
|
| 27 |
source: str
|
| 28 |
summary: str
|
| 29 |
|
|
|
|
| 31 |
@dataclass
|
| 32 |
class EmotionExample:
|
| 33 |
"""Container for multi-label emotion classification samples."""
|
|
|
|
| 34 |
text: str
|
| 35 |
emotions: Sequence[str]
|
| 36 |
|
|
|
|
| 38 |
@dataclass
|
| 39 |
class TopicExample:
|
| 40 |
"""Container for topic clustering / classification samples."""
|
|
|
|
| 41 |
text: str
|
| 42 |
topic: str
|
| 43 |
|
| 44 |
|
| 45 |
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 46 |
"""Dataset yielding encoder-decoder training pairs."""
|
|
|
|
| 47 |
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 48 |
self._examples = list(examples)
|
| 49 |
|
|
|
|
| 56 |
|
| 57 |
class EmotionDataset(Dataset[EmotionExample]):
|
| 58 |
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
|
|
|
| 59 |
def __init__(
|
| 60 |
self,
|
| 61 |
examples: Iterable[EmotionExample],
|
|
|
|
| 91 |
|
| 92 |
class TopicDataset(Dataset[TopicExample]):
|
| 93 |
"""Dataset that owns a LabelEncoder for topic ids."""
|
|
|
|
| 94 |
def __init__(
|
| 95 |
self,
|
| 96 |
examples: Iterable[TopicExample],
|
src/models/factory.py
CHANGED
|
@@ -102,7 +102,7 @@ def _load_pretrained_weights(
|
|
| 102 |
Load pretrained T5/FLAN-T5 weights into custom encoder/decoder.
|
| 103 |
|
| 104 |
T5 architecture compatibility with our custom Transformer:
|
| 105 |
-
- T5 uses Pre-LN (RMSNorm before sublayers)
|
| 106 |
- T5 uses relative position bias instead of absolute embeddings
|
| 107 |
-> We now load T5's relative position bias weights into our T5RelativePositionBias modules
|
| 108 |
-> This allows exact weight transfer without requiring fine-tuning
|
|
|
|
| 102 |
Load pretrained T5/FLAN-T5 weights into custom encoder/decoder.
|
| 103 |
|
| 104 |
T5 architecture compatibility with our custom Transformer:
|
| 105 |
+
- T5 uses Pre-LN (RMSNorm before sublayers) - matches our design
|
| 106 |
- T5 uses relative position bias instead of absolute embeddings
|
| 107 |
-> We now load T5's relative position bias weights into our T5RelativePositionBias modules
|
| 108 |
-> This allows exact weight transfer without requiring fine-tuning
|
src/training/metrics.py
CHANGED
|
@@ -90,7 +90,7 @@ def calculate_bertscore(
|
|
| 90 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 91 |
|
| 92 |
try:
|
| 93 |
-
from bert_score import score as bert_score
|
| 94 |
except ImportError:
|
| 95 |
print("Warning: bert-score not installed. Run: pip install bert-score")
|
| 96 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
|
|
| 90 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 91 |
|
| 92 |
try:
|
| 93 |
+
from bert_score import score as bert_score # type: ignore[import-not-found]
|
| 94 |
except ImportError:
|
| 95 |
print("Warning: bert-score not installed. Run: pip install bert-score")
|
| 96 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
src/training/trainer.py
CHANGED
|
@@ -59,7 +59,7 @@ class TrainerConfig:
|
|
| 59 |
# Task sampling strategy: "round_robin" or "temperature"
|
| 60 |
# Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
|
| 61 |
# alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
|
| 62 |
-
task_sampling: str = "
|
| 63 |
task_sampling_alpha: float = 0.5
|
| 64 |
|
| 65 |
# Gradient conflict diagnostics
|
|
@@ -180,8 +180,8 @@ class Trainer:
|
|
| 180 |
if self.early_stopping:
|
| 181 |
val_loss = val_metrics.get("total_loss", float('inf'))
|
| 182 |
if self.early_stopping(val_loss):
|
| 183 |
-
tqdm.write(f"\
|
| 184 |
-
|
| 185 |
break
|
| 186 |
|
| 187 |
# Checkpoint
|
|
@@ -194,7 +194,7 @@ class Trainer:
|
|
| 194 |
pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
|
| 195 |
|
| 196 |
total_time = time.perf_counter() - total_start
|
| 197 |
-
print(f"\
|
| 198 |
return history
|
| 199 |
|
| 200 |
def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
|
|
@@ -214,7 +214,7 @@ class Trainer:
|
|
| 214 |
return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
|
| 215 |
|
| 216 |
self.scheduler = LambdaLR(self.optimizer, lr_lambda)
|
| 217 |
-
print(f"
|
| 218 |
|
| 219 |
def _run_epoch(
|
| 220 |
self,
|
|
|
|
| 59 |
# Task sampling strategy: "round_robin" or "temperature"
|
| 60 |
# Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
|
| 61 |
# alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
|
| 62 |
+
task_sampling: str = "temperature"
|
| 63 |
task_sampling_alpha: float = 0.5
|
| 64 |
|
| 65 |
# Gradient conflict diagnostics
|
|
|
|
| 180 |
if self.early_stopping:
|
| 181 |
val_loss = val_metrics.get("total_loss", float('inf'))
|
| 182 |
if self.early_stopping(val_loss):
|
| 183 |
+
tqdm.write(f"\nEarly stopping at epoch {epoch} (best loss: {self.early_stopping.best_value:.4f})")
|
| 184 |
+
|
| 185 |
break
|
| 186 |
|
| 187 |
# Checkpoint
|
|
|
|
| 194 |
pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
|
| 195 |
|
| 196 |
total_time = time.perf_counter() - total_start
|
| 197 |
+
print(f"\nTraining complete in {total_time/60:.1f} minutes")
|
| 198 |
return history
|
| 199 |
|
| 200 |
def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
|
|
|
|
| 214 |
return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
|
| 215 |
|
| 216 |
self.scheduler = LambdaLR(self.optimizer, lr_lambda)
|
| 217 |
+
print(f" LR schedule: cosine, {warmup} warmup, {total_steps} total steps")
|
| 218 |
|
| 219 |
def _run_epoch(
|
| 220 |
self,
|