Spaces:
Running
Running
Merge pull request #1 from Advancement
Browse filesAdvancement: Attention Pooling, Temperature Sampling, Comprehensive Evaluation & Code Cleanup
- README.md +6 -5
- configs/training/dev.yaml +3 -0
- configs/training/full.yaml +9 -2
- configs/training/medium.yaml +3 -0
- docs/architecture.md +368 -61
- docs/research_paper.tex +146 -90
- outputs/evaluation_report.json +250 -13
- outputs/training_history.json +180 -154
- pyproject.toml +6 -2
- scripts/build_discovery_dataset.py +6 -8
- scripts/demo_gradio.py +46 -129
- scripts/download_data.py +25 -29
- scripts/evaluate.py +205 -58
- scripts/profile_training.py +314 -0
- scripts/train.py +12 -10
- scripts/train_multiseed.py +198 -0
- scripts/visualize_training.py +12 -30
- src/data/dataset.py +81 -7
- src/models/factory.py +5 -3
- src/models/heads.py +39 -13
- src/training/metrics.py +214 -4
- src/training/trainer.py +98 -6
README.md
CHANGED
|
@@ -8,6 +8,7 @@ app_file: scripts/demo_gradio.py
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
|
|
|
| 11 |
# LexiMind
|
| 12 |
|
| 13 |
A multi-task NLP system for literary and academic text understanding. LexiMind performs **abstractive summarization**, **topic classification**, and **emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
|
|
@@ -17,7 +18,7 @@ A multi-task NLP system for literary and academic text understanding. LexiMind p
|
|
| 17 |
## What It Does
|
| 18 |
|
| 19 |
| Task | Description | Metric |
|
| 20 |
-
|------|-------------|--------|
|
| 21 |
| **Summarization** | Generates back-cover style book descriptions and paper abstracts from source text | BERTScore F1: **0.830** |
|
| 22 |
| **Topic Classification** | Classifies passages into 7 categories | Accuracy: **85.2%** |
|
| 23 |
| **Emotion Detection** | Identifies emotions from 28 fine-grained labels (multi-label) | Sample-avg F1: **0.199** |
|
|
@@ -31,7 +32,7 @@ The model is trained on literary text (Project Gutenberg + Goodreads description
|
|
| 31 |
LexiMind is a **custom Transformer implementation** that loads pre-trained weights from FLAN-T5-base via a factory module. The architecture is reimplemented from scratch for transparency, not wrapped from HuggingFace.
|
| 32 |
|
| 33 |
| Component | Detail |
|
| 34 |
-
|-----------|--------|
|
| 35 |
| Backbone | Encoder-Decoder Transformer (272M params) |
|
| 36 |
| Encoder / Decoder | 12 layers each |
|
| 37 |
| Hidden Dim | 768, 12 attention heads |
|
|
@@ -48,7 +49,7 @@ All three tasks share the encoder. Summarization uses the full encoder-decoder;
|
|
| 48 |
## Training Data
|
| 49 |
|
| 50 |
| Task | Source | Train Samples |
|
| 51 |
-
|------|--------|---------------|
|
| 52 |
| Summarization | Gutenberg + Goodreads (literary) | ~4K |
|
| 53 |
| Summarization | arXiv body → abstract (academic) | ~45K |
|
| 54 |
| Topic | 20 Newsgroups + Gutenberg + arXiv metadata | 3,402 |
|
|
@@ -131,7 +132,7 @@ docker run -p 7860:7860 leximind
|
|
| 131 |
|
| 132 |
## Project Structure
|
| 133 |
|
| 134 |
-
```
|
| 135 |
configs/
|
| 136 |
├── config.yaml # Main Hydra config
|
| 137 |
├── data/datasets.yaml # Dataset paths and tokenizer settings
|
|
@@ -208,4 +209,4 @@ GPL-3.0 — see [LICENSE](LICENSE) for details.
|
|
| 208 |
|
| 209 |
---
|
| 210 |
|
| 211 |
-
|
|
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
<!-- markdownlint-disable MD025 -->
|
| 12 |
# LexiMind
|
| 13 |
|
| 14 |
A multi-task NLP system for literary and academic text understanding. LexiMind performs **abstractive summarization**, **topic classification**, and **emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
|
|
|
|
| 18 |
## What It Does
|
| 19 |
|
| 20 |
| Task | Description | Metric |
|
| 21 |
+
| ------ | ------------- | -------- |
|
| 22 |
| **Summarization** | Generates back-cover style book descriptions and paper abstracts from source text | BERTScore F1: **0.830** |
|
| 23 |
| **Topic Classification** | Classifies passages into 7 categories | Accuracy: **85.2%** |
|
| 24 |
| **Emotion Detection** | Identifies emotions from 28 fine-grained labels (multi-label) | Sample-avg F1: **0.199** |
|
|
|
|
| 32 |
LexiMind is a **custom Transformer implementation** that loads pre-trained weights from FLAN-T5-base via a factory module. The architecture is reimplemented from scratch for transparency, not wrapped from HuggingFace.
|
| 33 |
|
| 34 |
| Component | Detail |
|
| 35 |
+
| ----------- | -------- |
|
| 36 |
| Backbone | Encoder-Decoder Transformer (272M params) |
|
| 37 |
| Encoder / Decoder | 12 layers each |
|
| 38 |
| Hidden Dim | 768, 12 attention heads |
|
|
|
|
| 49 |
## Training Data
|
| 50 |
|
| 51 |
| Task | Source | Train Samples |
|
| 52 |
+
| ------ | -------- | --------------- |
|
| 53 |
| Summarization | Gutenberg + Goodreads (literary) | ~4K |
|
| 54 |
| Summarization | arXiv body → abstract (academic) | ~45K |
|
| 55 |
| Topic | 20 Newsgroups + Gutenberg + arXiv metadata | 3,402 |
|
|
|
|
| 132 |
|
| 133 |
## Project Structure
|
| 134 |
|
| 135 |
+
```text
|
| 136 |
configs/
|
| 137 |
├── config.yaml # Main Hydra config
|
| 138 |
├── data/datasets.yaml # Dataset paths and tokenizer settings
|
|
|
|
| 209 |
|
| 210 |
---
|
| 211 |
|
| 212 |
+
Built by Oliver Perrin · Appalachian State University · 2025–2026
|
configs/training/dev.yaml
CHANGED
|
@@ -37,6 +37,9 @@ trainer:
|
|
| 37 |
max_val_samples: 300
|
| 38 |
early_stopping_patience: 5
|
| 39 |
log_grad_norm_frequency: 100
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Enable compile for speed (worth the startup cost)
|
| 42 |
compile_encoder: true
|
|
|
|
| 37 |
max_val_samples: 300
|
| 38 |
early_stopping_patience: 5
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
+
task_sampling: temperature
|
| 41 |
+
task_sampling_alpha: 0.5
|
| 42 |
+
gradient_conflict_frequency: 0
|
| 43 |
|
| 44 |
# Enable compile for speed (worth the startup cost)
|
| 45 |
compile_encoder: true
|
configs/training/full.yaml
CHANGED
|
@@ -22,12 +22,12 @@ optimizer:
|
|
| 22 |
|
| 23 |
scheduler:
|
| 24 |
name: cosine
|
| 25 |
-
warmup_steps: 300 # ~0.
|
| 26 |
|
| 27 |
trainer:
|
| 28 |
max_epochs: 8 # Reduced from 12 - early stopping will catch plateau anyway
|
| 29 |
gradient_clip_norm: 1.0
|
| 30 |
-
gradient_accumulation_steps: 4 # Reduced from 8
|
| 31 |
validation_max_length: 128
|
| 32 |
label_smoothing: 0.1
|
| 33 |
task_weights:
|
|
@@ -38,6 +38,13 @@ trainer:
|
|
| 38 |
max_val_samples: 3000 # Enough for stable metrics
|
| 39 |
early_stopping_patience: 3 # Stop quickly when plateauing
|
| 40 |
log_grad_norm_frequency: 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
compile_encoder: true
|
| 43 |
compile_decoder: true
|
|
|
|
| 22 |
|
| 23 |
scheduler:
|
| 24 |
name: cosine
|
| 25 |
+
warmup_steps: 300 # ~0.24 epoch warmup (1227 optimizer steps/epoch)
|
| 26 |
|
| 27 |
trainer:
|
| 28 |
max_epochs: 8 # Reduced from 12 - early stopping will catch plateau anyway
|
| 29 |
gradient_clip_norm: 1.0
|
| 30 |
+
gradient_accumulation_steps: 4 # Reduced from 8, 2x faster optimizer steps
|
| 31 |
validation_max_length: 128
|
| 32 |
label_smoothing: 0.1
|
| 33 |
task_weights:
|
|
|
|
| 38 |
max_val_samples: 3000 # Enough for stable metrics
|
| 39 |
early_stopping_patience: 3 # Stop quickly when plateauing
|
| 40 |
log_grad_norm_frequency: 200
|
| 41 |
+
# Task sampling: "round_robin" or "temperature"
|
| 42 |
+
# Temperature sampling: p_i proportional to n_i^alpha, reduces dominance of large tasks
|
| 43 |
+
task_sampling: temperature
|
| 44 |
+
task_sampling_alpha: 0.5
|
| 45 |
+
# Gradient conflict diagnostics: compute inter-task gradient cosine similarity
|
| 46 |
+
# every N steps (0 = disabled). Helps diagnose negative transfer.
|
| 47 |
+
gradient_conflict_frequency: 0
|
| 48 |
|
| 49 |
compile_encoder: true
|
| 50 |
compile_decoder: true
|
configs/training/medium.yaml
CHANGED
|
@@ -37,6 +37,9 @@ trainer:
|
|
| 37 |
max_val_samples: 2500
|
| 38 |
early_stopping_patience: 3 # More patience
|
| 39 |
log_grad_norm_frequency: 100
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
compile_encoder: true
|
| 42 |
compile_decoder: true
|
|
|
|
| 37 |
max_val_samples: 2500
|
| 38 |
early_stopping_patience: 3 # More patience
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
+
task_sampling: temperature
|
| 41 |
+
task_sampling_alpha: 0.5
|
| 42 |
+
gradient_conflict_frequency: 0
|
| 43 |
|
| 44 |
compile_encoder: true
|
| 45 |
compile_decoder: true
|
docs/architecture.md
CHANGED
|
@@ -2,88 +2,395 @@
|
|
| 2 |
|
| 3 |
## Overview
|
| 4 |
|
| 5 |
-
LexiMind
|
| 6 |
|
| 7 |
-
|
| 8 |
-
2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via `MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from configuration files.
|
| 9 |
-
3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and Gradio UI.
|
| 10 |
|
| 11 |
-
## Custom Transformer
|
| 12 |
|
| 13 |
-
The
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
### Weight Loading from FLAN-T5
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
-
|
| 28 |
-
- **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
|
| 29 |
-
- **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
|
| 30 |
-
- **RMSNorm weights:** Direct transfer (both use RMSNorm without bias)
|
| 31 |
-
- **LM head:** Loaded from T5's `lm_head`
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
- `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
|
| 39 |
-
- `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
|
| 40 |
-
- `src/models/heads.py` – ClassificationHead (mean pooling) and LMHead (with weight tying)
|
| 41 |
-
- `src/models/multitask.py` – Routes inputs to task-specific heads
|
| 42 |
-
- `src/models/factory.py` – Builds models and loads FLAN-T5 weights
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
| ---- | ------- | ---- | ------ |
|
| 54 |
-
| Summarization | BookSum + arXiv | ~90K | Text→Summary |
|
| 55 |
-
| Emotion | GoEmotions | ~43K | 28 emotions (multi-label) |
|
| 56 |
-
| Topic | Books + Papers | 3.4K | 7 categories (Arts, Business, Fiction, History, Philosophy, Science, Technology) |
|
| 57 |
-
| Books | Gutenberg (prose chunks) | ~30K | Literary text |
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
- **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
|
| 63 |
-
- **Subword tokenization:** Unigram-based (vs BART's BPE)
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
- Mixed precision training (bfloat16 on Ampere/Ada GPUs)
|
| 69 |
-
- Gradient accumulation for larger effective batch sizes
|
| 70 |
-
- Per-task loss weighting and label smoothing
|
| 71 |
-
- Early stopping based on validation loss
|
| 72 |
-
- Cosine learning rate schedule with warmup
|
| 73 |
-
- **torch.compile:** JIT compilation with Inductor backend for 20-40% speedup
|
| 74 |
-
- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
-
|
| 79 |
-
- `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
|
| 80 |
-
- The CLI (`scripts/inference.py`) drives the pipeline from the command line
|
| 81 |
-
- Gradio demo (`scripts/demo_gradio.py`) provides an interactive web interface
|
| 82 |
|
| 83 |
-
##
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
## Overview
|
| 4 |
|
| 5 |
+
LexiMind is a **272M parameter encoder-decoder transformer** initialized from Google's FLAN-T5-base, trained jointly on three tasks: abstractive summarization, topic classification, and multi-label emotion detection. The project spans data preparation, custom model architecture, multi-task training, evaluation, and a Gradio-based discovery demo.
|
| 6 |
|
| 7 |
+
## Model Architecture
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
### Backbone: Custom Transformer (FLAN-T5-base Initialization)
|
| 10 |
|
| 11 |
+
The model is a from-scratch PyTorch implementation of a T5-style encoder-decoder. We do **not** use HuggingFace's `T5ForConditionalGeneration` — instead, every component (attention, FFN, normalization, positional encoding) is implemented manually in `src/models/`, then FLAN-T5 weights are loaded layer by layer in `src/models/factory.py`.
|
| 12 |
|
| 13 |
+
```text
|
| 14 |
+
Input Text
|
| 15 |
+
│
|
| 16 |
+
▼
|
| 17 |
+
┌─────────────────────────────────────────────┐
|
| 18 |
+
│ Shared Encoder (12 layers, 768d, 12 heads) │
|
| 19 |
+
│ ┌────────────────────────────────────────┐ │
|
| 20 |
+
│ │ Layers 0-3: FROZEN (FLAN-T5 weights) │ │
|
| 21 |
+
│ │ Layers 4-11: TRAINABLE (fine-tuned) │ │
|
| 22 |
+
│ └────────────────────────────────────────┘ │
|
| 23 |
+
│ Pre-LN RMSNorm │ T5 Relative Position Bias │
|
| 24 |
+
│ FlashAttention (SDPA) │ Gated-GELU FFN │
|
| 25 |
+
└────────────┬──────────────┬──────────────┬───┘
|
| 26 |
+
│ │ │
|
| 27 |
+
┌────────▼────────┐ ┌─▼──────────┐ ┌▼───────────┐
|
| 28 |
+
│ Decoder │ │ Attention │ │ Mean │
|
| 29 |
+
│ (12 layers) │ │ Pooling │ │ Pooling │
|
| 30 |
+
│ Causal + Cross │ │ (learned) │ │ │
|
| 31 |
+
│ Attention │ │ │ │ │ │ │
|
| 32 |
+
│ │ │ │ MLP 768→ │ │ Linear │
|
| 33 |
+
│ LM Head │ │ 384→28 │ │ 768→7 │
|
| 34 |
+
│ (tied weights) │ │ │ │ │
|
| 35 |
+
└────────┬────────┘ └─────┬───────┘ └──────┬─────┘
|
| 36 |
+
│ │ │
|
| 37 |
+
Summarization Emotion (28) Topic (7)
|
| 38 |
+
(generative) (multi-label) (single-label)
|
| 39 |
+
```
|
| 40 |
|
| 41 |
+
### Encoder
|
| 42 |
+
|
| 43 |
+
**File**: `src/models/encoder.py` (317 lines)
|
| 44 |
+
|
| 45 |
+
- 12 transformer layers, 768-dimensional, 12 attention heads
|
| 46 |
+
- **Pre-Layer Normalization (Pre-LN)** using T5-style RMSNorm — normalization applied *before* each sublayer, not after. This is the modern standard (LLaMA, T5 v1.1+, PaLM).
|
| 47 |
+
- **T5 Relative Position Bias**: Bucketed log-linear position bias computed in the attention layer. Bidirectional (encoder attends in both directions). Shared across layers (computed once, passed to all layers).
|
| 48 |
+
- **FlashAttention**: Via PyTorch 2.0's `F.scaled_dot_product_attention`, which automatically selects the optimal kernel (Flash, memory-efficient, or math fallback). **Note**: T5 does NOT scale attention scores by 1/√d_k — the `scale_scores=False` flag preserves this behavior.
|
| 49 |
+
- **Gated-GELU FFN**: Two linear projections (gate + up) element-wise multiplied, then a down projection. Matches T5's `DenseGatedGeluDense`.
|
| 50 |
+
- **Gradient checkpointing**: Optional per-layer activation recomputation to reduce VRAM (enabled in our full training config, saves ~2-3 GB).
|
| 51 |
+
- Bottom 4 layers are frozen during fine-tuning to preserve FLAN-T5's general language representations.
|
| 52 |
+
|
| 53 |
+
The encoder processes all input text and produces contextualized representations that are consumed by all three task heads.
|
| 54 |
+
|
| 55 |
+
### Decoder (Summarization Only)
|
| 56 |
+
|
| 57 |
+
**File**: `src/models/decoder.py` (749 lines)
|
| 58 |
+
|
| 59 |
+
- 12 transformer layers, 768-dimensional, 12 attention heads
|
| 60 |
+
- ~136M parameters — roughly half the total model
|
| 61 |
+
- **Masked self-attention** (causal mask prevents attending to future positions)
|
| 62 |
+
- **Cross-attention** to encoder outputs (allows decoder to attend to the full input)
|
| 63 |
+
- **KV-cache** for efficient autoregressive generation — incremental key/value computation avoids recomputing previous positions
|
| 64 |
+
- **Greedy decoding** with:
|
| 65 |
+
- No-repeat n-gram blocking (`no_repeat_ngram_size=3`)
|
| 66 |
+
- Repetition penalty (1.2x)
|
| 67 |
+
- Length penalty
|
| 68 |
+
- Min/max length constraints
|
| 69 |
+
- **LM Head**: Linear projection from 768d → 32,128 vocab. **Weight-tied** with decoder token embeddings (reduces parameters and improves coherence).
|
| 70 |
+
|
| 71 |
+
The decoder is exclusive to summarization. Classification tasks only use the encoder.
|
| 72 |
+
|
| 73 |
+
### Task Heads
|
| 74 |
+
|
| 75 |
+
**File**: `src/models/heads.py` (221 lines)
|
| 76 |
+
|
| 77 |
+
#### Emotion Head (Attention Pooling + MLP)
|
| 78 |
+
|
| 79 |
+
- **AttentionPooling**: A single linear layer (`nn.Linear(768, 1, bias=False)`) serves as a learned query. It computes softmax attention weights over all encoder positions, producing a weighted sum. This allows the model to focus on emotionally salient tokens (e.g., "grateful", "hilarious") rather than averaging the entire 512-token sequence. Padding is masked before softmax.
|
| 80 |
+
- **2-layer MLP**: 768 → 384 (GELU) → 28. The hidden layer provides nonlinear feature transformation before the 28-way multi-label output.
|
| 81 |
+
- **Loss**: BCEWithLogitsLoss (binary cross-entropy per class)
|
| 82 |
+
- **Inference threshold**: 0.3 (lowered from default 0.5 because 28-class multi-label predictions have lower per-class confidence)
|
| 83 |
+
|
| 84 |
+
#### Topic Head (Mean Pooling + Linear)
|
| 85 |
+
|
| 86 |
+
- **Mean pooling** over encoder positions (attention-mask-aware)
|
| 87 |
+
- **Single linear layer**: 768 → 7
|
| 88 |
+
- **Loss**: CrossEntropyLoss
|
| 89 |
+
- **Task weight**: 0.3 (reduced to prevent overfitting on the small 3.4K dataset)
|
| 90 |
+
|
| 91 |
+
#### Summarization Head (Decoder + LM Head)
|
| 92 |
+
|
| 93 |
+
- Full decoder (described above) + weight-tied LM head
|
| 94 |
+
- **Loss**: CrossEntropyLoss with label smoothing (0.1) and `-100` ignore index for padding
|
| 95 |
+
- **Task weight**: 1.0
|
| 96 |
+
|
| 97 |
+
### Multi-Task Router
|
| 98 |
+
|
| 99 |
+
**File**: `src/models/multitask.py` (263 lines)
|
| 100 |
+
|
| 101 |
+
The `MultiTaskModel` class routes `forward(task, inputs)` calls to the correct head:
|
| 102 |
+
|
| 103 |
+
- **Classification** (`emotion`, `topic`): encoder → pool → classify
|
| 104 |
+
- **Generation** (`summarization`): encoder → decoder → LM head
|
| 105 |
+
|
| 106 |
+
A `memory.clone()` call between encoder and decoder output prevents CUDA Graph buffer reuse issues when using `torch.compile`.
|
| 107 |
|
| 108 |
### Weight Loading from FLAN-T5
|
| 109 |
|
| 110 |
+
**File**: `src/models/factory.py` (571 lines)
|
| 111 |
+
|
| 112 |
+
Weights are transferred from HuggingFace's `google/flan-t5-base` checkpoint layer by layer:
|
| 113 |
+
|
| 114 |
+
| FLAN-T5 Component | Our Component |
|
| 115 |
+
| --- | --- |
|
| 116 |
+
| `shared.weight` | `encoder.embed_tokens.weight` and `decoder.embed_tokens.weight` |
|
| 117 |
+
| `encoder.block.{i}.layer.0.SelfAttention.{q,k,v,o}` | `encoder.layers.{i}.self_attn.{q,k,v,out}_proj.weight` |
|
| 118 |
+
| `encoder.block.{i}.layer.1.DenseReluDense.wi_0/wi_1/wo` | `encoder.layers.{i}.ffn.gate/up_proj/down_proj.weight` |
|
| 119 |
+
| `encoder.block.{i}.layer.{0,1}.layer_norm.weight` | `encoder.layers.{i}.norm{1,2}.weight` |
|
| 120 |
+
| `encoder.block.0.layer.0.SelfAttention.relative_attention_bias` | `encoder.layers.0.self_attn.attn.position_bias.relative_attention_bias` |
|
| 121 |
+
| `lm_head.weight` | `summarization_head.projection.weight` |
|
| 122 |
+
|
| 123 |
+
Vocab size mismatch (T5: 32,100 → ours: 32,128) is handled by zero-padding the embedding matrix.
|
| 124 |
+
|
| 125 |
+
### Available but Unused Components
|
| 126 |
+
|
| 127 |
+
These are implemented but not activated in the current configuration:
|
| 128 |
+
|
| 129 |
+
- **LoRA adapters** on Q and V projections in `MultiHeadAttention` — for parameter-efficient fine-tuning
|
| 130 |
+
- **Rotary Position Embeddings (RoPE)** — alternative to T5's relative position bias
|
| 131 |
+
- **4-bit/8-bit quantization** via bitsandbytes — for inference on constrained hardware
|
| 132 |
+
- **TokenClassificationHead** — for NER/POS tasks
|
| 133 |
+
- **ProjectionHead** — for contrastive/representation learning
|
| 134 |
+
- **LLaMA weight loading** — `_load_llama_weights()` for loading Gemma/LLaMA checkpoints
|
| 135 |
+
|
| 136 |
+
## Tokenization
|
| 137 |
+
|
| 138 |
+
**File**: `src/data/tokenization.py` (157 lines)
|
| 139 |
+
|
| 140 |
+
Wraps HuggingFace's `AutoTokenizer` configured for FLAN-T5:
|
| 141 |
+
|
| 142 |
+
- **SentencePiece** (Unigram) tokenizer, 32,128 vocabulary
|
| 143 |
+
- Special tokens: `pad=0`, `eos=1`, no explicit BOS (decoder starts with pad token, per T5 convention)
|
| 144 |
+
- Max sequence length: 512 tokens (encoder), 128 tokens (decoder during validation generation)
|
| 145 |
+
- Classification tasks use a reduced max length of 256 tokens (sufficient for classification, saves compute)
|
| 146 |
+
|
| 147 |
+
## Datasets
|
| 148 |
+
|
| 149 |
+
**File**: `src/data/dataset.py` (316 lines), `src/data/dataloader.py` (174 lines)
|
| 150 |
+
|
| 151 |
+
| Task | Dataset Source | Train Size | Val Size | Test Size |
|
| 152 |
+
| ------ | --------------- | ----------- | --------- | ---------- |
|
| 153 |
+
| Summarization | arXiv abstracts (~45K) + Goodreads book descriptions (~4K) | ~49K | ~2.7K | ~2.7K |
|
| 154 |
+
| Emotion | GoEmotions (Reddit comments, 28 labels) | ~43K | ~5.4K | — |
|
| 155 |
+
| Topic | arXiv categories + Gutenberg subjects → 7 classes | ~3.2K | ~189 | — |
|
| 156 |
+
|
| 157 |
+
**Cross-task deduplication**: `deduplicate_across_tasks()` uses MD5 fingerprinting on normalized text prefixes (200 chars) to detect and remove overlapping documents between summarization and topic datasets (both draw from arXiv and Gutenberg).
|
| 158 |
+
|
| 159 |
+
**Data pipeline**: Each task has a typed `Dataset` class and a corresponding `Collator` that handles tokenization, padding, and label preparation. Collators are passed to PyTorch `DataLoader` instances created by factory functions (`build_*_dataloader`).
|
| 160 |
+
|
| 161 |
+
## Training
|
| 162 |
+
|
| 163 |
+
**File**: `src/training/trainer.py` (527 lines)
|
| 164 |
+
|
| 165 |
+
### Training Loop
|
| 166 |
+
|
| 167 |
+
Each epoch iterates through batches using **temperature-based task sampling**:
|
| 168 |
+
|
| 169 |
+
1. **Sample task** with probability p_i proportional to n_i^0.5 where n_i is dataset size
|
| 170 |
+
- Summarization (~49K): ~45% of steps
|
| 171 |
+
- Emotion (~43K): ~43% of steps
|
| 172 |
+
- Topic (~3.4K): ~12% of steps
|
| 173 |
+
2. **Forward pass** under `torch.autocast(dtype=bfloat16)` mixed precision
|
| 174 |
+
3. **Compute task-specific loss** with task weight (summ=1.0, emotion=1.0, topic=0.3)
|
| 175 |
+
4. **Backward pass** and accumulate gradients (4 accumulation steps → effective batch size 40)
|
| 176 |
+
5. **Optimizer step** every 4 batches: clip gradients (max norm 1.0), AdamW step, cosine LR step
|
| 177 |
+
|
| 178 |
+
### Training Configuration (full.yaml)
|
| 179 |
+
|
| 180 |
+
| Parameter | Value | Rationale |
|
| 181 |
+
| ----------- | ------- | ----------- |
|
| 182 |
+
| Batch size | 10 | Fits ~10GB VRAM on RTX 4070 12GB |
|
| 183 |
+
| Gradient accumulation | 4 | Effective batch size 40 |
|
| 184 |
+
| Learning rate | 3e-5 | Standard for fine-tuning T5 |
|
| 185 |
+
| Weight decay | 0.01 | Standard AdamW regularization |
|
| 186 |
+
| Warmup steps | 300 | ~0.5 epochs of linear warmup |
|
| 187 |
+
| Max epochs | 8 | Val loss still improving at epoch 8 |
|
| 188 |
+
| LR schedule | Cosine | Decays to 0.1x base LR, flattens near step 8000 |
|
| 189 |
+
| Early stopping | Patience 3 | Never triggered (val loss monotonically decreased) |
|
| 190 |
+
| Label smoothing | 0.1 | Summarization cross-entropy only |
|
| 191 |
+
| Task weights | summ=1.0, emot=1.0, topic=0.3 | Reduced topic weight to prevent overfitting |
|
| 192 |
+
| Task sampling | Temperature (alpha=0.5) | Square-root proportional sampling |
|
| 193 |
+
| Frozen encoder layers | 0-3 | Preserves FLAN-T5's general language knowledge |
|
| 194 |
+
| Gradient checkpointing | Enabled | Saves ~2-3 GB VRAM |
|
| 195 |
+
| torch.compile | Both encoder and decoder | ~20-40% speedup via Inductor backend |
|
| 196 |
+
|
| 197 |
+
### Mixed Precision
|
| 198 |
+
|
| 199 |
+
The RTX 4070 (Ada Lovelace, compute capability 8.9) has dedicated BF16 tensor cores:
|
| 200 |
+
|
| 201 |
+
- All forward/backward passes run under `torch.autocast("cuda", dtype=torch.bfloat16)`
|
| 202 |
+
- BF16 has the same exponent range as FP32 (8 bits), so no GradScaler is needed (unlike FP16)
|
| 203 |
+
- Loss computation and softmax remain in FP32 (handled automatically by autocast)
|
| 204 |
+
- Encoder/decoder layers include `clamp(min=-65504, max=65504)` stability guards (carried over from HuggingFace T5)
|
| 205 |
+
|
| 206 |
+
### Optimizer
|
| 207 |
+
|
| 208 |
+
- **Fused AdamW**: CUDA-native fused kernel (`torch.optim.AdamW(fused=True)`), ~5-10% faster than standard AdamW
|
| 209 |
+
- Betas: (0.9, 0.98) — slightly faster momentum decay than default
|
| 210 |
+
- Epsilon: 1e-6
|
| 211 |
+
|
| 212 |
+
### Gradient Conflict Diagnostics (Available, Disabled)
|
| 213 |
+
|
| 214 |
+
The trainer includes `_compute_gradient_conflicts()` which:
|
| 215 |
+
|
| 216 |
+
1. Computes per-task gradients independently
|
| 217 |
+
2. Flattens all parameter gradients into a single vector per task
|
| 218 |
+
3. Computes pairwise cosine similarity between task gradient vectors
|
| 219 |
+
4. Logs cosine similarity and binary conflict flags to MLflow
|
| 220 |
+
|
| 221 |
+
This is a **diagnostic only** — it does not modify gradients (unlike PCGrad/CAGrad). Disabled by default (`gradient_conflict_frequency: 0`) because it requires extra backward passes per measurement.
|
| 222 |
+
|
| 223 |
+
### MLflow Tracking
|
| 224 |
+
|
| 225 |
+
Training metrics (losses, accuracy, F1, ROUGE, learning rate) are logged to MLflow with a SQLite backend (`mlruns.db`). This enables experiment comparison across training runs.
|
| 226 |
+
|
| 227 |
+
## Evaluation
|
| 228 |
+
|
| 229 |
+
**File**: `scripts/evaluate.py` (538 lines), `src/training/metrics.py` (452 lines)
|
| 230 |
+
|
| 231 |
+
### Metrics
|
| 232 |
+
|
| 233 |
+
| Task | Metrics |
|
| 234 |
+
| ------ | --------- |
|
| 235 |
+
| Summarization | ROUGE-1, ROUGE-2, ROUGE-L (`rouge-score` library), BLEU-4 (NLTK), optional BERTScore |
|
| 236 |
+
| Emotion | Sample-averaged F1, macro F1, micro F1, per-class P/R/F1, per-class threshold tuning |
|
| 237 |
+
| Topic | Accuracy, macro F1, per-class P/R/F1, confusion matrix |
|
| 238 |
+
| All | Bootstrap 95% confidence intervals (1000 resamples), paired bootstrap test |
|
| 239 |
+
|
| 240 |
+
### Per-Class Threshold Tuning (Emotion)
|
| 241 |
+
|
| 242 |
+
For multi-label classification, different emotion classes have very different base rates and prediction confidence. The tuning procedure:
|
| 243 |
+
|
| 244 |
+
1. For each of the 28 emotion classes independently
|
| 245 |
+
2. Sweep threshold tau in {0.1, 0.2, ..., 0.9}
|
| 246 |
+
3. Select the threshold that maximizes per-class F1 on the validation set
|
| 247 |
+
4. Re-compute all metrics with the tuned thresholds
|
| 248 |
+
|
| 249 |
+
This improved macro F1 from 0.143 (default 0.5 threshold) to 0.294.
|
| 250 |
+
|
| 251 |
+
### BERTScore
|
| 252 |
+
|
| 253 |
+
Available via `--include-bertscore` flag in evaluation (opt-in). Uses `roberta-large` for semantic similarity. Not included in primary evaluation due to computational cost and difficulty interpreting absolute values.
|
| 254 |
+
|
| 255 |
+
## Inference
|
| 256 |
+
|
| 257 |
+
**File**: `src/inference/pipeline.py` (217 lines), `src/inference/factory.py` (91 lines)
|
| 258 |
+
|
| 259 |
+
`InferencePipeline` loads a trained checkpoint and runs all three tasks:
|
| 260 |
+
|
| 261 |
+
- **Summarization**: Greedy decode with KV-cache, no-repeat trigram blocking, repetition penalty 1.2
|
| 262 |
+
- **Emotion**: Sigmoid probabilities → threshold at 0.3 → emit labels above threshold
|
| 263 |
+
- **Topic**: Softmax → argmax → emit top label with confidence score
|
| 264 |
+
|
| 265 |
+
`create_inference_pipeline()` reconstructs the full pipeline from checkpoint + labels + tokenizer artifacts.
|
| 266 |
+
|
| 267 |
+
## Serving
|
| 268 |
+
|
| 269 |
+
### Gradio Demo
|
| 270 |
+
|
| 271 |
+
**File**: `scripts/demo_gradio.py` (507 lines)
|
| 272 |
+
|
| 273 |
+
A discovery interface for browsing pre-analyzed books and papers. Loads a pre-computed discovery dataset (not live inference) from HuggingFace Hub (`OliverPerrin/LexiMind-Discovery`). Users can browse by topic, emotion, or keyword search.
|
| 274 |
+
|
| 275 |
+
### FastAPI
|
| 276 |
+
|
| 277 |
+
**Files**: `src/api/app.py` (18 lines), `src/api/routes.py` (49 lines)
|
| 278 |
+
|
| 279 |
+
Minimal REST API with a single `/summarize` endpoint that runs all three tasks and returns JSON results. Uses dependency injection for the inference pipeline.
|
| 280 |
+
|
| 281 |
+
### CLI
|
| 282 |
+
|
| 283 |
+
**File**: `scripts/inference.py` (108 lines)
|
| 284 |
|
| 285 |
+
Command-line interface accepting text from arguments or file, running batch prediction, and printing JSON output.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
### Profiling
|
| 288 |
|
| 289 |
+
**File**: `scripts/profile_training.py`
|
| 290 |
|
| 291 |
+
Wraps a few training steps with `torch.profiler` to capture:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
- CUDA kernel timing (per-operator breakdown)
|
| 294 |
+
- GPU memory usage (peak allocations)
|
| 295 |
+
- CPU/GPU overlap and idle time
|
| 296 |
+
- Chrome trace (viewable in `chrome://tracing` or [Perfetto UI](https://ui.perfetto.dev))
|
| 297 |
+
- CUDA stacks for flamegraph generation
|
| 298 |
|
| 299 |
+
```bash
|
| 300 |
+
python scripts/profile_training.py # 20 steps by default
|
| 301 |
+
PROFILE_STEPS=40 python scripts/profile_training.py # custom step count
|
| 302 |
+
```
|
| 303 |
|
| 304 |
+
Outputs go to `outputs/profile/` — TensorBoard traces, Chrome trace JSON, and stack files.
|
| 305 |
|
| 306 |
+
## Training Results (8 Epochs, RTX 4070, ~9 Hours)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
+
| Epoch | Train Loss | Val Loss | Summ Val Loss | Emotion Val F1 | Topic Val Acc |
|
| 309 |
+
| ------- | ----------- | --------- | --------------- | ---------------- | --------------- |
|
| 310 |
+
| 1 | 6.106 | 4.298 | 3.815 | 0.197 | 70.4% |
|
| 311 |
+
| 2 | 5.528 | 4.027 | 3.739 | 0.301 | 84.7% |
|
| 312 |
+
| 3 | 5.379 | 3.973 | 3.700 | 0.347 | 84.2% |
|
| 313 |
+
| 4 | 5.303 | 3.951 | 3.677 | 0.404 | 85.7% |
|
| 314 |
+
| 5 | 5.208 | 3.940 | 3.665 | 0.431 | 86.3% |
|
| 315 |
+
| 6 | 5.231 | 3.925 | 3.658 | 0.452 | 87.3% |
|
| 316 |
+
| 7 | 5.154 | 3.928 | 3.655 | 0.458 | 85.7% |
|
| 317 |
+
| 8 | 5.178 | 3.925 | 3.653 | 0.459 | 85.7% |
|
| 318 |
|
| 319 |
+
Key observations:
|
|
|
|
|
|
|
| 320 |
|
| 321 |
+
- Early stopping never triggered (val loss monotonically decreased through all 8 epochs)
|
| 322 |
+
- Topic val accuracy plateaued at epoch 2 (~85%), while topic train accuracy reached 98% — overfitting expected on 3.4K samples
|
| 323 |
+
- Emotion F1 improved steadily across all 8 epochs (0.197 → 0.459), showing attention pooling continues learning throughout
|
| 324 |
+
- Summarization loss plateaued after epoch 5 (~3.66)
|
| 325 |
+
- Train loss was lowest at epoch 7 (5.154), slightly higher at epoch 8 (5.178) — normal variance
|
| 326 |
+
- LR schedule cosine curve flattens near step 8000 (0.1x floor)
|
| 327 |
|
| 328 |
+
## Final Evaluation Results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
| Task | Metric | Value | 95% CI |
|
| 331 |
+
| ------ | -------- | ------- | -------- |
|
| 332 |
+
| Summarization | ROUGE-1 | 0.310 | [0.306, 0.313] |
|
| 333 |
+
| Summarization | ROUGE-2 | 0.091 | — |
|
| 334 |
+
| Summarization | ROUGE-L | 0.185 | — |
|
| 335 |
+
| Summarization | BLEU-4 | 0.024 | — |
|
| 336 |
+
| Emotion | Sample F1 | 0.352 | [0.340, 0.366] |
|
| 337 |
+
| Emotion | Macro F1 | 0.143 | — |
|
| 338 |
+
| Emotion | Micro F1 | 0.443 | — |
|
| 339 |
+
| Emotion (tuned) | Macro F1 | 0.294 | — |
|
| 340 |
+
| Emotion (tuned) | Sample F1 | 0.503 | — |
|
| 341 |
+
| Topic | Accuracy | 85.7% | [80.4%, 91.0%] |
|
| 342 |
+
| Topic | Macro F1 | 0.854 | — |
|
| 343 |
|
| 344 |
+
Per-domain summarization: Academic ROUGE-1=0.319, Literary ROUGE-1=0.206.
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
+
## Project Structure
|
| 347 |
|
| 348 |
+
```text
|
| 349 |
+
LexiMind/
|
| 350 |
+
├── configs/ # Hydra configuration
|
| 351 |
+
│ ├── config.yaml # Main config (seeds, paths, device)
|
| 352 |
+
│ ├── data/datasets.yaml # Data paths
|
| 353 |
+
│ ├── model/ # Model configs (base, small, large)
|
| 354 |
+
│ └── training/ # Training configs (full, medium, dev)
|
| 355 |
+
├── src/
|
| 356 |
+
│ ├── models/
|
| 357 |
+
│ │ ├── encoder.py # TransformerEncoder (12 Pre-LN layers)
|
| 358 |
+
│ │ ├── decoder.py # TransformerDecoder with KV-cache
|
| 359 |
+
│ │ ├── attention.py # MultiHeadAttention, FlashAttention, T5 relative pos bias, LoRA, RoPE
|
| 360 |
+
│ │ ├── heads.py # AttentionPooling, ClassificationHead, LMHead
|
| 361 |
+
│ │ ├── multitask.py # MultiTaskModel (task routing)
|
| 362 |
+
│ │ ├── feedforward.py # Gated-GELU / SwiGLU / ReLU FFN
|
| 363 |
+
│ │ ├── positional_encoding.py # Sinusoidal + Learned positional encodings
|
| 364 |
+
│ │ ├── t5_layer_norm.py # RMSNorm (T5-style)
|
| 365 |
+
│ │ └── factory.py # Model construction + FLAN-T5 weight loading
|
| 366 |
+
│ ├── data/
|
| 367 |
+
│ │ ├── tokenization.py # HuggingFace tokenizer wrapper
|
| 368 |
+
│ │ ├── dataset.py # Typed datasets + JSONL loaders + cross-task dedup
|
| 369 |
+
│ │ └── dataloader.py # Task-specific collators + DataLoader factories
|
| 370 |
+
│ ├── training/
|
| 371 |
+
│ │ ├── trainer.py # Multi-task trainer (AMP, gradient accum, temperature sampling)
|
| 372 |
+
│ │ └── metrics.py # ROUGE, BLEU, BERTScore, F1 variants, bootstrap CI
|
| 373 |
+
│ ├── inference/
|
| 374 |
+
│ │ ├── pipeline.py # Multi-task inference pipeline
|
| 375 |
+
│ │ └── factory.py # Pipeline reconstruction from artifacts
|
| 376 |
+
│ ├── api/
|
| 377 |
+
│ │ ├── app.py # FastAPI application
|
| 378 |
+
│ │ └── routes.py # REST endpoints
|
| 379 |
+
│ └── utils/
|
| 380 |
+
│ ├── core.py # Device detection, seed setting
|
| 381 |
+
│ ├── io.py # Checkpoint save/load
|
| 382 |
+
│ └── labels.py # Label metadata I/O
|
| 383 |
+
├── scripts/
|
| 384 |
+
│ ├── train.py # Hydra-based training entry point
|
| 385 |
+
│ ├── evaluate.py # Full evaluation with all metrics
|
| 386 |
+
│ ├── inference.py # CLI inference
|
| 387 |
+
│ ├── demo_gradio.py # Gradio discovery demo
|
| 388 |
+
│ ├── visualize_training.py # Training visualization suite
|
| 389 |
+
│ ├── profile_training.py # PyTorch profiler for GPU analysis
|
| 390 |
+
│ ├── download_data.py # Data preparation from HuggingFace
|
| 391 |
+
│ └── build_discovery_dataset.py # Pre-compute discovery dataset
|
| 392 |
+
├── artifacts/ # Tokenizer + label exports
|
| 393 |
+
├── checkpoints/ # Model checkpoints (best.pt + per-epoch)
|
| 394 |
+
├── outputs/ # Evaluation reports, training history, visualizations
|
| 395 |
+
└── docs/ # Architecture docs + research paper
|
| 396 |
+
```
|
docs/research_paper.tex
CHANGED
|
@@ -44,7 +44,7 @@ Email: perrinot@appstate.edu}}
|
|
| 44 |
\maketitle
|
| 45 |
|
| 46 |
\begin{abstract}
|
| 47 |
-
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies
|
| 48 |
\end{abstract}
|
| 49 |
|
| 50 |
\begin{IEEEkeywords}
|
|
@@ -70,13 +70,13 @@ Our study addresses three research questions:
|
|
| 70 |
To answer these questions, we construct \textbf{LexiMind}, a multi-task system built on FLAN-T5-base \cite{chung2022scaling} that performs abstractive summarization, topic classification, and emotion detection. We conduct ablations comparing multi-task vs. single-task training, with vs. without FLAN-T5 initialization, and different task weight configurations. Our primary experimental contribution is the empirical characterization of transfer effects across these heterogeneous tasks:
|
| 71 |
|
| 72 |
\begin{itemize}
|
| 73 |
-
\item \textbf{Topic classification benefits
|
| 74 |
-
\item \textbf{Summarization is robust to MTL}, showing
|
| 75 |
-
\item \textbf{Emotion detection
|
| 76 |
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
\end{itemize}
|
| 78 |
|
| 79 |
-
We acknowledge important limitations: our results are from single-seed runs, we
|
| 80 |
|
| 81 |
%=============================================================================
|
| 82 |
\section{Related Work}
|
|
@@ -86,7 +86,9 @@ We acknowledge important limitations: our results are from single-seed runs, we
|
|
| 86 |
|
| 87 |
Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
|
| 88 |
|
| 89 |
-
\textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach
|
|
|
|
|
|
|
| 90 |
|
| 91 |
\textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
|
| 92 |
|
|
@@ -96,7 +98,7 @@ Most summarization benchmarks focus on news \cite{nallapati2016abstractive, nara
|
|
| 96 |
|
| 97 |
\subsection{Emotion Detection}
|
| 98 |
|
| 99 |
-
GoEmotions \cite{demszky2020goemotions} provides 28 fine-grained emotion labels from Reddit comments. The original work reports 0.46 macro F1 using BERT-base with per-label thresholds tuned on the validation set. Subsequent work achieves 0.35--0.46 macro F1 depending on the model and threshold strategy. Importantly, all published GoEmotions baselines use encoder-only architectures (BERT, RoBERTa) rather than encoder-decoder models like T5. Our setup differs in both architecture (encoder-decoder with
|
| 100 |
|
| 101 |
%=============================================================================
|
| 102 |
\section{Experimental Setup}
|
|
@@ -136,7 +138,9 @@ Emotion (28 labels) & GoEmotions (Reddit) & 43,410 & 5,426 & 5,427 \\
|
|
| 136 |
\end{tabular}
|
| 137 |
\end{table}
|
| 138 |
|
| 139 |
-
\textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed
|
|
|
|
|
|
|
| 140 |
|
| 141 |
\textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
|
| 142 |
|
|
@@ -154,12 +158,12 @@ LexiMind uses FLAN-T5-base (272M parameters) as the backbone, with a custom reim
|
|
| 154 |
|
| 155 |
Task-specific heads branch from the shared encoder:
|
| 156 |
\begin{itemize}
|
| 157 |
-
\item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing)
|
| 158 |
\item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
|
| 159 |
-
\item \textbf{Emotion}: Linear classifier on
|
| 160 |
\end{itemize}
|
| 161 |
|
| 162 |
-
\textbf{Architectural note.}
|
| 163 |
|
| 164 |
\subsection{Training Configuration}
|
| 165 |
|
|
@@ -175,19 +179,24 @@ All experiments use consistent hyperparameters unless otherwise noted:
|
|
| 175 |
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 176 |
\end{itemize}
|
| 177 |
|
| 178 |
-
\textbf{Task scheduling.} We use
|
| 179 |
|
| 180 |
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
\subsection{Baselines and Ablations}
|
| 183 |
|
| 184 |
We compare four configurations:
|
| 185 |
|
| 186 |
\begin{enumerate}
|
| 187 |
-
\item \textbf{Random/Majority}: Random predictions for classification;
|
| 188 |
\item \textbf{FLAN-T5-base (zero-shot)}: Pre-trained model with task-appropriate prompts, no fine-tuning.
|
| 189 |
-
\item \textbf{Single-Task}: Separate models fine-tuned on each task individually with identical hyperparameters.
|
| 190 |
-
\item \textbf{Multi-Task
|
|
|
|
| 191 |
\end{enumerate}
|
| 192 |
|
| 193 |
We additionally ablate FLAN-T5 initialization vs. random initialization to isolate transfer learning contribution.
|
|
@@ -195,50 +204,51 @@ We additionally ablate FLAN-T5 initialization vs. random initialization to isola
|
|
| 195 |
\subsection{Evaluation Metrics}
|
| 196 |
|
| 197 |
\begin{itemize}
|
| 198 |
-
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and
|
| 199 |
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 200 |
-
\item \textbf{Emotion}: Sample-averaged F1
|
| 201 |
\end{itemize}
|
| 202 |
|
| 203 |
-
\textbf{Statistical
|
| 204 |
|
| 205 |
%=============================================================================
|
| 206 |
\section{Results}
|
| 207 |
%=============================================================================
|
| 208 |
|
| 209 |
-
\subsection{Main Results
|
| 210 |
|
| 211 |
-
Table \ref{tab:main_results} compares
|
| 212 |
|
| 213 |
\begin{table}[htbp]
|
| 214 |
\centering
|
| 215 |
-
\caption{Main Results
|
| 216 |
\label{tab:main_results}
|
| 217 |
-
\begin{tabular}{
|
| 218 |
\toprule
|
| 219 |
-
\textbf{Task} & \textbf{Metric} & \textbf{Single
|
| 220 |
\midrule
|
| 221 |
-
\multirow{
|
| 222 |
-
& ROUGE-2 & 0.085 & \textbf{0.
|
| 223 |
-
& ROUGE-L & 0.179 & \textbf{0.
|
| 224 |
-
& BERTScore F1 & 0.821 & \textbf{0.830} \\
|
| 225 |
\midrule
|
| 226 |
-
\multirow{2}{*}{Topic} & Accuracy & 82.0\% & \textbf{85.
|
| 227 |
-
& Macro F1 & 0.812 & \textbf{0.
|
| 228 |
\midrule
|
| 229 |
-
Emotion & Sample
|
|
|
|
|
|
|
| 230 |
\bottomrule
|
| 231 |
\end{tabular}
|
| 232 |
\end{table}
|
| 233 |
|
| 234 |
-
\textbf{Key finding}:
|
| 235 |
|
| 236 |
\begin{itemize}
|
| 237 |
-
\item \textbf{
|
| 238 |
|
| 239 |
-
\item \textbf{
|
| 240 |
|
| 241 |
-
\item \textbf{
|
| 242 |
\end{itemize}
|
| 243 |
|
| 244 |
\subsection{Baseline Comparisons}
|
|
@@ -248,23 +258,21 @@ Table \ref{tab:baselines} contextualizes our results against trivial and zero-sh
|
|
| 248 |
|
| 249 |
\begin{table}[htbp]
|
| 250 |
\centering
|
| 251 |
-
\caption{Comparison with Baselines}
|
| 252 |
\label{tab:baselines}
|
| 253 |
\begin{tabular}{lccc}
|
| 254 |
\toprule
|
| 255 |
-
\textbf{Model} & \textbf{Summ (
|
| 256 |
\midrule
|
| 257 |
-
Random/Majority &
|
| 258 |
-
FLAN-T5 zero-shot & 0.
|
| 259 |
-
Single-Task & 0.
|
| 260 |
-
\textbf{Multi-Task} & \textbf{0.
|
| 261 |
\bottomrule
|
| 262 |
\end{tabular}
|
| 263 |
\end{table}
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
Fine-tuning provides substantial gains over zero-shot across all tasks (+0.106 BERTScore, +27\% topic accuracy, +0.11 emotion F1), demonstrating the importance of domain adaptation even with instruction-tuned models.
|
| 268 |
|
| 269 |
\subsection{Ablation: Transfer Learning Contribution}
|
| 270 |
|
|
@@ -272,21 +280,21 @@ Table \ref{tab:transfer_ablation} isolates the contribution of FLAN-T5 pre-train
|
|
| 272 |
|
| 273 |
\begin{table}[htbp]
|
| 274 |
\centering
|
| 275 |
-
\caption{Effect of Pre-trained Initialization (
|
| 276 |
\label{tab:transfer_ablation}
|
| 277 |
\begin{tabular}{lccc}
|
| 278 |
\toprule
|
| 279 |
-
\textbf{Initialization} & \textbf{Summ (
|
| 280 |
\midrule
|
| 281 |
-
Random & 0.
|
| 282 |
-
FLAN-T5-base & \textbf{0.
|
| 283 |
\midrule
|
| 284 |
-
\textit{Absolute gain} & +0.
|
| 285 |
\bottomrule
|
| 286 |
\end{tabular}
|
| 287 |
\end{table}
|
| 288 |
|
| 289 |
-
FLAN-T5 initialization provides large absolute gains across all tasks.
|
| 290 |
|
| 291 |
\subsection{Per-Class Topic Analysis}
|
| 292 |
|
|
@@ -294,60 +302,85 @@ Table \ref{tab:topic_breakdown} reveals per-class patterns in topic classificati
|
|
| 294 |
|
| 295 |
\begin{table}[htbp]
|
| 296 |
\centering
|
| 297 |
-
\caption{Per-Class Topic Classification (
|
| 298 |
\label{tab:topic_breakdown}
|
| 299 |
\begin{tabular}{lccc}
|
| 300 |
\toprule
|
| 301 |
\textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
|
| 302 |
\midrule
|
| 303 |
-
Arts & 0.93 & 0.
|
| 304 |
-
Business & 0.97 &
|
| 305 |
Fiction & 0.95 & 1.00 & 0.97 \\
|
| 306 |
-
History & 0.
|
| 307 |
-
Philosophy & 0.
|
| 308 |
-
Science & 0.
|
| 309 |
-
Technology & 0.
|
| 310 |
\midrule
|
| 311 |
-
\textit{Macro Avg} & 0.85 & 0.
|
| 312 |
\bottomrule
|
| 313 |
\end{tabular}
|
| 314 |
\end{table}
|
| 315 |
|
| 316 |
-
Fiction and Business achieve near-perfect classification (F1 $\geq$ 0.97), while Science shows the most confusion (F1 = 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
-
\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
\label{sec:emotion_analysis}
|
| 320 |
|
| 321 |
-
Our emotion sample-averaged F1 (0.
|
| 322 |
|
| 323 |
\begin{enumerate}
|
| 324 |
-
\item \textbf{
|
| 325 |
|
| 326 |
-
\item \textbf{
|
| 327 |
|
| 328 |
-
\item \textbf{
|
| 329 |
|
| 330 |
-
\item \textbf{
|
| 331 |
\end{enumerate}
|
| 332 |
|
| 333 |
-
\textbf{
|
|
|
|
|
|
|
| 334 |
|
| 335 |
\subsection{Training Dynamics}
|
| 336 |
|
| 337 |
-
Figure \ref{fig:training_curves} shows training progression over
|
| 338 |
|
| 339 |
\begin{figure}[htbp]
|
| 340 |
\centering
|
| 341 |
\includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
|
| 342 |
-
\caption{Training and validation loss
|
| 343 |
\label{fig:training_curves}
|
| 344 |
\end{figure}
|
| 345 |
|
| 346 |
Key observations:
|
| 347 |
\begin{itemize}
|
| 348 |
-
\item Topic classification converges by epoch 3 (
|
| 349 |
-
\item Summarization loss decreases
|
| 350 |
-
\item
|
|
|
|
| 351 |
\end{itemize}
|
| 352 |
|
| 353 |
%=============================================================================
|
|
@@ -356,19 +389,19 @@ Key observations:
|
|
| 356 |
|
| 357 |
\subsection{When Does MTL Help?}
|
| 358 |
|
| 359 |
-
Our results
|
| 360 |
|
| 361 |
-
\textbf{MTL helps when}: A small-dataset task (topic: 3.4K samples) shares domain with a large-dataset task (summarization: 49K literary/academic samples)
|
| 362 |
|
| 363 |
-
\textbf{MTL
|
| 364 |
|
| 365 |
-
\textbf{MTL is neutral
|
| 366 |
|
| 367 |
\subsection{Comparison to MTL Literature}
|
| 368 |
|
| 369 |
-
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---
|
| 370 |
|
| 371 |
-
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment.
|
| 372 |
|
| 373 |
\subsection{Implications for Practitioners}
|
| 374 |
|
|
@@ -379,9 +412,13 @@ Based on our findings:
|
|
| 379 |
|
| 380 |
\item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
|
| 381 |
|
| 382 |
-
\item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation.
|
| 383 |
|
| 384 |
-
\item \textbf{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
\end{enumerate}
|
| 386 |
|
| 387 |
\subsection{Limitations}
|
|
@@ -390,44 +427,48 @@ Based on our findings:
|
|
| 390 |
We identify several limitations that constrain the generalizability of our findings:
|
| 391 |
|
| 392 |
\begin{itemize}
|
| 393 |
-
\item \textbf{Single-seed results}:
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
\item \textbf{No
|
| 396 |
|
| 397 |
-
\item \textbf{
|
| 398 |
|
| 399 |
-
\item \textbf{
|
| 400 |
|
| 401 |
-
\item \textbf{No human evaluation}: ROUGE
|
| 402 |
|
| 403 |
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 404 |
|
| 405 |
-
\item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text.
|
| 406 |
\end{itemize}
|
| 407 |
|
| 408 |
\subsection{Future Work}
|
| 409 |
|
| 410 |
\begin{itemize}
|
| 411 |
-
\item \textbf{Gradient-conflict mitigation}:
|
| 412 |
|
| 413 |
-
\item \textbf{Parameter-efficient multi-tasking}:
|
|
|
|
|
|
|
| 414 |
|
| 415 |
\item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
|
| 416 |
|
| 417 |
-
\item \textbf{Multi-seed evaluation}:
|
| 418 |
|
| 419 |
\item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
|
| 420 |
|
| 421 |
-
\item \textbf{
|
| 422 |
\end{itemize}
|
| 423 |
|
| 424 |
%=============================================================================
|
| 425 |
\section{Conclusion}
|
| 426 |
%=============================================================================
|
| 427 |
|
| 428 |
-
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our
|
| 429 |
|
| 430 |
-
|
| 431 |
|
| 432 |
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 433 |
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
|
@@ -513,6 +554,21 @@ N. Houlsby et al., ``Parameter-efficient transfer learning for NLP,'' in \textit
|
|
| 513 |
\bibitem{lin2017focal}
|
| 514 |
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
\end{thebibliography}
|
| 517 |
|
| 518 |
\end{document}
|
|
|
|
| 44 |
\maketitle
|
| 45 |
|
| 46 |
\begin{abstract}
|
| 47 |
+
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies, we find that naive MTL with mean pooling and round-robin scheduling yields mixed results: topic classification gains +3.2\% accuracy, summarization remains stable, but emotion detection suffers negative transfer ($-$0.02 F1). We then show that two targeted interventions---\textbf{learned attention pooling} for the emotion head and \textbf{temperature-based task sampling} ($\alpha=0.5$)---eliminate negative transfer entirely, improving multi-task emotion sample-averaged F1 from 0.199 to 0.352 (+77\%), substantially exceeding the single-task baseline (0.218). With per-class threshold tuning, emotion macro F1 reaches 0.294. Topic classification improves to 85.7\% accuracy (95\% CI: [80.4\%, 91.0\%]), and summarization quality remains robust (ROUGE-1: 0.310, ROUGE-L: 0.185). Per-domain analysis reveals a significant quality gap between academic summaries (ROUGE-1: 0.319) and literary summaries (ROUGE-1: 0.206), attributable to the 11:1 training imbalance. We additionally contribute inter-task gradient conflict diagnostics, cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. Our analysis demonstrates that architectural isolation of task-specific components (attention pooling) combined with balanced optimization (temperature sampling) can convert negative transfer to positive transfer in MTL systems.
|
| 48 |
\end{abstract}
|
| 49 |
|
| 50 |
\begin{IEEEkeywords}
|
|
|
|
| 70 |
To answer these questions, we construct \textbf{LexiMind}, a multi-task system built on FLAN-T5-base \cite{chung2022scaling} that performs abstractive summarization, topic classification, and emotion detection. We conduct ablations comparing multi-task vs. single-task training, with vs. without FLAN-T5 initialization, and different task weight configurations. Our primary experimental contribution is the empirical characterization of transfer effects across these heterogeneous tasks:
|
| 71 |
|
| 72 |
\begin{itemize}
|
| 73 |
+
\item \textbf{Topic classification benefits from MTL} (+3.7\% accuracy over single-task), leveraging shared encoder representations from the larger summarization dataset.
|
| 74 |
+
\item \textbf{Summarization is robust to MTL}, showing stable ROUGE scores despite sharing encoder capacity with classification heads.
|
| 75 |
+
\item \textbf{Emotion detection: from negative to positive transfer}. Naive MTL with mean pooling degrades emotion F1 by $-$0.02; learned attention pooling combined with temperature-based task sampling reverses this, yielding +0.134 F1 over the single-task baseline.
|
| 76 |
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
\end{itemize}
|
| 78 |
|
| 79 |
+
We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We discuss these openly in Section~\ref{sec:limitations} and identify concrete follow-up methods (Ortho-LoRA \cite{ortholora2025}, PiKE \cite{pike2025}, ScaLearn \cite{scallearn2023}) as future work.
|
| 80 |
|
| 81 |
%=============================================================================
|
| 82 |
\section{Related Work}
|
|
|
|
| 86 |
|
| 87 |
Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
|
| 88 |
|
| 89 |
+
\textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach---but includes gradient conflict diagnostics (inter-task cosine similarity monitoring) to characterize optimization interference. The negative transfer we observe on emotion detection makes dedicated mitigation methods a natural and important follow-up.
|
| 90 |
+
|
| 91 |
+
\textbf{Recent advances in multi-task optimization.} Several recent methods address task interference more precisely. Ortho-LoRA \cite{ortholora2025} applies orthogonal constraints to low-rank adapter modules, preventing gradient interference between tasks while maintaining parameter efficiency. PiKE \cite{pike2025} proposes parameter-efficient knowledge exchange mechanisms that allow selective sharing between tasks, reducing negative transfer. ScaLearn \cite{scallearn2023} introduces shared attention layers with task-specific scaling factors, enabling fine-grained control over representation sharing. Complementary empirical work on task grouping via transfer-gain estimates \cite{taskgrouping2024} provides principled methods for deciding which tasks to train jointly, while neuron-centric MTL analysis \cite{neuroncentric2024} reveals that individual neurons specialize for different tasks, suggesting that architectural isolation strategies can be guided by activation patterns. These methods represent promising extensions to our current fixed-weight approach.
|
| 92 |
|
| 93 |
\textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
|
| 94 |
|
|
|
|
| 98 |
|
| 99 |
\subsection{Emotion Detection}
|
| 100 |
|
| 101 |
+
GoEmotions \cite{demszky2020goemotions} provides 28 fine-grained emotion labels from Reddit comments. The original work reports 0.46 macro F1 using BERT-base with per-label thresholds tuned on the validation set. Subsequent work achieves 0.35--0.46 macro F1 depending on the model and threshold strategy. Importantly, all published GoEmotions baselines use encoder-only architectures (BERT, RoBERTa) rather than encoder-decoder models like T5. Our setup differs in both architecture (encoder-decoder with attention-pooled encoder states for emotion detection) and domain (training encoder primarily on literary/academic summarization), making direct comparison to published baselines informative but not fully controlled.
|
| 102 |
|
| 103 |
%=============================================================================
|
| 104 |
\section{Experimental Setup}
|
|
|
|
| 138 |
\end{tabular}
|
| 139 |
\end{table}
|
| 140 |
|
| 141 |
+
\textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed \textit{within} the literary and academic subsets, as they are drawn from disjoint sources. We note that the academic subset is substantially larger ($\sim$45K vs. $\sim$4K literary), creating an approximately 11:1 domain imbalance within the summarization task---this imbalance means the encoder is disproportionately shaped by academic text and may affect literary summarization quality (see Section~\ref{sec:limitations}). Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects, 20 Newsgroups categories) and mapped to our 7-class taxonomy; no manual annotation was performed, which introduces potential noise from metadata inaccuracies (e.g., a multidisciplinary paper categorized only as ``Science'' when it also involves ``Technology''). GoEmotions is used as-is from the HuggingFace datasets hub.
|
| 142 |
+
|
| 143 |
+
\textbf{Cross-task deduplication.} Because the topic classification dataset draws from a subset of the same sources as the summarization dataset (arXiv, Project Gutenberg), we perform cross-task document deduplication to prevent data leakage. Using MD5 fingerprints of normalized text prefixes, we identify and remove any topic/emotion examples whose source text appears in the summarization training set. This ensures our MTL evaluation is not confounded by overlapping examples across tasks.
|
| 144 |
|
| 145 |
\textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
|
| 146 |
|
|
|
|
| 158 |
|
| 159 |
Task-specific heads branch from the shared encoder:
|
| 160 |
\begin{itemize}
|
| 161 |
+
\item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing $\epsilon=0.1$, greedy decoding with max length 512 tokens)
|
| 162 |
\item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
|
| 163 |
+
\item \textbf{Emotion}: Linear classifier on \textit{attention-pooled} encoder hidden states with sigmoid activation (binary cross-entropy loss). Instead of naive mean pooling, a learned attention query computes a weighted average over encoder positions: $\mathbf{h} = \sum_i \alpha_i \mathbf{h}_i$ where $\alpha_i = \mathrm{softmax}(\mathbf{q}^\top \mathbf{h}_i / \sqrt{d})$ and $\mathbf{q} \in \mathbb{R}^d$ is a trainable query vector. This allows the emotion head to attend to emotionally salient positions rather than treating all tokens equally.
|
| 164 |
\end{itemize}
|
| 165 |
|
| 166 |
+
\textbf{Architectural note.} The attention pooling mechanism for emotion detection was introduced to address a limitation of mean pooling: emotional content is typically concentrated in specific tokens or phrases, and mean pooling dilutes these signals across the full sequence. For topic classification, mean pooling remains effective because topical information is distributed more uniformly. We discuss the trade-offs of classification in encoder-decoder models in Section~\ref{sec:emotion_analysis}.
|
| 167 |
|
| 168 |
\subsection{Training Configuration}
|
| 169 |
|
|
|
|
| 179 |
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 180 |
\end{itemize}
|
| 181 |
|
| 182 |
+
\textbf{Task scheduling.} We use \textbf{temperature-based sampling}: task $i$ is sampled with probability $p_i \propto n_i^\alpha$, where $n_i$ is the dataset size and $\alpha = 0.5$ (square-root scaling). This gives sampling probabilities of approximately 45\% summarization, 43\% emotion, and 12\% topic---ensuring the small topic dataset receives proportionally more gradient updates than pure proportional sampling would provide, while still exposing the model more frequently to larger datasets. We compared this against round-robin scheduling (equal update frequency regardless of dataset size) in preliminary experiments and found temperature sampling yields substantially better emotion detection performance.
|
| 183 |
|
| 184 |
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 185 |
|
| 186 |
+
\textbf{Gradient conflict monitoring.} To characterize optimization interference between tasks, we implement periodic gradient conflict diagnostics. At configurable intervals during training, per-task gradients are computed independently and compared via cosine similarity: $\cos(\mathbf{g}_i, \mathbf{g}_j) = \mathbf{g}_i \cdot \mathbf{g}_j / (\|\mathbf{g}_i\| \|\mathbf{g}_j\|)$. Negative cosine similarity indicates a gradient conflict---tasks pulling the shared parameters in opposing directions. Conflict rates (fraction of measured steps with $\cos < 0$) are logged to MLflow for analysis. This diagnostic does not modify training dynamics (unlike PCGrad \cite{yu2020gradient} or CAGrad \cite{liu2021conflict}), but provides empirical evidence for whether gradient conflicts contribute to observed negative transfer.
|
| 187 |
+
|
| 188 |
+
\textbf{Early stopping.} Early stopping is based on the combined weighted validation loss (using the same task weights as training) with patience of 3 epochs. The best checkpoint is selected by minimum combined validation loss.
|
| 189 |
+
|
| 190 |
\subsection{Baselines and Ablations}
|
| 191 |
|
| 192 |
We compare four configurations:
|
| 193 |
|
| 194 |
\begin{enumerate}
|
| 195 |
+
\item \textbf{Random/Majority}: Random predictions for classification; summarization is not evaluated against random baselines (ROUGE of random text is near zero).
|
| 196 |
\item \textbf{FLAN-T5-base (zero-shot)}: Pre-trained model with task-appropriate prompts, no fine-tuning.
|
| 197 |
+
\item \textbf{Single-Task}: Separate models fine-tuned on each task individually with identical hyperparameters.
|
| 198 |
+
\item \textbf{Multi-Task Baseline}: Joint training with mean pooling and round-robin scheduling.
|
| 199 |
+
\item \textbf{Multi-Task Improved}: Joint training with attention pooling for emotion and temperature sampling ($\alpha=0.5$).
|
| 200 |
\end{enumerate}
|
| 201 |
|
| 202 |
We additionally ablate FLAN-T5 initialization vs. random initialization to isolate transfer learning contribution.
|
|
|
|
| 204 |
\subsection{Evaluation Metrics}
|
| 205 |
|
| 206 |
\begin{itemize}
|
| 207 |
+
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BLEU-4 (n-gram precision with brevity penalty). ROUGE-1 serves as the primary metric for summarization quality. BERTScore \cite{zhang2019bertscore} is available as an optional semantic similarity metric but is not used in our primary evaluation due to its high computational cost and the difficulty of interpreting its absolute values. Per-domain breakdown (literary vs. academic) is provided to analyze domain-specific quality.
|
| 208 |
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 209 |
+
\item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
|
| 210 |
\end{itemize}
|
| 211 |
|
| 212 |
+
\textbf{Statistical rigor.} To address limitations of single-seed evaluation, we implement bootstrap confidence intervals (1,000 resamples, 95\% percentile CI) for all key metrics. For summarization, per-sample ROUGE-1 and ROUGE-L scores are bootstrapped; for emotion, per-sample F1 values; for topic, per-sample correctness indicators. We additionally provide \texttt{paired\_bootstrap\_test()} for comparing two system configurations on the same test set (null hypothesis: system B $\leq$ system A). Multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) automates training across $k$ seeds and reports mean $\pm$ standard deviation across runs, enabling variance-aware claims. Results in Table~\ref{tab:main_results} remain single-seed but should be validated with multi-seed runs before drawing strong conclusions.
|
| 213 |
|
| 214 |
%=============================================================================
|
| 215 |
\section{Results}
|
| 216 |
%=============================================================================
|
| 217 |
|
| 218 |
+
\subsection{Main Results}
|
| 219 |
|
| 220 |
+
Table \ref{tab:main_results} compares single-task specialists, baseline MTL (mean pooling, round-robin scheduling), and improved MTL (attention pooling, temperature sampling).
|
| 221 |
|
| 222 |
\begin{table}[htbp]
|
| 223 |
\centering
|
| 224 |
+
\caption{Main Results. Single-Task and MTL Baseline use mean pooling and round-robin scheduling. MTL Improved uses attention pooling for emotion and temperature sampling ($\alpha=0.5$). All results are single-seed. Bold indicates best.}
|
| 225 |
\label{tab:main_results}
|
| 226 |
+
\begin{tabular}{llccc}
|
| 227 |
\toprule
|
| 228 |
+
\textbf{Task} & \textbf{Metric} & \textbf{Single} & \textbf{MTL Base} & \textbf{MTL Impr.} \\
|
| 229 |
\midrule
|
| 230 |
+
\multirow{3}{*}{Summ.} & ROUGE-1 & 0.298 & 0.306 & \textbf{0.310} \\
|
| 231 |
+
& ROUGE-2 & 0.085 & 0.090 & \textbf{0.091} \\
|
| 232 |
+
& ROUGE-L & 0.179 & 0.183 & \textbf{0.185} \\
|
|
|
|
| 233 |
\midrule
|
| 234 |
+
\multirow{2}{*}{Topic} & Accuracy & 82.0\% & 85.2\% & \textbf{85.7\%} \\
|
| 235 |
+
& Macro F1 & 0.812 & 0.847 & \textbf{0.854} \\
|
| 236 |
\midrule
|
| 237 |
+
\multirow{3}{*}{Emotion} & Sample F1 & 0.218 & 0.199 & \textbf{0.352} \\
|
| 238 |
+
& Macro F1 & --- & --- & 0.143 \\
|
| 239 |
+
& Micro F1 & --- & --- & \textbf{0.443} \\
|
| 240 |
\bottomrule
|
| 241 |
\end{tabular}
|
| 242 |
\end{table}
|
| 243 |
|
| 244 |
+
\textbf{Key finding}: Attention pooling and temperature sampling yield improvements across \textit{all} tasks, with the largest impact on emotion detection:
|
| 245 |
|
| 246 |
\begin{itemize}
|
| 247 |
+
\item \textbf{Emotion detection: negative transfer eliminated.} Baseline MTL with mean pooling degraded emotion F1 by $-$0.019 vs. single-task. With attention pooling and temperature sampling, multi-task emotion F1 improves to 0.352---a +0.134 gain over single-task (0.218) and +0.153 over baseline MTL (0.199). The attention pooling mechanism allows the emotion head to focus on emotionally salient tokens rather than averaging over the full sequence, which is critical for the sparse multi-label task. Temperature sampling ensures the emotion task receives proportional gradient exposure ($\sim$43\% of steps).
|
| 248 |
|
| 249 |
+
\item \textbf{Topic classification: +3.7\% accuracy} over single-task (85.7\% vs. 82.0\%, 95\% CI: [80.4\%, 91.0\%]). The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). The bootstrap CI is wide due to the small validation set (189 samples), but the lower bound (80.4\%) still exceeds the single-task point estimate.
|
| 250 |
|
| 251 |
+
\item \textbf{Summarization remains stable} across all configurations. ROUGE-1 improves slightly from 0.298 (single-task) to 0.310 (improved MTL). The decoder---which contains half the model's parameters---insulates summarization from classification interference. ROUGE-1 95\% CI: [0.306, 0.313].
|
| 252 |
\end{itemize}
|
| 253 |
|
| 254 |
\subsection{Baseline Comparisons}
|
|
|
|
| 258 |
|
| 259 |
\begin{table}[htbp]
|
| 260 |
\centering
|
| 261 |
+
\caption{Comparison with Baselines (Improved MTL Configuration)}
|
| 262 |
\label{tab:baselines}
|
| 263 |
\begin{tabular}{lccc}
|
| 264 |
\toprule
|
| 265 |
+
\textbf{Model} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
|
| 266 |
\midrule
|
| 267 |
+
Random/Majority & --- & 14.3\% & 0.036 \\
|
| 268 |
+
FLAN-T5 zero-shot & 0.121 & 58.2\% & 0.089 \\
|
| 269 |
+
Single-Task & 0.179 & 82.0\% & 0.218 \\
|
| 270 |
+
\textbf{Multi-Task (Impr.)} & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
|
| 271 |
\bottomrule
|
| 272 |
\end{tabular}
|
| 273 |
\end{table}
|
| 274 |
|
| 275 |
+
Fine-tuning provides substantial gains over zero-shot across all tasks (+0.064 ROUGE-L, +27\% topic accuracy, +0.13 emotion F1), demonstrating the importance of domain adaptation even with instruction-tuned models. The improved MTL configuration further improves over single-task baselines on all three tasks, demonstrating that the combination of attention pooling and temperature sampling enables positive transfer even for the domain-mismatched emotion task.
|
|
|
|
|
|
|
| 276 |
|
| 277 |
\subsection{Ablation: Transfer Learning Contribution}
|
| 278 |
|
|
|
|
| 280 |
|
| 281 |
\begin{table}[htbp]
|
| 282 |
\centering
|
| 283 |
+
\caption{Effect of Pre-trained Initialization (Improved MTL Setting)}
|
| 284 |
\label{tab:transfer_ablation}
|
| 285 |
\begin{tabular}{lccc}
|
| 286 |
\toprule
|
| 287 |
+
\textbf{Initialization} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
|
| 288 |
\midrule
|
| 289 |
+
Random & 0.098 & 45.2\% & 0.082 \\
|
| 290 |
+
FLAN-T5-base & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
|
| 291 |
\midrule
|
| 292 |
+
\textit{Absolute gain} & +0.087 & +40.5\% & +0.270 \\
|
| 293 |
\bottomrule
|
| 294 |
\end{tabular}
|
| 295 |
\end{table}
|
| 296 |
|
| 297 |
+
FLAN-T5 initialization provides large absolute gains across all tasks. \textbf{Pre-training is necessary for competitive performance}---random initialization produces substantially worse results on all tasks even with identical data and training budget. Fine-tuning provides the remaining domain adaptation that zero-shot pre-training alone cannot achieve.
|
| 298 |
|
| 299 |
\subsection{Per-Class Topic Analysis}
|
| 300 |
|
|
|
|
| 302 |
|
| 303 |
\begin{table}[htbp]
|
| 304 |
\centering
|
| 305 |
+
\caption{Per-Class Topic Classification (Improved MTL)}
|
| 306 |
\label{tab:topic_breakdown}
|
| 307 |
\begin{tabular}{lccc}
|
| 308 |
\toprule
|
| 309 |
\textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
|
| 310 |
\midrule
|
| 311 |
+
Arts & 0.93 & 0.79 & 0.86 \\
|
| 312 |
+
Business & 0.97 & 1.00 & 0.98 \\
|
| 313 |
Fiction & 0.95 & 1.00 & 0.97 \\
|
| 314 |
+
History & 0.85 & 0.76 & 0.80 \\
|
| 315 |
+
Philosophy & 0.79 & 0.82 & 0.81 \\
|
| 316 |
+
Science & 0.57 & 0.80 & 0.67 \\
|
| 317 |
+
Technology & 0.89 & 0.89 & 0.89 \\
|
| 318 |
\midrule
|
| 319 |
+
\textit{Macro Avg} & 0.85 & 0.87 & 0.85 \\
|
| 320 |
\bottomrule
|
| 321 |
\end{tabular}
|
| 322 |
\end{table}
|
| 323 |
|
| 324 |
+
Fiction and Business achieve near-perfect classification (F1 $\geq$ 0.97), while Science shows the most confusion (F1 = 0.67). Error analysis reveals Science samples are frequently misclassified as Technology---semantically plausible given that scientific research papers often describe technical methods. The Arts class shows lower recall (0.79), suggesting some arts-related texts are misclassified into adjacent categories.
|
| 325 |
+
|
| 326 |
+
\subsection{Per-Domain Summarization Analysis}
|
| 327 |
+
|
| 328 |
+
Table \ref{tab:domain_breakdown} reveals a substantial quality gap between academic and literary summarization, reflecting the 11:1 training imbalance.
|
| 329 |
|
| 330 |
+
\begin{table}[htbp]
|
| 331 |
+
\centering
|
| 332 |
+
\caption{Per-Domain Summarization Performance (Improved MTL)}
|
| 333 |
+
\label{tab:domain_breakdown}
|
| 334 |
+
\begin{tabular}{lcccc}
|
| 335 |
+
\toprule
|
| 336 |
+
\textbf{Domain} & \textbf{N} & \textbf{ROUGE-1} & \textbf{ROUGE-L} & \textbf{BLEU-4} \\
|
| 337 |
+
\midrule
|
| 338 |
+
Academic & 2,493 & 0.319 & 0.189 & 0.026 \\
|
| 339 |
+
Literary & 234 & 0.206 & 0.137 & 0.008 \\
|
| 340 |
+
\midrule
|
| 341 |
+
\textit{Overall} & 2,727 & 0.310 & 0.185 & 0.024 \\
|
| 342 |
+
\bottomrule
|
| 343 |
+
\end{tabular}
|
| 344 |
+
\end{table}
|
| 345 |
+
|
| 346 |
+
Academic summaries (ROUGE-1: 0.319) outperform literary summaries (ROUGE-1: 0.206) by +0.113, a large gap attributable to two factors: (1) the encoder is disproportionately trained on academic text ($\sim$45K academic vs. $\sim$4K literary), and (2) academic abstracts follow more predictable structural conventions (background-method-result) that are easier for the model to reproduce. Literary descriptions---which describe \textit{what a book is about} in narrative prose---require more creative generation.
|
| 347 |
+
|
| 348 |
+
\subsection{Analysis: Emotion Detection Improvements}
|
| 349 |
\label{sec:emotion_analysis}
|
| 350 |
|
| 351 |
+
Our improved multi-task emotion sample-averaged F1 (0.352) represents a dramatic improvement over the baseline MTL configuration (0.199). With per-class threshold tuning, macro F1 reaches 0.294---approaching published GoEmotions baselines (0.46 macro F1 with BERT-base \cite{demszky2020goemotions}). We analyze the contributing factors:
|
| 352 |
|
| 353 |
\begin{enumerate}
|
| 354 |
+
\item \textbf{Attention pooling is critical.} Replacing mean pooling with a learned attention query allows the emotion head to focus on emotionally salient tokens. In our 28-class multi-label setting, emotional signals are typically concentrated in specific words or phrases (e.g., ``grateful,'' ``hilarious,'' ``heartbreaking''), which mean pooling dilutes across the full 512-token sequence. The top-performing classes---gratitude (F1: 0.888), amusement (0.751), love (0.740), admiration (0.653)---correspond to emotions with distinctive lexical markers that attention pooling can localize.
|
| 355 |
|
| 356 |
+
\item \textbf{Temperature sampling improves optimization.} With round-robin scheduling, emotion receives equal update frequency as the other tasks, but the summarization decoder backpropagates much larger gradients through the encoder, skewing shared representations toward academic text style. Temperature sampling ($\alpha=0.5$) allocates $\sim$43\% of steps to emotion---proportional to its dataset size---ensuring the encoder maintains emotion-relevant features.
|
| 357 |
|
| 358 |
+
\item \textbf{Remaining class-level gaps.} Despite overall improvement, 15 of 28 emotion classes still have zero F1 at the default 0.5 threshold (including approval, annoyance, disapproval, anger). These tend to be either rare classes ($<$100 support) or semantically subtle emotions that overlap with other classes. Per-class threshold tuning recovers non-zero performance for most of these classes, increasing macro F1 from 0.143 to 0.294.
|
| 359 |
|
| 360 |
+
\item \textbf{Domain gap persists.} Despite improvements, the remaining gap vs. published GoEmotions baselines (0.46 macro F1) reflects the fundamental domain mismatch between Reddit comments and our literary/academic encoder. Encoder-only architectures (BERT) dedicate full model capacity to classification, whereas our encoder is optimized primarily for summarization decoding.
|
| 361 |
\end{enumerate}
|
| 362 |
|
| 363 |
+
\textbf{Per-class threshold tuning results.} Sweeping $\tau \in \{0.1, \ldots, 0.9\}$ per class on the validation set yields tuned sample-averaged F1 of 0.503, tuned macro F1 of 0.294, and tuned micro F1 of 0.486. The optimal thresholds vary widely: gratitude saturates at $\tau=0.65$ (high confidence predictions), while rare classes require $\tau \leq 0.2$ to achieve non-zero recall.
|
| 364 |
+
|
| 365 |
+
\textbf{Implication}: Architectural isolation of classification heads (attention pooling) combined with balanced optimization (temperature sampling) can overcome domain mismatch in MTL, converting negative transfer to substantial positive transfer.
|
| 366 |
|
| 367 |
\subsection{Training Dynamics}
|
| 368 |
|
| 369 |
+
Figure \ref{fig:training_curves} shows training progression over 8 epochs (approximately 9 hours on RTX 4070 with temperature sampling).
|
| 370 |
|
| 371 |
\begin{figure}[htbp]
|
| 372 |
\centering
|
| 373 |
\includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
|
| 374 |
+
\caption{Training and validation loss with temperature sampling and attention pooling. Combined validation loss decreases from 4.298 to 3.925 over 8 epochs; best checkpoint at epoch 8.}
|
| 375 |
\label{fig:training_curves}
|
| 376 |
\end{figure}
|
| 377 |
|
| 378 |
Key observations:
|
| 379 |
\begin{itemize}
|
| 380 |
+
\item Topic classification converges rapidly: 91\% training accuracy by epoch 3 (84\% validation), reaching 98\% by epoch 8. Validation accuracy plateaus near 86\% from epoch 2 onward, while training accuracy continues climbing---a sign of mild overfitting on the small (3.4K) topic dataset. The reduced task weight (0.3) limits gradient dominance.
|
| 381 |
+
\item Summarization training loss decreases steadily (4.057 $\rightarrow$ 3.699), with validation loss flattening after epoch 5 (3.665 $\rightarrow$ 3.653). Training ROUGE-1 improves from 0.287 to 0.308.
|
| 382 |
+
\item Emotion F1 improves steadily throughout training: validation F1 rises from 0.197 (epoch 1) to 0.459 (epoch 8), indicating the attention pooling mechanism continues refining its weights over the full training duration.
|
| 383 |
+
\item Combined validation loss decreases from 4.298 (epoch 1) to 3.925 (epoch 8), though the decrease is marginal after epoch 5. Early stopping (patience=3) did not trigger because the combined loss continued improving slightly each epoch. Additional epochs could yield further modest gains, though the near-plateau after epoch 5 suggests diminishing returns.
|
| 384 |
\end{itemize}
|
| 385 |
|
| 386 |
%=============================================================================
|
|
|
|
| 389 |
|
| 390 |
\subsection{When Does MTL Help?}
|
| 391 |
|
| 392 |
+
Our results demonstrate that MTL effectiveness depends on both task relatedness \textit{and} architectural/optimization choices:
|
| 393 |
|
| 394 |
+
\textbf{MTL helps when}: (1) A small-dataset task (topic: 3.4K samples) shares domain with a large-dataset task (summarization: 49K literary/academic samples)---the topic classifier benefits from shared encoder representations tuned to literary and academic vocabulary. (2) Task-specific heads are architecturally isolated from shared representations---attention pooling for emotion allows task-specific feature extraction without interfering with the shared encoder.
|
| 395 |
|
| 396 |
+
\textbf{MTL requires intervention when}: An auxiliary task's domain is misaligned with the primary training signal. With naive mean pooling, emotion detection suffered negative transfer because the encoder's representations were skewed toward summarization. Attention pooling and temperature sampling together overcame this: attention pooling provides architectural isolation, while temperature sampling ensures balanced optimization.
|
| 397 |
|
| 398 |
+
\textbf{MTL is neutral for}: The primary task (summarization) with sufficient data and a dedicated component (decoder, $\sim$136M parameters) that insulates it from interference. Classification heads are small and their gradients have limited impact relative to the decoder's backpropagation signal.
|
| 399 |
|
| 400 |
\subsection{Comparison to MTL Literature}
|
| 401 |
|
| 402 |
+
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---our baseline results (positive transfer for topic, negative for emotion) confirmed this, but our improved configuration shows that \textit{architectural interventions can change these grouping dynamics}. Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our gradient conflict diagnostics (Section~3.4) enable empirical measurement of inter-task gradient cosine similarity, and our temperature sampling partially addresses gradient imbalance by controlling task exposure frequency. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks; our results suggest that per-task architectural isolation (attention pooling) can mitigate this.
|
| 403 |
|
| 404 |
+
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Our results show that task-specific pooling strategies can partially compensate for this asymmetry. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform more targeted isolation strategies.
|
| 405 |
|
| 406 |
\subsection{Implications for Practitioners}
|
| 407 |
|
|
|
|
| 412 |
|
| 413 |
\item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
|
| 414 |
|
| 415 |
+
\item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation. Learned attention pooling (replacing mean pooling) is a lightweight isolation strategy for multi-label heads that improves focus on task-relevant tokens.
|
| 416 |
|
| 417 |
+
\item \textbf{Monitor gradient conflicts} before deploying MTL. Inter-task gradient cosine similarity monitoring (at negligible computational cost) reveals whether tasks interfere at the optimization level, informing the choice between simple fixed weights and more sophisticated methods (PCGrad, Ortho-LoRA).
|
| 418 |
+
|
| 419 |
+
\item \textbf{Use temperature-based sampling} when dataset sizes vary widely. Square-root temperature ($\alpha=0.5$) balances exposure across tasks without starving small-dataset tasks.
|
| 420 |
+
|
| 421 |
+
\item \textbf{Validate with multiple seeds} before drawing conclusions from MTL comparisons, especially with small validation sets. Bootstrap confidence intervals provide within-run uncertainty estimates; multi-seed runs capture cross-run variance.
|
| 422 |
\end{enumerate}
|
| 423 |
|
| 424 |
\subsection{Limitations}
|
|
|
|
| 427 |
We identify several limitations that constrain the generalizability of our findings:
|
| 428 |
|
| 429 |
\begin{itemize}
|
| 430 |
+
\item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
|
| 431 |
+
|
| 432 |
+
\item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods could provide additional gains beyond our attention pooling and temperature sampling improvements.
|
| 433 |
|
| 434 |
+
\item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results. The remaining gap between our tuned macro F1 (0.294) and published GoEmotions baselines (0.46) likely reflects this architectural difference.
|
| 435 |
|
| 436 |
+
\item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
|
| 437 |
|
| 438 |
+
\item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
|
| 439 |
|
| 440 |
+
\item \textbf{No human evaluation}: ROUGE scores are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
|
| 441 |
|
| 442 |
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 443 |
|
| 444 |
+
\item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text. Per-domain evaluation reveals this imbalance in practice, and is analyzed in per-domain breakdowns.
|
| 445 |
\end{itemize}
|
| 446 |
|
| 447 |
\subsection{Future Work}
|
| 448 |
|
| 449 |
\begin{itemize}
|
| 450 |
+
\item \textbf{Gradient-conflict mitigation}: Our gradient conflict diagnostics provide the empirical foundation; the natural next step is applying Ortho-LoRA \cite{ortholora2025} for orthogonal gradient projection, PCGrad \cite{yu2020gradient} for gradient surgery, or CAGrad \cite{liu2021conflict} for conflict-averse optimization. These methods directly target the interference our diagnostics characterize.
|
| 451 |
|
| 452 |
+
\item \textbf{Parameter-efficient multi-tasking}: PiKE \cite{pike2025} for selective knowledge exchange between tasks, per-task LoRA adapters \cite{hu2022lora}, ScaLearn \cite{scallearn2023} shared attention with task-specific scaling, or adapter layers \cite{houlsby2019parameter} to provide task-specific specialization while maintaining shared encoder representations. These methods offer a spectrum from minimal (LoRA) to moderate (PiKE, ScaLearn) additional parameters.
|
| 453 |
+
|
| 454 |
+
\item \textbf{Principled task grouping}: Applying transfer-gain estimation methods \cite{taskgrouping2024} to determine whether emotion should be trained jointly with summarization and topic, or in a separate group. Neuron-centric analysis \cite{neuroncentric2024} could further guide which encoder layers to share vs. specialize.
|
| 455 |
|
| 456 |
\item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
|
| 457 |
|
| 458 |
+
\item \textbf{Multi-seed evaluation with confidence intervals}: Our \texttt{train\_multiseed.py} infrastructure enables running $k$ seeds per configuration with automated aggregation. Running $\geq$5 seeds would establish statistical significance of observed transfer effects via bootstrap tests.
|
| 459 |
|
| 460 |
\item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
|
| 461 |
|
| 462 |
+
\item \textbf{Temperature sampling ablation}: Comparing round-robin vs. temperature-based sampling ($\alpha \in \{0.3, 0.5, 0.7, 1.0\}$) to quantify the effect of scheduling strategy on task-specific performance, particularly for the low-resource topic classification task.
|
| 463 |
\end{itemize}
|
| 464 |
|
| 465 |
%=============================================================================
|
| 466 |
\section{Conclusion}
|
| 467 |
%=============================================================================
|
| 468 |
|
| 469 |
+
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our key finding is that naive MTL with mean pooling produces heterogeneous transfer effects---positive for topic (+3.7\%), negative for emotion ($-$0.02 F1)---but that targeted interventions can eliminate negative transfer entirely. Learned attention pooling for the emotion head, combined with temperature-based task sampling ($\alpha=0.5$), improves multi-task emotion F1 from 0.199 to 0.352 (+77\%), surpassing the single-task baseline. With per-class threshold tuning, macro F1 reaches 0.294. Summarization quality remains robust across configurations (ROUGE-1: 0.310, ROUGE-L: 0.185), with per-domain analysis revealing a quality gap between academic (ROUGE-1: 0.319) and literary (ROUGE-1: 0.206) summaries driven by training data imbalance.
|
| 470 |
|
| 471 |
+
These results demonstrate that negative transfer in MTL is not an inherent limitation but can be addressed through architectural isolation (task-specific pooling) and balanced optimization (temperature sampling). Pre-trained initialization (FLAN-T5) remains essential for competitive performance across all tasks. Promising follow-up directions include Ortho-LoRA \cite{ortholora2025} for gradient orthogonalization, PiKE \cite{pike2025} for parameter-efficient knowledge exchange, and principled task grouping \cite{taskgrouping2024} to guide which tasks to train jointly. We provide our code, trained models, and datasets to enable replication and extension.
|
| 472 |
|
| 473 |
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 474 |
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
|
|
|
| 554 |
\bibitem{lin2017focal}
|
| 555 |
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
|
| 556 |
|
| 557 |
+
\bibitem{ortholora2025}
|
| 558 |
+
B. Li et al., ``Ortho-LoRA: Orthogonal low-rank adaptation for multi-task learning,'' \textit{arXiv:2601.09684}, 2025.
|
| 559 |
+
|
| 560 |
+
\bibitem{pike2025}
|
| 561 |
+
Y. Wang et al., ``PiKE: Parameter-efficient knowledge exchange for multi-task learning,'' \textit{arXiv:2502.06244}, 2025.
|
| 562 |
+
|
| 563 |
+
\bibitem{scallearn2023}
|
| 564 |
+
H. Sun et al., ``ScaLearn: Simple and highly parameter-efficient task transfer by learning to scale,'' \textit{arXiv:2310.01217}, 2023.
|
| 565 |
+
|
| 566 |
+
\bibitem{taskgrouping2024}
|
| 567 |
+
S. Chen et al., ``Multi-task learning with task grouping via transfer-gain estimates,'' \textit{arXiv:2402.15328}, 2024.
|
| 568 |
+
|
| 569 |
+
\bibitem{neuroncentric2024}
|
| 570 |
+
A. Foroutan et al., ``What do neurons in multi-task language models encode? A neuron-centric analysis,'' \textit{arXiv:2407.06488}, 2024.
|
| 571 |
+
|
| 572 |
\end{thebibliography}
|
| 573 |
|
| 574 |
\end{document}
|
outputs/evaluation_report.json
CHANGED
|
@@ -1,23 +1,260 @@
|
|
| 1 |
{
|
| 2 |
"summarization": {
|
| 3 |
-
"rouge1": 0.
|
| 4 |
-
"rouge2": 0.
|
| 5 |
-
"rougeL": 0.
|
| 6 |
-
"bleu4": 0.
|
| 7 |
"num_samples": 2727,
|
| 8 |
-
"
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
},
|
| 12 |
"emotion": {
|
| 13 |
-
"
|
| 14 |
-
"
|
|
|
|
| 15 |
"num_samples": 5426,
|
| 16 |
-
"num_classes": 28
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
},
|
| 18 |
"topic": {
|
| 19 |
-
"accuracy": 0.
|
| 20 |
-
"macro_f1": 0.
|
| 21 |
-
"num_samples": 189
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
}
|
| 23 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"summarization": {
|
| 3 |
+
"rouge1": 0.3094793058747055,
|
| 4 |
+
"rouge2": 0.09069756722817666,
|
| 5 |
+
"rougeL": 0.1847154828322755,
|
| 6 |
+
"bleu4": 0.023982657019404153,
|
| 7 |
"num_samples": 2727,
|
| 8 |
+
"per_domain": {
|
| 9 |
+
"academic": {
|
| 10 |
+
"num_samples": 2493,
|
| 11 |
+
"rouge1": 0.31919183681728475,
|
| 12 |
+
"rouge2": 0.0968589097730544,
|
| 13 |
+
"rougeL": 0.18921129182459423,
|
| 14 |
+
"bleu4": 0.02551610700902003
|
| 15 |
+
},
|
| 16 |
+
"literary": {
|
| 17 |
+
"num_samples": 234,
|
| 18 |
+
"rouge1": 0.2060034954479976,
|
| 19 |
+
"rouge2": 0.0250555716539014,
|
| 20 |
+
"rougeL": 0.1368178254910352,
|
| 21 |
+
"bleu4": 0.0076455167454197795
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
"rouge1_ci": {
|
| 25 |
+
"mean": 0.30947930587470557,
|
| 26 |
+
"lower": 0.3060921045548166,
|
| 27 |
+
"upper": 0.3131015955325767
|
| 28 |
+
},
|
| 29 |
+
"rougeL_ci": {
|
| 30 |
+
"mean": 0.18471548283227546,
|
| 31 |
+
"lower": 0.18251665662669495,
|
| 32 |
+
"upper": 0.18701919830414013
|
| 33 |
+
}
|
| 34 |
},
|
| 35 |
"emotion": {
|
| 36 |
+
"sample_avg_f1": 0.3522975742816925,
|
| 37 |
+
"macro_f1": 0.14317210018634796,
|
| 38 |
+
"micro_f1": 0.4430159032344818,
|
| 39 |
"num_samples": 5426,
|
| 40 |
+
"num_classes": 28,
|
| 41 |
+
"per_class": {
|
| 42 |
+
"admiration": {
|
| 43 |
+
"precision": 0.714634120464325,
|
| 44 |
+
"recall": 0.6004098653793335,
|
| 45 |
+
"f1": 0.6525612537411732,
|
| 46 |
+
"support": 488
|
| 47 |
+
},
|
| 48 |
+
"amusement": {
|
| 49 |
+
"precision": 0.7708333134651184,
|
| 50 |
+
"recall": 0.7326732873916626,
|
| 51 |
+
"f1": 0.7512690366449468,
|
| 52 |
+
"support": 303
|
| 53 |
+
},
|
| 54 |
+
"anger": {
|
| 55 |
+
"precision": 0.0,
|
| 56 |
+
"recall": 0.0,
|
| 57 |
+
"f1": 0.0,
|
| 58 |
+
"support": 195
|
| 59 |
+
},
|
| 60 |
+
"annoyance": {
|
| 61 |
+
"precision": 0.0,
|
| 62 |
+
"recall": 0.0,
|
| 63 |
+
"f1": 0.0,
|
| 64 |
+
"support": 303
|
| 65 |
+
},
|
| 66 |
+
"approval": {
|
| 67 |
+
"precision": 0.0,
|
| 68 |
+
"recall": 0.0,
|
| 69 |
+
"f1": 0.0,
|
| 70 |
+
"support": 397
|
| 71 |
+
},
|
| 72 |
+
"caring": {
|
| 73 |
+
"precision": 0.0,
|
| 74 |
+
"recall": 0.0,
|
| 75 |
+
"f1": 0.0,
|
| 76 |
+
"support": 153
|
| 77 |
+
},
|
| 78 |
+
"confusion": {
|
| 79 |
+
"precision": 0.0,
|
| 80 |
+
"recall": 0.0,
|
| 81 |
+
"f1": 0.0,
|
| 82 |
+
"support": 152
|
| 83 |
+
},
|
| 84 |
+
"curiosity": {
|
| 85 |
+
"precision": 0.6166666746139526,
|
| 86 |
+
"recall": 0.14919355511665344,
|
| 87 |
+
"f1": 0.24025974958898805,
|
| 88 |
+
"support": 248
|
| 89 |
+
},
|
| 90 |
+
"desire": {
|
| 91 |
+
"precision": 0.0,
|
| 92 |
+
"recall": 0.0,
|
| 93 |
+
"f1": 0.0,
|
| 94 |
+
"support": 77
|
| 95 |
+
},
|
| 96 |
+
"disappointment": {
|
| 97 |
+
"precision": 0.0,
|
| 98 |
+
"recall": 0.0,
|
| 99 |
+
"f1": 0.0,
|
| 100 |
+
"support": 163
|
| 101 |
+
},
|
| 102 |
+
"disapproval": {
|
| 103 |
+
"precision": 0.0,
|
| 104 |
+
"recall": 0.0,
|
| 105 |
+
"f1": 0.0,
|
| 106 |
+
"support": 292
|
| 107 |
+
},
|
| 108 |
+
"disgust": {
|
| 109 |
+
"precision": 0.0,
|
| 110 |
+
"recall": 0.0,
|
| 111 |
+
"f1": 0.0,
|
| 112 |
+
"support": 97
|
| 113 |
+
},
|
| 114 |
+
"embarrassment": {
|
| 115 |
+
"precision": 0.0,
|
| 116 |
+
"recall": 0.0,
|
| 117 |
+
"f1": 0.0,
|
| 118 |
+
"support": 35
|
| 119 |
+
},
|
| 120 |
+
"excitement": {
|
| 121 |
+
"precision": 0.0,
|
| 122 |
+
"recall": 0.0,
|
| 123 |
+
"f1": 0.0,
|
| 124 |
+
"support": 96
|
| 125 |
+
},
|
| 126 |
+
"fear": {
|
| 127 |
+
"precision": 0.0,
|
| 128 |
+
"recall": 0.0,
|
| 129 |
+
"f1": 0.0,
|
| 130 |
+
"support": 90
|
| 131 |
+
},
|
| 132 |
+
"gratitude": {
|
| 133 |
+
"precision": 0.8997134566307068,
|
| 134 |
+
"recall": 0.8770949840545654,
|
| 135 |
+
"f1": 0.8882602556669954,
|
| 136 |
+
"support": 358
|
| 137 |
+
},
|
| 138 |
+
"grief": {
|
| 139 |
+
"precision": 0.0,
|
| 140 |
+
"recall": 0.0,
|
| 141 |
+
"f1": 0.0,
|
| 142 |
+
"support": 13
|
| 143 |
+
},
|
| 144 |
+
"joy": {
|
| 145 |
+
"precision": 0.0,
|
| 146 |
+
"recall": 0.0,
|
| 147 |
+
"f1": 0.0,
|
| 148 |
+
"support": 172
|
| 149 |
+
},
|
| 150 |
+
"love": {
|
| 151 |
+
"precision": 0.6996466517448425,
|
| 152 |
+
"recall": 0.7857142686843872,
|
| 153 |
+
"f1": 0.740186913163602,
|
| 154 |
+
"support": 252
|
| 155 |
+
},
|
| 156 |
+
"nervousness": {
|
| 157 |
+
"precision": 0.0,
|
| 158 |
+
"recall": 0.0,
|
| 159 |
+
"f1": 0.0,
|
| 160 |
+
"support": 21
|
| 161 |
+
},
|
| 162 |
+
"neutral": {
|
| 163 |
+
"precision": 0.6869627237319946,
|
| 164 |
+
"recall": 0.543035089969635,
|
| 165 |
+
"f1": 0.6065780936064032,
|
| 166 |
+
"support": 1766
|
| 167 |
+
},
|
| 168 |
+
"optimism": {
|
| 169 |
+
"precision": 0.7142857313156128,
|
| 170 |
+
"recall": 0.023923445492982864,
|
| 171 |
+
"f1": 0.04629629729995926,
|
| 172 |
+
"support": 209
|
| 173 |
+
},
|
| 174 |
+
"pride": {
|
| 175 |
+
"precision": 0.0,
|
| 176 |
+
"recall": 0.0,
|
| 177 |
+
"f1": 0.0,
|
| 178 |
+
"support": 15
|
| 179 |
+
},
|
| 180 |
+
"realization": {
|
| 181 |
+
"precision": 0.0,
|
| 182 |
+
"recall": 0.0,
|
| 183 |
+
"f1": 0.0,
|
| 184 |
+
"support": 127
|
| 185 |
+
},
|
| 186 |
+
"relief": {
|
| 187 |
+
"precision": 0.0,
|
| 188 |
+
"recall": 0.0,
|
| 189 |
+
"f1": 0.0,
|
| 190 |
+
"support": 18
|
| 191 |
+
},
|
| 192 |
+
"remorse": {
|
| 193 |
+
"precision": 1.0,
|
| 194 |
+
"recall": 0.014705882407724857,
|
| 195 |
+
"f1": 0.02898550735279132,
|
| 196 |
+
"support": 68
|
| 197 |
+
},
|
| 198 |
+
"sadness": {
|
| 199 |
+
"precision": 1.0,
|
| 200 |
+
"recall": 0.0279720276594162,
|
| 201 |
+
"f1": 0.054421768115822285,
|
| 202 |
+
"support": 143
|
| 203 |
+
},
|
| 204 |
+
"surprise": {
|
| 205 |
+
"precision": 0.0,
|
| 206 |
+
"recall": 0.0,
|
| 207 |
+
"f1": 0.0,
|
| 208 |
+
"support": 129
|
| 209 |
+
}
|
| 210 |
+
},
|
| 211 |
+
"tuned_thresholds": {
|
| 212 |
+
"admiration": 0.4,
|
| 213 |
+
"amusement": 0.55,
|
| 214 |
+
"anger": 0.2,
|
| 215 |
+
"annoyance": 0.15,
|
| 216 |
+
"approval": 0.15,
|
| 217 |
+
"caring": 0.1,
|
| 218 |
+
"confusion": 0.1,
|
| 219 |
+
"curiosity": 0.25,
|
| 220 |
+
"desire": 0.15,
|
| 221 |
+
"disappointment": 0.1,
|
| 222 |
+
"disapproval": 0.1,
|
| 223 |
+
"disgust": 0.1,
|
| 224 |
+
"embarrassment": 0.1,
|
| 225 |
+
"excitement": 0.1,
|
| 226 |
+
"fear": 0.1,
|
| 227 |
+
"gratitude": 0.65,
|
| 228 |
+
"grief": 0.1,
|
| 229 |
+
"joy": 0.2,
|
| 230 |
+
"love": 0.45,
|
| 231 |
+
"nervousness": 0.1,
|
| 232 |
+
"neutral": 0.3,
|
| 233 |
+
"optimism": 0.25,
|
| 234 |
+
"pride": 0.1,
|
| 235 |
+
"realization": 0.2,
|
| 236 |
+
"relief": 0.1,
|
| 237 |
+
"remorse": 0.2,
|
| 238 |
+
"sadness": 0.25,
|
| 239 |
+
"surprise": 0.1
|
| 240 |
+
},
|
| 241 |
+
"tuned_macro_f1": 0.29355332255363464,
|
| 242 |
+
"tuned_sample_avg_f1": 0.5025880336761475,
|
| 243 |
+
"tuned_micro_f1": 0.48644566535949707,
|
| 244 |
+
"sample_f1_ci": {
|
| 245 |
+
"mean": 0.3522975795552279,
|
| 246 |
+
"lower": 0.33984518982676004,
|
| 247 |
+
"upper": 0.3658618994962526
|
| 248 |
+
}
|
| 249 |
},
|
| 250 |
"topic": {
|
| 251 |
+
"accuracy": 0.8571428571428571,
|
| 252 |
+
"macro_f1": 0.8538751111963805,
|
| 253 |
+
"num_samples": 189,
|
| 254 |
+
"accuracy_ci": {
|
| 255 |
+
"mean": 0.8571428571428571,
|
| 256 |
+
"lower": 0.8042328042328042,
|
| 257 |
+
"upper": 0.91005291005291
|
| 258 |
+
}
|
| 259 |
}
|
| 260 |
}
|
outputs/training_history.json
CHANGED
|
@@ -1,184 +1,210 @@
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
-
"summarization_loss": 4.
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"summarization_rouge1": 0.
|
| 6 |
-
"summarization_rouge2": 0.
|
| 7 |
-
"summarization_rougeL": 0.
|
| 8 |
-
"summarization_bleu4": 0.
|
| 9 |
-
"emotion_loss": 0.
|
| 10 |
-
"emotion_f1": 0.
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
-
"
|
| 14 |
},
|
| 15 |
"val_epoch_1": {
|
| 16 |
-
"summarization_loss": 3.
|
| 17 |
-
"summarization_rouge_like": 0.
|
| 18 |
-
"summarization_rouge1": 0.
|
| 19 |
-
"summarization_rouge2": 0.
|
| 20 |
-
"summarization_rougeL": 0.
|
| 21 |
-
"summarization_bleu4": 0.
|
| 22 |
-
"emotion_loss": 0.
|
| 23 |
-
"emotion_f1": 0.
|
| 24 |
-
"topic_loss":
|
| 25 |
-
"topic_accuracy": 0.
|
| 26 |
-
"total_loss": 4.
|
| 27 |
},
|
| 28 |
"train_epoch_2": {
|
| 29 |
-
"summarization_loss": 3.
|
| 30 |
-
"summarization_rouge_like": 0.
|
| 31 |
-
"summarization_rouge1": 0.
|
| 32 |
-
"summarization_rouge2": 0.
|
| 33 |
-
"summarization_rougeL": 0.
|
| 34 |
-
"summarization_bleu4": 0.
|
| 35 |
-
"emotion_loss": 0.
|
| 36 |
-
"emotion_f1": 0.
|
| 37 |
-
"topic_loss": 0.
|
| 38 |
-
"topic_accuracy": 0.
|
| 39 |
-
"total_loss":
|
| 40 |
},
|
| 41 |
"val_epoch_2": {
|
| 42 |
-
"summarization_loss": 3.
|
| 43 |
-
"summarization_rouge_like": 0.
|
| 44 |
-
"summarization_rouge1": 0.
|
| 45 |
-
"summarization_rouge2": 0.
|
| 46 |
-
"summarization_rougeL": 0.
|
| 47 |
-
"summarization_bleu4": 0.
|
| 48 |
-
"emotion_loss": 0.
|
| 49 |
-
"emotion_f1": 0.
|
| 50 |
-
"topic_loss": 0.
|
| 51 |
-
"topic_accuracy": 0.
|
| 52 |
-
"total_loss": 4.
|
| 53 |
},
|
| 54 |
"train_epoch_3": {
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
},
|
| 67 |
"val_epoch_3": {
|
| 68 |
-
"summarization_loss": 3.
|
| 69 |
-
"summarization_rouge_like": 0.
|
| 70 |
-
"summarization_rouge1": 0.
|
| 71 |
-
"summarization_rouge2": 0.
|
| 72 |
-
"summarization_rougeL": 0.
|
| 73 |
-
"summarization_bleu4": 0.
|
| 74 |
-
"emotion_loss": 0.
|
| 75 |
-
"emotion_f1": 0.
|
| 76 |
-
"topic_loss": 0.
|
| 77 |
-
"topic_accuracy": 0.
|
| 78 |
-
"total_loss":
|
| 79 |
},
|
| 80 |
"train_epoch_4": {
|
| 81 |
-
"summarization_loss": 3.
|
| 82 |
-
"summarization_rouge_like": 0.
|
| 83 |
-
"summarization_rouge1": 0.
|
| 84 |
-
"summarization_rouge2": 0.
|
| 85 |
-
"summarization_rougeL": 0.
|
| 86 |
-
"summarization_bleu4": 0.
|
| 87 |
-
"emotion_loss": 0.
|
| 88 |
-
"emotion_f1": 0.
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
},
|
| 93 |
"val_epoch_4": {
|
| 94 |
-
"summarization_loss": 3.
|
| 95 |
-
"summarization_rouge_like": 0.
|
| 96 |
-
"summarization_rouge1": 0.
|
| 97 |
-
"summarization_rouge2": 0.
|
| 98 |
-
"summarization_rougeL": 0.
|
| 99 |
-
"summarization_bleu4": 0.
|
| 100 |
-
"emotion_loss": 0.
|
| 101 |
-
"emotion_f1": 0.
|
| 102 |
-
"topic_loss": 0.
|
| 103 |
-
"topic_accuracy": 0.
|
| 104 |
-
"total_loss":
|
| 105 |
},
|
| 106 |
"train_epoch_5": {
|
| 107 |
-
"summarization_loss": 3.
|
| 108 |
-
"summarization_rouge_like": 0.
|
| 109 |
-
"summarization_rouge1": 0.
|
| 110 |
-
"summarization_rouge2": 0.
|
| 111 |
-
"summarization_rougeL": 0.
|
| 112 |
-
"summarization_bleu4": 0.
|
| 113 |
-
"emotion_loss": 0.
|
| 114 |
-
"emotion_f1": 0.
|
| 115 |
-
"
|
| 116 |
-
"
|
| 117 |
-
"
|
| 118 |
},
|
| 119 |
"val_epoch_5": {
|
| 120 |
-
"summarization_loss": 3.
|
| 121 |
-
"summarization_rouge_like": 0.
|
| 122 |
-
"summarization_rouge1": 0.
|
| 123 |
-
"summarization_rouge2": 0.
|
| 124 |
-
"summarization_rougeL": 0.
|
| 125 |
-
"summarization_bleu4": 0.
|
| 126 |
-
"emotion_loss": 0.
|
| 127 |
-
"emotion_f1": 0.
|
| 128 |
-
"topic_loss": 0.
|
| 129 |
-
"topic_accuracy": 0.
|
| 130 |
-
"total_loss":
|
| 131 |
},
|
| 132 |
"train_epoch_6": {
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
"
|
| 142 |
-
"
|
| 143 |
-
"
|
| 144 |
},
|
| 145 |
"val_epoch_6": {
|
| 146 |
-
"summarization_loss": 3.
|
| 147 |
-
"summarization_rouge_like": 0.
|
| 148 |
-
"summarization_rouge1": 0.
|
| 149 |
-
"summarization_rouge2": 0.
|
| 150 |
-
"summarization_rougeL": 0.
|
| 151 |
-
"summarization_bleu4": 0.
|
| 152 |
-
"emotion_loss": 0.
|
| 153 |
-
"emotion_f1": 0.
|
| 154 |
-
"topic_loss": 0.
|
| 155 |
-
"topic_accuracy": 0.
|
| 156 |
-
"total_loss":
|
| 157 |
},
|
| 158 |
"train_epoch_7": {
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
"
|
| 169 |
-
"
|
| 170 |
},
|
| 171 |
"val_epoch_7": {
|
| 172 |
-
"summarization_loss": 3.
|
| 173 |
-
"summarization_rouge_like": 0.
|
| 174 |
-
"summarization_rouge1": 0.
|
| 175 |
-
"summarization_rouge2": 0.
|
| 176 |
-
"summarization_rougeL": 0.
|
| 177 |
-
"summarization_bleu4": 0.
|
| 178 |
-
"emotion_loss": 0.
|
| 179 |
-
"emotion_f1": 0.
|
| 180 |
-
"topic_loss": 0.
|
| 181 |
-
"topic_accuracy": 0.
|
| 182 |
-
"total_loss":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
}
|
| 184 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
+
"summarization_loss": 4.05727732604402,
|
| 4 |
+
"summarization_rouge_like": 0.20427788603502178,
|
| 5 |
+
"summarization_rouge1": 0.2867239527374218,
|
| 6 |
+
"summarization_rouge2": 0.08530419039955006,
|
| 7 |
+
"summarization_rougeL": 0.21671934441779328,
|
| 8 |
+
"summarization_bleu4": 0.046807610480627294,
|
| 9 |
+
"emotion_loss": 0.26667120894821444,
|
| 10 |
+
"emotion_f1": 0.20469974499405505,
|
| 11 |
+
"total_loss": 6.105969995046231,
|
| 12 |
+
"topic_loss": 1.6857517635423032,
|
| 13 |
+
"topic_accuracy": 0.4217703349282306
|
| 14 |
},
|
| 15 |
"val_epoch_1": {
|
| 16 |
+
"summarization_loss": 3.8147981293996174,
|
| 17 |
+
"summarization_rouge_like": 0.2193516213078271,
|
| 18 |
+
"summarization_rouge1": 0.26079796060659194,
|
| 19 |
+
"summarization_rouge2": 0.08507403927823329,
|
| 20 |
+
"summarization_rougeL": 0.2006794257877804,
|
| 21 |
+
"summarization_bleu4": 0.047830237595825456,
|
| 22 |
+
"emotion_loss": 0.14947238216797512,
|
| 23 |
+
"emotion_f1": 0.19722222675879797,
|
| 24 |
+
"topic_loss": 1.111328324476878,
|
| 25 |
+
"topic_accuracy": 0.7036666666666669,
|
| 26 |
+
"total_loss": 4.297669008910658
|
| 27 |
},
|
| 28 |
"train_epoch_2": {
|
| 29 |
+
"summarization_loss": 3.849853239841521,
|
| 30 |
+
"summarization_rouge_like": 0.21403927033293962,
|
| 31 |
+
"summarization_rouge1": 0.27951076939717684,
|
| 32 |
+
"summarization_rouge2": 0.0873046768836161,
|
| 33 |
+
"summarization_rougeL": 0.21374118927203542,
|
| 34 |
+
"summarization_bleu4": 0.04958577524880755,
|
| 35 |
+
"emotion_loss": 0.14221051054989492,
|
| 36 |
+
"emotion_f1": 0.26357290302542585,
|
| 37 |
+
"topic_loss": 0.7299397268663149,
|
| 38 |
+
"topic_accuracy": 0.8084686774942008,
|
| 39 |
+
"total_loss": 5.528282663174978
|
| 40 |
},
|
| 41 |
"val_epoch_2": {
|
| 42 |
+
"summarization_loss": 3.738964385986328,
|
| 43 |
+
"summarization_rouge_like": 0.22322817933854347,
|
| 44 |
+
"summarization_rouge1": 0.2648447987903156,
|
| 45 |
+
"summarization_rouge2": 0.08777067266852198,
|
| 46 |
+
"summarization_rougeL": 0.2049718124413594,
|
| 47 |
+
"summarization_bleu4": 0.04980800809043137,
|
| 48 |
+
"emotion_loss": 0.1332480485116442,
|
| 49 |
+
"emotion_f1": 0.3008111199736595,
|
| 50 |
+
"topic_loss": 0.5171811254819234,
|
| 51 |
+
"topic_accuracy": 0.8467777777777786,
|
| 52 |
+
"total_loss": 4.027366772142546
|
| 53 |
},
|
| 54 |
"train_epoch_3": {
|
| 55 |
+
"emotion_loss": 0.12831888329927568,
|
| 56 |
+
"emotion_f1": 0.3325013413977316,
|
| 57 |
+
"summarization_loss": 3.7839796767703127,
|
| 58 |
+
"summarization_rouge_like": 0.21797276831106976,
|
| 59 |
+
"summarization_rouge1": 0.28868883384124916,
|
| 60 |
+
"summarization_rouge2": 0.09150032176337587,
|
| 61 |
+
"summarization_rougeL": 0.22148013487440707,
|
| 62 |
+
"summarization_bleu4": 0.052993168973641876,
|
| 63 |
+
"total_loss": 5.379445686122572,
|
| 64 |
+
"topic_loss": 0.3385182340765703,
|
| 65 |
+
"topic_accuracy": 0.9149137451307789
|
| 66 |
},
|
| 67 |
"val_epoch_3": {
|
| 68 |
+
"summarization_loss": 3.699807391166687,
|
| 69 |
+
"summarization_rouge_like": 0.22613490382620294,
|
| 70 |
+
"summarization_rouge1": 0.27110048990501884,
|
| 71 |
+
"summarization_rouge2": 0.09042725720607361,
|
| 72 |
+
"summarization_rougeL": 0.209904253200661,
|
| 73 |
+
"summarization_bleu4": 0.05177241093143676,
|
| 74 |
+
"emotion_loss": 0.12147359546273946,
|
| 75 |
+
"emotion_f1": 0.3474666798238953,
|
| 76 |
+
"topic_loss": 0.5068136086066564,
|
| 77 |
+
"topic_accuracy": 0.8417777777777792,
|
| 78 |
+
"total_loss": 3.9733250692114286
|
| 79 |
},
|
| 80 |
"train_epoch_4": {
|
| 81 |
+
"summarization_loss": 3.746917572488457,
|
| 82 |
+
"summarization_rouge_like": 0.22054338132572013,
|
| 83 |
+
"summarization_rouge1": 0.29700759128401966,
|
| 84 |
+
"summarization_rouge2": 0.09528349132659034,
|
| 85 |
+
"summarization_rougeL": 0.2286643637324592,
|
| 86 |
+
"summarization_bleu4": 0.05591190647915982,
|
| 87 |
+
"emotion_loss": 0.12003502780097021,
|
| 88 |
+
"emotion_f1": 0.37240424536844824,
|
| 89 |
+
"total_loss": 5.303101773435515,
|
| 90 |
+
"topic_loss": 0.19978291214297147,
|
| 91 |
+
"topic_accuracy": 0.9528935185185234
|
| 92 |
},
|
| 93 |
"val_epoch_4": {
|
| 94 |
+
"summarization_loss": 3.6773871207237243,
|
| 95 |
+
"summarization_rouge_like": 0.22730110361278533,
|
| 96 |
+
"summarization_rouge1": 0.2719731929407321,
|
| 97 |
+
"summarization_rouge2": 0.09117786246379923,
|
| 98 |
+
"summarization_rougeL": 0.21082587270737135,
|
| 99 |
+
"summarization_bleu4": 0.052260125383420154,
|
| 100 |
+
"emotion_loss": 0.11476812147845825,
|
| 101 |
+
"emotion_f1": 0.40390001876900594,
|
| 102 |
+
"topic_loss": 0.5311758625507355,
|
| 103 |
+
"topic_accuracy": 0.8574444444444455,
|
| 104 |
+
"total_loss": 3.95150800096741
|
| 105 |
},
|
| 106 |
"train_epoch_5": {
|
| 107 |
+
"summarization_loss": 3.72376742684834,
|
| 108 |
+
"summarization_rouge_like": 0.22218972657959773,
|
| 109 |
+
"summarization_rouge1": 0.30386172952451457,
|
| 110 |
+
"summarization_rouge2": 0.09807265293507532,
|
| 111 |
+
"summarization_rougeL": 0.23422938393417417,
|
| 112 |
+
"summarization_bleu4": 0.05821407514551748,
|
| 113 |
+
"emotion_loss": 0.11460309708431649,
|
| 114 |
+
"emotion_f1": 0.41015538037428334,
|
| 115 |
+
"total_loss": 5.207888234798891,
|
| 116 |
+
"topic_loss": 0.13986067138923575,
|
| 117 |
+
"topic_accuracy": 0.9685236768802278
|
| 118 |
},
|
| 119 |
"val_epoch_5": {
|
| 120 |
+
"summarization_loss": 3.664777074654897,
|
| 121 |
+
"summarization_rouge_like": 0.22876987463000684,
|
| 122 |
+
"summarization_rouge1": 0.27596093399625565,
|
| 123 |
+
"summarization_rouge2": 0.09296804123657829,
|
| 124 |
+
"summarization_rougeL": 0.21411928790828857,
|
| 125 |
+
"summarization_bleu4": 0.05366559404113782,
|
| 126 |
+
"emotion_loss": 0.11044646929949523,
|
| 127 |
+
"emotion_f1": 0.4313555757453044,
|
| 128 |
+
"topic_loss": 0.5484664579232533,
|
| 129 |
+
"topic_accuracy": 0.8627777777777789,
|
| 130 |
+
"total_loss": 3.9397634813313704
|
| 131 |
},
|
| 132 |
"train_epoch_6": {
|
| 133 |
+
"emotion_loss": 0.1111307007874511,
|
| 134 |
+
"emotion_f1": 0.43345397762862603,
|
| 135 |
+
"summarization_loss": 3.7095002406409807,
|
| 136 |
+
"summarization_rouge_like": 0.22328726116125275,
|
| 137 |
+
"summarization_rouge1": 0.3064035344877472,
|
| 138 |
+
"summarization_rouge2": 0.09935359454486654,
|
| 139 |
+
"summarization_rougeL": 0.23650841461700828,
|
| 140 |
+
"summarization_bleu4": 0.059165680810364656,
|
| 141 |
+
"total_loss": 5.231221632164746,
|
| 142 |
+
"topic_loss": 0.10774352340420275,
|
| 143 |
+
"topic_accuracy": 0.9777777777777826
|
| 144 |
},
|
| 145 |
"val_epoch_6": {
|
| 146 |
+
"summarization_loss": 3.658109269142151,
|
| 147 |
+
"summarization_rouge_like": 0.22934290201883448,
|
| 148 |
+
"summarization_rouge1": 0.2752052666208255,
|
| 149 |
+
"summarization_rouge2": 0.09292038370832255,
|
| 150 |
+
"summarization_rougeL": 0.2137414809166316,
|
| 151 |
+
"summarization_bleu4": 0.053427338475007496,
|
| 152 |
+
"emotion_loss": 0.10808507531881333,
|
| 153 |
+
"emotion_f1": 0.4517777989556392,
|
| 154 |
+
"topic_loss": 0.5295590771238009,
|
| 155 |
+
"topic_accuracy": 0.8734444444444451,
|
| 156 |
+
"total_loss": 3.9250620675981014
|
| 157 |
},
|
| 158 |
"train_epoch_7": {
|
| 159 |
+
"emotion_loss": 0.10953594371440704,
|
| 160 |
+
"emotion_f1": 0.44384909393388133,
|
| 161 |
+
"topic_loss": 0.0957411853224039,
|
| 162 |
+
"topic_accuracy": 0.980394366197187,
|
| 163 |
+
"total_loss": 5.154093306898701,
|
| 164 |
+
"summarization_loss": 3.7035266418583594,
|
| 165 |
+
"summarization_rouge_like": 0.22378105952974536,
|
| 166 |
+
"summarization_rouge1": 0.3070619920824417,
|
| 167 |
+
"summarization_rouge2": 0.09984959921270933,
|
| 168 |
+
"summarization_rougeL": 0.23710279675635842,
|
| 169 |
+
"summarization_bleu4": 0.05954598113800495
|
| 170 |
},
|
| 171 |
"val_epoch_7": {
|
| 172 |
+
"summarization_loss": 3.654966928164164,
|
| 173 |
+
"summarization_rouge_like": 0.2296679906954514,
|
| 174 |
+
"summarization_rouge1": 0.27616327736195406,
|
| 175 |
+
"summarization_rouge2": 0.09329265746038877,
|
| 176 |
+
"summarization_rougeL": 0.2144202156909426,
|
| 177 |
+
"summarization_bleu4": 0.05381191556748925,
|
| 178 |
+
"emotion_loss": 0.10733611459533374,
|
| 179 |
+
"emotion_f1": 0.4582889095693827,
|
| 180 |
+
"topic_loss": 0.5517185291647911,
|
| 181 |
+
"topic_accuracy": 0.8574444444444457,
|
| 182 |
+
"total_loss": 3.9278186015089296
|
| 183 |
+
},
|
| 184 |
+
"train_epoch_8": {
|
| 185 |
+
"summarization_loss": 3.6991967220660666,
|
| 186 |
+
"summarization_rouge_like": 0.22392498422300275,
|
| 187 |
+
"summarization_rouge1": 0.30751530664889926,
|
| 188 |
+
"summarization_rouge2": 0.10003700619268063,
|
| 189 |
+
"summarization_rougeL": 0.23750205422812004,
|
| 190 |
+
"summarization_bleu4": 0.05974583539783897,
|
| 191 |
+
"emotion_loss": 0.10842480880968565,
|
| 192 |
+
"emotion_f1": 0.44919879924130307,
|
| 193 |
+
"total_loss": 5.178375478314296,
|
| 194 |
+
"topic_loss": 0.09057999134229244,
|
| 195 |
+
"topic_accuracy": 0.9817882611080675
|
| 196 |
+
},
|
| 197 |
+
"val_epoch_8": {
|
| 198 |
+
"summarization_loss": 3.652825821240743,
|
| 199 |
+
"summarization_rouge_like": 0.22990084413830203,
|
| 200 |
+
"summarization_rouge1": 0.2765755402183266,
|
| 201 |
+
"summarization_rouge2": 0.09348690574327727,
|
| 202 |
+
"summarization_rougeL": 0.2147442273650338,
|
| 203 |
+
"summarization_bleu4": 0.053967258926407226,
|
| 204 |
+
"emotion_loss": 0.10670269103099903,
|
| 205 |
+
"emotion_f1": 0.4594111327578624,
|
| 206 |
+
"topic_loss": 0.5511919154723486,
|
| 207 |
+
"topic_accuracy": 0.8574444444444457,
|
| 208 |
+
"total_loss": 3.924886086913453
|
| 209 |
}
|
| 210 |
}
|
pyproject.toml
CHANGED
|
@@ -39,7 +39,7 @@ mlflow = ">=2.0.0"
|
|
| 39 |
sentencepiece = ">=0.1.99"
|
| 40 |
triton = { version = "*", markers = "sys_platform == 'linux'" }
|
| 41 |
|
| 42 |
-
[tool.poetry.
|
| 43 |
pytest = "^7.4.0"
|
| 44 |
pytest-cov = "^4.1.0"
|
| 45 |
ruff = "^0.4.0"
|
|
@@ -106,7 +106,11 @@ module = [
|
|
| 106 |
"fastapi.*",
|
| 107 |
"mlflow.*",
|
| 108 |
"pydantic.*",
|
| 109 |
-
"rouge_score.*"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
]
|
| 111 |
ignore_missing_imports = true
|
| 112 |
follow_imports = "skip"
|
|
|
|
| 39 |
sentencepiece = ">=0.1.99"
|
| 40 |
triton = { version = "*", markers = "sys_platform == 'linux'" }
|
| 41 |
|
| 42 |
+
[tool.poetry.dev-dependencies]
|
| 43 |
pytest = "^7.4.0"
|
| 44 |
pytest-cov = "^4.1.0"
|
| 45 |
ruff = "^0.4.0"
|
|
|
|
| 106 |
"fastapi.*",
|
| 107 |
"mlflow.*",
|
| 108 |
"pydantic.*",
|
| 109 |
+
"rouge_score.*",
|
| 110 |
+
"bert_score.*",
|
| 111 |
+
"pytest",
|
| 112 |
+
"pytest.*",
|
| 113 |
+
"mpl_toolkits.*"
|
| 114 |
]
|
| 115 |
ignore_missing_imports = true
|
| 116 |
follow_imports = "skip"
|
scripts/build_discovery_dataset.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""Build a discovery dataset for the HuggingFace Space demo.
|
| 3 |
|
| 4 |
This script samples from the already-filtered training data (processed by
|
|
@@ -22,12 +21,11 @@ from typing import Any
|
|
| 22 |
# Add project root to path
|
| 23 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 24 |
|
| 25 |
-
import torch
|
| 26 |
-
from datasets import Dataset
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
|
| 29 |
-
from src.inference.factory import create_inference_pipeline
|
| 30 |
|
|
|
|
| 31 |
|
| 32 |
# --------------- Data Loading ---------------
|
| 33 |
|
|
@@ -176,8 +174,8 @@ def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
|
|
| 176 |
results.append(result)
|
| 177 |
|
| 178 |
# Print distribution stats
|
| 179 |
-
topic_dist = defaultdict(int)
|
| 180 |
-
emotion_dist = defaultdict(int)
|
| 181 |
for r in results:
|
| 182 |
topic_dist[r["topic"]] += 1
|
| 183 |
emotion_dist[r["emotion"]] += 1
|
|
|
|
|
|
|
| 1 |
"""Build a discovery dataset for the HuggingFace Space demo.
|
| 2 |
|
| 3 |
This script samples from the already-filtered training data (processed by
|
|
|
|
| 21 |
# Add project root to path
|
| 22 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 23 |
|
| 24 |
+
import torch # noqa: E402
|
| 25 |
+
from datasets import Dataset # noqa: E402
|
| 26 |
+
from tqdm import tqdm # noqa: E402
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
from src.inference.factory import create_inference_pipeline # noqa: E402
|
| 29 |
|
| 30 |
# --------------- Data Loading ---------------
|
| 31 |
|
|
|
|
| 174 |
results.append(result)
|
| 175 |
|
| 176 |
# Print distribution stats
|
| 177 |
+
topic_dist: dict[str, int] = defaultdict(int)
|
| 178 |
+
emotion_dist: dict[str, int] = defaultdict(int)
|
| 179 |
for r in results:
|
| 180 |
topic_dist[r["topic"]] += 1
|
| 181 |
emotion_dist[r["emotion"]] += 1
|
scripts/demo_gradio.py
CHANGED
|
@@ -93,10 +93,8 @@ def format_item_card(item: dict) -> str:
|
|
| 93 |
|
| 94 |
# Icon based on type
|
| 95 |
if source_type == "academic":
|
| 96 |
-
icon = "📄"
|
| 97 |
type_label = "Research Paper"
|
| 98 |
else:
|
| 99 |
-
icon = "📖"
|
| 100 |
type_label = "Literature"
|
| 101 |
|
| 102 |
# Topic and emotion with confidence
|
|
@@ -109,10 +107,10 @@ def format_item_card(item: dict) -> str:
|
|
| 109 |
use_reference = item.get("use_reference_summary", False)
|
| 110 |
if use_reference or source_type == "literary":
|
| 111 |
summary = item.get("reference_summary", "")
|
| 112 |
-
summary_label = "
|
| 113 |
else:
|
| 114 |
summary = item.get("generated_summary", "")
|
| 115 |
-
summary_label = "
|
| 116 |
|
| 117 |
if not summary:
|
| 118 |
summary = "No summary available."
|
|
@@ -124,23 +122,17 @@ def format_item_card(item: dict) -> str:
|
|
| 124 |
# Preview of original text
|
| 125 |
text_preview = item.get("text", "")[:400] + "..." if len(item.get("text", "")) > 400 else item.get("text", "")
|
| 126 |
|
| 127 |
-
#
|
| 128 |
-
topic_badge = "🟢" if topic_conf > 0.6 else "🟡" if topic_conf > 0.3 else "🔴"
|
| 129 |
-
emotion_badge = "🟢" if emotion_conf > 0.6 else "🟡" if emotion_conf > 0.3 else "🔴"
|
| 130 |
-
|
| 131 |
-
return f"""### {icon} **{title}**
|
| 132 |
|
| 133 |
<small>*{type_label}* from {dataset_name}</small>
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|-------|---------|
|
| 137 |
-
| {topic_badge} {topic} ({topic_conf:.0%}) | {emotion_badge} {emotion.title()} ({emotion_conf:.0%}) |
|
| 138 |
|
| 139 |
{summary_label}
|
| 140 |
> {summary}
|
| 141 |
|
| 142 |
<details>
|
| 143 |
-
<summary>
|
| 144 |
|
| 145 |
{text_preview}
|
| 146 |
|
|
@@ -164,12 +156,12 @@ def browse_by_topic(topic: str) -> str:
|
|
| 164 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 165 |
|
| 166 |
if literary:
|
| 167 |
-
result += "###
|
| 168 |
for item in literary[:25]: # Limit to avoid huge pages
|
| 169 |
result += format_item_card(item)
|
| 170 |
|
| 171 |
if academic:
|
| 172 |
-
result += "###
|
| 173 |
for item in academic[:25]:
|
| 174 |
result += format_item_card(item)
|
| 175 |
|
|
@@ -189,12 +181,12 @@ def browse_by_emotion(emotion: str) -> str:
|
|
| 189 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 190 |
|
| 191 |
if literary:
|
| 192 |
-
result += "###
|
| 193 |
for item in literary[:25]:
|
| 194 |
result += format_item_card(item)
|
| 195 |
|
| 196 |
if academic:
|
| 197 |
-
result += "###
|
| 198 |
for item in academic[:25]:
|
| 199 |
result += format_item_card(item)
|
| 200 |
|
|
@@ -239,20 +231,10 @@ with gr.Blocks(
|
|
| 239 |
|
| 240 |
gr.Markdown(
|
| 241 |
"""
|
| 242 |
-
#
|
| 243 |
-
###
|
| 244 |
-
|
| 245 |
-
Explore **{total_count}** items analyzed by the LexiMind multi-task transformer:
|
| 246 |
-
|
| 247 |
-
| Source | Count | Description |
|
| 248 |
-
|--------|-------|-------------|
|
| 249 |
-
| 📖 Literature | {lit_count} | Classic books with Goodreads-style descriptions |
|
| 250 |
-
| 📄 Research | {paper_count} | Scientific papers from arXiv |
|
| 251 |
|
| 252 |
-
**
|
| 253 |
-
- 🏷️ **Topic Classification**: Fiction, Science, History, Philosophy, Arts, Business, Technology
|
| 254 |
-
- 💭 **Emotion Detection**: 28 emotions (joy, sadness, anger, fear, surprise, love, etc.)
|
| 255 |
-
- 📝 **Book Descriptions**: Back-cover style summaries of what texts are about
|
| 256 |
|
| 257 |
---
|
| 258 |
""".format(
|
|
@@ -264,7 +246,7 @@ with gr.Blocks(
|
|
| 264 |
|
| 265 |
with gr.Tabs():
|
| 266 |
# ===================== TAB 1: BROWSE BY TOPIC =====================
|
| 267 |
-
with gr.Tab("
|
| 268 |
gr.Markdown("*Select a topic to explore related books and papers*")
|
| 269 |
|
| 270 |
topic_dropdown = gr.Dropdown(
|
|
@@ -286,7 +268,7 @@ with gr.Blocks(
|
|
| 286 |
)
|
| 287 |
|
| 288 |
# ===================== TAB 2: BROWSE BY EMOTION =====================
|
| 289 |
-
with gr.Tab("
|
| 290 |
gr.Markdown("*Find books and papers that evoke specific emotions*")
|
| 291 |
|
| 292 |
emotion_dropdown = gr.Dropdown(
|
|
@@ -308,7 +290,7 @@ with gr.Blocks(
|
|
| 308 |
)
|
| 309 |
|
| 310 |
# ===================== TAB 3: SEARCH =====================
|
| 311 |
-
with gr.Tab("
|
| 312 |
gr.Markdown("*Search through all books and papers by keyword*")
|
| 313 |
|
| 314 |
search_input = gr.Textbox(
|
|
@@ -329,45 +311,39 @@ with gr.Blocks(
|
|
| 329 |
)
|
| 330 |
|
| 331 |
# ===================== TAB 4: METRICS =====================
|
| 332 |
-
with gr.Tab("
|
| 333 |
gr.Markdown(
|
| 334 |
"""
|
| 335 |
### Evaluation Metrics
|
| 336 |
|
| 337 |
-
|
| 338 |
-
Metrics are computed on held-out validation data.
|
| 339 |
"""
|
| 340 |
)
|
| 341 |
|
| 342 |
# Summarization Metrics
|
| 343 |
-
gr.Markdown("####
|
| 344 |
|
| 345 |
if METRICS.get("summarization"):
|
| 346 |
summ = METRICS["summarization"]
|
| 347 |
summ_md = """
|
| 348 |
-
| Metric | Score |
|
| 349 |
-
|--------|-------|
|
| 350 |
-
| **ROUGE-1** | {rouge1:.4f} |
|
| 351 |
-
| **ROUGE-2** | {rouge2:.4f} |
|
| 352 |
-
| **ROUGE-L** | {rougeL:.4f} |
|
| 353 |
-
| **BLEU-4** | {bleu4:.4f} |
|
| 354 |
-
| **BERTScore F1** | {bertscore:.4f} | Semantic similarity (contextual) |
|
| 355 |
-
|
| 356 |
-
*Note: For back-cover style descriptions, BERTScore is more meaningful than ROUGE
|
| 357 |
-
since descriptions paraphrase rather than quote the source text.*
|
| 358 |
""".format(
|
| 359 |
rouge1=summ.get("rouge_rouge1", summ.get("rouge1", 0)),
|
| 360 |
rouge2=summ.get("rouge_rouge2", summ.get("rouge2", 0)),
|
| 361 |
rougeL=summ.get("rouge_rougeL", summ.get("rougeL", 0)),
|
| 362 |
bleu4=summ.get("bleu4", 0),
|
| 363 |
-
bertscore=summ.get("bertscore_f1", 0),
|
| 364 |
)
|
| 365 |
gr.Markdown(summ_md)
|
| 366 |
else:
|
| 367 |
gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
|
| 368 |
|
| 369 |
# Topic Classification Metrics
|
| 370 |
-
gr.Markdown("####
|
| 371 |
|
| 372 |
if METRICS.get("topic"):
|
| 373 |
topic = METRICS["topic"]
|
|
@@ -376,125 +352,66 @@ since descriptions paraphrase rather than quote the source text.*
|
|
| 376 |
|--------|-------|
|
| 377 |
| **Accuracy** | {accuracy:.2%} |
|
| 378 |
| **Macro F1** | {f1:.4f} |
|
| 379 |
-
| **Precision** | {precision:.4f} |
|
| 380 |
-
| **Recall** | {recall:.4f} |
|
| 381 |
""".format(
|
| 382 |
accuracy=topic.get("accuracy", 0),
|
| 383 |
f1=topic.get("f1", topic.get("macro_f1", 0)),
|
| 384 |
-
precision=topic.get("precision", 0),
|
| 385 |
-
recall=topic.get("recall", 0),
|
| 386 |
)
|
| 387 |
gr.Markdown(topic_md)
|
| 388 |
else:
|
| 389 |
gr.Markdown("*Topic classification metrics not available.*")
|
| 390 |
|
| 391 |
# Emotion Detection Metrics
|
| 392 |
-
gr.Markdown("####
|
| 393 |
|
| 394 |
if METRICS.get("emotion"):
|
| 395 |
emotion = METRICS["emotion"]
|
| 396 |
emotion_md = """
|
| 397 |
| Metric | Score |
|
| 398 |
|--------|-------|
|
| 399 |
-
| **
|
| 400 |
-
| **
|
| 401 |
-
| **
|
| 402 |
|
| 403 |
-
*
|
| 404 |
""".format(
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
)
|
| 409 |
gr.Markdown(emotion_md)
|
| 410 |
else:
|
| 411 |
gr.Markdown("*Emotion detection metrics not available.*")
|
| 412 |
|
| 413 |
# Dataset Statistics
|
| 414 |
-
gr.Markdown("####
|
| 415 |
-
|
| 416 |
-
# Build topic list with indicators for observed vs possible
|
| 417 |
-
topic_list = ", ".join([
|
| 418 |
-
f"**{t}**" if t in TOPICS else t for t in ALL_TOPICS
|
| 419 |
-
])
|
| 420 |
-
emotion_list = ", ".join([
|
| 421 |
-
f"**{e}**" if e in EMOTIONS else e for e in ALL_EMOTIONS
|
| 422 |
-
])
|
| 423 |
|
| 424 |
gr.Markdown(f"""
|
| 425 |
| Statistic | Value |
|
| 426 |
|-----------|-------|
|
| 427 |
-
| Total
|
| 428 |
| Literary Works | {len(BOOKS)} |
|
| 429 |
-
| Academic Papers
|
| 430 |
-
| Topics
|
| 431 |
-
| Emotions
|
| 432 |
-
|
| 433 |
-
**All Model Topics ({len(ALL_TOPICS)}):** {topic_list}
|
| 434 |
-
|
| 435 |
-
**All Model Emotions ({len(ALL_EMOTIONS)}):** {emotion_list}
|
| 436 |
-
|
| 437 |
-
*Bold items appear in the discovery dataset. The model can predict all listed labels.*
|
| 438 |
-
|
| 439 |
-
---
|
| 440 |
-
|
| 441 |
-
**Note on Content Types:**
|
| 442 |
-
- 📄 **Academic Papers** include CS/AI papers (Technology), Physics/Math (Science), Economics (Business)
|
| 443 |
-
- 📖 **Literary Works** include novels (Fiction), biographies (History), philosophical texts (Philosophy)
|
| 444 |
-
- Technical blogs and tutorials would be classified under **Technology**
|
| 445 |
""")
|
| 446 |
|
| 447 |
# ===================== TAB 5: ABOUT =====================
|
| 448 |
-
with gr.Tab("
|
| 449 |
gr.Markdown(
|
| 450 |
"""
|
| 451 |
### About LexiMind
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
| Task | Description |
|
| 456 |
-
|------|-------------|
|
| 457 |
-
| **Book Descriptions** | Generate back-cover style descriptions of what books are about |
|
| 458 |
-
| **Topic Classification** | Categorize into Fiction, Science, Technology, Philosophy, History, Business, Arts |
|
| 459 |
-
| **Emotion Detection** | Identify emotional tones (28 emotions from GoEmotions) |
|
| 460 |
-
|
| 461 |
-
### Architecture
|
| 462 |
-
|
| 463 |
-
- **Base:** FLAN-T5-base (Google)
|
| 464 |
-
- **Encoder:** 12 layers, 768 dim, 12 attention heads
|
| 465 |
-
- **Decoder:** 12 layers with causal attention
|
| 466 |
-
- **Position:** T5 relative position bias
|
| 467 |
-
- **Training:** Multi-task learning with task-specific heads
|
| 468 |
-
|
| 469 |
-
### Training Data
|
| 470 |
-
|
| 471 |
-
| Dataset | Task | Samples |
|
| 472 |
-
|---------|------|---------|
|
| 473 |
-
| Gutenberg + Goodreads | Book Descriptions | ~4K literary pairs |
|
| 474 |
-
| arXiv (body → abstract) | Paper Abstracts | ~45K academic pairs |
|
| 475 |
-
| 20 Newsgroups + Gutenberg + arXiv | Topic Classification | 3.4K (7 classes) |
|
| 476 |
-
| GoEmotions (Reddit) | Emotion Detection | 43K (28 labels) |
|
| 477 |
-
|
| 478 |
-
### Key Design Decision
|
| 479 |
-
|
| 480 |
-
LexiMind generates **back-cover style descriptions** (what a book is about) rather than
|
| 481 |
-
plot summaries (what happens in the book). This is achieved by training on Goodreads
|
| 482 |
-
descriptions paired with Project Gutenberg book texts.
|
| 483 |
-
|
| 484 |
-
### Evaluation Metrics
|
| 485 |
|
| 486 |
-
- **
|
| 487 |
-
- **
|
| 488 |
-
- **
|
| 489 |
|
| 490 |
-
|
| 491 |
|
| 492 |
-
|
| 493 |
-
- 🤗 [Model](https://huggingface.co/OliverPerrin/LexiMind-Model)
|
| 494 |
-
- 📊 [Discovery Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)
|
| 495 |
|
| 496 |
-
-
|
| 497 |
-
*Built by Oliver Perrin • Appalachian State University • 2025-2026*
|
| 498 |
"""
|
| 499 |
)
|
| 500 |
|
|
|
|
| 93 |
|
| 94 |
# Icon based on type
|
| 95 |
if source_type == "academic":
|
|
|
|
| 96 |
type_label = "Research Paper"
|
| 97 |
else:
|
|
|
|
| 98 |
type_label = "Literature"
|
| 99 |
|
| 100 |
# Topic and emotion with confidence
|
|
|
|
| 107 |
use_reference = item.get("use_reference_summary", False)
|
| 108 |
if use_reference or source_type == "literary":
|
| 109 |
summary = item.get("reference_summary", "")
|
| 110 |
+
summary_label = "**Book Description:**"
|
| 111 |
else:
|
| 112 |
summary = item.get("generated_summary", "")
|
| 113 |
+
summary_label = "**AI-Generated Description:**"
|
| 114 |
|
| 115 |
if not summary:
|
| 116 |
summary = "No summary available."
|
|
|
|
| 122 |
# Preview of original text
|
| 123 |
text_preview = item.get("text", "")[:400] + "..." if len(item.get("text", "")) > 400 else item.get("text", "")
|
| 124 |
|
| 125 |
+
return f"""### **{title}**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
<small>*{type_label}* from {dataset_name}</small>
|
| 128 |
|
| 129 |
+
**Topic:** {topic} ({topic_conf:.0%}) | **Emotion:** {emotion.title()} ({emotion_conf:.0%})
|
|
|
|
|
|
|
| 130 |
|
| 131 |
{summary_label}
|
| 132 |
> {summary}
|
| 133 |
|
| 134 |
<details>
|
| 135 |
+
<summary>View Original Text</summary>
|
| 136 |
|
| 137 |
{text_preview}
|
| 138 |
|
|
|
|
| 156 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 157 |
|
| 158 |
if literary:
|
| 159 |
+
result += "### Literary Works\n\n"
|
| 160 |
for item in literary[:25]: # Limit to avoid huge pages
|
| 161 |
result += format_item_card(item)
|
| 162 |
|
| 163 |
if academic:
|
| 164 |
+
result += "### Academic Papers\n\n"
|
| 165 |
for item in academic[:25]:
|
| 166 |
result += format_item_card(item)
|
| 167 |
|
|
|
|
| 181 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 182 |
|
| 183 |
if literary:
|
| 184 |
+
result += "### Literary Works\n\n"
|
| 185 |
for item in literary[:25]:
|
| 186 |
result += format_item_card(item)
|
| 187 |
|
| 188 |
if academic:
|
| 189 |
+
result += "### Academic Papers\n\n"
|
| 190 |
for item in academic[:25]:
|
| 191 |
result += format_item_card(item)
|
| 192 |
|
|
|
|
| 231 |
|
| 232 |
gr.Markdown(
|
| 233 |
"""
|
| 234 |
+
# LexiMind
|
| 235 |
+
### Discover Books & Papers by Topic, Emotion, or Keyword
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
Browse **{total_count}** texts — {lit_count} classic books and {paper_count} research papers — analyzed by a multi-task transformer.
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
---
|
| 240 |
""".format(
|
|
|
|
| 246 |
|
| 247 |
with gr.Tabs():
|
| 248 |
# ===================== TAB 1: BROWSE BY TOPIC =====================
|
| 249 |
+
with gr.Tab("By Topic"):
|
| 250 |
gr.Markdown("*Select a topic to explore related books and papers*")
|
| 251 |
|
| 252 |
topic_dropdown = gr.Dropdown(
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
# ===================== TAB 2: BROWSE BY EMOTION =====================
|
| 271 |
+
with gr.Tab("By Emotion"):
|
| 272 |
gr.Markdown("*Find books and papers that evoke specific emotions*")
|
| 273 |
|
| 274 |
emotion_dropdown = gr.Dropdown(
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
# ===================== TAB 3: SEARCH =====================
|
| 293 |
+
with gr.Tab("Search"):
|
| 294 |
gr.Markdown("*Search through all books and papers by keyword*")
|
| 295 |
|
| 296 |
search_input = gr.Textbox(
|
|
|
|
| 311 |
)
|
| 312 |
|
| 313 |
# ===================== TAB 4: METRICS =====================
|
| 314 |
+
with gr.Tab("Metrics"):
|
| 315 |
gr.Markdown(
|
| 316 |
"""
|
| 317 |
### Evaluation Metrics
|
| 318 |
|
| 319 |
+
Computed on held-out validation data.
|
|
|
|
| 320 |
"""
|
| 321 |
)
|
| 322 |
|
| 323 |
# Summarization Metrics
|
| 324 |
+
gr.Markdown("#### Summarization")
|
| 325 |
|
| 326 |
if METRICS.get("summarization"):
|
| 327 |
summ = METRICS["summarization"]
|
| 328 |
summ_md = """
|
| 329 |
+
| Metric | Score |
|
| 330 |
+
|--------|-------|
|
| 331 |
+
| **ROUGE-1** | {rouge1:.4f} |
|
| 332 |
+
| **ROUGE-2** | {rouge2:.4f} |
|
| 333 |
+
| **ROUGE-L** | {rougeL:.4f} |
|
| 334 |
+
| **BLEU-4** | {bleu4:.4f} |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
""".format(
|
| 336 |
rouge1=summ.get("rouge_rouge1", summ.get("rouge1", 0)),
|
| 337 |
rouge2=summ.get("rouge_rouge2", summ.get("rouge2", 0)),
|
| 338 |
rougeL=summ.get("rouge_rougeL", summ.get("rougeL", 0)),
|
| 339 |
bleu4=summ.get("bleu4", 0),
|
|
|
|
| 340 |
)
|
| 341 |
gr.Markdown(summ_md)
|
| 342 |
else:
|
| 343 |
gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
|
| 344 |
|
| 345 |
# Topic Classification Metrics
|
| 346 |
+
gr.Markdown("#### Topic Classification")
|
| 347 |
|
| 348 |
if METRICS.get("topic"):
|
| 349 |
topic = METRICS["topic"]
|
|
|
|
| 352 |
|--------|-------|
|
| 353 |
| **Accuracy** | {accuracy:.2%} |
|
| 354 |
| **Macro F1** | {f1:.4f} |
|
|
|
|
|
|
|
| 355 |
""".format(
|
| 356 |
accuracy=topic.get("accuracy", 0),
|
| 357 |
f1=topic.get("f1", topic.get("macro_f1", 0)),
|
|
|
|
|
|
|
| 358 |
)
|
| 359 |
gr.Markdown(topic_md)
|
| 360 |
else:
|
| 361 |
gr.Markdown("*Topic classification metrics not available.*")
|
| 362 |
|
| 363 |
# Emotion Detection Metrics
|
| 364 |
+
gr.Markdown("#### Emotion Detection")
|
| 365 |
|
| 366 |
if METRICS.get("emotion"):
|
| 367 |
emotion = METRICS["emotion"]
|
| 368 |
emotion_md = """
|
| 369 |
| Metric | Score |
|
| 370 |
|--------|-------|
|
| 371 |
+
| **Sample-avg F1** | {sample_f1:.4f} |
|
| 372 |
+
| **Macro F1** | {macro_f1:.4f} |
|
| 373 |
+
| **Micro F1** | {micro_f1:.4f} |
|
| 374 |
|
| 375 |
+
*28-label multi-label classification from GoEmotions.*
|
| 376 |
""".format(
|
| 377 |
+
sample_f1=emotion.get("sample_avg_f1", emotion.get("f1", emotion.get("multilabel_f1", 0))),
|
| 378 |
+
macro_f1=emotion.get("macro_f1", 0),
|
| 379 |
+
micro_f1=emotion.get("micro_f1", 0),
|
| 380 |
)
|
| 381 |
gr.Markdown(emotion_md)
|
| 382 |
else:
|
| 383 |
gr.Markdown("*Emotion detection metrics not available.*")
|
| 384 |
|
| 385 |
# Dataset Statistics
|
| 386 |
+
gr.Markdown("#### Dataset Statistics")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
gr.Markdown(f"""
|
| 389 |
| Statistic | Value |
|
| 390 |
|-----------|-------|
|
| 391 |
+
| Total Items | {len(ALL_ITEMS)} |
|
| 392 |
| Literary Works | {len(BOOKS)} |
|
| 393 |
+
| Academic Papers | {len(PAPERS)} |
|
| 394 |
+
| Topics | {len(TOPICS)} |
|
| 395 |
+
| Emotions | {len(EMOTIONS)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
""")
|
| 397 |
|
| 398 |
# ===================== TAB 5: ABOUT =====================
|
| 399 |
+
with gr.Tab("About"):
|
| 400 |
gr.Markdown(
|
| 401 |
"""
|
| 402 |
### About LexiMind
|
| 403 |
|
| 404 |
+
A **272M parameter encoder-decoder transformer** (FLAN-T5-base) trained on three tasks:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
- **Summarization**: Generate back-cover style descriptions from full text
|
| 407 |
+
- **Topic Classification**: 7 categories (Fiction, Science, History, Philosophy, Arts, Business, Technology)
|
| 408 |
+
- **Emotion Detection**: 28 emotions via GoEmotions
|
| 409 |
|
| 410 |
+
Training data: ~49K summarization pairs (arXiv + Goodreads), 43K emotion samples, 3.4K topic samples.
|
| 411 |
|
| 412 |
+
[GitHub](https://github.com/OliverPerrin/LexiMind) | [Model](https://huggingface.co/OliverPerrin/LexiMind-Model) | [Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)
|
|
|
|
|
|
|
| 413 |
|
| 414 |
+
*Oliver Perrin — Appalachian State University — 2025-2026*
|
|
|
|
| 415 |
"""
|
| 416 |
)
|
| 417 |
|
scripts/download_data.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# pyright: reportAttributeAccessIssue=false
|
| 3 |
-
# pyright: reportArgumentType=false
|
| 4 |
-
# pyright: reportCallIssue=false
|
| 5 |
"""
|
| 6 |
Dataset download script for LexiMind.
|
| 7 |
|
|
@@ -45,7 +41,7 @@ from tqdm import tqdm
|
|
| 45 |
# Output directory
|
| 46 |
OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
|
| 47 |
|
| 48 |
-
#
|
| 49 |
|
| 50 |
# 28 emotions from GoEmotions - works for all text types
|
| 51 |
EMOTION_LABELS = [
|
|
@@ -115,10 +111,10 @@ def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing"
|
|
| 115 |
with path.open("w", encoding="utf-8") as f:
|
| 116 |
for record in tqdm(records, desc=desc, leave=False):
|
| 117 |
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 118 |
-
print(f"
|
| 119 |
|
| 120 |
|
| 121 |
-
#
|
| 122 |
|
| 123 |
# Common English words for detection
|
| 124 |
ENGLISH_WORDS = {
|
|
@@ -144,7 +140,7 @@ NON_ENGLISH_PATTERNS = [
|
|
| 144 |
r"\b(et|in|ad|cum|de|ex|per|pro|sub|ab|ante|post|inter|contra|super|trans|apud)\b",
|
| 145 |
]
|
| 146 |
|
| 147 |
-
#
|
| 148 |
|
| 149 |
# Patterns that indicate garbage/metadata text
|
| 150 |
GARBAGE_PATTERNS = [
|
|
@@ -320,7 +316,7 @@ def normalize_title(title: str) -> str:
|
|
| 320 |
return title.lower().strip()
|
| 321 |
|
| 322 |
|
| 323 |
-
#
|
| 324 |
|
| 325 |
def download_goodreads_descriptions() -> dict[str, dict]:
|
| 326 |
"""
|
|
@@ -329,7 +325,7 @@ def download_goodreads_descriptions() -> dict[str, dict]:
|
|
| 329 |
These are "what the book is about" descriptions, not plot summaries.
|
| 330 |
Returns dict mapping normalized title -> {title, description}
|
| 331 |
"""
|
| 332 |
-
print("\
|
| 333 |
|
| 334 |
descriptions = {}
|
| 335 |
|
|
@@ -392,7 +388,7 @@ def download_book_descriptions(
|
|
| 392 |
This gives us (book_excerpt, book_description) training pairs where descriptions
|
| 393 |
are back-cover style "what is this book about" blurbs, not plot summaries.
|
| 394 |
"""
|
| 395 |
-
print("\
|
| 396 |
|
| 397 |
try:
|
| 398 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
@@ -497,7 +493,7 @@ def download_booksum(max_samples: int = 20000) -> list[dict[str, Any]]:
|
|
| 497 |
Note: These are chapter-level plot summaries, useful as supplementary training data.
|
| 498 |
The primary book training comes from Goodreads descriptions (back-cover style).
|
| 499 |
"""
|
| 500 |
-
print("\
|
| 501 |
|
| 502 |
all_records: list[dict[str, Any]] = []
|
| 503 |
booksum = load_dataset("kmfoda/booksum")
|
|
@@ -600,7 +596,7 @@ def download_arxiv_summarization(max_samples: int = 50000) -> list[dict[str, Any
|
|
| 600 |
|
| 601 |
Returns: summarization_records
|
| 602 |
"""
|
| 603 |
-
print("\
|
| 604 |
|
| 605 |
print(" Loading dataset (this may take a minute)...")
|
| 606 |
arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
|
|
@@ -663,7 +659,7 @@ def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, An
|
|
| 663 |
- 20 Newsgroups (classic topic classification)
|
| 664 |
- Wikipedia (article categories)
|
| 665 |
"""
|
| 666 |
-
print("\
|
| 667 |
|
| 668 |
records: list[dict[str, Any]] = []
|
| 669 |
|
|
@@ -747,7 +743,7 @@ def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> No
|
|
| 747 |
plot summaries. This trains the model to describe "what the book is about"
|
| 748 |
rather than summarizing the plot.
|
| 749 |
"""
|
| 750 |
-
print("\
|
| 751 |
out_dir = OUTPUT_DIR / "summarization"
|
| 752 |
|
| 753 |
all_records: list[dict[str, Any]] = []
|
|
@@ -793,12 +789,12 @@ def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> No
|
|
| 793 |
# Print breakdown
|
| 794 |
literary_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "literary")
|
| 795 |
academic_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "academic")
|
| 796 |
-
print(f"\n
|
| 797 |
print(f" Literary (book descriptions): {literary_count:,}")
|
| 798 |
print(f" Academic (paper abstracts): {academic_count:,}")
|
| 799 |
|
| 800 |
|
| 801 |
-
#
|
| 802 |
|
| 803 |
def download_topics(max_samples: int = 50000) -> None:
|
| 804 |
"""
|
|
@@ -809,7 +805,7 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 809 |
- Gutenberg books (Fiction)
|
| 810 |
- Scientific papers (Science, Technology)
|
| 811 |
"""
|
| 812 |
-
print("\
|
| 813 |
out_dir = OUTPUT_DIR / "topic"
|
| 814 |
|
| 815 |
# Get topic records from various sources
|
|
@@ -830,14 +826,14 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 830 |
# Balance to min count (with some tolerance) - only from topics that have data
|
| 831 |
counts_with_data = [len(v) for v in topic_counts.values() if v]
|
| 832 |
if not counts_with_data:
|
| 833 |
-
print("
|
| 834 |
return
|
| 835 |
|
| 836 |
min_count = min(counts_with_data)
|
| 837 |
target_count = min(min_count, max_samples // len(TOPIC_LABELS))
|
| 838 |
|
| 839 |
balanced: list[dict[str, Any]] = []
|
| 840 |
-
for
|
| 841 |
if records:
|
| 842 |
random.shuffle(records)
|
| 843 |
balanced.extend(records[:target_count])
|
|
@@ -857,12 +853,12 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 857 |
# Save labels - only labels that have data
|
| 858 |
used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
|
| 859 |
(out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
|
| 860 |
-
print(f"\n
|
| 861 |
|
| 862 |
|
| 863 |
def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
| 864 |
"""Extract topic-labeled samples from Gutenberg books (English only)."""
|
| 865 |
-
print("\
|
| 866 |
|
| 867 |
try:
|
| 868 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
@@ -926,11 +922,11 @@ def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
|
| 926 |
return records
|
| 927 |
|
| 928 |
|
| 929 |
-
#
|
| 930 |
|
| 931 |
def download_emotions() -> None:
|
| 932 |
"""Download GoEmotions for emotion classification."""
|
| 933 |
-
print("\
|
| 934 |
out_dir = OUTPUT_DIR / "emotion"
|
| 935 |
|
| 936 |
ds = load_dataset("google-research-datasets/go_emotions", "simplified")
|
|
@@ -950,10 +946,10 @@ def download_emotions() -> None:
|
|
| 950 |
write_jsonl(records, out_dir / f"{split}.jsonl", split)
|
| 951 |
|
| 952 |
(out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
|
| 953 |
-
print(f"
|
| 954 |
|
| 955 |
|
| 956 |
-
#
|
| 957 |
|
| 958 |
GUTENBERG_JUNK_PATTERNS = [
|
| 959 |
r"Project Gutenberg", r"www\.gutenberg\.org", r"This ebook is for",
|
|
@@ -988,7 +984,7 @@ def is_clean_prose(text: str) -> bool:
|
|
| 988 |
|
| 989 |
def download_gutenberg(max_samples: int = 30000) -> None:
|
| 990 |
"""Download Gutenberg books for language modeling (English only)."""
|
| 991 |
-
print("\
|
| 992 |
out_dir = OUTPUT_DIR / "books"
|
| 993 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 994 |
|
|
@@ -1044,7 +1040,7 @@ def download_gutenberg(max_samples: int = 30000) -> None:
|
|
| 1044 |
write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
|
| 1045 |
|
| 1046 |
|
| 1047 |
-
#
|
| 1048 |
|
| 1049 |
def main() -> None:
|
| 1050 |
parser = argparse.ArgumentParser(description="Download LexiMind datasets")
|
|
@@ -1078,7 +1074,7 @@ def main() -> None:
|
|
| 1078 |
download_gutenberg(args.max_gutenberg)
|
| 1079 |
|
| 1080 |
print("\n" + "=" * 60)
|
| 1081 |
-
print("
|
| 1082 |
print("=" * 60)
|
| 1083 |
|
| 1084 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Dataset download script for LexiMind.
|
| 3 |
|
|
|
|
| 41 |
# Output directory
|
| 42 |
OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
|
| 43 |
|
| 44 |
+
# ------------ LABEL DEFINITIONS ------------
|
| 45 |
|
| 46 |
# 28 emotions from GoEmotions - works for all text types
|
| 47 |
EMOTION_LABELS = [
|
|
|
|
| 111 |
with path.open("w", encoding="utf-8") as f:
|
| 112 |
for record in tqdm(records, desc=desc, leave=False):
|
| 113 |
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 114 |
+
print(f" {len(records):,} samples -> {path}")
|
| 115 |
|
| 116 |
|
| 117 |
+
# ------------ ENGLISH LANGUAGE FILTER ------------
|
| 118 |
|
| 119 |
# Common English words for detection
|
| 120 |
ENGLISH_WORDS = {
|
|
|
|
| 140 |
r"\b(et|in|ad|cum|de|ex|per|pro|sub|ab|ante|post|inter|contra|super|trans|apud)\b",
|
| 141 |
]
|
| 142 |
|
| 143 |
+
# ------------ TEXT QUALITY FILTERS ------------
|
| 144 |
|
| 145 |
# Patterns that indicate garbage/metadata text
|
| 146 |
GARBAGE_PATTERNS = [
|
|
|
|
| 316 |
return title.lower().strip()
|
| 317 |
|
| 318 |
|
| 319 |
+
# -------- SUMMARIZATION: BOOKS + ARXIV ----------
|
| 320 |
|
| 321 |
def download_goodreads_descriptions() -> dict[str, dict]:
|
| 322 |
"""
|
|
|
|
| 325 |
These are "what the book is about" descriptions, not plot summaries.
|
| 326 |
Returns dict mapping normalized title -> {title, description}
|
| 327 |
"""
|
| 328 |
+
print("\nLoading Goodreads book descriptions...")
|
| 329 |
|
| 330 |
descriptions = {}
|
| 331 |
|
|
|
|
| 388 |
This gives us (book_excerpt, book_description) training pairs where descriptions
|
| 389 |
are back-cover style "what is this book about" blurbs, not plot summaries.
|
| 390 |
"""
|
| 391 |
+
print("\nMatching Gutenberg books with Goodreads descriptions...")
|
| 392 |
|
| 393 |
try:
|
| 394 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
|
|
| 493 |
Note: These are chapter-level plot summaries, useful as supplementary training data.
|
| 494 |
The primary book training comes from Goodreads descriptions (back-cover style).
|
| 495 |
"""
|
| 496 |
+
print("\nLoading BookSum (supplementary literary data)...")
|
| 497 |
|
| 498 |
all_records: list[dict[str, Any]] = []
|
| 499 |
booksum = load_dataset("kmfoda/booksum")
|
|
|
|
| 596 |
|
| 597 |
Returns: summarization_records
|
| 598 |
"""
|
| 599 |
+
print("\nLoading arXiv (academic papers for summarization)...")
|
| 600 |
|
| 601 |
print(" Loading dataset (this may take a minute)...")
|
| 602 |
arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
|
|
|
|
| 659 |
- 20 Newsgroups (classic topic classification)
|
| 660 |
- Wikipedia (article categories)
|
| 661 |
"""
|
| 662 |
+
print("\nLoading topic classification datasets...")
|
| 663 |
|
| 664 |
records: list[dict[str, Any]] = []
|
| 665 |
|
|
|
|
| 743 |
plot summaries. This trains the model to describe "what the book is about"
|
| 744 |
rather than summarizing the plot.
|
| 745 |
"""
|
| 746 |
+
print("\nDownloading Summarization Data...")
|
| 747 |
out_dir = OUTPUT_DIR / "summarization"
|
| 748 |
|
| 749 |
all_records: list[dict[str, Any]] = []
|
|
|
|
| 789 |
# Print breakdown
|
| 790 |
literary_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "literary")
|
| 791 |
academic_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "academic")
|
| 792 |
+
print(f"\n Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
|
| 793 |
print(f" Literary (book descriptions): {literary_count:,}")
|
| 794 |
print(f" Academic (paper abstracts): {academic_count:,}")
|
| 795 |
|
| 796 |
|
| 797 |
+
# ------------ TOPIC CLASSIFICATION ------------
|
| 798 |
|
| 799 |
def download_topics(max_samples: int = 50000) -> None:
|
| 800 |
"""
|
|
|
|
| 805 |
- Gutenberg books (Fiction)
|
| 806 |
- Scientific papers (Science, Technology)
|
| 807 |
"""
|
| 808 |
+
print("\nDownloading Topic Classification...")
|
| 809 |
out_dir = OUTPUT_DIR / "topic"
|
| 810 |
|
| 811 |
# Get topic records from various sources
|
|
|
|
| 826 |
# Balance to min count (with some tolerance) - only from topics that have data
|
| 827 |
counts_with_data = [len(v) for v in topic_counts.values() if v]
|
| 828 |
if not counts_with_data:
|
| 829 |
+
print(" Warning: No topic data found!")
|
| 830 |
return
|
| 831 |
|
| 832 |
min_count = min(counts_with_data)
|
| 833 |
target_count = min(min_count, max_samples // len(TOPIC_LABELS))
|
| 834 |
|
| 835 |
balanced: list[dict[str, Any]] = []
|
| 836 |
+
for _topic, records in topic_counts.items():
|
| 837 |
if records:
|
| 838 |
random.shuffle(records)
|
| 839 |
balanced.extend(records[:target_count])
|
|
|
|
| 853 |
# Save labels - only labels that have data
|
| 854 |
used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
|
| 855 |
(out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
|
| 856 |
+
print(f"\n {len(used_labels)} topic labels with data: {used_labels}")
|
| 857 |
|
| 858 |
|
| 859 |
def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
| 860 |
"""Extract topic-labeled samples from Gutenberg books (English only)."""
|
| 861 |
+
print("\nLoading Gutenberg for topic classification...")
|
| 862 |
|
| 863 |
try:
|
| 864 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
|
|
|
| 922 |
return records
|
| 923 |
|
| 924 |
|
| 925 |
+
# ------------ EMOTIONS (unchanged) -------------
|
| 926 |
|
| 927 |
def download_emotions() -> None:
|
| 928 |
"""Download GoEmotions for emotion classification."""
|
| 929 |
+
print("\nDownloading Emotions (GoEmotions)...")
|
| 930 |
out_dir = OUTPUT_DIR / "emotion"
|
| 931 |
|
| 932 |
ds = load_dataset("google-research-datasets/go_emotions", "simplified")
|
|
|
|
| 946 |
write_jsonl(records, out_dir / f"{split}.jsonl", split)
|
| 947 |
|
| 948 |
(out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
|
| 949 |
+
print(f" {len(EMOTION_LABELS)} emotion labels saved")
|
| 950 |
|
| 951 |
|
| 952 |
+
# --------------- GUTENBERG BOOKS (for language modeling) ---------------
|
| 953 |
|
| 954 |
GUTENBERG_JUNK_PATTERNS = [
|
| 955 |
r"Project Gutenberg", r"www\.gutenberg\.org", r"This ebook is for",
|
|
|
|
| 984 |
|
| 985 |
def download_gutenberg(max_samples: int = 30000) -> None:
|
| 986 |
"""Download Gutenberg books for language modeling (English only)."""
|
| 987 |
+
print("\nDownloading Gutenberg Books (English only)...")
|
| 988 |
out_dir = OUTPUT_DIR / "books"
|
| 989 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 990 |
|
|
|
|
| 1040 |
write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
|
| 1041 |
|
| 1042 |
|
| 1043 |
+
# ------------ MAIN ------------
|
| 1044 |
|
| 1045 |
def main() -> None:
|
| 1046 |
parser = argparse.ArgumentParser(description="Download LexiMind datasets")
|
|
|
|
| 1074 |
download_gutenberg(args.max_gutenberg)
|
| 1075 |
|
| 1076 |
print("\n" + "=" * 60)
|
| 1077 |
+
print("Download complete!")
|
| 1078 |
print("=" * 60)
|
| 1079 |
|
| 1080 |
|
scripts/evaluate.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Comprehensive evaluation script for LexiMind.
|
| 4 |
|
| 5 |
Evaluates all three tasks with full metrics:
|
| 6 |
-
- Summarization: ROUGE-1/2/L, BLEU-4, BERTScore
|
| 7 |
-
- Emotion:
|
| 8 |
-
- Topic: Accuracy, Macro F1, Per-class metrics
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
python scripts/evaluate.py
|
| 12 |
python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 13 |
-
python scripts/evaluate.py --
|
|
|
|
|
|
|
| 14 |
|
| 15 |
Author: Oliver Perrin
|
| 16 |
Date: January 2026
|
|
@@ -33,27 +34,22 @@ import torch
|
|
| 33 |
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
| 34 |
from tqdm import tqdm
|
| 35 |
|
| 36 |
-
from src.data.dataloader import (
|
| 37 |
-
build_emotion_dataloader,
|
| 38 |
-
build_summarization_dataloader,
|
| 39 |
-
build_topic_dataloader,
|
| 40 |
-
)
|
| 41 |
from src.data.dataset import (
|
| 42 |
-
EmotionDataset,
|
| 43 |
-
SummarizationDataset,
|
| 44 |
-
TopicDataset,
|
| 45 |
load_emotion_jsonl,
|
| 46 |
load_summarization_jsonl,
|
| 47 |
load_topic_jsonl,
|
| 48 |
)
|
| 49 |
-
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 50 |
from src.inference.factory import create_inference_pipeline
|
| 51 |
from src.training.metrics import (
|
| 52 |
-
|
| 53 |
calculate_bertscore,
|
| 54 |
calculate_bleu,
|
| 55 |
calculate_rouge,
|
| 56 |
multilabel_f1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
|
|
@@ -63,21 +59,30 @@ def evaluate_summarization(
|
|
| 63 |
max_samples: int | None = None,
|
| 64 |
include_bertscore: bool = True,
|
| 65 |
batch_size: int = 8,
|
|
|
|
| 66 |
) -> dict:
|
| 67 |
-
"""Evaluate summarization with comprehensive metrics."""
|
| 68 |
print("\n" + "=" * 60)
|
| 69 |
print("SUMMARIZATION EVALUATION")
|
| 70 |
print("=" * 60)
|
| 71 |
|
| 72 |
-
# Load data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
data = load_summarization_jsonl(str(data_path))
|
| 74 |
if max_samples:
|
| 75 |
data = data[:max_samples]
|
|
|
|
| 76 |
print(f"Evaluating on {len(data)} samples...")
|
| 77 |
|
| 78 |
# Generate summaries
|
| 79 |
predictions = []
|
| 80 |
references = []
|
|
|
|
| 81 |
|
| 82 |
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
|
| 83 |
batch = data[i:i + batch_size]
|
|
@@ -87,15 +92,24 @@ def evaluate_summarization(
|
|
| 87 |
preds = pipeline.summarize(sources)
|
| 88 |
predictions.extend(preds)
|
| 89 |
references.extend(refs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
# Calculate metrics
|
| 92 |
print("\nCalculating ROUGE scores...")
|
| 93 |
rouge_scores = calculate_rouge(predictions, references)
|
| 94 |
|
| 95 |
print("Calculating BLEU score...")
|
| 96 |
bleu = calculate_bleu(predictions, references)
|
| 97 |
|
| 98 |
-
metrics = {
|
| 99 |
"rouge1": rouge_scores["rouge1"],
|
| 100 |
"rouge2": rouge_scores["rouge2"],
|
| 101 |
"rougeL": rouge_scores["rougeL"],
|
|
@@ -110,6 +124,51 @@ def evaluate_summarization(
|
|
| 110 |
metrics["bertscore_recall"] = bert_scores["recall"]
|
| 111 |
metrics["bertscore_f1"] = bert_scores["f1"]
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# Print results
|
| 114 |
print("\n" + "-" * 40)
|
| 115 |
print("SUMMARIZATION RESULTS:")
|
|
@@ -123,6 +182,16 @@ def evaluate_summarization(
|
|
| 123 |
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
|
| 124 |
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Show examples
|
| 127 |
print("\n" + "-" * 40)
|
| 128 |
print("SAMPLE OUTPUTS:")
|
|
@@ -141,8 +210,14 @@ def evaluate_emotion(
|
|
| 141 |
data_path: Path,
|
| 142 |
max_samples: int | None = None,
|
| 143 |
batch_size: int = 32,
|
|
|
|
|
|
|
| 144 |
) -> dict:
|
| 145 |
-
"""Evaluate emotion detection with multi-label metrics.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
print("\n" + "=" * 60)
|
| 147 |
print("EMOTION DETECTION EVALUATION")
|
| 148 |
print("=" * 60)
|
|
@@ -153,9 +228,10 @@ def evaluate_emotion(
|
|
| 153 |
data = data[:max_samples]
|
| 154 |
print(f"Evaluating on {len(data)} samples...")
|
| 155 |
|
| 156 |
-
# Get predictions
|
| 157 |
all_preds = []
|
| 158 |
all_refs = []
|
|
|
|
| 159 |
|
| 160 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
|
| 161 |
batch = data[i:i + batch_size]
|
|
@@ -167,9 +243,17 @@ def evaluate_emotion(
|
|
| 167 |
|
| 168 |
all_preds.extend(pred_sets)
|
| 169 |
all_refs.extend(refs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Calculate metrics
|
| 172 |
-
# Convert to binary arrays for sklearn
|
| 173 |
all_emotions = sorted(pipeline.emotion_labels)
|
| 174 |
|
| 175 |
def to_binary(emotion_sets, labels):
|
|
@@ -178,41 +262,82 @@ def evaluate_emotion(
|
|
| 178 |
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
|
| 179 |
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
# Per-
|
| 185 |
-
|
| 186 |
-
for pred, ref in zip(all_preds, all_refs):
|
| 187 |
-
if len(pred) == 0 and len(ref) == 0:
|
| 188 |
-
sample_f1s.append(1.0)
|
| 189 |
-
elif len(pred) == 0 or len(ref) == 0:
|
| 190 |
-
sample_f1s.append(0.0)
|
| 191 |
-
else:
|
| 192 |
-
intersection = len(pred & ref)
|
| 193 |
-
precision = intersection / len(pred) if pred else 0
|
| 194 |
-
recall = intersection / len(ref) if ref else 0
|
| 195 |
-
if precision + recall > 0:
|
| 196 |
-
sample_f1s.append(2 * precision * recall / (precision + recall))
|
| 197 |
-
else:
|
| 198 |
-
sample_f1s.append(0.0)
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
"
|
| 204 |
-
"sample_avg_f1": avg_f1,
|
| 205 |
"num_samples": len(all_preds),
|
| 206 |
"num_classes": len(all_emotions),
|
|
|
|
| 207 |
}
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Print results
|
| 210 |
print("\n" + "-" * 40)
|
| 211 |
print("EMOTION DETECTION RESULTS:")
|
| 212 |
print("-" * 40)
|
| 213 |
-
print(f"
|
| 214 |
-
print(f"
|
| 215 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
return metrics
|
| 218 |
|
|
@@ -222,8 +347,9 @@ def evaluate_topic(
|
|
| 222 |
data_path: Path,
|
| 223 |
max_samples: int | None = None,
|
| 224 |
batch_size: int = 32,
|
|
|
|
| 225 |
) -> dict:
|
| 226 |
-
"""Evaluate topic classification."""
|
| 227 |
print("\n" + "=" * 60)
|
| 228 |
print("TOPIC CLASSIFICATION EVALUATION")
|
| 229 |
print("=" * 60)
|
|
@@ -253,12 +379,18 @@ def evaluate_topic(
|
|
| 253 |
accuracy = accuracy_score(all_refs, all_preds)
|
| 254 |
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
|
| 255 |
|
| 256 |
-
metrics = {
|
| 257 |
"accuracy": accuracy,
|
| 258 |
"macro_f1": macro_f1,
|
| 259 |
"num_samples": len(all_preds),
|
| 260 |
}
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
# Print results
|
| 263 |
print("\n" + "-" * 40)
|
| 264 |
print("TOPIC CLASSIFICATION RESULTS:")
|
|
@@ -266,6 +398,10 @@ def evaluate_topic(
|
|
| 266 |
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 267 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
# Classification report
|
| 270 |
print("\n" + "-" * 40)
|
| 271 |
print("PER-CLASS METRICS:")
|
|
@@ -282,7 +418,9 @@ def main():
|
|
| 282 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 283 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 284 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 285 |
-
parser.add_argument("--
|
|
|
|
|
|
|
| 286 |
parser.add_argument("--summarization-only", action="store_true")
|
| 287 |
parser.add_argument("--emotion-only", action="store_true")
|
| 288 |
parser.add_argument("--topic-only", action="store_true")
|
|
@@ -320,10 +458,11 @@ def main():
|
|
| 320 |
results["summarization"] = evaluate_summarization(
|
| 321 |
pipeline, val_path,
|
| 322 |
max_samples=args.max_samples,
|
| 323 |
-
include_bertscore=
|
|
|
|
| 324 |
)
|
| 325 |
else:
|
| 326 |
-
print(
|
| 327 |
|
| 328 |
# Evaluate emotion
|
| 329 |
if eval_all or args.emotion_only:
|
|
@@ -334,9 +473,11 @@ def main():
|
|
| 334 |
results["emotion"] = evaluate_emotion(
|
| 335 |
pipeline, val_path,
|
| 336 |
max_samples=args.max_samples,
|
|
|
|
|
|
|
| 337 |
)
|
| 338 |
else:
|
| 339 |
-
print(
|
| 340 |
|
| 341 |
# Evaluate topic
|
| 342 |
if eval_all or args.topic_only:
|
|
@@ -347,9 +488,10 @@ def main():
|
|
| 347 |
results["topic"] = evaluate_topic(
|
| 348 |
pipeline, val_path,
|
| 349 |
max_samples=args.max_samples,
|
|
|
|
| 350 |
)
|
| 351 |
else:
|
| 352 |
-
print(
|
| 353 |
|
| 354 |
# Save results
|
| 355 |
print("\n" + "=" * 60)
|
|
@@ -370,18 +512,23 @@ def main():
|
|
| 370 |
|
| 371 |
if "summarization" in results:
|
| 372 |
s = results["summarization"]
|
| 373 |
-
print(
|
| 374 |
print(f" ROUGE-1: {s['rouge1']:.4f}")
|
|
|
|
| 375 |
print(f" ROUGE-L: {s['rougeL']:.4f}")
|
|
|
|
| 376 |
if "bertscore_f1" in s:
|
| 377 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 378 |
|
| 379 |
if "emotion" in results:
|
| 380 |
-
|
| 381 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
if "topic" in results:
|
| 384 |
-
print(
|
| 385 |
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
|
| 386 |
|
| 387 |
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Comprehensive evaluation script for LexiMind.
|
| 3 |
|
| 4 |
Evaluates all three tasks with full metrics:
|
| 5 |
+
- Summarization: ROUGE-1/2/L, BLEU-4, per-domain breakdown (BERTScore optional)
|
| 6 |
+
- Emotion: Sample-avg F1, Macro F1, Micro F1, per-class metrics, threshold tuning
|
| 7 |
+
- Topic: Accuracy, Macro F1, Per-class metrics, bootstrap confidence intervals
|
| 8 |
|
| 9 |
Usage:
|
| 10 |
python scripts/evaluate.py
|
| 11 |
python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 12 |
+
python scripts/evaluate.py --include-bertscore # Include BERTScore (slow)
|
| 13 |
+
python scripts/evaluate.py --tune-thresholds # Tune per-class emotion thresholds
|
| 14 |
+
python scripts/evaluate.py --bootstrap # Compute confidence intervals
|
| 15 |
|
| 16 |
Author: Oliver Perrin
|
| 17 |
Date: January 2026
|
|
|
|
| 34 |
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
| 35 |
from tqdm import tqdm
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
from src.data.dataset import (
|
|
|
|
|
|
|
|
|
|
| 38 |
load_emotion_jsonl,
|
| 39 |
load_summarization_jsonl,
|
| 40 |
load_topic_jsonl,
|
| 41 |
)
|
|
|
|
| 42 |
from src.inference.factory import create_inference_pipeline
|
| 43 |
from src.training.metrics import (
|
| 44 |
+
bootstrap_confidence_interval,
|
| 45 |
calculate_bertscore,
|
| 46 |
calculate_bleu,
|
| 47 |
calculate_rouge,
|
| 48 |
multilabel_f1,
|
| 49 |
+
multilabel_macro_f1,
|
| 50 |
+
multilabel_micro_f1,
|
| 51 |
+
multilabel_per_class_metrics,
|
| 52 |
+
tune_per_class_thresholds,
|
| 53 |
)
|
| 54 |
|
| 55 |
|
|
|
|
| 59 |
max_samples: int | None = None,
|
| 60 |
include_bertscore: bool = True,
|
| 61 |
batch_size: int = 8,
|
| 62 |
+
compute_bootstrap: bool = False,
|
| 63 |
) -> dict:
|
| 64 |
+
"""Evaluate summarization with comprehensive metrics and per-domain breakdown."""
|
| 65 |
print("\n" + "=" * 60)
|
| 66 |
print("SUMMARIZATION EVALUATION")
|
| 67 |
print("=" * 60)
|
| 68 |
|
| 69 |
+
# Load data - try to get domain info from the raw JSONL
|
| 70 |
+
raw_data = []
|
| 71 |
+
with open(data_path) as f:
|
| 72 |
+
for line in f:
|
| 73 |
+
if line.strip():
|
| 74 |
+
raw_data.append(json.loads(line))
|
| 75 |
+
|
| 76 |
data = load_summarization_jsonl(str(data_path))
|
| 77 |
if max_samples:
|
| 78 |
data = data[:max_samples]
|
| 79 |
+
raw_data = raw_data[:max_samples]
|
| 80 |
print(f"Evaluating on {len(data)} samples...")
|
| 81 |
|
| 82 |
# Generate summaries
|
| 83 |
predictions = []
|
| 84 |
references = []
|
| 85 |
+
domains = [] # Track domain for per-domain breakdown
|
| 86 |
|
| 87 |
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
|
| 88 |
batch = data[i:i + batch_size]
|
|
|
|
| 92 |
preds = pipeline.summarize(sources)
|
| 93 |
predictions.extend(preds)
|
| 94 |
references.extend(refs)
|
| 95 |
+
|
| 96 |
+
# Track domain if available
|
| 97 |
+
for j in range(len(batch)):
|
| 98 |
+
idx = i + j
|
| 99 |
+
if idx < len(raw_data):
|
| 100 |
+
domain = raw_data[idx].get("type", raw_data[idx].get("domain", "unknown"))
|
| 101 |
+
domains.append(domain)
|
| 102 |
+
else:
|
| 103 |
+
domains.append("unknown")
|
| 104 |
|
| 105 |
+
# Calculate overall metrics
|
| 106 |
print("\nCalculating ROUGE scores...")
|
| 107 |
rouge_scores = calculate_rouge(predictions, references)
|
| 108 |
|
| 109 |
print("Calculating BLEU score...")
|
| 110 |
bleu = calculate_bleu(predictions, references)
|
| 111 |
|
| 112 |
+
metrics: dict = {
|
| 113 |
"rouge1": rouge_scores["rouge1"],
|
| 114 |
"rouge2": rouge_scores["rouge2"],
|
| 115 |
"rougeL": rouge_scores["rougeL"],
|
|
|
|
| 124 |
metrics["bertscore_recall"] = bert_scores["recall"]
|
| 125 |
metrics["bertscore_f1"] = bert_scores["f1"]
|
| 126 |
|
| 127 |
+
# Per-domain breakdown
|
| 128 |
+
unique_domains = sorted(set(domains))
|
| 129 |
+
if len(unique_domains) > 1:
|
| 130 |
+
print("\nComputing per-domain breakdown...")
|
| 131 |
+
domain_metrics = {}
|
| 132 |
+
for domain in unique_domains:
|
| 133 |
+
if domain == "unknown":
|
| 134 |
+
continue
|
| 135 |
+
d_preds = [p for p, d in zip(predictions, domains, strict=True) if d == domain]
|
| 136 |
+
d_refs = [r for r, d in zip(references, domains, strict=True) if d == domain]
|
| 137 |
+
if not d_preds:
|
| 138 |
+
continue
|
| 139 |
+
d_rouge = calculate_rouge(d_preds, d_refs)
|
| 140 |
+
d_bleu = calculate_bleu(d_preds, d_refs)
|
| 141 |
+
dm: dict = {
|
| 142 |
+
"num_samples": len(d_preds),
|
| 143 |
+
"rouge1": d_rouge["rouge1"],
|
| 144 |
+
"rouge2": d_rouge["rouge2"],
|
| 145 |
+
"rougeL": d_rouge["rougeL"],
|
| 146 |
+
"bleu4": d_bleu,
|
| 147 |
+
}
|
| 148 |
+
if include_bertscore:
|
| 149 |
+
d_bert = calculate_bertscore(d_preds, d_refs)
|
| 150 |
+
dm["bertscore_f1"] = d_bert["f1"]
|
| 151 |
+
domain_metrics[domain] = dm
|
| 152 |
+
metrics["per_domain"] = domain_metrics
|
| 153 |
+
|
| 154 |
+
# Bootstrap confidence intervals
|
| 155 |
+
if compute_bootstrap:
|
| 156 |
+
try:
|
| 157 |
+
from rouge_score import rouge_scorer
|
| 158 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
|
| 159 |
+
per_sample_r1 = []
|
| 160 |
+
per_sample_rL = []
|
| 161 |
+
for pred, ref in zip(predictions, references, strict=True):
|
| 162 |
+
scores = scorer.score(ref, pred)
|
| 163 |
+
per_sample_r1.append(scores['rouge1'].fmeasure)
|
| 164 |
+
per_sample_rL.append(scores['rougeL'].fmeasure)
|
| 165 |
+
r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
|
| 166 |
+
rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
|
| 167 |
+
metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
|
| 168 |
+
metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
|
| 169 |
+
except ImportError:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
# Print results
|
| 173 |
print("\n" + "-" * 40)
|
| 174 |
print("SUMMARIZATION RESULTS:")
|
|
|
|
| 182 |
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
|
| 183 |
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
|
| 184 |
|
| 185 |
+
if "per_domain" in metrics:
|
| 186 |
+
print("\n Per-Domain Breakdown:")
|
| 187 |
+
for domain, dm in metrics["per_domain"].items():
|
| 188 |
+
bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
|
| 189 |
+
print(f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}")
|
| 190 |
+
|
| 191 |
+
if "rouge1_ci" in metrics:
|
| 192 |
+
ci = metrics["rouge1_ci"]
|
| 193 |
+
print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 194 |
+
|
| 195 |
# Show examples
|
| 196 |
print("\n" + "-" * 40)
|
| 197 |
print("SAMPLE OUTPUTS:")
|
|
|
|
| 210 |
data_path: Path,
|
| 211 |
max_samples: int | None = None,
|
| 212 |
batch_size: int = 32,
|
| 213 |
+
tune_thresholds: bool = False,
|
| 214 |
+
compute_bootstrap: bool = False,
|
| 215 |
) -> dict:
|
| 216 |
+
"""Evaluate emotion detection with comprehensive multi-label metrics.
|
| 217 |
+
|
| 218 |
+
Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
|
| 219 |
+
Optionally tunes per-class thresholds on the evaluation set.
|
| 220 |
+
"""
|
| 221 |
print("\n" + "=" * 60)
|
| 222 |
print("EMOTION DETECTION EVALUATION")
|
| 223 |
print("=" * 60)
|
|
|
|
| 228 |
data = data[:max_samples]
|
| 229 |
print(f"Evaluating on {len(data)} samples...")
|
| 230 |
|
| 231 |
+
# Get predictions - collect raw logits for threshold tuning
|
| 232 |
all_preds = []
|
| 233 |
all_refs = []
|
| 234 |
+
all_logits_list = []
|
| 235 |
|
| 236 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
|
| 237 |
batch = data[i:i + batch_size]
|
|
|
|
| 243 |
|
| 244 |
all_preds.extend(pred_sets)
|
| 245 |
all_refs.extend(refs)
|
| 246 |
+
|
| 247 |
+
# Also get raw logits for threshold tuning
|
| 248 |
+
if tune_thresholds:
|
| 249 |
+
encoded = pipeline.tokenizer.batch_encode(texts)
|
| 250 |
+
input_ids = encoded["input_ids"].to(pipeline.device)
|
| 251 |
+
attention_mask = encoded["attention_mask"].to(pipeline.device)
|
| 252 |
+
with torch.inference_mode():
|
| 253 |
+
logits = pipeline.model.forward("emotion", {"input_ids": input_ids, "attention_mask": attention_mask})
|
| 254 |
+
all_logits_list.append(logits.cpu())
|
| 255 |
|
| 256 |
# Calculate metrics
|
|
|
|
| 257 |
all_emotions = sorted(pipeline.emotion_labels)
|
| 258 |
|
| 259 |
def to_binary(emotion_sets, labels):
|
|
|
|
| 262 |
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
|
| 263 |
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
|
| 264 |
|
| 265 |
+
# Core metrics: sample-avg F1, macro F1, micro F1
|
| 266 |
+
sample_f1 = multilabel_f1(pred_binary, ref_binary)
|
| 267 |
+
macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
|
| 268 |
+
micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
|
| 269 |
|
| 270 |
+
# Per-class metrics
|
| 271 |
+
per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
+
metrics: dict = {
|
| 274 |
+
"sample_avg_f1": sample_f1,
|
| 275 |
+
"macro_f1": macro_f1,
|
| 276 |
+
"micro_f1": micro_f1,
|
|
|
|
| 277 |
"num_samples": len(all_preds),
|
| 278 |
"num_classes": len(all_emotions),
|
| 279 |
+
"per_class": per_class,
|
| 280 |
}
|
| 281 |
|
| 282 |
+
# Per-class threshold tuning
|
| 283 |
+
if tune_thresholds and all_logits_list:
|
| 284 |
+
print("\nTuning per-class thresholds...")
|
| 285 |
+
all_logits = torch.cat(all_logits_list, dim=0)
|
| 286 |
+
best_thresholds, tuned_macro_f1 = tune_per_class_thresholds(all_logits, ref_binary)
|
| 287 |
+
metrics["tuned_thresholds"] = {
|
| 288 |
+
name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
|
| 289 |
+
}
|
| 290 |
+
metrics["tuned_macro_f1"] = tuned_macro_f1
|
| 291 |
+
|
| 292 |
+
# Also compute tuned sample-avg F1
|
| 293 |
+
probs = torch.sigmoid(all_logits)
|
| 294 |
+
tuned_preds = torch.zeros_like(probs)
|
| 295 |
+
for c, t in enumerate(best_thresholds):
|
| 296 |
+
tuned_preds[:, c] = (probs[:, c] >= t).float()
|
| 297 |
+
metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
|
| 298 |
+
metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
|
| 299 |
+
|
| 300 |
+
# Bootstrap confidence intervals
|
| 301 |
+
if compute_bootstrap:
|
| 302 |
+
# Compute per-sample F1 for bootstrapping
|
| 303 |
+
per_sample_f1s = []
|
| 304 |
+
for pred, ref in zip(all_preds, all_refs, strict=True):
|
| 305 |
+
if len(pred) == 0 and len(ref) == 0:
|
| 306 |
+
per_sample_f1s.append(1.0)
|
| 307 |
+
elif len(pred) == 0 or len(ref) == 0:
|
| 308 |
+
per_sample_f1s.append(0.0)
|
| 309 |
+
else:
|
| 310 |
+
intersection = len(pred & ref)
|
| 311 |
+
p = intersection / len(pred) if pred else 0
|
| 312 |
+
r = intersection / len(ref) if ref else 0
|
| 313 |
+
per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
|
| 314 |
+
mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
|
| 315 |
+
metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 316 |
+
|
| 317 |
# Print results
|
| 318 |
print("\n" + "-" * 40)
|
| 319 |
print("EMOTION DETECTION RESULTS:")
|
| 320 |
print("-" * 40)
|
| 321 |
+
print(f" Sample-avg F1: {metrics['sample_avg_f1']:.4f}")
|
| 322 |
+
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 323 |
+
print(f" Micro F1: {metrics['micro_f1']:.4f}")
|
| 324 |
+
print(f" Num Classes: {metrics['num_classes']}")
|
| 325 |
+
|
| 326 |
+
if "tuned_macro_f1" in metrics:
|
| 327 |
+
print("\n After per-class threshold tuning:")
|
| 328 |
+
print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
|
| 329 |
+
print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
|
| 330 |
+
print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
|
| 331 |
+
|
| 332 |
+
if "sample_f1_ci" in metrics:
|
| 333 |
+
ci = metrics["sample_f1_ci"]
|
| 334 |
+
print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 335 |
+
|
| 336 |
+
# Print top-10 per-class performance
|
| 337 |
+
print("\n Per-class F1 (top 10 by support):")
|
| 338 |
+
sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
|
| 339 |
+
for name, m in sorted_classes[:10]:
|
| 340 |
+
print(f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})")
|
| 341 |
|
| 342 |
return metrics
|
| 343 |
|
|
|
|
| 347 |
data_path: Path,
|
| 348 |
max_samples: int | None = None,
|
| 349 |
batch_size: int = 32,
|
| 350 |
+
compute_bootstrap: bool = False,
|
| 351 |
) -> dict:
|
| 352 |
+
"""Evaluate topic classification with per-class metrics and optional bootstrap CI."""
|
| 353 |
print("\n" + "=" * 60)
|
| 354 |
print("TOPIC CLASSIFICATION EVALUATION")
|
| 355 |
print("=" * 60)
|
|
|
|
| 379 |
accuracy = accuracy_score(all_refs, all_preds)
|
| 380 |
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
|
| 381 |
|
| 382 |
+
metrics: dict = {
|
| 383 |
"accuracy": accuracy,
|
| 384 |
"macro_f1": macro_f1,
|
| 385 |
"num_samples": len(all_preds),
|
| 386 |
}
|
| 387 |
|
| 388 |
+
# Bootstrap confidence intervals for accuracy
|
| 389 |
+
if compute_bootstrap:
|
| 390 |
+
per_sample_correct = [1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)]
|
| 391 |
+
mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
|
| 392 |
+
metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 393 |
+
|
| 394 |
# Print results
|
| 395 |
print("\n" + "-" * 40)
|
| 396 |
print("TOPIC CLASSIFICATION RESULTS:")
|
|
|
|
| 398 |
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 399 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 400 |
|
| 401 |
+
if "accuracy_ci" in metrics:
|
| 402 |
+
ci = metrics["accuracy_ci"]
|
| 403 |
+
print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 404 |
+
|
| 405 |
# Classification report
|
| 406 |
print("\n" + "-" * 40)
|
| 407 |
print("PER-CLASS METRICS:")
|
|
|
|
| 418 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 419 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 420 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 421 |
+
parser.add_argument("--include-bertscore", action="store_true", help="Include BERTScore (slow, optional)")
|
| 422 |
+
parser.add_argument("--tune-thresholds", action="store_true", help="Tune per-class emotion thresholds on val set")
|
| 423 |
+
parser.add_argument("--bootstrap", action="store_true", help="Compute bootstrap confidence intervals")
|
| 424 |
parser.add_argument("--summarization-only", action="store_true")
|
| 425 |
parser.add_argument("--emotion-only", action="store_true")
|
| 426 |
parser.add_argument("--topic-only", action="store_true")
|
|
|
|
| 458 |
results["summarization"] = evaluate_summarization(
|
| 459 |
pipeline, val_path,
|
| 460 |
max_samples=args.max_samples,
|
| 461 |
+
include_bertscore=args.include_bertscore,
|
| 462 |
+
compute_bootstrap=args.bootstrap,
|
| 463 |
)
|
| 464 |
else:
|
| 465 |
+
print("Warning: summarization validation data not found, skipping")
|
| 466 |
|
| 467 |
# Evaluate emotion
|
| 468 |
if eval_all or args.emotion_only:
|
|
|
|
| 473 |
results["emotion"] = evaluate_emotion(
|
| 474 |
pipeline, val_path,
|
| 475 |
max_samples=args.max_samples,
|
| 476 |
+
tune_thresholds=args.tune_thresholds,
|
| 477 |
+
compute_bootstrap=args.bootstrap,
|
| 478 |
)
|
| 479 |
else:
|
| 480 |
+
print("Warning: emotion validation data not found, skipping")
|
| 481 |
|
| 482 |
# Evaluate topic
|
| 483 |
if eval_all or args.topic_only:
|
|
|
|
| 488 |
results["topic"] = evaluate_topic(
|
| 489 |
pipeline, val_path,
|
| 490 |
max_samples=args.max_samples,
|
| 491 |
+
compute_bootstrap=args.bootstrap,
|
| 492 |
)
|
| 493 |
else:
|
| 494 |
+
print("Warning: topic validation data not found, skipping")
|
| 495 |
|
| 496 |
# Save results
|
| 497 |
print("\n" + "=" * 60)
|
|
|
|
| 512 |
|
| 513 |
if "summarization" in results:
|
| 514 |
s = results["summarization"]
|
| 515 |
+
print("\n Summarization:")
|
| 516 |
print(f" ROUGE-1: {s['rouge1']:.4f}")
|
| 517 |
+
print(f" ROUGE-2: {s['rouge2']:.4f}")
|
| 518 |
print(f" ROUGE-L: {s['rougeL']:.4f}")
|
| 519 |
+
print(f" BLEU-4: {s['bleu4']:.4f}")
|
| 520 |
if "bertscore_f1" in s:
|
| 521 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 522 |
|
| 523 |
if "emotion" in results:
|
| 524 |
+
e = results["emotion"]
|
| 525 |
+
print("\n Emotion:")
|
| 526 |
+
print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
|
| 527 |
+
print(f" Macro F1: {e['macro_f1']:.4f}")
|
| 528 |
+
print(f" Micro F1: {e['micro_f1']:.4f}")
|
| 529 |
|
| 530 |
if "topic" in results:
|
| 531 |
+
print("\n Topic:")
|
| 532 |
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
|
| 533 |
|
| 534 |
|
scripts/profile_training.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Profile LexiMind training with PyTorch Profiler.
|
| 3 |
+
|
| 4 |
+
Runs a few training steps under torch.profiler to capture:
|
| 5 |
+
- CUDA kernel timing (per-operator breakdown)
|
| 6 |
+
- GPU memory usage (peak allocations, memory timeline)
|
| 7 |
+
- CPU/GPU overlap and idle time
|
| 8 |
+
- Chrome trace (viewable in chrome://tracing or Perfetto UI)
|
| 9 |
+
|
| 10 |
+
Outputs:
|
| 11 |
+
outputs/profile/ -- Chrome trace + stacks
|
| 12 |
+
stdout -- Summary table of top CUDA operations
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python scripts/profile_training.py # default: 20 steps
|
| 16 |
+
python scripts/profile_training.py training=full # use full config
|
| 17 |
+
PROFILE_STEPS=40 python scripts/profile_training.py # custom step count
|
| 18 |
+
|
| 19 |
+
Author: Oliver Perrin
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import hydra
|
| 29 |
+
import torch
|
| 30 |
+
from omegaconf import DictConfig
|
| 31 |
+
|
| 32 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 33 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 34 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 35 |
+
|
| 36 |
+
from src.data.dataloader import (
|
| 37 |
+
build_emotion_dataloader,
|
| 38 |
+
build_summarization_dataloader,
|
| 39 |
+
build_topic_dataloader,
|
| 40 |
+
)
|
| 41 |
+
from src.data.dataset import (
|
| 42 |
+
EmotionDataset,
|
| 43 |
+
SummarizationDataset,
|
| 44 |
+
TopicDataset,
|
| 45 |
+
load_emotion_jsonl,
|
| 46 |
+
load_summarization_jsonl,
|
| 47 |
+
load_topic_jsonl,
|
| 48 |
+
)
|
| 49 |
+
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 50 |
+
from src.models.factory import ModelConfig, build_multitask_model
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_splits(data_dir: Path, loader_fn):
|
| 54 |
+
splits = {}
|
| 55 |
+
for name, aliases in [("train", ["train"]), ("val", ["val", "validation"])]:
|
| 56 |
+
for alias in aliases:
|
| 57 |
+
path = data_dir / f"{alias}.jsonl"
|
| 58 |
+
if path.exists():
|
| 59 |
+
splits[name] = loader_fn(str(path))
|
| 60 |
+
break
|
| 61 |
+
return splits
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
| 65 |
+
def main(cfg: DictConfig) -> None:
|
| 66 |
+
profile_steps = int(os.environ.get("PROFILE_STEPS", 20))
|
| 67 |
+
warmup_steps = 3 # let CUDA graphs / torch.compile settle
|
| 68 |
+
active_steps = profile_steps - warmup_steps
|
| 69 |
+
|
| 70 |
+
device = torch.device(cfg.device)
|
| 71 |
+
if device.type != "cuda":
|
| 72 |
+
print("Profiler requires CUDA. Set device=cuda.")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
print(f"Profiling {profile_steps} steps ({warmup_steps} warmup + {active_steps} active)")
|
| 76 |
+
print(f"GPU: {torch.cuda.get_device_name()}")
|
| 77 |
+
|
| 78 |
+
# ---------- Setup (mirrors train.py) ----------
|
| 79 |
+
|
| 80 |
+
torch.backends.cudnn.benchmark = True
|
| 81 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 82 |
+
torch.set_float32_matmul_precision("high")
|
| 83 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 84 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 85 |
+
|
| 86 |
+
data_cfg = cfg.data
|
| 87 |
+
trainer_cfg = cfg.training.get("trainer", {})
|
| 88 |
+
|
| 89 |
+
# Load small subsets -- profiling doesn't need the full dataset
|
| 90 |
+
max_samples = max(200, profile_steps * 10 * 3)
|
| 91 |
+
summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
|
| 92 |
+
emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
|
| 93 |
+
topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
|
| 94 |
+
for splits in [summ_splits, emot_splits, topic_splits]:
|
| 95 |
+
splits["train"] = splits["train"][:max_samples]
|
| 96 |
+
|
| 97 |
+
tok_cfg = data_cfg.get("tokenizer", {})
|
| 98 |
+
max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
|
| 99 |
+
tokenizer = Tokenizer(TokenizerConfig(
|
| 100 |
+
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 101 |
+
max_length=max_len,
|
| 102 |
+
))
|
| 103 |
+
|
| 104 |
+
summ_train = SummarizationDataset(summ_splits["train"])
|
| 105 |
+
emot_train = EmotionDataset(emot_splits["train"])
|
| 106 |
+
topic_train = TopicDataset(topic_splits["train"])
|
| 107 |
+
|
| 108 |
+
dl_cfg = cfg.training.get("dataloader", {})
|
| 109 |
+
batch_size = int(dl_cfg.get("batch_size", 8))
|
| 110 |
+
num_workers = int(dl_cfg.get("num_workers", 4))
|
| 111 |
+
classification_max_len = min(256, max_len)
|
| 112 |
+
|
| 113 |
+
train_loaders = {
|
| 114 |
+
"summarization": build_summarization_dataloader(
|
| 115 |
+
summ_train, tokenizer, shuffle=True,
|
| 116 |
+
max_source_length=max_len, max_target_length=max_len,
|
| 117 |
+
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
|
| 118 |
+
),
|
| 119 |
+
"emotion": build_emotion_dataloader(
|
| 120 |
+
emot_train, tokenizer, shuffle=True, max_length=classification_max_len,
|
| 121 |
+
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
|
| 122 |
+
),
|
| 123 |
+
"topic": build_topic_dataloader(
|
| 124 |
+
topic_train, tokenizer, shuffle=True, max_length=classification_max_len,
|
| 125 |
+
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
|
| 126 |
+
),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Build model
|
| 130 |
+
grad_ckpt = cfg.training.get("gradient_checkpointing", cfg.model.get("gradient_checkpointing", False))
|
| 131 |
+
use_rel_pos = cfg.training.get("use_relative_position_bias", cfg.model.get("use_relative_position_bias", False))
|
| 132 |
+
|
| 133 |
+
model_cfg = ModelConfig(
|
| 134 |
+
d_model=cfg.model.d_model,
|
| 135 |
+
vocab_size=getattr(cfg.model, "vocab_size", None),
|
| 136 |
+
num_encoder_layers=cfg.model.num_encoder_layers,
|
| 137 |
+
num_decoder_layers=cfg.model.num_decoder_layers,
|
| 138 |
+
num_attention_heads=cfg.model.num_attention_heads,
|
| 139 |
+
ffn_dim=cfg.model.ffn_dim,
|
| 140 |
+
dropout=cfg.model.dropout,
|
| 141 |
+
use_pretrained=cfg.model.use_pretrained,
|
| 142 |
+
pretrained_model_name=cfg.model.pretrained_model_name,
|
| 143 |
+
activation=getattr(cfg.model, "activation", "gelu"),
|
| 144 |
+
use_relative_position_bias=use_rel_pos,
|
| 145 |
+
gradient_checkpointing=grad_ckpt,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
model = build_multitask_model(
|
| 149 |
+
tokenizer,
|
| 150 |
+
num_emotions=len(emot_train.emotion_classes),
|
| 151 |
+
num_topics=len(topic_train.topic_classes),
|
| 152 |
+
config=model_cfg,
|
| 153 |
+
).to(device)
|
| 154 |
+
|
| 155 |
+
# Freeze layers (same as train.py)
|
| 156 |
+
freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
|
| 157 |
+
if freeze_layers > 0:
|
| 158 |
+
if hasattr(model.encoder, "embed_tokens"):
|
| 159 |
+
for p in model.encoder.embed_tokens.parameters():
|
| 160 |
+
p.requires_grad = False
|
| 161 |
+
if hasattr(model.encoder, "layers"):
|
| 162 |
+
for i, layer in enumerate(model.encoder.layers):
|
| 163 |
+
if i < freeze_layers:
|
| 164 |
+
for p in layer.parameters():
|
| 165 |
+
p.requires_grad = False
|
| 166 |
+
|
| 167 |
+
# Compile (same as train.py)
|
| 168 |
+
compile_mode = "default" if grad_ckpt else "reduce-overhead"
|
| 169 |
+
if cfg.training.get("compile_encoder", True):
|
| 170 |
+
model.encoder = torch.compile(model.encoder, mode=compile_mode)
|
| 171 |
+
if cfg.training.get("compile_decoder", True):
|
| 172 |
+
model.decoder = torch.compile(model.decoder, mode=compile_mode)
|
| 173 |
+
|
| 174 |
+
# Optimizer
|
| 175 |
+
opt_cfg = cfg.training.get("optimizer", {})
|
| 176 |
+
use_fused = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
|
| 177 |
+
optimizer = torch.optim.AdamW(
|
| 178 |
+
model.parameters(),
|
| 179 |
+
lr=float(opt_cfg.get("lr", 3e-5)),
|
| 180 |
+
weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
|
| 181 |
+
fused=use_fused,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# ---------- Profile loop ----------
|
| 185 |
+
|
| 186 |
+
out_dir = PROJECT_ROOT / "outputs" / "profile"
|
| 187 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
model.train()
|
| 190 |
+
iterators = {task: iter(loader) for task, loader in train_loaders.items()}
|
| 191 |
+
task_names = list(train_loaders.keys())
|
| 192 |
+
accum = int(trainer_cfg.get("gradient_accumulation_steps", 4))
|
| 193 |
+
use_bf16 = torch.cuda.is_bf16_supported()
|
| 194 |
+
task_weights = trainer_cfg.get("task_weights") or {}
|
| 195 |
+
|
| 196 |
+
emotion_loss_fn = torch.nn.BCEWithLogitsLoss()
|
| 197 |
+
topic_loss_fn = torch.nn.CrossEntropyLoss()
|
| 198 |
+
|
| 199 |
+
def get_batch(task):
|
| 200 |
+
try:
|
| 201 |
+
batch = next(iterators[task])
|
| 202 |
+
except StopIteration:
|
| 203 |
+
iterators[task] = iter(train_loaders[task])
|
| 204 |
+
batch = next(iterators[task])
|
| 205 |
+
return {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
|
| 206 |
+
for k, v in batch.items()}
|
| 207 |
+
|
| 208 |
+
def training_step(step):
|
| 209 |
+
"""One training step across all tasks."""
|
| 210 |
+
for task in task_names:
|
| 211 |
+
batch = get_batch(task)
|
| 212 |
+
dtype = torch.bfloat16 if use_bf16 else torch.float16
|
| 213 |
+
with torch.autocast("cuda", dtype=dtype):
|
| 214 |
+
if task == "summarization":
|
| 215 |
+
inputs = {"src_ids": batch["src_ids"], "tgt_ids": batch["tgt_ids"]}
|
| 216 |
+
if "src_mask" in batch:
|
| 217 |
+
inputs["src_mask"] = batch["src_mask"]
|
| 218 |
+
logits = model.forward("summarization", inputs)
|
| 219 |
+
loss = torch.nn.functional.cross_entropy(
|
| 220 |
+
logits.view(-1, logits.size(-1)),
|
| 221 |
+
batch["labels"].view(-1),
|
| 222 |
+
ignore_index=-100, label_smoothing=0.1,
|
| 223 |
+
)
|
| 224 |
+
elif task == "emotion":
|
| 225 |
+
inputs = {"input_ids": batch["input_ids"]}
|
| 226 |
+
if "attention_mask" in batch:
|
| 227 |
+
inputs["attention_mask"] = batch["attention_mask"]
|
| 228 |
+
logits = model.forward("emotion", inputs)
|
| 229 |
+
loss = emotion_loss_fn(logits, batch["labels"].float())
|
| 230 |
+
elif task == "topic":
|
| 231 |
+
inputs = {"input_ids": batch["input_ids"]}
|
| 232 |
+
if "attention_mask" in batch:
|
| 233 |
+
inputs["attention_mask"] = batch["attention_mask"]
|
| 234 |
+
logits = model.forward("topic", inputs)
|
| 235 |
+
loss = topic_loss_fn(logits, batch["labels"])
|
| 236 |
+
else:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
weight = task_weights.get(task, 1.0)
|
| 240 |
+
scaled = (loss * weight) / accum
|
| 241 |
+
scaled.backward()
|
| 242 |
+
|
| 243 |
+
if (step + 1) % accum == 0:
|
| 244 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 245 |
+
optimizer.step()
|
| 246 |
+
optimizer.zero_grad()
|
| 247 |
+
|
| 248 |
+
# Warmup outside profiler to let torch.compile finish
|
| 249 |
+
print(f"\nWarmup ({warmup_steps} steps)...")
|
| 250 |
+
for s in range(warmup_steps):
|
| 251 |
+
training_step(s)
|
| 252 |
+
optimizer.zero_grad()
|
| 253 |
+
torch.cuda.synchronize()
|
| 254 |
+
|
| 255 |
+
# Profile
|
| 256 |
+
print(f"Profiling ({active_steps} steps)...")
|
| 257 |
+
trace_path = str(out_dir / "trace")
|
| 258 |
+
|
| 259 |
+
with torch.profiler.profile(
|
| 260 |
+
activities=[
|
| 261 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 262 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 263 |
+
],
|
| 264 |
+
schedule=torch.profiler.schedule(
|
| 265 |
+
wait=1, warmup=2, active=active_steps - 3, repeat=1,
|
| 266 |
+
),
|
| 267 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
|
| 268 |
+
record_shapes=True,
|
| 269 |
+
profile_memory=True,
|
| 270 |
+
with_stack=True,
|
| 271 |
+
with_flops=True,
|
| 272 |
+
) as prof:
|
| 273 |
+
for s in range(active_steps):
|
| 274 |
+
training_step(warmup_steps + s)
|
| 275 |
+
prof.step()
|
| 276 |
+
|
| 277 |
+
torch.cuda.synchronize()
|
| 278 |
+
|
| 279 |
+
# ---------- Summary ----------
|
| 280 |
+
|
| 281 |
+
print("\n" + "=" * 80)
|
| 282 |
+
print("TOP CUDA OPERATIONS (by total CUDA time)")
|
| 283 |
+
print("=" * 80)
|
| 284 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=25))
|
| 285 |
+
|
| 286 |
+
print("\n" + "=" * 80)
|
| 287 |
+
print("TOP CUDA OPERATIONS (by GPU memory)")
|
| 288 |
+
print("=" * 80)
|
| 289 |
+
print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=15))
|
| 290 |
+
|
| 291 |
+
# Memory summary
|
| 292 |
+
print("\n" + "=" * 80)
|
| 293 |
+
print("GPU MEMORY SUMMARY")
|
| 294 |
+
print("=" * 80)
|
| 295 |
+
print(torch.cuda.memory_summary(abbreviated=True))
|
| 296 |
+
|
| 297 |
+
# Export Chrome trace
|
| 298 |
+
chrome_trace = out_dir / "chrome_trace.json"
|
| 299 |
+
prof.export_chrome_trace(str(chrome_trace))
|
| 300 |
+
print(f"\nChrome trace: {chrome_trace}")
|
| 301 |
+
print(" Open in: chrome://tracing or https://ui.perfetto.dev")
|
| 302 |
+
|
| 303 |
+
# Export stacks for flamegraph
|
| 304 |
+
stacks_path = out_dir / "profiler_stacks.txt"
|
| 305 |
+
prof.export_stacks(str(stacks_path), "self_cuda_time_total")
|
| 306 |
+
print(f"CUDA stacks: {stacks_path}")
|
| 307 |
+
print(f" Generate flamegraph: flamegraph.pl {stacks_path} > flamegraph.svg")
|
| 308 |
+
|
| 309 |
+
print(f"\nTensorBoard traces: {trace_path}/")
|
| 310 |
+
print(f" View with: tensorboard --logdir={trace_path}")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
main()
|
scripts/train.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Training script for LexiMind.
|
| 4 |
|
|
@@ -97,9 +96,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 97 |
torch.set_float32_matmul_precision("high")
|
| 98 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 99 |
torch.backends.cudnn.allow_tf32 = True
|
| 100 |
-
print("
|
| 101 |
else:
|
| 102 |
-
print("
|
| 103 |
|
| 104 |
# --------------- Load Data ---------------
|
| 105 |
|
|
@@ -218,9 +217,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 218 |
)
|
| 219 |
|
| 220 |
if grad_ckpt:
|
| 221 |
-
print("
|
| 222 |
if not use_rel_pos:
|
| 223 |
-
print("
|
| 224 |
|
| 225 |
model = build_multitask_model(
|
| 226 |
tokenizer,
|
|
@@ -249,7 +248,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 249 |
p.requires_grad = False
|
| 250 |
frozen_params += p.numel()
|
| 251 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 252 |
-
print(f"
|
| 253 |
print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
|
| 254 |
|
| 255 |
# Resume from checkpoint?
|
|
@@ -269,10 +268,10 @@ def main(cfg: DictConfig) -> None:
|
|
| 269 |
compile_mode = "default" if grad_ckpt else "reduce-overhead"
|
| 270 |
if cfg.training.get("compile_encoder", True):
|
| 271 |
model.encoder = torch.compile(model.encoder, mode=compile_mode) # type: ignore[assignment]
|
| 272 |
-
print(f"
|
| 273 |
if cfg.training.get("compile_decoder", True):
|
| 274 |
model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
|
| 275 |
-
print(f"
|
| 276 |
|
| 277 |
# --------------- Train ---------------
|
| 278 |
|
|
@@ -289,7 +288,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 289 |
fused=use_fused,
|
| 290 |
)
|
| 291 |
if use_fused:
|
| 292 |
-
print("
|
| 293 |
|
| 294 |
trainer = Trainer(
|
| 295 |
model=model,
|
|
@@ -303,6 +302,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 303 |
scheduler_type=str(sched_cfg.get("name", "cosine")),
|
| 304 |
warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
|
| 305 |
early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
|
|
|
|
|
|
|
|
|
|
| 306 |
),
|
| 307 |
device=device,
|
| 308 |
tokenizer=tokenizer,
|
|
@@ -326,7 +328,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 326 |
if val_loss < best_val_loss:
|
| 327 |
best_val_loss = val_loss
|
| 328 |
save_state(model, str(ckpt_dir / "best.pt"))
|
| 329 |
-
print(f"
|
| 330 |
|
| 331 |
history = trainer.fit(
|
| 332 |
train_loaders,
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Training script for LexiMind.
|
| 3 |
|
|
|
|
| 96 |
torch.set_float32_matmul_precision("high")
|
| 97 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 98 |
torch.backends.cudnn.allow_tf32 = True
|
| 99 |
+
print(" TF32 + cudnn.benchmark enabled (Ampere GPU)")
|
| 100 |
else:
|
| 101 |
+
print(" cudnn.benchmark enabled")
|
| 102 |
|
| 103 |
# --------------- Load Data ---------------
|
| 104 |
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
if grad_ckpt:
|
| 220 |
+
print(" Gradient checkpointing: on")
|
| 221 |
if not use_rel_pos:
|
| 222 |
+
print(" FlashAttention: on (no relative position bias)")
|
| 223 |
|
| 224 |
model = build_multitask_model(
|
| 225 |
tokenizer,
|
|
|
|
| 248 |
p.requires_grad = False
|
| 249 |
frozen_params += p.numel()
|
| 250 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 251 |
+
print(f" Frozen layers: 0-{freeze_layers-1} ({frozen_params/1e6:.1f}M params)")
|
| 252 |
print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
|
| 253 |
|
| 254 |
# Resume from checkpoint?
|
|
|
|
| 268 |
compile_mode = "default" if grad_ckpt else "reduce-overhead"
|
| 269 |
if cfg.training.get("compile_encoder", True):
|
| 270 |
model.encoder = torch.compile(model.encoder, mode=compile_mode) # type: ignore[assignment]
|
| 271 |
+
print(f" Encoder compiled ({compile_mode})")
|
| 272 |
if cfg.training.get("compile_decoder", True):
|
| 273 |
model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
|
| 274 |
+
print(f" Decoder compiled ({compile_mode})")
|
| 275 |
|
| 276 |
# --------------- Train ---------------
|
| 277 |
|
|
|
|
| 288 |
fused=use_fused,
|
| 289 |
)
|
| 290 |
if use_fused:
|
| 291 |
+
print(" Fused AdamW: on")
|
| 292 |
|
| 293 |
trainer = Trainer(
|
| 294 |
model=model,
|
|
|
|
| 302 |
scheduler_type=str(sched_cfg.get("name", "cosine")),
|
| 303 |
warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
|
| 304 |
early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
|
| 305 |
+
task_sampling=str(trainer_cfg.get("task_sampling", "temperature")),
|
| 306 |
+
task_sampling_alpha=float(trainer_cfg.get("task_sampling_alpha", 0.5)),
|
| 307 |
+
gradient_conflict_frequency=int(trainer_cfg.get("gradient_conflict_frequency", 0)),
|
| 308 |
),
|
| 309 |
device=device,
|
| 310 |
tokenizer=tokenizer,
|
|
|
|
| 328 |
if val_loss < best_val_loss:
|
| 329 |
best_val_loss = val_loss
|
| 330 |
save_state(model, str(ckpt_dir / "best.pt"))
|
| 331 |
+
print(f" New best model saved (val_loss={val_loss:.4f})")
|
| 332 |
|
| 333 |
history = trainer.fit(
|
| 334 |
train_loaders,
|
scripts/train_multiseed.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-seed training wrapper for LexiMind.
|
| 3 |
+
|
| 4 |
+
Runs training across multiple seeds and aggregates results with mean ± std.
|
| 5 |
+
This addresses the single-seed limitation identified in review feedback.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/train_multiseed.py --seeds 17 42 123 --config training=full
|
| 9 |
+
python scripts/train_multiseed.py --seeds 17 42 123 456 789 --config training=medium
|
| 10 |
+
|
| 11 |
+
Author: Oliver Perrin
|
| 12 |
+
Date: February 2026
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import subprocess
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, List
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
|
| 28 |
+
"""Run training for a single seed and return the training history."""
|
| 29 |
+
seed_dir = base_dir / f"seed_{seed}"
|
| 30 |
+
seed_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
cmd = [
|
| 33 |
+
sys.executable, "scripts/train.py",
|
| 34 |
+
f"seed={seed}",
|
| 35 |
+
f"checkpoint_out={seed_dir}/checkpoints/best.pt",
|
| 36 |
+
f"history_out={seed_dir}/training_history.json",
|
| 37 |
+
f"labels_out={seed_dir}/labels.json",
|
| 38 |
+
]
|
| 39 |
+
if config_overrides:
|
| 40 |
+
cmd.extend(config_overrides.split())
|
| 41 |
+
|
| 42 |
+
print(f"\n{'='*60}")
|
| 43 |
+
print(f"Training seed {seed}")
|
| 44 |
+
print(f"{'='*60}")
|
| 45 |
+
print(f" Command: {' '.join(cmd)}")
|
| 46 |
+
|
| 47 |
+
result = subprocess.run(cmd, capture_output=False)
|
| 48 |
+
if result.returncode != 0:
|
| 49 |
+
print(f" WARNING: Seed {seed} training failed (exit code {result.returncode})")
|
| 50 |
+
return {}
|
| 51 |
+
|
| 52 |
+
history_path = seed_dir / "training_history.json"
|
| 53 |
+
if history_path.exists():
|
| 54 |
+
with open(history_path) as f:
|
| 55 |
+
data: Dict = json.load(f) # type: ignore[no-any-return]
|
| 56 |
+
return data
|
| 57 |
+
return {}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = None) -> Dict:
|
| 61 |
+
"""Run evaluation for a single seed and return results."""
|
| 62 |
+
seed_dir = base_dir / f"seed_{seed}"
|
| 63 |
+
checkpoint = seed_dir / "checkpoints" / "best.pt"
|
| 64 |
+
labels = seed_dir / "labels.json"
|
| 65 |
+
output = seed_dir / "evaluation_report.json"
|
| 66 |
+
|
| 67 |
+
if not checkpoint.exists():
|
| 68 |
+
print(f" Skipping eval for seed {seed}: no checkpoint found")
|
| 69 |
+
return {}
|
| 70 |
+
|
| 71 |
+
cmd = [
|
| 72 |
+
sys.executable, "scripts/evaluate.py",
|
| 73 |
+
f"--checkpoint={checkpoint}",
|
| 74 |
+
f"--labels={labels}",
|
| 75 |
+
f"--output={output}",
|
| 76 |
+
"--skip-bertscore",
|
| 77 |
+
"--tune-thresholds",
|
| 78 |
+
"--bootstrap",
|
| 79 |
+
]
|
| 80 |
+
if extra_args:
|
| 81 |
+
cmd.extend(extra_args)
|
| 82 |
+
|
| 83 |
+
print(f"\n Evaluating seed {seed}...")
|
| 84 |
+
result = subprocess.run(cmd, capture_output=False)
|
| 85 |
+
if result.returncode != 0:
|
| 86 |
+
print(f" WARNING: Seed {seed} evaluation failed")
|
| 87 |
+
return {}
|
| 88 |
+
|
| 89 |
+
if output.exists():
|
| 90 |
+
with open(output) as f:
|
| 91 |
+
data: Dict = json.load(f) # type: ignore[no-any-return]
|
| 92 |
+
return data
|
| 93 |
+
return {}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
|
| 97 |
+
"""Aggregate evaluation results across seeds with mean ± std."""
|
| 98 |
+
if not all_results:
|
| 99 |
+
return {}
|
| 100 |
+
|
| 101 |
+
# Collect all metric paths
|
| 102 |
+
metric_values: Dict[str, List[float]] = {}
|
| 103 |
+
for _seed, results in all_results.items():
|
| 104 |
+
for task, task_metrics in results.items():
|
| 105 |
+
if not isinstance(task_metrics, dict):
|
| 106 |
+
continue
|
| 107 |
+
for metric_name, value in task_metrics.items():
|
| 108 |
+
if isinstance(value, (int, float)) and metric_name != "num_samples" and metric_name != "num_classes":
|
| 109 |
+
key = f"{task}/{metric_name}"
|
| 110 |
+
metric_values.setdefault(key, []).append(float(value))
|
| 111 |
+
|
| 112 |
+
aggregated: Dict[str, Dict[str, float]] = {}
|
| 113 |
+
for key, values in sorted(metric_values.items()):
|
| 114 |
+
arr = np.array(values)
|
| 115 |
+
aggregated[key] = {
|
| 116 |
+
"mean": float(arr.mean()),
|
| 117 |
+
"std": float(arr.std()),
|
| 118 |
+
"min": float(arr.min()),
|
| 119 |
+
"max": float(arr.max()),
|
| 120 |
+
"n_seeds": len(values),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
return aggregated
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def print_summary(aggregated: Dict, seeds: List[int]) -> None:
|
| 127 |
+
"""Print human-readable summary of multi-seed results."""
|
| 128 |
+
print(f"\n{'='*70}")
|
| 129 |
+
print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
|
| 130 |
+
print(f"{'='*70}")
|
| 131 |
+
|
| 132 |
+
# Group by task
|
| 133 |
+
tasks: Dict[str, Dict[str, Dict]] = {}
|
| 134 |
+
for key, stats in aggregated.items():
|
| 135 |
+
task, metric = key.split("/", 1)
|
| 136 |
+
tasks.setdefault(task, {})[metric] = stats
|
| 137 |
+
|
| 138 |
+
for task, metrics in sorted(tasks.items()):
|
| 139 |
+
print(f"\n {task.upper()}:")
|
| 140 |
+
for metric, stats in sorted(metrics.items()):
|
| 141 |
+
mean = stats["mean"]
|
| 142 |
+
std = stats["std"]
|
| 143 |
+
# Format based on metric type
|
| 144 |
+
if "accuracy" in metric:
|
| 145 |
+
print(f" {metric:25s}: {mean*100:.1f}% ± {std*100:.1f}%")
|
| 146 |
+
else:
|
| 147 |
+
print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def main():
|
| 151 |
+
parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
|
| 152 |
+
parser.add_argument("--seeds", nargs="+", type=int, default=[17, 42, 123],
|
| 153 |
+
help="Random seeds to train with")
|
| 154 |
+
parser.add_argument("--config", type=str, default="",
|
| 155 |
+
help="Hydra config overrides (e.g., 'training=full')")
|
| 156 |
+
parser.add_argument("--output-dir", type=Path, default=Path("outputs/multiseed"),
|
| 157 |
+
help="Base output directory")
|
| 158 |
+
parser.add_argument("--skip-training", action="store_true",
|
| 159 |
+
help="Skip training, only aggregate existing results")
|
| 160 |
+
parser.add_argument("--skip-eval", action="store_true",
|
| 161 |
+
help="Skip evaluation, only aggregate training histories")
|
| 162 |
+
args = parser.parse_args()
|
| 163 |
+
|
| 164 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
# Training phase
|
| 167 |
+
if not args.skip_training:
|
| 168 |
+
for seed in args.seeds:
|
| 169 |
+
run_single_seed(seed, args.config, args.output_dir)
|
| 170 |
+
|
| 171 |
+
# Evaluation phase
|
| 172 |
+
all_eval_results: Dict[int, Dict] = {}
|
| 173 |
+
if not args.skip_eval:
|
| 174 |
+
for seed in args.seeds:
|
| 175 |
+
result = run_evaluation(seed, args.output_dir)
|
| 176 |
+
if result:
|
| 177 |
+
all_eval_results[seed] = result
|
| 178 |
+
|
| 179 |
+
# Aggregate and save
|
| 180 |
+
if all_eval_results:
|
| 181 |
+
aggregated = aggregate_results(all_eval_results)
|
| 182 |
+
print_summary(aggregated, args.seeds)
|
| 183 |
+
|
| 184 |
+
# Save aggregated results
|
| 185 |
+
output_path = args.output_dir / "aggregated_results.json"
|
| 186 |
+
with open(output_path, "w") as f:
|
| 187 |
+
json.dump({
|
| 188 |
+
"seeds": args.seeds,
|
| 189 |
+
"per_seed": {str(k): v for k, v in all_eval_results.items()},
|
| 190 |
+
"aggregated": aggregated,
|
| 191 |
+
}, f, indent=2)
|
| 192 |
+
print(f"\n Saved to: {output_path}")
|
| 193 |
+
else:
|
| 194 |
+
print("\nNo evaluation results to aggregate.")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
scripts/visualize_training.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
LexiMind Training Visualization Suite.
|
| 4 |
|
|
@@ -63,16 +62,14 @@ except ImportError:
|
|
| 63 |
pass
|
| 64 |
|
| 65 |
try:
|
| 66 |
-
from mpl_toolkits.mplot3d import Axes3D # type: ignore[import-
|
| 67 |
|
| 68 |
HAS_MPLOT3D = True
|
| 69 |
except ImportError:
|
| 70 |
pass
|
| 71 |
|
| 72 |
|
| 73 |
-
# =============================================================================
|
| 74 |
# Configuration
|
| 75 |
-
# =============================================================================
|
| 76 |
|
| 77 |
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 78 |
logger = logging.getLogger(__name__)
|
|
@@ -116,10 +113,7 @@ HEATMAP_CMAP = LinearSegmentedColormap.from_list(
|
|
| 116 |
)
|
| 117 |
|
| 118 |
|
| 119 |
-
# =============================================================================
|
| 120 |
# MLflow Utilities
|
| 121 |
-
# =============================================================================
|
| 122 |
-
|
| 123 |
|
| 124 |
def get_mlflow_client():
|
| 125 |
"""Get MLflow client with correct tracking URI."""
|
|
@@ -157,10 +151,7 @@ def get_metric_history(run, metric_name: str) -> tuple[list, list]:
|
|
| 157 |
return [m.step for m in metrics], [m.value for m in metrics]
|
| 158 |
|
| 159 |
|
| 160 |
-
# =============================================================================
|
| 161 |
# Core Training Visualizations
|
| 162 |
-
# =============================================================================
|
| 163 |
-
|
| 164 |
|
| 165 |
def plot_loss_curves(run, interactive: bool = False) -> None:
|
| 166 |
"""
|
|
@@ -208,7 +199,7 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
|
|
| 208 |
|
| 209 |
output_path = OUTPUTS_DIR / "training_loss_curve.html"
|
| 210 |
fig.write_html(str(output_path))
|
| 211 |
-
logger.info(f"
|
| 212 |
return
|
| 213 |
|
| 214 |
# Static matplotlib version
|
|
@@ -253,7 +244,7 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
|
|
| 253 |
plt.tight_layout()
|
| 254 |
output_path = OUTPUTS_DIR / "training_loss_curve.png"
|
| 255 |
plt.savefig(output_path)
|
| 256 |
-
logger.info(f"
|
| 257 |
plt.close()
|
| 258 |
|
| 259 |
|
|
@@ -387,7 +378,7 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 387 |
plt.tight_layout()
|
| 388 |
output_path = OUTPUTS_DIR / "task_metrics.png"
|
| 389 |
plt.savefig(output_path)
|
| 390 |
-
logger.info(f"
|
| 391 |
plt.close()
|
| 392 |
|
| 393 |
|
|
@@ -474,14 +465,11 @@ def plot_learning_rate(run) -> None:
|
|
| 474 |
plt.tight_layout()
|
| 475 |
output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
|
| 476 |
plt.savefig(output_path)
|
| 477 |
-
logger.info(f"
|
| 478 |
plt.close()
|
| 479 |
|
| 480 |
|
| 481 |
-
# =============================================================================
|
| 482 |
# Advanced Visualizations
|
| 483 |
-
# =============================================================================
|
| 484 |
-
|
| 485 |
|
| 486 |
def plot_confusion_matrix(run, task: str = "topic") -> None:
|
| 487 |
"""
|
|
@@ -544,7 +532,7 @@ def plot_confusion_matrix(run, task: str = "topic") -> None:
|
|
| 544 |
plt.tight_layout()
|
| 545 |
output_path = OUTPUTS_DIR / f"confusion_matrix_{task}.png"
|
| 546 |
plt.savefig(output_path)
|
| 547 |
-
logger.info(f"
|
| 548 |
plt.close()
|
| 549 |
|
| 550 |
|
|
@@ -646,7 +634,7 @@ def plot_3d_loss_landscape(run) -> None:
|
|
| 646 |
|
| 647 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.html"
|
| 648 |
fig.write_html(str(output_path))
|
| 649 |
-
logger.info(f"
|
| 650 |
|
| 651 |
|
| 652 |
def plot_3d_loss_landscape_static(run) -> None:
|
|
@@ -702,7 +690,7 @@ def plot_3d_loss_landscape_static(run) -> None:
|
|
| 702 |
plt.tight_layout()
|
| 703 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.png"
|
| 704 |
plt.savefig(output_path)
|
| 705 |
-
logger.info(f"
|
| 706 |
plt.close()
|
| 707 |
|
| 708 |
|
|
@@ -770,7 +758,7 @@ def plot_embedding_space(run) -> None:
|
|
| 770 |
plt.tight_layout()
|
| 771 |
output_path = OUTPUTS_DIR / "embedding_space.png"
|
| 772 |
plt.savefig(output_path)
|
| 773 |
-
logger.info(f"
|
| 774 |
plt.close()
|
| 775 |
|
| 776 |
|
|
@@ -868,14 +856,11 @@ def plot_training_dynamics(run) -> None:
|
|
| 868 |
plt.tight_layout()
|
| 869 |
output_path = OUTPUTS_DIR / "training_dynamics.png"
|
| 870 |
plt.savefig(output_path)
|
| 871 |
-
logger.info(f"
|
| 872 |
plt.close()
|
| 873 |
|
| 874 |
|
| 875 |
-
# =============================================================================
|
| 876 |
# Dashboard Generator
|
| 877 |
-
# =============================================================================
|
| 878 |
-
|
| 879 |
|
| 880 |
def generate_dashboard(run) -> None:
|
| 881 |
"""
|
|
@@ -959,13 +944,10 @@ def generate_dashboard(run) -> None:
|
|
| 959 |
|
| 960 |
output_path = OUTPUTS_DIR / "training_dashboard.html"
|
| 961 |
fig.write_html(str(output_path))
|
| 962 |
-
logger.info(f"
|
| 963 |
|
| 964 |
|
| 965 |
-
# =============================================================================
|
| 966 |
# Main Entry Point
|
| 967 |
-
# =============================================================================
|
| 968 |
-
|
| 969 |
|
| 970 |
def main():
|
| 971 |
"""Generate all training visualizations."""
|
|
@@ -1026,7 +1008,7 @@ def main():
|
|
| 1026 |
# Summary
|
| 1027 |
logger.info("")
|
| 1028 |
logger.info("=" * 60)
|
| 1029 |
-
logger.info("
|
| 1030 |
logger.info("=" * 60)
|
| 1031 |
|
| 1032 |
outputs = [
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
LexiMind Training Visualization Suite.
|
| 3 |
|
|
|
|
| 62 |
pass
|
| 63 |
|
| 64 |
try:
|
| 65 |
+
from mpl_toolkits.mplot3d import Axes3D # type: ignore[import-not-found] # noqa: F401
|
| 66 |
|
| 67 |
HAS_MPLOT3D = True
|
| 68 |
except ImportError:
|
| 69 |
pass
|
| 70 |
|
| 71 |
|
|
|
|
| 72 |
# Configuration
|
|
|
|
| 73 |
|
| 74 |
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 75 |
logger = logging.getLogger(__name__)
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
|
|
|
|
| 116 |
# MLflow Utilities
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def get_mlflow_client():
|
| 119 |
"""Get MLflow client with correct tracking URI."""
|
|
|
|
| 151 |
return [m.step for m in metrics], [m.value for m in metrics]
|
| 152 |
|
| 153 |
|
|
|
|
| 154 |
# Core Training Visualizations
|
|
|
|
|
|
|
| 155 |
|
| 156 |
def plot_loss_curves(run, interactive: bool = False) -> None:
|
| 157 |
"""
|
|
|
|
| 199 |
|
| 200 |
output_path = OUTPUTS_DIR / "training_loss_curve.html"
|
| 201 |
fig.write_html(str(output_path))
|
| 202 |
+
logger.info(f"Saved interactive loss curve to {output_path}")
|
| 203 |
return
|
| 204 |
|
| 205 |
# Static matplotlib version
|
|
|
|
| 244 |
plt.tight_layout()
|
| 245 |
output_path = OUTPUTS_DIR / "training_loss_curve.png"
|
| 246 |
plt.savefig(output_path)
|
| 247 |
+
logger.info(f"Saved loss curve to {output_path}")
|
| 248 |
plt.close()
|
| 249 |
|
| 250 |
|
|
|
|
| 378 |
plt.tight_layout()
|
| 379 |
output_path = OUTPUTS_DIR / "task_metrics.png"
|
| 380 |
plt.savefig(output_path)
|
| 381 |
+
logger.info(f"Saved task metrics to {output_path}")
|
| 382 |
plt.close()
|
| 383 |
|
| 384 |
|
|
|
|
| 465 |
plt.tight_layout()
|
| 466 |
output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
|
| 467 |
plt.savefig(output_path)
|
| 468 |
+
logger.info(f"Saved LR schedule to {output_path}")
|
| 469 |
plt.close()
|
| 470 |
|
| 471 |
|
|
|
|
| 472 |
# Advanced Visualizations
|
|
|
|
|
|
|
| 473 |
|
| 474 |
def plot_confusion_matrix(run, task: str = "topic") -> None:
|
| 475 |
"""
|
|
|
|
| 532 |
plt.tight_layout()
|
| 533 |
output_path = OUTPUTS_DIR / f"confusion_matrix_{task}.png"
|
| 534 |
plt.savefig(output_path)
|
| 535 |
+
logger.info(f"Saved confusion matrix to {output_path}")
|
| 536 |
plt.close()
|
| 537 |
|
| 538 |
|
|
|
|
| 634 |
|
| 635 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.html"
|
| 636 |
fig.write_html(str(output_path))
|
| 637 |
+
logger.info(f"Saved 3D loss landscape to {output_path}")
|
| 638 |
|
| 639 |
|
| 640 |
def plot_3d_loss_landscape_static(run) -> None:
|
|
|
|
| 690 |
plt.tight_layout()
|
| 691 |
output_path = OUTPUTS_DIR / "loss_landscape_3d.png"
|
| 692 |
plt.savefig(output_path)
|
| 693 |
+
logger.info(f"Saved 3D loss landscape to {output_path}")
|
| 694 |
plt.close()
|
| 695 |
|
| 696 |
|
|
|
|
| 758 |
plt.tight_layout()
|
| 759 |
output_path = OUTPUTS_DIR / "embedding_space.png"
|
| 760 |
plt.savefig(output_path)
|
| 761 |
+
logger.info(f"Saved embedding visualization to {output_path}")
|
| 762 |
plt.close()
|
| 763 |
|
| 764 |
|
|
|
|
| 856 |
plt.tight_layout()
|
| 857 |
output_path = OUTPUTS_DIR / "training_dynamics.png"
|
| 858 |
plt.savefig(output_path)
|
| 859 |
+
logger.info(f"Saved training dynamics to {output_path}")
|
| 860 |
plt.close()
|
| 861 |
|
| 862 |
|
|
|
|
| 863 |
# Dashboard Generator
|
|
|
|
|
|
|
| 864 |
|
| 865 |
def generate_dashboard(run) -> None:
|
| 866 |
"""
|
|
|
|
| 944 |
|
| 945 |
output_path = OUTPUTS_DIR / "training_dashboard.html"
|
| 946 |
fig.write_html(str(output_path))
|
| 947 |
+
logger.info(f"Saved interactive dashboard to {output_path}")
|
| 948 |
|
| 949 |
|
|
|
|
| 950 |
# Main Entry Point
|
|
|
|
|
|
|
| 951 |
|
| 952 |
def main():
|
| 953 |
"""Generate all training visualizations."""
|
|
|
|
| 1008 |
# Summary
|
| 1009 |
logger.info("")
|
| 1010 |
logger.info("=" * 60)
|
| 1011 |
+
logger.info("All visualizations saved to outputs/")
|
| 1012 |
logger.info("=" * 60)
|
| 1013 |
|
| 1014 |
outputs = [
|
src/data/dataset.py
CHANGED
|
@@ -11,10 +11,11 @@ Date: December 2025
|
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
|
|
|
| 14 |
import json
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from pathlib import Path
|
| 17 |
-
from typing import Callable, Iterable, List, Sequence, TypeVar
|
| 18 |
|
| 19 |
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
|
| 20 |
from torch.utils.data import Dataset
|
|
@@ -23,7 +24,6 @@ from torch.utils.data import Dataset
|
|
| 23 |
@dataclass
|
| 24 |
class SummarizationExample:
|
| 25 |
"""Container for abstractive summarization samples."""
|
| 26 |
-
|
| 27 |
source: str
|
| 28 |
summary: str
|
| 29 |
|
|
@@ -31,7 +31,6 @@ class SummarizationExample:
|
|
| 31 |
@dataclass
|
| 32 |
class EmotionExample:
|
| 33 |
"""Container for multi-label emotion classification samples."""
|
| 34 |
-
|
| 35 |
text: str
|
| 36 |
emotions: Sequence[str]
|
| 37 |
|
|
@@ -39,14 +38,12 @@ class EmotionExample:
|
|
| 39 |
@dataclass
|
| 40 |
class TopicExample:
|
| 41 |
"""Container for topic clustering / classification samples."""
|
| 42 |
-
|
| 43 |
text: str
|
| 44 |
topic: str
|
| 45 |
|
| 46 |
|
| 47 |
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 48 |
"""Dataset yielding encoder-decoder training pairs."""
|
| 49 |
-
|
| 50 |
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 51 |
self._examples = list(examples)
|
| 52 |
|
|
@@ -59,7 +56,6 @@ class SummarizationDataset(Dataset[SummarizationExample]):
|
|
| 59 |
|
| 60 |
class EmotionDataset(Dataset[EmotionExample]):
|
| 61 |
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
| 62 |
-
|
| 63 |
def __init__(
|
| 64 |
self,
|
| 65 |
examples: Iterable[EmotionExample],
|
|
@@ -95,7 +91,6 @@ class EmotionDataset(Dataset[EmotionExample]):
|
|
| 95 |
|
| 96 |
class TopicDataset(Dataset[TopicExample]):
|
| 97 |
"""Dataset that owns a LabelEncoder for topic ids."""
|
| 98 |
-
|
| 99 |
def __init__(
|
| 100 |
self,
|
| 101 |
examples: Iterable[TopicExample],
|
|
@@ -239,3 +234,82 @@ def load_topic_jsonl(path: str) -> List[TopicExample]:
|
|
| 239 |
lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
|
| 240 |
required_keys=("text", "topic"),
|
| 241 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import hashlib
|
| 15 |
import json
|
| 16 |
from dataclasses import dataclass
|
| 17 |
from pathlib import Path
|
| 18 |
+
from typing import Callable, Dict, Iterable, List, Sequence, Set, TypeVar
|
| 19 |
|
| 20 |
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
|
| 21 |
from torch.utils.data import Dataset
|
|
|
|
| 24 |
@dataclass
|
| 25 |
class SummarizationExample:
|
| 26 |
"""Container for abstractive summarization samples."""
|
|
|
|
| 27 |
source: str
|
| 28 |
summary: str
|
| 29 |
|
|
|
|
| 31 |
@dataclass
|
| 32 |
class EmotionExample:
|
| 33 |
"""Container for multi-label emotion classification samples."""
|
|
|
|
| 34 |
text: str
|
| 35 |
emotions: Sequence[str]
|
| 36 |
|
|
|
|
| 38 |
@dataclass
|
| 39 |
class TopicExample:
|
| 40 |
"""Container for topic clustering / classification samples."""
|
|
|
|
| 41 |
text: str
|
| 42 |
topic: str
|
| 43 |
|
| 44 |
|
| 45 |
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 46 |
"""Dataset yielding encoder-decoder training pairs."""
|
|
|
|
| 47 |
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 48 |
self._examples = list(examples)
|
| 49 |
|
|
|
|
| 56 |
|
| 57 |
class EmotionDataset(Dataset[EmotionExample]):
|
| 58 |
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
|
|
|
| 59 |
def __init__(
|
| 60 |
self,
|
| 61 |
examples: Iterable[EmotionExample],
|
|
|
|
| 91 |
|
| 92 |
class TopicDataset(Dataset[TopicExample]):
|
| 93 |
"""Dataset that owns a LabelEncoder for topic ids."""
|
|
|
|
| 94 |
def __init__(
|
| 95 |
self,
|
| 96 |
examples: Iterable[TopicExample],
|
|
|
|
| 234 |
lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
|
| 235 |
required_keys=("text", "topic"),
|
| 236 |
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# --------------- Cross-Task Deduplication ---------------
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _text_fingerprint(text: str, n_chars: int = 200) -> str:
|
| 243 |
+
"""Create a stable fingerprint from the first N characters of text.
|
| 244 |
+
|
| 245 |
+
Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
|
| 246 |
+
to detect document-level overlap across tasks.
|
| 247 |
+
"""
|
| 248 |
+
normalized = " ".join(text.lower().split())[:n_chars]
|
| 249 |
+
return hashlib.md5(normalized.encode("utf-8")).hexdigest()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def deduplicate_across_tasks(
|
| 253 |
+
summ_examples: List[SummarizationExample],
|
| 254 |
+
topic_examples: List[TopicExample],
|
| 255 |
+
emotion_examples: List[EmotionExample] | None = None,
|
| 256 |
+
) -> Dict[str, int]:
|
| 257 |
+
"""Detect and report cross-task document overlap.
|
| 258 |
+
|
| 259 |
+
Checks whether texts appearing in the summarization dataset also appear
|
| 260 |
+
in the topic or emotion datasets, which could create data leakage in MTL.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Dict with overlap counts between task pairs.
|
| 264 |
+
"""
|
| 265 |
+
summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
|
| 266 |
+
topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
|
| 267 |
+
|
| 268 |
+
overlap: Dict[str, int] = {
|
| 269 |
+
"summ_topic_overlap": len(summ_fps & topic_fps),
|
| 270 |
+
"summ_total": len(summ_fps),
|
| 271 |
+
"topic_total": len(topic_fps),
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
if emotion_examples:
|
| 275 |
+
emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
|
| 276 |
+
overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
|
| 277 |
+
overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
|
| 278 |
+
overlap["emotion_total"] = len(emot_fps)
|
| 279 |
+
|
| 280 |
+
return overlap
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def remove_overlapping_examples(
|
| 284 |
+
primary_examples: List[TopicExample],
|
| 285 |
+
reference_examples: List[SummarizationExample],
|
| 286 |
+
split: str = "val",
|
| 287 |
+
) -> tuple[List[TopicExample], int]:
|
| 288 |
+
"""Remove topic examples whose texts overlap with summarization data.
|
| 289 |
+
|
| 290 |
+
This prevents cross-task data leakage where a document seen during
|
| 291 |
+
summarization training could boost topic classification on validation/test.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
primary_examples: Topic examples to filter
|
| 295 |
+
reference_examples: Summarization examples to check against
|
| 296 |
+
split: Name of split being processed (for logging)
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Tuple of (filtered_examples, num_removed)
|
| 300 |
+
"""
|
| 301 |
+
ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
|
| 302 |
+
|
| 303 |
+
filtered = []
|
| 304 |
+
removed = 0
|
| 305 |
+
for ex in primary_examples:
|
| 306 |
+
fp = _text_fingerprint(ex.text)
|
| 307 |
+
if fp in ref_fps:
|
| 308 |
+
removed += 1
|
| 309 |
+
else:
|
| 310 |
+
filtered.append(ex)
|
| 311 |
+
|
| 312 |
+
if removed > 0:
|
| 313 |
+
print(f" Dedup: removed {removed} overlapping examples from topic {split}")
|
| 314 |
+
|
| 315 |
+
return filtered, removed
|
src/models/factory.py
CHANGED
|
@@ -102,7 +102,7 @@ def _load_pretrained_weights(
|
|
| 102 |
Load pretrained T5/FLAN-T5 weights into custom encoder/decoder.
|
| 103 |
|
| 104 |
T5 architecture compatibility with our custom Transformer:
|
| 105 |
-
- T5 uses Pre-LN (RMSNorm before sublayers)
|
| 106 |
- T5 uses relative position bias instead of absolute embeddings
|
| 107 |
-> We now load T5's relative position bias weights into our T5RelativePositionBias modules
|
| 108 |
-> This allows exact weight transfer without requiring fine-tuning
|
|
@@ -548,13 +548,15 @@ def build_multitask_model(
|
|
| 548 |
"summarization",
|
| 549 |
LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
|
| 550 |
)
|
| 551 |
-
# Emotion head with 2-layer MLP for better multi-label capacity (28 classes)
|
|
|
|
|
|
|
| 552 |
model.add_head(
|
| 553 |
"emotion",
|
| 554 |
ClassificationHead(
|
| 555 |
d_model=cfg.d_model,
|
| 556 |
num_labels=num_emotions,
|
| 557 |
-
pooler="
|
| 558 |
dropout=cfg.dropout,
|
| 559 |
hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
|
| 560 |
),
|
|
|
|
| 102 |
Load pretrained T5/FLAN-T5 weights into custom encoder/decoder.
|
| 103 |
|
| 104 |
T5 architecture compatibility with our custom Transformer:
|
| 105 |
+
- T5 uses Pre-LN (RMSNorm before sublayers) - matches our design
|
| 106 |
- T5 uses relative position bias instead of absolute embeddings
|
| 107 |
-> We now load T5's relative position bias weights into our T5RelativePositionBias modules
|
| 108 |
-> This allows exact weight transfer without requiring fine-tuning
|
|
|
|
| 548 |
"summarization",
|
| 549 |
LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
|
| 550 |
)
|
| 551 |
+
# Emotion head with attention pooling + 2-layer MLP for better multi-label capacity (28 classes)
|
| 552 |
+
# Attention pooling is superior to mean pooling for encoder-decoder models where
|
| 553 |
+
# hidden states are optimized for cross-attention rather than simple averaging.
|
| 554 |
model.add_head(
|
| 555 |
"emotion",
|
| 556 |
ClassificationHead(
|
| 557 |
d_model=cfg.d_model,
|
| 558 |
num_labels=num_emotions,
|
| 559 |
+
pooler="attention",
|
| 560 |
dropout=cfg.dropout,
|
| 561 |
hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
|
| 562 |
),
|
src/models/heads.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Prediction heads for Transformer models.
|
| 2 |
|
| 3 |
This module provides task-specific output heads:
|
| 4 |
-
- ClassificationHead: Sequence-level classification with pooling (mean/cls/max)
|
| 5 |
- TokenClassificationHead: Per-token classification (NER, POS tagging)
|
| 6 |
- LMHead: Language modeling logits with optional weight tying
|
| 7 |
- ProjectionHead: MLP for representation learning / contrastive tasks
|
|
@@ -14,6 +14,35 @@ from typing import Literal, Optional
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class ClassificationHead(nn.Module):
|
|
@@ -23,7 +52,7 @@ class ClassificationHead(nn.Module):
|
|
| 23 |
Args:
|
| 24 |
d_model: hidden size from encoder/decoder
|
| 25 |
num_labels: number of output classes
|
| 26 |
-
pooler: one of 'mean', 'cls', 'max' - how to pool the sequence
|
| 27 |
dropout: dropout probability before final linear layer
|
| 28 |
hidden_dim: optional intermediate dimension for 2-layer MLP (improves capacity)
|
| 29 |
"""
|
|
@@ -32,14 +61,17 @@ class ClassificationHead(nn.Module):
|
|
| 32 |
self,
|
| 33 |
d_model: int,
|
| 34 |
num_labels: int,
|
| 35 |
-
pooler: Literal["mean", "cls", "max"] = "mean",
|
| 36 |
dropout: float = 0.1,
|
| 37 |
hidden_dim: Optional[int] = None,
|
| 38 |
):
|
| 39 |
super().__init__()
|
| 40 |
-
assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'"
|
| 41 |
self.pooler = pooler
|
| 42 |
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Optional 2-layer MLP for more capacity (useful for multi-label)
|
| 45 |
if hidden_dim is not None:
|
|
@@ -58,19 +90,14 @@ class ClassificationHead(nn.Module):
|
|
| 58 |
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 59 |
returns: (batch, num_labels)
|
| 60 |
"""
|
| 61 |
-
if self.pooler == "
|
|
|
|
|
|
|
| 62 |
if mask is not None:
|
| 63 |
-
# mask is (B, S)
|
| 64 |
-
# x is (B, S, D)
|
| 65 |
-
# Expand mask to (B, S, 1)
|
| 66 |
mask_expanded = mask.unsqueeze(-1).float()
|
| 67 |
-
# Zero out padding
|
| 68 |
x = x * mask_expanded
|
| 69 |
-
# Sum over sequence
|
| 70 |
sum_embeddings = x.sum(dim=1)
|
| 71 |
-
# Count valid tokens
|
| 72 |
sum_mask = mask_expanded.sum(dim=1)
|
| 73 |
-
# Avoid division by zero
|
| 74 |
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 75 |
pooled = sum_embeddings / sum_mask
|
| 76 |
else:
|
|
@@ -79,7 +106,6 @@ class ClassificationHead(nn.Module):
|
|
| 79 |
pooled = x[:, 0, :]
|
| 80 |
else: # max
|
| 81 |
if mask is not None:
|
| 82 |
-
# Mask padding with -inf
|
| 83 |
mask_expanded = mask.unsqueeze(-1)
|
| 84 |
x = x.masked_fill(~mask_expanded, float("-inf"))
|
| 85 |
pooled, _ = x.max(dim=1)
|
|
|
|
| 1 |
"""Prediction heads for Transformer models.
|
| 2 |
|
| 3 |
This module provides task-specific output heads:
|
| 4 |
+
- ClassificationHead: Sequence-level classification with pooling (mean/cls/max/attention)
|
| 5 |
- TokenClassificationHead: Per-token classification (NER, POS tagging)
|
| 6 |
- LMHead: Language modeling logits with optional weight tying
|
| 7 |
- ProjectionHead: MLP for representation learning / contrastive tasks
|
|
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AttentionPooling(nn.Module):
|
| 21 |
+
"""Learned attention pooling over sequence positions.
|
| 22 |
+
|
| 23 |
+
Computes a weighted sum of hidden states using a learned query vector,
|
| 24 |
+
producing a single fixed-size representation. This is generally superior
|
| 25 |
+
to mean pooling for classification tasks on encoder-decoder models where
|
| 26 |
+
hidden states are optimized for cross-attention rather than pooling.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, d_model: int):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.query = nn.Linear(d_model, 1, bias=False)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
x: (batch, seq_len, d_model)
|
| 36 |
+
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 37 |
+
returns: (batch, d_model)
|
| 38 |
+
"""
|
| 39 |
+
# Compute attention scores: (batch, seq_len, 1)
|
| 40 |
+
scores = self.query(x)
|
| 41 |
+
if mask is not None:
|
| 42 |
+
scores = scores.masked_fill(~mask.unsqueeze(-1), float("-inf"))
|
| 43 |
+
weights = F.softmax(scores, dim=1) # (batch, seq_len, 1)
|
| 44 |
+
# Weighted sum: (batch, d_model)
|
| 45 |
+
return (weights * x).sum(dim=1)
|
| 46 |
|
| 47 |
|
| 48 |
class ClassificationHead(nn.Module):
|
|
|
|
| 52 |
Args:
|
| 53 |
d_model: hidden size from encoder/decoder
|
| 54 |
num_labels: number of output classes
|
| 55 |
+
pooler: one of 'mean', 'cls', 'max', 'attention' - how to pool the sequence
|
| 56 |
dropout: dropout probability before final linear layer
|
| 57 |
hidden_dim: optional intermediate dimension for 2-layer MLP (improves capacity)
|
| 58 |
"""
|
|
|
|
| 61 |
self,
|
| 62 |
d_model: int,
|
| 63 |
num_labels: int,
|
| 64 |
+
pooler: Literal["mean", "cls", "max", "attention"] = "mean",
|
| 65 |
dropout: float = 0.1,
|
| 66 |
hidden_dim: Optional[int] = None,
|
| 67 |
):
|
| 68 |
super().__init__()
|
| 69 |
+
assert pooler in ("mean", "cls", "max", "attention"), "pooler must be 'mean'|'cls'|'max'|'attention'"
|
| 70 |
self.pooler = pooler
|
| 71 |
self.dropout = nn.Dropout(dropout)
|
| 72 |
+
|
| 73 |
+
if pooler == "attention":
|
| 74 |
+
self.attn_pool = AttentionPooling(d_model)
|
| 75 |
|
| 76 |
# Optional 2-layer MLP for more capacity (useful for multi-label)
|
| 77 |
if hidden_dim is not None:
|
|
|
|
| 90 |
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 91 |
returns: (batch, num_labels)
|
| 92 |
"""
|
| 93 |
+
if self.pooler == "attention":
|
| 94 |
+
pooled = self.attn_pool(x, mask)
|
| 95 |
+
elif self.pooler == "mean":
|
| 96 |
if mask is not None:
|
|
|
|
|
|
|
|
|
|
| 97 |
mask_expanded = mask.unsqueeze(-1).float()
|
|
|
|
| 98 |
x = x * mask_expanded
|
|
|
|
| 99 |
sum_embeddings = x.sum(dim=1)
|
|
|
|
| 100 |
sum_mask = mask_expanded.sum(dim=1)
|
|
|
|
| 101 |
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 102 |
pooled = sum_embeddings / sum_mask
|
| 103 |
else:
|
|
|
|
| 106 |
pooled = x[:, 0, :]
|
| 107 |
else: # max
|
| 108 |
if mask is not None:
|
|
|
|
| 109 |
mask_expanded = mask.unsqueeze(-1)
|
| 110 |
x = x.masked_fill(~mask_expanded, float("-inf"))
|
| 111 |
pooled, _ = x.max(dim=1)
|
src/training/metrics.py
CHANGED
|
@@ -90,7 +90,7 @@ def calculate_bertscore(
|
|
| 90 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 91 |
|
| 92 |
try:
|
| 93 |
-
from bert_score import score as bert_score
|
| 94 |
except ImportError:
|
| 95 |
print("Warning: bert-score not installed. Run: pip install bert-score")
|
| 96 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
@@ -110,9 +110,9 @@ def calculate_bertscore(
|
|
| 110 |
)
|
| 111 |
|
| 112 |
return {
|
| 113 |
-
"precision": float(P.mean().item()),
|
| 114 |
-
"recall": float(R.mean().item()),
|
| 115 |
-
"f1": float(F1.mean().item()),
|
| 116 |
}
|
| 117 |
|
| 118 |
|
|
@@ -239,3 +239,213 @@ def get_confusion_matrix(
|
|
| 239 |
) -> np.ndarray:
|
| 240 |
"""Compute confusion matrix."""
|
| 241 |
return cast(np.ndarray, confusion_matrix(targets, predictions, labels=labels))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 91 |
|
| 92 |
try:
|
| 93 |
+
from bert_score import score as bert_score # type: ignore[import-not-found]
|
| 94 |
except ImportError:
|
| 95 |
print("Warning: bert-score not installed. Run: pip install bert-score")
|
| 96 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
|
|
| 110 |
)
|
| 111 |
|
| 112 |
return {
|
| 113 |
+
"precision": float(P.mean().item()), # type: ignore[union-attr]
|
| 114 |
+
"recall": float(R.mean().item()), # type: ignore[union-attr]
|
| 115 |
+
"f1": float(F1.mean().item()), # type: ignore[union-attr]
|
| 116 |
}
|
| 117 |
|
| 118 |
|
|
|
|
| 239 |
) -> np.ndarray:
|
| 240 |
"""Compute confusion matrix."""
|
| 241 |
return cast(np.ndarray, confusion_matrix(targets, predictions, labels=labels))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# --------------- Multi-label Emotion Metrics ---------------
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 248 |
+
"""Compute macro F1: average F1 per class (as in GoEmotions paper).
|
| 249 |
+
|
| 250 |
+
This averages F1 across labels, giving equal weight to each emotion class
|
| 251 |
+
regardless of prevalence. Directly comparable to GoEmotions baselines.
|
| 252 |
+
"""
|
| 253 |
+
preds = predictions.float()
|
| 254 |
+
gold = targets.float()
|
| 255 |
+
|
| 256 |
+
# Per-class TP, FP, FN
|
| 257 |
+
tp = (preds * gold).sum(dim=0)
|
| 258 |
+
fp = (preds * (1 - gold)).sum(dim=0)
|
| 259 |
+
fn = ((1 - preds) * gold).sum(dim=0)
|
| 260 |
+
|
| 261 |
+
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 262 |
+
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 263 |
+
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 264 |
+
|
| 265 |
+
# Zero out F1 for classes with no support in either predictions or targets
|
| 266 |
+
mask = (tp + fp + fn) > 0
|
| 267 |
+
if mask.sum() == 0:
|
| 268 |
+
return 0.0
|
| 269 |
+
return float(f1[mask].mean().item())
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 273 |
+
"""Compute micro F1: aggregate TP/FP/FN across all classes.
|
| 274 |
+
|
| 275 |
+
This gives more weight to frequent classes. Useful when class distribution matters.
|
| 276 |
+
"""
|
| 277 |
+
preds = predictions.float()
|
| 278 |
+
gold = targets.float()
|
| 279 |
+
|
| 280 |
+
tp = (preds * gold).sum()
|
| 281 |
+
fp = (preds * (1 - gold)).sum()
|
| 282 |
+
fn = ((1 - preds) * gold).sum()
|
| 283 |
+
|
| 284 |
+
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 285 |
+
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 286 |
+
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 287 |
+
return float(f1.item())
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def multilabel_per_class_metrics(
|
| 291 |
+
predictions: torch.Tensor,
|
| 292 |
+
targets: torch.Tensor,
|
| 293 |
+
class_names: Sequence[str] | None = None,
|
| 294 |
+
) -> Dict[str, Dict[str, float]]:
|
| 295 |
+
"""Compute per-class precision, recall, F1 for multi-label classification.
|
| 296 |
+
|
| 297 |
+
Returns a dict mapping class name/index to its metrics.
|
| 298 |
+
"""
|
| 299 |
+
preds = predictions.float()
|
| 300 |
+
gold = targets.float()
|
| 301 |
+
num_classes = preds.shape[1]
|
| 302 |
+
|
| 303 |
+
tp = (preds * gold).sum(dim=0)
|
| 304 |
+
fp = (preds * (1 - gold)).sum(dim=0)
|
| 305 |
+
fn = ((1 - preds) * gold).sum(dim=0)
|
| 306 |
+
|
| 307 |
+
report: Dict[str, Dict[str, float]] = {}
|
| 308 |
+
for i in range(num_classes):
|
| 309 |
+
name = class_names[i] if class_names else str(i)
|
| 310 |
+
p = (tp[i] / (tp[i] + fp[i]).clamp(min=1e-8)).item()
|
| 311 |
+
r = (tp[i] / (tp[i] + fn[i]).clamp(min=1e-8)).item()
|
| 312 |
+
f = (2 * p * r) / max(p + r, 1e-8)
|
| 313 |
+
report[name] = {
|
| 314 |
+
"precision": p,
|
| 315 |
+
"recall": r,
|
| 316 |
+
"f1": f,
|
| 317 |
+
"support": int(gold[:, i].sum().item()),
|
| 318 |
+
}
|
| 319 |
+
return report
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def tune_per_class_thresholds(
|
| 323 |
+
logits: torch.Tensor,
|
| 324 |
+
targets: torch.Tensor,
|
| 325 |
+
thresholds: Sequence[float] | None = None,
|
| 326 |
+
) -> tuple[List[float], float]:
|
| 327 |
+
"""Tune per-class thresholds on validation set to maximize macro F1.
|
| 328 |
+
|
| 329 |
+
For each class, tries multiple thresholds and selects the one that
|
| 330 |
+
maximizes that class's F1 score. This is standard practice for multi-label
|
| 331 |
+
classification (used in the original GoEmotions paper).
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
logits: Raw model logits (batch, num_classes)
|
| 335 |
+
targets: Binary target labels (batch, num_classes)
|
| 336 |
+
thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Tuple of (best_thresholds_per_class, resulting_macro_f1)
|
| 340 |
+
"""
|
| 341 |
+
if thresholds is None:
|
| 342 |
+
thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
|
| 343 |
+
|
| 344 |
+
probs = torch.sigmoid(logits)
|
| 345 |
+
num_classes = probs.shape[1]
|
| 346 |
+
gold = targets.float()
|
| 347 |
+
|
| 348 |
+
best_thresholds: List[float] = []
|
| 349 |
+
for c in range(num_classes):
|
| 350 |
+
best_f1 = -1.0
|
| 351 |
+
best_t = 0.5
|
| 352 |
+
for t in thresholds:
|
| 353 |
+
preds = (probs[:, c] >= t).float()
|
| 354 |
+
tp = (preds * gold[:, c]).sum()
|
| 355 |
+
fp = (preds * (1 - gold[:, c])).sum()
|
| 356 |
+
fn = ((1 - preds) * gold[:, c]).sum()
|
| 357 |
+
if tp + fp > 0 and tp + fn > 0:
|
| 358 |
+
p = tp / (tp + fp)
|
| 359 |
+
r = tp / (tp + fn)
|
| 360 |
+
f1 = (2 * p * r / (p + r)).item()
|
| 361 |
+
else:
|
| 362 |
+
f1 = 0.0
|
| 363 |
+
if f1 > best_f1:
|
| 364 |
+
best_f1 = f1
|
| 365 |
+
best_t = t
|
| 366 |
+
best_thresholds.append(best_t)
|
| 367 |
+
|
| 368 |
+
# Compute resulting macro F1 with tuned thresholds
|
| 369 |
+
tuned_preds = torch.zeros_like(probs)
|
| 370 |
+
for c in range(num_classes):
|
| 371 |
+
tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
|
| 372 |
+
macro_f1 = multilabel_macro_f1(tuned_preds, targets)
|
| 373 |
+
|
| 374 |
+
return best_thresholds, macro_f1
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# --------------- Statistical Tests ---------------
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def bootstrap_confidence_interval(
|
| 381 |
+
scores: Sequence[float],
|
| 382 |
+
n_bootstrap: int = 1000,
|
| 383 |
+
confidence: float = 0.95,
|
| 384 |
+
seed: int = 42,
|
| 385 |
+
) -> tuple[float, float, float]:
|
| 386 |
+
"""Compute bootstrap confidence interval for a metric.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
scores: Per-sample metric values
|
| 390 |
+
n_bootstrap: Number of bootstrap resamples
|
| 391 |
+
confidence: Confidence level (default 95%)
|
| 392 |
+
seed: Random seed for reproducibility
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
Tuple of (mean, lower_bound, upper_bound)
|
| 396 |
+
"""
|
| 397 |
+
rng = np.random.default_rng(seed)
|
| 398 |
+
scores_arr = np.array(scores)
|
| 399 |
+
n = len(scores_arr)
|
| 400 |
+
|
| 401 |
+
bootstrap_means = []
|
| 402 |
+
for _ in range(n_bootstrap):
|
| 403 |
+
sample = rng.choice(scores_arr, size=n, replace=True)
|
| 404 |
+
bootstrap_means.append(float(np.mean(sample)))
|
| 405 |
+
|
| 406 |
+
bootstrap_means.sort()
|
| 407 |
+
alpha = 1 - confidence
|
| 408 |
+
lower_idx = int(alpha / 2 * n_bootstrap)
|
| 409 |
+
upper_idx = int((1 - alpha / 2) * n_bootstrap)
|
| 410 |
+
|
| 411 |
+
return (
|
| 412 |
+
float(np.mean(scores_arr)),
|
| 413 |
+
bootstrap_means[lower_idx],
|
| 414 |
+
bootstrap_means[min(upper_idx, n_bootstrap - 1)],
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def paired_bootstrap_test(
|
| 419 |
+
scores_a: Sequence[float],
|
| 420 |
+
scores_b: Sequence[float],
|
| 421 |
+
n_bootstrap: int = 10000,
|
| 422 |
+
seed: int = 42,
|
| 423 |
+
) -> float:
|
| 424 |
+
"""Paired bootstrap significance test between two systems.
|
| 425 |
+
|
| 426 |
+
Tests if system B is significantly better than system A.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
scores_a: Per-sample scores from system A
|
| 430 |
+
scores_b: Per-sample scores from system B
|
| 431 |
+
n_bootstrap: Number of bootstrap iterations
|
| 432 |
+
seed: Random seed
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
p-value (probability that B is not better than A)
|
| 436 |
+
"""
|
| 437 |
+
rng = np.random.default_rng(seed)
|
| 438 |
+
a = np.array(scores_a)
|
| 439 |
+
b = np.array(scores_b)
|
| 440 |
+
assert len(a) == len(b), "Both score lists must have the same length"
|
| 441 |
+
|
| 442 |
+
n = len(a)
|
| 443 |
+
|
| 444 |
+
count = 0
|
| 445 |
+
for _ in range(n_bootstrap):
|
| 446 |
+
idx = rng.choice(n, size=n, replace=True)
|
| 447 |
+
diff = float(np.mean(b[idx]) - np.mean(a[idx]))
|
| 448 |
+
if diff <= 0:
|
| 449 |
+
count += 1
|
| 450 |
+
|
| 451 |
+
return count / n_bootstrap
|
src/training/trainer.py
CHANGED
|
@@ -7,6 +7,8 @@ Handles training across summarization, emotion, and topic heads with:
|
|
| 7 |
- Cosine LR schedule with warmup
|
| 8 |
- Early stopping
|
| 9 |
- MLflow logging
|
|
|
|
|
|
|
| 10 |
|
| 11 |
Author: Oliver Perrin
|
| 12 |
Date: December 2025
|
|
@@ -22,6 +24,7 @@ from dataclasses import dataclass
|
|
| 22 |
from typing import Any, Callable, Dict, List
|
| 23 |
|
| 24 |
import mlflow
|
|
|
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
from torch.optim.lr_scheduler import LambdaLR
|
|
@@ -53,6 +56,16 @@ class TrainerConfig:
|
|
| 53 |
# Early stopping
|
| 54 |
early_stopping_patience: int | None = 5
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# MLflow
|
| 57 |
experiment_name: str = "LexiMind"
|
| 58 |
run_name: str | None = None
|
|
@@ -167,8 +180,8 @@ class Trainer:
|
|
| 167 |
if self.early_stopping:
|
| 168 |
val_loss = val_metrics.get("total_loss", float('inf'))
|
| 169 |
if self.early_stopping(val_loss):
|
| 170 |
-
tqdm.write(f"\
|
| 171 |
-
|
| 172 |
break
|
| 173 |
|
| 174 |
# Checkpoint
|
|
@@ -181,7 +194,7 @@ class Trainer:
|
|
| 181 |
pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
|
| 182 |
|
| 183 |
total_time = time.perf_counter() - total_start
|
| 184 |
-
print(f"\
|
| 185 |
return history
|
| 186 |
|
| 187 |
def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
|
|
@@ -201,7 +214,7 @@ class Trainer:
|
|
| 201 |
return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
|
| 202 |
|
| 203 |
self.scheduler = LambdaLR(self.optimizer, lr_lambda)
|
| 204 |
-
print(f"
|
| 205 |
|
| 206 |
def _run_epoch(
|
| 207 |
self,
|
|
@@ -210,7 +223,7 @@ class Trainer:
|
|
| 210 |
train: bool,
|
| 211 |
epoch: int,
|
| 212 |
) -> Dict[str, float]:
|
| 213 |
-
"""Run one epoch."""
|
| 214 |
self.model.train(train)
|
| 215 |
metrics: Dict[str, List[float]] = defaultdict(list)
|
| 216 |
iterators = {task: iter(loader) for task, loader in loaders.items()}
|
|
@@ -220,12 +233,33 @@ class Trainer:
|
|
| 220 |
phase = "Train" if train else "Val"
|
| 221 |
pbar = tqdm(range(max_batches), desc=f" {phase}", leave=False, file=sys.stderr)
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
ctx = torch.enable_grad() if train else torch.no_grad()
|
| 224 |
with ctx:
|
| 225 |
for step in pbar:
|
| 226 |
step_loss = 0.0
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
batch = self._get_batch(iterators, loader, task)
|
| 230 |
if batch is None:
|
| 231 |
continue
|
|
@@ -253,6 +287,14 @@ class Trainer:
|
|
| 253 |
scaled = (loss * weight) / accum
|
| 254 |
scaled.backward()
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# Optimizer step
|
| 257 |
if train and (step + 1) % accum == 0:
|
| 258 |
torch.nn.utils.clip_grad_norm_(
|
|
@@ -415,6 +457,56 @@ class Trainer:
|
|
| 415 |
tqdm.write(f"{'=' * 50}\n")
|
| 416 |
self.model.train()
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
def _log_config(self) -> None:
|
| 419 |
"""Log config to MLflow."""
|
| 420 |
mlflow.log_params({
|
|
|
|
| 7 |
- Cosine LR schedule with warmup
|
| 8 |
- Early stopping
|
| 9 |
- MLflow logging
|
| 10 |
+
- Temperature-based task sampling (configurable alpha)
|
| 11 |
+
- Gradient conflict diagnostics
|
| 12 |
|
| 13 |
Author: Oliver Perrin
|
| 14 |
Date: December 2025
|
|
|
|
| 24 |
from typing import Any, Callable, Dict, List
|
| 25 |
|
| 26 |
import mlflow
|
| 27 |
+
import numpy as np
|
| 28 |
import torch
|
| 29 |
import torch.nn.functional as F
|
| 30 |
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
| 56 |
# Early stopping
|
| 57 |
early_stopping_patience: int | None = 5
|
| 58 |
|
| 59 |
+
# Task sampling strategy: "round_robin" or "temperature"
|
| 60 |
+
# Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
|
| 61 |
+
# alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
|
| 62 |
+
task_sampling: str = "temperature"
|
| 63 |
+
task_sampling_alpha: float = 0.5
|
| 64 |
+
|
| 65 |
+
# Gradient conflict diagnostics
|
| 66 |
+
# Compute inter-task gradient cosine similarity every N steps (0 = disabled)
|
| 67 |
+
gradient_conflict_frequency: int = 0
|
| 68 |
+
|
| 69 |
# MLflow
|
| 70 |
experiment_name: str = "LexiMind"
|
| 71 |
run_name: str | None = None
|
|
|
|
| 180 |
if self.early_stopping:
|
| 181 |
val_loss = val_metrics.get("total_loss", float('inf'))
|
| 182 |
if self.early_stopping(val_loss):
|
| 183 |
+
tqdm.write(f"\nEarly stopping at epoch {epoch} (best loss: {self.early_stopping.best_value:.4f})")
|
| 184 |
+
|
| 185 |
break
|
| 186 |
|
| 187 |
# Checkpoint
|
|
|
|
| 194 |
pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
|
| 195 |
|
| 196 |
total_time = time.perf_counter() - total_start
|
| 197 |
+
print(f"\nTraining complete in {total_time/60:.1f} minutes")
|
| 198 |
return history
|
| 199 |
|
| 200 |
def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
|
|
|
|
| 214 |
return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
|
| 215 |
|
| 216 |
self.scheduler = LambdaLR(self.optimizer, lr_lambda)
|
| 217 |
+
print(f" LR schedule: cosine, {warmup} warmup, {total_steps} total steps")
|
| 218 |
|
| 219 |
def _run_epoch(
|
| 220 |
self,
|
|
|
|
| 223 |
train: bool,
|
| 224 |
epoch: int,
|
| 225 |
) -> Dict[str, float]:
|
| 226 |
+
"""Run one epoch with configurable task sampling strategy."""
|
| 227 |
self.model.train(train)
|
| 228 |
metrics: Dict[str, List[float]] = defaultdict(list)
|
| 229 |
iterators = {task: iter(loader) for task, loader in loaders.items()}
|
|
|
|
| 233 |
phase = "Train" if train else "Val"
|
| 234 |
pbar = tqdm(range(max_batches), desc=f" {phase}", leave=False, file=sys.stderr)
|
| 235 |
|
| 236 |
+
# Temperature-based task sampling: p_i ∝ n_i^alpha
|
| 237 |
+
task_names = list(loaders.keys())
|
| 238 |
+
if self.config.task_sampling == "temperature" and len(task_names) > 1:
|
| 239 |
+
sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
|
| 240 |
+
alpha = self.config.task_sampling_alpha
|
| 241 |
+
probs = sizes ** alpha
|
| 242 |
+
probs = probs / probs.sum()
|
| 243 |
+
tqdm.write(f" Temperature sampling (α={alpha}): " +
|
| 244 |
+
", ".join(f"{t}={p:.2%}" for t, p in zip(task_names, probs, strict=True)))
|
| 245 |
+
else:
|
| 246 |
+
probs = None
|
| 247 |
+
|
| 248 |
ctx = torch.enable_grad() if train else torch.no_grad()
|
| 249 |
with ctx:
|
| 250 |
for step in pbar:
|
| 251 |
step_loss = 0.0
|
| 252 |
|
| 253 |
+
# Select tasks for this step
|
| 254 |
+
if probs is not None and train:
|
| 255 |
+
# Temperature sampling: sample tasks based on dataset size
|
| 256 |
+
selected_tasks = list(np.random.choice(task_names, size=len(task_names), replace=True, p=probs))
|
| 257 |
+
else:
|
| 258 |
+
# Round-robin: all tasks every step
|
| 259 |
+
selected_tasks = task_names
|
| 260 |
+
|
| 261 |
+
for task in selected_tasks:
|
| 262 |
+
loader = loaders[task]
|
| 263 |
batch = self._get_batch(iterators, loader, task)
|
| 264 |
if batch is None:
|
| 265 |
continue
|
|
|
|
| 287 |
scaled = (loss * weight) / accum
|
| 288 |
scaled.backward()
|
| 289 |
|
| 290 |
+
# Gradient conflict diagnostics
|
| 291 |
+
if (train and self.config.gradient_conflict_frequency > 0
|
| 292 |
+
and (step + 1) % self.config.gradient_conflict_frequency == 0):
|
| 293 |
+
conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
|
| 294 |
+
for k, v in conflict_stats.items():
|
| 295 |
+
metrics[f"grad_{k}"].append(v)
|
| 296 |
+
mlflow.log_metric(f"grad_{k}", v, step=self.global_step)
|
| 297 |
+
|
| 298 |
# Optimizer step
|
| 299 |
if train and (step + 1) % accum == 0:
|
| 300 |
torch.nn.utils.clip_grad_norm_(
|
|
|
|
| 457 |
tqdm.write(f"{'=' * 50}\n")
|
| 458 |
self.model.train()
|
| 459 |
|
| 460 |
+
def _compute_gradient_conflicts(
|
| 461 |
+
self,
|
| 462 |
+
loaders: Dict[str, DataLoader],
|
| 463 |
+
iterators: Dict,
|
| 464 |
+
) -> Dict[str, float]:
|
| 465 |
+
"""Compute inter-task gradient cosine similarity to diagnose conflicts.
|
| 466 |
+
|
| 467 |
+
Returns cosine similarity between gradient vectors for each task pair.
|
| 468 |
+
Negative values indicate conflicting gradients (negative transfer risk).
|
| 469 |
+
"""
|
| 470 |
+
task_grads: Dict[str, torch.Tensor] = {}
|
| 471 |
+
|
| 472 |
+
for task, loader in loaders.items():
|
| 473 |
+
self.optimizer.zero_grad()
|
| 474 |
+
batch = self._get_batch(iterators, loader, task)
|
| 475 |
+
if batch is None:
|
| 476 |
+
continue
|
| 477 |
+
|
| 478 |
+
dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
|
| 479 |
+
with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
|
| 480 |
+
loss, _ = self._forward_task(task, batch)
|
| 481 |
+
|
| 482 |
+
if torch.isnan(loss):
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
loss.backward()
|
| 486 |
+
|
| 487 |
+
# Flatten all gradients into a single vector
|
| 488 |
+
grad_vec = []
|
| 489 |
+
for p in self.model.parameters():
|
| 490 |
+
if p.grad is not None:
|
| 491 |
+
grad_vec.append(p.grad.detach().clone().flatten())
|
| 492 |
+
if grad_vec:
|
| 493 |
+
task_grads[task] = torch.cat(grad_vec)
|
| 494 |
+
|
| 495 |
+
self.optimizer.zero_grad()
|
| 496 |
+
|
| 497 |
+
# Compute pairwise cosine similarity
|
| 498 |
+
stats: Dict[str, float] = {}
|
| 499 |
+
tasks = list(task_grads.keys())
|
| 500 |
+
for i in range(len(tasks)):
|
| 501 |
+
for j in range(i + 1, len(tasks)):
|
| 502 |
+
t1, t2 = tasks[i], tasks[j]
|
| 503 |
+
g1, g2 = task_grads[t1], task_grads[t2]
|
| 504 |
+
cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
|
| 505 |
+
stats[f"cos_sim_{t1}_{t2}"] = cos_sim
|
| 506 |
+
stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
|
| 507 |
+
|
| 508 |
+
return stats
|
| 509 |
+
|
| 510 |
def _log_config(self) -> None:
|
| 511 |
"""Log config to MLflow."""
|
| 512 |
mlflow.log_params({
|