Spaces:
Running
Running
OliverPerrin commited on
Commit ·
df3ebbd
1
Parent(s): c472a19
updated readme, ruff formatted all files
Browse files- README.md +88 -114
- scripts/build_discovery_dataset.py +72 -67
- scripts/demo_gradio.py +63 -58
- scripts/download_data.py +542 -326
- scripts/evaluate.py +103 -83
- scripts/profile_training.py +44 -17
- scripts/train.py +122 -76
- scripts/train_multiseed.py +42 -23
- scripts/visualize_training.py +406 -193
- src/data/dataset.py +19 -13
- src/models/decoder.py +10 -3
- src/models/encoder.py +11 -3
- src/models/factory.py +12 -6
- src/models/heads.py +4 -2
- src/training/metrics.py +57 -59
- src/training/trainer.py +73 -50
- src/utils/__init__.py +12 -5
- src/utils/core.py +9 -7
- tests/test_training/test_trainer.py +10 -10
README.md
CHANGED
|
@@ -11,56 +11,76 @@ pinned: false
|
|
| 11 |
<!-- markdownlint-disable MD025 -->
|
| 12 |
# LexiMind
|
| 13 |
|
| 14 |
-
A multi-task NLP system for literary and academic text understanding. LexiMind performs **abstractive summarization**, **topic classification**, and **emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
|
| 15 |
|
| 16 |
**[Live Demo](https://huggingface.co/spaces/OliverPerrin/LexiMind)** · **[Model](https://huggingface.co/OliverPerrin/LexiMind-Model)** · **[Discovery Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)** · **[Research Paper](docs/research_paper.tex)**
|
| 17 |
|
| 18 |
-
##
|
| 19 |
|
| 20 |
-
| Task |
|
| 21 |
-
| ----
|
| 22 |
-
|
|
| 23 |
-
|
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
## Architecture
|
| 31 |
|
| 32 |
-
LexiMind is a **
|
| 33 |
|
| 34 |
| Component | Detail |
|
| 35 |
-
| ---------
|
| 36 |
| Backbone | Encoder-Decoder Transformer (272M params) |
|
| 37 |
-
| Encoder / Decoder | 12 layers each |
|
| 38 |
-
|
|
| 39 |
-
|
|
| 40 |
-
|
|
| 41 |
-
|
|
| 42 |
-
|
|
| 43 |
-
|
|
| 44 |
|
| 45 |
### Multi-Task Training
|
| 46 |
|
| 47 |
-
All three tasks share the encoder. Summarization uses the full encoder-decoder;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
## Training Data
|
| 50 |
|
| 51 |
-
| Task | Source |
|
| 52 |
-
| ----
|
| 53 |
-
| Summarization | Gutenberg + Goodreads (literary) | ~4K |
|
| 54 |
| Summarization | arXiv body → abstract (academic) | ~45K |
|
| 55 |
-
| Topic |
|
| 56 |
-
| Emotion | GoEmotions
|
|
|
|
|
|
|
| 57 |
|
| 58 |
## Getting Started
|
| 59 |
|
| 60 |
### Prerequisites
|
| 61 |
|
| 62 |
- Python 3.10+
|
| 63 |
-
- [Poetry](https://python-poetry.org/) for dependency management
|
| 64 |
- NVIDIA GPU with CUDA (for training; CPU works for inference)
|
| 65 |
|
| 66 |
### Installation
|
|
@@ -68,59 +88,50 @@ All three tasks share the encoder. Summarization uses the full encoder-decoder;
|
|
| 68 |
```bash
|
| 69 |
git clone https://github.com/OliverPerrin/LexiMind.git
|
| 70 |
cd LexiMind
|
| 71 |
-
|
| 72 |
```
|
| 73 |
|
| 74 |
-
### Download Data
|
| 75 |
-
|
| 76 |
-
```bash
|
| 77 |
-
poetry run python scripts/download_data.py
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
Downloads Goodreads descriptions, arXiv papers, GoEmotions, 20 Newsgroups, and Gutenberg texts.
|
| 81 |
-
|
| 82 |
### Training
|
| 83 |
|
| 84 |
```bash
|
| 85 |
-
# Full training (~
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Quick dev run (~10-15 min)
|
| 89 |
-
poetry run python scripts/train.py training=dev
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
|
| 93 |
|
| 94 |
# Override parameters
|
| 95 |
-
|
| 96 |
|
| 97 |
# Resume from checkpoint
|
| 98 |
-
|
| 99 |
```
|
| 100 |
|
| 101 |
-
|
| 102 |
|
| 103 |
### Evaluation
|
| 104 |
|
| 105 |
```bash
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
# Skip BERTScore for faster runs
|
| 110 |
-
poetry run python scripts/evaluate.py --skip-bertscore
|
| 111 |
-
|
| 112 |
-
# Single task
|
| 113 |
-
poetry run python scripts/evaluate.py --summarization-only
|
| 114 |
```
|
| 115 |
|
| 116 |
### Inference
|
| 117 |
|
| 118 |
```bash
|
| 119 |
# Command-line
|
| 120 |
-
|
| 121 |
|
| 122 |
# Gradio web demo
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
```
|
| 125 |
|
| 126 |
### Docker
|
|
@@ -133,76 +144,39 @@ docker run -p 7860:7860 leximind
|
|
| 133 |
## Project Structure
|
| 134 |
|
| 135 |
```text
|
| 136 |
-
configs/
|
| 137 |
-
├── config.yaml # Main Hydra config
|
| 138 |
-
├── data/datasets.yaml # Dataset paths and tokenizer settings
|
| 139 |
-
├── model/ # Architecture configs (base, small, large)
|
| 140 |
-
└── training/ # Training configs (dev, medium, full)
|
| 141 |
-
|
| 142 |
src/
|
| 143 |
-
├── models/
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
│ ├── t5_layer_norm.py # T5-style RMSNorm
|
| 150 |
-
│ ├── heads.py # Task-specific classification heads
|
| 151 |
-
│ ├── multitask.py # Multi-task model combining all components
|
| 152 |
-
│ └── factory.py # Model builder with FLAN-T5 weight loading
|
| 153 |
-
├── data/
|
| 154 |
-
│ ├── dataset.py # Dataset classes for all tasks
|
| 155 |
-
│ ├── dataloader.py # Multi-task dataloader with round-robin sampling
|
| 156 |
-
│ └── tokenization.py # Tokenizer wrapper
|
| 157 |
-
├── training/
|
| 158 |
-
│ ├── trainer.py # Training loop with AMP, grad accumulation, early stopping
|
| 159 |
-
│ ├── metrics.py # ROUGE, BERTScore, F1, accuracy computation
|
| 160 |
-
│ └── utils.py # Checkpointing, logging utilities
|
| 161 |
-
├── inference/
|
| 162 |
-
│ ├── pipeline.py # End-to-end inference pipeline
|
| 163 |
-
│ └── factory.py # Model loading for inference
|
| 164 |
-
├── api/ # FastAPI REST endpoint
|
| 165 |
-
└── utils/ # Shared utilities
|
| 166 |
|
| 167 |
scripts/
|
| 168 |
-
├── train.py
|
| 169 |
-
├── evaluate.py
|
| 170 |
-
├── inference.py
|
| 171 |
-
├── demo_gradio.py
|
| 172 |
-
├──
|
| 173 |
-
├──
|
| 174 |
-
├──
|
| 175 |
-
├──
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
tests/
|
| 181 |
-
docs/ # Research paper and architecture notes
|
| 182 |
-
artifacts/ # Tokenizer files and label definitions
|
| 183 |
-
checkpoints/ # Saved model checkpoints
|
| 184 |
```
|
| 185 |
|
| 186 |
## Code Quality
|
| 187 |
|
| 188 |
```bash
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
```
|
| 194 |
|
| 195 |
-
## Key Results
|
| 196 |
-
|
| 197 |
-
From the research paper ([docs/research_paper.tex](docs/research_paper.tex)):
|
| 198 |
-
|
| 199 |
-
- **Multi-task learning helps topic classification** (+3.2% accuracy over single-task) because the small topic dataset (3.4K) benefits from shared encoder representations trained on the larger summarization corpus (49K).
|
| 200 |
-
- **Summarization is robust to MTL**—quality stays comparable whether trained alone or jointly.
|
| 201 |
-
- **Emotion detection shows slight negative transfer** (−0.02 F1), likely due to domain mismatch between Reddit-sourced emotion labels and literary/academic text.
|
| 202 |
-
- **FLAN-T5 pre-training is essential**—random initialization produces dramatically worse results on all tasks.
|
| 203 |
-
|
| 204 |
-
See the paper for full ablations, per-class breakdowns, and discussion of limitations.
|
| 205 |
-
|
| 206 |
## License
|
| 207 |
|
| 208 |
GPL-3.0 — see [LICENSE](LICENSE) for details.
|
|
|
|
| 11 |
<!-- markdownlint-disable MD025 -->
|
| 12 |
# LexiMind
|
| 13 |
|
| 14 |
+
A multi-task NLP system for literary and academic text understanding. LexiMind jointly performs **abstractive summarization**, **topic classification**, and **multi-label emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
|
| 15 |
|
| 16 |
**[Live Demo](https://huggingface.co/spaces/OliverPerrin/LexiMind)** · **[Model](https://huggingface.co/OliverPerrin/LexiMind-Model)** · **[Discovery Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)** · **[Research Paper](docs/research_paper.tex)**
|
| 17 |
|
| 18 |
+
## Results
|
| 19 |
|
| 20 |
+
| Task | Metric | Score |
|
| 21 |
+
| ---- | ------ | ----- |
|
| 22 |
+
| Summarization | ROUGE-1 / ROUGE-L | 0.309 / 0.185 |
|
| 23 |
+
| Summarization (academic) | ROUGE-1 | 0.319 |
|
| 24 |
+
| Summarization (literary) | ROUGE-1 | 0.206 |
|
| 25 |
+
| Topic Classification | Accuracy (95% CI) | 85.7% (80.4–91.0%) |
|
| 26 |
+
| Emotion Detection | Sample-avg F1 | 0.352 |
|
| 27 |
+
| Emotion Detection (tuned thresholds) | Sample-avg F1 / Macro F1 | 0.503 / 0.294 |
|
| 28 |
|
| 29 |
+
Trained for 8 epochs on an RTX 4070 12GB (~9 hours) with BFloat16 mixed precision, `torch.compile`, and cosine LR decay.
|
| 30 |
|
| 31 |
+
## Key Findings
|
| 32 |
+
|
| 33 |
+
From the [research paper](docs/research_paper.tex):
|
| 34 |
+
|
| 35 |
+
- **Naive MTL produces mixed results**: topic classification benefits (+3.7% accuracy), but emotion detection suffers negative transfer (−0.02 F1) under mean pooling with round-robin scheduling.
|
| 36 |
+
- **Learned attention pooling + temperature sampling eliminates negative transfer entirely**: emotion F1 improves from 0.199 → 0.352 (+77%), surpassing the single-task baseline (0.218).
|
| 37 |
+
- **Summarization is robust to MTL** — quality remains stable across configurations.
|
| 38 |
+
- **FLAN-T5 pre-training is essential** — random initialization produces dramatically worse results on all tasks.
|
| 39 |
+
- **Domain gap matters**: academic summaries (ROUGE-1: 0.319) substantially outperform literary (0.206), driven by an 11:1 training data imbalance.
|
| 40 |
|
| 41 |
## Architecture
|
| 42 |
|
| 43 |
+
LexiMind is a **from-scratch PyTorch Transformer** that loads pre-trained FLAN-T5-base weights layer by layer via a custom factory module — no HuggingFace model wrappers.
|
| 44 |
|
| 45 |
| Component | Detail |
|
| 46 |
+
| --------- | ------ |
|
| 47 |
| Backbone | Encoder-Decoder Transformer (272M params) |
|
| 48 |
+
| Encoder / Decoder | 12 layers each, 768d, 12 attention heads |
|
| 49 |
+
| Normalization | RMSNorm (Pre-LN, T5-style) |
|
| 50 |
+
| Attention | FlashAttention via PyTorch SDPA + T5 relative position bias |
|
| 51 |
+
| FFN | Gated-GELU (wi\_0, wi\_1, wo) |
|
| 52 |
+
| Summarization | Full decoder → language modeling head |
|
| 53 |
+
| Emotion (28-class multi-label) | Learned attention pooling → linear head |
|
| 54 |
+
| Topic (7-class) | Mean pooling → linear head |
|
| 55 |
|
| 56 |
### Multi-Task Training
|
| 57 |
|
| 58 |
+
All three tasks share the encoder. Summarization uses the full encoder-decoder; classification heads branch off the encoder output. Key training details:
|
| 59 |
+
|
| 60 |
+
- **Temperature-based task sampling** (α=0.5): allocates training steps proportional to dataset size, preventing large tasks from dominating
|
| 61 |
+
- **Attention pooling** for emotion: a learned query attends over encoder outputs, focusing on emotionally salient tokens rather than averaging the full sequence
|
| 62 |
+
- **Fixed loss weights**: summarization=1.0, emotion=1.0, topic=0.3 (reduced to prevent overfitting on the small topic dataset)
|
| 63 |
+
- **Frozen encoder layers 0–3**: preserves FLAN-T5's language understanding in lower layers
|
| 64 |
+
- **Gradient conflict diagnostics**: optional inter-task gradient cosine similarity monitoring
|
| 65 |
+
|
| 66 |
+
See [docs/architecture.md](docs/architecture.md) for full implementation details, weight loading tables, and training configuration rationale.
|
| 67 |
|
| 68 |
## Training Data
|
| 69 |
|
| 70 |
+
| Task | Source | Samples |
|
| 71 |
+
| ---- | ------ | ------- |
|
| 72 |
+
| Summarization | Gutenberg + Goodreads descriptions (literary) | ~4K |
|
| 73 |
| Summarization | arXiv body → abstract (academic) | ~45K |
|
| 74 |
+
| Topic | Gutenberg + arXiv metadata → 7 categories | 3,402 |
|
| 75 |
+
| Emotion | GoEmotions — Reddit comments, 28 labels | 43,410 |
|
| 76 |
+
|
| 77 |
+
For summarization, the model learns to produce descriptive summaries — what a book *is about* — rather than plot recaps, by pairing Gutenberg full texts with Goodreads descriptions and arXiv papers with their abstracts.
|
| 78 |
|
| 79 |
## Getting Started
|
| 80 |
|
| 81 |
### Prerequisites
|
| 82 |
|
| 83 |
- Python 3.10+
|
|
|
|
| 84 |
- NVIDIA GPU with CUDA (for training; CPU works for inference)
|
| 85 |
|
| 86 |
### Installation
|
|
|
|
| 88 |
```bash
|
| 89 |
git clone https://github.com/OliverPerrin/LexiMind.git
|
| 90 |
cd LexiMind
|
| 91 |
+
pip install -r requirements.txt
|
| 92 |
```
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
### Training
|
| 95 |
|
| 96 |
```bash
|
| 97 |
+
# Full training (~9 hours on RTX 4070 12GB)
|
| 98 |
+
python scripts/train.py training=full
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
# Quick dev run
|
| 101 |
+
python scripts/train.py training=dev
|
| 102 |
|
| 103 |
# Override parameters
|
| 104 |
+
python scripts/train.py training=full training.optimizer.lr=5e-5
|
| 105 |
|
| 106 |
# Resume from checkpoint
|
| 107 |
+
python scripts/train.py training=full resume_from=checkpoints/epoch_5.pt
|
| 108 |
```
|
| 109 |
|
| 110 |
+
Experiments are tracked with MLflow (`mlflow ui` to browse).
|
| 111 |
|
| 112 |
### Evaluation
|
| 113 |
|
| 114 |
```bash
|
| 115 |
+
python scripts/evaluate.py
|
| 116 |
+
python scripts/evaluate.py --skip-bertscore # faster
|
| 117 |
+
python scripts/evaluate.py --tune-thresholds # per-class threshold tuning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
```
|
| 119 |
|
| 120 |
### Inference
|
| 121 |
|
| 122 |
```bash
|
| 123 |
# Command-line
|
| 124 |
+
python scripts/inference.py "Your text to analyze"
|
| 125 |
|
| 126 |
# Gradio web demo
|
| 127 |
+
python scripts/demo_gradio.py
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Profiling
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
# Profile GPU usage (CUDA kernels, memory, Chrome trace)
|
| 134 |
+
python scripts/profile_training.py
|
| 135 |
```
|
| 136 |
|
| 137 |
### Docker
|
|
|
|
| 144 |
## Project Structure
|
| 145 |
|
| 146 |
```text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
src/
|
| 148 |
+
├── models/ # Encoder, decoder, attention, FFN, heads, factory
|
| 149 |
+
├── data/ # Datasets, dataloaders, tokenization, cross-task dedup
|
| 150 |
+
├── training/ # Trainer (AMP, grad accum, temperature sampling), metrics
|
| 151 |
+
├── inference/ # Pipeline + factory for checkpoint loading
|
| 152 |
+
├── api/ # FastAPI REST endpoint
|
| 153 |
+
└── utils/ # Device detection, checkpointing, label I/O
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
scripts/
|
| 156 |
+
├── train.py # Hydra training entry point
|
| 157 |
+
├── evaluate.py # Full evaluation suite
|
| 158 |
+
├── inference.py # CLI inference
|
| 159 |
+
├── demo_gradio.py # Gradio discovery demo
|
| 160 |
+
├── profile_training.py # PyTorch profiler
|
| 161 |
+
├── train_multiseed.py # Multi-seed training with aggregation
|
| 162 |
+
├── visualize_training.py # Training curve visualization
|
| 163 |
+
├── download_data.py # Dataset downloader
|
| 164 |
+
└── build_discovery_dataset.py # Pre-compute discovery dataset
|
| 165 |
+
|
| 166 |
+
configs/ # Hydra configs (model, training, data)
|
| 167 |
+
docs/ # Research paper + architecture documentation
|
| 168 |
+
tests/ # Pytest suite
|
|
|
|
|
|
|
|
|
|
| 169 |
```
|
| 170 |
|
| 171 |
## Code Quality
|
| 172 |
|
| 173 |
```bash
|
| 174 |
+
ruff check . # Linting
|
| 175 |
+
mypy src/ scripts/ tests/ # Type checking
|
| 176 |
+
pytest # Tests
|
| 177 |
+
pre-commit run --all-files # All checks
|
| 178 |
```
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
## License
|
| 181 |
|
| 182 |
GPL-3.0 — see [LICENSE](LICENSE) for details.
|
scripts/build_discovery_dataset.py
CHANGED
|
@@ -29,134 +29,140 @@ from src.inference.factory import create_inference_pipeline # noqa: E402
|
|
| 29 |
|
| 30 |
# --------------- Data Loading ---------------
|
| 31 |
|
|
|
|
| 32 |
def load_academic_papers(data_dir: Path, max_samples: int = 300) -> list[dict]:
|
| 33 |
"""Load academic paper samples from the training data."""
|
| 34 |
summ_file = data_dir / "summarization" / "train.jsonl"
|
| 35 |
-
|
| 36 |
if not summ_file.exists():
|
| 37 |
print(f" Warning: {summ_file} not found")
|
| 38 |
return []
|
| 39 |
-
|
| 40 |
academic = []
|
| 41 |
with open(summ_file) as f:
|
| 42 |
for line in f:
|
| 43 |
item = json.loads(line)
|
| 44 |
if item.get("type") != "academic":
|
| 45 |
continue
|
| 46 |
-
|
| 47 |
text = item.get("source", "")
|
| 48 |
if len(text) < 500:
|
| 49 |
continue
|
| 50 |
-
|
| 51 |
# Use title from data
|
| 52 |
title = item.get("title", "Research Paper")
|
| 53 |
-
|
| 54 |
-
academic.append(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
random.seed(42)
|
| 61 |
samples = random.sample(academic, min(max_samples, len(academic)))
|
| 62 |
-
|
| 63 |
results = []
|
| 64 |
for i, item in enumerate(samples):
|
| 65 |
-
results.append(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
print(f" Loaded {len(results)} academic papers")
|
| 75 |
return results
|
| 76 |
|
| 77 |
|
| 78 |
def load_literary(data_dir: Path, max_samples: int = 300) -> list[dict]:
|
| 79 |
"""Load literary samples from the training data.
|
| 80 |
-
|
| 81 |
Training data now contains Goodreads descriptions (back-cover style)
|
| 82 |
instead of plot summaries.
|
| 83 |
"""
|
| 84 |
summ_file = data_dir / "summarization" / "train.jsonl"
|
| 85 |
-
|
| 86 |
if not summ_file.exists():
|
| 87 |
print(f" Warning: {summ_file} not found")
|
| 88 |
return []
|
| 89 |
-
|
| 90 |
literary = []
|
| 91 |
seen_titles = set()
|
| 92 |
-
|
| 93 |
with open(summ_file) as f:
|
| 94 |
for line in f:
|
| 95 |
item = json.loads(line)
|
| 96 |
if item.get("type") != "literary":
|
| 97 |
continue
|
| 98 |
-
|
| 99 |
title = item.get("title", "")
|
| 100 |
if not title or title in seen_titles:
|
| 101 |
continue
|
| 102 |
-
|
| 103 |
text = item.get("source", "")
|
| 104 |
summary = item.get("summary", "")
|
| 105 |
-
|
| 106 |
if len(text) < 500 or len(summary) < 50:
|
| 107 |
continue
|
| 108 |
-
|
| 109 |
seen_titles.add(title)
|
| 110 |
-
literary.append(
|
| 111 |
-
"text": text[:2000],
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
})
|
| 115 |
-
|
| 116 |
random.seed(42)
|
| 117 |
samples = random.sample(literary, min(max_samples, len(literary)))
|
| 118 |
-
|
| 119 |
results = []
|
| 120 |
for i, item in enumerate(samples):
|
| 121 |
-
results.append(
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
print(f" Loaded {len(results)} literary works (unique titles)")
|
| 131 |
return results
|
| 132 |
|
| 133 |
|
| 134 |
# --------------- Inference ---------------
|
| 135 |
|
|
|
|
| 136 |
def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
|
| 137 |
"""Run model inference on all samples."""
|
| 138 |
results = []
|
| 139 |
-
|
| 140 |
for sample in tqdm(samples, desc="Running inference"):
|
| 141 |
text = sample["text"]
|
| 142 |
-
|
| 143 |
# Get model predictions using correct pipeline methods
|
| 144 |
summaries = pipeline.summarize([text])
|
| 145 |
topics = pipeline.predict_topics([text])
|
| 146 |
emotions = pipeline.predict_emotions([text])
|
| 147 |
-
|
| 148 |
# Extract first result from each list
|
| 149 |
summary = summaries[0] if summaries else ""
|
| 150 |
topic = topics[0] if topics else None
|
| 151 |
emotion = emotions[0] if emotions else None
|
| 152 |
-
|
| 153 |
# Get primary emotion (highest confidence if any detected)
|
| 154 |
primary_emotion = "neutral"
|
| 155 |
emotion_confidence = 0.0
|
| 156 |
if emotion and emotion.labels:
|
| 157 |
primary_emotion = emotion.labels[0]
|
| 158 |
emotion_confidence = emotion.scores[0]
|
| 159 |
-
|
| 160 |
result = {
|
| 161 |
"id": sample["id"],
|
| 162 |
"title": sample["title"],
|
|
@@ -170,24 +176,25 @@ def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
|
|
| 170 |
"generated_summary": summary,
|
| 171 |
"reference_summary": sample.get("reference_summary", ""),
|
| 172 |
}
|
| 173 |
-
|
| 174 |
results.append(result)
|
| 175 |
-
|
| 176 |
# Print distribution stats
|
| 177 |
topic_dist: dict[str, int] = defaultdict(int)
|
| 178 |
emotion_dist: dict[str, int] = defaultdict(int)
|
| 179 |
for r in results:
|
| 180 |
topic_dist[r["topic"]] += 1
|
| 181 |
emotion_dist[r["emotion"]] += 1
|
| 182 |
-
|
| 183 |
print(f"\nTopic distribution: {dict(topic_dist)}")
|
| 184 |
print(f"Emotion distribution: {dict(emotion_dist)}")
|
| 185 |
-
|
| 186 |
return results
|
| 187 |
|
| 188 |
|
| 189 |
def main():
|
| 190 |
import argparse
|
|
|
|
| 191 |
parser = argparse.ArgumentParser(description="Build discovery dataset for HuggingFace Space")
|
| 192 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 193 |
parser.add_argument("--checkpoint", type=Path, default=Path("checkpoints/best.pt"))
|
|
@@ -197,41 +204,39 @@ def main():
|
|
| 197 |
parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
|
| 198 |
parser.add_argument("--hub-repo", type=str, default="OliverPerrin/LexiMind-Discovery")
|
| 199 |
args = parser.parse_args()
|
| 200 |
-
|
| 201 |
print("Loading data samples from training data...")
|
| 202 |
print("(Data has already been filtered by download_data.py)")
|
| 203 |
-
|
| 204 |
# Load samples from training data
|
| 205 |
papers = load_academic_papers(args.data_dir, args.num_papers)
|
| 206 |
literary = load_literary(args.data_dir, args.num_literary)
|
| 207 |
-
|
| 208 |
all_samples = papers + literary
|
| 209 |
print(f"\nTotal samples: {len(all_samples)} ({len(papers)} papers, {len(literary)} literary)")
|
| 210 |
-
|
| 211 |
if not all_samples:
|
| 212 |
print("ERROR: No samples loaded! Check if data/processed exists and has data.")
|
| 213 |
print("Run: python scripts/download_data.py --task summarization")
|
| 214 |
return
|
| 215 |
-
|
| 216 |
# Load model and run inference
|
| 217 |
print(f"\nLoading model from {args.checkpoint}...")
|
| 218 |
labels_path = Path("artifacts/labels.json")
|
| 219 |
pipeline, labels = create_inference_pipeline(
|
| 220 |
-
args.checkpoint,
|
| 221 |
-
labels_path,
|
| 222 |
-
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 223 |
)
|
| 224 |
-
|
| 225 |
print("Running inference on all samples...")
|
| 226 |
results = run_inference(pipeline, all_samples)
|
| 227 |
-
|
| 228 |
# Save locally
|
| 229 |
print(f"\nSaving to {args.output}...")
|
| 230 |
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 231 |
with open(args.output, "w") as f:
|
| 232 |
for item in results:
|
| 233 |
f.write(json.dumps(item) + "\n")
|
| 234 |
-
|
| 235 |
# Push to HuggingFace Hub
|
| 236 |
if args.push_to_hub:
|
| 237 |
print(f"\nPushing to HuggingFace Hub: {args.hub_repo}")
|
|
@@ -239,10 +244,10 @@ def main():
|
|
| 239 |
dataset.push_to_hub(
|
| 240 |
args.hub_repo,
|
| 241 |
private=False,
|
| 242 |
-
commit_message="Rebuild with Goodreads descriptions (back-cover style)"
|
| 243 |
)
|
| 244 |
print(f"Dataset available at: https://huggingface.co/datasets/{args.hub_repo}")
|
| 245 |
-
|
| 246 |
print("\nDone!")
|
| 247 |
|
| 248 |
|
|
|
|
| 29 |
|
| 30 |
# --------------- Data Loading ---------------
|
| 31 |
|
| 32 |
+
|
| 33 |
def load_academic_papers(data_dir: Path, max_samples: int = 300) -> list[dict]:
|
| 34 |
"""Load academic paper samples from the training data."""
|
| 35 |
summ_file = data_dir / "summarization" / "train.jsonl"
|
| 36 |
+
|
| 37 |
if not summ_file.exists():
|
| 38 |
print(f" Warning: {summ_file} not found")
|
| 39 |
return []
|
| 40 |
+
|
| 41 |
academic = []
|
| 42 |
with open(summ_file) as f:
|
| 43 |
for line in f:
|
| 44 |
item = json.loads(line)
|
| 45 |
if item.get("type") != "academic":
|
| 46 |
continue
|
| 47 |
+
|
| 48 |
text = item.get("source", "")
|
| 49 |
if len(text) < 500:
|
| 50 |
continue
|
| 51 |
+
|
| 52 |
# Use title from data
|
| 53 |
title = item.get("title", "Research Paper")
|
| 54 |
+
|
| 55 |
+
academic.append(
|
| 56 |
+
{
|
| 57 |
+
"text": text[:2000],
|
| 58 |
+
"title": title,
|
| 59 |
+
"reference_summary": item.get("summary", "")[:500],
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
random.seed(42)
|
| 64 |
samples = random.sample(academic, min(max_samples, len(academic)))
|
| 65 |
+
|
| 66 |
results = []
|
| 67 |
for i, item in enumerate(samples):
|
| 68 |
+
results.append(
|
| 69 |
+
{
|
| 70 |
+
"id": f"paper_{i}",
|
| 71 |
+
"title": item["title"],
|
| 72 |
+
"text": item["text"],
|
| 73 |
+
"source_type": "academic",
|
| 74 |
+
"dataset": "arxiv",
|
| 75 |
+
"reference_summary": item["reference_summary"],
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
print(f" Loaded {len(results)} academic papers")
|
| 80 |
return results
|
| 81 |
|
| 82 |
|
| 83 |
def load_literary(data_dir: Path, max_samples: int = 300) -> list[dict]:
|
| 84 |
"""Load literary samples from the training data.
|
| 85 |
+
|
| 86 |
Training data now contains Goodreads descriptions (back-cover style)
|
| 87 |
instead of plot summaries.
|
| 88 |
"""
|
| 89 |
summ_file = data_dir / "summarization" / "train.jsonl"
|
| 90 |
+
|
| 91 |
if not summ_file.exists():
|
| 92 |
print(f" Warning: {summ_file} not found")
|
| 93 |
return []
|
| 94 |
+
|
| 95 |
literary = []
|
| 96 |
seen_titles = set()
|
| 97 |
+
|
| 98 |
with open(summ_file) as f:
|
| 99 |
for line in f:
|
| 100 |
item = json.loads(line)
|
| 101 |
if item.get("type") != "literary":
|
| 102 |
continue
|
| 103 |
+
|
| 104 |
title = item.get("title", "")
|
| 105 |
if not title or title in seen_titles:
|
| 106 |
continue
|
| 107 |
+
|
| 108 |
text = item.get("source", "")
|
| 109 |
summary = item.get("summary", "")
|
| 110 |
+
|
| 111 |
if len(text) < 500 or len(summary) < 50:
|
| 112 |
continue
|
| 113 |
+
|
| 114 |
seen_titles.add(title)
|
| 115 |
+
literary.append(
|
| 116 |
+
{"text": text[:2000], "title": title, "reference_summary": summary[:600]}
|
| 117 |
+
)
|
| 118 |
+
|
|
|
|
|
|
|
| 119 |
random.seed(42)
|
| 120 |
samples = random.sample(literary, min(max_samples, len(literary)))
|
| 121 |
+
|
| 122 |
results = []
|
| 123 |
for i, item in enumerate(samples):
|
| 124 |
+
results.append(
|
| 125 |
+
{
|
| 126 |
+
"id": f"literary_{i}",
|
| 127 |
+
"title": item["title"],
|
| 128 |
+
"text": item["text"],
|
| 129 |
+
"source_type": "literary",
|
| 130 |
+
"dataset": "goodreads",
|
| 131 |
+
"reference_summary": item["reference_summary"],
|
| 132 |
+
}
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
print(f" Loaded {len(results)} literary works (unique titles)")
|
| 136 |
return results
|
| 137 |
|
| 138 |
|
| 139 |
# --------------- Inference ---------------
|
| 140 |
|
| 141 |
+
|
| 142 |
def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
|
| 143 |
"""Run model inference on all samples."""
|
| 144 |
results = []
|
| 145 |
+
|
| 146 |
for sample in tqdm(samples, desc="Running inference"):
|
| 147 |
text = sample["text"]
|
| 148 |
+
|
| 149 |
# Get model predictions using correct pipeline methods
|
| 150 |
summaries = pipeline.summarize([text])
|
| 151 |
topics = pipeline.predict_topics([text])
|
| 152 |
emotions = pipeline.predict_emotions([text])
|
| 153 |
+
|
| 154 |
# Extract first result from each list
|
| 155 |
summary = summaries[0] if summaries else ""
|
| 156 |
topic = topics[0] if topics else None
|
| 157 |
emotion = emotions[0] if emotions else None
|
| 158 |
+
|
| 159 |
# Get primary emotion (highest confidence if any detected)
|
| 160 |
primary_emotion = "neutral"
|
| 161 |
emotion_confidence = 0.0
|
| 162 |
if emotion and emotion.labels:
|
| 163 |
primary_emotion = emotion.labels[0]
|
| 164 |
emotion_confidence = emotion.scores[0]
|
| 165 |
+
|
| 166 |
result = {
|
| 167 |
"id": sample["id"],
|
| 168 |
"title": sample["title"],
|
|
|
|
| 176 |
"generated_summary": summary,
|
| 177 |
"reference_summary": sample.get("reference_summary", ""),
|
| 178 |
}
|
| 179 |
+
|
| 180 |
results.append(result)
|
| 181 |
+
|
| 182 |
# Print distribution stats
|
| 183 |
topic_dist: dict[str, int] = defaultdict(int)
|
| 184 |
emotion_dist: dict[str, int] = defaultdict(int)
|
| 185 |
for r in results:
|
| 186 |
topic_dist[r["topic"]] += 1
|
| 187 |
emotion_dist[r["emotion"]] += 1
|
| 188 |
+
|
| 189 |
print(f"\nTopic distribution: {dict(topic_dist)}")
|
| 190 |
print(f"Emotion distribution: {dict(emotion_dist)}")
|
| 191 |
+
|
| 192 |
return results
|
| 193 |
|
| 194 |
|
| 195 |
def main():
|
| 196 |
import argparse
|
| 197 |
+
|
| 198 |
parser = argparse.ArgumentParser(description="Build discovery dataset for HuggingFace Space")
|
| 199 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 200 |
parser.add_argument("--checkpoint", type=Path, default=Path("checkpoints/best.pt"))
|
|
|
|
| 204 |
parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
|
| 205 |
parser.add_argument("--hub-repo", type=str, default="OliverPerrin/LexiMind-Discovery")
|
| 206 |
args = parser.parse_args()
|
| 207 |
+
|
| 208 |
print("Loading data samples from training data...")
|
| 209 |
print("(Data has already been filtered by download_data.py)")
|
| 210 |
+
|
| 211 |
# Load samples from training data
|
| 212 |
papers = load_academic_papers(args.data_dir, args.num_papers)
|
| 213 |
literary = load_literary(args.data_dir, args.num_literary)
|
| 214 |
+
|
| 215 |
all_samples = papers + literary
|
| 216 |
print(f"\nTotal samples: {len(all_samples)} ({len(papers)} papers, {len(literary)} literary)")
|
| 217 |
+
|
| 218 |
if not all_samples:
|
| 219 |
print("ERROR: No samples loaded! Check if data/processed exists and has data.")
|
| 220 |
print("Run: python scripts/download_data.py --task summarization")
|
| 221 |
return
|
| 222 |
+
|
| 223 |
# Load model and run inference
|
| 224 |
print(f"\nLoading model from {args.checkpoint}...")
|
| 225 |
labels_path = Path("artifacts/labels.json")
|
| 226 |
pipeline, labels = create_inference_pipeline(
|
| 227 |
+
args.checkpoint, labels_path, device="cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
+
|
| 230 |
print("Running inference on all samples...")
|
| 231 |
results = run_inference(pipeline, all_samples)
|
| 232 |
+
|
| 233 |
# Save locally
|
| 234 |
print(f"\nSaving to {args.output}...")
|
| 235 |
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 236 |
with open(args.output, "w") as f:
|
| 237 |
for item in results:
|
| 238 |
f.write(json.dumps(item) + "\n")
|
| 239 |
+
|
| 240 |
# Push to HuggingFace Hub
|
| 241 |
if args.push_to_hub:
|
| 242 |
print(f"\nPushing to HuggingFace Hub: {args.hub_repo}")
|
|
|
|
| 244 |
dataset.push_to_hub(
|
| 245 |
args.hub_repo,
|
| 246 |
private=False,
|
| 247 |
+
commit_message="Rebuild with Goodreads descriptions (back-cover style)",
|
| 248 |
)
|
| 249 |
print(f"Dataset available at: https://huggingface.co/datasets/{args.hub_repo}")
|
| 250 |
+
|
| 251 |
print("\nDone!")
|
| 252 |
|
| 253 |
|
scripts/demo_gradio.py
CHANGED
|
@@ -27,8 +27,12 @@ print(f"Loaded {len(_dataset)} items")
|
|
| 27 |
ALL_ITEMS: list[dict[str, Any]] = [dict(row) for row in _dataset]
|
| 28 |
|
| 29 |
# Extract unique topics and emotions FROM THE DATASET (what model predicted)
|
| 30 |
-
DATASET_TOPICS: list[str] = sorted(
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Load ALL possible labels from labels.json (what the model CAN predict)
|
| 34 |
_labels_path = Path(__file__).parent.parent / "artifacts" / "labels.json"
|
|
@@ -90,19 +94,19 @@ def format_item_card(item: dict) -> str:
|
|
| 90 |
title = item.get("title", "Unknown")
|
| 91 |
source_type = item.get("source_type", "unknown")
|
| 92 |
dataset_name = item.get("dataset", "").title()
|
| 93 |
-
|
| 94 |
# Icon based on type
|
| 95 |
if source_type == "academic":
|
| 96 |
type_label = "Research Paper"
|
| 97 |
else:
|
| 98 |
type_label = "Literature"
|
| 99 |
-
|
| 100 |
# Topic and emotion with confidence
|
| 101 |
topic = item.get("topic", "Unknown")
|
| 102 |
topic_conf = item.get("topic_confidence", 0)
|
| 103 |
emotion = item.get("emotion", "Unknown")
|
| 104 |
emotion_conf = item.get("emotion_confidence", 0)
|
| 105 |
-
|
| 106 |
# Summary - check if using reference or generated
|
| 107 |
use_reference = item.get("use_reference_summary", False)
|
| 108 |
if use_reference or source_type == "literary":
|
|
@@ -111,17 +115,21 @@ def format_item_card(item: dict) -> str:
|
|
| 111 |
else:
|
| 112 |
summary = item.get("generated_summary", "")
|
| 113 |
summary_label = "**AI-Generated Description:**"
|
| 114 |
-
|
| 115 |
if not summary:
|
| 116 |
summary = "No summary available."
|
| 117 |
-
|
| 118 |
# Truncate summary if too long
|
| 119 |
if len(summary) > 400:
|
| 120 |
-
summary = summary[:400].rsplit(
|
| 121 |
-
|
| 122 |
# Preview of original text
|
| 123 |
-
text_preview =
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
return f"""### **{title}**
|
| 126 |
|
| 127 |
<small>*{type_label}* from {dataset_name}</small>
|
|
@@ -147,24 +155,24 @@ def browse_by_topic(topic: str) -> str:
|
|
| 147 |
items = get_items_by_topic(topic)
|
| 148 |
if not items:
|
| 149 |
return "No items found for this topic."
|
| 150 |
-
|
| 151 |
# Group by type
|
| 152 |
literary = [i for i in items if i.get("source_type") == "literary"]
|
| 153 |
academic = [i for i in items if i.get("source_type") == "academic"]
|
| 154 |
-
|
| 155 |
result = f"## {topic if topic != 'All' else 'All Topics'}\n\n"
|
| 156 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 157 |
-
|
| 158 |
if literary:
|
| 159 |
result += "### Literary Works\n\n"
|
| 160 |
for item in literary[:25]: # Limit to avoid huge pages
|
| 161 |
result += format_item_card(item)
|
| 162 |
-
|
| 163 |
if academic:
|
| 164 |
result += "### Academic Papers\n\n"
|
| 165 |
for item in academic[:25]:
|
| 166 |
result += format_item_card(item)
|
| 167 |
-
|
| 168 |
return result
|
| 169 |
|
| 170 |
|
|
@@ -173,23 +181,23 @@ def browse_by_emotion(emotion: str) -> str:
|
|
| 173 |
items = get_items_by_emotion(emotion)
|
| 174 |
if not items:
|
| 175 |
return "No items found for this emotion."
|
| 176 |
-
|
| 177 |
literary = [i for i in items if i.get("source_type") == "literary"]
|
| 178 |
academic = [i for i in items if i.get("source_type") == "academic"]
|
| 179 |
-
|
| 180 |
result = f"## Feeling {emotion.title() if emotion != 'All' else 'All Emotions'}?\n\n"
|
| 181 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 182 |
-
|
| 183 |
if literary:
|
| 184 |
result += "### Literary Works\n\n"
|
| 185 |
for item in literary[:25]:
|
| 186 |
result += format_item_card(item)
|
| 187 |
-
|
| 188 |
if academic:
|
| 189 |
result += "### Academic Papers\n\n"
|
| 190 |
for item in academic[:25]:
|
| 191 |
result += format_item_card(item)
|
| 192 |
-
|
| 193 |
return result
|
| 194 |
|
| 195 |
|
|
@@ -197,24 +205,25 @@ def search_items(query: str) -> str:
|
|
| 197 |
"""Search items by text content."""
|
| 198 |
if not query or len(query) < 3:
|
| 199 |
return "Enter at least 3 characters to search."
|
| 200 |
-
|
| 201 |
query_lower = query.lower()
|
| 202 |
matches = [
|
| 203 |
-
item
|
|
|
|
| 204 |
if query_lower in item.get("text", "").lower()
|
| 205 |
or query_lower in item.get("generated_summary", "").lower()
|
| 206 |
or query_lower in item.get("title", "").lower()
|
| 207 |
]
|
| 208 |
-
|
| 209 |
if not matches:
|
| 210 |
return f"No results found for '{query}'."
|
| 211 |
-
|
| 212 |
result = f"## Search Results for '{query}'\n\n"
|
| 213 |
result += f"*Found {len(matches)} matching items*\n\n"
|
| 214 |
-
|
| 215 |
for item in matches[:30]:
|
| 216 |
result += format_item_card(item)
|
| 217 |
-
|
| 218 |
return result
|
| 219 |
|
| 220 |
|
|
@@ -226,9 +235,8 @@ with gr.Blocks(
|
|
| 226 |
css="""
|
| 227 |
.result-box { max-height: 700px; overflow-y: auto; }
|
| 228 |
h3 { margin-top: 0.5em !important; }
|
| 229 |
-
"""
|
| 230 |
) as demo:
|
| 231 |
-
|
| 232 |
gr.Markdown(
|
| 233 |
"""
|
| 234 |
# LexiMind
|
|
@@ -237,79 +245,75 @@ with gr.Blocks(
|
|
| 237 |
Browse **{total_count}** texts — {lit_count} classic books and {paper_count} research papers — analyzed by a multi-task transformer.
|
| 238 |
|
| 239 |
---
|
| 240 |
-
""".format(
|
| 241 |
-
total_count=len(ALL_ITEMS),
|
| 242 |
-
lit_count=len(BOOKS),
|
| 243 |
-
paper_count=len(PAPERS)
|
| 244 |
-
)
|
| 245 |
)
|
| 246 |
-
|
| 247 |
with gr.Tabs():
|
| 248 |
# ===================== TAB 1: BROWSE BY TOPIC =====================
|
| 249 |
with gr.Tab("By Topic"):
|
| 250 |
gr.Markdown("*Select a topic to explore related books and papers*")
|
| 251 |
-
|
| 252 |
topic_dropdown = gr.Dropdown(
|
| 253 |
choices=["All"] + TOPICS,
|
| 254 |
value="All",
|
| 255 |
label="Select Topic",
|
| 256 |
interactive=True,
|
| 257 |
)
|
| 258 |
-
|
| 259 |
topic_results = gr.Markdown(
|
| 260 |
value=browse_by_topic("All"),
|
| 261 |
elem_classes=["result-box"],
|
| 262 |
)
|
| 263 |
-
|
| 264 |
topic_dropdown.change(
|
| 265 |
fn=browse_by_topic,
|
| 266 |
inputs=[topic_dropdown],
|
| 267 |
outputs=[topic_results],
|
| 268 |
)
|
| 269 |
-
|
| 270 |
# ===================== TAB 2: BROWSE BY EMOTION =====================
|
| 271 |
with gr.Tab("By Emotion"):
|
| 272 |
gr.Markdown("*Find books and papers that evoke specific emotions*")
|
| 273 |
-
|
| 274 |
emotion_dropdown = gr.Dropdown(
|
| 275 |
choices=["All"] + [e.title() for e in EMOTIONS],
|
| 276 |
value="All",
|
| 277 |
label="Select Emotion",
|
| 278 |
interactive=True,
|
| 279 |
)
|
| 280 |
-
|
| 281 |
emotion_results = gr.Markdown(
|
| 282 |
value=browse_by_emotion("All"),
|
| 283 |
elem_classes=["result-box"],
|
| 284 |
)
|
| 285 |
-
|
| 286 |
emotion_dropdown.change(
|
| 287 |
fn=lambda e: browse_by_emotion(e.lower() if e != "All" else "All"),
|
| 288 |
inputs=[emotion_dropdown],
|
| 289 |
outputs=[emotion_results],
|
| 290 |
)
|
| 291 |
-
|
| 292 |
# ===================== TAB 3: SEARCH =====================
|
| 293 |
with gr.Tab("Search"):
|
| 294 |
gr.Markdown("*Search through all books and papers by keyword*")
|
| 295 |
-
|
| 296 |
search_input = gr.Textbox(
|
| 297 |
placeholder="Enter keywords to search...",
|
| 298 |
label="Search",
|
| 299 |
interactive=True,
|
| 300 |
)
|
| 301 |
-
|
| 302 |
search_results = gr.Markdown(
|
| 303 |
value="Enter at least 3 characters to search.",
|
| 304 |
elem_classes=["result-box"],
|
| 305 |
)
|
| 306 |
-
|
| 307 |
search_input.change(
|
| 308 |
fn=search_items,
|
| 309 |
inputs=[search_input],
|
| 310 |
outputs=[search_results],
|
| 311 |
)
|
| 312 |
-
|
| 313 |
# ===================== TAB 4: METRICS =====================
|
| 314 |
with gr.Tab("Metrics"):
|
| 315 |
gr.Markdown(
|
|
@@ -319,10 +323,10 @@ with gr.Blocks(
|
|
| 319 |
Computed on held-out validation data.
|
| 320 |
"""
|
| 321 |
)
|
| 322 |
-
|
| 323 |
# Summarization Metrics
|
| 324 |
gr.Markdown("#### Summarization")
|
| 325 |
-
|
| 326 |
if METRICS.get("summarization"):
|
| 327 |
summ = METRICS["summarization"]
|
| 328 |
summ_md = """
|
|
@@ -341,10 +345,10 @@ with gr.Blocks(
|
|
| 341 |
gr.Markdown(summ_md)
|
| 342 |
else:
|
| 343 |
gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
|
| 344 |
-
|
| 345 |
# Topic Classification Metrics
|
| 346 |
gr.Markdown("#### Topic Classification")
|
| 347 |
-
|
| 348 |
if METRICS.get("topic"):
|
| 349 |
topic = METRICS["topic"]
|
| 350 |
topic_md = """
|
|
@@ -359,10 +363,10 @@ with gr.Blocks(
|
|
| 359 |
gr.Markdown(topic_md)
|
| 360 |
else:
|
| 361 |
gr.Markdown("*Topic classification metrics not available.*")
|
| 362 |
-
|
| 363 |
# Emotion Detection Metrics
|
| 364 |
gr.Markdown("#### Emotion Detection")
|
| 365 |
-
|
| 366 |
if METRICS.get("emotion"):
|
| 367 |
emotion = METRICS["emotion"]
|
| 368 |
emotion_md = """
|
|
@@ -374,17 +378,19 @@ with gr.Blocks(
|
|
| 374 |
|
| 375 |
*28-label multi-label classification from GoEmotions.*
|
| 376 |
""".format(
|
| 377 |
-
sample_f1=emotion.get(
|
|
|
|
|
|
|
| 378 |
macro_f1=emotion.get("macro_f1", 0),
|
| 379 |
micro_f1=emotion.get("micro_f1", 0),
|
| 380 |
)
|
| 381 |
gr.Markdown(emotion_md)
|
| 382 |
else:
|
| 383 |
gr.Markdown("*Emotion detection metrics not available.*")
|
| 384 |
-
|
| 385 |
# Dataset Statistics
|
| 386 |
gr.Markdown("#### Dataset Statistics")
|
| 387 |
-
|
| 388 |
gr.Markdown(f"""
|
| 389 |
| Statistic | Value |
|
| 390 |
|-----------|-------|
|
|
@@ -394,7 +400,7 @@ with gr.Blocks(
|
|
| 394 |
| Topics | {len(TOPICS)} |
|
| 395 |
| Emotions | {len(EMOTIONS)} |
|
| 396 |
""")
|
| 397 |
-
|
| 398 |
# ===================== TAB 5: ABOUT =====================
|
| 399 |
with gr.Tab("About"):
|
| 400 |
gr.Markdown(
|
|
@@ -420,4 +426,3 @@ with gr.Blocks(
|
|
| 420 |
|
| 421 |
if __name__ == "__main__":
|
| 422 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 423 |
-
|
|
|
|
| 27 |
ALL_ITEMS: list[dict[str, Any]] = [dict(row) for row in _dataset]
|
| 28 |
|
| 29 |
# Extract unique topics and emotions FROM THE DATASET (what model predicted)
|
| 30 |
+
DATASET_TOPICS: list[str] = sorted(
|
| 31 |
+
set(str(item["topic"]) for item in ALL_ITEMS if item.get("topic"))
|
| 32 |
+
)
|
| 33 |
+
DATASET_EMOTIONS: list[str] = sorted(
|
| 34 |
+
set(str(item["emotion"]) for item in ALL_ITEMS if item.get("emotion"))
|
| 35 |
+
)
|
| 36 |
|
| 37 |
# Load ALL possible labels from labels.json (what the model CAN predict)
|
| 38 |
_labels_path = Path(__file__).parent.parent / "artifacts" / "labels.json"
|
|
|
|
| 94 |
title = item.get("title", "Unknown")
|
| 95 |
source_type = item.get("source_type", "unknown")
|
| 96 |
dataset_name = item.get("dataset", "").title()
|
| 97 |
+
|
| 98 |
# Icon based on type
|
| 99 |
if source_type == "academic":
|
| 100 |
type_label = "Research Paper"
|
| 101 |
else:
|
| 102 |
type_label = "Literature"
|
| 103 |
+
|
| 104 |
# Topic and emotion with confidence
|
| 105 |
topic = item.get("topic", "Unknown")
|
| 106 |
topic_conf = item.get("topic_confidence", 0)
|
| 107 |
emotion = item.get("emotion", "Unknown")
|
| 108 |
emotion_conf = item.get("emotion_confidence", 0)
|
| 109 |
+
|
| 110 |
# Summary - check if using reference or generated
|
| 111 |
use_reference = item.get("use_reference_summary", False)
|
| 112 |
if use_reference or source_type == "literary":
|
|
|
|
| 115 |
else:
|
| 116 |
summary = item.get("generated_summary", "")
|
| 117 |
summary_label = "**AI-Generated Description:**"
|
| 118 |
+
|
| 119 |
if not summary:
|
| 120 |
summary = "No summary available."
|
| 121 |
+
|
| 122 |
# Truncate summary if too long
|
| 123 |
if len(summary) > 400:
|
| 124 |
+
summary = summary[:400].rsplit(" ", 1)[0] + "..."
|
| 125 |
+
|
| 126 |
# Preview of original text
|
| 127 |
+
text_preview = (
|
| 128 |
+
item.get("text", "")[:400] + "..."
|
| 129 |
+
if len(item.get("text", "")) > 400
|
| 130 |
+
else item.get("text", "")
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
return f"""### **{title}**
|
| 134 |
|
| 135 |
<small>*{type_label}* from {dataset_name}</small>
|
|
|
|
| 155 |
items = get_items_by_topic(topic)
|
| 156 |
if not items:
|
| 157 |
return "No items found for this topic."
|
| 158 |
+
|
| 159 |
# Group by type
|
| 160 |
literary = [i for i in items if i.get("source_type") == "literary"]
|
| 161 |
academic = [i for i in items if i.get("source_type") == "academic"]
|
| 162 |
+
|
| 163 |
result = f"## {topic if topic != 'All' else 'All Topics'}\n\n"
|
| 164 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 165 |
+
|
| 166 |
if literary:
|
| 167 |
result += "### Literary Works\n\n"
|
| 168 |
for item in literary[:25]: # Limit to avoid huge pages
|
| 169 |
result += format_item_card(item)
|
| 170 |
+
|
| 171 |
if academic:
|
| 172 |
result += "### Academic Papers\n\n"
|
| 173 |
for item in academic[:25]:
|
| 174 |
result += format_item_card(item)
|
| 175 |
+
|
| 176 |
return result
|
| 177 |
|
| 178 |
|
|
|
|
| 181 |
items = get_items_by_emotion(emotion)
|
| 182 |
if not items:
|
| 183 |
return "No items found for this emotion."
|
| 184 |
+
|
| 185 |
literary = [i for i in items if i.get("source_type") == "literary"]
|
| 186 |
academic = [i for i in items if i.get("source_type") == "academic"]
|
| 187 |
+
|
| 188 |
result = f"## Feeling {emotion.title() if emotion != 'All' else 'All Emotions'}?\n\n"
|
| 189 |
result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
|
| 190 |
+
|
| 191 |
if literary:
|
| 192 |
result += "### Literary Works\n\n"
|
| 193 |
for item in literary[:25]:
|
| 194 |
result += format_item_card(item)
|
| 195 |
+
|
| 196 |
if academic:
|
| 197 |
result += "### Academic Papers\n\n"
|
| 198 |
for item in academic[:25]:
|
| 199 |
result += format_item_card(item)
|
| 200 |
+
|
| 201 |
return result
|
| 202 |
|
| 203 |
|
|
|
|
| 205 |
"""Search items by text content."""
|
| 206 |
if not query or len(query) < 3:
|
| 207 |
return "Enter at least 3 characters to search."
|
| 208 |
+
|
| 209 |
query_lower = query.lower()
|
| 210 |
matches = [
|
| 211 |
+
item
|
| 212 |
+
for item in ALL_ITEMS
|
| 213 |
if query_lower in item.get("text", "").lower()
|
| 214 |
or query_lower in item.get("generated_summary", "").lower()
|
| 215 |
or query_lower in item.get("title", "").lower()
|
| 216 |
]
|
| 217 |
+
|
| 218 |
if not matches:
|
| 219 |
return f"No results found for '{query}'."
|
| 220 |
+
|
| 221 |
result = f"## Search Results for '{query}'\n\n"
|
| 222 |
result += f"*Found {len(matches)} matching items*\n\n"
|
| 223 |
+
|
| 224 |
for item in matches[:30]:
|
| 225 |
result += format_item_card(item)
|
| 226 |
+
|
| 227 |
return result
|
| 228 |
|
| 229 |
|
|
|
|
| 235 |
css="""
|
| 236 |
.result-box { max-height: 700px; overflow-y: auto; }
|
| 237 |
h3 { margin-top: 0.5em !important; }
|
| 238 |
+
""",
|
| 239 |
) as demo:
|
|
|
|
| 240 |
gr.Markdown(
|
| 241 |
"""
|
| 242 |
# LexiMind
|
|
|
|
| 245 |
Browse **{total_count}** texts — {lit_count} classic books and {paper_count} research papers — analyzed by a multi-task transformer.
|
| 246 |
|
| 247 |
---
|
| 248 |
+
""".format(total_count=len(ALL_ITEMS), lit_count=len(BOOKS), paper_count=len(PAPERS))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
)
|
| 250 |
+
|
| 251 |
with gr.Tabs():
|
| 252 |
# ===================== TAB 1: BROWSE BY TOPIC =====================
|
| 253 |
with gr.Tab("By Topic"):
|
| 254 |
gr.Markdown("*Select a topic to explore related books and papers*")
|
| 255 |
+
|
| 256 |
topic_dropdown = gr.Dropdown(
|
| 257 |
choices=["All"] + TOPICS,
|
| 258 |
value="All",
|
| 259 |
label="Select Topic",
|
| 260 |
interactive=True,
|
| 261 |
)
|
| 262 |
+
|
| 263 |
topic_results = gr.Markdown(
|
| 264 |
value=browse_by_topic("All"),
|
| 265 |
elem_classes=["result-box"],
|
| 266 |
)
|
| 267 |
+
|
| 268 |
topic_dropdown.change(
|
| 269 |
fn=browse_by_topic,
|
| 270 |
inputs=[topic_dropdown],
|
| 271 |
outputs=[topic_results],
|
| 272 |
)
|
| 273 |
+
|
| 274 |
# ===================== TAB 2: BROWSE BY EMOTION =====================
|
| 275 |
with gr.Tab("By Emotion"):
|
| 276 |
gr.Markdown("*Find books and papers that evoke specific emotions*")
|
| 277 |
+
|
| 278 |
emotion_dropdown = gr.Dropdown(
|
| 279 |
choices=["All"] + [e.title() for e in EMOTIONS],
|
| 280 |
value="All",
|
| 281 |
label="Select Emotion",
|
| 282 |
interactive=True,
|
| 283 |
)
|
| 284 |
+
|
| 285 |
emotion_results = gr.Markdown(
|
| 286 |
value=browse_by_emotion("All"),
|
| 287 |
elem_classes=["result-box"],
|
| 288 |
)
|
| 289 |
+
|
| 290 |
emotion_dropdown.change(
|
| 291 |
fn=lambda e: browse_by_emotion(e.lower() if e != "All" else "All"),
|
| 292 |
inputs=[emotion_dropdown],
|
| 293 |
outputs=[emotion_results],
|
| 294 |
)
|
| 295 |
+
|
| 296 |
# ===================== TAB 3: SEARCH =====================
|
| 297 |
with gr.Tab("Search"):
|
| 298 |
gr.Markdown("*Search through all books and papers by keyword*")
|
| 299 |
+
|
| 300 |
search_input = gr.Textbox(
|
| 301 |
placeholder="Enter keywords to search...",
|
| 302 |
label="Search",
|
| 303 |
interactive=True,
|
| 304 |
)
|
| 305 |
+
|
| 306 |
search_results = gr.Markdown(
|
| 307 |
value="Enter at least 3 characters to search.",
|
| 308 |
elem_classes=["result-box"],
|
| 309 |
)
|
| 310 |
+
|
| 311 |
search_input.change(
|
| 312 |
fn=search_items,
|
| 313 |
inputs=[search_input],
|
| 314 |
outputs=[search_results],
|
| 315 |
)
|
| 316 |
+
|
| 317 |
# ===================== TAB 4: METRICS =====================
|
| 318 |
with gr.Tab("Metrics"):
|
| 319 |
gr.Markdown(
|
|
|
|
| 323 |
Computed on held-out validation data.
|
| 324 |
"""
|
| 325 |
)
|
| 326 |
+
|
| 327 |
# Summarization Metrics
|
| 328 |
gr.Markdown("#### Summarization")
|
| 329 |
+
|
| 330 |
if METRICS.get("summarization"):
|
| 331 |
summ = METRICS["summarization"]
|
| 332 |
summ_md = """
|
|
|
|
| 345 |
gr.Markdown(summ_md)
|
| 346 |
else:
|
| 347 |
gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
|
| 348 |
+
|
| 349 |
# Topic Classification Metrics
|
| 350 |
gr.Markdown("#### Topic Classification")
|
| 351 |
+
|
| 352 |
if METRICS.get("topic"):
|
| 353 |
topic = METRICS["topic"]
|
| 354 |
topic_md = """
|
|
|
|
| 363 |
gr.Markdown(topic_md)
|
| 364 |
else:
|
| 365 |
gr.Markdown("*Topic classification metrics not available.*")
|
| 366 |
+
|
| 367 |
# Emotion Detection Metrics
|
| 368 |
gr.Markdown("#### Emotion Detection")
|
| 369 |
+
|
| 370 |
if METRICS.get("emotion"):
|
| 371 |
emotion = METRICS["emotion"]
|
| 372 |
emotion_md = """
|
|
|
|
| 378 |
|
| 379 |
*28-label multi-label classification from GoEmotions.*
|
| 380 |
""".format(
|
| 381 |
+
sample_f1=emotion.get(
|
| 382 |
+
"sample_avg_f1", emotion.get("f1", emotion.get("multilabel_f1", 0))
|
| 383 |
+
),
|
| 384 |
macro_f1=emotion.get("macro_f1", 0),
|
| 385 |
micro_f1=emotion.get("micro_f1", 0),
|
| 386 |
)
|
| 387 |
gr.Markdown(emotion_md)
|
| 388 |
else:
|
| 389 |
gr.Markdown("*Emotion detection metrics not available.*")
|
| 390 |
+
|
| 391 |
# Dataset Statistics
|
| 392 |
gr.Markdown("#### Dataset Statistics")
|
| 393 |
+
|
| 394 |
gr.Markdown(f"""
|
| 395 |
| Statistic | Value |
|
| 396 |
|-----------|-------|
|
|
|
|
| 400 |
| Topics | {len(TOPICS)} |
|
| 401 |
| Emotions | {len(EMOTIONS)} |
|
| 402 |
""")
|
| 403 |
+
|
| 404 |
# ===================== TAB 5: ABOUT =====================
|
| 405 |
with gr.Tab("About"):
|
| 406 |
gr.Markdown(
|
|
|
|
| 426 |
|
| 427 |
if __name__ == "__main__":
|
| 428 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
scripts/download_data.py
CHANGED
|
@@ -45,63 +45,128 @@ OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
|
|
| 45 |
|
| 46 |
# 28 emotions from GoEmotions - works for all text types
|
| 47 |
EMOTION_LABELS = [
|
| 48 |
-
"admiration",
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
]
|
| 54 |
|
| 55 |
# New topic labels for books + papers + blogs
|
| 56 |
TOPIC_LABELS = [
|
| 57 |
-
"Fiction",
|
| 58 |
-
"Science",
|
| 59 |
-
"Technology",
|
| 60 |
-
"Philosophy",
|
| 61 |
-
"History",
|
| 62 |
-
"Psychology",
|
| 63 |
-
"Business",
|
| 64 |
-
"Arts",
|
| 65 |
]
|
| 66 |
|
| 67 |
# arXiv category → our topic mapping
|
| 68 |
ARXIV_CATEGORY_MAP = {
|
| 69 |
# Computer Science
|
| 70 |
-
"cs.AI": "Technology",
|
| 71 |
-
"cs.
|
| 72 |
-
"cs.
|
| 73 |
-
"cs.
|
| 74 |
-
"cs.
|
| 75 |
-
"cs.
|
| 76 |
-
"cs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# Physics
|
| 78 |
-
"physics": "Science",
|
| 79 |
-
"
|
| 80 |
-
"
|
| 81 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"quant-ph": "Science",
|
| 83 |
# Math
|
| 84 |
"math": "Science",
|
| 85 |
# Biology/Medicine
|
| 86 |
-
"q-bio": "Science",
|
|
|
|
| 87 |
# Economics/Finance
|
| 88 |
-
"econ": "Business",
|
|
|
|
| 89 |
# Electrical Engineering
|
| 90 |
"eess": "Technology",
|
| 91 |
}
|
| 92 |
|
| 93 |
# Gutenberg subject → our topic mapping
|
| 94 |
GUTENBERG_SUBJECT_MAP = {
|
| 95 |
-
"fiction": "Fiction",
|
| 96 |
-
"
|
| 97 |
-
"
|
| 98 |
-
"
|
| 99 |
-
"
|
| 100 |
-
"
|
| 101 |
-
"
|
| 102 |
-
"
|
| 103 |
-
"
|
| 104 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
}
|
| 106 |
|
| 107 |
|
|
@@ -118,12 +183,69 @@ def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing"
|
|
| 118 |
|
| 119 |
# Common English words for detection
|
| 120 |
ENGLISH_WORDS = {
|
| 121 |
-
"the",
|
| 122 |
-
"
|
| 123 |
-
"
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
}
|
| 128 |
|
| 129 |
# Non-English language patterns
|
|
@@ -144,72 +266,126 @@ NON_ENGLISH_PATTERNS = [
|
|
| 144 |
|
| 145 |
# Patterns that indicate garbage/metadata text
|
| 146 |
GARBAGE_PATTERNS = [
|
| 147 |
-
r"^Page \d+:",
|
| 148 |
-
r"changed to",
|
| 149 |
-
r"Punctuation has been",
|
| 150 |
-
r"^\[.*\]$",
|
| 151 |
-
r"^Note\.?[-—]",
|
| 152 |
-
r"^follows:",
|
| 153 |
-
r"CHAPTER [IVXLC]+\.",
|
| 154 |
-
r"^\*\*\*",
|
| 155 |
-
r"^End of.*Project",
|
| 156 |
-
r"^Produced by",
|
| 157 |
-
r"transcriber",
|
| 158 |
-
r"eBook",
|
| 159 |
-
r"©|copyright",
|
| 160 |
-
r"^INDEX",
|
| 161 |
r"^\d+\.\s+\w+,\s+\d+", # Index entries like "1. Name, 234"
|
| 162 |
-
r"(syn\.|var\.|sp\.)",
|
| 163 |
-
r"[A-Z][a-z]+aceae",
|
| 164 |
-
r"\(\s*syn\s+",
|
| 165 |
]
|
| 166 |
|
| 167 |
# Patterns that indicate technical manuals/instructions (not narrative)
|
| 168 |
TECHNICAL_PATTERNS = [
|
| 169 |
r"\d+\.\s+It\s+(is|has|can)", # Numbered features "1. It is a..."
|
| 170 |
-
r"^\d+(st|nd|rd|th)\.",
|
| 171 |
-
r"Mesh\.?\s*\d+",
|
| 172 |
r"\d+\s*(oz|lb|kg|g|ml|mm|cm|inch)", # Measurements
|
| 173 |
-
r"Parts?\s*:?\s*\d+",
|
| 174 |
-
r"Method of Using",
|
| 175 |
-
r"How to\s+\w+",
|
| 176 |
-
r"Step\s+\d+",
|
| 177 |
-
r"wire.*address",
|
| 178 |
-
r"orders?\s+should\s+be",
|
| 179 |
-
r"specifications?",
|
| 180 |
-
r"(Front|Back)\s+Focus",
|
| 181 |
-
r"Rack and Pinion",
|
| 182 |
]
|
| 183 |
|
| 184 |
# Shakespeare and plays to exclude (model hallucinates on Early Modern English)
|
| 185 |
EXCLUDED_TITLES = {
|
| 186 |
# Shakespeare
|
| 187 |
-
"King Lear",
|
| 188 |
-
"
|
| 189 |
-
"
|
| 190 |
-
"
|
| 191 |
-
"
|
| 192 |
-
"
|
| 193 |
-
"The
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
# French plays
|
| 198 |
-
"Tartuffe",
|
| 199 |
-
"
|
| 200 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
# Greek/Roman plays
|
| 202 |
-
"Oedipus Rex",
|
| 203 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
# Other classic plays
|
| 205 |
-
"The Importance of Being Earnest",
|
| 206 |
-
"
|
| 207 |
-
"
|
| 208 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Verse/poetic epics
|
| 210 |
-
"Idylls of the King",
|
| 211 |
-
"
|
| 212 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
}
|
| 214 |
|
| 215 |
|
|
@@ -227,25 +403,25 @@ def is_quality_text(text: str) -> bool:
|
|
| 227 |
for pattern in GARBAGE_PATTERNS:
|
| 228 |
if re.search(pattern, text, re.IGNORECASE | re.MULTILINE):
|
| 229 |
return False
|
| 230 |
-
|
| 231 |
# Reject technical manuals/instructions
|
| 232 |
if is_technical_manual(text):
|
| 233 |
return False
|
| 234 |
-
|
| 235 |
# Must have reasonable length
|
| 236 |
if len(text) < 300:
|
| 237 |
return False
|
| 238 |
-
|
| 239 |
# Must have sentences (not just fragments)
|
| 240 |
-
sentences = re.split(r
|
| 241 |
if len(sentences) < 4:
|
| 242 |
return False
|
| 243 |
-
|
| 244 |
# Check for too many special characters
|
| 245 |
special_ratio = len(re.findall(r'[^\w\s.,!?\'"()-]', text)) / max(len(text), 1)
|
| 246 |
if special_ratio > 0.08:
|
| 247 |
return False
|
| 248 |
-
|
| 249 |
return True
|
| 250 |
|
| 251 |
|
|
@@ -263,7 +439,7 @@ def is_play_text(text: str) -> bool:
|
|
| 263 |
r"^[A-Z]{2,}\.\s", # Character names like "HAMLET."
|
| 264 |
r"Alarum|Flourish|Sennet", # Stage directions
|
| 265 |
]
|
| 266 |
-
lines = text.split(
|
| 267 |
play_indicators = 0
|
| 268 |
for line in lines:
|
| 269 |
for pattern in play_patterns:
|
|
@@ -275,182 +451,182 @@ def is_play_text(text: str) -> bool:
|
|
| 275 |
def is_english_text(text: str, min_ratio: float = 0.08, max_foreign: int = 5) -> bool:
|
| 276 |
"""
|
| 277 |
Check if text is primarily English.
|
| 278 |
-
|
| 279 |
Args:
|
| 280 |
text: Text to check
|
| 281 |
min_ratio: Minimum ratio of common English words
|
| 282 |
max_foreign: Maximum number of foreign word matches before rejecting
|
| 283 |
-
|
| 284 |
Returns:
|
| 285 |
True if text appears to be English
|
| 286 |
"""
|
| 287 |
if not text or len(text) < 100:
|
| 288 |
return False
|
| 289 |
-
|
| 290 |
text_lower = text.lower()
|
| 291 |
words = text_lower.split()
|
| 292 |
-
|
| 293 |
if len(words) < 20:
|
| 294 |
return False
|
| 295 |
-
|
| 296 |
# Check for excessive non-English words
|
| 297 |
for pattern in NON_ENGLISH_PATTERNS:
|
| 298 |
matches = len(re.findall(pattern, text_lower))
|
| 299 |
if matches > max_foreign:
|
| 300 |
return False
|
| 301 |
-
|
| 302 |
# Check for sufficient English words
|
| 303 |
english_count = sum(1 for w in words if w.strip(".,!?;:'\"") in ENGLISH_WORDS)
|
| 304 |
ratio = english_count / len(words)
|
| 305 |
-
|
| 306 |
return ratio >= min_ratio
|
| 307 |
|
| 308 |
|
| 309 |
def normalize_title(title: str) -> str:
|
| 310 |
"""Normalize a book title for matching."""
|
| 311 |
# Remove common prefixes/suffixes
|
| 312 |
-
title = re.sub(r
|
| 313 |
-
title = re.sub(r
|
| 314 |
-
title = re.sub(r
|
| 315 |
-
title = re.sub(r
|
| 316 |
return title.lower().strip()
|
| 317 |
|
| 318 |
|
| 319 |
# -------- SUMMARIZATION: BOOKS + ARXIV ----------
|
| 320 |
|
|
|
|
| 321 |
def download_goodreads_descriptions() -> dict[str, dict]:
|
| 322 |
"""
|
| 323 |
Download Goodreads book descriptions - back-cover style blurbs.
|
| 324 |
-
|
| 325 |
These are "what the book is about" descriptions, not plot summaries.
|
| 326 |
Returns dict mapping normalized title -> {title, description}
|
| 327 |
"""
|
| 328 |
print("\nLoading Goodreads book descriptions...")
|
| 329 |
-
|
| 330 |
descriptions = {}
|
| 331 |
-
|
| 332 |
# Try multiple sources
|
| 333 |
datasets_to_try = [
|
| 334 |
"booksouls/goodreads-book-descriptions",
|
| 335 |
"Skelebor/book_titles_and_descriptions_en_clean",
|
| 336 |
]
|
| 337 |
-
|
| 338 |
for ds_name in datasets_to_try:
|
| 339 |
try:
|
| 340 |
print(f" Loading {ds_name}...")
|
| 341 |
ds = load_dataset(ds_name, split="train")
|
| 342 |
-
|
| 343 |
for item in tqdm(ds, desc="Goodreads", leave=False):
|
| 344 |
title = item.get("title", "")
|
| 345 |
description = item.get("description", "")
|
| 346 |
-
|
| 347 |
if not title or not description:
|
| 348 |
continue
|
| 349 |
-
|
| 350 |
# Skip very short descriptions (not useful for training)
|
| 351 |
if len(description) < 100:
|
| 352 |
continue
|
| 353 |
-
|
| 354 |
# Skip very long descriptions (truncate later)
|
| 355 |
if len(description) > 2000:
|
| 356 |
description = description[:2000]
|
| 357 |
-
|
| 358 |
# Skip plays and excluded titles
|
| 359 |
if is_excluded_title(title):
|
| 360 |
continue
|
| 361 |
-
|
| 362 |
# Skip non-English descriptions
|
| 363 |
if not is_english_text(description):
|
| 364 |
continue
|
| 365 |
-
|
| 366 |
norm_title = normalize_title(title)
|
| 367 |
if norm_title and norm_title not in descriptions:
|
| 368 |
descriptions[norm_title] = {
|
| 369 |
"title": title,
|
| 370 |
"description": description,
|
| 371 |
}
|
| 372 |
-
|
| 373 |
print(f" Loaded {len(descriptions):,} descriptions from {ds_name}")
|
| 374 |
except Exception as e:
|
| 375 |
print(f" {ds_name} failed: {e}")
|
| 376 |
-
|
| 377 |
print(f" Total: {len(descriptions):,} unique book descriptions")
|
| 378 |
return descriptions
|
| 379 |
|
| 380 |
|
| 381 |
def download_book_descriptions(
|
| 382 |
-
goodreads_descriptions: dict[str, dict],
|
| 383 |
-
max_samples: int = 20000
|
| 384 |
) -> list[dict[str, Any]]:
|
| 385 |
"""
|
| 386 |
Download book description data by matching Gutenberg texts with Goodreads descriptions.
|
| 387 |
-
|
| 388 |
This gives us (book_excerpt, book_description) training pairs where descriptions
|
| 389 |
are back-cover style "what is this book about" blurbs, not plot summaries.
|
| 390 |
"""
|
| 391 |
print("\nMatching Gutenberg books with Goodreads descriptions...")
|
| 392 |
-
|
| 393 |
try:
|
| 394 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
| 395 |
except Exception:
|
| 396 |
gutenberg = load_dataset("pg19", split="train")
|
| 397 |
-
|
| 398 |
records: list[dict[str, Any]] = []
|
| 399 |
matched_titles = set()
|
| 400 |
skipped_quality = 0
|
| 401 |
skipped_play = 0
|
| 402 |
-
|
| 403 |
indices = list(range(len(gutenberg)))
|
| 404 |
random.shuffle(indices)
|
| 405 |
-
|
| 406 |
for i in tqdm(indices, desc="Matching books", leave=False):
|
| 407 |
if len(records) >= max_samples:
|
| 408 |
break
|
| 409 |
-
|
| 410 |
item = gutenberg[i]
|
| 411 |
text = item.get("TEXT", "") or item.get("text", "")
|
| 412 |
metadata_raw = item.get("METADATA", "") or "{}"
|
| 413 |
-
|
| 414 |
# Parse metadata
|
| 415 |
try:
|
| 416 |
metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
|
| 417 |
except (json.JSONDecodeError, TypeError):
|
| 418 |
metadata = {}
|
| 419 |
-
|
| 420 |
# Get title
|
| 421 |
title = metadata.get("title", "") if isinstance(metadata, dict) else ""
|
| 422 |
if not title:
|
| 423 |
continue
|
| 424 |
-
|
| 425 |
# Check if we have a Goodreads description for this book
|
| 426 |
norm_title = normalize_title(title)
|
| 427 |
if norm_title not in goodreads_descriptions:
|
| 428 |
continue
|
| 429 |
-
|
| 430 |
# Skip if already matched this book
|
| 431 |
if norm_title in matched_titles:
|
| 432 |
continue
|
| 433 |
-
|
| 434 |
goodreads_data = goodreads_descriptions[norm_title]
|
| 435 |
-
|
| 436 |
# Skip plays and excluded titles
|
| 437 |
if is_excluded_title(title):
|
| 438 |
skipped_play += 1
|
| 439 |
continue
|
| 440 |
-
|
| 441 |
if not text or len(text) < 2000:
|
| 442 |
continue
|
| 443 |
-
|
| 444 |
# Get a clean excerpt from the book (skip front matter)
|
| 445 |
-
paragraphs = re.split(r
|
| 446 |
excerpt_parts = []
|
| 447 |
total_len = 0
|
| 448 |
-
|
| 449 |
for para in paragraphs[10:]: # Skip front matter
|
| 450 |
para = para.strip()
|
| 451 |
if len(para) < 100:
|
| 452 |
continue
|
| 453 |
-
|
| 454 |
# Quality check on paragraph
|
| 455 |
if not is_english_text(para):
|
| 456 |
continue
|
|
@@ -460,112 +636,119 @@ def download_book_descriptions(
|
|
| 460 |
if not is_quality_text(para) and len(para) > 300:
|
| 461 |
skipped_quality += 1
|
| 462 |
continue
|
| 463 |
-
|
| 464 |
excerpt_parts.append(para)
|
| 465 |
total_len += len(para)
|
| 466 |
-
|
| 467 |
if total_len >= 3000:
|
| 468 |
break
|
| 469 |
-
|
| 470 |
if total_len < 1000:
|
| 471 |
continue
|
| 472 |
-
|
| 473 |
book_excerpt = "\n\n".join(excerpt_parts)[:4000]
|
| 474 |
matched_titles.add(norm_title)
|
| 475 |
-
|
| 476 |
-
records.append(
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
|
|
|
|
|
|
| 483 |
print(f" Matched {len(records):,} books with descriptions")
|
| 484 |
print(f" Skipped: {skipped_quality} quality, {skipped_play} plays")
|
| 485 |
-
|
| 486 |
return records
|
| 487 |
|
| 488 |
|
| 489 |
# Keep BookSum for additional literary training (chapter summaries are still useful)
|
| 490 |
def download_booksum(max_samples: int = 20000) -> list[dict[str, Any]]:
|
| 491 |
"""Download BookSum - literary chapter summarization (English only, quality filtered).
|
| 492 |
-
|
| 493 |
Note: These are chapter-level plot summaries, useful as supplementary training data.
|
| 494 |
The primary book training comes from Goodreads descriptions (back-cover style).
|
| 495 |
"""
|
| 496 |
print("\nLoading BookSum (supplementary literary data)...")
|
| 497 |
-
|
| 498 |
all_records: list[dict[str, Any]] = []
|
| 499 |
booksum = load_dataset("kmfoda/booksum")
|
| 500 |
-
|
| 501 |
for split_name in booksum.keys():
|
| 502 |
split = str(split_name)
|
| 503 |
data = booksum[split_name]
|
| 504 |
limit = max_samples if "train" in split else max_samples // 10
|
| 505 |
indices = random.sample(range(len(data)), min(len(data), limit))
|
| 506 |
-
|
| 507 |
records = []
|
| 508 |
skipped_language = 0
|
| 509 |
skipped_excluded = 0
|
| 510 |
skipped_play = 0
|
| 511 |
-
|
| 512 |
for i in tqdm(indices, desc=f"BookSum {split}", leave=False):
|
| 513 |
item = data[i]
|
| 514 |
chapter = item.get("chapter", "")
|
| 515 |
summary = item.get("summary_text") or item.get("summary", "")
|
| 516 |
-
|
| 517 |
# Extract book title from book_id (e.g., "The Last of the Mohicans.chapters 1-2")
|
| 518 |
book_id = item.get("book_id", "")
|
| 519 |
book_title = book_id.split(".")[0] if "." in book_id else book_id
|
| 520 |
chapter_name = item.get("summary_id", "") or item.get("summary_name", "")
|
| 521 |
-
|
| 522 |
if not (chapter and summary and len(chapter) > 300):
|
| 523 |
continue
|
| 524 |
-
|
| 525 |
# Filter: excluded titles (Shakespeare, plays, etc.)
|
| 526 |
if is_excluded_title(book_title):
|
| 527 |
skipped_excluded += 1
|
| 528 |
continue
|
| 529 |
-
|
| 530 |
# Filter: play text format
|
| 531 |
if is_play_text(chapter):
|
| 532 |
skipped_play += 1
|
| 533 |
continue
|
| 534 |
-
|
| 535 |
# Filter: English only
|
| 536 |
if not is_english_text(chapter):
|
| 537 |
skipped_language += 1
|
| 538 |
continue
|
| 539 |
-
|
| 540 |
# Filter: quality text
|
| 541 |
if not is_quality_text(chapter):
|
| 542 |
continue
|
| 543 |
-
|
| 544 |
-
records.append(
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
|
|
|
|
|
|
| 552 |
all_records.extend(records)
|
| 553 |
-
print(
|
| 554 |
-
|
|
|
|
|
|
|
| 555 |
return all_records
|
| 556 |
|
| 557 |
|
| 558 |
def clean_arxiv_text(text: str) -> str:
|
| 559 |
"""Clean arXiv LaTeX-style text to make it more readable."""
|
| 560 |
import re
|
|
|
|
| 561 |
# Remove LaTeX math placeholders
|
| 562 |
-
text = re.sub(r
|
| 563 |
-
text = re.sub(r
|
| 564 |
# Remove excessive whitespace
|
| 565 |
-
text = re.sub(r
|
| 566 |
# Remove LaTeX commands
|
| 567 |
-
text = re.sub(r
|
| 568 |
-
text = re.sub(r
|
| 569 |
return text.strip()
|
| 570 |
|
| 571 |
|
|
@@ -573,19 +756,19 @@ def extract_paper_title(abstract: str) -> str:
|
|
| 573 |
"""Extract a meaningful title from the first sentence of an abstract."""
|
| 574 |
# Clean the abstract first
|
| 575 |
abstract = clean_arxiv_text(abstract)
|
| 576 |
-
|
| 577 |
# Get the first sentence (up to first period, question mark, or newline)
|
| 578 |
-
first_sentence = re.split(r
|
| 579 |
-
|
| 580 |
# Truncate if too long
|
| 581 |
if len(first_sentence) > 100:
|
| 582 |
# Try to cut at a natural word boundary
|
| 583 |
-
first_sentence = first_sentence[:100].rsplit(
|
| 584 |
-
|
| 585 |
# Capitalize first letter
|
| 586 |
if first_sentence:
|
| 587 |
first_sentence = first_sentence[0].upper() + first_sentence[1:]
|
| 588 |
-
|
| 589 |
return first_sentence or "Untitled Paper"
|
| 590 |
|
| 591 |
|
|
@@ -593,202 +776,222 @@ def download_arxiv_summarization(max_samples: int = 50000) -> list[dict[str, Any
|
|
| 593 |
"""
|
| 594 |
Download arXiv papers for academic summarization only (English only).
|
| 595 |
Note: This dataset doesn't have categories, so can't be used for topic classification.
|
| 596 |
-
|
| 597 |
Returns: summarization_records
|
| 598 |
"""
|
| 599 |
print("\nLoading arXiv (academic papers for summarization)...")
|
| 600 |
-
|
| 601 |
print(" Loading dataset (this may take a minute)...")
|
| 602 |
arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
|
| 603 |
-
|
| 604 |
summ_records: list[dict[str, Any]] = []
|
| 605 |
skipped_language = 0
|
| 606 |
-
|
| 607 |
indices = list(range(len(arxiv)))
|
| 608 |
random.shuffle(indices)
|
| 609 |
-
|
| 610 |
print(" Processing papers...")
|
| 611 |
-
for i in tqdm(indices[:max_samples * 2], desc="arXiv", leave=False):
|
| 612 |
if len(summ_records) >= max_samples:
|
| 613 |
break
|
| 614 |
-
|
| 615 |
item = arxiv[i]
|
| 616 |
-
|
| 617 |
# Get abstract and article
|
| 618 |
abstract = item.get("abstract", "")
|
| 619 |
article = item.get("article", "")
|
| 620 |
-
|
| 621 |
if not abstract or len(abstract) < 100:
|
| 622 |
continue
|
| 623 |
-
|
| 624 |
# Clean LaTeX artifacts
|
| 625 |
abstract = clean_arxiv_text(abstract)
|
| 626 |
article = clean_arxiv_text(article)
|
| 627 |
-
|
| 628 |
# Skip if still has too many weird characters after cleaning
|
| 629 |
-
if
|
| 630 |
continue
|
| 631 |
-
|
| 632 |
# Filter: English only
|
| 633 |
if not is_english_text(article[:1000]):
|
| 634 |
skipped_language += 1
|
| 635 |
continue
|
| 636 |
-
|
| 637 |
# Summarization: article → abstract
|
| 638 |
if article and len(article) > 500:
|
| 639 |
# Extract title from abstract
|
| 640 |
paper_title = extract_paper_title(abstract)
|
| 641 |
-
|
| 642 |
-
summ_records.append(
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
|
|
|
|
|
|
| 649 |
print(f" Summarization: {len(summ_records):,} (skipped {skipped_language} non-English)")
|
| 650 |
-
|
| 651 |
return summ_records
|
| 652 |
|
| 653 |
|
| 654 |
def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, Any]]:
|
| 655 |
"""
|
| 656 |
Download topic classification data from multiple sources with real categories.
|
| 657 |
-
|
| 658 |
Sources:
|
| 659 |
- 20 Newsgroups (classic topic classification)
|
| 660 |
- Wikipedia (article categories)
|
| 661 |
"""
|
| 662 |
print("\nLoading topic classification datasets...")
|
| 663 |
-
|
| 664 |
records: list[dict[str, Any]] = []
|
| 665 |
-
|
| 666 |
# 20 Newsgroups - classic topic dataset
|
| 667 |
print(" Loading 20 Newsgroups...")
|
| 668 |
try:
|
| 669 |
newsgroups = load_dataset("SetFit/20_newsgroups", split="train")
|
| 670 |
-
|
| 671 |
# Map 20 newsgroups categories to our 8 topics
|
| 672 |
newsgroup_map = {
|
| 673 |
# Science
|
| 674 |
-
"sci.crypt": "Science",
|
| 675 |
-
"sci.
|
| 676 |
-
|
| 677 |
-
"
|
| 678 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
"comp.windows.x": "Technology",
|
| 680 |
# Philosophy/Religion
|
| 681 |
-
"alt.atheism": "Philosophy",
|
|
|
|
| 682 |
"talk.religion.misc": "Philosophy",
|
| 683 |
# History/Politics
|
| 684 |
-
"talk.politics.guns": "History",
|
|
|
|
| 685 |
"talk.politics.misc": "History",
|
| 686 |
# Business
|
| 687 |
"misc.forsale": "Business",
|
| 688 |
# Sports/Recreation
|
| 689 |
-
"rec.autos": "Arts",
|
| 690 |
-
"rec.
|
|
|
|
|
|
|
| 691 |
}
|
| 692 |
-
|
| 693 |
for item in tqdm(newsgroups, desc="20 Newsgroups", leave=False):
|
| 694 |
if len(records) >= max_samples:
|
| 695 |
break
|
| 696 |
label_name = item.get("label_text", "")
|
| 697 |
text = item.get("text", "")
|
| 698 |
-
|
| 699 |
if label_name in newsgroup_map and text and len(text) > 100:
|
| 700 |
-
records.append(
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
|
|
|
|
|
|
| 706 |
print(f" 20 Newsgroups: {len(records):,}")
|
| 707 |
except Exception as e:
|
| 708 |
print(f" 20 Newsgroups failed: {e}")
|
| 709 |
-
|
| 710 |
# Add from Gutenberg for Fiction
|
| 711 |
gutenberg_topics = download_gutenberg_topics(max_samples // 4)
|
| 712 |
records.extend(gutenberg_topics)
|
| 713 |
-
|
| 714 |
# Add from scientific papers abstract dataset for more Science/Tech
|
| 715 |
print(" Loading scientific papers...")
|
| 716 |
try:
|
| 717 |
sci_papers = load_dataset("scientific_papers", "arxiv", split="train", streaming=True)
|
| 718 |
sci_count = 0
|
| 719 |
-
for item in tqdm(sci_papers, desc="Scientific papers", leave=False, total=max_samples//4):
|
| 720 |
if sci_count >= max_samples // 4:
|
| 721 |
break
|
| 722 |
abstract = item.get("abstract", "")
|
| 723 |
if abstract and len(abstract) > 100:
|
| 724 |
# Alternate between Science and Technology
|
| 725 |
topic = "Science" if sci_count % 2 == 0 else "Technology"
|
| 726 |
-
records.append(
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
|
|
|
|
|
|
| 731 |
sci_count += 1
|
| 732 |
print(f" Scientific papers: {sci_count:,}")
|
| 733 |
except Exception as e:
|
| 734 |
print(f" Scientific papers failed: {e}")
|
| 735 |
-
|
| 736 |
return records
|
| 737 |
|
| 738 |
|
| 739 |
def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> None:
|
| 740 |
"""Download all summarization data (books + arxiv, NO news).
|
| 741 |
-
|
| 742 |
Book data now uses Goodreads descriptions (back-cover blurbs) instead of
|
| 743 |
plot summaries. This trains the model to describe "what the book is about"
|
| 744 |
rather than summarizing the plot.
|
| 745 |
"""
|
| 746 |
print("\nDownloading Summarization Data...")
|
| 747 |
out_dir = OUTPUT_DIR / "summarization"
|
| 748 |
-
|
| 749 |
all_records: list[dict[str, Any]] = []
|
| 750 |
-
|
| 751 |
# Goodreads descriptions - primary book training data (back-cover style)
|
| 752 |
goodreads_descriptions = download_goodreads_descriptions()
|
| 753 |
book_records = download_book_descriptions(goodreads_descriptions, max_books)
|
| 754 |
all_records.extend(book_records)
|
| 755 |
-
|
| 756 |
# Optional: Add some BookSum for additional literary variety
|
| 757 |
# These are chapter summaries, not back-cover style, so keep limited
|
| 758 |
# booksum_records = download_booksum(max_books // 4)
|
| 759 |
# all_records.extend(booksum_records)
|
| 760 |
-
|
| 761 |
# arXiv - academic (abstracts are already "what is this paper about")
|
| 762 |
arxiv_summ = download_arxiv_summarization(max_arxiv)
|
| 763 |
all_records.extend(arxiv_summ)
|
| 764 |
-
|
| 765 |
# Shuffle and split
|
| 766 |
random.shuffle(all_records)
|
| 767 |
-
|
| 768 |
# Split by original split if available, else 90/5/5
|
| 769 |
-
train_records = [
|
|
|
|
|
|
|
| 770 |
val_records = [r for r in all_records if r.get("split") == "validation"]
|
| 771 |
test_records = [r for r in all_records if r.get("split") == "test"]
|
| 772 |
-
|
| 773 |
# If no split info, do 90/5/5
|
| 774 |
if len(val_records) < 100:
|
| 775 |
n = len(train_records)
|
| 776 |
random.shuffle(train_records)
|
| 777 |
-
val_records = train_records[int(n*0.9):int(n*0.95)]
|
| 778 |
-
test_records = train_records[int(n*0.95):]
|
| 779 |
-
train_records = train_records[:int(n*0.9)]
|
| 780 |
-
|
| 781 |
# Remove split key before saving
|
| 782 |
for r in train_records + val_records + test_records:
|
| 783 |
r.pop("split", None)
|
| 784 |
-
|
| 785 |
write_jsonl(train_records, out_dir / "train.jsonl", "train")
|
| 786 |
write_jsonl(val_records, out_dir / "validation.jsonl", "val")
|
| 787 |
write_jsonl(test_records, out_dir / "test.jsonl", "test")
|
| 788 |
-
|
| 789 |
# Print breakdown
|
| 790 |
-
literary_count = sum(
|
| 791 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
print(f"\n Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
|
| 793 |
print(f" Literary (book descriptions): {literary_count:,}")
|
| 794 |
print(f" Academic (paper abstracts): {academic_count:,}")
|
|
@@ -796,10 +999,11 @@ def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> No
|
|
| 796 |
|
| 797 |
# ------------ TOPIC CLASSIFICATION ------------
|
| 798 |
|
|
|
|
| 799 |
def download_topics(max_samples: int = 50000) -> None:
|
| 800 |
"""
|
| 801 |
Download topic classification data from multiple sources.
|
| 802 |
-
|
| 803 |
Sources:
|
| 804 |
- 20 Newsgroups (classic topic dataset)
|
| 805 |
- Gutenberg books (Fiction)
|
|
@@ -807,49 +1011,49 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 807 |
"""
|
| 808 |
print("\nDownloading Topic Classification...")
|
| 809 |
out_dir = OUTPUT_DIR / "topic"
|
| 810 |
-
|
| 811 |
# Get topic records from various sources
|
| 812 |
all_records = download_topics_from_datasets(max_samples)
|
| 813 |
-
|
| 814 |
# Balance topics
|
| 815 |
topic_counts: dict[str, list] = {t: [] for t in TOPIC_LABELS}
|
| 816 |
for r in all_records:
|
| 817 |
topic = r.get("topic")
|
| 818 |
if topic in topic_counts:
|
| 819 |
topic_counts[topic].append(r)
|
| 820 |
-
|
| 821 |
# Print distribution before balancing
|
| 822 |
print("\n Topic distribution (before balancing):")
|
| 823 |
for topic, records in topic_counts.items():
|
| 824 |
print(f" {topic}: {len(records):,}")
|
| 825 |
-
|
| 826 |
# Balance to min count (with some tolerance) - only from topics that have data
|
| 827 |
counts_with_data = [len(v) for v in topic_counts.values() if v]
|
| 828 |
if not counts_with_data:
|
| 829 |
print(" Warning: No topic data found!")
|
| 830 |
return
|
| 831 |
-
|
| 832 |
min_count = min(counts_with_data)
|
| 833 |
target_count = min(min_count, max_samples // len(TOPIC_LABELS))
|
| 834 |
-
|
| 835 |
balanced: list[dict[str, Any]] = []
|
| 836 |
for _topic, records in topic_counts.items():
|
| 837 |
if records:
|
| 838 |
random.shuffle(records)
|
| 839 |
balanced.extend(records[:target_count])
|
| 840 |
-
|
| 841 |
random.shuffle(balanced)
|
| 842 |
-
|
| 843 |
# Split 90/5/5
|
| 844 |
n = len(balanced)
|
| 845 |
-
train_records = balanced[:int(n*0.9)]
|
| 846 |
-
val_records = balanced[int(n*0.9):int(n*0.95)]
|
| 847 |
-
test_records = balanced[int(n*0.95):]
|
| 848 |
-
|
| 849 |
write_jsonl(train_records, out_dir / "train.jsonl", "train")
|
| 850 |
write_jsonl(val_records, out_dir / "validation.jsonl", "val")
|
| 851 |
write_jsonl(test_records, out_dir / "test.jsonl", "test")
|
| 852 |
-
|
| 853 |
# Save labels - only labels that have data
|
| 854 |
used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
|
| 855 |
(out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
|
|
@@ -859,82 +1063,85 @@ def download_topics(max_samples: int = 50000) -> None:
|
|
| 859 |
def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
| 860 |
"""Extract topic-labeled samples from Gutenberg books (English only)."""
|
| 861 |
print("\nLoading Gutenberg for topic classification...")
|
| 862 |
-
|
| 863 |
try:
|
| 864 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
| 865 |
except Exception:
|
| 866 |
print(" Trying pg19...")
|
| 867 |
gutenberg = load_dataset("pg19", split="train")
|
| 868 |
-
|
| 869 |
records: list[dict[str, Any]] = []
|
| 870 |
skipped_language = 0
|
| 871 |
-
|
| 872 |
indices = list(range(len(gutenberg)))
|
| 873 |
random.shuffle(indices)
|
| 874 |
-
|
| 875 |
for i in tqdm(indices, desc="Gutenberg topics", leave=False):
|
| 876 |
if len(records) >= max_samples:
|
| 877 |
break
|
| 878 |
-
|
| 879 |
item = gutenberg[i]
|
| 880 |
text = item.get("TEXT", "") or item.get("text", "")
|
| 881 |
metadata = item.get("METADATA", {}) or {}
|
| 882 |
-
|
| 883 |
if not text or len(text) < 1000:
|
| 884 |
continue
|
| 885 |
-
|
| 886 |
# Try to determine topic from metadata
|
| 887 |
subjects = ""
|
| 888 |
if isinstance(metadata, dict):
|
| 889 |
subjects = str(metadata.get("subjects", "")).lower()
|
| 890 |
subjects += " " + str(metadata.get("subject", "")).lower()
|
| 891 |
subjects += " " + str(metadata.get("category", "")).lower()
|
| 892 |
-
|
| 893 |
topic = None
|
| 894 |
for keyword, mapped_topic in GUTENBERG_SUBJECT_MAP.items():
|
| 895 |
if keyword in subjects:
|
| 896 |
topic = mapped_topic
|
| 897 |
break
|
| 898 |
-
|
| 899 |
# Default fiction for novels without clear subject
|
| 900 |
if not topic and ("novel" in subjects or not subjects.strip()):
|
| 901 |
topic = "Fiction"
|
| 902 |
-
|
| 903 |
if topic:
|
| 904 |
# Get a clean paragraph as sample
|
| 905 |
-
paragraphs = re.split(r
|
| 906 |
for para in paragraphs[5:]: # Skip front matter
|
| 907 |
para = para.strip()
|
| 908 |
-
if 200 < len(para) < 1500 and para.count(
|
| 909 |
# Filter: English only
|
| 910 |
if not is_english_text(para):
|
| 911 |
skipped_language += 1
|
| 912 |
break
|
| 913 |
-
|
| 914 |
-
records.append(
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
|
|
|
|
|
|
| 919 |
break
|
| 920 |
-
|
| 921 |
print(f" Gutenberg topics: {len(records):,} (skipped {skipped_language} non-English)")
|
| 922 |
return records
|
| 923 |
|
| 924 |
|
| 925 |
# ------------ EMOTIONS (unchanged) -------------
|
| 926 |
|
|
|
|
| 927 |
def download_emotions() -> None:
|
| 928 |
"""Download GoEmotions for emotion classification."""
|
| 929 |
print("\nDownloading Emotions (GoEmotions)...")
|
| 930 |
out_dir = OUTPUT_DIR / "emotion"
|
| 931 |
-
|
| 932 |
ds = load_dataset("google-research-datasets/go_emotions", "simplified")
|
| 933 |
-
|
| 934 |
for split_name in ds.keys():
|
| 935 |
split = str(split_name)
|
| 936 |
data = ds[split_name]
|
| 937 |
-
|
| 938 |
records: list[dict[str, Any]] = []
|
| 939 |
for item in tqdm(data, desc=split, leave=False):
|
| 940 |
text = item.get("text", "")
|
|
@@ -944,7 +1151,7 @@ def download_emotions() -> None:
|
|
| 944 |
if emotions:
|
| 945 |
records.append({"text": text, "emotions": emotions})
|
| 946 |
write_jsonl(records, out_dir / f"{split}.jsonl", split)
|
| 947 |
-
|
| 948 |
(out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
|
| 949 |
print(f" {len(EMOTION_LABELS)} emotion labels saved")
|
| 950 |
|
|
@@ -952,12 +1159,23 @@ def download_emotions() -> None:
|
|
| 952 |
# --------------- GUTENBERG BOOKS (for language modeling) ---------------
|
| 953 |
|
| 954 |
GUTENBERG_JUNK_PATTERNS = [
|
| 955 |
-
r"Project Gutenberg",
|
| 956 |
-
r"
|
| 957 |
-
r"
|
| 958 |
-
r"
|
| 959 |
-
r"^\
|
| 960 |
-
r"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
]
|
| 962 |
GUTENBERG_JUNK_REGEX = re.compile("|".join(GUTENBERG_JUNK_PATTERNS), re.IGNORECASE)
|
| 963 |
|
|
@@ -968,7 +1186,7 @@ def is_clean_prose(text: str) -> bool:
|
|
| 968 |
return False
|
| 969 |
if GUTENBERG_JUNK_REGEX.search(text):
|
| 970 |
return False
|
| 971 |
-
if text.count(
|
| 972 |
return False
|
| 973 |
uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
|
| 974 |
if uppercase_ratio > 0.3:
|
|
@@ -987,68 +1205,66 @@ def download_gutenberg(max_samples: int = 30000) -> None:
|
|
| 987 |
print("\nDownloading Gutenberg Books (English only)...")
|
| 988 |
out_dir = OUTPUT_DIR / "books"
|
| 989 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 990 |
-
|
| 991 |
try:
|
| 992 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
| 993 |
except Exception:
|
| 994 |
gutenberg = load_dataset("pg19", split="train")
|
| 995 |
-
|
| 996 |
records: list[dict[str, Any]] = []
|
| 997 |
indices = list(range(len(gutenberg)))
|
| 998 |
random.shuffle(indices)
|
| 999 |
-
|
| 1000 |
for i in tqdm(indices, desc="Books", leave=False):
|
| 1001 |
if len(records) >= max_samples:
|
| 1002 |
break
|
| 1003 |
-
|
| 1004 |
item = gutenberg[i]
|
| 1005 |
text = item.get("TEXT", "") or item.get("text", "")
|
| 1006 |
metadata_raw = item.get("METADATA", "") or "{}"
|
| 1007 |
-
|
| 1008 |
# Parse metadata - it's stored as JSON string
|
| 1009 |
try:
|
| 1010 |
metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
|
| 1011 |
except (json.JSONDecodeError, TypeError):
|
| 1012 |
metadata = {}
|
| 1013 |
-
|
| 1014 |
# Extract title and author
|
| 1015 |
title = metadata.get("title", "") if isinstance(metadata, dict) else ""
|
| 1016 |
author = metadata.get("author", "") if isinstance(metadata, dict) else ""
|
| 1017 |
if not title:
|
| 1018 |
title = item.get("title", f"Unknown Book #{i}")
|
| 1019 |
-
|
| 1020 |
if not text or len(text) < 1000:
|
| 1021 |
continue
|
| 1022 |
-
|
| 1023 |
-
paragraphs = re.split(r
|
| 1024 |
for para in paragraphs:
|
| 1025 |
para = para.strip()
|
| 1026 |
if is_clean_prose(para):
|
| 1027 |
-
records.append(
|
| 1028 |
-
"text": para,
|
| 1029 |
-
|
| 1030 |
-
"author": author,
|
| 1031 |
-
"type": "gutenberg"
|
| 1032 |
-
})
|
| 1033 |
if len(records) >= max_samples:
|
| 1034 |
break
|
| 1035 |
-
|
| 1036 |
random.shuffle(records)
|
| 1037 |
n = len(records)
|
| 1038 |
-
write_jsonl(records[:int(n*0.9)], out_dir / "train.jsonl", "train")
|
| 1039 |
-
write_jsonl(records[int(n*0.9):int(n*0.95)], out_dir / "validation.jsonl", "val")
|
| 1040 |
-
write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
|
| 1041 |
|
| 1042 |
|
| 1043 |
# ------------ MAIN ------------
|
| 1044 |
|
|
|
|
| 1045 |
def main() -> None:
|
| 1046 |
parser = argparse.ArgumentParser(description="Download LexiMind datasets")
|
| 1047 |
parser.add_argument(
|
| 1048 |
"--task",
|
| 1049 |
choices=["all", "summarization", "emotion", "topic", "gutenberg"],
|
| 1050 |
default="all",
|
| 1051 |
-
help="Dataset to download"
|
| 1052 |
)
|
| 1053 |
parser.add_argument("--max-books", type=int, default=40000, help="Max BookSum samples")
|
| 1054 |
parser.add_argument("--max-arxiv", type=int, default=50000, help="Max arXiv samples")
|
|
@@ -1056,14 +1272,14 @@ def main() -> None:
|
|
| 1056 |
parser.add_argument("--max-topics", type=int, default=50000, help="Max topic samples")
|
| 1057 |
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 1058 |
args = parser.parse_args()
|
| 1059 |
-
|
| 1060 |
random.seed(args.seed)
|
| 1061 |
-
|
| 1062 |
print("=" * 60)
|
| 1063 |
print("LexiMind Dataset Download")
|
| 1064 |
print("Books + Academic Papers + Topic Classification")
|
| 1065 |
print("=" * 60)
|
| 1066 |
-
|
| 1067 |
if args.task in ["all", "summarization"]:
|
| 1068 |
download_summarization(args.max_books, args.max_arxiv)
|
| 1069 |
if args.task in ["all", "emotion"]:
|
|
@@ -1072,7 +1288,7 @@ def main() -> None:
|
|
| 1072 |
download_topics(args.max_topics)
|
| 1073 |
if args.task in ["all", "gutenberg"]:
|
| 1074 |
download_gutenberg(args.max_gutenberg)
|
| 1075 |
-
|
| 1076 |
print("\n" + "=" * 60)
|
| 1077 |
print("Download complete!")
|
| 1078 |
print("=" * 60)
|
|
|
|
| 45 |
|
| 46 |
# 28 emotions from GoEmotions - works for all text types
|
| 47 |
EMOTION_LABELS = [
|
| 48 |
+
"admiration",
|
| 49 |
+
"amusement",
|
| 50 |
+
"anger",
|
| 51 |
+
"annoyance",
|
| 52 |
+
"approval",
|
| 53 |
+
"caring",
|
| 54 |
+
"confusion",
|
| 55 |
+
"curiosity",
|
| 56 |
+
"desire",
|
| 57 |
+
"disappointment",
|
| 58 |
+
"disapproval",
|
| 59 |
+
"disgust",
|
| 60 |
+
"embarrassment",
|
| 61 |
+
"excitement",
|
| 62 |
+
"fear",
|
| 63 |
+
"gratitude",
|
| 64 |
+
"grief",
|
| 65 |
+
"joy",
|
| 66 |
+
"love",
|
| 67 |
+
"nervousness",
|
| 68 |
+
"optimism",
|
| 69 |
+
"pride",
|
| 70 |
+
"realization",
|
| 71 |
+
"relief",
|
| 72 |
+
"remorse",
|
| 73 |
+
"sadness",
|
| 74 |
+
"surprise",
|
| 75 |
+
"neutral",
|
| 76 |
]
|
| 77 |
|
| 78 |
# New topic labels for books + papers + blogs
|
| 79 |
TOPIC_LABELS = [
|
| 80 |
+
"Fiction", # Novels, short stories, literary fiction
|
| 81 |
+
"Science", # Physics, chemistry, biology, nature
|
| 82 |
+
"Technology", # CS, engineering, programming, AI/ML
|
| 83 |
+
"Philosophy", # Ethics, logic, metaphysics, epistemology
|
| 84 |
+
"History", # Historical texts, biographies, memoirs
|
| 85 |
+
"Psychology", # Mind, behavior, self-help, mental health
|
| 86 |
+
"Business", # Economics, finance, entrepreneurship
|
| 87 |
+
"Arts", # Music, visual arts, film, architecture
|
| 88 |
]
|
| 89 |
|
| 90 |
# arXiv category → our topic mapping
|
| 91 |
ARXIV_CATEGORY_MAP = {
|
| 92 |
# Computer Science
|
| 93 |
+
"cs.AI": "Technology",
|
| 94 |
+
"cs.CL": "Technology",
|
| 95 |
+
"cs.CV": "Technology",
|
| 96 |
+
"cs.LG": "Technology",
|
| 97 |
+
"cs.NE": "Technology",
|
| 98 |
+
"cs.RO": "Technology",
|
| 99 |
+
"cs.SE": "Technology",
|
| 100 |
+
"cs.PL": "Technology",
|
| 101 |
+
"cs.DB": "Technology",
|
| 102 |
+
"cs.DS": "Technology",
|
| 103 |
+
"cs.CR": "Technology",
|
| 104 |
+
"cs.DC": "Technology",
|
| 105 |
+
"cs.HC": "Technology",
|
| 106 |
+
"cs.IR": "Technology",
|
| 107 |
+
"cs.IT": "Technology",
|
| 108 |
+
"cs.MA": "Technology",
|
| 109 |
+
"cs.MM": "Technology",
|
| 110 |
+
"cs.NI": "Technology",
|
| 111 |
+
"cs.OS": "Technology",
|
| 112 |
+
"cs.PF": "Technology",
|
| 113 |
+
"cs.SY": "Technology",
|
| 114 |
# Physics
|
| 115 |
+
"physics": "Science",
|
| 116 |
+
"astro-ph": "Science",
|
| 117 |
+
"cond-mat": "Science",
|
| 118 |
+
"gr-qc": "Science",
|
| 119 |
+
"hep-ex": "Science",
|
| 120 |
+
"hep-lat": "Science",
|
| 121 |
+
"hep-ph": "Science",
|
| 122 |
+
"hep-th": "Science",
|
| 123 |
+
"math-ph": "Science",
|
| 124 |
+
"nlin": "Science",
|
| 125 |
+
"nucl-ex": "Science",
|
| 126 |
+
"nucl-th": "Science",
|
| 127 |
"quant-ph": "Science",
|
| 128 |
# Math
|
| 129 |
"math": "Science",
|
| 130 |
# Biology/Medicine
|
| 131 |
+
"q-bio": "Science",
|
| 132 |
+
"stat": "Science",
|
| 133 |
# Economics/Finance
|
| 134 |
+
"econ": "Business",
|
| 135 |
+
"q-fin": "Business",
|
| 136 |
# Electrical Engineering
|
| 137 |
"eess": "Technology",
|
| 138 |
}
|
| 139 |
|
| 140 |
# Gutenberg subject → our topic mapping
|
| 141 |
GUTENBERG_SUBJECT_MAP = {
|
| 142 |
+
"fiction": "Fiction",
|
| 143 |
+
"novel": "Fiction",
|
| 144 |
+
"stories": "Fiction",
|
| 145 |
+
"poetry": "Arts",
|
| 146 |
+
"drama": "Arts",
|
| 147 |
+
"plays": "Arts",
|
| 148 |
+
"science": "Science",
|
| 149 |
+
"physics": "Science",
|
| 150 |
+
"chemistry": "Science",
|
| 151 |
+
"biology": "Science",
|
| 152 |
+
"nature": "Science",
|
| 153 |
+
"astronomy": "Science",
|
| 154 |
+
"philosophy": "Philosophy",
|
| 155 |
+
"ethics": "Philosophy",
|
| 156 |
+
"logic": "Philosophy",
|
| 157 |
+
"history": "History",
|
| 158 |
+
"biography": "History",
|
| 159 |
+
"memoir": "History",
|
| 160 |
+
"psychology": "Psychology",
|
| 161 |
+
"mind": "Psychology",
|
| 162 |
+
"economics": "Business",
|
| 163 |
+
"business": "Business",
|
| 164 |
+
"finance": "Business",
|
| 165 |
+
"art": "Arts",
|
| 166 |
+
"music": "Arts",
|
| 167 |
+
"architecture": "Arts",
|
| 168 |
+
"technology": "Technology",
|
| 169 |
+
"engineering": "Technology",
|
| 170 |
}
|
| 171 |
|
| 172 |
|
|
|
|
| 183 |
|
| 184 |
# Common English words for detection
|
| 185 |
ENGLISH_WORDS = {
|
| 186 |
+
"the",
|
| 187 |
+
"and",
|
| 188 |
+
"of",
|
| 189 |
+
"to",
|
| 190 |
+
"a",
|
| 191 |
+
"in",
|
| 192 |
+
"that",
|
| 193 |
+
"is",
|
| 194 |
+
"was",
|
| 195 |
+
"he",
|
| 196 |
+
"she",
|
| 197 |
+
"it",
|
| 198 |
+
"for",
|
| 199 |
+
"with",
|
| 200 |
+
"as",
|
| 201 |
+
"his",
|
| 202 |
+
"her",
|
| 203 |
+
"they",
|
| 204 |
+
"be",
|
| 205 |
+
"at",
|
| 206 |
+
"on",
|
| 207 |
+
"have",
|
| 208 |
+
"had",
|
| 209 |
+
"this",
|
| 210 |
+
"but",
|
| 211 |
+
"not",
|
| 212 |
+
"from",
|
| 213 |
+
"by",
|
| 214 |
+
"or",
|
| 215 |
+
"an",
|
| 216 |
+
"said",
|
| 217 |
+
"were",
|
| 218 |
+
"been",
|
| 219 |
+
"would",
|
| 220 |
+
"could",
|
| 221 |
+
"which",
|
| 222 |
+
"their",
|
| 223 |
+
"there",
|
| 224 |
+
"what",
|
| 225 |
+
"when",
|
| 226 |
+
"who",
|
| 227 |
+
"will",
|
| 228 |
+
"more",
|
| 229 |
+
"if",
|
| 230 |
+
"no",
|
| 231 |
+
"out",
|
| 232 |
+
"so",
|
| 233 |
+
"up",
|
| 234 |
+
"into",
|
| 235 |
+
"than",
|
| 236 |
+
"them",
|
| 237 |
+
"can",
|
| 238 |
+
"only",
|
| 239 |
+
"other",
|
| 240 |
+
"new",
|
| 241 |
+
"some",
|
| 242 |
+
"very",
|
| 243 |
+
"just",
|
| 244 |
+
"over",
|
| 245 |
+
"such",
|
| 246 |
+
"also",
|
| 247 |
+
"its",
|
| 248 |
+
"then",
|
| 249 |
}
|
| 250 |
|
| 251 |
# Non-English language patterns
|
|
|
|
| 266 |
|
| 267 |
# Patterns that indicate garbage/metadata text
|
| 268 |
GARBAGE_PATTERNS = [
|
| 269 |
+
r"^Page \d+:", # Page corrections
|
| 270 |
+
r"changed to", # Errata
|
| 271 |
+
r"Punctuation has been", # Editorial notes
|
| 272 |
+
r"^\[.*\]$", # Bracketed notes
|
| 273 |
+
r"^Note\.?[-—]", # Notes
|
| 274 |
+
r"^follows:", # "as follows:"
|
| 275 |
+
r"CHAPTER [IVXLC]+\.", # Chapter headers only
|
| 276 |
+
r"^\*\*\*", # Project Gutenberg markers
|
| 277 |
+
r"^End of.*Project", # End markers
|
| 278 |
+
r"^Produced by", # Production credits
|
| 279 |
+
r"transcriber", # Transcriber notes
|
| 280 |
+
r"eBook", # eBook references
|
| 281 |
+
r"©|copyright", # Copyright notices
|
| 282 |
+
r"^INDEX", # Index pages
|
| 283 |
r"^\d+\.\s+\w+,\s+\d+", # Index entries like "1. Name, 234"
|
| 284 |
+
r"(syn\.|var\.|sp\.)", # Botanical abbreviations
|
| 285 |
+
r"[A-Z][a-z]+aceae", # Botanical family names
|
| 286 |
+
r"\(\s*syn\s+", # Synonym references
|
| 287 |
]
|
| 288 |
|
| 289 |
# Patterns that indicate technical manuals/instructions (not narrative)
|
| 290 |
TECHNICAL_PATTERNS = [
|
| 291 |
r"\d+\.\s+It\s+(is|has|can)", # Numbered features "1. It is a..."
|
| 292 |
+
r"^\d+(st|nd|rd|th)\.", # "1st. 2nd. 3rd."
|
| 293 |
+
r"Mesh\.?\s*\d+", # Mesh sizes (pottery)
|
| 294 |
r"\d+\s*(oz|lb|kg|g|ml|mm|cm|inch)", # Measurements
|
| 295 |
+
r"Parts?\s*:?\s*\d+", # "Parts: 50"
|
| 296 |
+
r"Method of Using", # Instructions
|
| 297 |
+
r"How to\s+\w+", # How-to guides
|
| 298 |
+
r"Step\s+\d+", # Step-by-step
|
| 299 |
+
r"wire.*address", # Business instructions
|
| 300 |
+
r"orders?\s+should\s+be", # Order instructions
|
| 301 |
+
r"specifications?", # Technical specs
|
| 302 |
+
r"(Front|Back)\s+Focus", # Camera terms
|
| 303 |
+
r"Rack and Pinion", # Mechanical terms
|
| 304 |
]
|
| 305 |
|
| 306 |
# Shakespeare and plays to exclude (model hallucinates on Early Modern English)
|
| 307 |
EXCLUDED_TITLES = {
|
| 308 |
# Shakespeare
|
| 309 |
+
"King Lear",
|
| 310 |
+
"Hamlet",
|
| 311 |
+
"Macbeth",
|
| 312 |
+
"Othello",
|
| 313 |
+
"Romeo and Juliet",
|
| 314 |
+
"A Midsummer Night's Dream",
|
| 315 |
+
"The Tempest",
|
| 316 |
+
"Julius Caesar",
|
| 317 |
+
"The Merchant of Venice",
|
| 318 |
+
"Twelfth Night",
|
| 319 |
+
"Much Ado About Nothing",
|
| 320 |
+
"As You Like It",
|
| 321 |
+
"The Taming of the Shrew",
|
| 322 |
+
"Antony and Cleopatra",
|
| 323 |
+
"Coriolanus",
|
| 324 |
+
"Cymbeline",
|
| 325 |
+
"Timon of Athens",
|
| 326 |
+
"Troilus and Cressida",
|
| 327 |
+
"Measure for Measure",
|
| 328 |
+
"All's Well That Ends Well",
|
| 329 |
+
"Pericles",
|
| 330 |
+
"The Winter's Tale",
|
| 331 |
+
"The Comedy of Errors",
|
| 332 |
+
"Two Gentlemen of Verona",
|
| 333 |
+
"Love's Labour's Lost",
|
| 334 |
+
"The Merry Wives of Windsor",
|
| 335 |
+
"Henry IV",
|
| 336 |
+
"Henry V",
|
| 337 |
+
"Henry VI",
|
| 338 |
+
"Henry VIII",
|
| 339 |
+
"Richard II",
|
| 340 |
+
"Richard III",
|
| 341 |
+
"King John",
|
| 342 |
+
"Titus Andronicus",
|
| 343 |
# French plays
|
| 344 |
+
"Tartuffe",
|
| 345 |
+
"Phaedra",
|
| 346 |
+
"Cyrano de Bergerac",
|
| 347 |
+
"Cyrano De Bergerac",
|
| 348 |
+
"Le Misanthrope",
|
| 349 |
+
"The School for Wives",
|
| 350 |
+
"The Miser",
|
| 351 |
+
"The Imaginary Invalid",
|
| 352 |
+
"Andromaque",
|
| 353 |
+
"Britannicus",
|
| 354 |
+
"Bérénice",
|
| 355 |
+
"Le Cid",
|
| 356 |
# Greek/Roman plays
|
| 357 |
+
"Oedipus Rex",
|
| 358 |
+
"Oedipus the King",
|
| 359 |
+
"Antigone",
|
| 360 |
+
"Electra",
|
| 361 |
+
"Medea",
|
| 362 |
+
"The Bacchae",
|
| 363 |
+
"The Oresteia",
|
| 364 |
+
"Agamemnon",
|
| 365 |
+
"Prometheus Bound",
|
| 366 |
# Other classic plays
|
| 367 |
+
"The Importance of Being Earnest",
|
| 368 |
+
"Pygmalion",
|
| 369 |
+
"Doctor Faustus",
|
| 370 |
+
"Waiting for Godot",
|
| 371 |
+
"Death of a Salesman",
|
| 372 |
+
"A Streetcar Named Desire",
|
| 373 |
+
"The Glass Menagerie",
|
| 374 |
+
"Our Town",
|
| 375 |
+
"Long Day's Journey Into Night",
|
| 376 |
+
"Who's Afraid of Virginia Woolf",
|
| 377 |
+
"The Crucible",
|
| 378 |
+
"Cat on a Hot Tin Roof",
|
| 379 |
# Verse/poetic epics
|
| 380 |
+
"Idylls of the King",
|
| 381 |
+
"Paradise Lost",
|
| 382 |
+
"Paradise Regained",
|
| 383 |
+
"The Divine Comedy",
|
| 384 |
+
"Inferno",
|
| 385 |
+
"Purgatorio",
|
| 386 |
+
"Paradiso",
|
| 387 |
+
"The Faerie Queene",
|
| 388 |
+
"Beowulf",
|
| 389 |
}
|
| 390 |
|
| 391 |
|
|
|
|
| 403 |
for pattern in GARBAGE_PATTERNS:
|
| 404 |
if re.search(pattern, text, re.IGNORECASE | re.MULTILINE):
|
| 405 |
return False
|
| 406 |
+
|
| 407 |
# Reject technical manuals/instructions
|
| 408 |
if is_technical_manual(text):
|
| 409 |
return False
|
| 410 |
+
|
| 411 |
# Must have reasonable length
|
| 412 |
if len(text) < 300:
|
| 413 |
return False
|
| 414 |
+
|
| 415 |
# Must have sentences (not just fragments)
|
| 416 |
+
sentences = re.split(r"[.!?]+", text)
|
| 417 |
if len(sentences) < 4:
|
| 418 |
return False
|
| 419 |
+
|
| 420 |
# Check for too many special characters
|
| 421 |
special_ratio = len(re.findall(r'[^\w\s.,!?\'"()-]', text)) / max(len(text), 1)
|
| 422 |
if special_ratio > 0.08:
|
| 423 |
return False
|
| 424 |
+
|
| 425 |
return True
|
| 426 |
|
| 427 |
|
|
|
|
| 439 |
r"^[A-Z]{2,}\.\s", # Character names like "HAMLET."
|
| 440 |
r"Alarum|Flourish|Sennet", # Stage directions
|
| 441 |
]
|
| 442 |
+
lines = text.split("\n")[:10]
|
| 443 |
play_indicators = 0
|
| 444 |
for line in lines:
|
| 445 |
for pattern in play_patterns:
|
|
|
|
| 451 |
def is_english_text(text: str, min_ratio: float = 0.08, max_foreign: int = 5) -> bool:
|
| 452 |
"""
|
| 453 |
Check if text is primarily English.
|
| 454 |
+
|
| 455 |
Args:
|
| 456 |
text: Text to check
|
| 457 |
min_ratio: Minimum ratio of common English words
|
| 458 |
max_foreign: Maximum number of foreign word matches before rejecting
|
| 459 |
+
|
| 460 |
Returns:
|
| 461 |
True if text appears to be English
|
| 462 |
"""
|
| 463 |
if not text or len(text) < 100:
|
| 464 |
return False
|
| 465 |
+
|
| 466 |
text_lower = text.lower()
|
| 467 |
words = text_lower.split()
|
| 468 |
+
|
| 469 |
if len(words) < 20:
|
| 470 |
return False
|
| 471 |
+
|
| 472 |
# Check for excessive non-English words
|
| 473 |
for pattern in NON_ENGLISH_PATTERNS:
|
| 474 |
matches = len(re.findall(pattern, text_lower))
|
| 475 |
if matches > max_foreign:
|
| 476 |
return False
|
| 477 |
+
|
| 478 |
# Check for sufficient English words
|
| 479 |
english_count = sum(1 for w in words if w.strip(".,!?;:'\"") in ENGLISH_WORDS)
|
| 480 |
ratio = english_count / len(words)
|
| 481 |
+
|
| 482 |
return ratio >= min_ratio
|
| 483 |
|
| 484 |
|
| 485 |
def normalize_title(title: str) -> str:
|
| 486 |
"""Normalize a book title for matching."""
|
| 487 |
# Remove common prefixes/suffixes
|
| 488 |
+
title = re.sub(r"^(The|A|An)\s+", "", title, flags=re.IGNORECASE)
|
| 489 |
+
title = re.sub(r"\s*\([^)]*\)\s*", "", title) # Remove parentheticals
|
| 490 |
+
title = re.sub(r"\s*:.+$", "", title) # Remove subtitles
|
| 491 |
+
title = re.sub(r"[^\w\s]", "", title) # Remove punctuation
|
| 492 |
return title.lower().strip()
|
| 493 |
|
| 494 |
|
| 495 |
# -------- SUMMARIZATION: BOOKS + ARXIV ----------
|
| 496 |
|
| 497 |
+
|
| 498 |
def download_goodreads_descriptions() -> dict[str, dict]:
|
| 499 |
"""
|
| 500 |
Download Goodreads book descriptions - back-cover style blurbs.
|
| 501 |
+
|
| 502 |
These are "what the book is about" descriptions, not plot summaries.
|
| 503 |
Returns dict mapping normalized title -> {title, description}
|
| 504 |
"""
|
| 505 |
print("\nLoading Goodreads book descriptions...")
|
| 506 |
+
|
| 507 |
descriptions = {}
|
| 508 |
+
|
| 509 |
# Try multiple sources
|
| 510 |
datasets_to_try = [
|
| 511 |
"booksouls/goodreads-book-descriptions",
|
| 512 |
"Skelebor/book_titles_and_descriptions_en_clean",
|
| 513 |
]
|
| 514 |
+
|
| 515 |
for ds_name in datasets_to_try:
|
| 516 |
try:
|
| 517 |
print(f" Loading {ds_name}...")
|
| 518 |
ds = load_dataset(ds_name, split="train")
|
| 519 |
+
|
| 520 |
for item in tqdm(ds, desc="Goodreads", leave=False):
|
| 521 |
title = item.get("title", "")
|
| 522 |
description = item.get("description", "")
|
| 523 |
+
|
| 524 |
if not title or not description:
|
| 525 |
continue
|
| 526 |
+
|
| 527 |
# Skip very short descriptions (not useful for training)
|
| 528 |
if len(description) < 100:
|
| 529 |
continue
|
| 530 |
+
|
| 531 |
# Skip very long descriptions (truncate later)
|
| 532 |
if len(description) > 2000:
|
| 533 |
description = description[:2000]
|
| 534 |
+
|
| 535 |
# Skip plays and excluded titles
|
| 536 |
if is_excluded_title(title):
|
| 537 |
continue
|
| 538 |
+
|
| 539 |
# Skip non-English descriptions
|
| 540 |
if not is_english_text(description):
|
| 541 |
continue
|
| 542 |
+
|
| 543 |
norm_title = normalize_title(title)
|
| 544 |
if norm_title and norm_title not in descriptions:
|
| 545 |
descriptions[norm_title] = {
|
| 546 |
"title": title,
|
| 547 |
"description": description,
|
| 548 |
}
|
| 549 |
+
|
| 550 |
print(f" Loaded {len(descriptions):,} descriptions from {ds_name}")
|
| 551 |
except Exception as e:
|
| 552 |
print(f" {ds_name} failed: {e}")
|
| 553 |
+
|
| 554 |
print(f" Total: {len(descriptions):,} unique book descriptions")
|
| 555 |
return descriptions
|
| 556 |
|
| 557 |
|
| 558 |
def download_book_descriptions(
|
| 559 |
+
goodreads_descriptions: dict[str, dict], max_samples: int = 20000
|
|
|
|
| 560 |
) -> list[dict[str, Any]]:
|
| 561 |
"""
|
| 562 |
Download book description data by matching Gutenberg texts with Goodreads descriptions.
|
| 563 |
+
|
| 564 |
This gives us (book_excerpt, book_description) training pairs where descriptions
|
| 565 |
are back-cover style "what is this book about" blurbs, not plot summaries.
|
| 566 |
"""
|
| 567 |
print("\nMatching Gutenberg books with Goodreads descriptions...")
|
| 568 |
+
|
| 569 |
try:
|
| 570 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
| 571 |
except Exception:
|
| 572 |
gutenberg = load_dataset("pg19", split="train")
|
| 573 |
+
|
| 574 |
records: list[dict[str, Any]] = []
|
| 575 |
matched_titles = set()
|
| 576 |
skipped_quality = 0
|
| 577 |
skipped_play = 0
|
| 578 |
+
|
| 579 |
indices = list(range(len(gutenberg)))
|
| 580 |
random.shuffle(indices)
|
| 581 |
+
|
| 582 |
for i in tqdm(indices, desc="Matching books", leave=False):
|
| 583 |
if len(records) >= max_samples:
|
| 584 |
break
|
| 585 |
+
|
| 586 |
item = gutenberg[i]
|
| 587 |
text = item.get("TEXT", "") or item.get("text", "")
|
| 588 |
metadata_raw = item.get("METADATA", "") or "{}"
|
| 589 |
+
|
| 590 |
# Parse metadata
|
| 591 |
try:
|
| 592 |
metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
|
| 593 |
except (json.JSONDecodeError, TypeError):
|
| 594 |
metadata = {}
|
| 595 |
+
|
| 596 |
# Get title
|
| 597 |
title = metadata.get("title", "") if isinstance(metadata, dict) else ""
|
| 598 |
if not title:
|
| 599 |
continue
|
| 600 |
+
|
| 601 |
# Check if we have a Goodreads description for this book
|
| 602 |
norm_title = normalize_title(title)
|
| 603 |
if norm_title not in goodreads_descriptions:
|
| 604 |
continue
|
| 605 |
+
|
| 606 |
# Skip if already matched this book
|
| 607 |
if norm_title in matched_titles:
|
| 608 |
continue
|
| 609 |
+
|
| 610 |
goodreads_data = goodreads_descriptions[norm_title]
|
| 611 |
+
|
| 612 |
# Skip plays and excluded titles
|
| 613 |
if is_excluded_title(title):
|
| 614 |
skipped_play += 1
|
| 615 |
continue
|
| 616 |
+
|
| 617 |
if not text or len(text) < 2000:
|
| 618 |
continue
|
| 619 |
+
|
| 620 |
# Get a clean excerpt from the book (skip front matter)
|
| 621 |
+
paragraphs = re.split(r"\n\s*\n", text)
|
| 622 |
excerpt_parts = []
|
| 623 |
total_len = 0
|
| 624 |
+
|
| 625 |
for para in paragraphs[10:]: # Skip front matter
|
| 626 |
para = para.strip()
|
| 627 |
if len(para) < 100:
|
| 628 |
continue
|
| 629 |
+
|
| 630 |
# Quality check on paragraph
|
| 631 |
if not is_english_text(para):
|
| 632 |
continue
|
|
|
|
| 636 |
if not is_quality_text(para) and len(para) > 300:
|
| 637 |
skipped_quality += 1
|
| 638 |
continue
|
| 639 |
+
|
| 640 |
excerpt_parts.append(para)
|
| 641 |
total_len += len(para)
|
| 642 |
+
|
| 643 |
if total_len >= 3000:
|
| 644 |
break
|
| 645 |
+
|
| 646 |
if total_len < 1000:
|
| 647 |
continue
|
| 648 |
+
|
| 649 |
book_excerpt = "\n\n".join(excerpt_parts)[:4000]
|
| 650 |
matched_titles.add(norm_title)
|
| 651 |
+
|
| 652 |
+
records.append(
|
| 653 |
+
{
|
| 654 |
+
"source": book_excerpt,
|
| 655 |
+
"summary": goodreads_data["description"][:800], # Back-cover blurbs are shorter
|
| 656 |
+
"type": "literary",
|
| 657 |
+
"title": goodreads_data["title"],
|
| 658 |
+
}
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
print(f" Matched {len(records):,} books with descriptions")
|
| 662 |
print(f" Skipped: {skipped_quality} quality, {skipped_play} plays")
|
| 663 |
+
|
| 664 |
return records
|
| 665 |
|
| 666 |
|
| 667 |
# Keep BookSum for additional literary training (chapter summaries are still useful)
|
| 668 |
def download_booksum(max_samples: int = 20000) -> list[dict[str, Any]]:
|
| 669 |
"""Download BookSum - literary chapter summarization (English only, quality filtered).
|
| 670 |
+
|
| 671 |
Note: These are chapter-level plot summaries, useful as supplementary training data.
|
| 672 |
The primary book training comes from Goodreads descriptions (back-cover style).
|
| 673 |
"""
|
| 674 |
print("\nLoading BookSum (supplementary literary data)...")
|
| 675 |
+
|
| 676 |
all_records: list[dict[str, Any]] = []
|
| 677 |
booksum = load_dataset("kmfoda/booksum")
|
| 678 |
+
|
| 679 |
for split_name in booksum.keys():
|
| 680 |
split = str(split_name)
|
| 681 |
data = booksum[split_name]
|
| 682 |
limit = max_samples if "train" in split else max_samples // 10
|
| 683 |
indices = random.sample(range(len(data)), min(len(data), limit))
|
| 684 |
+
|
| 685 |
records = []
|
| 686 |
skipped_language = 0
|
| 687 |
skipped_excluded = 0
|
| 688 |
skipped_play = 0
|
| 689 |
+
|
| 690 |
for i in tqdm(indices, desc=f"BookSum {split}", leave=False):
|
| 691 |
item = data[i]
|
| 692 |
chapter = item.get("chapter", "")
|
| 693 |
summary = item.get("summary_text") or item.get("summary", "")
|
| 694 |
+
|
| 695 |
# Extract book title from book_id (e.g., "The Last of the Mohicans.chapters 1-2")
|
| 696 |
book_id = item.get("book_id", "")
|
| 697 |
book_title = book_id.split(".")[0] if "." in book_id else book_id
|
| 698 |
chapter_name = item.get("summary_id", "") or item.get("summary_name", "")
|
| 699 |
+
|
| 700 |
if not (chapter and summary and len(chapter) > 300):
|
| 701 |
continue
|
| 702 |
+
|
| 703 |
# Filter: excluded titles (Shakespeare, plays, etc.)
|
| 704 |
if is_excluded_title(book_title):
|
| 705 |
skipped_excluded += 1
|
| 706 |
continue
|
| 707 |
+
|
| 708 |
# Filter: play text format
|
| 709 |
if is_play_text(chapter):
|
| 710 |
skipped_play += 1
|
| 711 |
continue
|
| 712 |
+
|
| 713 |
# Filter: English only
|
| 714 |
if not is_english_text(chapter):
|
| 715 |
skipped_language += 1
|
| 716 |
continue
|
| 717 |
+
|
| 718 |
# Filter: quality text
|
| 719 |
if not is_quality_text(chapter):
|
| 720 |
continue
|
| 721 |
+
|
| 722 |
+
records.append(
|
| 723 |
+
{
|
| 724 |
+
"source": chapter[:4000],
|
| 725 |
+
"summary": summary,
|
| 726 |
+
"type": "literary",
|
| 727 |
+
"split": split,
|
| 728 |
+
"title": book_title,
|
| 729 |
+
"chapter": chapter_name,
|
| 730 |
+
}
|
| 731 |
+
)
|
| 732 |
all_records.extend(records)
|
| 733 |
+
print(
|
| 734 |
+
f" {split}: {len(records):,} (skipped {skipped_language} non-English, {skipped_excluded} excluded, {skipped_play} plays)"
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
return all_records
|
| 738 |
|
| 739 |
|
| 740 |
def clean_arxiv_text(text: str) -> str:
|
| 741 |
"""Clean arXiv LaTeX-style text to make it more readable."""
|
| 742 |
import re
|
| 743 |
+
|
| 744 |
# Remove LaTeX math placeholders
|
| 745 |
+
text = re.sub(r"@xmath\d+", "", text)
|
| 746 |
+
text = re.sub(r"@xcite", "", text)
|
| 747 |
# Remove excessive whitespace
|
| 748 |
+
text = re.sub(r"\s+", " ", text)
|
| 749 |
# Remove LaTeX commands
|
| 750 |
+
text = re.sub(r"\\[a-zA-Z]+\{[^}]*\}", "", text)
|
| 751 |
+
text = re.sub(r"\\[a-zA-Z]+", "", text)
|
| 752 |
return text.strip()
|
| 753 |
|
| 754 |
|
|
|
|
| 756 |
"""Extract a meaningful title from the first sentence of an abstract."""
|
| 757 |
# Clean the abstract first
|
| 758 |
abstract = clean_arxiv_text(abstract)
|
| 759 |
+
|
| 760 |
# Get the first sentence (up to first period, question mark, or newline)
|
| 761 |
+
first_sentence = re.split(r"[.!?\n]", abstract)[0].strip()
|
| 762 |
+
|
| 763 |
# Truncate if too long
|
| 764 |
if len(first_sentence) > 100:
|
| 765 |
# Try to cut at a natural word boundary
|
| 766 |
+
first_sentence = first_sentence[:100].rsplit(" ", 1)[0] + "..."
|
| 767 |
+
|
| 768 |
# Capitalize first letter
|
| 769 |
if first_sentence:
|
| 770 |
first_sentence = first_sentence[0].upper() + first_sentence[1:]
|
| 771 |
+
|
| 772 |
return first_sentence or "Untitled Paper"
|
| 773 |
|
| 774 |
|
|
|
|
| 776 |
"""
|
| 777 |
Download arXiv papers for academic summarization only (English only).
|
| 778 |
Note: This dataset doesn't have categories, so can't be used for topic classification.
|
| 779 |
+
|
| 780 |
Returns: summarization_records
|
| 781 |
"""
|
| 782 |
print("\nLoading arXiv (academic papers for summarization)...")
|
| 783 |
+
|
| 784 |
print(" Loading dataset (this may take a minute)...")
|
| 785 |
arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
|
| 786 |
+
|
| 787 |
summ_records: list[dict[str, Any]] = []
|
| 788 |
skipped_language = 0
|
| 789 |
+
|
| 790 |
indices = list(range(len(arxiv)))
|
| 791 |
random.shuffle(indices)
|
| 792 |
+
|
| 793 |
print(" Processing papers...")
|
| 794 |
+
for i in tqdm(indices[: max_samples * 2], desc="arXiv", leave=False):
|
| 795 |
if len(summ_records) >= max_samples:
|
| 796 |
break
|
| 797 |
+
|
| 798 |
item = arxiv[i]
|
| 799 |
+
|
| 800 |
# Get abstract and article
|
| 801 |
abstract = item.get("abstract", "")
|
| 802 |
article = item.get("article", "")
|
| 803 |
+
|
| 804 |
if not abstract or len(abstract) < 100:
|
| 805 |
continue
|
| 806 |
+
|
| 807 |
# Clean LaTeX artifacts
|
| 808 |
abstract = clean_arxiv_text(abstract)
|
| 809 |
article = clean_arxiv_text(article)
|
| 810 |
+
|
| 811 |
# Skip if still has too many weird characters after cleaning
|
| 812 |
+
if "@" in abstract or "@" in article[:500]:
|
| 813 |
continue
|
| 814 |
+
|
| 815 |
# Filter: English only
|
| 816 |
if not is_english_text(article[:1000]):
|
| 817 |
skipped_language += 1
|
| 818 |
continue
|
| 819 |
+
|
| 820 |
# Summarization: article → abstract
|
| 821 |
if article and len(article) > 500:
|
| 822 |
# Extract title from abstract
|
| 823 |
paper_title = extract_paper_title(abstract)
|
| 824 |
+
|
| 825 |
+
summ_records.append(
|
| 826 |
+
{
|
| 827 |
+
"source": article[:4000],
|
| 828 |
+
"summary": abstract,
|
| 829 |
+
"type": "academic",
|
| 830 |
+
"title": paper_title,
|
| 831 |
+
}
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
print(f" Summarization: {len(summ_records):,} (skipped {skipped_language} non-English)")
|
| 835 |
+
|
| 836 |
return summ_records
|
| 837 |
|
| 838 |
|
| 839 |
def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, Any]]:
|
| 840 |
"""
|
| 841 |
Download topic classification data from multiple sources with real categories.
|
| 842 |
+
|
| 843 |
Sources:
|
| 844 |
- 20 Newsgroups (classic topic classification)
|
| 845 |
- Wikipedia (article categories)
|
| 846 |
"""
|
| 847 |
print("\nLoading topic classification datasets...")
|
| 848 |
+
|
| 849 |
records: list[dict[str, Any]] = []
|
| 850 |
+
|
| 851 |
# 20 Newsgroups - classic topic dataset
|
| 852 |
print(" Loading 20 Newsgroups...")
|
| 853 |
try:
|
| 854 |
newsgroups = load_dataset("SetFit/20_newsgroups", split="train")
|
| 855 |
+
|
| 856 |
# Map 20 newsgroups categories to our 8 topics
|
| 857 |
newsgroup_map = {
|
| 858 |
# Science
|
| 859 |
+
"sci.crypt": "Science",
|
| 860 |
+
"sci.electronics": "Science",
|
| 861 |
+
"sci.med": "Science",
|
| 862 |
+
"sci.space": "Science",
|
| 863 |
+
# Technology
|
| 864 |
+
"comp.graphics": "Technology",
|
| 865 |
+
"comp.os.ms-windows.misc": "Technology",
|
| 866 |
+
"comp.sys.ibm.pc.hardware": "Technology",
|
| 867 |
+
"comp.sys.mac.hardware": "Technology",
|
| 868 |
"comp.windows.x": "Technology",
|
| 869 |
# Philosophy/Religion
|
| 870 |
+
"alt.atheism": "Philosophy",
|
| 871 |
+
"soc.religion.christian": "Philosophy",
|
| 872 |
"talk.religion.misc": "Philosophy",
|
| 873 |
# History/Politics
|
| 874 |
+
"talk.politics.guns": "History",
|
| 875 |
+
"talk.politics.mideast": "History",
|
| 876 |
"talk.politics.misc": "History",
|
| 877 |
# Business
|
| 878 |
"misc.forsale": "Business",
|
| 879 |
# Sports/Recreation
|
| 880 |
+
"rec.autos": "Arts",
|
| 881 |
+
"rec.motorcycles": "Arts",
|
| 882 |
+
"rec.sport.baseball": "Arts",
|
| 883 |
+
"rec.sport.hockey": "Arts",
|
| 884 |
}
|
| 885 |
+
|
| 886 |
for item in tqdm(newsgroups, desc="20 Newsgroups", leave=False):
|
| 887 |
if len(records) >= max_samples:
|
| 888 |
break
|
| 889 |
label_name = item.get("label_text", "")
|
| 890 |
text = item.get("text", "")
|
| 891 |
+
|
| 892 |
if label_name in newsgroup_map and text and len(text) > 100:
|
| 893 |
+
records.append(
|
| 894 |
+
{
|
| 895 |
+
"text": text[:1500],
|
| 896 |
+
"topic": newsgroup_map[label_name],
|
| 897 |
+
"source": "newsgroups",
|
| 898 |
+
}
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
print(f" 20 Newsgroups: {len(records):,}")
|
| 902 |
except Exception as e:
|
| 903 |
print(f" 20 Newsgroups failed: {e}")
|
| 904 |
+
|
| 905 |
# Add from Gutenberg for Fiction
|
| 906 |
gutenberg_topics = download_gutenberg_topics(max_samples // 4)
|
| 907 |
records.extend(gutenberg_topics)
|
| 908 |
+
|
| 909 |
# Add from scientific papers abstract dataset for more Science/Tech
|
| 910 |
print(" Loading scientific papers...")
|
| 911 |
try:
|
| 912 |
sci_papers = load_dataset("scientific_papers", "arxiv", split="train", streaming=True)
|
| 913 |
sci_count = 0
|
| 914 |
+
for item in tqdm(sci_papers, desc="Scientific papers", leave=False, total=max_samples // 4):
|
| 915 |
if sci_count >= max_samples // 4:
|
| 916 |
break
|
| 917 |
abstract = item.get("abstract", "")
|
| 918 |
if abstract and len(abstract) > 100:
|
| 919 |
# Alternate between Science and Technology
|
| 920 |
topic = "Science" if sci_count % 2 == 0 else "Technology"
|
| 921 |
+
records.append(
|
| 922 |
+
{
|
| 923 |
+
"text": abstract[:1500],
|
| 924 |
+
"topic": topic,
|
| 925 |
+
"source": "scientific_papers",
|
| 926 |
+
}
|
| 927 |
+
)
|
| 928 |
sci_count += 1
|
| 929 |
print(f" Scientific papers: {sci_count:,}")
|
| 930 |
except Exception as e:
|
| 931 |
print(f" Scientific papers failed: {e}")
|
| 932 |
+
|
| 933 |
return records
|
| 934 |
|
| 935 |
|
| 936 |
def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> None:
|
| 937 |
"""Download all summarization data (books + arxiv, NO news).
|
| 938 |
+
|
| 939 |
Book data now uses Goodreads descriptions (back-cover blurbs) instead of
|
| 940 |
plot summaries. This trains the model to describe "what the book is about"
|
| 941 |
rather than summarizing the plot.
|
| 942 |
"""
|
| 943 |
print("\nDownloading Summarization Data...")
|
| 944 |
out_dir = OUTPUT_DIR / "summarization"
|
| 945 |
+
|
| 946 |
all_records: list[dict[str, Any]] = []
|
| 947 |
+
|
| 948 |
# Goodreads descriptions - primary book training data (back-cover style)
|
| 949 |
goodreads_descriptions = download_goodreads_descriptions()
|
| 950 |
book_records = download_book_descriptions(goodreads_descriptions, max_books)
|
| 951 |
all_records.extend(book_records)
|
| 952 |
+
|
| 953 |
# Optional: Add some BookSum for additional literary variety
|
| 954 |
# These are chapter summaries, not back-cover style, so keep limited
|
| 955 |
# booksum_records = download_booksum(max_books // 4)
|
| 956 |
# all_records.extend(booksum_records)
|
| 957 |
+
|
| 958 |
# arXiv - academic (abstracts are already "what is this paper about")
|
| 959 |
arxiv_summ = download_arxiv_summarization(max_arxiv)
|
| 960 |
all_records.extend(arxiv_summ)
|
| 961 |
+
|
| 962 |
# Shuffle and split
|
| 963 |
random.shuffle(all_records)
|
| 964 |
+
|
| 965 |
# Split by original split if available, else 90/5/5
|
| 966 |
+
train_records = [
|
| 967 |
+
r for r in all_records if r.get("split", "train") == "train" or "split" not in r
|
| 968 |
+
]
|
| 969 |
val_records = [r for r in all_records if r.get("split") == "validation"]
|
| 970 |
test_records = [r for r in all_records if r.get("split") == "test"]
|
| 971 |
+
|
| 972 |
# If no split info, do 90/5/5
|
| 973 |
if len(val_records) < 100:
|
| 974 |
n = len(train_records)
|
| 975 |
random.shuffle(train_records)
|
| 976 |
+
val_records = train_records[int(n * 0.9) : int(n * 0.95)]
|
| 977 |
+
test_records = train_records[int(n * 0.95) :]
|
| 978 |
+
train_records = train_records[: int(n * 0.9)]
|
| 979 |
+
|
| 980 |
# Remove split key before saving
|
| 981 |
for r in train_records + val_records + test_records:
|
| 982 |
r.pop("split", None)
|
| 983 |
+
|
| 984 |
write_jsonl(train_records, out_dir / "train.jsonl", "train")
|
| 985 |
write_jsonl(val_records, out_dir / "validation.jsonl", "val")
|
| 986 |
write_jsonl(test_records, out_dir / "test.jsonl", "test")
|
| 987 |
+
|
| 988 |
# Print breakdown
|
| 989 |
+
literary_count = sum(
|
| 990 |
+
1 for r in train_records + val_records + test_records if r.get("type") == "literary"
|
| 991 |
+
)
|
| 992 |
+
academic_count = sum(
|
| 993 |
+
1 for r in train_records + val_records + test_records if r.get("type") == "academic"
|
| 994 |
+
)
|
| 995 |
print(f"\n Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
|
| 996 |
print(f" Literary (book descriptions): {literary_count:,}")
|
| 997 |
print(f" Academic (paper abstracts): {academic_count:,}")
|
|
|
|
| 999 |
|
| 1000 |
# ------------ TOPIC CLASSIFICATION ------------
|
| 1001 |
|
| 1002 |
+
|
| 1003 |
def download_topics(max_samples: int = 50000) -> None:
|
| 1004 |
"""
|
| 1005 |
Download topic classification data from multiple sources.
|
| 1006 |
+
|
| 1007 |
Sources:
|
| 1008 |
- 20 Newsgroups (classic topic dataset)
|
| 1009 |
- Gutenberg books (Fiction)
|
|
|
|
| 1011 |
"""
|
| 1012 |
print("\nDownloading Topic Classification...")
|
| 1013 |
out_dir = OUTPUT_DIR / "topic"
|
| 1014 |
+
|
| 1015 |
# Get topic records from various sources
|
| 1016 |
all_records = download_topics_from_datasets(max_samples)
|
| 1017 |
+
|
| 1018 |
# Balance topics
|
| 1019 |
topic_counts: dict[str, list] = {t: [] for t in TOPIC_LABELS}
|
| 1020 |
for r in all_records:
|
| 1021 |
topic = r.get("topic")
|
| 1022 |
if topic in topic_counts:
|
| 1023 |
topic_counts[topic].append(r)
|
| 1024 |
+
|
| 1025 |
# Print distribution before balancing
|
| 1026 |
print("\n Topic distribution (before balancing):")
|
| 1027 |
for topic, records in topic_counts.items():
|
| 1028 |
print(f" {topic}: {len(records):,}")
|
| 1029 |
+
|
| 1030 |
# Balance to min count (with some tolerance) - only from topics that have data
|
| 1031 |
counts_with_data = [len(v) for v in topic_counts.values() if v]
|
| 1032 |
if not counts_with_data:
|
| 1033 |
print(" Warning: No topic data found!")
|
| 1034 |
return
|
| 1035 |
+
|
| 1036 |
min_count = min(counts_with_data)
|
| 1037 |
target_count = min(min_count, max_samples // len(TOPIC_LABELS))
|
| 1038 |
+
|
| 1039 |
balanced: list[dict[str, Any]] = []
|
| 1040 |
for _topic, records in topic_counts.items():
|
| 1041 |
if records:
|
| 1042 |
random.shuffle(records)
|
| 1043 |
balanced.extend(records[:target_count])
|
| 1044 |
+
|
| 1045 |
random.shuffle(balanced)
|
| 1046 |
+
|
| 1047 |
# Split 90/5/5
|
| 1048 |
n = len(balanced)
|
| 1049 |
+
train_records = balanced[: int(n * 0.9)]
|
| 1050 |
+
val_records = balanced[int(n * 0.9) : int(n * 0.95)]
|
| 1051 |
+
test_records = balanced[int(n * 0.95) :]
|
| 1052 |
+
|
| 1053 |
write_jsonl(train_records, out_dir / "train.jsonl", "train")
|
| 1054 |
write_jsonl(val_records, out_dir / "validation.jsonl", "val")
|
| 1055 |
write_jsonl(test_records, out_dir / "test.jsonl", "test")
|
| 1056 |
+
|
| 1057 |
# Save labels - only labels that have data
|
| 1058 |
used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
|
| 1059 |
(out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
|
|
|
|
| 1063 |
def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
|
| 1064 |
"""Extract topic-labeled samples from Gutenberg books (English only)."""
|
| 1065 |
print("\nLoading Gutenberg for topic classification...")
|
| 1066 |
+
|
| 1067 |
try:
|
| 1068 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
| 1069 |
except Exception:
|
| 1070 |
print(" Trying pg19...")
|
| 1071 |
gutenberg = load_dataset("pg19", split="train")
|
| 1072 |
+
|
| 1073 |
records: list[dict[str, Any]] = []
|
| 1074 |
skipped_language = 0
|
| 1075 |
+
|
| 1076 |
indices = list(range(len(gutenberg)))
|
| 1077 |
random.shuffle(indices)
|
| 1078 |
+
|
| 1079 |
for i in tqdm(indices, desc="Gutenberg topics", leave=False):
|
| 1080 |
if len(records) >= max_samples:
|
| 1081 |
break
|
| 1082 |
+
|
| 1083 |
item = gutenberg[i]
|
| 1084 |
text = item.get("TEXT", "") or item.get("text", "")
|
| 1085 |
metadata = item.get("METADATA", {}) or {}
|
| 1086 |
+
|
| 1087 |
if not text or len(text) < 1000:
|
| 1088 |
continue
|
| 1089 |
+
|
| 1090 |
# Try to determine topic from metadata
|
| 1091 |
subjects = ""
|
| 1092 |
if isinstance(metadata, dict):
|
| 1093 |
subjects = str(metadata.get("subjects", "")).lower()
|
| 1094 |
subjects += " " + str(metadata.get("subject", "")).lower()
|
| 1095 |
subjects += " " + str(metadata.get("category", "")).lower()
|
| 1096 |
+
|
| 1097 |
topic = None
|
| 1098 |
for keyword, mapped_topic in GUTENBERG_SUBJECT_MAP.items():
|
| 1099 |
if keyword in subjects:
|
| 1100 |
topic = mapped_topic
|
| 1101 |
break
|
| 1102 |
+
|
| 1103 |
# Default fiction for novels without clear subject
|
| 1104 |
if not topic and ("novel" in subjects or not subjects.strip()):
|
| 1105 |
topic = "Fiction"
|
| 1106 |
+
|
| 1107 |
if topic:
|
| 1108 |
# Get a clean paragraph as sample
|
| 1109 |
+
paragraphs = re.split(r"\n\s*\n", text)
|
| 1110 |
for para in paragraphs[5:]: # Skip front matter
|
| 1111 |
para = para.strip()
|
| 1112 |
+
if 200 < len(para) < 1500 and para.count(".") >= 2:
|
| 1113 |
# Filter: English only
|
| 1114 |
if not is_english_text(para):
|
| 1115 |
skipped_language += 1
|
| 1116 |
break
|
| 1117 |
+
|
| 1118 |
+
records.append(
|
| 1119 |
+
{
|
| 1120 |
+
"text": para,
|
| 1121 |
+
"topic": topic,
|
| 1122 |
+
"source": "gutenberg",
|
| 1123 |
+
}
|
| 1124 |
+
)
|
| 1125 |
break
|
| 1126 |
+
|
| 1127 |
print(f" Gutenberg topics: {len(records):,} (skipped {skipped_language} non-English)")
|
| 1128 |
return records
|
| 1129 |
|
| 1130 |
|
| 1131 |
# ------------ EMOTIONS (unchanged) -------------
|
| 1132 |
|
| 1133 |
+
|
| 1134 |
def download_emotions() -> None:
|
| 1135 |
"""Download GoEmotions for emotion classification."""
|
| 1136 |
print("\nDownloading Emotions (GoEmotions)...")
|
| 1137 |
out_dir = OUTPUT_DIR / "emotion"
|
| 1138 |
+
|
| 1139 |
ds = load_dataset("google-research-datasets/go_emotions", "simplified")
|
| 1140 |
+
|
| 1141 |
for split_name in ds.keys():
|
| 1142 |
split = str(split_name)
|
| 1143 |
data = ds[split_name]
|
| 1144 |
+
|
| 1145 |
records: list[dict[str, Any]] = []
|
| 1146 |
for item in tqdm(data, desc=split, leave=False):
|
| 1147 |
text = item.get("text", "")
|
|
|
|
| 1151 |
if emotions:
|
| 1152 |
records.append({"text": text, "emotions": emotions})
|
| 1153 |
write_jsonl(records, out_dir / f"{split}.jsonl", split)
|
| 1154 |
+
|
| 1155 |
(out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
|
| 1156 |
print(f" {len(EMOTION_LABELS)} emotion labels saved")
|
| 1157 |
|
|
|
|
| 1159 |
# --------------- GUTENBERG BOOKS (for language modeling) ---------------
|
| 1160 |
|
| 1161 |
GUTENBERG_JUNK_PATTERNS = [
|
| 1162 |
+
r"Project Gutenberg",
|
| 1163 |
+
r"www\.gutenberg\.org",
|
| 1164 |
+
r"This ebook is for",
|
| 1165 |
+
r"Gutenberg License",
|
| 1166 |
+
r"^\*\*\* START OF",
|
| 1167 |
+
r"^\*\*\* END OF",
|
| 1168 |
+
r"Produced by",
|
| 1169 |
+
r"Transcriber's Note",
|
| 1170 |
+
r"TABLE OF CONTENTS",
|
| 1171 |
+
r"^\s*CHAPTER\s+[IVXLC\d]+",
|
| 1172 |
+
r"^\s*Chapter\s+[IVXLC\d]+",
|
| 1173 |
+
r"^\s*BOOK\s+[IVXLC\d]+",
|
| 1174 |
+
r"^\s*PREFACE\s*$",
|
| 1175 |
+
r"^\s*INTRODUCTION\s*$",
|
| 1176 |
+
r"E-text prepared by",
|
| 1177 |
+
r"Internet Archive",
|
| 1178 |
+
r"Distributed Proofreaders",
|
| 1179 |
]
|
| 1180 |
GUTENBERG_JUNK_REGEX = re.compile("|".join(GUTENBERG_JUNK_PATTERNS), re.IGNORECASE)
|
| 1181 |
|
|
|
|
| 1186 |
return False
|
| 1187 |
if GUTENBERG_JUNK_REGEX.search(text):
|
| 1188 |
return False
|
| 1189 |
+
if text.count(".") < 2:
|
| 1190 |
return False
|
| 1191 |
uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
|
| 1192 |
if uppercase_ratio > 0.3:
|
|
|
|
| 1205 |
print("\nDownloading Gutenberg Books (English only)...")
|
| 1206 |
out_dir = OUTPUT_DIR / "books"
|
| 1207 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 1208 |
+
|
| 1209 |
try:
|
| 1210 |
gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
|
| 1211 |
except Exception:
|
| 1212 |
gutenberg = load_dataset("pg19", split="train")
|
| 1213 |
+
|
| 1214 |
records: list[dict[str, Any]] = []
|
| 1215 |
indices = list(range(len(gutenberg)))
|
| 1216 |
random.shuffle(indices)
|
| 1217 |
+
|
| 1218 |
for i in tqdm(indices, desc="Books", leave=False):
|
| 1219 |
if len(records) >= max_samples:
|
| 1220 |
break
|
| 1221 |
+
|
| 1222 |
item = gutenberg[i]
|
| 1223 |
text = item.get("TEXT", "") or item.get("text", "")
|
| 1224 |
metadata_raw = item.get("METADATA", "") or "{}"
|
| 1225 |
+
|
| 1226 |
# Parse metadata - it's stored as JSON string
|
| 1227 |
try:
|
| 1228 |
metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
|
| 1229 |
except (json.JSONDecodeError, TypeError):
|
| 1230 |
metadata = {}
|
| 1231 |
+
|
| 1232 |
# Extract title and author
|
| 1233 |
title = metadata.get("title", "") if isinstance(metadata, dict) else ""
|
| 1234 |
author = metadata.get("author", "") if isinstance(metadata, dict) else ""
|
| 1235 |
if not title:
|
| 1236 |
title = item.get("title", f"Unknown Book #{i}")
|
| 1237 |
+
|
| 1238 |
if not text or len(text) < 1000:
|
| 1239 |
continue
|
| 1240 |
+
|
| 1241 |
+
paragraphs = re.split(r"\n\s*\n", text)
|
| 1242 |
for para in paragraphs:
|
| 1243 |
para = para.strip()
|
| 1244 |
if is_clean_prose(para):
|
| 1245 |
+
records.append(
|
| 1246 |
+
{"text": para, "title": title, "author": author, "type": "gutenberg"}
|
| 1247 |
+
)
|
|
|
|
|
|
|
|
|
|
| 1248 |
if len(records) >= max_samples:
|
| 1249 |
break
|
| 1250 |
+
|
| 1251 |
random.shuffle(records)
|
| 1252 |
n = len(records)
|
| 1253 |
+
write_jsonl(records[: int(n * 0.9)], out_dir / "train.jsonl", "train")
|
| 1254 |
+
write_jsonl(records[int(n * 0.9) : int(n * 0.95)], out_dir / "validation.jsonl", "val")
|
| 1255 |
+
write_jsonl(records[int(n * 0.95) :], out_dir / "test.jsonl", "test")
|
| 1256 |
|
| 1257 |
|
| 1258 |
# ------------ MAIN ------------
|
| 1259 |
|
| 1260 |
+
|
| 1261 |
def main() -> None:
|
| 1262 |
parser = argparse.ArgumentParser(description="Download LexiMind datasets")
|
| 1263 |
parser.add_argument(
|
| 1264 |
"--task",
|
| 1265 |
choices=["all", "summarization", "emotion", "topic", "gutenberg"],
|
| 1266 |
default="all",
|
| 1267 |
+
help="Dataset to download",
|
| 1268 |
)
|
| 1269 |
parser.add_argument("--max-books", type=int, default=40000, help="Max BookSum samples")
|
| 1270 |
parser.add_argument("--max-arxiv", type=int, default=50000, help="Max arXiv samples")
|
|
|
|
| 1272 |
parser.add_argument("--max-topics", type=int, default=50000, help="Max topic samples")
|
| 1273 |
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 1274 |
args = parser.parse_args()
|
| 1275 |
+
|
| 1276 |
random.seed(args.seed)
|
| 1277 |
+
|
| 1278 |
print("=" * 60)
|
| 1279 |
print("LexiMind Dataset Download")
|
| 1280 |
print("Books + Academic Papers + Topic Classification")
|
| 1281 |
print("=" * 60)
|
| 1282 |
+
|
| 1283 |
if args.task in ["all", "summarization"]:
|
| 1284 |
download_summarization(args.max_books, args.max_arxiv)
|
| 1285 |
if args.task in ["all", "emotion"]:
|
|
|
|
| 1288 |
download_topics(args.max_topics)
|
| 1289 |
if args.task in ["all", "gutenberg"]:
|
| 1290 |
download_gutenberg(args.max_gutenberg)
|
| 1291 |
+
|
| 1292 |
print("\n" + "=" * 60)
|
| 1293 |
print("Download complete!")
|
| 1294 |
print("=" * 60)
|
scripts/evaluate.py
CHANGED
|
@@ -65,34 +65,34 @@ def evaluate_summarization(
|
|
| 65 |
print("\n" + "=" * 60)
|
| 66 |
print("SUMMARIZATION EVALUATION")
|
| 67 |
print("=" * 60)
|
| 68 |
-
|
| 69 |
# Load data - try to get domain info from the raw JSONL
|
| 70 |
raw_data = []
|
| 71 |
with open(data_path) as f:
|
| 72 |
for line in f:
|
| 73 |
if line.strip():
|
| 74 |
raw_data.append(json.loads(line))
|
| 75 |
-
|
| 76 |
data = load_summarization_jsonl(str(data_path))
|
| 77 |
if max_samples:
|
| 78 |
data = data[:max_samples]
|
| 79 |
raw_data = raw_data[:max_samples]
|
| 80 |
print(f"Evaluating on {len(data)} samples...")
|
| 81 |
-
|
| 82 |
# Generate summaries
|
| 83 |
predictions = []
|
| 84 |
references = []
|
| 85 |
domains = [] # Track domain for per-domain breakdown
|
| 86 |
-
|
| 87 |
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
|
| 88 |
-
batch = data[i:i + batch_size]
|
| 89 |
sources = [ex.source for ex in batch]
|
| 90 |
refs = [ex.summary for ex in batch]
|
| 91 |
-
|
| 92 |
preds = pipeline.summarize(sources)
|
| 93 |
predictions.extend(preds)
|
| 94 |
references.extend(refs)
|
| 95 |
-
|
| 96 |
# Track domain if available
|
| 97 |
for j in range(len(batch)):
|
| 98 |
idx = i + j
|
|
@@ -101,14 +101,14 @@ def evaluate_summarization(
|
|
| 101 |
domains.append(domain)
|
| 102 |
else:
|
| 103 |
domains.append("unknown")
|
| 104 |
-
|
| 105 |
# Calculate overall metrics
|
| 106 |
print("\nCalculating ROUGE scores...")
|
| 107 |
rouge_scores = calculate_rouge(predictions, references)
|
| 108 |
-
|
| 109 |
print("Calculating BLEU score...")
|
| 110 |
bleu = calculate_bleu(predictions, references)
|
| 111 |
-
|
| 112 |
metrics: dict = {
|
| 113 |
"rouge1": rouge_scores["rouge1"],
|
| 114 |
"rouge2": rouge_scores["rouge2"],
|
|
@@ -116,14 +116,14 @@ def evaluate_summarization(
|
|
| 116 |
"bleu4": bleu,
|
| 117 |
"num_samples": len(predictions),
|
| 118 |
}
|
| 119 |
-
|
| 120 |
if include_bertscore:
|
| 121 |
print("Calculating BERTScore (this may take a few minutes)...")
|
| 122 |
bert_scores = calculate_bertscore(predictions, references)
|
| 123 |
metrics["bertscore_precision"] = bert_scores["precision"]
|
| 124 |
metrics["bertscore_recall"] = bert_scores["recall"]
|
| 125 |
metrics["bertscore_f1"] = bert_scores["f1"]
|
| 126 |
-
|
| 127 |
# Per-domain breakdown
|
| 128 |
unique_domains = sorted(set(domains))
|
| 129 |
if len(unique_domains) > 1:
|
|
@@ -150,25 +150,26 @@ def evaluate_summarization(
|
|
| 150 |
dm["bertscore_f1"] = d_bert["f1"]
|
| 151 |
domain_metrics[domain] = dm
|
| 152 |
metrics["per_domain"] = domain_metrics
|
| 153 |
-
|
| 154 |
# Bootstrap confidence intervals
|
| 155 |
if compute_bootstrap:
|
| 156 |
try:
|
| 157 |
from rouge_score import rouge_scorer
|
| 158 |
-
|
|
|
|
| 159 |
per_sample_r1 = []
|
| 160 |
per_sample_rL = []
|
| 161 |
for pred, ref in zip(predictions, references, strict=True):
|
| 162 |
scores = scorer.score(ref, pred)
|
| 163 |
-
per_sample_r1.append(scores[
|
| 164 |
-
per_sample_rL.append(scores[
|
| 165 |
r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
|
| 166 |
rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
|
| 167 |
metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
|
| 168 |
metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
|
| 169 |
except ImportError:
|
| 170 |
pass
|
| 171 |
-
|
| 172 |
# Print results
|
| 173 |
print("\n" + "-" * 40)
|
| 174 |
print("SUMMARIZATION RESULTS:")
|
|
@@ -181,27 +182,29 @@ def evaluate_summarization(
|
|
| 181 |
print(f" BERTScore P: {metrics['bertscore_precision']:.4f}")
|
| 182 |
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
|
| 183 |
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
|
| 184 |
-
|
| 185 |
if "per_domain" in metrics:
|
| 186 |
print("\n Per-Domain Breakdown:")
|
| 187 |
for domain, dm in metrics["per_domain"].items():
|
| 188 |
bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
|
| 189 |
-
print(
|
| 190 |
-
|
|
|
|
|
|
|
| 191 |
if "rouge1_ci" in metrics:
|
| 192 |
ci = metrics["rouge1_ci"]
|
| 193 |
print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 194 |
-
|
| 195 |
# Show examples
|
| 196 |
print("\n" + "-" * 40)
|
| 197 |
print("SAMPLE OUTPUTS:")
|
| 198 |
print("-" * 40)
|
| 199 |
for i in range(min(3, len(predictions))):
|
| 200 |
-
print(f"\nExample {i+1}:")
|
| 201 |
print(f" Source: {data[i].source[:100]}...")
|
| 202 |
print(f" Generated: {predictions[i][:150]}...")
|
| 203 |
print(f" Reference: {references[i][:150]}...")
|
| 204 |
-
|
| 205 |
return metrics
|
| 206 |
|
| 207 |
|
|
@@ -214,62 +217,64 @@ def evaluate_emotion(
|
|
| 214 |
compute_bootstrap: bool = False,
|
| 215 |
) -> dict:
|
| 216 |
"""Evaluate emotion detection with comprehensive multi-label metrics.
|
| 217 |
-
|
| 218 |
Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
|
| 219 |
Optionally tunes per-class thresholds on the evaluation set.
|
| 220 |
"""
|
| 221 |
print("\n" + "=" * 60)
|
| 222 |
print("EMOTION DETECTION EVALUATION")
|
| 223 |
print("=" * 60)
|
| 224 |
-
|
| 225 |
# Load data (returns EmotionExample dataclass objects)
|
| 226 |
data = load_emotion_jsonl(str(data_path))
|
| 227 |
if max_samples:
|
| 228 |
data = data[:max_samples]
|
| 229 |
print(f"Evaluating on {len(data)} samples...")
|
| 230 |
-
|
| 231 |
# Get predictions - collect raw logits for threshold tuning
|
| 232 |
all_preds = []
|
| 233 |
all_refs = []
|
| 234 |
all_logits_list = []
|
| 235 |
-
|
| 236 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
|
| 237 |
-
batch = data[i:i + batch_size]
|
| 238 |
texts = [ex.text for ex in batch]
|
| 239 |
refs = [set(ex.emotions) for ex in batch]
|
| 240 |
-
|
| 241 |
preds = pipeline.predict_emotions(texts)
|
| 242 |
pred_sets = [set(p.labels) for p in preds]
|
| 243 |
-
|
| 244 |
all_preds.extend(pred_sets)
|
| 245 |
all_refs.extend(refs)
|
| 246 |
-
|
| 247 |
# Also get raw logits for threshold tuning
|
| 248 |
if tune_thresholds:
|
| 249 |
encoded = pipeline.tokenizer.batch_encode(texts)
|
| 250 |
input_ids = encoded["input_ids"].to(pipeline.device)
|
| 251 |
attention_mask = encoded["attention_mask"].to(pipeline.device)
|
| 252 |
with torch.inference_mode():
|
| 253 |
-
logits = pipeline.model.forward(
|
|
|
|
|
|
|
| 254 |
all_logits_list.append(logits.cpu())
|
| 255 |
-
|
| 256 |
# Calculate metrics
|
| 257 |
all_emotions = sorted(pipeline.emotion_labels)
|
| 258 |
-
|
| 259 |
def to_binary(emotion_sets, labels):
|
| 260 |
return [[1 if e in es else 0 for e in labels] for es in emotion_sets]
|
| 261 |
-
|
| 262 |
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
|
| 263 |
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
|
| 264 |
-
|
| 265 |
# Core metrics: sample-avg F1, macro F1, micro F1
|
| 266 |
sample_f1 = multilabel_f1(pred_binary, ref_binary)
|
| 267 |
macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
|
| 268 |
micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
|
| 269 |
-
|
| 270 |
# Per-class metrics
|
| 271 |
per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
|
| 272 |
-
|
| 273 |
metrics: dict = {
|
| 274 |
"sample_avg_f1": sample_f1,
|
| 275 |
"macro_f1": macro_f1,
|
|
@@ -278,7 +283,7 @@ def evaluate_emotion(
|
|
| 278 |
"num_classes": len(all_emotions),
|
| 279 |
"per_class": per_class,
|
| 280 |
}
|
| 281 |
-
|
| 282 |
# Per-class threshold tuning
|
| 283 |
if tune_thresholds and all_logits_list:
|
| 284 |
print("\nTuning per-class thresholds...")
|
|
@@ -288,7 +293,7 @@ def evaluate_emotion(
|
|
| 288 |
name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
|
| 289 |
}
|
| 290 |
metrics["tuned_macro_f1"] = tuned_macro_f1
|
| 291 |
-
|
| 292 |
# Also compute tuned sample-avg F1
|
| 293 |
probs = torch.sigmoid(all_logits)
|
| 294 |
tuned_preds = torch.zeros_like(probs)
|
|
@@ -296,7 +301,7 @@ def evaluate_emotion(
|
|
| 296 |
tuned_preds[:, c] = (probs[:, c] >= t).float()
|
| 297 |
metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
|
| 298 |
metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
|
| 299 |
-
|
| 300 |
# Bootstrap confidence intervals
|
| 301 |
if compute_bootstrap:
|
| 302 |
# Compute per-sample F1 for bootstrapping
|
|
@@ -313,7 +318,7 @@ def evaluate_emotion(
|
|
| 313 |
per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
|
| 314 |
mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
|
| 315 |
metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 316 |
-
|
| 317 |
# Print results
|
| 318 |
print("\n" + "-" * 40)
|
| 319 |
print("EMOTION DETECTION RESULTS:")
|
|
@@ -322,23 +327,25 @@ def evaluate_emotion(
|
|
| 322 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 323 |
print(f" Micro F1: {metrics['micro_f1']:.4f}")
|
| 324 |
print(f" Num Classes: {metrics['num_classes']}")
|
| 325 |
-
|
| 326 |
if "tuned_macro_f1" in metrics:
|
| 327 |
print("\n After per-class threshold tuning:")
|
| 328 |
print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
|
| 329 |
print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
|
| 330 |
print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
|
| 331 |
-
|
| 332 |
if "sample_f1_ci" in metrics:
|
| 333 |
ci = metrics["sample_f1_ci"]
|
| 334 |
print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 335 |
-
|
| 336 |
# Print top-10 per-class performance
|
| 337 |
print("\n Per-class F1 (top 10 by support):")
|
| 338 |
sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
|
| 339 |
for name, m in sorted_classes[:10]:
|
| 340 |
-
print(
|
| 341 |
-
|
|
|
|
|
|
|
| 342 |
return metrics
|
| 343 |
|
| 344 |
|
|
@@ -353,61 +360,63 @@ def evaluate_topic(
|
|
| 353 |
print("\n" + "=" * 60)
|
| 354 |
print("TOPIC CLASSIFICATION EVALUATION")
|
| 355 |
print("=" * 60)
|
| 356 |
-
|
| 357 |
# Load data (returns TopicExample dataclass objects)
|
| 358 |
data = load_topic_jsonl(str(data_path))
|
| 359 |
if max_samples:
|
| 360 |
data = data[:max_samples]
|
| 361 |
print(f"Evaluating on {len(data)} samples...")
|
| 362 |
-
|
| 363 |
# Get predictions
|
| 364 |
all_preds = []
|
| 365 |
all_refs = []
|
| 366 |
-
|
| 367 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting topics"):
|
| 368 |
-
batch = data[i:i + batch_size]
|
| 369 |
texts = [ex.text for ex in batch]
|
| 370 |
refs = [ex.topic for ex in batch]
|
| 371 |
-
|
| 372 |
preds = pipeline.predict_topics(texts)
|
| 373 |
pred_labels = [p.label for p in preds]
|
| 374 |
-
|
| 375 |
all_preds.extend(pred_labels)
|
| 376 |
all_refs.extend(refs)
|
| 377 |
-
|
| 378 |
# Calculate metrics
|
| 379 |
accuracy = accuracy_score(all_refs, all_preds)
|
| 380 |
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
|
| 381 |
-
|
| 382 |
metrics: dict = {
|
| 383 |
"accuracy": accuracy,
|
| 384 |
"macro_f1": macro_f1,
|
| 385 |
"num_samples": len(all_preds),
|
| 386 |
}
|
| 387 |
-
|
| 388 |
# Bootstrap confidence intervals for accuracy
|
| 389 |
if compute_bootstrap:
|
| 390 |
-
per_sample_correct = [
|
|
|
|
|
|
|
| 391 |
mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
|
| 392 |
metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 393 |
-
|
| 394 |
# Print results
|
| 395 |
print("\n" + "-" * 40)
|
| 396 |
print("TOPIC CLASSIFICATION RESULTS:")
|
| 397 |
print("-" * 40)
|
| 398 |
-
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 399 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 400 |
-
|
| 401 |
if "accuracy_ci" in metrics:
|
| 402 |
ci = metrics["accuracy_ci"]
|
| 403 |
print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 404 |
-
|
| 405 |
# Classification report
|
| 406 |
print("\n" + "-" * 40)
|
| 407 |
print("PER-CLASS METRICS:")
|
| 408 |
print("-" * 40)
|
| 409 |
print(classification_report(all_refs, all_preds, zero_division=0))
|
| 410 |
-
|
| 411 |
return metrics
|
| 412 |
|
| 413 |
|
|
@@ -418,20 +427,28 @@ def main():
|
|
| 418 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 419 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 420 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 421 |
-
parser.add_argument(
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
parser.add_argument("--summarization-only", action="store_true")
|
| 425 |
parser.add_argument("--emotion-only", action="store_true")
|
| 426 |
parser.add_argument("--topic-only", action="store_true")
|
| 427 |
args = parser.parse_args()
|
| 428 |
-
|
| 429 |
print("=" * 60)
|
| 430 |
print("LexiMind Evaluation")
|
| 431 |
print("=" * 60)
|
| 432 |
-
|
| 433 |
start_time = time.perf_counter()
|
| 434 |
-
|
| 435 |
# Load model
|
| 436 |
print(f"\nLoading model from {args.checkpoint}...")
|
| 437 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -443,12 +460,12 @@ def main():
|
|
| 443 |
print(f" Device: {device}")
|
| 444 |
print(f" Topics: {labels.topic}")
|
| 445 |
print(f" Emotions: {len(labels.emotion)} classes")
|
| 446 |
-
|
| 447 |
results = {}
|
| 448 |
-
|
| 449 |
# Determine which tasks to evaluate
|
| 450 |
eval_all = not (args.summarization_only or args.emotion_only or args.topic_only)
|
| 451 |
-
|
| 452 |
# Evaluate summarization
|
| 453 |
if eval_all or args.summarization_only:
|
| 454 |
val_path = args.data_dir / "summarization" / "validation.jsonl"
|
|
@@ -456,14 +473,15 @@ def main():
|
|
| 456 |
val_path = args.data_dir / "summarization" / "val.jsonl"
|
| 457 |
if val_path.exists():
|
| 458 |
results["summarization"] = evaluate_summarization(
|
| 459 |
-
pipeline,
|
|
|
|
| 460 |
max_samples=args.max_samples,
|
| 461 |
include_bertscore=args.include_bertscore,
|
| 462 |
compute_bootstrap=args.bootstrap,
|
| 463 |
)
|
| 464 |
else:
|
| 465 |
print("Warning: summarization validation data not found, skipping")
|
| 466 |
-
|
| 467 |
# Evaluate emotion
|
| 468 |
if eval_all or args.emotion_only:
|
| 469 |
val_path = args.data_dir / "emotion" / "validation.jsonl"
|
|
@@ -471,14 +489,15 @@ def main():
|
|
| 471 |
val_path = args.data_dir / "emotion" / "val.jsonl"
|
| 472 |
if val_path.exists():
|
| 473 |
results["emotion"] = evaluate_emotion(
|
| 474 |
-
pipeline,
|
|
|
|
| 475 |
max_samples=args.max_samples,
|
| 476 |
tune_thresholds=args.tune_thresholds,
|
| 477 |
compute_bootstrap=args.bootstrap,
|
| 478 |
)
|
| 479 |
else:
|
| 480 |
print("Warning: emotion validation data not found, skipping")
|
| 481 |
-
|
| 482 |
# Evaluate topic
|
| 483 |
if eval_all or args.topic_only:
|
| 484 |
val_path = args.data_dir / "topic" / "validation.jsonl"
|
|
@@ -486,30 +505,31 @@ def main():
|
|
| 486 |
val_path = args.data_dir / "topic" / "val.jsonl"
|
| 487 |
if val_path.exists():
|
| 488 |
results["topic"] = evaluate_topic(
|
| 489 |
-
pipeline,
|
|
|
|
| 490 |
max_samples=args.max_samples,
|
| 491 |
compute_bootstrap=args.bootstrap,
|
| 492 |
)
|
| 493 |
else:
|
| 494 |
print("Warning: topic validation data not found, skipping")
|
| 495 |
-
|
| 496 |
# Save results
|
| 497 |
print("\n" + "=" * 60)
|
| 498 |
print("SAVING RESULTS")
|
| 499 |
print("=" * 60)
|
| 500 |
-
|
| 501 |
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 502 |
with open(args.output, "w") as f:
|
| 503 |
json.dump(results, f, indent=2)
|
| 504 |
print(f" Saved to: {args.output}")
|
| 505 |
-
|
| 506 |
# Final summary
|
| 507 |
elapsed = time.perf_counter() - start_time
|
| 508 |
print("\n" + "=" * 60)
|
| 509 |
print("EVALUATION COMPLETE")
|
| 510 |
print("=" * 60)
|
| 511 |
-
print(f" Time: {elapsed/60:.1f} minutes")
|
| 512 |
-
|
| 513 |
if "summarization" in results:
|
| 514 |
s = results["summarization"]
|
| 515 |
print("\n Summarization:")
|
|
@@ -519,14 +539,14 @@ def main():
|
|
| 519 |
print(f" BLEU-4: {s['bleu4']:.4f}")
|
| 520 |
if "bertscore_f1" in s:
|
| 521 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 522 |
-
|
| 523 |
if "emotion" in results:
|
| 524 |
e = results["emotion"]
|
| 525 |
print("\n Emotion:")
|
| 526 |
print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
|
| 527 |
print(f" Macro F1: {e['macro_f1']:.4f}")
|
| 528 |
print(f" Micro F1: {e['micro_f1']:.4f}")
|
| 529 |
-
|
| 530 |
if "topic" in results:
|
| 531 |
print("\n Topic:")
|
| 532 |
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
|
|
|
|
| 65 |
print("\n" + "=" * 60)
|
| 66 |
print("SUMMARIZATION EVALUATION")
|
| 67 |
print("=" * 60)
|
| 68 |
+
|
| 69 |
# Load data - try to get domain info from the raw JSONL
|
| 70 |
raw_data = []
|
| 71 |
with open(data_path) as f:
|
| 72 |
for line in f:
|
| 73 |
if line.strip():
|
| 74 |
raw_data.append(json.loads(line))
|
| 75 |
+
|
| 76 |
data = load_summarization_jsonl(str(data_path))
|
| 77 |
if max_samples:
|
| 78 |
data = data[:max_samples]
|
| 79 |
raw_data = raw_data[:max_samples]
|
| 80 |
print(f"Evaluating on {len(data)} samples...")
|
| 81 |
+
|
| 82 |
# Generate summaries
|
| 83 |
predictions = []
|
| 84 |
references = []
|
| 85 |
domains = [] # Track domain for per-domain breakdown
|
| 86 |
+
|
| 87 |
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
|
| 88 |
+
batch = data[i : i + batch_size]
|
| 89 |
sources = [ex.source for ex in batch]
|
| 90 |
refs = [ex.summary for ex in batch]
|
| 91 |
+
|
| 92 |
preds = pipeline.summarize(sources)
|
| 93 |
predictions.extend(preds)
|
| 94 |
references.extend(refs)
|
| 95 |
+
|
| 96 |
# Track domain if available
|
| 97 |
for j in range(len(batch)):
|
| 98 |
idx = i + j
|
|
|
|
| 101 |
domains.append(domain)
|
| 102 |
else:
|
| 103 |
domains.append("unknown")
|
| 104 |
+
|
| 105 |
# Calculate overall metrics
|
| 106 |
print("\nCalculating ROUGE scores...")
|
| 107 |
rouge_scores = calculate_rouge(predictions, references)
|
| 108 |
+
|
| 109 |
print("Calculating BLEU score...")
|
| 110 |
bleu = calculate_bleu(predictions, references)
|
| 111 |
+
|
| 112 |
metrics: dict = {
|
| 113 |
"rouge1": rouge_scores["rouge1"],
|
| 114 |
"rouge2": rouge_scores["rouge2"],
|
|
|
|
| 116 |
"bleu4": bleu,
|
| 117 |
"num_samples": len(predictions),
|
| 118 |
}
|
| 119 |
+
|
| 120 |
if include_bertscore:
|
| 121 |
print("Calculating BERTScore (this may take a few minutes)...")
|
| 122 |
bert_scores = calculate_bertscore(predictions, references)
|
| 123 |
metrics["bertscore_precision"] = bert_scores["precision"]
|
| 124 |
metrics["bertscore_recall"] = bert_scores["recall"]
|
| 125 |
metrics["bertscore_f1"] = bert_scores["f1"]
|
| 126 |
+
|
| 127 |
# Per-domain breakdown
|
| 128 |
unique_domains = sorted(set(domains))
|
| 129 |
if len(unique_domains) > 1:
|
|
|
|
| 150 |
dm["bertscore_f1"] = d_bert["f1"]
|
| 151 |
domain_metrics[domain] = dm
|
| 152 |
metrics["per_domain"] = domain_metrics
|
| 153 |
+
|
| 154 |
# Bootstrap confidence intervals
|
| 155 |
if compute_bootstrap:
|
| 156 |
try:
|
| 157 |
from rouge_score import rouge_scorer
|
| 158 |
+
|
| 159 |
+
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
|
| 160 |
per_sample_r1 = []
|
| 161 |
per_sample_rL = []
|
| 162 |
for pred, ref in zip(predictions, references, strict=True):
|
| 163 |
scores = scorer.score(ref, pred)
|
| 164 |
+
per_sample_r1.append(scores["rouge1"].fmeasure)
|
| 165 |
+
per_sample_rL.append(scores["rougeL"].fmeasure)
|
| 166 |
r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
|
| 167 |
rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
|
| 168 |
metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
|
| 169 |
metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
|
| 170 |
except ImportError:
|
| 171 |
pass
|
| 172 |
+
|
| 173 |
# Print results
|
| 174 |
print("\n" + "-" * 40)
|
| 175 |
print("SUMMARIZATION RESULTS:")
|
|
|
|
| 182 |
print(f" BERTScore P: {metrics['bertscore_precision']:.4f}")
|
| 183 |
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
|
| 184 |
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
|
| 185 |
+
|
| 186 |
if "per_domain" in metrics:
|
| 187 |
print("\n Per-Domain Breakdown:")
|
| 188 |
for domain, dm in metrics["per_domain"].items():
|
| 189 |
bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
|
| 190 |
+
print(
|
| 191 |
+
f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
if "rouge1_ci" in metrics:
|
| 195 |
ci = metrics["rouge1_ci"]
|
| 196 |
print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 197 |
+
|
| 198 |
# Show examples
|
| 199 |
print("\n" + "-" * 40)
|
| 200 |
print("SAMPLE OUTPUTS:")
|
| 201 |
print("-" * 40)
|
| 202 |
for i in range(min(3, len(predictions))):
|
| 203 |
+
print(f"\nExample {i + 1}:")
|
| 204 |
print(f" Source: {data[i].source[:100]}...")
|
| 205 |
print(f" Generated: {predictions[i][:150]}...")
|
| 206 |
print(f" Reference: {references[i][:150]}...")
|
| 207 |
+
|
| 208 |
return metrics
|
| 209 |
|
| 210 |
|
|
|
|
| 217 |
compute_bootstrap: bool = False,
|
| 218 |
) -> dict:
|
| 219 |
"""Evaluate emotion detection with comprehensive multi-label metrics.
|
| 220 |
+
|
| 221 |
Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
|
| 222 |
Optionally tunes per-class thresholds on the evaluation set.
|
| 223 |
"""
|
| 224 |
print("\n" + "=" * 60)
|
| 225 |
print("EMOTION DETECTION EVALUATION")
|
| 226 |
print("=" * 60)
|
| 227 |
+
|
| 228 |
# Load data (returns EmotionExample dataclass objects)
|
| 229 |
data = load_emotion_jsonl(str(data_path))
|
| 230 |
if max_samples:
|
| 231 |
data = data[:max_samples]
|
| 232 |
print(f"Evaluating on {len(data)} samples...")
|
| 233 |
+
|
| 234 |
# Get predictions - collect raw logits for threshold tuning
|
| 235 |
all_preds = []
|
| 236 |
all_refs = []
|
| 237 |
all_logits_list = []
|
| 238 |
+
|
| 239 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
|
| 240 |
+
batch = data[i : i + batch_size]
|
| 241 |
texts = [ex.text for ex in batch]
|
| 242 |
refs = [set(ex.emotions) for ex in batch]
|
| 243 |
+
|
| 244 |
preds = pipeline.predict_emotions(texts)
|
| 245 |
pred_sets = [set(p.labels) for p in preds]
|
| 246 |
+
|
| 247 |
all_preds.extend(pred_sets)
|
| 248 |
all_refs.extend(refs)
|
| 249 |
+
|
| 250 |
# Also get raw logits for threshold tuning
|
| 251 |
if tune_thresholds:
|
| 252 |
encoded = pipeline.tokenizer.batch_encode(texts)
|
| 253 |
input_ids = encoded["input_ids"].to(pipeline.device)
|
| 254 |
attention_mask = encoded["attention_mask"].to(pipeline.device)
|
| 255 |
with torch.inference_mode():
|
| 256 |
+
logits = pipeline.model.forward(
|
| 257 |
+
"emotion", {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 258 |
+
)
|
| 259 |
all_logits_list.append(logits.cpu())
|
| 260 |
+
|
| 261 |
# Calculate metrics
|
| 262 |
all_emotions = sorted(pipeline.emotion_labels)
|
| 263 |
+
|
| 264 |
def to_binary(emotion_sets, labels):
|
| 265 |
return [[1 if e in es else 0 for e in labels] for es in emotion_sets]
|
| 266 |
+
|
| 267 |
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
|
| 268 |
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
|
| 269 |
+
|
| 270 |
# Core metrics: sample-avg F1, macro F1, micro F1
|
| 271 |
sample_f1 = multilabel_f1(pred_binary, ref_binary)
|
| 272 |
macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
|
| 273 |
micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
|
| 274 |
+
|
| 275 |
# Per-class metrics
|
| 276 |
per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
|
| 277 |
+
|
| 278 |
metrics: dict = {
|
| 279 |
"sample_avg_f1": sample_f1,
|
| 280 |
"macro_f1": macro_f1,
|
|
|
|
| 283 |
"num_classes": len(all_emotions),
|
| 284 |
"per_class": per_class,
|
| 285 |
}
|
| 286 |
+
|
| 287 |
# Per-class threshold tuning
|
| 288 |
if tune_thresholds and all_logits_list:
|
| 289 |
print("\nTuning per-class thresholds...")
|
|
|
|
| 293 |
name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
|
| 294 |
}
|
| 295 |
metrics["tuned_macro_f1"] = tuned_macro_f1
|
| 296 |
+
|
| 297 |
# Also compute tuned sample-avg F1
|
| 298 |
probs = torch.sigmoid(all_logits)
|
| 299 |
tuned_preds = torch.zeros_like(probs)
|
|
|
|
| 301 |
tuned_preds[:, c] = (probs[:, c] >= t).float()
|
| 302 |
metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
|
| 303 |
metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
|
| 304 |
+
|
| 305 |
# Bootstrap confidence intervals
|
| 306 |
if compute_bootstrap:
|
| 307 |
# Compute per-sample F1 for bootstrapping
|
|
|
|
| 318 |
per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
|
| 319 |
mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
|
| 320 |
metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 321 |
+
|
| 322 |
# Print results
|
| 323 |
print("\n" + "-" * 40)
|
| 324 |
print("EMOTION DETECTION RESULTS:")
|
|
|
|
| 327 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 328 |
print(f" Micro F1: {metrics['micro_f1']:.4f}")
|
| 329 |
print(f" Num Classes: {metrics['num_classes']}")
|
| 330 |
+
|
| 331 |
if "tuned_macro_f1" in metrics:
|
| 332 |
print("\n After per-class threshold tuning:")
|
| 333 |
print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
|
| 334 |
print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
|
| 335 |
print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
|
| 336 |
+
|
| 337 |
if "sample_f1_ci" in metrics:
|
| 338 |
ci = metrics["sample_f1_ci"]
|
| 339 |
print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 340 |
+
|
| 341 |
# Print top-10 per-class performance
|
| 342 |
print("\n Per-class F1 (top 10 by support):")
|
| 343 |
sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
|
| 344 |
for name, m in sorted_classes[:10]:
|
| 345 |
+
print(
|
| 346 |
+
f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
return metrics
|
| 350 |
|
| 351 |
|
|
|
|
| 360 |
print("\n" + "=" * 60)
|
| 361 |
print("TOPIC CLASSIFICATION EVALUATION")
|
| 362 |
print("=" * 60)
|
| 363 |
+
|
| 364 |
# Load data (returns TopicExample dataclass objects)
|
| 365 |
data = load_topic_jsonl(str(data_path))
|
| 366 |
if max_samples:
|
| 367 |
data = data[:max_samples]
|
| 368 |
print(f"Evaluating on {len(data)} samples...")
|
| 369 |
+
|
| 370 |
# Get predictions
|
| 371 |
all_preds = []
|
| 372 |
all_refs = []
|
| 373 |
+
|
| 374 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting topics"):
|
| 375 |
+
batch = data[i : i + batch_size]
|
| 376 |
texts = [ex.text for ex in batch]
|
| 377 |
refs = [ex.topic for ex in batch]
|
| 378 |
+
|
| 379 |
preds = pipeline.predict_topics(texts)
|
| 380 |
pred_labels = [p.label for p in preds]
|
| 381 |
+
|
| 382 |
all_preds.extend(pred_labels)
|
| 383 |
all_refs.extend(refs)
|
| 384 |
+
|
| 385 |
# Calculate metrics
|
| 386 |
accuracy = accuracy_score(all_refs, all_preds)
|
| 387 |
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
|
| 388 |
+
|
| 389 |
metrics: dict = {
|
| 390 |
"accuracy": accuracy,
|
| 391 |
"macro_f1": macro_f1,
|
| 392 |
"num_samples": len(all_preds),
|
| 393 |
}
|
| 394 |
+
|
| 395 |
# Bootstrap confidence intervals for accuracy
|
| 396 |
if compute_bootstrap:
|
| 397 |
+
per_sample_correct = [
|
| 398 |
+
1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)
|
| 399 |
+
]
|
| 400 |
mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
|
| 401 |
metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 402 |
+
|
| 403 |
# Print results
|
| 404 |
print("\n" + "-" * 40)
|
| 405 |
print("TOPIC CLASSIFICATION RESULTS:")
|
| 406 |
print("-" * 40)
|
| 407 |
+
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.1f}%)")
|
| 408 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 409 |
+
|
| 410 |
if "accuracy_ci" in metrics:
|
| 411 |
ci = metrics["accuracy_ci"]
|
| 412 |
print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 413 |
+
|
| 414 |
# Classification report
|
| 415 |
print("\n" + "-" * 40)
|
| 416 |
print("PER-CLASS METRICS:")
|
| 417 |
print("-" * 40)
|
| 418 |
print(classification_report(all_refs, all_preds, zero_division=0))
|
| 419 |
+
|
| 420 |
return metrics
|
| 421 |
|
| 422 |
|
|
|
|
| 427 |
parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
|
| 428 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 429 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--include-bertscore", action="store_true", help="Include BERTScore (slow, optional)"
|
| 432 |
+
)
|
| 433 |
+
parser.add_argument(
|
| 434 |
+
"--tune-thresholds",
|
| 435 |
+
action="store_true",
|
| 436 |
+
help="Tune per-class emotion thresholds on val set",
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
"--bootstrap", action="store_true", help="Compute bootstrap confidence intervals"
|
| 440 |
+
)
|
| 441 |
parser.add_argument("--summarization-only", action="store_true")
|
| 442 |
parser.add_argument("--emotion-only", action="store_true")
|
| 443 |
parser.add_argument("--topic-only", action="store_true")
|
| 444 |
args = parser.parse_args()
|
| 445 |
+
|
| 446 |
print("=" * 60)
|
| 447 |
print("LexiMind Evaluation")
|
| 448 |
print("=" * 60)
|
| 449 |
+
|
| 450 |
start_time = time.perf_counter()
|
| 451 |
+
|
| 452 |
# Load model
|
| 453 |
print(f"\nLoading model from {args.checkpoint}...")
|
| 454 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 460 |
print(f" Device: {device}")
|
| 461 |
print(f" Topics: {labels.topic}")
|
| 462 |
print(f" Emotions: {len(labels.emotion)} classes")
|
| 463 |
+
|
| 464 |
results = {}
|
| 465 |
+
|
| 466 |
# Determine which tasks to evaluate
|
| 467 |
eval_all = not (args.summarization_only or args.emotion_only or args.topic_only)
|
| 468 |
+
|
| 469 |
# Evaluate summarization
|
| 470 |
if eval_all or args.summarization_only:
|
| 471 |
val_path = args.data_dir / "summarization" / "validation.jsonl"
|
|
|
|
| 473 |
val_path = args.data_dir / "summarization" / "val.jsonl"
|
| 474 |
if val_path.exists():
|
| 475 |
results["summarization"] = evaluate_summarization(
|
| 476 |
+
pipeline,
|
| 477 |
+
val_path,
|
| 478 |
max_samples=args.max_samples,
|
| 479 |
include_bertscore=args.include_bertscore,
|
| 480 |
compute_bootstrap=args.bootstrap,
|
| 481 |
)
|
| 482 |
else:
|
| 483 |
print("Warning: summarization validation data not found, skipping")
|
| 484 |
+
|
| 485 |
# Evaluate emotion
|
| 486 |
if eval_all or args.emotion_only:
|
| 487 |
val_path = args.data_dir / "emotion" / "validation.jsonl"
|
|
|
|
| 489 |
val_path = args.data_dir / "emotion" / "val.jsonl"
|
| 490 |
if val_path.exists():
|
| 491 |
results["emotion"] = evaluate_emotion(
|
| 492 |
+
pipeline,
|
| 493 |
+
val_path,
|
| 494 |
max_samples=args.max_samples,
|
| 495 |
tune_thresholds=args.tune_thresholds,
|
| 496 |
compute_bootstrap=args.bootstrap,
|
| 497 |
)
|
| 498 |
else:
|
| 499 |
print("Warning: emotion validation data not found, skipping")
|
| 500 |
+
|
| 501 |
# Evaluate topic
|
| 502 |
if eval_all or args.topic_only:
|
| 503 |
val_path = args.data_dir / "topic" / "validation.jsonl"
|
|
|
|
| 505 |
val_path = args.data_dir / "topic" / "val.jsonl"
|
| 506 |
if val_path.exists():
|
| 507 |
results["topic"] = evaluate_topic(
|
| 508 |
+
pipeline,
|
| 509 |
+
val_path,
|
| 510 |
max_samples=args.max_samples,
|
| 511 |
compute_bootstrap=args.bootstrap,
|
| 512 |
)
|
| 513 |
else:
|
| 514 |
print("Warning: topic validation data not found, skipping")
|
| 515 |
+
|
| 516 |
# Save results
|
| 517 |
print("\n" + "=" * 60)
|
| 518 |
print("SAVING RESULTS")
|
| 519 |
print("=" * 60)
|
| 520 |
+
|
| 521 |
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 522 |
with open(args.output, "w") as f:
|
| 523 |
json.dump(results, f, indent=2)
|
| 524 |
print(f" Saved to: {args.output}")
|
| 525 |
+
|
| 526 |
# Final summary
|
| 527 |
elapsed = time.perf_counter() - start_time
|
| 528 |
print("\n" + "=" * 60)
|
| 529 |
print("EVALUATION COMPLETE")
|
| 530 |
print("=" * 60)
|
| 531 |
+
print(f" Time: {elapsed / 60:.1f} minutes")
|
| 532 |
+
|
| 533 |
if "summarization" in results:
|
| 534 |
s = results["summarization"]
|
| 535 |
print("\n Summarization:")
|
|
|
|
| 539 |
print(f" BLEU-4: {s['bleu4']:.4f}")
|
| 540 |
if "bertscore_f1" in s:
|
| 541 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 542 |
+
|
| 543 |
if "emotion" in results:
|
| 544 |
e = results["emotion"]
|
| 545 |
print("\n Emotion:")
|
| 546 |
print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
|
| 547 |
print(f" Macro F1: {e['macro_f1']:.4f}")
|
| 548 |
print(f" Micro F1: {e['micro_f1']:.4f}")
|
| 549 |
+
|
| 550 |
if "topic" in results:
|
| 551 |
print("\n Topic:")
|
| 552 |
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
|
scripts/profile_training.py
CHANGED
|
@@ -96,10 +96,12 @@ def main(cfg: DictConfig) -> None:
|
|
| 96 |
|
| 97 |
tok_cfg = data_cfg.get("tokenizer", {})
|
| 98 |
max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
|
| 99 |
-
tokenizer = Tokenizer(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
summ_train = SummarizationDataset(summ_splits["train"])
|
| 105 |
emot_train = EmotionDataset(emot_splits["train"])
|
|
@@ -112,23 +114,42 @@ def main(cfg: DictConfig) -> None:
|
|
| 112 |
|
| 113 |
train_loaders = {
|
| 114 |
"summarization": build_summarization_dataloader(
|
| 115 |
-
summ_train,
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
),
|
| 119 |
"emotion": build_emotion_dataloader(
|
| 120 |
-
emot_train,
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
),
|
| 123 |
"topic": build_topic_dataloader(
|
| 124 |
-
topic_train,
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
),
|
| 127 |
}
|
| 128 |
|
| 129 |
# Build model
|
| 130 |
-
grad_ckpt = cfg.training.get(
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
model_cfg = ModelConfig(
|
| 134 |
d_model=cfg.model.d_model,
|
|
@@ -202,8 +223,10 @@ def main(cfg: DictConfig) -> None:
|
|
| 202 |
except StopIteration:
|
| 203 |
iterators[task] = iter(train_loaders[task])
|
| 204 |
batch = next(iterators[task])
|
| 205 |
-
return {
|
| 206 |
-
|
|
|
|
|
|
|
| 207 |
|
| 208 |
def training_step(step):
|
| 209 |
"""One training step across all tasks."""
|
|
@@ -219,7 +242,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 219 |
loss = torch.nn.functional.cross_entropy(
|
| 220 |
logits.view(-1, logits.size(-1)),
|
| 221 |
batch["labels"].view(-1),
|
| 222 |
-
ignore_index=-100,
|
|
|
|
| 223 |
)
|
| 224 |
elif task == "emotion":
|
| 225 |
inputs = {"input_ids": batch["input_ids"]}
|
|
@@ -262,7 +286,10 @@ def main(cfg: DictConfig) -> None:
|
|
| 262 |
torch.profiler.ProfilerActivity.CUDA,
|
| 263 |
],
|
| 264 |
schedule=torch.profiler.schedule(
|
| 265 |
-
wait=1,
|
|
|
|
|
|
|
|
|
|
| 266 |
),
|
| 267 |
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
|
| 268 |
record_shapes=True,
|
|
|
|
| 96 |
|
| 97 |
tok_cfg = data_cfg.get("tokenizer", {})
|
| 98 |
max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
|
| 99 |
+
tokenizer = Tokenizer(
|
| 100 |
+
TokenizerConfig(
|
| 101 |
+
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 102 |
+
max_length=max_len,
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
|
| 106 |
summ_train = SummarizationDataset(summ_splits["train"])
|
| 107 |
emot_train = EmotionDataset(emot_splits["train"])
|
|
|
|
| 114 |
|
| 115 |
train_loaders = {
|
| 116 |
"summarization": build_summarization_dataloader(
|
| 117 |
+
summ_train,
|
| 118 |
+
tokenizer,
|
| 119 |
+
shuffle=True,
|
| 120 |
+
max_source_length=max_len,
|
| 121 |
+
max_target_length=max_len,
|
| 122 |
+
batch_size=batch_size,
|
| 123 |
+
num_workers=num_workers,
|
| 124 |
+
pin_memory=True,
|
| 125 |
),
|
| 126 |
"emotion": build_emotion_dataloader(
|
| 127 |
+
emot_train,
|
| 128 |
+
tokenizer,
|
| 129 |
+
shuffle=True,
|
| 130 |
+
max_length=classification_max_len,
|
| 131 |
+
batch_size=batch_size,
|
| 132 |
+
num_workers=num_workers,
|
| 133 |
+
pin_memory=True,
|
| 134 |
),
|
| 135 |
"topic": build_topic_dataloader(
|
| 136 |
+
topic_train,
|
| 137 |
+
tokenizer,
|
| 138 |
+
shuffle=True,
|
| 139 |
+
max_length=classification_max_len,
|
| 140 |
+
batch_size=batch_size,
|
| 141 |
+
num_workers=num_workers,
|
| 142 |
+
pin_memory=True,
|
| 143 |
),
|
| 144 |
}
|
| 145 |
|
| 146 |
# Build model
|
| 147 |
+
grad_ckpt = cfg.training.get(
|
| 148 |
+
"gradient_checkpointing", cfg.model.get("gradient_checkpointing", False)
|
| 149 |
+
)
|
| 150 |
+
use_rel_pos = cfg.training.get(
|
| 151 |
+
"use_relative_position_bias", cfg.model.get("use_relative_position_bias", False)
|
| 152 |
+
)
|
| 153 |
|
| 154 |
model_cfg = ModelConfig(
|
| 155 |
d_model=cfg.model.d_model,
|
|
|
|
| 223 |
except StopIteration:
|
| 224 |
iterators[task] = iter(train_loaders[task])
|
| 225 |
batch = next(iterators[task])
|
| 226 |
+
return {
|
| 227 |
+
k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
|
| 228 |
+
for k, v in batch.items()
|
| 229 |
+
}
|
| 230 |
|
| 231 |
def training_step(step):
|
| 232 |
"""One training step across all tasks."""
|
|
|
|
| 242 |
loss = torch.nn.functional.cross_entropy(
|
| 243 |
logits.view(-1, logits.size(-1)),
|
| 244 |
batch["labels"].view(-1),
|
| 245 |
+
ignore_index=-100,
|
| 246 |
+
label_smoothing=0.1,
|
| 247 |
)
|
| 248 |
elif task == "emotion":
|
| 249 |
inputs = {"input_ids": batch["input_ids"]}
|
|
|
|
| 286 |
torch.profiler.ProfilerActivity.CUDA,
|
| 287 |
],
|
| 288 |
schedule=torch.profiler.schedule(
|
| 289 |
+
wait=1,
|
| 290 |
+
warmup=2,
|
| 291 |
+
active=active_steps - 3,
|
| 292 |
+
repeat=1,
|
| 293 |
),
|
| 294 |
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
|
| 295 |
record_shapes=True,
|
scripts/train.py
CHANGED
|
@@ -56,6 +56,7 @@ def set_seed(seed: int) -> None:
|
|
| 56 |
import random
|
| 57 |
|
| 58 |
import numpy as np
|
|
|
|
| 59 |
random.seed(seed)
|
| 60 |
np.random.seed(seed)
|
| 61 |
torch.manual_seed(seed)
|
|
@@ -78,20 +79,20 @@ def load_splits(data_dir: Path, loader_fn) -> Dict[str, list]:
|
|
| 78 |
def main(cfg: DictConfig) -> None:
|
| 79 |
"""Main training entry point."""
|
| 80 |
start_time = time.perf_counter()
|
| 81 |
-
|
| 82 |
print("=" * 60)
|
| 83 |
print("LexiMind Training")
|
| 84 |
print("=" * 60)
|
| 85 |
print(OmegaConf.to_yaml(cfg))
|
| 86 |
-
|
| 87 |
set_seed(cfg.seed)
|
| 88 |
device = torch.device(cfg.device)
|
| 89 |
-
|
| 90 |
# GPU optimizations for Ampere+
|
| 91 |
if device.type == "cuda":
|
| 92 |
# Enable cudnn benchmark for fixed-size inputs (10-20% speedup)
|
| 93 |
torch.backends.cudnn.benchmark = True
|
| 94 |
-
|
| 95 |
if torch.cuda.get_device_capability()[0] >= 8:
|
| 96 |
torch.set_float32_matmul_precision("high")
|
| 97 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
@@ -99,18 +100,18 @@ def main(cfg: DictConfig) -> None:
|
|
| 99 |
print(" TF32 + cudnn.benchmark enabled (Ampere GPU)")
|
| 100 |
else:
|
| 101 |
print(" cudnn.benchmark enabled")
|
| 102 |
-
|
| 103 |
# --------------- Load Data ---------------
|
| 104 |
-
|
| 105 |
print("\nLoading datasets...")
|
| 106 |
data_cfg = cfg.data
|
| 107 |
trainer_cfg = cfg.training.get("trainer", {})
|
| 108 |
-
|
| 109 |
# Load splits
|
| 110 |
summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
|
| 111 |
emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
|
| 112 |
topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
|
| 113 |
-
|
| 114 |
# Apply sample limits for dev runs
|
| 115 |
max_train = trainer_cfg.get("max_train_samples")
|
| 116 |
max_val = trainer_cfg.get("max_val_samples")
|
|
@@ -121,86 +122,130 @@ def main(cfg: DictConfig) -> None:
|
|
| 121 |
for splits in [summ_splits, emot_splits, topic_splits]:
|
| 122 |
if "val" in splits:
|
| 123 |
splits["val"] = splits["val"][:max_val]
|
| 124 |
-
|
| 125 |
-
print(
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# --------------- Tokenizer ---------------
|
| 130 |
-
|
| 131 |
tok_cfg = data_cfg.get("tokenizer", {})
|
| 132 |
max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
|
| 133 |
-
|
| 134 |
-
tokenizer = Tokenizer(
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
print(f" Tokenizer: {tokenizer.vocab_size:,} vocab, max_len={max_len}")
|
| 139 |
-
|
| 140 |
# --------------- Datasets ---------------
|
| 141 |
-
|
| 142 |
summ_train = SummarizationDataset(summ_splits["train"])
|
| 143 |
summ_val = SummarizationDataset(summ_splits.get("val", []))
|
| 144 |
emot_train = EmotionDataset(emot_splits["train"])
|
| 145 |
emot_val = EmotionDataset(emot_splits.get("val", []), binarizer=emot_train.binarizer)
|
| 146 |
topic_train = TopicDataset(topic_splits["train"])
|
| 147 |
topic_val = TopicDataset(topic_splits.get("val", []), encoder=topic_train.encoder)
|
| 148 |
-
|
| 149 |
print(f" Emotions: {len(emot_train.emotion_classes)} classes")
|
| 150 |
-
print(
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
# --------------- DataLoaders ---------------
|
| 153 |
-
|
| 154 |
dl_cfg = cfg.training.get("dataloader", {})
|
| 155 |
batch_size = int(dl_cfg.get("batch_size", 8))
|
| 156 |
num_workers = int(dl_cfg.get("num_workers", 4))
|
| 157 |
-
|
| 158 |
# Classification tasks don't need full 512 tokens - 256 is sufficient
|
| 159 |
# This speeds up emotion/topic forward passes significantly
|
| 160 |
classification_max_len = min(256, max_len)
|
| 161 |
-
|
| 162 |
train_loaders = {
|
| 163 |
"summarization": build_summarization_dataloader(
|
| 164 |
-
summ_train,
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
),
|
| 168 |
"emotion": build_emotion_dataloader(
|
| 169 |
-
emot_train,
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
),
|
| 172 |
"topic": build_topic_dataloader(
|
| 173 |
-
topic_train,
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
),
|
| 176 |
}
|
| 177 |
-
|
| 178 |
val_loaders = {}
|
| 179 |
if summ_val:
|
| 180 |
val_loaders["summarization"] = build_summarization_dataloader(
|
| 181 |
-
summ_val,
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
)
|
| 185 |
if emot_val:
|
| 186 |
val_loaders["emotion"] = build_emotion_dataloader(
|
| 187 |
-
emot_val,
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
if topic_val:
|
| 191 |
val_loaders["topic"] = build_topic_dataloader(
|
| 192 |
-
topic_val,
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
)
|
| 195 |
-
|
| 196 |
# --------------- Model ---------------
|
| 197 |
-
|
| 198 |
print("\nBuilding model...")
|
| 199 |
-
|
| 200 |
# Check for overrides in training config
|
| 201 |
-
grad_ckpt = cfg.training.get(
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
model_cfg = ModelConfig(
|
| 205 |
d_model=cfg.model.d_model,
|
| 206 |
vocab_size=getattr(cfg.model, "vocab_size", None),
|
|
@@ -215,42 +260,42 @@ def main(cfg: DictConfig) -> None:
|
|
| 215 |
use_relative_position_bias=use_rel_pos,
|
| 216 |
gradient_checkpointing=grad_ckpt,
|
| 217 |
)
|
| 218 |
-
|
| 219 |
if grad_ckpt:
|
| 220 |
print(" Gradient checkpointing: on")
|
| 221 |
if not use_rel_pos:
|
| 222 |
print(" FlashAttention: on (no relative position bias)")
|
| 223 |
-
|
| 224 |
model = build_multitask_model(
|
| 225 |
tokenizer,
|
| 226 |
num_emotions=len(emot_train.emotion_classes),
|
| 227 |
num_topics=len(topic_train.topic_classes),
|
| 228 |
config=model_cfg,
|
| 229 |
).to(device)
|
| 230 |
-
|
| 231 |
param_count = sum(p.numel() for p in model.parameters())
|
| 232 |
-
print(f" Parameters: {param_count:,} ({param_count/1e6:.1f}M)")
|
| 233 |
-
|
| 234 |
# Freeze lower encoder layers (keeps pretrained language understanding, adapts upper layers)
|
| 235 |
freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
|
| 236 |
if freeze_layers > 0:
|
| 237 |
frozen_params = 0
|
| 238 |
# Freeze embedding layer
|
| 239 |
-
if hasattr(model.encoder,
|
| 240 |
for p in model.encoder.embed_tokens.parameters():
|
| 241 |
p.requires_grad = False
|
| 242 |
frozen_params += p.numel()
|
| 243 |
# Freeze specified number of encoder layers
|
| 244 |
-
if hasattr(model.encoder,
|
| 245 |
for i, layer in enumerate(model.encoder.layers):
|
| 246 |
if i < freeze_layers:
|
| 247 |
for p in layer.parameters():
|
| 248 |
p.requires_grad = False
|
| 249 |
frozen_params += p.numel()
|
| 250 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 251 |
-
print(f" Frozen layers: 0-{freeze_layers-1} ({frozen_params/1e6:.1f}M params)")
|
| 252 |
-
print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
|
| 253 |
-
|
| 254 |
# Resume from checkpoint?
|
| 255 |
start_epoch = 1
|
| 256 |
resume_path = cfg.get("resume_from")
|
|
@@ -258,10 +303,11 @@ def main(cfg: DictConfig) -> None:
|
|
| 258 |
print(f" Resuming from: {resume_path}")
|
| 259 |
load_state(model, str(resume_path))
|
| 260 |
import re
|
|
|
|
| 261 |
digits = re.findall(r"\d+", Path(resume_path).stem)
|
| 262 |
if digits:
|
| 263 |
start_epoch = int(digits[-1]) + 1
|
| 264 |
-
|
| 265 |
# Compile model for speed
|
| 266 |
# Note: "reduce-overhead" mode uses CUDA graphs which conflicts with gradient checkpointing
|
| 267 |
# Use "default" mode when checkpointing is enabled
|
|
@@ -272,13 +318,13 @@ def main(cfg: DictConfig) -> None:
|
|
| 272 |
if cfg.training.get("compile_decoder", True):
|
| 273 |
model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
|
| 274 |
print(f" Decoder compiled ({compile_mode})")
|
| 275 |
-
|
| 276 |
# --------------- Train ---------------
|
| 277 |
-
|
| 278 |
print("\nStarting training...")
|
| 279 |
opt_cfg = cfg.training.get("optimizer", {})
|
| 280 |
sched_cfg = cfg.training.get("scheduler", {})
|
| 281 |
-
|
| 282 |
# Use fused AdamW on CUDA for ~5-10% speedup
|
| 283 |
use_fused = device.type == "cuda" and "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
|
| 284 |
optimizer = torch.optim.AdamW(
|
|
@@ -289,7 +335,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 289 |
)
|
| 290 |
if use_fused:
|
| 291 |
print(" Fused AdamW: on")
|
| 292 |
-
|
| 293 |
trainer = Trainer(
|
| 294 |
model=model,
|
| 295 |
optimizer=optimizer,
|
|
@@ -309,38 +355,38 @@ def main(cfg: DictConfig) -> None:
|
|
| 309 |
device=device,
|
| 310 |
tokenizer=tokenizer,
|
| 311 |
)
|
| 312 |
-
|
| 313 |
# Checkpoint callback
|
| 314 |
ckpt_dir = Path(cfg.checkpoint_out).parent
|
| 315 |
-
best_val_loss = float(
|
| 316 |
-
|
| 317 |
def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
|
| 318 |
nonlocal best_val_loss
|
| 319 |
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 320 |
-
|
| 321 |
# Save epoch checkpoint
|
| 322 |
save_state(model, str(ckpt_dir / f"epoch_{epoch}.pt"))
|
| 323 |
-
|
| 324 |
# Track best
|
| 325 |
val_key = f"val_epoch_{epoch}"
|
| 326 |
if val_key in history:
|
| 327 |
-
val_loss = history[val_key].get("total_loss", float(
|
| 328 |
if val_loss < best_val_loss:
|
| 329 |
best_val_loss = val_loss
|
| 330 |
save_state(model, str(ckpt_dir / "best.pt"))
|
| 331 |
print(f" New best model saved (val_loss={val_loss:.4f})")
|
| 332 |
-
|
| 333 |
history = trainer.fit(
|
| 334 |
train_loaders,
|
| 335 |
val_loaders if val_loaders else None,
|
| 336 |
checkpoint_callback=save_checkpoint,
|
| 337 |
start_epoch=start_epoch,
|
| 338 |
)
|
| 339 |
-
|
| 340 |
# --------------- Save Outputs ---------------
|
| 341 |
-
|
| 342 |
print("\nSaving outputs...")
|
| 343 |
-
|
| 344 |
# Labels
|
| 345 |
labels_path = Path(cfg.labels_out)
|
| 346 |
save_label_metadata(
|
|
@@ -348,17 +394,17 @@ def main(cfg: DictConfig) -> None:
|
|
| 348 |
labels_path,
|
| 349 |
)
|
| 350 |
print(f" Labels: {labels_path}")
|
| 351 |
-
|
| 352 |
# History
|
| 353 |
history_path = Path(cfg.history_out)
|
| 354 |
history_path.parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
with history_path.open("w") as f:
|
| 356 |
json.dump(history, f, indent=2)
|
| 357 |
print(f" History: {history_path}")
|
| 358 |
-
|
| 359 |
total_time = time.perf_counter() - start_time
|
| 360 |
print(f"\n{'=' * 60}")
|
| 361 |
-
print(f"Training complete in {total_time/60:.1f} minutes")
|
| 362 |
print(f" Best checkpoint: {ckpt_dir / 'best.pt'}")
|
| 363 |
print(f"{'=' * 60}")
|
| 364 |
|
|
|
|
| 56 |
import random
|
| 57 |
|
| 58 |
import numpy as np
|
| 59 |
+
|
| 60 |
random.seed(seed)
|
| 61 |
np.random.seed(seed)
|
| 62 |
torch.manual_seed(seed)
|
|
|
|
| 79 |
def main(cfg: DictConfig) -> None:
|
| 80 |
"""Main training entry point."""
|
| 81 |
start_time = time.perf_counter()
|
| 82 |
+
|
| 83 |
print("=" * 60)
|
| 84 |
print("LexiMind Training")
|
| 85 |
print("=" * 60)
|
| 86 |
print(OmegaConf.to_yaml(cfg))
|
| 87 |
+
|
| 88 |
set_seed(cfg.seed)
|
| 89 |
device = torch.device(cfg.device)
|
| 90 |
+
|
| 91 |
# GPU optimizations for Ampere+
|
| 92 |
if device.type == "cuda":
|
| 93 |
# Enable cudnn benchmark for fixed-size inputs (10-20% speedup)
|
| 94 |
torch.backends.cudnn.benchmark = True
|
| 95 |
+
|
| 96 |
if torch.cuda.get_device_capability()[0] >= 8:
|
| 97 |
torch.set_float32_matmul_precision("high")
|
| 98 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
| 100 |
print(" TF32 + cudnn.benchmark enabled (Ampere GPU)")
|
| 101 |
else:
|
| 102 |
print(" cudnn.benchmark enabled")
|
| 103 |
+
|
| 104 |
# --------------- Load Data ---------------
|
| 105 |
+
|
| 106 |
print("\nLoading datasets...")
|
| 107 |
data_cfg = cfg.data
|
| 108 |
trainer_cfg = cfg.training.get("trainer", {})
|
| 109 |
+
|
| 110 |
# Load splits
|
| 111 |
summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
|
| 112 |
emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
|
| 113 |
topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
|
| 114 |
+
|
| 115 |
# Apply sample limits for dev runs
|
| 116 |
max_train = trainer_cfg.get("max_train_samples")
|
| 117 |
max_val = trainer_cfg.get("max_val_samples")
|
|
|
|
| 122 |
for splits in [summ_splits, emot_splits, topic_splits]:
|
| 123 |
if "val" in splits:
|
| 124 |
splits["val"] = splits["val"][:max_val]
|
| 125 |
+
|
| 126 |
+
print(
|
| 127 |
+
f" Summarization: {len(summ_splits['train']):,} train, {len(summ_splits.get('val', [])):,} val"
|
| 128 |
+
)
|
| 129 |
+
print(
|
| 130 |
+
f" Emotion: {len(emot_splits['train']):,} train, {len(emot_splits.get('val', [])):,} val"
|
| 131 |
+
)
|
| 132 |
+
print(
|
| 133 |
+
f" Topic: {len(topic_splits['train']):,} train, {len(topic_splits.get('val', [])):,} val"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
# --------------- Tokenizer ---------------
|
| 137 |
+
|
| 138 |
tok_cfg = data_cfg.get("tokenizer", {})
|
| 139 |
max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
|
| 140 |
+
|
| 141 |
+
tokenizer = Tokenizer(
|
| 142 |
+
TokenizerConfig(
|
| 143 |
+
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 144 |
+
max_length=max_len,
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
print(f" Tokenizer: {tokenizer.vocab_size:,} vocab, max_len={max_len}")
|
| 148 |
+
|
| 149 |
# --------------- Datasets ---------------
|
| 150 |
+
|
| 151 |
summ_train = SummarizationDataset(summ_splits["train"])
|
| 152 |
summ_val = SummarizationDataset(summ_splits.get("val", []))
|
| 153 |
emot_train = EmotionDataset(emot_splits["train"])
|
| 154 |
emot_val = EmotionDataset(emot_splits.get("val", []), binarizer=emot_train.binarizer)
|
| 155 |
topic_train = TopicDataset(topic_splits["train"])
|
| 156 |
topic_val = TopicDataset(topic_splits.get("val", []), encoder=topic_train.encoder)
|
| 157 |
+
|
| 158 |
print(f" Emotions: {len(emot_train.emotion_classes)} classes")
|
| 159 |
+
print(
|
| 160 |
+
f" Topics: {len(topic_train.topic_classes)} classes → {list(map(str, topic_train.topic_classes))}"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
# --------------- DataLoaders ---------------
|
| 164 |
+
|
| 165 |
dl_cfg = cfg.training.get("dataloader", {})
|
| 166 |
batch_size = int(dl_cfg.get("batch_size", 8))
|
| 167 |
num_workers = int(dl_cfg.get("num_workers", 4))
|
| 168 |
+
|
| 169 |
# Classification tasks don't need full 512 tokens - 256 is sufficient
|
| 170 |
# This speeds up emotion/topic forward passes significantly
|
| 171 |
classification_max_len = min(256, max_len)
|
| 172 |
+
|
| 173 |
train_loaders = {
|
| 174 |
"summarization": build_summarization_dataloader(
|
| 175 |
+
summ_train,
|
| 176 |
+
tokenizer,
|
| 177 |
+
shuffle=True,
|
| 178 |
+
max_source_length=max_len,
|
| 179 |
+
max_target_length=max_len,
|
| 180 |
+
batch_size=batch_size,
|
| 181 |
+
num_workers=num_workers,
|
| 182 |
+
pin_memory=True,
|
| 183 |
),
|
| 184 |
"emotion": build_emotion_dataloader(
|
| 185 |
+
emot_train,
|
| 186 |
+
tokenizer,
|
| 187 |
+
shuffle=True,
|
| 188 |
+
max_length=classification_max_len,
|
| 189 |
+
batch_size=batch_size,
|
| 190 |
+
num_workers=num_workers,
|
| 191 |
+
pin_memory=True,
|
| 192 |
),
|
| 193 |
"topic": build_topic_dataloader(
|
| 194 |
+
topic_train,
|
| 195 |
+
tokenizer,
|
| 196 |
+
shuffle=True,
|
| 197 |
+
max_length=classification_max_len,
|
| 198 |
+
batch_size=batch_size,
|
| 199 |
+
num_workers=num_workers,
|
| 200 |
+
pin_memory=True,
|
| 201 |
),
|
| 202 |
}
|
| 203 |
+
|
| 204 |
val_loaders = {}
|
| 205 |
if summ_val:
|
| 206 |
val_loaders["summarization"] = build_summarization_dataloader(
|
| 207 |
+
summ_val,
|
| 208 |
+
tokenizer,
|
| 209 |
+
shuffle=False,
|
| 210 |
+
max_source_length=max_len,
|
| 211 |
+
max_target_length=max_len,
|
| 212 |
+
batch_size=batch_size,
|
| 213 |
+
num_workers=num_workers,
|
| 214 |
+
pin_memory=True,
|
| 215 |
)
|
| 216 |
if emot_val:
|
| 217 |
val_loaders["emotion"] = build_emotion_dataloader(
|
| 218 |
+
emot_val,
|
| 219 |
+
tokenizer,
|
| 220 |
+
shuffle=False,
|
| 221 |
+
max_length=classification_max_len,
|
| 222 |
+
batch_size=batch_size,
|
| 223 |
+
num_workers=num_workers,
|
| 224 |
+
pin_memory=True,
|
| 225 |
)
|
| 226 |
if topic_val:
|
| 227 |
val_loaders["topic"] = build_topic_dataloader(
|
| 228 |
+
topic_val,
|
| 229 |
+
tokenizer,
|
| 230 |
+
shuffle=False,
|
| 231 |
+
max_length=classification_max_len,
|
| 232 |
+
batch_size=batch_size,
|
| 233 |
+
num_workers=num_workers,
|
| 234 |
+
pin_memory=True,
|
| 235 |
)
|
| 236 |
+
|
| 237 |
# --------------- Model ---------------
|
| 238 |
+
|
| 239 |
print("\nBuilding model...")
|
| 240 |
+
|
| 241 |
# Check for overrides in training config
|
| 242 |
+
grad_ckpt = cfg.training.get(
|
| 243 |
+
"gradient_checkpointing", cfg.model.get("gradient_checkpointing", False)
|
| 244 |
+
)
|
| 245 |
+
use_rel_pos = cfg.training.get(
|
| 246 |
+
"use_relative_position_bias", cfg.model.get("use_relative_position_bias", False)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
model_cfg = ModelConfig(
|
| 250 |
d_model=cfg.model.d_model,
|
| 251 |
vocab_size=getattr(cfg.model, "vocab_size", None),
|
|
|
|
| 260 |
use_relative_position_bias=use_rel_pos,
|
| 261 |
gradient_checkpointing=grad_ckpt,
|
| 262 |
)
|
| 263 |
+
|
| 264 |
if grad_ckpt:
|
| 265 |
print(" Gradient checkpointing: on")
|
| 266 |
if not use_rel_pos:
|
| 267 |
print(" FlashAttention: on (no relative position bias)")
|
| 268 |
+
|
| 269 |
model = build_multitask_model(
|
| 270 |
tokenizer,
|
| 271 |
num_emotions=len(emot_train.emotion_classes),
|
| 272 |
num_topics=len(topic_train.topic_classes),
|
| 273 |
config=model_cfg,
|
| 274 |
).to(device)
|
| 275 |
+
|
| 276 |
param_count = sum(p.numel() for p in model.parameters())
|
| 277 |
+
print(f" Parameters: {param_count:,} ({param_count / 1e6:.1f}M)")
|
| 278 |
+
|
| 279 |
# Freeze lower encoder layers (keeps pretrained language understanding, adapts upper layers)
|
| 280 |
freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
|
| 281 |
if freeze_layers > 0:
|
| 282 |
frozen_params = 0
|
| 283 |
# Freeze embedding layer
|
| 284 |
+
if hasattr(model.encoder, "embed_tokens"):
|
| 285 |
for p in model.encoder.embed_tokens.parameters():
|
| 286 |
p.requires_grad = False
|
| 287 |
frozen_params += p.numel()
|
| 288 |
# Freeze specified number of encoder layers
|
| 289 |
+
if hasattr(model.encoder, "layers"):
|
| 290 |
for i, layer in enumerate(model.encoder.layers):
|
| 291 |
if i < freeze_layers:
|
| 292 |
for p in layer.parameters():
|
| 293 |
p.requires_grad = False
|
| 294 |
frozen_params += p.numel()
|
| 295 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 296 |
+
print(f" Frozen layers: 0-{freeze_layers - 1} ({frozen_params / 1e6:.1f}M params)")
|
| 297 |
+
print(f" Trainable: {trainable:,} ({trainable / 1e6:.1f}M)")
|
| 298 |
+
|
| 299 |
# Resume from checkpoint?
|
| 300 |
start_epoch = 1
|
| 301 |
resume_path = cfg.get("resume_from")
|
|
|
|
| 303 |
print(f" Resuming from: {resume_path}")
|
| 304 |
load_state(model, str(resume_path))
|
| 305 |
import re
|
| 306 |
+
|
| 307 |
digits = re.findall(r"\d+", Path(resume_path).stem)
|
| 308 |
if digits:
|
| 309 |
start_epoch = int(digits[-1]) + 1
|
| 310 |
+
|
| 311 |
# Compile model for speed
|
| 312 |
# Note: "reduce-overhead" mode uses CUDA graphs which conflicts with gradient checkpointing
|
| 313 |
# Use "default" mode when checkpointing is enabled
|
|
|
|
| 318 |
if cfg.training.get("compile_decoder", True):
|
| 319 |
model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
|
| 320 |
print(f" Decoder compiled ({compile_mode})")
|
| 321 |
+
|
| 322 |
# --------------- Train ---------------
|
| 323 |
+
|
| 324 |
print("\nStarting training...")
|
| 325 |
opt_cfg = cfg.training.get("optimizer", {})
|
| 326 |
sched_cfg = cfg.training.get("scheduler", {})
|
| 327 |
+
|
| 328 |
# Use fused AdamW on CUDA for ~5-10% speedup
|
| 329 |
use_fused = device.type == "cuda" and "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
|
| 330 |
optimizer = torch.optim.AdamW(
|
|
|
|
| 335 |
)
|
| 336 |
if use_fused:
|
| 337 |
print(" Fused AdamW: on")
|
| 338 |
+
|
| 339 |
trainer = Trainer(
|
| 340 |
model=model,
|
| 341 |
optimizer=optimizer,
|
|
|
|
| 355 |
device=device,
|
| 356 |
tokenizer=tokenizer,
|
| 357 |
)
|
| 358 |
+
|
| 359 |
# Checkpoint callback
|
| 360 |
ckpt_dir = Path(cfg.checkpoint_out).parent
|
| 361 |
+
best_val_loss = float("inf")
|
| 362 |
+
|
| 363 |
def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
|
| 364 |
nonlocal best_val_loss
|
| 365 |
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 366 |
+
|
| 367 |
# Save epoch checkpoint
|
| 368 |
save_state(model, str(ckpt_dir / f"epoch_{epoch}.pt"))
|
| 369 |
+
|
| 370 |
# Track best
|
| 371 |
val_key = f"val_epoch_{epoch}"
|
| 372 |
if val_key in history:
|
| 373 |
+
val_loss = history[val_key].get("total_loss", float("inf"))
|
| 374 |
if val_loss < best_val_loss:
|
| 375 |
best_val_loss = val_loss
|
| 376 |
save_state(model, str(ckpt_dir / "best.pt"))
|
| 377 |
print(f" New best model saved (val_loss={val_loss:.4f})")
|
| 378 |
+
|
| 379 |
history = trainer.fit(
|
| 380 |
train_loaders,
|
| 381 |
val_loaders if val_loaders else None,
|
| 382 |
checkpoint_callback=save_checkpoint,
|
| 383 |
start_epoch=start_epoch,
|
| 384 |
)
|
| 385 |
+
|
| 386 |
# --------------- Save Outputs ---------------
|
| 387 |
+
|
| 388 |
print("\nSaving outputs...")
|
| 389 |
+
|
| 390 |
# Labels
|
| 391 |
labels_path = Path(cfg.labels_out)
|
| 392 |
save_label_metadata(
|
|
|
|
| 394 |
labels_path,
|
| 395 |
)
|
| 396 |
print(f" Labels: {labels_path}")
|
| 397 |
+
|
| 398 |
# History
|
| 399 |
history_path = Path(cfg.history_out)
|
| 400 |
history_path.parent.mkdir(parents=True, exist_ok=True)
|
| 401 |
with history_path.open("w") as f:
|
| 402 |
json.dump(history, f, indent=2)
|
| 403 |
print(f" History: {history_path}")
|
| 404 |
+
|
| 405 |
total_time = time.perf_counter() - start_time
|
| 406 |
print(f"\n{'=' * 60}")
|
| 407 |
+
print(f"Training complete in {total_time / 60:.1f} minutes")
|
| 408 |
print(f" Best checkpoint: {ckpt_dir / 'best.pt'}")
|
| 409 |
print(f"{'=' * 60}")
|
| 410 |
|
scripts/train_multiseed.py
CHANGED
|
@@ -30,7 +30,8 @@ def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
|
|
| 30 |
seed_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
|
| 32 |
cmd = [
|
| 33 |
-
sys.executable,
|
|
|
|
| 34 |
f"seed={seed}",
|
| 35 |
f"checkpoint_out={seed_dir}/checkpoints/best.pt",
|
| 36 |
f"history_out={seed_dir}/training_history.json",
|
|
@@ -39,9 +40,9 @@ def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
|
|
| 39 |
if config_overrides:
|
| 40 |
cmd.extend(config_overrides.split())
|
| 41 |
|
| 42 |
-
print(f"\n{'='*60}")
|
| 43 |
print(f"Training seed {seed}")
|
| 44 |
-
print(f"{'='*60}")
|
| 45 |
print(f" Command: {' '.join(cmd)}")
|
| 46 |
|
| 47 |
result = subprocess.run(cmd, capture_output=False)
|
|
@@ -69,7 +70,8 @@ def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = Non
|
|
| 69 |
return {}
|
| 70 |
|
| 71 |
cmd = [
|
| 72 |
-
sys.executable,
|
|
|
|
| 73 |
f"--checkpoint={checkpoint}",
|
| 74 |
f"--labels={labels}",
|
| 75 |
f"--output={output}",
|
|
@@ -105,7 +107,11 @@ def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
|
|
| 105 |
if not isinstance(task_metrics, dict):
|
| 106 |
continue
|
| 107 |
for metric_name, value in task_metrics.items():
|
| 108 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
key = f"{task}/{metric_name}"
|
| 110 |
metric_values.setdefault(key, []).append(float(value))
|
| 111 |
|
|
@@ -125,9 +131,9 @@ def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
|
|
| 125 |
|
| 126 |
def print_summary(aggregated: Dict, seeds: List[int]) -> None:
|
| 127 |
"""Print human-readable summary of multi-seed results."""
|
| 128 |
-
print(f"\n{'='*70}")
|
| 129 |
print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
|
| 130 |
-
print(f"{'='*70}")
|
| 131 |
|
| 132 |
# Group by task
|
| 133 |
tasks: Dict[str, Dict[str, Dict]] = {}
|
|
@@ -142,23 +148,32 @@ def print_summary(aggregated: Dict, seeds: List[int]) -> None:
|
|
| 142 |
std = stats["std"]
|
| 143 |
# Format based on metric type
|
| 144 |
if "accuracy" in metric:
|
| 145 |
-
print(f" {metric:25s}: {mean*100:.1f}% ± {std*100:.1f}%")
|
| 146 |
else:
|
| 147 |
print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
|
| 148 |
|
| 149 |
|
| 150 |
def main():
|
| 151 |
parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
|
| 152 |
-
parser.add_argument(
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
parser.add_argument(
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
args = parser.parse_args()
|
| 163 |
|
| 164 |
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -184,11 +199,15 @@ def main():
|
|
| 184 |
# Save aggregated results
|
| 185 |
output_path = args.output_dir / "aggregated_results.json"
|
| 186 |
with open(output_path, "w") as f:
|
| 187 |
-
json.dump(
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
print(f"\n Saved to: {output_path}")
|
| 193 |
else:
|
| 194 |
print("\nNo evaluation results to aggregate.")
|
|
|
|
| 30 |
seed_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
|
| 32 |
cmd = [
|
| 33 |
+
sys.executable,
|
| 34 |
+
"scripts/train.py",
|
| 35 |
f"seed={seed}",
|
| 36 |
f"checkpoint_out={seed_dir}/checkpoints/best.pt",
|
| 37 |
f"history_out={seed_dir}/training_history.json",
|
|
|
|
| 40 |
if config_overrides:
|
| 41 |
cmd.extend(config_overrides.split())
|
| 42 |
|
| 43 |
+
print(f"\n{'=' * 60}")
|
| 44 |
print(f"Training seed {seed}")
|
| 45 |
+
print(f"{'=' * 60}")
|
| 46 |
print(f" Command: {' '.join(cmd)}")
|
| 47 |
|
| 48 |
result = subprocess.run(cmd, capture_output=False)
|
|
|
|
| 70 |
return {}
|
| 71 |
|
| 72 |
cmd = [
|
| 73 |
+
sys.executable,
|
| 74 |
+
"scripts/evaluate.py",
|
| 75 |
f"--checkpoint={checkpoint}",
|
| 76 |
f"--labels={labels}",
|
| 77 |
f"--output={output}",
|
|
|
|
| 107 |
if not isinstance(task_metrics, dict):
|
| 108 |
continue
|
| 109 |
for metric_name, value in task_metrics.items():
|
| 110 |
+
if (
|
| 111 |
+
isinstance(value, (int, float))
|
| 112 |
+
and metric_name != "num_samples"
|
| 113 |
+
and metric_name != "num_classes"
|
| 114 |
+
):
|
| 115 |
key = f"{task}/{metric_name}"
|
| 116 |
metric_values.setdefault(key, []).append(float(value))
|
| 117 |
|
|
|
|
| 131 |
|
| 132 |
def print_summary(aggregated: Dict, seeds: List[int]) -> None:
|
| 133 |
"""Print human-readable summary of multi-seed results."""
|
| 134 |
+
print(f"\n{'=' * 70}")
|
| 135 |
print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
|
| 136 |
+
print(f"{'=' * 70}")
|
| 137 |
|
| 138 |
# Group by task
|
| 139 |
tasks: Dict[str, Dict[str, Dict]] = {}
|
|
|
|
| 148 |
std = stats["std"]
|
| 149 |
# Format based on metric type
|
| 150 |
if "accuracy" in metric:
|
| 151 |
+
print(f" {metric:25s}: {mean * 100:.1f}% ± {std * 100:.1f}%")
|
| 152 |
else:
|
| 153 |
print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
|
| 154 |
|
| 155 |
|
| 156 |
def main():
|
| 157 |
parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--seeds", nargs="+", type=int, default=[17, 42, 123], help="Random seeds to train with"
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--config", type=str, default="", help="Hydra config overrides (e.g., 'training=full')"
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--output-dir", type=Path, default=Path("outputs/multiseed"), help="Base output directory"
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--skip-training",
|
| 169 |
+
action="store_true",
|
| 170 |
+
help="Skip training, only aggregate existing results",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--skip-eval",
|
| 174 |
+
action="store_true",
|
| 175 |
+
help="Skip evaluation, only aggregate training histories",
|
| 176 |
+
)
|
| 177 |
args = parser.parse_args()
|
| 178 |
|
| 179 |
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 199 |
# Save aggregated results
|
| 200 |
output_path = args.output_dir / "aggregated_results.json"
|
| 201 |
with open(output_path, "w") as f:
|
| 202 |
+
json.dump(
|
| 203 |
+
{
|
| 204 |
+
"seeds": args.seeds,
|
| 205 |
+
"per_seed": {str(k): v for k, v in all_eval_results.items()},
|
| 206 |
+
"aggregated": aggregated,
|
| 207 |
+
},
|
| 208 |
+
f,
|
| 209 |
+
indent=2,
|
| 210 |
+
)
|
| 211 |
print(f"\n Saved to: {output_path}")
|
| 212 |
else:
|
| 213 |
print("\nNo evaluation results to aggregate.")
|
scripts/visualize_training.py
CHANGED
|
@@ -81,31 +81,33 @@ ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
|
|
| 81 |
|
| 82 |
# Professional color palette (accessible + publication-ready)
|
| 83 |
COLORS = {
|
| 84 |
-
"primary": "#2E86AB",
|
| 85 |
-
"secondary": "#E94F37",
|
| 86 |
-
"accent": "#28A745",
|
| 87 |
-
"highlight": "#F7B801",
|
| 88 |
-
"dark": "#1E3A5F",
|
| 89 |
-
"light": "#F5F5F5",
|
| 90 |
-
"topic": "#8338EC",
|
| 91 |
-
"emotion": "#FF6B6B",
|
| 92 |
-
"summary": "#06D6A0",
|
| 93 |
}
|
| 94 |
|
| 95 |
# Style configuration
|
| 96 |
plt.style.use("seaborn-v0_8-whitegrid")
|
| 97 |
-
plt.rcParams.update(
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# Custom colormap for heatmaps
|
| 111 |
HEATMAP_CMAP = LinearSegmentedColormap.from_list(
|
|
@@ -115,12 +117,14 @@ HEATMAP_CMAP = LinearSegmentedColormap.from_list(
|
|
| 115 |
|
| 116 |
# MLflow Utilities
|
| 117 |
|
|
|
|
| 118 |
def get_mlflow_client():
|
| 119 |
"""Get MLflow client with correct tracking URI."""
|
| 120 |
if not HAS_MLFLOW:
|
| 121 |
raise ImportError("MLflow not installed. Install with: pip install mlflow")
|
| 122 |
import mlflow
|
| 123 |
import mlflow.tracking
|
|
|
|
| 124 |
# Use SQLite database (same as trainer.py)
|
| 125 |
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 126 |
return mlflow.tracking.MlflowClient()
|
|
@@ -153,6 +157,7 @@ def get_metric_history(run, metric_name: str) -> tuple[list, list]:
|
|
| 153 |
|
| 154 |
# Core Training Visualizations
|
| 155 |
|
|
|
|
| 156 |
def plot_loss_curves(run, interactive: bool = False) -> None:
|
| 157 |
"""
|
| 158 |
Plot training and validation loss over time.
|
|
@@ -164,37 +169,49 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
|
|
| 164 |
|
| 165 |
if interactive and HAS_PLOTLY:
|
| 166 |
import plotly.graph_objects as go
|
|
|
|
| 167 |
fig = go.Figure()
|
| 168 |
|
| 169 |
if train_values:
|
| 170 |
-
fig.add_trace(
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
if val_values:
|
| 177 |
-
fig.add_trace(
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Best point
|
| 184 |
best_idx = int(np.argmin(val_values))
|
| 185 |
-
fig.add_trace(
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
fig.update_layout(
|
| 193 |
title="Training Progress: Multi-Task Loss",
|
| 194 |
xaxis_title="Epoch",
|
| 195 |
yaxis_title="Loss",
|
| 196 |
template="plotly_white",
|
| 197 |
-
hovermode="x unified"
|
| 198 |
)
|
| 199 |
|
| 200 |
output_path = OUTPUTS_DIR / "training_loss_curve.html"
|
|
@@ -206,32 +223,62 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
|
|
| 206 |
fig, ax = plt.subplots(figsize=(12, 6))
|
| 207 |
|
| 208 |
if not train_values:
|
| 209 |
-
ax.text(
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
ax.set_xlim(0, 1)
|
| 212 |
ax.set_ylim(0, 1)
|
| 213 |
else:
|
| 214 |
# Training curve
|
| 215 |
-
ax.plot(
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
# Validation curve with best point
|
| 219 |
if val_values:
|
| 220 |
-
ax.plot(
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
best_idx = int(np.argmin(val_values))
|
| 224 |
-
ax.scatter(
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
# Annotate best point
|
| 230 |
-
ax.annotate(
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
ax.legend(fontsize=11, loc="upper right", framealpha=0.9)
|
| 237 |
ax.set_ylim(bottom=0)
|
|
@@ -265,11 +312,22 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 265 |
val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
|
| 266 |
|
| 267 |
if train_sum:
|
| 268 |
-
ax.plot(
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
if val_sum:
|
| 271 |
-
ax.plot(
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
ax.set_title("Summarization Loss")
|
| 275 |
ax.set_xlabel("Epoch")
|
|
@@ -286,20 +344,43 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 286 |
val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
|
| 287 |
|
| 288 |
if train_emo:
|
| 289 |
-
ax.plot(
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
if val_emo:
|
| 292 |
-
ax.plot(
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
# Secondary axis for F1
|
| 296 |
ax2 = ax.twinx()
|
| 297 |
if train_f1:
|
| 298 |
-
ax2.plot(
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
if val_f1:
|
| 301 |
-
ax2.plot(
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
ax2.set_ylim(0, 1)
|
| 304 |
|
| 305 |
ax.set_title("Emotion Detection (28 classes)")
|
|
@@ -320,19 +401,42 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 320 |
val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
|
| 321 |
|
| 322 |
if train_topic:
|
| 323 |
-
ax.plot(
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
if val_topic:
|
| 326 |
-
ax.plot(
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
ax2 = ax.twinx()
|
| 330 |
if train_acc:
|
| 331 |
-
ax2.plot(
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
if val_acc:
|
| 334 |
-
ax2.plot(
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
ax2.set_ylim(0, 1)
|
| 337 |
|
| 338 |
ax.set_title("Topic Classification (4 classes)")
|
|
@@ -350,9 +454,11 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 350 |
ax.axis("off")
|
| 351 |
|
| 352 |
# Get final metrics
|
| 353 |
-
summary_lines = [
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
| 356 |
|
| 357 |
if val_topic and val_acc:
|
| 358 |
summary_lines.append(f"| Topic Accuracy: {val_acc[-1].value:>6.1%} |")
|
|
@@ -363,8 +469,15 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 363 |
|
| 364 |
summary_lines.append("+--------------------------------------+")
|
| 365 |
|
| 366 |
-
ax.text(
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
# Add model info
|
| 370 |
run_params = run.data.params
|
|
@@ -372,8 +485,7 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
|
|
| 372 |
model_info += f"Batch Size: {run_params.get('batch_size', 'N/A')}\n"
|
| 373 |
model_info += f"Learning Rate: {run_params.get('learning_rate', 'N/A')}"
|
| 374 |
|
| 375 |
-
ax.text(0.1, 0.15, model_info, fontsize=10, color="gray",
|
| 376 |
-
verticalalignment="center")
|
| 377 |
|
| 378 |
plt.tight_layout()
|
| 379 |
output_path = OUTPUTS_DIR / "task_metrics.png"
|
|
@@ -392,13 +504,13 @@ def plot_learning_rate(run) -> None:
|
|
| 392 |
if not lr_metrics or len(lr_metrics) < 2:
|
| 393 |
# No LR data logged - generate theoretical schedule from config
|
| 394 |
logger.info(" No LR metrics found - generating theoretical schedule...")
|
| 395 |
-
|
| 396 |
# Get config from run params
|
| 397 |
params = run.data.params
|
| 398 |
lr_max = float(params.get("learning_rate", params.get("lr", 5e-5)))
|
| 399 |
warmup_steps = int(params.get("warmup_steps", 500))
|
| 400 |
max_epochs = int(params.get("max_epochs", 5))
|
| 401 |
-
|
| 402 |
# Estimate total steps from training loss history
|
| 403 |
train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
|
| 404 |
if train_loss:
|
|
@@ -407,7 +519,7 @@ def plot_learning_rate(run) -> None:
|
|
| 407 |
total_steps = max_epochs * estimated_steps_per_epoch
|
| 408 |
else:
|
| 409 |
total_steps = 4000 # Default fallback
|
| 410 |
-
|
| 411 |
# Generate cosine schedule with warmup
|
| 412 |
steps = np.arange(0, total_steps)
|
| 413 |
values = []
|
|
@@ -418,25 +530,43 @@ def plot_learning_rate(run) -> None:
|
|
| 418 |
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 419 |
lr = lr_max * max(0.1, 0.5 * (1 + np.cos(np.pi * progress)))
|
| 420 |
values.append(lr)
|
| 421 |
-
|
| 422 |
ax.fill_between(steps, values, alpha=0.3, color=COLORS["primary"])
|
| 423 |
ax.plot(steps, values, linewidth=2.5, color=COLORS["primary"], label="Cosine + Warmup")
|
| 424 |
-
|
| 425 |
# Mark warmup region
|
| 426 |
-
ax.axvline(
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"])
|
| 429 |
-
|
| 430 |
# Add annotation
|
| 431 |
-
ax.annotate(
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
ax.legend(loc="upper right")
|
| 437 |
-
ax.text(
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
else:
|
| 441 |
steps = np.array([m.step for m in lr_metrics])
|
| 442 |
values = [m.value for m in lr_metrics]
|
|
@@ -449,10 +579,15 @@ def plot_learning_rate(run) -> None:
|
|
| 449 |
params = run.data.params
|
| 450 |
warmup_steps = int(params.get("warmup_steps", 500))
|
| 451 |
if warmup_steps < max(steps):
|
| 452 |
-
ax.axvline(
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
ax.legend(loc="upper right")
|
| 457 |
|
| 458 |
# Scientific notation for y-axis if needed
|
|
@@ -471,6 +606,7 @@ def plot_learning_rate(run) -> None:
|
|
| 471 |
|
| 472 |
# Advanced Visualizations
|
| 473 |
|
|
|
|
| 474 |
def plot_confusion_matrix(run, task: str = "topic") -> None:
|
| 475 |
"""
|
| 476 |
Plot confusion matrix for classification tasks.
|
|
@@ -482,8 +618,16 @@ def plot_confusion_matrix(run, task: str = "topic") -> None:
|
|
| 482 |
if task == "topic":
|
| 483 |
default_labels = ["World", "Sports", "Business", "Sci/Tech"]
|
| 484 |
else: # emotion - top 8 for visibility
|
| 485 |
-
default_labels = [
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
if labels_path.exists():
|
| 489 |
with open(labels_path) as f:
|
|
@@ -516,9 +660,16 @@ def plot_confusion_matrix(run, task: str = "topic") -> None:
|
|
| 516 |
# Plot
|
| 517 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 518 |
|
| 519 |
-
sns.heatmap(
|
| 520 |
-
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
ax.set_title(f"Confusion Matrix: {task.title()} Classification")
|
| 524 |
ax.set_xlabel("Predicted Label")
|
|
@@ -570,7 +721,7 @@ def plot_3d_loss_landscape(run) -> None:
|
|
| 570 |
|
| 571 |
# Synthetic loss surface (bowl shape with some local minima)
|
| 572 |
min_loss = min(val_loss) if val_loss else min(train_loss)
|
| 573 |
-
Z = min_loss + 0.3 * (X**2 + Y**2) + 0.1 * np.sin(3*X) * np.cos(3*Y)
|
| 574 |
|
| 575 |
# Add noise for realism
|
| 576 |
Z += np.random.normal(0, 0.02, Z.shape)
|
|
@@ -584,41 +735,57 @@ def plot_3d_loss_landscape(run) -> None:
|
|
| 584 |
fig = go.Figure()
|
| 585 |
|
| 586 |
# Loss surface
|
| 587 |
-
fig.add_trace(
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
# Training trajectory
|
| 596 |
-
fig.add_trace(
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
# Mark start and end
|
| 605 |
-
fig.add_trace(
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
fig.update_layout(
|
| 624 |
title="Loss Landscape & Optimization Trajectory",
|
|
@@ -626,7 +793,7 @@ def plot_3d_loss_landscape(run) -> None:
|
|
| 626 |
xaxis_title="Parameter Direction 1",
|
| 627 |
yaxis_title="Parameter Direction 2",
|
| 628 |
zaxis_title="Loss",
|
| 629 |
-
camera=dict(eye=dict(x=1.5, y=1.5, z=0.8))
|
| 630 |
),
|
| 631 |
width=900,
|
| 632 |
height=700,
|
|
@@ -658,26 +825,46 @@ def plot_3d_loss_landscape_static(run) -> None:
|
|
| 658 |
X, Y = np.meshgrid(x, y)
|
| 659 |
|
| 660 |
min_loss = min(train_loss)
|
| 661 |
-
Z = min_loss + 0.3 * (X**2 + Y**2) + 0.08 * np.sin(3*X) * np.cos(3*Y)
|
| 662 |
|
| 663 |
fig = plt.figure(figsize=(12, 8))
|
| 664 |
ax = fig.add_subplot(111, projection="3d")
|
| 665 |
|
| 666 |
# Surface
|
| 667 |
-
surf = ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.7,
|
| 668 |
-
linewidth=0, antialiased=True)
|
| 669 |
|
| 670 |
# Training path
|
| 671 |
path_x = np.linspace(-1.5, 0, len(train_loss))
|
| 672 |
path_y = np.linspace(1.2, 0, len(train_loss))
|
| 673 |
-
ax.plot(
|
| 674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
# Start/end markers
|
| 677 |
-
ax.scatter(
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
ax.set_xlabel("θ₁ Direction")
|
| 683 |
ax.set_ylabel("θ₂ Direction")
|
|
@@ -722,7 +909,7 @@ def plot_embedding_space(run) -> None:
|
|
| 722 |
for i in range(n_clusters):
|
| 723 |
# Create cluster center
|
| 724 |
center = np.random.randn(64) * 0.5
|
| 725 |
-
center[i*16:(i+1)*16] += 3 # Make clusters separable
|
| 726 |
|
| 727 |
# Add samples around center
|
| 728 |
samples = center + np.random.randn(n_samples // n_clusters, 64) * 0.5
|
|
@@ -742,8 +929,14 @@ def plot_embedding_space(run) -> None:
|
|
| 742 |
|
| 743 |
for i in range(n_clusters):
|
| 744 |
mask = cluster_labels == i
|
| 745 |
-
ax.scatter(
|
| 746 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
|
| 748 |
ax.set_xlabel("t-SNE Dimension 1")
|
| 749 |
ax.set_ylabel("t-SNE Dimension 2")
|
|
@@ -787,14 +980,18 @@ def plot_training_dynamics(run) -> None:
|
|
| 787 |
# Smoothed loss (exponential moving average)
|
| 788 |
if len(train_loss) > 5:
|
| 789 |
window = min(5, len(train_loss) // 2)
|
| 790 |
-
smoothed = np.convolve(train_loss, np.ones(window)/window, mode="valid")
|
| 791 |
-
smoothed_steps = train_steps[window-1:]
|
| 792 |
-
ax.plot(
|
| 793 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
|
| 795 |
if val_loss:
|
| 796 |
-
ax.plot(val_steps, val_loss, color=COLORS["secondary"],
|
| 797 |
-
linewidth=2.5, label="Validation")
|
| 798 |
|
| 799 |
ax.set_title("Loss Convergence")
|
| 800 |
ax.set_xlabel("Epoch")
|
|
@@ -806,8 +1003,10 @@ def plot_training_dynamics(run) -> None:
|
|
| 806 |
ax = axes[0, 1]
|
| 807 |
|
| 808 |
if len(train_loss) > 1:
|
| 809 |
-
improvements = [
|
| 810 |
-
|
|
|
|
|
|
|
| 811 |
colors_bar = [COLORS["accent"] if imp > 0 else COLORS["secondary"] for imp in improvements]
|
| 812 |
ax.bar(train_steps[1:], improvements, color=colors_bar, alpha=0.7)
|
| 813 |
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
|
|
@@ -862,6 +1061,7 @@ def plot_training_dynamics(run) -> None:
|
|
| 862 |
|
| 863 |
# Dashboard Generator
|
| 864 |
|
|
|
|
| 865 |
def generate_dashboard(run) -> None:
|
| 866 |
"""
|
| 867 |
Generate an interactive HTML dashboard with all visualizations.
|
|
@@ -883,63 +1083,73 @@ def generate_dashboard(run) -> None:
|
|
| 883 |
|
| 884 |
# Create subplots
|
| 885 |
fig = make_subplots(
|
| 886 |
-
rows=2,
|
|
|
|
| 887 |
subplot_titles=("Total Loss", "Task Losses", "Learning Rate", "Metrics"),
|
| 888 |
-
specs=[[{}, {}], [{}, {}]]
|
| 889 |
)
|
| 890 |
|
| 891 |
# Total loss
|
| 892 |
if train_loss:
|
| 893 |
fig.add_trace(
|
| 894 |
-
go.Scatter(
|
| 895 |
-
|
| 896 |
-
|
|
|
|
|
|
|
| 897 |
)
|
| 898 |
if val_loss:
|
| 899 |
fig.add_trace(
|
| 900 |
-
go.Scatter(
|
| 901 |
-
|
| 902 |
-
|
|
|
|
|
|
|
| 903 |
)
|
| 904 |
|
| 905 |
# Per-task losses
|
| 906 |
-
for task, color in [
|
| 907 |
-
|
| 908 |
-
|
|
|
|
|
|
|
| 909 |
steps, values = get_metric_history(run, f"val_{task}_loss")
|
| 910 |
if values:
|
| 911 |
fig.add_trace(
|
| 912 |
-
go.Scatter(x=steps, y=values, name=f"{task.title()} Loss",
|
| 913 |
-
|
| 914 |
-
|
| 915 |
)
|
| 916 |
|
| 917 |
# Learning rate
|
| 918 |
lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
|
| 919 |
if lr_metrics:
|
| 920 |
fig.add_trace(
|
| 921 |
-
go.Scatter(
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
)
|
| 926 |
|
| 927 |
# Accuracy metrics
|
| 928 |
-
for metric, color in [("topic_accuracy", COLORS["topic"]),
|
| 929 |
-
("emotion_f1", COLORS["emotion"])]:
|
| 930 |
steps, values = get_metric_history(run, f"val_{metric}")
|
| 931 |
if values:
|
| 932 |
fig.add_trace(
|
| 933 |
-
go.Scatter(
|
| 934 |
-
|
| 935 |
-
|
|
|
|
|
|
|
| 936 |
)
|
| 937 |
|
| 938 |
fig.update_layout(
|
| 939 |
-
title="LexiMind Training Dashboard",
|
| 940 |
-
height=800,
|
| 941 |
-
template="plotly_white",
|
| 942 |
-
showlegend=True
|
| 943 |
)
|
| 944 |
|
| 945 |
output_path = OUTPUTS_DIR / "training_dashboard.html"
|
|
@@ -949,17 +1159,20 @@ def generate_dashboard(run) -> None:
|
|
| 949 |
|
| 950 |
# Main Entry Point
|
| 951 |
|
|
|
|
| 952 |
def main():
|
| 953 |
"""Generate all training visualizations."""
|
| 954 |
parser = argparse.ArgumentParser(description="LexiMind Visualization Suite")
|
| 955 |
-
parser.add_argument(
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
|
|
|
|
|
|
| 963 |
args = parser.parse_args()
|
| 964 |
|
| 965 |
logger.info("=" * 60)
|
|
|
|
| 81 |
|
| 82 |
# Professional color palette (accessible + publication-ready)
|
| 83 |
COLORS = {
|
| 84 |
+
"primary": "#2E86AB", # Deep blue - training
|
| 85 |
+
"secondary": "#E94F37", # Coral red - validation
|
| 86 |
+
"accent": "#28A745", # Green - best points
|
| 87 |
+
"highlight": "#F7B801", # Gold - highlights
|
| 88 |
+
"dark": "#1E3A5F", # Navy - text
|
| 89 |
+
"light": "#F5F5F5", # Light gray - background
|
| 90 |
+
"topic": "#8338EC", # Purple
|
| 91 |
+
"emotion": "#FF6B6B", # Salmon
|
| 92 |
+
"summary": "#06D6A0", # Teal
|
| 93 |
}
|
| 94 |
|
| 95 |
# Style configuration
|
| 96 |
plt.style.use("seaborn-v0_8-whitegrid")
|
| 97 |
+
plt.rcParams.update(
|
| 98 |
+
{
|
| 99 |
+
"font.family": "sans-serif",
|
| 100 |
+
"font.size": 11,
|
| 101 |
+
"axes.titlesize": 14,
|
| 102 |
+
"axes.titleweight": "bold",
|
| 103 |
+
"axes.labelsize": 12,
|
| 104 |
+
"legend.fontsize": 10,
|
| 105 |
+
"figure.titlesize": 16,
|
| 106 |
+
"figure.titleweight": "bold",
|
| 107 |
+
"savefig.dpi": 150,
|
| 108 |
+
"savefig.bbox": "tight",
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
|
| 112 |
# Custom colormap for heatmaps
|
| 113 |
HEATMAP_CMAP = LinearSegmentedColormap.from_list(
|
|
|
|
| 117 |
|
| 118 |
# MLflow Utilities
|
| 119 |
|
| 120 |
+
|
| 121 |
def get_mlflow_client():
|
| 122 |
"""Get MLflow client with correct tracking URI."""
|
| 123 |
if not HAS_MLFLOW:
|
| 124 |
raise ImportError("MLflow not installed. Install with: pip install mlflow")
|
| 125 |
import mlflow
|
| 126 |
import mlflow.tracking
|
| 127 |
+
|
| 128 |
# Use SQLite database (same as trainer.py)
|
| 129 |
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 130 |
return mlflow.tracking.MlflowClient()
|
|
|
|
| 157 |
|
| 158 |
# Core Training Visualizations
|
| 159 |
|
| 160 |
+
|
| 161 |
def plot_loss_curves(run, interactive: bool = False) -> None:
|
| 162 |
"""
|
| 163 |
Plot training and validation loss over time.
|
|
|
|
| 169 |
|
| 170 |
if interactive and HAS_PLOTLY:
|
| 171 |
import plotly.graph_objects as go
|
| 172 |
+
|
| 173 |
fig = go.Figure()
|
| 174 |
|
| 175 |
if train_values:
|
| 176 |
+
fig.add_trace(
|
| 177 |
+
go.Scatter(
|
| 178 |
+
x=train_steps,
|
| 179 |
+
y=train_values,
|
| 180 |
+
name="Training Loss",
|
| 181 |
+
mode="lines",
|
| 182 |
+
line=dict(color=COLORS["primary"], width=3),
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
|
| 186 |
if val_values:
|
| 187 |
+
fig.add_trace(
|
| 188 |
+
go.Scatter(
|
| 189 |
+
x=val_steps,
|
| 190 |
+
y=val_values,
|
| 191 |
+
name="Validation Loss",
|
| 192 |
+
mode="lines",
|
| 193 |
+
line=dict(color=COLORS["secondary"], width=3),
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
|
| 197 |
# Best point
|
| 198 |
best_idx = int(np.argmin(val_values))
|
| 199 |
+
fig.add_trace(
|
| 200 |
+
go.Scatter(
|
| 201 |
+
x=[val_steps[best_idx]],
|
| 202 |
+
y=[val_values[best_idx]],
|
| 203 |
+
name=f"Best: {val_values[best_idx]:.3f}",
|
| 204 |
+
mode="markers",
|
| 205 |
+
marker=dict(color=COLORS["accent"], size=15, symbol="star"),
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
|
| 209 |
fig.update_layout(
|
| 210 |
title="Training Progress: Multi-Task Loss",
|
| 211 |
xaxis_title="Epoch",
|
| 212 |
yaxis_title="Loss",
|
| 213 |
template="plotly_white",
|
| 214 |
+
hovermode="x unified",
|
| 215 |
)
|
| 216 |
|
| 217 |
output_path = OUTPUTS_DIR / "training_loss_curve.html"
|
|
|
|
| 223 |
fig, ax = plt.subplots(figsize=(12, 6))
|
| 224 |
|
| 225 |
if not train_values:
|
| 226 |
+
ax.text(
|
| 227 |
+
0.5,
|
| 228 |
+
0.5,
|
| 229 |
+
"No training data yet\n\nWaiting for first epoch...",
|
| 230 |
+
ha="center",
|
| 231 |
+
va="center",
|
| 232 |
+
fontsize=14,
|
| 233 |
+
color="gray",
|
| 234 |
+
)
|
| 235 |
ax.set_xlim(0, 1)
|
| 236 |
ax.set_ylim(0, 1)
|
| 237 |
else:
|
| 238 |
# Training curve
|
| 239 |
+
ax.plot(
|
| 240 |
+
train_steps,
|
| 241 |
+
train_values,
|
| 242 |
+
label="Training Loss",
|
| 243 |
+
linewidth=2.5,
|
| 244 |
+
color=COLORS["primary"],
|
| 245 |
+
alpha=0.9,
|
| 246 |
+
)
|
| 247 |
|
| 248 |
# Validation curve with best point
|
| 249 |
if val_values:
|
| 250 |
+
ax.plot(
|
| 251 |
+
val_steps,
|
| 252 |
+
val_values,
|
| 253 |
+
label="Validation Loss",
|
| 254 |
+
linewidth=2.5,
|
| 255 |
+
color=COLORS["secondary"],
|
| 256 |
+
alpha=0.9,
|
| 257 |
+
)
|
| 258 |
|
| 259 |
best_idx = int(np.argmin(val_values))
|
| 260 |
+
ax.scatter(
|
| 261 |
+
[val_steps[best_idx]],
|
| 262 |
+
[val_values[best_idx]],
|
| 263 |
+
s=200,
|
| 264 |
+
c=COLORS["accent"],
|
| 265 |
+
zorder=5,
|
| 266 |
+
marker="*",
|
| 267 |
+
edgecolors="white",
|
| 268 |
+
linewidth=2,
|
| 269 |
+
label=f"Best: {val_values[best_idx]:.3f}",
|
| 270 |
+
)
|
| 271 |
|
| 272 |
# Annotate best point
|
| 273 |
+
ax.annotate(
|
| 274 |
+
f"Epoch {val_steps[best_idx]}",
|
| 275 |
+
xy=(val_steps[best_idx], val_values[best_idx]),
|
| 276 |
+
xytext=(10, 20),
|
| 277 |
+
textcoords="offset points",
|
| 278 |
+
fontsize=10,
|
| 279 |
+
color=COLORS["accent"],
|
| 280 |
+
arrowprops=dict(arrowstyle="->", color=COLORS["accent"]),
|
| 281 |
+
)
|
| 282 |
|
| 283 |
ax.legend(fontsize=11, loc="upper right", framealpha=0.9)
|
| 284 |
ax.set_ylim(bottom=0)
|
|
|
|
| 312 |
val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
|
| 313 |
|
| 314 |
if train_sum:
|
| 315 |
+
ax.plot(
|
| 316 |
+
[m.step for m in train_sum],
|
| 317 |
+
[m.value for m in train_sum],
|
| 318 |
+
label="Train",
|
| 319 |
+
linewidth=2.5,
|
| 320 |
+
color=COLORS["summary"],
|
| 321 |
+
)
|
| 322 |
if val_sum:
|
| 323 |
+
ax.plot(
|
| 324 |
+
[m.step for m in val_sum],
|
| 325 |
+
[m.value for m in val_sum],
|
| 326 |
+
label="Validation",
|
| 327 |
+
linewidth=2.5,
|
| 328 |
+
color=COLORS["secondary"],
|
| 329 |
+
linestyle="--",
|
| 330 |
+
)
|
| 331 |
|
| 332 |
ax.set_title("Summarization Loss")
|
| 333 |
ax.set_xlabel("Epoch")
|
|
|
|
| 344 |
val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
|
| 345 |
|
| 346 |
if train_emo:
|
| 347 |
+
ax.plot(
|
| 348 |
+
[m.step for m in train_emo],
|
| 349 |
+
[m.value for m in train_emo],
|
| 350 |
+
label="Train Loss",
|
| 351 |
+
linewidth=2.5,
|
| 352 |
+
color=COLORS["emotion"],
|
| 353 |
+
)
|
| 354 |
if val_emo:
|
| 355 |
+
ax.plot(
|
| 356 |
+
[m.step for m in val_emo],
|
| 357 |
+
[m.value for m in val_emo],
|
| 358 |
+
label="Val Loss",
|
| 359 |
+
linewidth=2.5,
|
| 360 |
+
color=COLORS["secondary"],
|
| 361 |
+
linestyle="--",
|
| 362 |
+
)
|
| 363 |
|
| 364 |
# Secondary axis for F1
|
| 365 |
ax2 = ax.twinx()
|
| 366 |
if train_f1:
|
| 367 |
+
ax2.plot(
|
| 368 |
+
[m.step for m in train_f1],
|
| 369 |
+
[m.value for m in train_f1],
|
| 370 |
+
label="Train F1",
|
| 371 |
+
linewidth=2,
|
| 372 |
+
color=COLORS["accent"],
|
| 373 |
+
alpha=0.7,
|
| 374 |
+
)
|
| 375 |
if val_f1:
|
| 376 |
+
ax2.plot(
|
| 377 |
+
[m.step for m in val_f1],
|
| 378 |
+
[m.value for m in val_f1],
|
| 379 |
+
label="Val F1",
|
| 380 |
+
linewidth=2,
|
| 381 |
+
color=COLORS["highlight"],
|
| 382 |
+
alpha=0.7,
|
| 383 |
+
)
|
| 384 |
ax2.set_ylim(0, 1)
|
| 385 |
|
| 386 |
ax.set_title("Emotion Detection (28 classes)")
|
|
|
|
| 401 |
val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
|
| 402 |
|
| 403 |
if train_topic:
|
| 404 |
+
ax.plot(
|
| 405 |
+
[m.step for m in train_topic],
|
| 406 |
+
[m.value for m in train_topic],
|
| 407 |
+
label="Train Loss",
|
| 408 |
+
linewidth=2.5,
|
| 409 |
+
color=COLORS["topic"],
|
| 410 |
+
)
|
| 411 |
if val_topic:
|
| 412 |
+
ax.plot(
|
| 413 |
+
[m.step for m in val_topic],
|
| 414 |
+
[m.value for m in val_topic],
|
| 415 |
+
label="Val Loss",
|
| 416 |
+
linewidth=2.5,
|
| 417 |
+
color=COLORS["secondary"],
|
| 418 |
+
linestyle="--",
|
| 419 |
+
)
|
| 420 |
|
| 421 |
ax2 = ax.twinx()
|
| 422 |
if train_acc:
|
| 423 |
+
ax2.plot(
|
| 424 |
+
[m.step for m in train_acc],
|
| 425 |
+
[m.value for m in train_acc],
|
| 426 |
+
label="Train Acc",
|
| 427 |
+
linewidth=2,
|
| 428 |
+
color=COLORS["accent"],
|
| 429 |
+
alpha=0.7,
|
| 430 |
+
)
|
| 431 |
if val_acc:
|
| 432 |
+
ax2.plot(
|
| 433 |
+
[m.step for m in val_acc],
|
| 434 |
+
[m.value for m in val_acc],
|
| 435 |
+
label="Val Acc",
|
| 436 |
+
linewidth=2,
|
| 437 |
+
color=COLORS["highlight"],
|
| 438 |
+
alpha=0.7,
|
| 439 |
+
)
|
| 440 |
ax2.set_ylim(0, 1)
|
| 441 |
|
| 442 |
ax.set_title("Topic Classification (4 classes)")
|
|
|
|
| 454 |
ax.axis("off")
|
| 455 |
|
| 456 |
# Get final metrics
|
| 457 |
+
summary_lines = [
|
| 458 |
+
"+--------------------------------------+",
|
| 459 |
+
"| FINAL METRICS (Last Epoch) |",
|
| 460 |
+
"+--------------------------------------+",
|
| 461 |
+
]
|
| 462 |
|
| 463 |
if val_topic and val_acc:
|
| 464 |
summary_lines.append(f"| Topic Accuracy: {val_acc[-1].value:>6.1%} |")
|
|
|
|
| 469 |
|
| 470 |
summary_lines.append("+--------------------------------------+")
|
| 471 |
|
| 472 |
+
ax.text(
|
| 473 |
+
0.1,
|
| 474 |
+
0.6,
|
| 475 |
+
"\n".join(summary_lines),
|
| 476 |
+
fontsize=11,
|
| 477 |
+
family="monospace",
|
| 478 |
+
verticalalignment="center",
|
| 479 |
+
bbox=dict(boxstyle="round", facecolor=COLORS["light"]),
|
| 480 |
+
)
|
| 481 |
|
| 482 |
# Add model info
|
| 483 |
run_params = run.data.params
|
|
|
|
| 485 |
model_info += f"Batch Size: {run_params.get('batch_size', 'N/A')}\n"
|
| 486 |
model_info += f"Learning Rate: {run_params.get('learning_rate', 'N/A')}"
|
| 487 |
|
| 488 |
+
ax.text(0.1, 0.15, model_info, fontsize=10, color="gray", verticalalignment="center")
|
|
|
|
| 489 |
|
| 490 |
plt.tight_layout()
|
| 491 |
output_path = OUTPUTS_DIR / "task_metrics.png"
|
|
|
|
| 504 |
if not lr_metrics or len(lr_metrics) < 2:
|
| 505 |
# No LR data logged - generate theoretical schedule from config
|
| 506 |
logger.info(" No LR metrics found - generating theoretical schedule...")
|
| 507 |
+
|
| 508 |
# Get config from run params
|
| 509 |
params = run.data.params
|
| 510 |
lr_max = float(params.get("learning_rate", params.get("lr", 5e-5)))
|
| 511 |
warmup_steps = int(params.get("warmup_steps", 500))
|
| 512 |
max_epochs = int(params.get("max_epochs", 5))
|
| 513 |
+
|
| 514 |
# Estimate total steps from training loss history
|
| 515 |
train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
|
| 516 |
if train_loss:
|
|
|
|
| 519 |
total_steps = max_epochs * estimated_steps_per_epoch
|
| 520 |
else:
|
| 521 |
total_steps = 4000 # Default fallback
|
| 522 |
+
|
| 523 |
# Generate cosine schedule with warmup
|
| 524 |
steps = np.arange(0, total_steps)
|
| 525 |
values = []
|
|
|
|
| 530 |
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 531 |
lr = lr_max * max(0.1, 0.5 * (1 + np.cos(np.pi * progress)))
|
| 532 |
values.append(lr)
|
| 533 |
+
|
| 534 |
ax.fill_between(steps, values, alpha=0.3, color=COLORS["primary"])
|
| 535 |
ax.plot(steps, values, linewidth=2.5, color=COLORS["primary"], label="Cosine + Warmup")
|
| 536 |
+
|
| 537 |
# Mark warmup region
|
| 538 |
+
ax.axvline(
|
| 539 |
+
warmup_steps,
|
| 540 |
+
color=COLORS["secondary"],
|
| 541 |
+
linestyle="--",
|
| 542 |
+
alpha=0.7,
|
| 543 |
+
linewidth=2,
|
| 544 |
+
label=f"Warmup End ({warmup_steps})",
|
| 545 |
+
)
|
| 546 |
ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"])
|
| 547 |
+
|
| 548 |
# Add annotation
|
| 549 |
+
ax.annotate(
|
| 550 |
+
f"Peak LR: {lr_max:.1e}",
|
| 551 |
+
xy=(warmup_steps, lr_max),
|
| 552 |
+
xytext=(warmup_steps + 200, lr_max * 0.9),
|
| 553 |
+
fontsize=10,
|
| 554 |
+
color=COLORS["dark"],
|
| 555 |
+
arrowprops=dict(arrowstyle="->", color=COLORS["dark"], alpha=0.5),
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
ax.legend(loc="upper right")
|
| 559 |
+
ax.text(
|
| 560 |
+
0.98,
|
| 561 |
+
0.02,
|
| 562 |
+
"(Theoretical - actual LR not logged)",
|
| 563 |
+
transform=ax.transAxes,
|
| 564 |
+
ha="right",
|
| 565 |
+
va="bottom",
|
| 566 |
+
fontsize=9,
|
| 567 |
+
color="gray",
|
| 568 |
+
style="italic",
|
| 569 |
+
)
|
| 570 |
else:
|
| 571 |
steps = np.array([m.step for m in lr_metrics])
|
| 572 |
values = [m.value for m in lr_metrics]
|
|
|
|
| 579 |
params = run.data.params
|
| 580 |
warmup_steps = int(params.get("warmup_steps", 500))
|
| 581 |
if warmup_steps < max(steps):
|
| 582 |
+
ax.axvline(
|
| 583 |
+
warmup_steps,
|
| 584 |
+
color=COLORS["secondary"],
|
| 585 |
+
linestyle="--",
|
| 586 |
+
alpha=0.7,
|
| 587 |
+
linewidth=2,
|
| 588 |
+
label="Warmup End",
|
| 589 |
+
)
|
| 590 |
+
ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"], label="Warmup Phase")
|
| 591 |
ax.legend(loc="upper right")
|
| 592 |
|
| 593 |
# Scientific notation for y-axis if needed
|
|
|
|
| 606 |
|
| 607 |
# Advanced Visualizations
|
| 608 |
|
| 609 |
+
|
| 610 |
def plot_confusion_matrix(run, task: str = "topic") -> None:
|
| 611 |
"""
|
| 612 |
Plot confusion matrix for classification tasks.
|
|
|
|
| 618 |
if task == "topic":
|
| 619 |
default_labels = ["World", "Sports", "Business", "Sci/Tech"]
|
| 620 |
else: # emotion - top 8 for visibility
|
| 621 |
+
default_labels = [
|
| 622 |
+
"admiration",
|
| 623 |
+
"amusement",
|
| 624 |
+
"anger",
|
| 625 |
+
"annoyance",
|
| 626 |
+
"approval",
|
| 627 |
+
"caring",
|
| 628 |
+
"curiosity",
|
| 629 |
+
"desire",
|
| 630 |
+
]
|
| 631 |
|
| 632 |
if labels_path.exists():
|
| 633 |
with open(labels_path) as f:
|
|
|
|
| 660 |
# Plot
|
| 661 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 662 |
|
| 663 |
+
sns.heatmap(
|
| 664 |
+
cm_normalized,
|
| 665 |
+
annot=True,
|
| 666 |
+
fmt=".2f",
|
| 667 |
+
cmap=HEATMAP_CMAP,
|
| 668 |
+
xticklabels=labels[:n_classes],
|
| 669 |
+
yticklabels=labels[:n_classes],
|
| 670 |
+
ax=ax,
|
| 671 |
+
cbar_kws={"label": "Proportion"},
|
| 672 |
+
)
|
| 673 |
|
| 674 |
ax.set_title(f"Confusion Matrix: {task.title()} Classification")
|
| 675 |
ax.set_xlabel("Predicted Label")
|
|
|
|
| 721 |
|
| 722 |
# Synthetic loss surface (bowl shape with some local minima)
|
| 723 |
min_loss = min(val_loss) if val_loss else min(train_loss)
|
| 724 |
+
Z = min_loss + 0.3 * (X**2 + Y**2) + 0.1 * np.sin(3 * X) * np.cos(3 * Y)
|
| 725 |
|
| 726 |
# Add noise for realism
|
| 727 |
Z += np.random.normal(0, 0.02, Z.shape)
|
|
|
|
| 735 |
fig = go.Figure()
|
| 736 |
|
| 737 |
# Loss surface
|
| 738 |
+
fig.add_trace(
|
| 739 |
+
go.Surface(
|
| 740 |
+
x=X,
|
| 741 |
+
y=Y,
|
| 742 |
+
z=Z,
|
| 743 |
+
colorscale=[[0, COLORS["accent"]], [0.5, COLORS["primary"]], [1, COLORS["secondary"]]],
|
| 744 |
+
opacity=0.8,
|
| 745 |
+
showscale=True,
|
| 746 |
+
colorbar=dict(title="Loss", x=1.02),
|
| 747 |
+
)
|
| 748 |
+
)
|
| 749 |
|
| 750 |
# Training trajectory
|
| 751 |
+
fig.add_trace(
|
| 752 |
+
go.Scatter3d(
|
| 753 |
+
x=trajectory_x,
|
| 754 |
+
y=trajectory_y,
|
| 755 |
+
z=trajectory_z,
|
| 756 |
+
mode="lines+markers",
|
| 757 |
+
line=dict(color=COLORS["highlight"], width=5),
|
| 758 |
+
marker=dict(size=4, color=COLORS["highlight"]),
|
| 759 |
+
name="Training Path",
|
| 760 |
+
)
|
| 761 |
+
)
|
| 762 |
|
| 763 |
# Mark start and end
|
| 764 |
+
fig.add_trace(
|
| 765 |
+
go.Scatter3d(
|
| 766 |
+
x=[trajectory_x[0]],
|
| 767 |
+
y=[trajectory_y[0]],
|
| 768 |
+
z=[trajectory_z[0]],
|
| 769 |
+
mode="markers+text",
|
| 770 |
+
marker=dict(size=10, color="red", symbol="circle"),
|
| 771 |
+
text=["Start"],
|
| 772 |
+
textposition="top center",
|
| 773 |
+
name="Start",
|
| 774 |
+
)
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
fig.add_trace(
|
| 778 |
+
go.Scatter3d(
|
| 779 |
+
x=[trajectory_x[-1]],
|
| 780 |
+
y=[trajectory_y[-1]],
|
| 781 |
+
z=[trajectory_z[-1]],
|
| 782 |
+
mode="markers+text",
|
| 783 |
+
marker=dict(size=10, color="green", symbol="diamond"),
|
| 784 |
+
text=["Converged"],
|
| 785 |
+
textposition="top center",
|
| 786 |
+
name="Converged",
|
| 787 |
+
)
|
| 788 |
+
)
|
| 789 |
|
| 790 |
fig.update_layout(
|
| 791 |
title="Loss Landscape & Optimization Trajectory",
|
|
|
|
| 793 |
xaxis_title="Parameter Direction 1",
|
| 794 |
yaxis_title="Parameter Direction 2",
|
| 795 |
zaxis_title="Loss",
|
| 796 |
+
camera=dict(eye=dict(x=1.5, y=1.5, z=0.8)),
|
| 797 |
),
|
| 798 |
width=900,
|
| 799 |
height=700,
|
|
|
|
| 825 |
X, Y = np.meshgrid(x, y)
|
| 826 |
|
| 827 |
min_loss = min(train_loss)
|
| 828 |
+
Z = min_loss + 0.3 * (X**2 + Y**2) + 0.08 * np.sin(3 * X) * np.cos(3 * Y)
|
| 829 |
|
| 830 |
fig = plt.figure(figsize=(12, 8))
|
| 831 |
ax = fig.add_subplot(111, projection="3d")
|
| 832 |
|
| 833 |
# Surface
|
| 834 |
+
surf = ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.7, linewidth=0, antialiased=True)
|
|
|
|
| 835 |
|
| 836 |
# Training path
|
| 837 |
path_x = np.linspace(-1.5, 0, len(train_loss))
|
| 838 |
path_y = np.linspace(1.2, 0, len(train_loss))
|
| 839 |
+
ax.plot(
|
| 840 |
+
path_x,
|
| 841 |
+
path_y,
|
| 842 |
+
train_loss,
|
| 843 |
+
color=COLORS["secondary"],
|
| 844 |
+
linewidth=3,
|
| 845 |
+
label="Training Path",
|
| 846 |
+
zorder=10,
|
| 847 |
+
)
|
| 848 |
|
| 849 |
# Start/end markers
|
| 850 |
+
ax.scatter(
|
| 851 |
+
[path_x[0]],
|
| 852 |
+
[path_y[0]],
|
| 853 |
+
train_loss[0], # type: ignore[arg-type]
|
| 854 |
+
c="red",
|
| 855 |
+
s=100,
|
| 856 |
+
marker="o",
|
| 857 |
+
label="Start",
|
| 858 |
+
)
|
| 859 |
+
ax.scatter(
|
| 860 |
+
[path_x[-1]],
|
| 861 |
+
[path_y[-1]],
|
| 862 |
+
train_loss[-1], # type: ignore[arg-type]
|
| 863 |
+
c="green",
|
| 864 |
+
s=100,
|
| 865 |
+
marker="*",
|
| 866 |
+
label="Converged",
|
| 867 |
+
)
|
| 868 |
|
| 869 |
ax.set_xlabel("θ₁ Direction")
|
| 870 |
ax.set_ylabel("θ₂ Direction")
|
|
|
|
| 909 |
for i in range(n_clusters):
|
| 910 |
# Create cluster center
|
| 911 |
center = np.random.randn(64) * 0.5
|
| 912 |
+
center[i * 16 : (i + 1) * 16] += 3 # Make clusters separable
|
| 913 |
|
| 914 |
# Add samples around center
|
| 915 |
samples = center + np.random.randn(n_samples // n_clusters, 64) * 0.5
|
|
|
|
| 929 |
|
| 930 |
for i in range(n_clusters):
|
| 931 |
mask = cluster_labels == i
|
| 932 |
+
ax.scatter(
|
| 933 |
+
embeddings_2d[mask, 0],
|
| 934 |
+
embeddings_2d[mask, 1],
|
| 935 |
+
c=colors[i],
|
| 936 |
+
label=labels[i],
|
| 937 |
+
alpha=0.6,
|
| 938 |
+
s=30,
|
| 939 |
+
)
|
| 940 |
|
| 941 |
ax.set_xlabel("t-SNE Dimension 1")
|
| 942 |
ax.set_ylabel("t-SNE Dimension 2")
|
|
|
|
| 980 |
# Smoothed loss (exponential moving average)
|
| 981 |
if len(train_loss) > 5:
|
| 982 |
window = min(5, len(train_loss) // 2)
|
| 983 |
+
smoothed = np.convolve(train_loss, np.ones(window) / window, mode="valid")
|
| 984 |
+
smoothed_steps = train_steps[window - 1 :]
|
| 985 |
+
ax.plot(
|
| 986 |
+
smoothed_steps,
|
| 987 |
+
smoothed,
|
| 988 |
+
color=COLORS["primary"],
|
| 989 |
+
linewidth=2.5,
|
| 990 |
+
label="Training (smoothed)",
|
| 991 |
+
)
|
| 992 |
|
| 993 |
if val_loss:
|
| 994 |
+
ax.plot(val_steps, val_loss, color=COLORS["secondary"], linewidth=2.5, label="Validation")
|
|
|
|
| 995 |
|
| 996 |
ax.set_title("Loss Convergence")
|
| 997 |
ax.set_xlabel("Epoch")
|
|
|
|
| 1003 |
ax = axes[0, 1]
|
| 1004 |
|
| 1005 |
if len(train_loss) > 1:
|
| 1006 |
+
improvements = [
|
| 1007 |
+
-(train_loss[i] - train_loss[i - 1]) / train_loss[i - 1] * 100
|
| 1008 |
+
for i in range(1, len(train_loss))
|
| 1009 |
+
]
|
| 1010 |
colors_bar = [COLORS["accent"] if imp > 0 else COLORS["secondary"] for imp in improvements]
|
| 1011 |
ax.bar(train_steps[1:], improvements, color=colors_bar, alpha=0.7)
|
| 1012 |
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
|
|
|
|
| 1061 |
|
| 1062 |
# Dashboard Generator
|
| 1063 |
|
| 1064 |
+
|
| 1065 |
def generate_dashboard(run) -> None:
|
| 1066 |
"""
|
| 1067 |
Generate an interactive HTML dashboard with all visualizations.
|
|
|
|
| 1083 |
|
| 1084 |
# Create subplots
|
| 1085 |
fig = make_subplots(
|
| 1086 |
+
rows=2,
|
| 1087 |
+
cols=2,
|
| 1088 |
subplot_titles=("Total Loss", "Task Losses", "Learning Rate", "Metrics"),
|
| 1089 |
+
specs=[[{}, {}], [{}, {}]],
|
| 1090 |
)
|
| 1091 |
|
| 1092 |
# Total loss
|
| 1093 |
if train_loss:
|
| 1094 |
fig.add_trace(
|
| 1095 |
+
go.Scatter(
|
| 1096 |
+
x=train_steps, y=train_loss, name="Train Loss", line=dict(color=COLORS["primary"])
|
| 1097 |
+
),
|
| 1098 |
+
row=1,
|
| 1099 |
+
col=1,
|
| 1100 |
)
|
| 1101 |
if val_loss:
|
| 1102 |
fig.add_trace(
|
| 1103 |
+
go.Scatter(
|
| 1104 |
+
x=val_steps, y=val_loss, name="Val Loss", line=dict(color=COLORS["secondary"])
|
| 1105 |
+
),
|
| 1106 |
+
row=1,
|
| 1107 |
+
col=1,
|
| 1108 |
)
|
| 1109 |
|
| 1110 |
# Per-task losses
|
| 1111 |
+
for task, color in [
|
| 1112 |
+
("summarization", COLORS["summary"]),
|
| 1113 |
+
("emotion", COLORS["emotion"]),
|
| 1114 |
+
("topic", COLORS["topic"]),
|
| 1115 |
+
]:
|
| 1116 |
steps, values = get_metric_history(run, f"val_{task}_loss")
|
| 1117 |
if values:
|
| 1118 |
fig.add_trace(
|
| 1119 |
+
go.Scatter(x=steps, y=values, name=f"{task.title()} Loss", line=dict(color=color)),
|
| 1120 |
+
row=1,
|
| 1121 |
+
col=2,
|
| 1122 |
)
|
| 1123 |
|
| 1124 |
# Learning rate
|
| 1125 |
lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
|
| 1126 |
if lr_metrics:
|
| 1127 |
fig.add_trace(
|
| 1128 |
+
go.Scatter(
|
| 1129 |
+
x=[m.step for m in lr_metrics],
|
| 1130 |
+
y=[m.value for m in lr_metrics],
|
| 1131 |
+
name="Learning Rate",
|
| 1132 |
+
fill="tozeroy",
|
| 1133 |
+
line=dict(color=COLORS["primary"]),
|
| 1134 |
+
),
|
| 1135 |
+
row=2,
|
| 1136 |
+
col=1,
|
| 1137 |
)
|
| 1138 |
|
| 1139 |
# Accuracy metrics
|
| 1140 |
+
for metric, color in [("topic_accuracy", COLORS["topic"]), ("emotion_f1", COLORS["emotion"])]:
|
|
|
|
| 1141 |
steps, values = get_metric_history(run, f"val_{metric}")
|
| 1142 |
if values:
|
| 1143 |
fig.add_trace(
|
| 1144 |
+
go.Scatter(
|
| 1145 |
+
x=steps, y=values, name=metric.replace("_", " ").title(), line=dict(color=color)
|
| 1146 |
+
),
|
| 1147 |
+
row=2,
|
| 1148 |
+
col=2,
|
| 1149 |
)
|
| 1150 |
|
| 1151 |
fig.update_layout(
|
| 1152 |
+
title="LexiMind Training Dashboard", height=800, template="plotly_white", showlegend=True
|
|
|
|
|
|
|
|
|
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
output_path = OUTPUTS_DIR / "training_dashboard.html"
|
|
|
|
| 1159 |
|
| 1160 |
# Main Entry Point
|
| 1161 |
|
| 1162 |
+
|
| 1163 |
def main():
|
| 1164 |
"""Generate all training visualizations."""
|
| 1165 |
parser = argparse.ArgumentParser(description="LexiMind Visualization Suite")
|
| 1166 |
+
parser.add_argument(
|
| 1167 |
+
"--interactive",
|
| 1168 |
+
action="store_true",
|
| 1169 |
+
help="Generate interactive HTML plots (requires plotly)",
|
| 1170 |
+
)
|
| 1171 |
+
parser.add_argument(
|
| 1172 |
+
"--landscape", action="store_true", help="Include 3D loss landscape visualization"
|
| 1173 |
+
)
|
| 1174 |
+
parser.add_argument("--dashboard", action="store_true", help="Generate interactive dashboard")
|
| 1175 |
+
parser.add_argument("--all", action="store_true", help="Generate all visualizations")
|
| 1176 |
args = parser.parse_args()
|
| 1177 |
|
| 1178 |
logger.info("=" * 60)
|
src/data/dataset.py
CHANGED
|
@@ -24,6 +24,7 @@ from torch.utils.data import Dataset
|
|
| 24 |
@dataclass
|
| 25 |
class SummarizationExample:
|
| 26 |
"""Container for abstractive summarization samples."""
|
|
|
|
| 27 |
source: str
|
| 28 |
summary: str
|
| 29 |
|
|
@@ -31,6 +32,7 @@ class SummarizationExample:
|
|
| 31 |
@dataclass
|
| 32 |
class EmotionExample:
|
| 33 |
"""Container for multi-label emotion classification samples."""
|
|
|
|
| 34 |
text: str
|
| 35 |
emotions: Sequence[str]
|
| 36 |
|
|
@@ -38,12 +40,14 @@ class EmotionExample:
|
|
| 38 |
@dataclass
|
| 39 |
class TopicExample:
|
| 40 |
"""Container for topic clustering / classification samples."""
|
|
|
|
| 41 |
text: str
|
| 42 |
topic: str
|
| 43 |
|
| 44 |
|
| 45 |
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 46 |
"""Dataset yielding encoder-decoder training pairs."""
|
|
|
|
| 47 |
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 48 |
self._examples = list(examples)
|
| 49 |
|
|
@@ -56,6 +60,7 @@ class SummarizationDataset(Dataset[SummarizationExample]):
|
|
| 56 |
|
| 57 |
class EmotionDataset(Dataset[EmotionExample]):
|
| 58 |
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
|
|
|
| 59 |
def __init__(
|
| 60 |
self,
|
| 61 |
examples: Iterable[EmotionExample],
|
|
@@ -91,6 +96,7 @@ class EmotionDataset(Dataset[EmotionExample]):
|
|
| 91 |
|
| 92 |
class TopicDataset(Dataset[TopicExample]):
|
| 93 |
"""Dataset that owns a LabelEncoder for topic ids."""
|
|
|
|
| 94 |
def __init__(
|
| 95 |
self,
|
| 96 |
examples: Iterable[TopicExample],
|
|
@@ -241,7 +247,7 @@ def load_topic_jsonl(path: str) -> List[TopicExample]:
|
|
| 241 |
|
| 242 |
def _text_fingerprint(text: str, n_chars: int = 200) -> str:
|
| 243 |
"""Create a stable fingerprint from the first N characters of text.
|
| 244 |
-
|
| 245 |
Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
|
| 246 |
to detect document-level overlap across tasks.
|
| 247 |
"""
|
|
@@ -255,28 +261,28 @@ def deduplicate_across_tasks(
|
|
| 255 |
emotion_examples: List[EmotionExample] | None = None,
|
| 256 |
) -> Dict[str, int]:
|
| 257 |
"""Detect and report cross-task document overlap.
|
| 258 |
-
|
| 259 |
Checks whether texts appearing in the summarization dataset also appear
|
| 260 |
in the topic or emotion datasets, which could create data leakage in MTL.
|
| 261 |
-
|
| 262 |
Returns:
|
| 263 |
Dict with overlap counts between task pairs.
|
| 264 |
"""
|
| 265 |
summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
|
| 266 |
topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
|
| 267 |
-
|
| 268 |
overlap: Dict[str, int] = {
|
| 269 |
"summ_topic_overlap": len(summ_fps & topic_fps),
|
| 270 |
"summ_total": len(summ_fps),
|
| 271 |
"topic_total": len(topic_fps),
|
| 272 |
}
|
| 273 |
-
|
| 274 |
if emotion_examples:
|
| 275 |
emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
|
| 276 |
overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
|
| 277 |
overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
|
| 278 |
overlap["emotion_total"] = len(emot_fps)
|
| 279 |
-
|
| 280 |
return overlap
|
| 281 |
|
| 282 |
|
|
@@ -286,20 +292,20 @@ def remove_overlapping_examples(
|
|
| 286 |
split: str = "val",
|
| 287 |
) -> tuple[List[TopicExample], int]:
|
| 288 |
"""Remove topic examples whose texts overlap with summarization data.
|
| 289 |
-
|
| 290 |
-
This prevents cross-task data leakage where a document seen during
|
| 291 |
summarization training could boost topic classification on validation/test.
|
| 292 |
-
|
| 293 |
Args:
|
| 294 |
primary_examples: Topic examples to filter
|
| 295 |
reference_examples: Summarization examples to check against
|
| 296 |
split: Name of split being processed (for logging)
|
| 297 |
-
|
| 298 |
Returns:
|
| 299 |
Tuple of (filtered_examples, num_removed)
|
| 300 |
"""
|
| 301 |
ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
|
| 302 |
-
|
| 303 |
filtered = []
|
| 304 |
removed = 0
|
| 305 |
for ex in primary_examples:
|
|
@@ -308,8 +314,8 @@ def remove_overlapping_examples(
|
|
| 308 |
removed += 1
|
| 309 |
else:
|
| 310 |
filtered.append(ex)
|
| 311 |
-
|
| 312 |
if removed > 0:
|
| 313 |
print(f" Dedup: removed {removed} overlapping examples from topic {split}")
|
| 314 |
-
|
| 315 |
return filtered, removed
|
|
|
|
| 24 |
@dataclass
|
| 25 |
class SummarizationExample:
|
| 26 |
"""Container for abstractive summarization samples."""
|
| 27 |
+
|
| 28 |
source: str
|
| 29 |
summary: str
|
| 30 |
|
|
|
|
| 32 |
@dataclass
|
| 33 |
class EmotionExample:
|
| 34 |
"""Container for multi-label emotion classification samples."""
|
| 35 |
+
|
| 36 |
text: str
|
| 37 |
emotions: Sequence[str]
|
| 38 |
|
|
|
|
| 40 |
@dataclass
|
| 41 |
class TopicExample:
|
| 42 |
"""Container for topic clustering / classification samples."""
|
| 43 |
+
|
| 44 |
text: str
|
| 45 |
topic: str
|
| 46 |
|
| 47 |
|
| 48 |
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 49 |
"""Dataset yielding encoder-decoder training pairs."""
|
| 50 |
+
|
| 51 |
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 52 |
self._examples = list(examples)
|
| 53 |
|
|
|
|
| 60 |
|
| 61 |
class EmotionDataset(Dataset[EmotionExample]):
|
| 62 |
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
| 63 |
+
|
| 64 |
def __init__(
|
| 65 |
self,
|
| 66 |
examples: Iterable[EmotionExample],
|
|
|
|
| 96 |
|
| 97 |
class TopicDataset(Dataset[TopicExample]):
|
| 98 |
"""Dataset that owns a LabelEncoder for topic ids."""
|
| 99 |
+
|
| 100 |
def __init__(
|
| 101 |
self,
|
| 102 |
examples: Iterable[TopicExample],
|
|
|
|
| 247 |
|
| 248 |
def _text_fingerprint(text: str, n_chars: int = 200) -> str:
|
| 249 |
"""Create a stable fingerprint from the first N characters of text.
|
| 250 |
+
|
| 251 |
Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
|
| 252 |
to detect document-level overlap across tasks.
|
| 253 |
"""
|
|
|
|
| 261 |
emotion_examples: List[EmotionExample] | None = None,
|
| 262 |
) -> Dict[str, int]:
|
| 263 |
"""Detect and report cross-task document overlap.
|
| 264 |
+
|
| 265 |
Checks whether texts appearing in the summarization dataset also appear
|
| 266 |
in the topic or emotion datasets, which could create data leakage in MTL.
|
| 267 |
+
|
| 268 |
Returns:
|
| 269 |
Dict with overlap counts between task pairs.
|
| 270 |
"""
|
| 271 |
summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
|
| 272 |
topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
|
| 273 |
+
|
| 274 |
overlap: Dict[str, int] = {
|
| 275 |
"summ_topic_overlap": len(summ_fps & topic_fps),
|
| 276 |
"summ_total": len(summ_fps),
|
| 277 |
"topic_total": len(topic_fps),
|
| 278 |
}
|
| 279 |
+
|
| 280 |
if emotion_examples:
|
| 281 |
emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
|
| 282 |
overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
|
| 283 |
overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
|
| 284 |
overlap["emotion_total"] = len(emot_fps)
|
| 285 |
+
|
| 286 |
return overlap
|
| 287 |
|
| 288 |
|
|
|
|
| 292 |
split: str = "val",
|
| 293 |
) -> tuple[List[TopicExample], int]:
|
| 294 |
"""Remove topic examples whose texts overlap with summarization data.
|
| 295 |
+
|
| 296 |
+
This prevents cross-task data leakage where a document seen during
|
| 297 |
summarization training could boost topic classification on validation/test.
|
| 298 |
+
|
| 299 |
Args:
|
| 300 |
primary_examples: Topic examples to filter
|
| 301 |
reference_examples: Summarization examples to check against
|
| 302 |
split: Name of split being processed (for logging)
|
| 303 |
+
|
| 304 |
Returns:
|
| 305 |
Tuple of (filtered_examples, num_removed)
|
| 306 |
"""
|
| 307 |
ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
|
| 308 |
+
|
| 309 |
filtered = []
|
| 310 |
removed = 0
|
| 311 |
for ex in primary_examples:
|
|
|
|
| 314 |
removed += 1
|
| 315 |
else:
|
| 316 |
filtered.append(ex)
|
| 317 |
+
|
| 318 |
if removed > 0:
|
| 319 |
print(f" Dedup: removed {removed} overlapping examples from topic {split}")
|
| 320 |
+
|
| 321 |
return filtered, removed
|
src/models/decoder.py
CHANGED
|
@@ -327,7 +327,6 @@ class TransformerDecoder(nn.Module):
|
|
| 327 |
elif tgt_mask.dim() == 3:
|
| 328 |
tgt_mask = tgt_mask.unsqueeze(1)
|
| 329 |
|
| 330 |
-
|
| 331 |
# Normalize memory_mask dtype/device and expand simple shapes
|
| 332 |
if memory_mask is not None:
|
| 333 |
memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
|
|
@@ -355,7 +354,15 @@ class TransformerDecoder(nn.Module):
|
|
| 355 |
# Gradient checkpointing requires the inputs to require grad
|
| 356 |
def create_custom_forward(module):
|
| 357 |
def custom_forward(*inputs):
|
| 358 |
-
return module(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
return custom_forward
|
| 360 |
|
| 361 |
x, attn = cast(
|
|
@@ -450,7 +457,7 @@ class TransformerDecoder(nn.Module):
|
|
| 450 |
) -> torch.Tensor:
|
| 451 |
"""
|
| 452 |
Greedy decoding with KV caching for O(N) complexity.
|
| 453 |
-
|
| 454 |
Args:
|
| 455 |
length_penalty: Values > 1.0 encourage shorter sequences by boosting EOS probability
|
| 456 |
as sequence length increases. Default 1.0 (no penalty).
|
|
|
|
| 327 |
elif tgt_mask.dim() == 3:
|
| 328 |
tgt_mask = tgt_mask.unsqueeze(1)
|
| 329 |
|
|
|
|
| 330 |
# Normalize memory_mask dtype/device and expand simple shapes
|
| 331 |
if memory_mask is not None:
|
| 332 |
memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
|
|
|
|
| 354 |
# Gradient checkpointing requires the inputs to require grad
|
| 355 |
def create_custom_forward(module):
|
| 356 |
def custom_forward(*inputs):
|
| 357 |
+
return module(
|
| 358 |
+
*inputs,
|
| 359 |
+
tgt_mask=tgt_mask,
|
| 360 |
+
memory_mask=memory_mask,
|
| 361 |
+
collect_attn=collect_attn,
|
| 362 |
+
self_attn_position_bias=self_position_bias,
|
| 363 |
+
cross_attn_position_bias=cross_position_bias,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
return custom_forward
|
| 367 |
|
| 368 |
x, attn = cast(
|
|
|
|
| 457 |
) -> torch.Tensor:
|
| 458 |
"""
|
| 459 |
Greedy decoding with KV caching for O(N) complexity.
|
| 460 |
+
|
| 461 |
Args:
|
| 462 |
length_penalty: Values > 1.0 encourage shorter sequences by boosting EOS probability
|
| 463 |
as sequence length increases. Default 1.0 (no penalty).
|
src/models/encoder.py
CHANGED
|
@@ -291,7 +291,13 @@ class TransformerEncoder(nn.Module):
|
|
| 291 |
# We use a lambda to pass keyword arguments
|
| 292 |
def create_custom_forward(module):
|
| 293 |
def custom_forward(*inputs):
|
| 294 |
-
return module(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
return custom_forward
|
| 296 |
|
| 297 |
x, attn = cast(
|
|
@@ -303,8 +309,10 @@ class TransformerEncoder(nn.Module):
|
|
| 303 |
),
|
| 304 |
)
|
| 305 |
else:
|
| 306 |
-
x, attn = layer(
|
| 307 |
-
|
|
|
|
|
|
|
| 308 |
if collect_attn:
|
| 309 |
attn_weights_per_layer.append(attn)
|
| 310 |
|
|
|
|
| 291 |
# We use a lambda to pass keyword arguments
|
| 292 |
def create_custom_forward(module):
|
| 293 |
def custom_forward(*inputs):
|
| 294 |
+
return module(
|
| 295 |
+
*inputs,
|
| 296 |
+
mask=mask,
|
| 297 |
+
collect_attn=collect_attn,
|
| 298 |
+
position_bias=position_bias,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
return custom_forward
|
| 302 |
|
| 303 |
x, attn = cast(
|
|
|
|
| 309 |
),
|
| 310 |
)
|
| 311 |
else:
|
| 312 |
+
x, attn = layer(
|
| 313 |
+
x, mask=mask, collect_attn=collect_attn, position_bias=position_bias
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
if collect_attn:
|
| 317 |
attn_weights_per_layer.append(attn)
|
| 318 |
|
src/models/factory.py
CHANGED
|
@@ -208,7 +208,9 @@ def _load_pretrained_weights(
|
|
| 208 |
if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
|
| 209 |
print("Transferring encoder relative position bias...")
|
| 210 |
t5_enc_rel_bias = (
|
| 211 |
-
cast(Any, t5_encoder.block[0])
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
|
| 214 |
|
|
@@ -285,7 +287,9 @@ def _load_pretrained_weights(
|
|
| 285 |
):
|
| 286 |
print("Transferring decoder self-attention relative position bias...")
|
| 287 |
t5_dec_self_rel_bias = (
|
| 288 |
-
cast(Any, t5_decoder.block[0])
|
|
|
|
|
|
|
| 289 |
)
|
| 290 |
decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 291 |
t5_dec_self_rel_bias
|
|
@@ -298,7 +302,9 @@ def _load_pretrained_weights(
|
|
| 298 |
print("Transferring decoder cross-attention relative position bias...")
|
| 299 |
# Cross-attention relative position bias is in EncDecAttention of first block
|
| 300 |
t5_dec_cross_rel_bias = (
|
| 301 |
-
cast(Any, t5_decoder.block[0])
|
|
|
|
|
|
|
| 302 |
)
|
| 303 |
decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 304 |
t5_dec_cross_rel_bias
|
|
@@ -554,9 +560,9 @@ def build_multitask_model(
|
|
| 554 |
model.add_head(
|
| 555 |
"emotion",
|
| 556 |
ClassificationHead(
|
| 557 |
-
d_model=cfg.d_model,
|
| 558 |
-
num_labels=num_emotions,
|
| 559 |
-
pooler="attention",
|
| 560 |
dropout=cfg.dropout,
|
| 561 |
hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
|
| 562 |
),
|
|
|
|
| 208 |
if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
|
| 209 |
print("Transferring encoder relative position bias...")
|
| 210 |
t5_enc_rel_bias = (
|
| 211 |
+
cast(Any, t5_encoder.block[0])
|
| 212 |
+
.layer[0]
|
| 213 |
+
.SelfAttention.relative_attention_bias.weight.data
|
| 214 |
)
|
| 215 |
encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
|
| 216 |
|
|
|
|
| 287 |
):
|
| 288 |
print("Transferring decoder self-attention relative position bias...")
|
| 289 |
t5_dec_self_rel_bias = (
|
| 290 |
+
cast(Any, t5_decoder.block[0])
|
| 291 |
+
.layer[0]
|
| 292 |
+
.SelfAttention.relative_attention_bias.weight.data
|
| 293 |
)
|
| 294 |
decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 295 |
t5_dec_self_rel_bias
|
|
|
|
| 302 |
print("Transferring decoder cross-attention relative position bias...")
|
| 303 |
# Cross-attention relative position bias is in EncDecAttention of first block
|
| 304 |
t5_dec_cross_rel_bias = (
|
| 305 |
+
cast(Any, t5_decoder.block[0])
|
| 306 |
+
.layer[1]
|
| 307 |
+
.EncDecAttention.relative_attention_bias.weight.data
|
| 308 |
)
|
| 309 |
decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 310 |
t5_dec_cross_rel_bias
|
|
|
|
| 560 |
model.add_head(
|
| 561 |
"emotion",
|
| 562 |
ClassificationHead(
|
| 563 |
+
d_model=cfg.d_model,
|
| 564 |
+
num_labels=num_emotions,
|
| 565 |
+
pooler="attention",
|
| 566 |
dropout=cfg.dropout,
|
| 567 |
hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
|
| 568 |
),
|
src/models/heads.py
CHANGED
|
@@ -66,13 +66,15 @@ class ClassificationHead(nn.Module):
|
|
| 66 |
hidden_dim: Optional[int] = None,
|
| 67 |
):
|
| 68 |
super().__init__()
|
| 69 |
-
assert pooler in ("mean", "cls", "max", "attention"),
|
|
|
|
|
|
|
| 70 |
self.pooler = pooler
|
| 71 |
self.dropout = nn.Dropout(dropout)
|
| 72 |
|
| 73 |
if pooler == "attention":
|
| 74 |
self.attn_pool = AttentionPooling(d_model)
|
| 75 |
-
|
| 76 |
# Optional 2-layer MLP for more capacity (useful for multi-label)
|
| 77 |
if hidden_dim is not None:
|
| 78 |
self.out_proj = nn.Sequential(
|
|
|
|
| 66 |
hidden_dim: Optional[int] = None,
|
| 67 |
):
|
| 68 |
super().__init__()
|
| 69 |
+
assert pooler in ("mean", "cls", "max", "attention"), (
|
| 70 |
+
"pooler must be 'mean'|'cls'|'max'|'attention'"
|
| 71 |
+
)
|
| 72 |
self.pooler = pooler
|
| 73 |
self.dropout = nn.Dropout(dropout)
|
| 74 |
|
| 75 |
if pooler == "attention":
|
| 76 |
self.attn_pool = AttentionPooling(d_model)
|
| 77 |
+
|
| 78 |
# Optional 2-layer MLP for more capacity (useful for multi-label)
|
| 79 |
if hidden_dim is not None:
|
| 80 |
self.out_proj = nn.Sequential(
|
src/training/metrics.py
CHANGED
|
@@ -72,33 +72,33 @@ def calculate_bertscore(
|
|
| 72 |
) -> Dict[str, float]:
|
| 73 |
"""
|
| 74 |
Calculate BERTScore for semantic similarity between predictions and references.
|
| 75 |
-
|
| 76 |
BERTScore measures semantic similarity using contextual embeddings, making it
|
| 77 |
more robust than n-gram based metrics like ROUGE for paraphrased content.
|
| 78 |
-
|
| 79 |
Args:
|
| 80 |
predictions: Generated summaries/descriptions
|
| 81 |
references: Reference summaries/descriptions
|
| 82 |
model_type: BERT model to use (default: deberta-xlarge-mnli for best quality)
|
| 83 |
batch_size: Batch size for encoding
|
| 84 |
device: Device to use (auto-detected if None)
|
| 85 |
-
|
| 86 |
Returns:
|
| 87 |
Dict with 'precision', 'recall', 'f1' BERTScore averages
|
| 88 |
"""
|
| 89 |
if not predictions or not references:
|
| 90 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 91 |
-
|
| 92 |
try:
|
| 93 |
from bert_score import score as bert_score # type: ignore[import-not-found]
|
| 94 |
except ImportError:
|
| 95 |
print("Warning: bert-score not installed. Run: pip install bert-score")
|
| 96 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 97 |
-
|
| 98 |
# Auto-detect device
|
| 99 |
if device is None:
|
| 100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
-
|
| 102 |
# Calculate BERTScore
|
| 103 |
P, R, F1 = bert_score(
|
| 104 |
list(predictions),
|
|
@@ -108,7 +108,7 @@ def calculate_bertscore(
|
|
| 108 |
device=device,
|
| 109 |
verbose=False,
|
| 110 |
)
|
| 111 |
-
|
| 112 |
return {
|
| 113 |
"precision": float(P.mean().item()), # type: ignore[union-attr]
|
| 114 |
"recall": float(R.mean().item()), # type: ignore[union-attr]
|
|
@@ -122,35 +122,35 @@ def calculate_rouge(
|
|
| 122 |
) -> Dict[str, float]:
|
| 123 |
"""
|
| 124 |
Calculate proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L).
|
| 125 |
-
|
| 126 |
Args:
|
| 127 |
predictions: Generated summaries
|
| 128 |
references: Reference summaries
|
| 129 |
-
|
| 130 |
Returns:
|
| 131 |
Dict with rouge1, rouge2, rougeL F1 scores
|
| 132 |
"""
|
| 133 |
if not predictions or not references:
|
| 134 |
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
|
| 135 |
-
|
| 136 |
try:
|
| 137 |
from rouge_score import rouge_scorer
|
| 138 |
except ImportError:
|
| 139 |
print("Warning: rouge-score not installed. Run: pip install rouge-score")
|
| 140 |
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
|
| 141 |
-
|
| 142 |
-
scorer = rouge_scorer.RougeScorer([
|
| 143 |
-
|
| 144 |
rouge1_scores = []
|
| 145 |
rouge2_scores = []
|
| 146 |
rougeL_scores = []
|
| 147 |
-
|
| 148 |
for pred, ref in zip(predictions, references, strict=False):
|
| 149 |
scores = scorer.score(ref, pred)
|
| 150 |
-
rouge1_scores.append(scores[
|
| 151 |
-
rouge2_scores.append(scores[
|
| 152 |
-
rougeL_scores.append(scores[
|
| 153 |
-
|
| 154 |
return {
|
| 155 |
"rouge1": sum(rouge1_scores) / len(rouge1_scores),
|
| 156 |
"rouge2": sum(rouge2_scores) / len(rouge2_scores),
|
|
@@ -166,37 +166,35 @@ def calculate_all_summarization_metrics(
|
|
| 166 |
) -> Dict[str, float]:
|
| 167 |
"""
|
| 168 |
Calculate comprehensive summarization metrics for research paper reporting.
|
| 169 |
-
|
| 170 |
Includes:
|
| 171 |
- ROUGE-1, ROUGE-2, ROUGE-L (lexical overlap)
|
| 172 |
- BLEU-4 (n-gram precision)
|
| 173 |
- BERTScore (semantic similarity)
|
| 174 |
-
|
| 175 |
Args:
|
| 176 |
predictions: Generated summaries/descriptions
|
| 177 |
references: Reference summaries/descriptions
|
| 178 |
include_bertscore: Whether to compute BERTScore (slower but valuable)
|
| 179 |
bertscore_model: Model for BERTScore computation
|
| 180 |
-
|
| 181 |
Returns:
|
| 182 |
Dict with all metric scores
|
| 183 |
"""
|
| 184 |
metrics: Dict[str, float] = {}
|
| 185 |
-
|
| 186 |
# ROUGE scores
|
| 187 |
rouge_scores = calculate_rouge(predictions, references)
|
| 188 |
metrics.update({f"rouge_{k}": v for k, v in rouge_scores.items()})
|
| 189 |
-
|
| 190 |
# BLEU score
|
| 191 |
metrics["bleu4"] = calculate_bleu(predictions, references)
|
| 192 |
-
|
| 193 |
# BERTScore (semantic similarity - important for back-cover style descriptions)
|
| 194 |
if include_bertscore:
|
| 195 |
-
bert_scores = calculate_bertscore(
|
| 196 |
-
predictions, references, model_type=bertscore_model
|
| 197 |
-
)
|
| 198 |
metrics.update({f"bertscore_{k}": v for k, v in bert_scores.items()})
|
| 199 |
-
|
| 200 |
return metrics
|
| 201 |
|
| 202 |
|
|
@@ -246,22 +244,22 @@ def get_confusion_matrix(
|
|
| 246 |
|
| 247 |
def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 248 |
"""Compute macro F1: average F1 per class (as in GoEmotions paper).
|
| 249 |
-
|
| 250 |
-
This averages F1 across labels, giving equal weight to each emotion class
|
| 251 |
regardless of prevalence. Directly comparable to GoEmotions baselines.
|
| 252 |
"""
|
| 253 |
preds = predictions.float()
|
| 254 |
gold = targets.float()
|
| 255 |
-
|
| 256 |
# Per-class TP, FP, FN
|
| 257 |
tp = (preds * gold).sum(dim=0)
|
| 258 |
fp = (preds * (1 - gold)).sum(dim=0)
|
| 259 |
fn = ((1 - preds) * gold).sum(dim=0)
|
| 260 |
-
|
| 261 |
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 262 |
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 263 |
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 264 |
-
|
| 265 |
# Zero out F1 for classes with no support in either predictions or targets
|
| 266 |
mask = (tp + fp + fn) > 0
|
| 267 |
if mask.sum() == 0:
|
|
@@ -271,16 +269,16 @@ def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> flo
|
|
| 271 |
|
| 272 |
def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 273 |
"""Compute micro F1: aggregate TP/FP/FN across all classes.
|
| 274 |
-
|
| 275 |
This gives more weight to frequent classes. Useful when class distribution matters.
|
| 276 |
"""
|
| 277 |
preds = predictions.float()
|
| 278 |
gold = targets.float()
|
| 279 |
-
|
| 280 |
tp = (preds * gold).sum()
|
| 281 |
fp = (preds * (1 - gold)).sum()
|
| 282 |
fn = ((1 - preds) * gold).sum()
|
| 283 |
-
|
| 284 |
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 285 |
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 286 |
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
|
@@ -293,17 +291,17 @@ def multilabel_per_class_metrics(
|
|
| 293 |
class_names: Sequence[str] | None = None,
|
| 294 |
) -> Dict[str, Dict[str, float]]:
|
| 295 |
"""Compute per-class precision, recall, F1 for multi-label classification.
|
| 296 |
-
|
| 297 |
Returns a dict mapping class name/index to its metrics.
|
| 298 |
"""
|
| 299 |
preds = predictions.float()
|
| 300 |
gold = targets.float()
|
| 301 |
num_classes = preds.shape[1]
|
| 302 |
-
|
| 303 |
tp = (preds * gold).sum(dim=0)
|
| 304 |
fp = (preds * (1 - gold)).sum(dim=0)
|
| 305 |
fn = ((1 - preds) * gold).sum(dim=0)
|
| 306 |
-
|
| 307 |
report: Dict[str, Dict[str, float]] = {}
|
| 308 |
for i in range(num_classes):
|
| 309 |
name = class_names[i] if class_names else str(i)
|
|
@@ -325,26 +323,26 @@ def tune_per_class_thresholds(
|
|
| 325 |
thresholds: Sequence[float] | None = None,
|
| 326 |
) -> tuple[List[float], float]:
|
| 327 |
"""Tune per-class thresholds on validation set to maximize macro F1.
|
| 328 |
-
|
| 329 |
-
For each class, tries multiple thresholds and selects the one that
|
| 330 |
-
maximizes that class's F1 score. This is standard practice for multi-label
|
| 331 |
classification (used in the original GoEmotions paper).
|
| 332 |
-
|
| 333 |
Args:
|
| 334 |
logits: Raw model logits (batch, num_classes)
|
| 335 |
targets: Binary target labels (batch, num_classes)
|
| 336 |
thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
|
| 337 |
-
|
| 338 |
Returns:
|
| 339 |
Tuple of (best_thresholds_per_class, resulting_macro_f1)
|
| 340 |
"""
|
| 341 |
if thresholds is None:
|
| 342 |
thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
|
| 343 |
-
|
| 344 |
probs = torch.sigmoid(logits)
|
| 345 |
num_classes = probs.shape[1]
|
| 346 |
gold = targets.float()
|
| 347 |
-
|
| 348 |
best_thresholds: List[float] = []
|
| 349 |
for c in range(num_classes):
|
| 350 |
best_f1 = -1.0
|
|
@@ -364,13 +362,13 @@ def tune_per_class_thresholds(
|
|
| 364 |
best_f1 = f1
|
| 365 |
best_t = t
|
| 366 |
best_thresholds.append(best_t)
|
| 367 |
-
|
| 368 |
# Compute resulting macro F1 with tuned thresholds
|
| 369 |
tuned_preds = torch.zeros_like(probs)
|
| 370 |
for c in range(num_classes):
|
| 371 |
tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
|
| 372 |
macro_f1 = multilabel_macro_f1(tuned_preds, targets)
|
| 373 |
-
|
| 374 |
return best_thresholds, macro_f1
|
| 375 |
|
| 376 |
|
|
@@ -384,30 +382,30 @@ def bootstrap_confidence_interval(
|
|
| 384 |
seed: int = 42,
|
| 385 |
) -> tuple[float, float, float]:
|
| 386 |
"""Compute bootstrap confidence interval for a metric.
|
| 387 |
-
|
| 388 |
Args:
|
| 389 |
scores: Per-sample metric values
|
| 390 |
n_bootstrap: Number of bootstrap resamples
|
| 391 |
confidence: Confidence level (default 95%)
|
| 392 |
seed: Random seed for reproducibility
|
| 393 |
-
|
| 394 |
Returns:
|
| 395 |
Tuple of (mean, lower_bound, upper_bound)
|
| 396 |
"""
|
| 397 |
rng = np.random.default_rng(seed)
|
| 398 |
scores_arr = np.array(scores)
|
| 399 |
n = len(scores_arr)
|
| 400 |
-
|
| 401 |
bootstrap_means = []
|
| 402 |
for _ in range(n_bootstrap):
|
| 403 |
sample = rng.choice(scores_arr, size=n, replace=True)
|
| 404 |
bootstrap_means.append(float(np.mean(sample)))
|
| 405 |
-
|
| 406 |
bootstrap_means.sort()
|
| 407 |
alpha = 1 - confidence
|
| 408 |
lower_idx = int(alpha / 2 * n_bootstrap)
|
| 409 |
upper_idx = int((1 - alpha / 2) * n_bootstrap)
|
| 410 |
-
|
| 411 |
return (
|
| 412 |
float(np.mean(scores_arr)),
|
| 413 |
bootstrap_means[lower_idx],
|
|
@@ -422,15 +420,15 @@ def paired_bootstrap_test(
|
|
| 422 |
seed: int = 42,
|
| 423 |
) -> float:
|
| 424 |
"""Paired bootstrap significance test between two systems.
|
| 425 |
-
|
| 426 |
Tests if system B is significantly better than system A.
|
| 427 |
-
|
| 428 |
Args:
|
| 429 |
scores_a: Per-sample scores from system A
|
| 430 |
scores_b: Per-sample scores from system B
|
| 431 |
n_bootstrap: Number of bootstrap iterations
|
| 432 |
seed: Random seed
|
| 433 |
-
|
| 434 |
Returns:
|
| 435 |
p-value (probability that B is not better than A)
|
| 436 |
"""
|
|
@@ -438,14 +436,14 @@ def paired_bootstrap_test(
|
|
| 438 |
a = np.array(scores_a)
|
| 439 |
b = np.array(scores_b)
|
| 440 |
assert len(a) == len(b), "Both score lists must have the same length"
|
| 441 |
-
|
| 442 |
n = len(a)
|
| 443 |
-
|
| 444 |
count = 0
|
| 445 |
for _ in range(n_bootstrap):
|
| 446 |
idx = rng.choice(n, size=n, replace=True)
|
| 447 |
diff = float(np.mean(b[idx]) - np.mean(a[idx]))
|
| 448 |
if diff <= 0:
|
| 449 |
count += 1
|
| 450 |
-
|
| 451 |
return count / n_bootstrap
|
|
|
|
| 72 |
) -> Dict[str, float]:
|
| 73 |
"""
|
| 74 |
Calculate BERTScore for semantic similarity between predictions and references.
|
| 75 |
+
|
| 76 |
BERTScore measures semantic similarity using contextual embeddings, making it
|
| 77 |
more robust than n-gram based metrics like ROUGE for paraphrased content.
|
| 78 |
+
|
| 79 |
Args:
|
| 80 |
predictions: Generated summaries/descriptions
|
| 81 |
references: Reference summaries/descriptions
|
| 82 |
model_type: BERT model to use (default: deberta-xlarge-mnli for best quality)
|
| 83 |
batch_size: Batch size for encoding
|
| 84 |
device: Device to use (auto-detected if None)
|
| 85 |
+
|
| 86 |
Returns:
|
| 87 |
Dict with 'precision', 'recall', 'f1' BERTScore averages
|
| 88 |
"""
|
| 89 |
if not predictions or not references:
|
| 90 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 91 |
+
|
| 92 |
try:
|
| 93 |
from bert_score import score as bert_score # type: ignore[import-not-found]
|
| 94 |
except ImportError:
|
| 95 |
print("Warning: bert-score not installed. Run: pip install bert-score")
|
| 96 |
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 97 |
+
|
| 98 |
# Auto-detect device
|
| 99 |
if device is None:
|
| 100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
+
|
| 102 |
# Calculate BERTScore
|
| 103 |
P, R, F1 = bert_score(
|
| 104 |
list(predictions),
|
|
|
|
| 108 |
device=device,
|
| 109 |
verbose=False,
|
| 110 |
)
|
| 111 |
+
|
| 112 |
return {
|
| 113 |
"precision": float(P.mean().item()), # type: ignore[union-attr]
|
| 114 |
"recall": float(R.mean().item()), # type: ignore[union-attr]
|
|
|
|
| 122 |
) -> Dict[str, float]:
|
| 123 |
"""
|
| 124 |
Calculate proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L).
|
| 125 |
+
|
| 126 |
Args:
|
| 127 |
predictions: Generated summaries
|
| 128 |
references: Reference summaries
|
| 129 |
+
|
| 130 |
Returns:
|
| 131 |
Dict with rouge1, rouge2, rougeL F1 scores
|
| 132 |
"""
|
| 133 |
if not predictions or not references:
|
| 134 |
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
|
| 135 |
+
|
| 136 |
try:
|
| 137 |
from rouge_score import rouge_scorer
|
| 138 |
except ImportError:
|
| 139 |
print("Warning: rouge-score not installed. Run: pip install rouge-score")
|
| 140 |
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
|
| 141 |
+
|
| 142 |
+
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
| 143 |
+
|
| 144 |
rouge1_scores = []
|
| 145 |
rouge2_scores = []
|
| 146 |
rougeL_scores = []
|
| 147 |
+
|
| 148 |
for pred, ref in zip(predictions, references, strict=False):
|
| 149 |
scores = scorer.score(ref, pred)
|
| 150 |
+
rouge1_scores.append(scores["rouge1"].fmeasure)
|
| 151 |
+
rouge2_scores.append(scores["rouge2"].fmeasure)
|
| 152 |
+
rougeL_scores.append(scores["rougeL"].fmeasure)
|
| 153 |
+
|
| 154 |
return {
|
| 155 |
"rouge1": sum(rouge1_scores) / len(rouge1_scores),
|
| 156 |
"rouge2": sum(rouge2_scores) / len(rouge2_scores),
|
|
|
|
| 166 |
) -> Dict[str, float]:
|
| 167 |
"""
|
| 168 |
Calculate comprehensive summarization metrics for research paper reporting.
|
| 169 |
+
|
| 170 |
Includes:
|
| 171 |
- ROUGE-1, ROUGE-2, ROUGE-L (lexical overlap)
|
| 172 |
- BLEU-4 (n-gram precision)
|
| 173 |
- BERTScore (semantic similarity)
|
| 174 |
+
|
| 175 |
Args:
|
| 176 |
predictions: Generated summaries/descriptions
|
| 177 |
references: Reference summaries/descriptions
|
| 178 |
include_bertscore: Whether to compute BERTScore (slower but valuable)
|
| 179 |
bertscore_model: Model for BERTScore computation
|
| 180 |
+
|
| 181 |
Returns:
|
| 182 |
Dict with all metric scores
|
| 183 |
"""
|
| 184 |
metrics: Dict[str, float] = {}
|
| 185 |
+
|
| 186 |
# ROUGE scores
|
| 187 |
rouge_scores = calculate_rouge(predictions, references)
|
| 188 |
metrics.update({f"rouge_{k}": v for k, v in rouge_scores.items()})
|
| 189 |
+
|
| 190 |
# BLEU score
|
| 191 |
metrics["bleu4"] = calculate_bleu(predictions, references)
|
| 192 |
+
|
| 193 |
# BERTScore (semantic similarity - important for back-cover style descriptions)
|
| 194 |
if include_bertscore:
|
| 195 |
+
bert_scores = calculate_bertscore(predictions, references, model_type=bertscore_model)
|
|
|
|
|
|
|
| 196 |
metrics.update({f"bertscore_{k}": v for k, v in bert_scores.items()})
|
| 197 |
+
|
| 198 |
return metrics
|
| 199 |
|
| 200 |
|
|
|
|
| 244 |
|
| 245 |
def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 246 |
"""Compute macro F1: average F1 per class (as in GoEmotions paper).
|
| 247 |
+
|
| 248 |
+
This averages F1 across labels, giving equal weight to each emotion class
|
| 249 |
regardless of prevalence. Directly comparable to GoEmotions baselines.
|
| 250 |
"""
|
| 251 |
preds = predictions.float()
|
| 252 |
gold = targets.float()
|
| 253 |
+
|
| 254 |
# Per-class TP, FP, FN
|
| 255 |
tp = (preds * gold).sum(dim=0)
|
| 256 |
fp = (preds * (1 - gold)).sum(dim=0)
|
| 257 |
fn = ((1 - preds) * gold).sum(dim=0)
|
| 258 |
+
|
| 259 |
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 260 |
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 261 |
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 262 |
+
|
| 263 |
# Zero out F1 for classes with no support in either predictions or targets
|
| 264 |
mask = (tp + fp + fn) > 0
|
| 265 |
if mask.sum() == 0:
|
|
|
|
| 269 |
|
| 270 |
def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 271 |
"""Compute micro F1: aggregate TP/FP/FN across all classes.
|
| 272 |
+
|
| 273 |
This gives more weight to frequent classes. Useful when class distribution matters.
|
| 274 |
"""
|
| 275 |
preds = predictions.float()
|
| 276 |
gold = targets.float()
|
| 277 |
+
|
| 278 |
tp = (preds * gold).sum()
|
| 279 |
fp = (preds * (1 - gold)).sum()
|
| 280 |
fn = ((1 - preds) * gold).sum()
|
| 281 |
+
|
| 282 |
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 283 |
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 284 |
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
|
|
|
| 291 |
class_names: Sequence[str] | None = None,
|
| 292 |
) -> Dict[str, Dict[str, float]]:
|
| 293 |
"""Compute per-class precision, recall, F1 for multi-label classification.
|
| 294 |
+
|
| 295 |
Returns a dict mapping class name/index to its metrics.
|
| 296 |
"""
|
| 297 |
preds = predictions.float()
|
| 298 |
gold = targets.float()
|
| 299 |
num_classes = preds.shape[1]
|
| 300 |
+
|
| 301 |
tp = (preds * gold).sum(dim=0)
|
| 302 |
fp = (preds * (1 - gold)).sum(dim=0)
|
| 303 |
fn = ((1 - preds) * gold).sum(dim=0)
|
| 304 |
+
|
| 305 |
report: Dict[str, Dict[str, float]] = {}
|
| 306 |
for i in range(num_classes):
|
| 307 |
name = class_names[i] if class_names else str(i)
|
|
|
|
| 323 |
thresholds: Sequence[float] | None = None,
|
| 324 |
) -> tuple[List[float], float]:
|
| 325 |
"""Tune per-class thresholds on validation set to maximize macro F1.
|
| 326 |
+
|
| 327 |
+
For each class, tries multiple thresholds and selects the one that
|
| 328 |
+
maximizes that class's F1 score. This is standard practice for multi-label
|
| 329 |
classification (used in the original GoEmotions paper).
|
| 330 |
+
|
| 331 |
Args:
|
| 332 |
logits: Raw model logits (batch, num_classes)
|
| 333 |
targets: Binary target labels (batch, num_classes)
|
| 334 |
thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
|
| 335 |
+
|
| 336 |
Returns:
|
| 337 |
Tuple of (best_thresholds_per_class, resulting_macro_f1)
|
| 338 |
"""
|
| 339 |
if thresholds is None:
|
| 340 |
thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
|
| 341 |
+
|
| 342 |
probs = torch.sigmoid(logits)
|
| 343 |
num_classes = probs.shape[1]
|
| 344 |
gold = targets.float()
|
| 345 |
+
|
| 346 |
best_thresholds: List[float] = []
|
| 347 |
for c in range(num_classes):
|
| 348 |
best_f1 = -1.0
|
|
|
|
| 362 |
best_f1 = f1
|
| 363 |
best_t = t
|
| 364 |
best_thresholds.append(best_t)
|
| 365 |
+
|
| 366 |
# Compute resulting macro F1 with tuned thresholds
|
| 367 |
tuned_preds = torch.zeros_like(probs)
|
| 368 |
for c in range(num_classes):
|
| 369 |
tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
|
| 370 |
macro_f1 = multilabel_macro_f1(tuned_preds, targets)
|
| 371 |
+
|
| 372 |
return best_thresholds, macro_f1
|
| 373 |
|
| 374 |
|
|
|
|
| 382 |
seed: int = 42,
|
| 383 |
) -> tuple[float, float, float]:
|
| 384 |
"""Compute bootstrap confidence interval for a metric.
|
| 385 |
+
|
| 386 |
Args:
|
| 387 |
scores: Per-sample metric values
|
| 388 |
n_bootstrap: Number of bootstrap resamples
|
| 389 |
confidence: Confidence level (default 95%)
|
| 390 |
seed: Random seed for reproducibility
|
| 391 |
+
|
| 392 |
Returns:
|
| 393 |
Tuple of (mean, lower_bound, upper_bound)
|
| 394 |
"""
|
| 395 |
rng = np.random.default_rng(seed)
|
| 396 |
scores_arr = np.array(scores)
|
| 397 |
n = len(scores_arr)
|
| 398 |
+
|
| 399 |
bootstrap_means = []
|
| 400 |
for _ in range(n_bootstrap):
|
| 401 |
sample = rng.choice(scores_arr, size=n, replace=True)
|
| 402 |
bootstrap_means.append(float(np.mean(sample)))
|
| 403 |
+
|
| 404 |
bootstrap_means.sort()
|
| 405 |
alpha = 1 - confidence
|
| 406 |
lower_idx = int(alpha / 2 * n_bootstrap)
|
| 407 |
upper_idx = int((1 - alpha / 2) * n_bootstrap)
|
| 408 |
+
|
| 409 |
return (
|
| 410 |
float(np.mean(scores_arr)),
|
| 411 |
bootstrap_means[lower_idx],
|
|
|
|
| 420 |
seed: int = 42,
|
| 421 |
) -> float:
|
| 422 |
"""Paired bootstrap significance test between two systems.
|
| 423 |
+
|
| 424 |
Tests if system B is significantly better than system A.
|
| 425 |
+
|
| 426 |
Args:
|
| 427 |
scores_a: Per-sample scores from system A
|
| 428 |
scores_b: Per-sample scores from system B
|
| 429 |
n_bootstrap: Number of bootstrap iterations
|
| 430 |
seed: Random seed
|
| 431 |
+
|
| 432 |
Returns:
|
| 433 |
p-value (probability that B is not better than A)
|
| 434 |
"""
|
|
|
|
| 436 |
a = np.array(scores_a)
|
| 437 |
b = np.array(scores_b)
|
| 438 |
assert len(a) == len(b), "Both score lists must have the same length"
|
| 439 |
+
|
| 440 |
n = len(a)
|
| 441 |
+
|
| 442 |
count = 0
|
| 443 |
for _ in range(n_bootstrap):
|
| 444 |
idx = rng.choice(n, size=n, replace=True)
|
| 445 |
diff = float(np.mean(b[idx]) - np.mean(a[idx]))
|
| 446 |
if diff <= 0:
|
| 447 |
count += 1
|
| 448 |
+
|
| 449 |
return count / n_bootstrap
|
src/training/trainer.py
CHANGED
|
@@ -48,24 +48,24 @@ class TrainerConfig:
|
|
| 48 |
validation_max_length: int = 128
|
| 49 |
label_smoothing: float = 0.1
|
| 50 |
gradient_accumulation_steps: int = 1
|
| 51 |
-
|
| 52 |
# LR scheduler
|
| 53 |
scheduler_type: str = "cosine"
|
| 54 |
warmup_steps: int = 500
|
| 55 |
-
|
| 56 |
# Early stopping
|
| 57 |
early_stopping_patience: int | None = 5
|
| 58 |
-
|
| 59 |
# Task sampling strategy: "round_robin" or "temperature"
|
| 60 |
# Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
|
| 61 |
# alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
|
| 62 |
task_sampling: str = "temperature"
|
| 63 |
task_sampling_alpha: float = 0.5
|
| 64 |
-
|
| 65 |
# Gradient conflict diagnostics
|
| 66 |
# Compute inter-task gradient cosine similarity every N steps (0 = disabled)
|
| 67 |
gradient_conflict_frequency: int = 0
|
| 68 |
-
|
| 69 |
# MLflow
|
| 70 |
experiment_name: str = "LexiMind"
|
| 71 |
run_name: str | None = None
|
|
@@ -76,13 +76,13 @@ class TrainerConfig:
|
|
| 76 |
|
| 77 |
class EarlyStopping:
|
| 78 |
"""Stop training when validation loss stops improving."""
|
| 79 |
-
|
| 80 |
def __init__(self, patience: int = 5, min_delta: float = 0.001):
|
| 81 |
self.patience = patience
|
| 82 |
self.min_delta = min_delta
|
| 83 |
self.counter = 0
|
| 84 |
-
self.best_value = float(
|
| 85 |
-
|
| 86 |
def __call__(self, val_loss: float) -> bool:
|
| 87 |
"""Returns True if training should stop."""
|
| 88 |
if val_loss < self.best_value - self.min_delta:
|
|
@@ -155,7 +155,9 @@ class Trainer:
|
|
| 155 |
|
| 156 |
pbar = tqdm(
|
| 157 |
range(start_epoch, self.config.max_epochs + 1),
|
| 158 |
-
desc="Training",
|
|
|
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
for epoch in pbar:
|
|
@@ -178,10 +180,12 @@ class Trainer:
|
|
| 178 |
|
| 179 |
# Early stopping
|
| 180 |
if self.early_stopping:
|
| 181 |
-
val_loss = val_metrics.get("total_loss", float(
|
| 182 |
if self.early_stopping(val_loss):
|
| 183 |
-
tqdm.write(
|
| 184 |
-
|
|
|
|
|
|
|
| 185 |
break
|
| 186 |
|
| 187 |
# Checkpoint
|
|
@@ -190,11 +194,11 @@ class Trainer:
|
|
| 190 |
|
| 191 |
# Update progress
|
| 192 |
epoch_time = time.perf_counter() - epoch_start
|
| 193 |
-
loss = train_metrics.get(
|
| 194 |
pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
|
| 195 |
|
| 196 |
total_time = time.perf_counter() - total_start
|
| 197 |
-
print(f"\nTraining complete in {total_time/60:.1f} minutes")
|
| 198 |
return history
|
| 199 |
|
| 200 |
def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
|
|
@@ -203,7 +207,9 @@ class Trainer:
|
|
| 203 |
self.scheduler = None
|
| 204 |
return
|
| 205 |
|
| 206 |
-
steps_per_epoch = max(len(loader) for loader in loaders.values()) // max(
|
|
|
|
|
|
|
| 207 |
total_steps = steps_per_epoch * (self.config.max_epochs - start_epoch + 1)
|
| 208 |
warmup = self.config.warmup_steps
|
| 209 |
|
|
@@ -238,10 +244,12 @@ class Trainer:
|
|
| 238 |
if self.config.task_sampling == "temperature" and len(task_names) > 1:
|
| 239 |
sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
|
| 240 |
alpha = self.config.task_sampling_alpha
|
| 241 |
-
probs = sizes
|
| 242 |
probs = probs / probs.sum()
|
| 243 |
-
tqdm.write(
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
else:
|
| 246 |
probs = None
|
| 247 |
|
|
@@ -253,7 +261,9 @@ class Trainer:
|
|
| 253 |
# Select tasks for this step
|
| 254 |
if probs is not None and train:
|
| 255 |
# Temperature sampling: sample tasks based on dataset size
|
| 256 |
-
selected_tasks = list(
|
|
|
|
|
|
|
| 257 |
else:
|
| 258 |
# Round-robin: all tasks every step
|
| 259 |
selected_tasks = task_names
|
|
@@ -288,8 +298,11 @@ class Trainer:
|
|
| 288 |
scaled.backward()
|
| 289 |
|
| 290 |
# Gradient conflict diagnostics
|
| 291 |
-
if (
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
| 293 |
conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
|
| 294 |
for k, v in conflict_stats.items():
|
| 295 |
metrics[f"grad_{k}"].append(v)
|
|
@@ -316,8 +329,10 @@ class Trainer:
|
|
| 316 |
|
| 317 |
# Average metrics
|
| 318 |
averaged = {k: sum(v) / len(v) for k, v in metrics.items() if v}
|
| 319 |
-
tqdm.write(
|
| 320 |
-
|
|
|
|
|
|
|
| 321 |
return averaged
|
| 322 |
|
| 323 |
def _get_batch(self, iterators: Dict, loader: DataLoader, task: str) -> Dict | None:
|
|
@@ -330,8 +345,10 @@ class Trainer:
|
|
| 330 |
batch = next(iterators[task])
|
| 331 |
except StopIteration:
|
| 332 |
return None
|
| 333 |
-
return {
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
|
| 336 |
def _forward_task(self, task: str, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 337 |
"""Route to task-specific forward pass."""
|
|
@@ -360,10 +377,10 @@ class Trainer:
|
|
| 360 |
# Decode predictions and references
|
| 361 |
preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
|
| 362 |
refs = self._decode_labels(batch["labels"])
|
| 363 |
-
|
| 364 |
# Calculate comprehensive metrics
|
| 365 |
metrics = {"rouge_like": rouge_like(preds, refs)}
|
| 366 |
-
|
| 367 |
# Proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L)
|
| 368 |
try:
|
| 369 |
rouge_scores = calculate_rouge(preds, refs)
|
|
@@ -372,13 +389,13 @@ class Trainer:
|
|
| 372 |
metrics["rougeL"] = rouge_scores["rougeL"]
|
| 373 |
except Exception:
|
| 374 |
pass # Fall back to rouge_like only if rouge-score not installed
|
| 375 |
-
|
| 376 |
# BLEU-4 score
|
| 377 |
try:
|
| 378 |
metrics["bleu4"] = calculate_bleu(preds, refs)
|
| 379 |
except Exception:
|
| 380 |
pass
|
| 381 |
-
|
| 382 |
return loss, metrics
|
| 383 |
|
| 384 |
def _forward_emotion(self, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
|
|
@@ -423,8 +440,10 @@ class Trainer:
|
|
| 423 |
if i >= n:
|
| 424 |
break
|
| 425 |
|
| 426 |
-
batch = {
|
| 427 |
-
|
|
|
|
|
|
|
| 428 |
src_ids = batch["src_ids"][:1]
|
| 429 |
src_mask = batch.get("src_mask", None)
|
| 430 |
if src_mask is not None:
|
|
@@ -432,7 +451,9 @@ class Trainer:
|
|
| 432 |
|
| 433 |
# Generate with anti-repetition
|
| 434 |
model: Any = self.model
|
| 435 |
-
enc_mask =
|
|
|
|
|
|
|
| 436 |
memory = model.encoder(src_ids, mask=enc_mask)
|
| 437 |
generated = model.decoder.greedy_decode(
|
| 438 |
memory=memory,
|
|
@@ -463,27 +484,27 @@ class Trainer:
|
|
| 463 |
iterators: Dict,
|
| 464 |
) -> Dict[str, float]:
|
| 465 |
"""Compute inter-task gradient cosine similarity to diagnose conflicts.
|
| 466 |
-
|
| 467 |
Returns cosine similarity between gradient vectors for each task pair.
|
| 468 |
Negative values indicate conflicting gradients (negative transfer risk).
|
| 469 |
"""
|
| 470 |
task_grads: Dict[str, torch.Tensor] = {}
|
| 471 |
-
|
| 472 |
for task, loader in loaders.items():
|
| 473 |
self.optimizer.zero_grad()
|
| 474 |
batch = self._get_batch(iterators, loader, task)
|
| 475 |
if batch is None:
|
| 476 |
continue
|
| 477 |
-
|
| 478 |
dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
|
| 479 |
with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
|
| 480 |
loss, _ = self._forward_task(task, batch)
|
| 481 |
-
|
| 482 |
if torch.isnan(loss):
|
| 483 |
continue
|
| 484 |
-
|
| 485 |
loss.backward()
|
| 486 |
-
|
| 487 |
# Flatten all gradients into a single vector
|
| 488 |
grad_vec = []
|
| 489 |
for p in self.model.parameters():
|
|
@@ -491,9 +512,9 @@ class Trainer:
|
|
| 491 |
grad_vec.append(p.grad.detach().clone().flatten())
|
| 492 |
if grad_vec:
|
| 493 |
task_grads[task] = torch.cat(grad_vec)
|
| 494 |
-
|
| 495 |
self.optimizer.zero_grad()
|
| 496 |
-
|
| 497 |
# Compute pairwise cosine similarity
|
| 498 |
stats: Dict[str, float] = {}
|
| 499 |
tasks = list(task_grads.keys())
|
|
@@ -504,20 +525,22 @@ class Trainer:
|
|
| 504 |
cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
|
| 505 |
stats[f"cos_sim_{t1}_{t2}"] = cos_sim
|
| 506 |
stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
|
| 507 |
-
|
| 508 |
return stats
|
| 509 |
|
| 510 |
def _log_config(self) -> None:
|
| 511 |
"""Log config to MLflow."""
|
| 512 |
-
mlflow.log_params(
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
|
|
|
|
|
|
| 521 |
|
| 522 |
def _log_metrics(self, metrics: Dict[str, float], prefix: str, epoch: int) -> None:
|
| 523 |
"""Log metrics to MLflow."""
|
|
|
|
| 48 |
validation_max_length: int = 128
|
| 49 |
label_smoothing: float = 0.1
|
| 50 |
gradient_accumulation_steps: int = 1
|
| 51 |
+
|
| 52 |
# LR scheduler
|
| 53 |
scheduler_type: str = "cosine"
|
| 54 |
warmup_steps: int = 500
|
| 55 |
+
|
| 56 |
# Early stopping
|
| 57 |
early_stopping_patience: int | None = 5
|
| 58 |
+
|
| 59 |
# Task sampling strategy: "round_robin" or "temperature"
|
| 60 |
# Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
|
| 61 |
# alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
|
| 62 |
task_sampling: str = "temperature"
|
| 63 |
task_sampling_alpha: float = 0.5
|
| 64 |
+
|
| 65 |
# Gradient conflict diagnostics
|
| 66 |
# Compute inter-task gradient cosine similarity every N steps (0 = disabled)
|
| 67 |
gradient_conflict_frequency: int = 0
|
| 68 |
+
|
| 69 |
# MLflow
|
| 70 |
experiment_name: str = "LexiMind"
|
| 71 |
run_name: str | None = None
|
|
|
|
| 76 |
|
| 77 |
class EarlyStopping:
|
| 78 |
"""Stop training when validation loss stops improving."""
|
| 79 |
+
|
| 80 |
def __init__(self, patience: int = 5, min_delta: float = 0.001):
|
| 81 |
self.patience = patience
|
| 82 |
self.min_delta = min_delta
|
| 83 |
self.counter = 0
|
| 84 |
+
self.best_value = float("inf")
|
| 85 |
+
|
| 86 |
def __call__(self, val_loss: float) -> bool:
|
| 87 |
"""Returns True if training should stop."""
|
| 88 |
if val_loss < self.best_value - self.min_delta:
|
|
|
|
| 155 |
|
| 156 |
pbar = tqdm(
|
| 157 |
range(start_epoch, self.config.max_epochs + 1),
|
| 158 |
+
desc="Training",
|
| 159 |
+
unit="epoch",
|
| 160 |
+
file=sys.stderr,
|
| 161 |
)
|
| 162 |
|
| 163 |
for epoch in pbar:
|
|
|
|
| 180 |
|
| 181 |
# Early stopping
|
| 182 |
if self.early_stopping:
|
| 183 |
+
val_loss = val_metrics.get("total_loss", float("inf"))
|
| 184 |
if self.early_stopping(val_loss):
|
| 185 |
+
tqdm.write(
|
| 186 |
+
f"\nEarly stopping at epoch {epoch} (best loss: {self.early_stopping.best_value:.4f})"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
break
|
| 190 |
|
| 191 |
# Checkpoint
|
|
|
|
| 194 |
|
| 195 |
# Update progress
|
| 196 |
epoch_time = time.perf_counter() - epoch_start
|
| 197 |
+
loss = train_metrics.get("total_loss", 0)
|
| 198 |
pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
|
| 199 |
|
| 200 |
total_time = time.perf_counter() - total_start
|
| 201 |
+
print(f"\nTraining complete in {total_time / 60:.1f} minutes")
|
| 202 |
return history
|
| 203 |
|
| 204 |
def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
|
|
|
|
| 207 |
self.scheduler = None
|
| 208 |
return
|
| 209 |
|
| 210 |
+
steps_per_epoch = max(len(loader) for loader in loaders.values()) // max(
|
| 211 |
+
1, self.config.gradient_accumulation_steps
|
| 212 |
+
)
|
| 213 |
total_steps = steps_per_epoch * (self.config.max_epochs - start_epoch + 1)
|
| 214 |
warmup = self.config.warmup_steps
|
| 215 |
|
|
|
|
| 244 |
if self.config.task_sampling == "temperature" and len(task_names) > 1:
|
| 245 |
sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
|
| 246 |
alpha = self.config.task_sampling_alpha
|
| 247 |
+
probs = sizes**alpha
|
| 248 |
probs = probs / probs.sum()
|
| 249 |
+
tqdm.write(
|
| 250 |
+
f" Temperature sampling (α={alpha}): "
|
| 251 |
+
+ ", ".join(f"{t}={p:.2%}" for t, p in zip(task_names, probs, strict=True))
|
| 252 |
+
)
|
| 253 |
else:
|
| 254 |
probs = None
|
| 255 |
|
|
|
|
| 261 |
# Select tasks for this step
|
| 262 |
if probs is not None and train:
|
| 263 |
# Temperature sampling: sample tasks based on dataset size
|
| 264 |
+
selected_tasks = list(
|
| 265 |
+
np.random.choice(task_names, size=len(task_names), replace=True, p=probs)
|
| 266 |
+
)
|
| 267 |
else:
|
| 268 |
# Round-robin: all tasks every step
|
| 269 |
selected_tasks = task_names
|
|
|
|
| 298 |
scaled.backward()
|
| 299 |
|
| 300 |
# Gradient conflict diagnostics
|
| 301 |
+
if (
|
| 302 |
+
train
|
| 303 |
+
and self.config.gradient_conflict_frequency > 0
|
| 304 |
+
and (step + 1) % self.config.gradient_conflict_frequency == 0
|
| 305 |
+
):
|
| 306 |
conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
|
| 307 |
for k, v in conflict_stats.items():
|
| 308 |
metrics[f"grad_{k}"].append(v)
|
|
|
|
| 329 |
|
| 330 |
# Average metrics
|
| 331 |
averaged = {k: sum(v) / len(v) for k, v in metrics.items() if v}
|
| 332 |
+
tqdm.write(
|
| 333 |
+
f"[{phase.lower()}] epoch {epoch}: "
|
| 334 |
+
+ ", ".join(f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch")
|
| 335 |
+
)
|
| 336 |
return averaged
|
| 337 |
|
| 338 |
def _get_batch(self, iterators: Dict, loader: DataLoader, task: str) -> Dict | None:
|
|
|
|
| 345 |
batch = next(iterators[task])
|
| 346 |
except StopIteration:
|
| 347 |
return None
|
| 348 |
+
return {
|
| 349 |
+
k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
|
| 350 |
+
for k, v in batch.items()
|
| 351 |
+
}
|
| 352 |
|
| 353 |
def _forward_task(self, task: str, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 354 |
"""Route to task-specific forward pass."""
|
|
|
|
| 377 |
# Decode predictions and references
|
| 378 |
preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
|
| 379 |
refs = self._decode_labels(batch["labels"])
|
| 380 |
+
|
| 381 |
# Calculate comprehensive metrics
|
| 382 |
metrics = {"rouge_like": rouge_like(preds, refs)}
|
| 383 |
+
|
| 384 |
# Proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L)
|
| 385 |
try:
|
| 386 |
rouge_scores = calculate_rouge(preds, refs)
|
|
|
|
| 389 |
metrics["rougeL"] = rouge_scores["rougeL"]
|
| 390 |
except Exception:
|
| 391 |
pass # Fall back to rouge_like only if rouge-score not installed
|
| 392 |
+
|
| 393 |
# BLEU-4 score
|
| 394 |
try:
|
| 395 |
metrics["bleu4"] = calculate_bleu(preds, refs)
|
| 396 |
except Exception:
|
| 397 |
pass
|
| 398 |
+
|
| 399 |
return loss, metrics
|
| 400 |
|
| 401 |
def _forward_emotion(self, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
|
|
|
|
| 440 |
if i >= n:
|
| 441 |
break
|
| 442 |
|
| 443 |
+
batch = {
|
| 444 |
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 445 |
+
for k, v in batch.items()
|
| 446 |
+
}
|
| 447 |
src_ids = batch["src_ids"][:1]
|
| 448 |
src_mask = batch.get("src_mask", None)
|
| 449 |
if src_mask is not None:
|
|
|
|
| 451 |
|
| 452 |
# Generate with anti-repetition
|
| 453 |
model: Any = self.model
|
| 454 |
+
enc_mask = (
|
| 455 |
+
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 456 |
+
)
|
| 457 |
memory = model.encoder(src_ids, mask=enc_mask)
|
| 458 |
generated = model.decoder.greedy_decode(
|
| 459 |
memory=memory,
|
|
|
|
| 484 |
iterators: Dict,
|
| 485 |
) -> Dict[str, float]:
|
| 486 |
"""Compute inter-task gradient cosine similarity to diagnose conflicts.
|
| 487 |
+
|
| 488 |
Returns cosine similarity between gradient vectors for each task pair.
|
| 489 |
Negative values indicate conflicting gradients (negative transfer risk).
|
| 490 |
"""
|
| 491 |
task_grads: Dict[str, torch.Tensor] = {}
|
| 492 |
+
|
| 493 |
for task, loader in loaders.items():
|
| 494 |
self.optimizer.zero_grad()
|
| 495 |
batch = self._get_batch(iterators, loader, task)
|
| 496 |
if batch is None:
|
| 497 |
continue
|
| 498 |
+
|
| 499 |
dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
|
| 500 |
with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
|
| 501 |
loss, _ = self._forward_task(task, batch)
|
| 502 |
+
|
| 503 |
if torch.isnan(loss):
|
| 504 |
continue
|
| 505 |
+
|
| 506 |
loss.backward()
|
| 507 |
+
|
| 508 |
# Flatten all gradients into a single vector
|
| 509 |
grad_vec = []
|
| 510 |
for p in self.model.parameters():
|
|
|
|
| 512 |
grad_vec.append(p.grad.detach().clone().flatten())
|
| 513 |
if grad_vec:
|
| 514 |
task_grads[task] = torch.cat(grad_vec)
|
| 515 |
+
|
| 516 |
self.optimizer.zero_grad()
|
| 517 |
+
|
| 518 |
# Compute pairwise cosine similarity
|
| 519 |
stats: Dict[str, float] = {}
|
| 520 |
tasks = list(task_grads.keys())
|
|
|
|
| 525 |
cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
|
| 526 |
stats[f"cos_sim_{t1}_{t2}"] = cos_sim
|
| 527 |
stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
|
| 528 |
+
|
| 529 |
return stats
|
| 530 |
|
| 531 |
def _log_config(self) -> None:
|
| 532 |
"""Log config to MLflow."""
|
| 533 |
+
mlflow.log_params(
|
| 534 |
+
{
|
| 535 |
+
"max_epochs": self.config.max_epochs,
|
| 536 |
+
"gradient_clip_norm": self.config.gradient_clip_norm,
|
| 537 |
+
"label_smoothing": self.config.label_smoothing,
|
| 538 |
+
"task_weights": str(self.config.task_weights),
|
| 539 |
+
"warmup_steps": self.config.warmup_steps,
|
| 540 |
+
"scheduler_type": self.config.scheduler_type,
|
| 541 |
+
"learning_rate": self.optimizer.param_groups[0]["lr"],
|
| 542 |
+
}
|
| 543 |
+
)
|
| 544 |
|
| 545 |
def _log_metrics(self, metrics: Dict[str, float], prefix: str, epoch: int) -> None:
|
| 546 |
"""Log metrics to MLflow."""
|
src/utils/__init__.py
CHANGED
|
@@ -14,9 +14,16 @@ from .io import load_state, save_state
|
|
| 14 |
from .labels import load_label_metadata, save_label_metadata
|
| 15 |
|
| 16 |
__all__ = [
|
| 17 |
-
"save_checkpoint",
|
| 18 |
-
"
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
]
|
|
|
|
| 14 |
from .labels import load_label_metadata, save_label_metadata
|
| 15 |
|
| 16 |
__all__ = [
|
| 17 |
+
"save_checkpoint",
|
| 18 |
+
"load_checkpoint",
|
| 19 |
+
"save_state",
|
| 20 |
+
"load_state",
|
| 21 |
+
"LabelMetadata",
|
| 22 |
+
"load_labels",
|
| 23 |
+
"save_labels",
|
| 24 |
+
"load_label_metadata",
|
| 25 |
+
"save_label_metadata",
|
| 26 |
+
"set_seed",
|
| 27 |
+
"Config",
|
| 28 |
+
"load_yaml",
|
| 29 |
]
|
src/utils/core.py
CHANGED
|
@@ -28,7 +28,7 @@ def save_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
|
|
| 28 |
"""Save model state dict, handling torch.compile artifacts."""
|
| 29 |
path = Path(path)
|
| 30 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 31 |
-
|
| 32 |
# Strip '_orig_mod.' prefix from compiled models
|
| 33 |
state_dict = {k.replace("_orig_mod.", ""): v for k, v in model.state_dict().items()}
|
| 34 |
torch.save(state_dict, path)
|
|
@@ -47,7 +47,7 @@ def load_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
|
|
| 47 |
@dataclass
|
| 48 |
class LabelMetadata:
|
| 49 |
"""Container for emotion and topic label vocabularies."""
|
| 50 |
-
|
| 51 |
emotion: List[str]
|
| 52 |
topic: List[str]
|
| 53 |
|
|
@@ -65,16 +65,16 @@ def load_labels(path: str | Path) -> LabelMetadata:
|
|
| 65 |
path = Path(path)
|
| 66 |
if not path.exists():
|
| 67 |
raise FileNotFoundError(f"Labels not found: {path}")
|
| 68 |
-
|
| 69 |
with path.open("r", encoding="utf-8") as f:
|
| 70 |
data = json.load(f)
|
| 71 |
-
|
| 72 |
emotion = data.get("emotion") or data.get("emotions", [])
|
| 73 |
topic = data.get("topic") or data.get("topics", [])
|
| 74 |
-
|
| 75 |
if not emotion or not topic:
|
| 76 |
raise ValueError("Labels file must contain 'emotion' and 'topic' lists")
|
| 77 |
-
|
| 78 |
return LabelMetadata(emotion=emotion, topic=topic)
|
| 79 |
|
| 80 |
|
|
@@ -82,7 +82,7 @@ def save_labels(labels: LabelMetadata, path: str | Path) -> None:
|
|
| 82 |
"""Save label metadata to JSON file."""
|
| 83 |
path = Path(path)
|
| 84 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 85 |
-
|
| 86 |
with path.open("w", encoding="utf-8") as f:
|
| 87 |
json.dump({"emotion": labels.emotion, "topic": labels.topic}, f, indent=2)
|
| 88 |
|
|
@@ -105,12 +105,14 @@ def set_seed(seed: int) -> None:
|
|
| 105 |
@dataclass
|
| 106 |
class Config:
|
| 107 |
"""Simple config wrapper."""
|
|
|
|
| 108 |
data: dict
|
| 109 |
|
| 110 |
|
| 111 |
def load_yaml(path: str | Path) -> Config:
|
| 112 |
"""Load YAML configuration file."""
|
| 113 |
import yaml
|
|
|
|
| 114 |
with Path(path).open("r", encoding="utf-8") as f:
|
| 115 |
content = yaml.safe_load(f)
|
| 116 |
if not isinstance(content, dict):
|
|
|
|
| 28 |
"""Save model state dict, handling torch.compile artifacts."""
|
| 29 |
path = Path(path)
|
| 30 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
# Strip '_orig_mod.' prefix from compiled models
|
| 33 |
state_dict = {k.replace("_orig_mod.", ""): v for k, v in model.state_dict().items()}
|
| 34 |
torch.save(state_dict, path)
|
|
|
|
| 47 |
@dataclass
|
| 48 |
class LabelMetadata:
|
| 49 |
"""Container for emotion and topic label vocabularies."""
|
| 50 |
+
|
| 51 |
emotion: List[str]
|
| 52 |
topic: List[str]
|
| 53 |
|
|
|
|
| 65 |
path = Path(path)
|
| 66 |
if not path.exists():
|
| 67 |
raise FileNotFoundError(f"Labels not found: {path}")
|
| 68 |
+
|
| 69 |
with path.open("r", encoding="utf-8") as f:
|
| 70 |
data = json.load(f)
|
| 71 |
+
|
| 72 |
emotion = data.get("emotion") or data.get("emotions", [])
|
| 73 |
topic = data.get("topic") or data.get("topics", [])
|
| 74 |
+
|
| 75 |
if not emotion or not topic:
|
| 76 |
raise ValueError("Labels file must contain 'emotion' and 'topic' lists")
|
| 77 |
+
|
| 78 |
return LabelMetadata(emotion=emotion, topic=topic)
|
| 79 |
|
| 80 |
|
|
|
|
| 82 |
"""Save label metadata to JSON file."""
|
| 83 |
path = Path(path)
|
| 84 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 85 |
+
|
| 86 |
with path.open("w", encoding="utf-8") as f:
|
| 87 |
json.dump({"emotion": labels.emotion, "topic": labels.topic}, f, indent=2)
|
| 88 |
|
|
|
|
| 105 |
@dataclass
|
| 106 |
class Config:
|
| 107 |
"""Simple config wrapper."""
|
| 108 |
+
|
| 109 |
data: dict
|
| 110 |
|
| 111 |
|
| 112 |
def load_yaml(path: str | Path) -> Config:
|
| 113 |
"""Load YAML configuration file."""
|
| 114 |
import yaml
|
| 115 |
+
|
| 116 |
with Path(path).open("r", encoding="utf-8") as f:
|
| 117 |
content = yaml.safe_load(f)
|
| 118 |
if not isinstance(content, dict):
|
tests/test_training/test_trainer.py
CHANGED
|
@@ -111,8 +111,9 @@ class TestGradientFlow(unittest.TestCase):
|
|
| 111 |
loss = nn.CrossEntropyLoss()(logits, batch["labels"])
|
| 112 |
loss.backward()
|
| 113 |
|
| 114 |
-
has_grads = any(
|
| 115 |
-
|
|
|
|
| 116 |
self.assertTrue(has_grads, "No gradients found")
|
| 117 |
|
| 118 |
def test_emotion_gradients(self):
|
|
@@ -130,8 +131,9 @@ class TestGradientFlow(unittest.TestCase):
|
|
| 130 |
loss = nn.BCEWithLogitsLoss()(logits, batch["labels"])
|
| 131 |
loss.backward()
|
| 132 |
|
| 133 |
-
has_grads = any(
|
| 134 |
-
|
|
|
|
| 135 |
self.assertTrue(has_grads, "No gradients found")
|
| 136 |
|
| 137 |
def test_summarization_gradients(self):
|
|
@@ -145,14 +147,12 @@ class TestGradientFlow(unittest.TestCase):
|
|
| 145 |
self.model.zero_grad()
|
| 146 |
logits = self.model.forward("summarization", batch)
|
| 147 |
# Flatten for cross entropy: (B*T, vocab) vs (B*T,)
|
| 148 |
-
loss = nn.CrossEntropyLoss()(
|
| 149 |
-
logits.view(-1, 100),
|
| 150 |
-
batch["labels"].view(-1)
|
| 151 |
-
)
|
| 152 |
loss.backward()
|
| 153 |
|
| 154 |
-
has_grads = any(
|
| 155 |
-
|
|
|
|
| 156 |
self.assertTrue(has_grads, "No gradients found")
|
| 157 |
|
| 158 |
|
|
|
|
| 111 |
loss = nn.CrossEntropyLoss()(logits, batch["labels"])
|
| 112 |
loss.backward()
|
| 113 |
|
| 114 |
+
has_grads = any(
|
| 115 |
+
p.grad is not None and p.grad.abs().sum() > 0 for p in self.model.parameters()
|
| 116 |
+
)
|
| 117 |
self.assertTrue(has_grads, "No gradients found")
|
| 118 |
|
| 119 |
def test_emotion_gradients(self):
|
|
|
|
| 131 |
loss = nn.BCEWithLogitsLoss()(logits, batch["labels"])
|
| 132 |
loss.backward()
|
| 133 |
|
| 134 |
+
has_grads = any(
|
| 135 |
+
p.grad is not None and p.grad.abs().sum() > 0 for p in self.model.parameters()
|
| 136 |
+
)
|
| 137 |
self.assertTrue(has_grads, "No gradients found")
|
| 138 |
|
| 139 |
def test_summarization_gradients(self):
|
|
|
|
| 147 |
self.model.zero_grad()
|
| 148 |
logits = self.model.forward("summarization", batch)
|
| 149 |
# Flatten for cross entropy: (B*T, vocab) vs (B*T,)
|
| 150 |
+
loss = nn.CrossEntropyLoss()(logits.view(-1, 100), batch["labels"].view(-1))
|
|
|
|
|
|
|
|
|
|
| 151 |
loss.backward()
|
| 152 |
|
| 153 |
+
has_grads = any(
|
| 154 |
+
p.grad is not None and p.grad.abs().sum() > 0 for p in self.model.parameters()
|
| 155 |
+
)
|
| 156 |
self.assertTrue(has_grads, "No gradients found")
|
| 157 |
|
| 158 |
|