OliverPerrin commited on
Commit
df3ebbd
·
1 Parent(s): c472a19

updated readme, ruff formatted all files

Browse files
README.md CHANGED
@@ -11,56 +11,76 @@ pinned: false
11
  <!-- markdownlint-disable MD025 -->
12
  # LexiMind
13
 
14
- A multi-task NLP system for literary and academic text understanding. LexiMind performs **abstractive summarization**, **topic classification**, and **emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
15
 
16
  **[Live Demo](https://huggingface.co/spaces/OliverPerrin/LexiMind)** · **[Model](https://huggingface.co/OliverPerrin/LexiMind-Model)** · **[Discovery Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)** · **[Research Paper](docs/research_paper.tex)**
17
 
18
- ## What It Does
19
 
20
- | Task | Description | Metric |
21
- | ------ | ------------- | -------- |
22
- | **Summarization** | Generates back-cover style book descriptions and paper abstracts from source text | BERTScore F1: **0.830** |
23
- | **Topic Classification** | Classifies passages into 7 categories | Accuracy: **85.2%** |
24
- | **Emotion Detection** | Identifies emotions from 28 fine-grained labels (multi-label) | Sample-avg F1: **0.199** |
 
 
 
25
 
26
- **Topic labels:** Arts · Business · Fiction · History · Philosophy · Science · Technology
27
 
28
- The model is trained on literary text (Project Gutenberg + Goodreads descriptions), academic papers (arXiv), and emotion-annotated Reddit comments (GoEmotions). For summarization, it learns to produce descriptive summaries—what a book *is about*—rather than plot recaps, by pairing Gutenberg full texts with Goodreads descriptions and arXiv bodies with their abstracts.
 
 
 
 
 
 
 
 
29
 
30
  ## Architecture
31
 
32
- LexiMind is a **custom Transformer implementation** that loads pre-trained weights from FLAN-T5-base via a factory module. The architecture is reimplemented from scratch for transparency, not wrapped from HuggingFace.
33
 
34
  | Component | Detail |
35
- | ----------- | -------- |
36
  | Backbone | Encoder-Decoder Transformer (272M params) |
37
- | Encoder / Decoder | 12 layers each |
38
- | Hidden Dim | 768, 12 attention heads |
39
- | Position Encoding | T5-style relative position bias |
40
- | Normalization | RMSNorm (Pre-LN) |
41
- | Attention | FlashAttention via PyTorch 2.0 SDPA |
42
- | Summarization Head | Full decoder with language modeling head |
43
- | Classification Heads | Linear layers on mean-pooled encoder states |
44
 
45
  ### Multi-Task Training
46
 
47
- All three tasks share the encoder. Summarization uses the full encoder-decoder; topic and emotion classification branch off the encoder with lightweight linear heads. Training uses round-robin scheduling (one batch per task per step), fixed loss weights (summarization=1.0, emotion=1.0, topic=0.3), and early stopping.
 
 
 
 
 
 
 
 
48
 
49
  ## Training Data
50
 
51
- | Task | Source | Train Samples |
52
- | ------ | -------- | --------------- |
53
- | Summarization | Gutenberg + Goodreads (literary) | ~4K |
54
  | Summarization | arXiv body → abstract (academic) | ~45K |
55
- | Topic | 20 Newsgroups + Gutenberg + arXiv metadata | 3,402 |
56
- | Emotion | GoEmotions (Reddit comments, 28 labels) | 43,410 |
 
 
57
 
58
  ## Getting Started
59
 
60
  ### Prerequisites
61
 
62
  - Python 3.10+
63
- - [Poetry](https://python-poetry.org/) for dependency management
64
  - NVIDIA GPU with CUDA (for training; CPU works for inference)
65
 
66
  ### Installation
@@ -68,59 +88,50 @@ All three tasks share the encoder. Summarization uses the full encoder-decoder;
68
  ```bash
69
  git clone https://github.com/OliverPerrin/LexiMind.git
70
  cd LexiMind
71
- poetry install
72
  ```
73
 
74
- ### Download Data
75
-
76
- ```bash
77
- poetry run python scripts/download_data.py
78
- ```
79
-
80
- Downloads Goodreads descriptions, arXiv papers, GoEmotions, 20 Newsgroups, and Gutenberg texts.
81
-
82
  ### Training
83
 
84
  ```bash
85
- # Full training (~45-60 min on RTX 4070 12GB)
86
- poetry run python scripts/train.py training=full
87
-
88
- # Quick dev run (~10-15 min)
89
- poetry run python scripts/train.py training=dev
90
 
91
- # Medium run (~30-45 min)
92
- poetry run python scripts/train.py training=medium
93
 
94
  # Override parameters
95
- poetry run python scripts/train.py training.optimizer.lr=5e-5
96
 
97
  # Resume from checkpoint
98
- poetry run python scripts/train.py training=full resume_from=checkpoints/epoch_5.pt
99
  ```
100
 
101
- Training uses BFloat16 mixed precision, gradient checkpointing, `torch.compile`, and cosine LR decay with warmup. Experiments are tracked with MLflow (`mlflow ui` to browse).
102
 
103
  ### Evaluation
104
 
105
  ```bash
106
- # Full evaluation (ROUGE, BERTScore, topic accuracy, emotion F1)
107
- poetry run python scripts/evaluate.py
108
-
109
- # Skip BERTScore for faster runs
110
- poetry run python scripts/evaluate.py --skip-bertscore
111
-
112
- # Single task
113
- poetry run python scripts/evaluate.py --summarization-only
114
  ```
115
 
116
  ### Inference
117
 
118
  ```bash
119
  # Command-line
120
- poetry run python scripts/inference.py "Your text to analyze"
121
 
122
  # Gradio web demo
123
- poetry run python scripts/demo_gradio.py
 
 
 
 
 
 
 
124
  ```
125
 
126
  ### Docker
@@ -133,76 +144,39 @@ docker run -p 7860:7860 leximind
133
  ## Project Structure
134
 
135
  ```text
136
- configs/
137
- ├── config.yaml # Main Hydra config
138
- ├── data/datasets.yaml # Dataset paths and tokenizer settings
139
- ├── model/ # Architecture configs (base, small, large)
140
- └── training/ # Training configs (dev, medium, full)
141
-
142
  src/
143
- ├── models/
144
- ├── encoder.py # Transformer Encoder with Pre-LN RMSNorm
145
- ├── decoder.py # Transformer Decoder with KV-cache
146
- ├── attention.py # Multi-Head Attention + T5 relative position bias
147
- ├── feedforward.py # Gated feed-forward network
148
- │ ├── positional_encoding.py # Sinusoidal & learned position encodings
149
- │ ├── t5_layer_norm.py # T5-style RMSNorm
150
- │ ├── heads.py # Task-specific classification heads
151
- │ ├── multitask.py # Multi-task model combining all components
152
- │ └── factory.py # Model builder with FLAN-T5 weight loading
153
- ├── data/
154
- │ ├── dataset.py # Dataset classes for all tasks
155
- │ ├── dataloader.py # Multi-task dataloader with round-robin sampling
156
- │ └── tokenization.py # Tokenizer wrapper
157
- ├── training/
158
- │ ├── trainer.py # Training loop with AMP, grad accumulation, early stopping
159
- │ ├── metrics.py # ROUGE, BERTScore, F1, accuracy computation
160
- │ └── utils.py # Checkpointing, logging utilities
161
- ├── inference/
162
- │ ├── pipeline.py # End-to-end inference pipeline
163
- │ └── factory.py # Model loading for inference
164
- ├── api/ # FastAPI REST endpoint
165
- └── utils/ # Shared utilities
166
 
167
  scripts/
168
- ├── train.py # Training entry point
169
- ├── evaluate.py # Evaluation with all metrics
170
- ├── inference.py # CLI inference
171
- ├── demo_gradio.py # Gradio web UI
172
- ├── download_data.py # Dataset downloader
173
- ├── export_model.py # Model export utilities
174
- ├── export_tokenizer.py # Tokenizer export
175
- ├── preprocess_data.py # Data preprocessing
176
- ── process_books.py # Gutenberg text processing
177
- ├── eval_rouge.py # ROUGE-only evaluation
178
- └── visualize_training.py # Training curve plotting
179
-
180
- tests/ # Pytest suite (data, models, training, inference, utils)
181
- docs/ # Research paper and architecture notes
182
- artifacts/ # Tokenizer files and label definitions
183
- checkpoints/ # Saved model checkpoints
184
  ```
185
 
186
  ## Code Quality
187
 
188
  ```bash
189
- poetry run ruff check . # Linting
190
- poetry run mypy . # Type checking
191
- poetry run pytest # Test suite
192
- poetry run pre-commit run --all-files # All checks
193
  ```
194
 
195
- ## Key Results
196
-
197
- From the research paper ([docs/research_paper.tex](docs/research_paper.tex)):
198
-
199
- - **Multi-task learning helps topic classification** (+3.2% accuracy over single-task) because the small topic dataset (3.4K) benefits from shared encoder representations trained on the larger summarization corpus (49K).
200
- - **Summarization is robust to MTL**—quality stays comparable whether trained alone or jointly.
201
- - **Emotion detection shows slight negative transfer** (−0.02 F1), likely due to domain mismatch between Reddit-sourced emotion labels and literary/academic text.
202
- - **FLAN-T5 pre-training is essential**—random initialization produces dramatically worse results on all tasks.
203
-
204
- See the paper for full ablations, per-class breakdowns, and discussion of limitations.
205
-
206
  ## License
207
 
208
  GPL-3.0 — see [LICENSE](LICENSE) for details.
 
11
  <!-- markdownlint-disable MD025 -->
12
  # LexiMind
13
 
14
+ A multi-task NLP system for literary and academic text understanding. LexiMind jointly performs **abstractive summarization**, **topic classification**, and **multi-label emotion detection** using a single encoder-decoder transformer initialized from [FLAN-T5-base](https://huggingface.co/google/flan-t5-base) (272M parameters).
15
 
16
  **[Live Demo](https://huggingface.co/spaces/OliverPerrin/LexiMind)** · **[Model](https://huggingface.co/OliverPerrin/LexiMind-Model)** · **[Discovery Dataset](https://huggingface.co/datasets/OliverPerrin/LexiMind-Discovery)** · **[Research Paper](docs/research_paper.tex)**
17
 
18
+ ## Results
19
 
20
+ | Task | Metric | Score |
21
+ | ---- | ------ | ----- |
22
+ | Summarization | ROUGE-1 / ROUGE-L | 0.309 / 0.185 |
23
+ | Summarization (academic) | ROUGE-1 | 0.319 |
24
+ | Summarization (literary) | ROUGE-1 | 0.206 |
25
+ | Topic Classification | Accuracy (95% CI) | 85.7% (80.4–91.0%) |
26
+ | Emotion Detection | Sample-avg F1 | 0.352 |
27
+ | Emotion Detection (tuned thresholds) | Sample-avg F1 / Macro F1 | 0.503 / 0.294 |
28
 
29
+ Trained for 8 epochs on an RTX 4070 12GB (~9 hours) with BFloat16 mixed precision, `torch.compile`, and cosine LR decay.
30
 
31
+ ## Key Findings
32
+
33
+ From the [research paper](docs/research_paper.tex):
34
+
35
+ - **Naive MTL produces mixed results**: topic classification benefits (+3.7% accuracy), but emotion detection suffers negative transfer (−0.02 F1) under mean pooling with round-robin scheduling.
36
+ - **Learned attention pooling + temperature sampling eliminates negative transfer entirely**: emotion F1 improves from 0.199 → 0.352 (+77%), surpassing the single-task baseline (0.218).
37
+ - **Summarization is robust to MTL** — quality remains stable across configurations.
38
+ - **FLAN-T5 pre-training is essential** — random initialization produces dramatically worse results on all tasks.
39
+ - **Domain gap matters**: academic summaries (ROUGE-1: 0.319) substantially outperform literary (0.206), driven by an 11:1 training data imbalance.
40
 
41
  ## Architecture
42
 
43
+ LexiMind is a **from-scratch PyTorch Transformer** that loads pre-trained FLAN-T5-base weights layer by layer via a custom factory module no HuggingFace model wrappers.
44
 
45
  | Component | Detail |
46
+ | --------- | ------ |
47
  | Backbone | Encoder-Decoder Transformer (272M params) |
48
+ | Encoder / Decoder | 12 layers each, 768d, 12 attention heads |
49
+ | Normalization | RMSNorm (Pre-LN, T5-style) |
50
+ | Attention | FlashAttention via PyTorch SDPA + T5 relative position bias |
51
+ | FFN | Gated-GELU (wi\_0, wi\_1, wo) |
52
+ | Summarization | Full decoder language modeling head |
53
+ | Emotion (28-class multi-label) | Learned attention pooling linear head |
54
+ | Topic (7-class) | Mean pooling linear head |
55
 
56
  ### Multi-Task Training
57
 
58
+ All three tasks share the encoder. Summarization uses the full encoder-decoder; classification heads branch off the encoder output. Key training details:
59
+
60
+ - **Temperature-based task sampling** (α=0.5): allocates training steps proportional to dataset size, preventing large tasks from dominating
61
+ - **Attention pooling** for emotion: a learned query attends over encoder outputs, focusing on emotionally salient tokens rather than averaging the full sequence
62
+ - **Fixed loss weights**: summarization=1.0, emotion=1.0, topic=0.3 (reduced to prevent overfitting on the small topic dataset)
63
+ - **Frozen encoder layers 0–3**: preserves FLAN-T5's language understanding in lower layers
64
+ - **Gradient conflict diagnostics**: optional inter-task gradient cosine similarity monitoring
65
+
66
+ See [docs/architecture.md](docs/architecture.md) for full implementation details, weight loading tables, and training configuration rationale.
67
 
68
  ## Training Data
69
 
70
+ | Task | Source | Samples |
71
+ | ---- | ------ | ------- |
72
+ | Summarization | Gutenberg + Goodreads descriptions (literary) | ~4K |
73
  | Summarization | arXiv body → abstract (academic) | ~45K |
74
+ | Topic | Gutenberg + arXiv metadata → 7 categories | 3,402 |
75
+ | Emotion | GoEmotions Reddit comments, 28 labels | 43,410 |
76
+
77
+ For summarization, the model learns to produce descriptive summaries — what a book *is about* — rather than plot recaps, by pairing Gutenberg full texts with Goodreads descriptions and arXiv papers with their abstracts.
78
 
79
  ## Getting Started
80
 
81
  ### Prerequisites
82
 
83
  - Python 3.10+
 
84
  - NVIDIA GPU with CUDA (for training; CPU works for inference)
85
 
86
  ### Installation
 
88
  ```bash
89
  git clone https://github.com/OliverPerrin/LexiMind.git
90
  cd LexiMind
91
+ pip install -r requirements.txt
92
  ```
93
 
 
 
 
 
 
 
 
 
94
  ### Training
95
 
96
  ```bash
97
+ # Full training (~9 hours on RTX 4070 12GB)
98
+ python scripts/train.py training=full
 
 
 
99
 
100
+ # Quick dev run
101
+ python scripts/train.py training=dev
102
 
103
  # Override parameters
104
+ python scripts/train.py training=full training.optimizer.lr=5e-5
105
 
106
  # Resume from checkpoint
107
+ python scripts/train.py training=full resume_from=checkpoints/epoch_5.pt
108
  ```
109
 
110
+ Experiments are tracked with MLflow (`mlflow ui` to browse).
111
 
112
  ### Evaluation
113
 
114
  ```bash
115
+ python scripts/evaluate.py
116
+ python scripts/evaluate.py --skip-bertscore # faster
117
+ python scripts/evaluate.py --tune-thresholds # per-class threshold tuning
 
 
 
 
 
118
  ```
119
 
120
  ### Inference
121
 
122
  ```bash
123
  # Command-line
124
+ python scripts/inference.py "Your text to analyze"
125
 
126
  # Gradio web demo
127
+ python scripts/demo_gradio.py
128
+ ```
129
+
130
+ ### Profiling
131
+
132
+ ```bash
133
+ # Profile GPU usage (CUDA kernels, memory, Chrome trace)
134
+ python scripts/profile_training.py
135
  ```
136
 
137
  ### Docker
 
144
  ## Project Structure
145
 
146
  ```text
 
 
 
 
 
 
147
  src/
148
+ ├── models/ # Encoder, decoder, attention, FFN, heads, factory
149
+ ├── data/ # Datasets, dataloaders, tokenization, cross-task dedup
150
+ ├── training/ # Trainer (AMP, grad accum, temperature sampling), metrics
151
+ ├── inference/ # Pipeline + factory for checkpoint loading
152
+ ├── api/ # FastAPI REST endpoint
153
+ ── utils/ # Device detection, checkpointing, label I/O
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  scripts/
156
+ ├── train.py # Hydra training entry point
157
+ ├── evaluate.py # Full evaluation suite
158
+ ├── inference.py # CLI inference
159
+ ├── demo_gradio.py # Gradio discovery demo
160
+ ├── profile_training.py # PyTorch profiler
161
+ ├── train_multiseed.py # Multi-seed training with aggregation
162
+ ├── visualize_training.py # Training curve visualization
163
+ ├── download_data.py # Dataset downloader
164
+ ── build_discovery_dataset.py # Pre-compute discovery dataset
165
+
166
+ configs/ # Hydra configs (model, training, data)
167
+ docs/ # Research paper + architecture documentation
168
+ tests/ # Pytest suite
 
 
 
169
  ```
170
 
171
  ## Code Quality
172
 
173
  ```bash
174
+ ruff check . # Linting
175
+ mypy src/ scripts/ tests/ # Type checking
176
+ pytest # Tests
177
+ pre-commit run --all-files # All checks
178
  ```
179
 
 
 
 
 
 
 
 
 
 
 
 
180
  ## License
181
 
182
  GPL-3.0 — see [LICENSE](LICENSE) for details.
scripts/build_discovery_dataset.py CHANGED
@@ -29,134 +29,140 @@ from src.inference.factory import create_inference_pipeline # noqa: E402
29
 
30
  # --------------- Data Loading ---------------
31
 
 
32
  def load_academic_papers(data_dir: Path, max_samples: int = 300) -> list[dict]:
33
  """Load academic paper samples from the training data."""
34
  summ_file = data_dir / "summarization" / "train.jsonl"
35
-
36
  if not summ_file.exists():
37
  print(f" Warning: {summ_file} not found")
38
  return []
39
-
40
  academic = []
41
  with open(summ_file) as f:
42
  for line in f:
43
  item = json.loads(line)
44
  if item.get("type") != "academic":
45
  continue
46
-
47
  text = item.get("source", "")
48
  if len(text) < 500:
49
  continue
50
-
51
  # Use title from data
52
  title = item.get("title", "Research Paper")
53
-
54
- academic.append({
55
- "text": text[:2000],
56
- "title": title,
57
- "reference_summary": item.get("summary", "")[:500]
58
- })
59
-
 
 
60
  random.seed(42)
61
  samples = random.sample(academic, min(max_samples, len(academic)))
62
-
63
  results = []
64
  for i, item in enumerate(samples):
65
- results.append({
66
- "id": f"paper_{i}",
67
- "title": item["title"],
68
- "text": item["text"],
69
- "source_type": "academic",
70
- "dataset": "arxiv",
71
- "reference_summary": item["reference_summary"]
72
- })
73
-
 
 
74
  print(f" Loaded {len(results)} academic papers")
75
  return results
76
 
77
 
78
  def load_literary(data_dir: Path, max_samples: int = 300) -> list[dict]:
79
  """Load literary samples from the training data.
80
-
81
  Training data now contains Goodreads descriptions (back-cover style)
82
  instead of plot summaries.
83
  """
84
  summ_file = data_dir / "summarization" / "train.jsonl"
85
-
86
  if not summ_file.exists():
87
  print(f" Warning: {summ_file} not found")
88
  return []
89
-
90
  literary = []
91
  seen_titles = set()
92
-
93
  with open(summ_file) as f:
94
  for line in f:
95
  item = json.loads(line)
96
  if item.get("type") != "literary":
97
  continue
98
-
99
  title = item.get("title", "")
100
  if not title or title in seen_titles:
101
  continue
102
-
103
  text = item.get("source", "")
104
  summary = item.get("summary", "")
105
-
106
  if len(text) < 500 or len(summary) < 50:
107
  continue
108
-
109
  seen_titles.add(title)
110
- literary.append({
111
- "text": text[:2000],
112
- "title": title,
113
- "reference_summary": summary[:600]
114
- })
115
-
116
  random.seed(42)
117
  samples = random.sample(literary, min(max_samples, len(literary)))
118
-
119
  results = []
120
  for i, item in enumerate(samples):
121
- results.append({
122
- "id": f"literary_{i}",
123
- "title": item["title"],
124
- "text": item["text"],
125
- "source_type": "literary",
126
- "dataset": "goodreads",
127
- "reference_summary": item["reference_summary"],
128
- })
129
-
 
 
130
  print(f" Loaded {len(results)} literary works (unique titles)")
131
  return results
132
 
133
 
134
  # --------------- Inference ---------------
135
 
 
136
  def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
137
  """Run model inference on all samples."""
138
  results = []
139
-
140
  for sample in tqdm(samples, desc="Running inference"):
141
  text = sample["text"]
142
-
143
  # Get model predictions using correct pipeline methods
144
  summaries = pipeline.summarize([text])
145
  topics = pipeline.predict_topics([text])
146
  emotions = pipeline.predict_emotions([text])
147
-
148
  # Extract first result from each list
149
  summary = summaries[0] if summaries else ""
150
  topic = topics[0] if topics else None
151
  emotion = emotions[0] if emotions else None
152
-
153
  # Get primary emotion (highest confidence if any detected)
154
  primary_emotion = "neutral"
155
  emotion_confidence = 0.0
156
  if emotion and emotion.labels:
157
  primary_emotion = emotion.labels[0]
158
  emotion_confidence = emotion.scores[0]
159
-
160
  result = {
161
  "id": sample["id"],
162
  "title": sample["title"],
@@ -170,24 +176,25 @@ def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
170
  "generated_summary": summary,
171
  "reference_summary": sample.get("reference_summary", ""),
172
  }
173
-
174
  results.append(result)
175
-
176
  # Print distribution stats
177
  topic_dist: dict[str, int] = defaultdict(int)
178
  emotion_dist: dict[str, int] = defaultdict(int)
179
  for r in results:
180
  topic_dist[r["topic"]] += 1
181
  emotion_dist[r["emotion"]] += 1
182
-
183
  print(f"\nTopic distribution: {dict(topic_dist)}")
184
  print(f"Emotion distribution: {dict(emotion_dist)}")
185
-
186
  return results
187
 
188
 
189
  def main():
190
  import argparse
 
191
  parser = argparse.ArgumentParser(description="Build discovery dataset for HuggingFace Space")
192
  parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
193
  parser.add_argument("--checkpoint", type=Path, default=Path("checkpoints/best.pt"))
@@ -197,41 +204,39 @@ def main():
197
  parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
198
  parser.add_argument("--hub-repo", type=str, default="OliverPerrin/LexiMind-Discovery")
199
  args = parser.parse_args()
200
-
201
  print("Loading data samples from training data...")
202
  print("(Data has already been filtered by download_data.py)")
203
-
204
  # Load samples from training data
205
  papers = load_academic_papers(args.data_dir, args.num_papers)
206
  literary = load_literary(args.data_dir, args.num_literary)
207
-
208
  all_samples = papers + literary
209
  print(f"\nTotal samples: {len(all_samples)} ({len(papers)} papers, {len(literary)} literary)")
210
-
211
  if not all_samples:
212
  print("ERROR: No samples loaded! Check if data/processed exists and has data.")
213
  print("Run: python scripts/download_data.py --task summarization")
214
  return
215
-
216
  # Load model and run inference
217
  print(f"\nLoading model from {args.checkpoint}...")
218
  labels_path = Path("artifacts/labels.json")
219
  pipeline, labels = create_inference_pipeline(
220
- args.checkpoint,
221
- labels_path,
222
- device="cuda" if torch.cuda.is_available() else "cpu"
223
  )
224
-
225
  print("Running inference on all samples...")
226
  results = run_inference(pipeline, all_samples)
227
-
228
  # Save locally
229
  print(f"\nSaving to {args.output}...")
230
  args.output.parent.mkdir(parents=True, exist_ok=True)
231
  with open(args.output, "w") as f:
232
  for item in results:
233
  f.write(json.dumps(item) + "\n")
234
-
235
  # Push to HuggingFace Hub
236
  if args.push_to_hub:
237
  print(f"\nPushing to HuggingFace Hub: {args.hub_repo}")
@@ -239,10 +244,10 @@ def main():
239
  dataset.push_to_hub(
240
  args.hub_repo,
241
  private=False,
242
- commit_message="Rebuild with Goodreads descriptions (back-cover style)"
243
  )
244
  print(f"Dataset available at: https://huggingface.co/datasets/{args.hub_repo}")
245
-
246
  print("\nDone!")
247
 
248
 
 
29
 
30
  # --------------- Data Loading ---------------
31
 
32
+
33
  def load_academic_papers(data_dir: Path, max_samples: int = 300) -> list[dict]:
34
  """Load academic paper samples from the training data."""
35
  summ_file = data_dir / "summarization" / "train.jsonl"
36
+
37
  if not summ_file.exists():
38
  print(f" Warning: {summ_file} not found")
39
  return []
40
+
41
  academic = []
42
  with open(summ_file) as f:
43
  for line in f:
44
  item = json.loads(line)
45
  if item.get("type") != "academic":
46
  continue
47
+
48
  text = item.get("source", "")
49
  if len(text) < 500:
50
  continue
51
+
52
  # Use title from data
53
  title = item.get("title", "Research Paper")
54
+
55
+ academic.append(
56
+ {
57
+ "text": text[:2000],
58
+ "title": title,
59
+ "reference_summary": item.get("summary", "")[:500],
60
+ }
61
+ )
62
+
63
  random.seed(42)
64
  samples = random.sample(academic, min(max_samples, len(academic)))
65
+
66
  results = []
67
  for i, item in enumerate(samples):
68
+ results.append(
69
+ {
70
+ "id": f"paper_{i}",
71
+ "title": item["title"],
72
+ "text": item["text"],
73
+ "source_type": "academic",
74
+ "dataset": "arxiv",
75
+ "reference_summary": item["reference_summary"],
76
+ }
77
+ )
78
+
79
  print(f" Loaded {len(results)} academic papers")
80
  return results
81
 
82
 
83
  def load_literary(data_dir: Path, max_samples: int = 300) -> list[dict]:
84
  """Load literary samples from the training data.
85
+
86
  Training data now contains Goodreads descriptions (back-cover style)
87
  instead of plot summaries.
88
  """
89
  summ_file = data_dir / "summarization" / "train.jsonl"
90
+
91
  if not summ_file.exists():
92
  print(f" Warning: {summ_file} not found")
93
  return []
94
+
95
  literary = []
96
  seen_titles = set()
97
+
98
  with open(summ_file) as f:
99
  for line in f:
100
  item = json.loads(line)
101
  if item.get("type") != "literary":
102
  continue
103
+
104
  title = item.get("title", "")
105
  if not title or title in seen_titles:
106
  continue
107
+
108
  text = item.get("source", "")
109
  summary = item.get("summary", "")
110
+
111
  if len(text) < 500 or len(summary) < 50:
112
  continue
113
+
114
  seen_titles.add(title)
115
+ literary.append(
116
+ {"text": text[:2000], "title": title, "reference_summary": summary[:600]}
117
+ )
118
+
 
 
119
  random.seed(42)
120
  samples = random.sample(literary, min(max_samples, len(literary)))
121
+
122
  results = []
123
  for i, item in enumerate(samples):
124
+ results.append(
125
+ {
126
+ "id": f"literary_{i}",
127
+ "title": item["title"],
128
+ "text": item["text"],
129
+ "source_type": "literary",
130
+ "dataset": "goodreads",
131
+ "reference_summary": item["reference_summary"],
132
+ }
133
+ )
134
+
135
  print(f" Loaded {len(results)} literary works (unique titles)")
136
  return results
137
 
138
 
139
  # --------------- Inference ---------------
140
 
141
+
142
  def run_inference(pipeline: Any, samples: list[dict]) -> list[dict]:
143
  """Run model inference on all samples."""
144
  results = []
145
+
146
  for sample in tqdm(samples, desc="Running inference"):
147
  text = sample["text"]
148
+
149
  # Get model predictions using correct pipeline methods
150
  summaries = pipeline.summarize([text])
151
  topics = pipeline.predict_topics([text])
152
  emotions = pipeline.predict_emotions([text])
153
+
154
  # Extract first result from each list
155
  summary = summaries[0] if summaries else ""
156
  topic = topics[0] if topics else None
157
  emotion = emotions[0] if emotions else None
158
+
159
  # Get primary emotion (highest confidence if any detected)
160
  primary_emotion = "neutral"
161
  emotion_confidence = 0.0
162
  if emotion and emotion.labels:
163
  primary_emotion = emotion.labels[0]
164
  emotion_confidence = emotion.scores[0]
165
+
166
  result = {
167
  "id": sample["id"],
168
  "title": sample["title"],
 
176
  "generated_summary": summary,
177
  "reference_summary": sample.get("reference_summary", ""),
178
  }
179
+
180
  results.append(result)
181
+
182
  # Print distribution stats
183
  topic_dist: dict[str, int] = defaultdict(int)
184
  emotion_dist: dict[str, int] = defaultdict(int)
185
  for r in results:
186
  topic_dist[r["topic"]] += 1
187
  emotion_dist[r["emotion"]] += 1
188
+
189
  print(f"\nTopic distribution: {dict(topic_dist)}")
190
  print(f"Emotion distribution: {dict(emotion_dist)}")
191
+
192
  return results
193
 
194
 
195
  def main():
196
  import argparse
197
+
198
  parser = argparse.ArgumentParser(description="Build discovery dataset for HuggingFace Space")
199
  parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
200
  parser.add_argument("--checkpoint", type=Path, default=Path("checkpoints/best.pt"))
 
204
  parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
205
  parser.add_argument("--hub-repo", type=str, default="OliverPerrin/LexiMind-Discovery")
206
  args = parser.parse_args()
207
+
208
  print("Loading data samples from training data...")
209
  print("(Data has already been filtered by download_data.py)")
210
+
211
  # Load samples from training data
212
  papers = load_academic_papers(args.data_dir, args.num_papers)
213
  literary = load_literary(args.data_dir, args.num_literary)
214
+
215
  all_samples = papers + literary
216
  print(f"\nTotal samples: {len(all_samples)} ({len(papers)} papers, {len(literary)} literary)")
217
+
218
  if not all_samples:
219
  print("ERROR: No samples loaded! Check if data/processed exists and has data.")
220
  print("Run: python scripts/download_data.py --task summarization")
221
  return
222
+
223
  # Load model and run inference
224
  print(f"\nLoading model from {args.checkpoint}...")
225
  labels_path = Path("artifacts/labels.json")
226
  pipeline, labels = create_inference_pipeline(
227
+ args.checkpoint, labels_path, device="cuda" if torch.cuda.is_available() else "cpu"
 
 
228
  )
229
+
230
  print("Running inference on all samples...")
231
  results = run_inference(pipeline, all_samples)
232
+
233
  # Save locally
234
  print(f"\nSaving to {args.output}...")
235
  args.output.parent.mkdir(parents=True, exist_ok=True)
236
  with open(args.output, "w") as f:
237
  for item in results:
238
  f.write(json.dumps(item) + "\n")
239
+
240
  # Push to HuggingFace Hub
241
  if args.push_to_hub:
242
  print(f"\nPushing to HuggingFace Hub: {args.hub_repo}")
 
244
  dataset.push_to_hub(
245
  args.hub_repo,
246
  private=False,
247
+ commit_message="Rebuild with Goodreads descriptions (back-cover style)",
248
  )
249
  print(f"Dataset available at: https://huggingface.co/datasets/{args.hub_repo}")
250
+
251
  print("\nDone!")
252
 
253
 
scripts/demo_gradio.py CHANGED
@@ -27,8 +27,12 @@ print(f"Loaded {len(_dataset)} items")
27
  ALL_ITEMS: list[dict[str, Any]] = [dict(row) for row in _dataset]
28
 
29
  # Extract unique topics and emotions FROM THE DATASET (what model predicted)
30
- DATASET_TOPICS: list[str] = sorted(set(str(item["topic"]) for item in ALL_ITEMS if item.get("topic")))
31
- DATASET_EMOTIONS: list[str] = sorted(set(str(item["emotion"]) for item in ALL_ITEMS if item.get("emotion")))
 
 
 
 
32
 
33
  # Load ALL possible labels from labels.json (what the model CAN predict)
34
  _labels_path = Path(__file__).parent.parent / "artifacts" / "labels.json"
@@ -90,19 +94,19 @@ def format_item_card(item: dict) -> str:
90
  title = item.get("title", "Unknown")
91
  source_type = item.get("source_type", "unknown")
92
  dataset_name = item.get("dataset", "").title()
93
-
94
  # Icon based on type
95
  if source_type == "academic":
96
  type_label = "Research Paper"
97
  else:
98
  type_label = "Literature"
99
-
100
  # Topic and emotion with confidence
101
  topic = item.get("topic", "Unknown")
102
  topic_conf = item.get("topic_confidence", 0)
103
  emotion = item.get("emotion", "Unknown")
104
  emotion_conf = item.get("emotion_confidence", 0)
105
-
106
  # Summary - check if using reference or generated
107
  use_reference = item.get("use_reference_summary", False)
108
  if use_reference or source_type == "literary":
@@ -111,17 +115,21 @@ def format_item_card(item: dict) -> str:
111
  else:
112
  summary = item.get("generated_summary", "")
113
  summary_label = "**AI-Generated Description:**"
114
-
115
  if not summary:
116
  summary = "No summary available."
117
-
118
  # Truncate summary if too long
119
  if len(summary) > 400:
120
- summary = summary[:400].rsplit(' ', 1)[0] + "..."
121
-
122
  # Preview of original text
123
- text_preview = item.get("text", "")[:400] + "..." if len(item.get("text", "")) > 400 else item.get("text", "")
124
-
 
 
 
 
125
  return f"""### **{title}**
126
 
127
  <small>*{type_label}* from {dataset_name}</small>
@@ -147,24 +155,24 @@ def browse_by_topic(topic: str) -> str:
147
  items = get_items_by_topic(topic)
148
  if not items:
149
  return "No items found for this topic."
150
-
151
  # Group by type
152
  literary = [i for i in items if i.get("source_type") == "literary"]
153
  academic = [i for i in items if i.get("source_type") == "academic"]
154
-
155
  result = f"## {topic if topic != 'All' else 'All Topics'}\n\n"
156
  result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
157
-
158
  if literary:
159
  result += "### Literary Works\n\n"
160
  for item in literary[:25]: # Limit to avoid huge pages
161
  result += format_item_card(item)
162
-
163
  if academic:
164
  result += "### Academic Papers\n\n"
165
  for item in academic[:25]:
166
  result += format_item_card(item)
167
-
168
  return result
169
 
170
 
@@ -173,23 +181,23 @@ def browse_by_emotion(emotion: str) -> str:
173
  items = get_items_by_emotion(emotion)
174
  if not items:
175
  return "No items found for this emotion."
176
-
177
  literary = [i for i in items if i.get("source_type") == "literary"]
178
  academic = [i for i in items if i.get("source_type") == "academic"]
179
-
180
  result = f"## Feeling {emotion.title() if emotion != 'All' else 'All Emotions'}?\n\n"
181
  result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
182
-
183
  if literary:
184
  result += "### Literary Works\n\n"
185
  for item in literary[:25]:
186
  result += format_item_card(item)
187
-
188
  if academic:
189
  result += "### Academic Papers\n\n"
190
  for item in academic[:25]:
191
  result += format_item_card(item)
192
-
193
  return result
194
 
195
 
@@ -197,24 +205,25 @@ def search_items(query: str) -> str:
197
  """Search items by text content."""
198
  if not query or len(query) < 3:
199
  return "Enter at least 3 characters to search."
200
-
201
  query_lower = query.lower()
202
  matches = [
203
- item for item in ALL_ITEMS
 
204
  if query_lower in item.get("text", "").lower()
205
  or query_lower in item.get("generated_summary", "").lower()
206
  or query_lower in item.get("title", "").lower()
207
  ]
208
-
209
  if not matches:
210
  return f"No results found for '{query}'."
211
-
212
  result = f"## Search Results for '{query}'\n\n"
213
  result += f"*Found {len(matches)} matching items*\n\n"
214
-
215
  for item in matches[:30]:
216
  result += format_item_card(item)
217
-
218
  return result
219
 
220
 
@@ -226,9 +235,8 @@ with gr.Blocks(
226
  css="""
227
  .result-box { max-height: 700px; overflow-y: auto; }
228
  h3 { margin-top: 0.5em !important; }
229
- """
230
  ) as demo:
231
-
232
  gr.Markdown(
233
  """
234
  # LexiMind
@@ -237,79 +245,75 @@ with gr.Blocks(
237
  Browse **{total_count}** texts — {lit_count} classic books and {paper_count} research papers — analyzed by a multi-task transformer.
238
 
239
  ---
240
- """.format(
241
- total_count=len(ALL_ITEMS),
242
- lit_count=len(BOOKS),
243
- paper_count=len(PAPERS)
244
- )
245
  )
246
-
247
  with gr.Tabs():
248
  # ===================== TAB 1: BROWSE BY TOPIC =====================
249
  with gr.Tab("By Topic"):
250
  gr.Markdown("*Select a topic to explore related books and papers*")
251
-
252
  topic_dropdown = gr.Dropdown(
253
  choices=["All"] + TOPICS,
254
  value="All",
255
  label="Select Topic",
256
  interactive=True,
257
  )
258
-
259
  topic_results = gr.Markdown(
260
  value=browse_by_topic("All"),
261
  elem_classes=["result-box"],
262
  )
263
-
264
  topic_dropdown.change(
265
  fn=browse_by_topic,
266
  inputs=[topic_dropdown],
267
  outputs=[topic_results],
268
  )
269
-
270
  # ===================== TAB 2: BROWSE BY EMOTION =====================
271
  with gr.Tab("By Emotion"):
272
  gr.Markdown("*Find books and papers that evoke specific emotions*")
273
-
274
  emotion_dropdown = gr.Dropdown(
275
  choices=["All"] + [e.title() for e in EMOTIONS],
276
  value="All",
277
  label="Select Emotion",
278
  interactive=True,
279
  )
280
-
281
  emotion_results = gr.Markdown(
282
  value=browse_by_emotion("All"),
283
  elem_classes=["result-box"],
284
  )
285
-
286
  emotion_dropdown.change(
287
  fn=lambda e: browse_by_emotion(e.lower() if e != "All" else "All"),
288
  inputs=[emotion_dropdown],
289
  outputs=[emotion_results],
290
  )
291
-
292
  # ===================== TAB 3: SEARCH =====================
293
  with gr.Tab("Search"):
294
  gr.Markdown("*Search through all books and papers by keyword*")
295
-
296
  search_input = gr.Textbox(
297
  placeholder="Enter keywords to search...",
298
  label="Search",
299
  interactive=True,
300
  )
301
-
302
  search_results = gr.Markdown(
303
  value="Enter at least 3 characters to search.",
304
  elem_classes=["result-box"],
305
  )
306
-
307
  search_input.change(
308
  fn=search_items,
309
  inputs=[search_input],
310
  outputs=[search_results],
311
  )
312
-
313
  # ===================== TAB 4: METRICS =====================
314
  with gr.Tab("Metrics"):
315
  gr.Markdown(
@@ -319,10 +323,10 @@ with gr.Blocks(
319
  Computed on held-out validation data.
320
  """
321
  )
322
-
323
  # Summarization Metrics
324
  gr.Markdown("#### Summarization")
325
-
326
  if METRICS.get("summarization"):
327
  summ = METRICS["summarization"]
328
  summ_md = """
@@ -341,10 +345,10 @@ with gr.Blocks(
341
  gr.Markdown(summ_md)
342
  else:
343
  gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
344
-
345
  # Topic Classification Metrics
346
  gr.Markdown("#### Topic Classification")
347
-
348
  if METRICS.get("topic"):
349
  topic = METRICS["topic"]
350
  topic_md = """
@@ -359,10 +363,10 @@ with gr.Blocks(
359
  gr.Markdown(topic_md)
360
  else:
361
  gr.Markdown("*Topic classification metrics not available.*")
362
-
363
  # Emotion Detection Metrics
364
  gr.Markdown("#### Emotion Detection")
365
-
366
  if METRICS.get("emotion"):
367
  emotion = METRICS["emotion"]
368
  emotion_md = """
@@ -374,17 +378,19 @@ with gr.Blocks(
374
 
375
  *28-label multi-label classification from GoEmotions.*
376
  """.format(
377
- sample_f1=emotion.get("sample_avg_f1", emotion.get("f1", emotion.get("multilabel_f1", 0))),
 
 
378
  macro_f1=emotion.get("macro_f1", 0),
379
  micro_f1=emotion.get("micro_f1", 0),
380
  )
381
  gr.Markdown(emotion_md)
382
  else:
383
  gr.Markdown("*Emotion detection metrics not available.*")
384
-
385
  # Dataset Statistics
386
  gr.Markdown("#### Dataset Statistics")
387
-
388
  gr.Markdown(f"""
389
  | Statistic | Value |
390
  |-----------|-------|
@@ -394,7 +400,7 @@ with gr.Blocks(
394
  | Topics | {len(TOPICS)} |
395
  | Emotions | {len(EMOTIONS)} |
396
  """)
397
-
398
  # ===================== TAB 5: ABOUT =====================
399
  with gr.Tab("About"):
400
  gr.Markdown(
@@ -420,4 +426,3 @@ with gr.Blocks(
420
 
421
  if __name__ == "__main__":
422
  demo.launch(server_name="0.0.0.0", server_port=7860)
423
-
 
27
  ALL_ITEMS: list[dict[str, Any]] = [dict(row) for row in _dataset]
28
 
29
  # Extract unique topics and emotions FROM THE DATASET (what model predicted)
30
+ DATASET_TOPICS: list[str] = sorted(
31
+ set(str(item["topic"]) for item in ALL_ITEMS if item.get("topic"))
32
+ )
33
+ DATASET_EMOTIONS: list[str] = sorted(
34
+ set(str(item["emotion"]) for item in ALL_ITEMS if item.get("emotion"))
35
+ )
36
 
37
  # Load ALL possible labels from labels.json (what the model CAN predict)
38
  _labels_path = Path(__file__).parent.parent / "artifacts" / "labels.json"
 
94
  title = item.get("title", "Unknown")
95
  source_type = item.get("source_type", "unknown")
96
  dataset_name = item.get("dataset", "").title()
97
+
98
  # Icon based on type
99
  if source_type == "academic":
100
  type_label = "Research Paper"
101
  else:
102
  type_label = "Literature"
103
+
104
  # Topic and emotion with confidence
105
  topic = item.get("topic", "Unknown")
106
  topic_conf = item.get("topic_confidence", 0)
107
  emotion = item.get("emotion", "Unknown")
108
  emotion_conf = item.get("emotion_confidence", 0)
109
+
110
  # Summary - check if using reference or generated
111
  use_reference = item.get("use_reference_summary", False)
112
  if use_reference or source_type == "literary":
 
115
  else:
116
  summary = item.get("generated_summary", "")
117
  summary_label = "**AI-Generated Description:**"
118
+
119
  if not summary:
120
  summary = "No summary available."
121
+
122
  # Truncate summary if too long
123
  if len(summary) > 400:
124
+ summary = summary[:400].rsplit(" ", 1)[0] + "..."
125
+
126
  # Preview of original text
127
+ text_preview = (
128
+ item.get("text", "")[:400] + "..."
129
+ if len(item.get("text", "")) > 400
130
+ else item.get("text", "")
131
+ )
132
+
133
  return f"""### **{title}**
134
 
135
  <small>*{type_label}* from {dataset_name}</small>
 
155
  items = get_items_by_topic(topic)
156
  if not items:
157
  return "No items found for this topic."
158
+
159
  # Group by type
160
  literary = [i for i in items if i.get("source_type") == "literary"]
161
  academic = [i for i in items if i.get("source_type") == "academic"]
162
+
163
  result = f"## {topic if topic != 'All' else 'All Topics'}\n\n"
164
  result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
165
+
166
  if literary:
167
  result += "### Literary Works\n\n"
168
  for item in literary[:25]: # Limit to avoid huge pages
169
  result += format_item_card(item)
170
+
171
  if academic:
172
  result += "### Academic Papers\n\n"
173
  for item in academic[:25]:
174
  result += format_item_card(item)
175
+
176
  return result
177
 
178
 
 
181
  items = get_items_by_emotion(emotion)
182
  if not items:
183
  return "No items found for this emotion."
184
+
185
  literary = [i for i in items if i.get("source_type") == "literary"]
186
  academic = [i for i in items if i.get("source_type") == "academic"]
187
+
188
  result = f"## Feeling {emotion.title() if emotion != 'All' else 'All Emotions'}?\n\n"
189
  result += f"*Found {len(items)} items ({len(literary)} literary, {len(academic)} academic)*\n\n"
190
+
191
  if literary:
192
  result += "### Literary Works\n\n"
193
  for item in literary[:25]:
194
  result += format_item_card(item)
195
+
196
  if academic:
197
  result += "### Academic Papers\n\n"
198
  for item in academic[:25]:
199
  result += format_item_card(item)
200
+
201
  return result
202
 
203
 
 
205
  """Search items by text content."""
206
  if not query or len(query) < 3:
207
  return "Enter at least 3 characters to search."
208
+
209
  query_lower = query.lower()
210
  matches = [
211
+ item
212
+ for item in ALL_ITEMS
213
  if query_lower in item.get("text", "").lower()
214
  or query_lower in item.get("generated_summary", "").lower()
215
  or query_lower in item.get("title", "").lower()
216
  ]
217
+
218
  if not matches:
219
  return f"No results found for '{query}'."
220
+
221
  result = f"## Search Results for '{query}'\n\n"
222
  result += f"*Found {len(matches)} matching items*\n\n"
223
+
224
  for item in matches[:30]:
225
  result += format_item_card(item)
226
+
227
  return result
228
 
229
 
 
235
  css="""
236
  .result-box { max-height: 700px; overflow-y: auto; }
237
  h3 { margin-top: 0.5em !important; }
238
+ """,
239
  ) as demo:
 
240
  gr.Markdown(
241
  """
242
  # LexiMind
 
245
  Browse **{total_count}** texts — {lit_count} classic books and {paper_count} research papers — analyzed by a multi-task transformer.
246
 
247
  ---
248
+ """.format(total_count=len(ALL_ITEMS), lit_count=len(BOOKS), paper_count=len(PAPERS))
 
 
 
 
249
  )
250
+
251
  with gr.Tabs():
252
  # ===================== TAB 1: BROWSE BY TOPIC =====================
253
  with gr.Tab("By Topic"):
254
  gr.Markdown("*Select a topic to explore related books and papers*")
255
+
256
  topic_dropdown = gr.Dropdown(
257
  choices=["All"] + TOPICS,
258
  value="All",
259
  label="Select Topic",
260
  interactive=True,
261
  )
262
+
263
  topic_results = gr.Markdown(
264
  value=browse_by_topic("All"),
265
  elem_classes=["result-box"],
266
  )
267
+
268
  topic_dropdown.change(
269
  fn=browse_by_topic,
270
  inputs=[topic_dropdown],
271
  outputs=[topic_results],
272
  )
273
+
274
  # ===================== TAB 2: BROWSE BY EMOTION =====================
275
  with gr.Tab("By Emotion"):
276
  gr.Markdown("*Find books and papers that evoke specific emotions*")
277
+
278
  emotion_dropdown = gr.Dropdown(
279
  choices=["All"] + [e.title() for e in EMOTIONS],
280
  value="All",
281
  label="Select Emotion",
282
  interactive=True,
283
  )
284
+
285
  emotion_results = gr.Markdown(
286
  value=browse_by_emotion("All"),
287
  elem_classes=["result-box"],
288
  )
289
+
290
  emotion_dropdown.change(
291
  fn=lambda e: browse_by_emotion(e.lower() if e != "All" else "All"),
292
  inputs=[emotion_dropdown],
293
  outputs=[emotion_results],
294
  )
295
+
296
  # ===================== TAB 3: SEARCH =====================
297
  with gr.Tab("Search"):
298
  gr.Markdown("*Search through all books and papers by keyword*")
299
+
300
  search_input = gr.Textbox(
301
  placeholder="Enter keywords to search...",
302
  label="Search",
303
  interactive=True,
304
  )
305
+
306
  search_results = gr.Markdown(
307
  value="Enter at least 3 characters to search.",
308
  elem_classes=["result-box"],
309
  )
310
+
311
  search_input.change(
312
  fn=search_items,
313
  inputs=[search_input],
314
  outputs=[search_results],
315
  )
316
+
317
  # ===================== TAB 4: METRICS =====================
318
  with gr.Tab("Metrics"):
319
  gr.Markdown(
 
323
  Computed on held-out validation data.
324
  """
325
  )
326
+
327
  # Summarization Metrics
328
  gr.Markdown("#### Summarization")
329
+
330
  if METRICS.get("summarization"):
331
  summ = METRICS["summarization"]
332
  summ_md = """
 
345
  gr.Markdown(summ_md)
346
  else:
347
  gr.Markdown("*Summarization metrics not available. Run evaluation script.*")
348
+
349
  # Topic Classification Metrics
350
  gr.Markdown("#### Topic Classification")
351
+
352
  if METRICS.get("topic"):
353
  topic = METRICS["topic"]
354
  topic_md = """
 
363
  gr.Markdown(topic_md)
364
  else:
365
  gr.Markdown("*Topic classification metrics not available.*")
366
+
367
  # Emotion Detection Metrics
368
  gr.Markdown("#### Emotion Detection")
369
+
370
  if METRICS.get("emotion"):
371
  emotion = METRICS["emotion"]
372
  emotion_md = """
 
378
 
379
  *28-label multi-label classification from GoEmotions.*
380
  """.format(
381
+ sample_f1=emotion.get(
382
+ "sample_avg_f1", emotion.get("f1", emotion.get("multilabel_f1", 0))
383
+ ),
384
  macro_f1=emotion.get("macro_f1", 0),
385
  micro_f1=emotion.get("micro_f1", 0),
386
  )
387
  gr.Markdown(emotion_md)
388
  else:
389
  gr.Markdown("*Emotion detection metrics not available.*")
390
+
391
  # Dataset Statistics
392
  gr.Markdown("#### Dataset Statistics")
393
+
394
  gr.Markdown(f"""
395
  | Statistic | Value |
396
  |-----------|-------|
 
400
  | Topics | {len(TOPICS)} |
401
  | Emotions | {len(EMOTIONS)} |
402
  """)
403
+
404
  # ===================== TAB 5: ABOUT =====================
405
  with gr.Tab("About"):
406
  gr.Markdown(
 
426
 
427
  if __name__ == "__main__":
428
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
scripts/download_data.py CHANGED
@@ -45,63 +45,128 @@ OUTPUT_DIR = Path(__file__).parent.parent / "data" / "processed"
45
 
46
  # 28 emotions from GoEmotions - works for all text types
47
  EMOTION_LABELS = [
48
- "admiration", "amusement", "anger", "annoyance", "approval", "caring",
49
- "confusion", "curiosity", "desire", "disappointment", "disapproval",
50
- "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
51
- "joy", "love", "nervousness", "optimism", "pride", "realization",
52
- "relief", "remorse", "sadness", "surprise", "neutral",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  ]
54
 
55
  # New topic labels for books + papers + blogs
56
  TOPIC_LABELS = [
57
- "Fiction", # Novels, short stories, literary fiction
58
- "Science", # Physics, chemistry, biology, nature
59
- "Technology", # CS, engineering, programming, AI/ML
60
- "Philosophy", # Ethics, logic, metaphysics, epistemology
61
- "History", # Historical texts, biographies, memoirs
62
- "Psychology", # Mind, behavior, self-help, mental health
63
- "Business", # Economics, finance, entrepreneurship
64
- "Arts", # Music, visual arts, film, architecture
65
  ]
66
 
67
  # arXiv category → our topic mapping
68
  ARXIV_CATEGORY_MAP = {
69
  # Computer Science
70
- "cs.AI": "Technology", "cs.CL": "Technology", "cs.CV": "Technology",
71
- "cs.LG": "Technology", "cs.NE": "Technology", "cs.RO": "Technology",
72
- "cs.SE": "Technology", "cs.PL": "Technology", "cs.DB": "Technology",
73
- "cs.DS": "Technology", "cs.CR": "Technology", "cs.DC": "Technology",
74
- "cs.HC": "Technology", "cs.IR": "Technology", "cs.IT": "Technology",
75
- "cs.MA": "Technology", "cs.MM": "Technology", "cs.NI": "Technology",
76
- "cs.OS": "Technology", "cs.PF": "Technology", "cs.SY": "Technology",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Physics
78
- "physics": "Science", "astro-ph": "Science", "cond-mat": "Science",
79
- "gr-qc": "Science", "hep-ex": "Science", "hep-lat": "Science",
80
- "hep-ph": "Science", "hep-th": "Science", "math-ph": "Science",
81
- "nlin": "Science", "nucl-ex": "Science", "nucl-th": "Science",
 
 
 
 
 
 
 
 
82
  "quant-ph": "Science",
83
  # Math
84
  "math": "Science",
85
  # Biology/Medicine
86
- "q-bio": "Science", "stat": "Science",
 
87
  # Economics/Finance
88
- "econ": "Business", "q-fin": "Business",
 
89
  # Electrical Engineering
90
  "eess": "Technology",
91
  }
92
 
93
  # Gutenberg subject → our topic mapping
94
  GUTENBERG_SUBJECT_MAP = {
95
- "fiction": "Fiction", "novel": "Fiction", "stories": "Fiction",
96
- "poetry": "Arts", "drama": "Arts", "plays": "Arts",
97
- "science": "Science", "physics": "Science", "chemistry": "Science",
98
- "biology": "Science", "nature": "Science", "astronomy": "Science",
99
- "philosophy": "Philosophy", "ethics": "Philosophy", "logic": "Philosophy",
100
- "history": "History", "biography": "History", "memoir": "History",
101
- "psychology": "Psychology", "mind": "Psychology",
102
- "economics": "Business", "business": "Business", "finance": "Business",
103
- "art": "Arts", "music": "Arts", "architecture": "Arts",
104
- "technology": "Technology", "engineering": "Technology",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
106
 
107
 
@@ -118,12 +183,69 @@ def write_jsonl(records: list[dict[str, Any]], path: Path, desc: str = "Writing"
118
 
119
  # Common English words for detection
120
  ENGLISH_WORDS = {
121
- "the", "and", "of", "to", "a", "in", "that", "is", "was", "he", "she", "it",
122
- "for", "with", "as", "his", "her", "they", "be", "at", "on", "have", "had",
123
- "this", "but", "not", "from", "by", "or", "an", "said", "were", "been",
124
- "would", "could", "which", "their", "there", "what", "when", "who", "will",
125
- "more", "if", "no", "out", "so", "up", "into", "than", "them", "can", "only",
126
- "other", "new", "some", "very", "just", "over", "such", "also", "its", "then",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  }
128
 
129
  # Non-English language patterns
@@ -144,72 +266,126 @@ NON_ENGLISH_PATTERNS = [
144
 
145
  # Patterns that indicate garbage/metadata text
146
  GARBAGE_PATTERNS = [
147
- r"^Page \d+:", # Page corrections
148
- r"changed to", # Errata
149
- r"Punctuation has been", # Editorial notes
150
- r"^\[.*\]$", # Bracketed notes
151
- r"^Note\.?[-—]", # Notes
152
- r"^follows:", # "as follows:"
153
- r"CHAPTER [IVXLC]+\.", # Chapter headers only
154
- r"^\*\*\*", # Project Gutenberg markers
155
- r"^End of.*Project", # End markers
156
- r"^Produced by", # Production credits
157
- r"transcriber", # Transcriber notes
158
- r"eBook", # eBook references
159
- r"©|copyright", # Copyright notices
160
- r"^INDEX", # Index pages
161
  r"^\d+\.\s+\w+,\s+\d+", # Index entries like "1. Name, 234"
162
- r"(syn\.|var\.|sp\.)", # Botanical abbreviations
163
- r"[A-Z][a-z]+aceae", # Botanical family names
164
- r"\(\s*syn\s+", # Synonym references
165
  ]
166
 
167
  # Patterns that indicate technical manuals/instructions (not narrative)
168
  TECHNICAL_PATTERNS = [
169
  r"\d+\.\s+It\s+(is|has|can)", # Numbered features "1. It is a..."
170
- r"^\d+(st|nd|rd|th)\.", # "1st. 2nd. 3rd."
171
- r"Mesh\.?\s*\d+", # Mesh sizes (pottery)
172
  r"\d+\s*(oz|lb|kg|g|ml|mm|cm|inch)", # Measurements
173
- r"Parts?\s*:?\s*\d+", # "Parts: 50"
174
- r"Method of Using", # Instructions
175
- r"How to\s+\w+", # How-to guides
176
- r"Step\s+\d+", # Step-by-step
177
- r"wire.*address", # Business instructions
178
- r"orders?\s+should\s+be", # Order instructions
179
- r"specifications?", # Technical specs
180
- r"(Front|Back)\s+Focus", # Camera terms
181
- r"Rack and Pinion", # Mechanical terms
182
  ]
183
 
184
  # Shakespeare and plays to exclude (model hallucinates on Early Modern English)
185
  EXCLUDED_TITLES = {
186
  # Shakespeare
187
- "King Lear", "Hamlet", "Macbeth", "Othello", "Romeo and Juliet",
188
- "A Midsummer Night's Dream", "The Tempest", "Julius Caesar",
189
- "The Merchant of Venice", "Twelfth Night", "Much Ado About Nothing",
190
- "As You Like It", "The Taming of the Shrew", "Antony and Cleopatra",
191
- "Coriolanus", "Cymbeline", "Timon of Athens", "Troilus and Cressida",
192
- "Measure for Measure", "All's Well That Ends Well", "Pericles",
193
- "The Winter's Tale", "The Comedy of Errors", "Two Gentlemen of Verona",
194
- "Love's Labour's Lost", "The Merry Wives of Windsor", "Henry IV",
195
- "Henry V", "Henry VI", "Henry VIII", "Richard II", "Richard III",
196
- "King John", "Titus Andronicus",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  # French plays
198
- "Tartuffe", "Phaedra", "Cyrano de Bergerac", "Cyrano De Bergerac",
199
- "Le Misanthrope", "The School for Wives", "The Miser", "The Imaginary Invalid",
200
- "Andromaque", "Britannicus", "Bérénice", "Le Cid",
 
 
 
 
 
 
 
 
 
201
  # Greek/Roman plays
202
- "Oedipus Rex", "Oedipus the King", "Antigone", "Electra", "Medea",
203
- "The Bacchae", "The Oresteia", "Agamemnon", "Prometheus Bound",
 
 
 
 
 
 
 
204
  # Other classic plays
205
- "The Importance of Being Earnest", "Pygmalion", "Doctor Faustus",
206
- "Waiting for Godot", "Death of a Salesman", "A Streetcar Named Desire",
207
- "The Glass Menagerie", "Our Town", "Long Day's Journey Into Night",
208
- "Who's Afraid of Virginia Woolf", "The Crucible", "Cat on a Hot Tin Roof",
 
 
 
 
 
 
 
 
209
  # Verse/poetic epics
210
- "Idylls of the King", "Paradise Lost", "Paradise Regained",
211
- "The Divine Comedy", "Inferno", "Purgatorio", "Paradiso",
212
- "The Faerie Queene", "Beowulf",
 
 
 
 
 
 
213
  }
214
 
215
 
@@ -227,25 +403,25 @@ def is_quality_text(text: str) -> bool:
227
  for pattern in GARBAGE_PATTERNS:
228
  if re.search(pattern, text, re.IGNORECASE | re.MULTILINE):
229
  return False
230
-
231
  # Reject technical manuals/instructions
232
  if is_technical_manual(text):
233
  return False
234
-
235
  # Must have reasonable length
236
  if len(text) < 300:
237
  return False
238
-
239
  # Must have sentences (not just fragments)
240
- sentences = re.split(r'[.!?]+', text)
241
  if len(sentences) < 4:
242
  return False
243
-
244
  # Check for too many special characters
245
  special_ratio = len(re.findall(r'[^\w\s.,!?\'"()-]', text)) / max(len(text), 1)
246
  if special_ratio > 0.08:
247
  return False
248
-
249
  return True
250
 
251
 
@@ -263,7 +439,7 @@ def is_play_text(text: str) -> bool:
263
  r"^[A-Z]{2,}\.\s", # Character names like "HAMLET."
264
  r"Alarum|Flourish|Sennet", # Stage directions
265
  ]
266
- lines = text.split('\n')[:10]
267
  play_indicators = 0
268
  for line in lines:
269
  for pattern in play_patterns:
@@ -275,182 +451,182 @@ def is_play_text(text: str) -> bool:
275
  def is_english_text(text: str, min_ratio: float = 0.08, max_foreign: int = 5) -> bool:
276
  """
277
  Check if text is primarily English.
278
-
279
  Args:
280
  text: Text to check
281
  min_ratio: Minimum ratio of common English words
282
  max_foreign: Maximum number of foreign word matches before rejecting
283
-
284
  Returns:
285
  True if text appears to be English
286
  """
287
  if not text or len(text) < 100:
288
  return False
289
-
290
  text_lower = text.lower()
291
  words = text_lower.split()
292
-
293
  if len(words) < 20:
294
  return False
295
-
296
  # Check for excessive non-English words
297
  for pattern in NON_ENGLISH_PATTERNS:
298
  matches = len(re.findall(pattern, text_lower))
299
  if matches > max_foreign:
300
  return False
301
-
302
  # Check for sufficient English words
303
  english_count = sum(1 for w in words if w.strip(".,!?;:'\"") in ENGLISH_WORDS)
304
  ratio = english_count / len(words)
305
-
306
  return ratio >= min_ratio
307
 
308
 
309
  def normalize_title(title: str) -> str:
310
  """Normalize a book title for matching."""
311
  # Remove common prefixes/suffixes
312
- title = re.sub(r'^(The|A|An)\s+', '', title, flags=re.IGNORECASE)
313
- title = re.sub(r'\s*\([^)]*\)\s*', '', title) # Remove parentheticals
314
- title = re.sub(r'\s*:.+$', '', title) # Remove subtitles
315
- title = re.sub(r'[^\w\s]', '', title) # Remove punctuation
316
  return title.lower().strip()
317
 
318
 
319
  # -------- SUMMARIZATION: BOOKS + ARXIV ----------
320
 
 
321
  def download_goodreads_descriptions() -> dict[str, dict]:
322
  """
323
  Download Goodreads book descriptions - back-cover style blurbs.
324
-
325
  These are "what the book is about" descriptions, not plot summaries.
326
  Returns dict mapping normalized title -> {title, description}
327
  """
328
  print("\nLoading Goodreads book descriptions...")
329
-
330
  descriptions = {}
331
-
332
  # Try multiple sources
333
  datasets_to_try = [
334
  "booksouls/goodreads-book-descriptions",
335
  "Skelebor/book_titles_and_descriptions_en_clean",
336
  ]
337
-
338
  for ds_name in datasets_to_try:
339
  try:
340
  print(f" Loading {ds_name}...")
341
  ds = load_dataset(ds_name, split="train")
342
-
343
  for item in tqdm(ds, desc="Goodreads", leave=False):
344
  title = item.get("title", "")
345
  description = item.get("description", "")
346
-
347
  if not title or not description:
348
  continue
349
-
350
  # Skip very short descriptions (not useful for training)
351
  if len(description) < 100:
352
  continue
353
-
354
  # Skip very long descriptions (truncate later)
355
  if len(description) > 2000:
356
  description = description[:2000]
357
-
358
  # Skip plays and excluded titles
359
  if is_excluded_title(title):
360
  continue
361
-
362
  # Skip non-English descriptions
363
  if not is_english_text(description):
364
  continue
365
-
366
  norm_title = normalize_title(title)
367
  if norm_title and norm_title not in descriptions:
368
  descriptions[norm_title] = {
369
  "title": title,
370
  "description": description,
371
  }
372
-
373
  print(f" Loaded {len(descriptions):,} descriptions from {ds_name}")
374
  except Exception as e:
375
  print(f" {ds_name} failed: {e}")
376
-
377
  print(f" Total: {len(descriptions):,} unique book descriptions")
378
  return descriptions
379
 
380
 
381
  def download_book_descriptions(
382
- goodreads_descriptions: dict[str, dict],
383
- max_samples: int = 20000
384
  ) -> list[dict[str, Any]]:
385
  """
386
  Download book description data by matching Gutenberg texts with Goodreads descriptions.
387
-
388
  This gives us (book_excerpt, book_description) training pairs where descriptions
389
  are back-cover style "what is this book about" blurbs, not plot summaries.
390
  """
391
  print("\nMatching Gutenberg books with Goodreads descriptions...")
392
-
393
  try:
394
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
395
  except Exception:
396
  gutenberg = load_dataset("pg19", split="train")
397
-
398
  records: list[dict[str, Any]] = []
399
  matched_titles = set()
400
  skipped_quality = 0
401
  skipped_play = 0
402
-
403
  indices = list(range(len(gutenberg)))
404
  random.shuffle(indices)
405
-
406
  for i in tqdm(indices, desc="Matching books", leave=False):
407
  if len(records) >= max_samples:
408
  break
409
-
410
  item = gutenberg[i]
411
  text = item.get("TEXT", "") or item.get("text", "")
412
  metadata_raw = item.get("METADATA", "") or "{}"
413
-
414
  # Parse metadata
415
  try:
416
  metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
417
  except (json.JSONDecodeError, TypeError):
418
  metadata = {}
419
-
420
  # Get title
421
  title = metadata.get("title", "") if isinstance(metadata, dict) else ""
422
  if not title:
423
  continue
424
-
425
  # Check if we have a Goodreads description for this book
426
  norm_title = normalize_title(title)
427
  if norm_title not in goodreads_descriptions:
428
  continue
429
-
430
  # Skip if already matched this book
431
  if norm_title in matched_titles:
432
  continue
433
-
434
  goodreads_data = goodreads_descriptions[norm_title]
435
-
436
  # Skip plays and excluded titles
437
  if is_excluded_title(title):
438
  skipped_play += 1
439
  continue
440
-
441
  if not text or len(text) < 2000:
442
  continue
443
-
444
  # Get a clean excerpt from the book (skip front matter)
445
- paragraphs = re.split(r'\n\s*\n', text)
446
  excerpt_parts = []
447
  total_len = 0
448
-
449
  for para in paragraphs[10:]: # Skip front matter
450
  para = para.strip()
451
  if len(para) < 100:
452
  continue
453
-
454
  # Quality check on paragraph
455
  if not is_english_text(para):
456
  continue
@@ -460,112 +636,119 @@ def download_book_descriptions(
460
  if not is_quality_text(para) and len(para) > 300:
461
  skipped_quality += 1
462
  continue
463
-
464
  excerpt_parts.append(para)
465
  total_len += len(para)
466
-
467
  if total_len >= 3000:
468
  break
469
-
470
  if total_len < 1000:
471
  continue
472
-
473
  book_excerpt = "\n\n".join(excerpt_parts)[:4000]
474
  matched_titles.add(norm_title)
475
-
476
- records.append({
477
- "source": book_excerpt,
478
- "summary": goodreads_data["description"][:800], # Back-cover blurbs are shorter
479
- "type": "literary",
480
- "title": goodreads_data["title"],
481
- })
482
-
 
 
483
  print(f" Matched {len(records):,} books with descriptions")
484
  print(f" Skipped: {skipped_quality} quality, {skipped_play} plays")
485
-
486
  return records
487
 
488
 
489
  # Keep BookSum for additional literary training (chapter summaries are still useful)
490
  def download_booksum(max_samples: int = 20000) -> list[dict[str, Any]]:
491
  """Download BookSum - literary chapter summarization (English only, quality filtered).
492
-
493
  Note: These are chapter-level plot summaries, useful as supplementary training data.
494
  The primary book training comes from Goodreads descriptions (back-cover style).
495
  """
496
  print("\nLoading BookSum (supplementary literary data)...")
497
-
498
  all_records: list[dict[str, Any]] = []
499
  booksum = load_dataset("kmfoda/booksum")
500
-
501
  for split_name in booksum.keys():
502
  split = str(split_name)
503
  data = booksum[split_name]
504
  limit = max_samples if "train" in split else max_samples // 10
505
  indices = random.sample(range(len(data)), min(len(data), limit))
506
-
507
  records = []
508
  skipped_language = 0
509
  skipped_excluded = 0
510
  skipped_play = 0
511
-
512
  for i in tqdm(indices, desc=f"BookSum {split}", leave=False):
513
  item = data[i]
514
  chapter = item.get("chapter", "")
515
  summary = item.get("summary_text") or item.get("summary", "")
516
-
517
  # Extract book title from book_id (e.g., "The Last of the Mohicans.chapters 1-2")
518
  book_id = item.get("book_id", "")
519
  book_title = book_id.split(".")[0] if "." in book_id else book_id
520
  chapter_name = item.get("summary_id", "") or item.get("summary_name", "")
521
-
522
  if not (chapter and summary and len(chapter) > 300):
523
  continue
524
-
525
  # Filter: excluded titles (Shakespeare, plays, etc.)
526
  if is_excluded_title(book_title):
527
  skipped_excluded += 1
528
  continue
529
-
530
  # Filter: play text format
531
  if is_play_text(chapter):
532
  skipped_play += 1
533
  continue
534
-
535
  # Filter: English only
536
  if not is_english_text(chapter):
537
  skipped_language += 1
538
  continue
539
-
540
  # Filter: quality text
541
  if not is_quality_text(chapter):
542
  continue
543
-
544
- records.append({
545
- "source": chapter[:4000],
546
- "summary": summary,
547
- "type": "literary",
548
- "split": split,
549
- "title": book_title,
550
- "chapter": chapter_name,
551
- })
 
 
552
  all_records.extend(records)
553
- print(f" {split}: {len(records):,} (skipped {skipped_language} non-English, {skipped_excluded} excluded, {skipped_play} plays)")
554
-
 
 
555
  return all_records
556
 
557
 
558
  def clean_arxiv_text(text: str) -> str:
559
  """Clean arXiv LaTeX-style text to make it more readable."""
560
  import re
 
561
  # Remove LaTeX math placeholders
562
- text = re.sub(r'@xmath\d+', '', text)
563
- text = re.sub(r'@xcite', '', text)
564
  # Remove excessive whitespace
565
- text = re.sub(r'\s+', ' ', text)
566
  # Remove LaTeX commands
567
- text = re.sub(r'\\[a-zA-Z]+\{[^}]*\}', '', text)
568
- text = re.sub(r'\\[a-zA-Z]+', '', text)
569
  return text.strip()
570
 
571
 
@@ -573,19 +756,19 @@ def extract_paper_title(abstract: str) -> str:
573
  """Extract a meaningful title from the first sentence of an abstract."""
574
  # Clean the abstract first
575
  abstract = clean_arxiv_text(abstract)
576
-
577
  # Get the first sentence (up to first period, question mark, or newline)
578
- first_sentence = re.split(r'[.!?\n]', abstract)[0].strip()
579
-
580
  # Truncate if too long
581
  if len(first_sentence) > 100:
582
  # Try to cut at a natural word boundary
583
- first_sentence = first_sentence[:100].rsplit(' ', 1)[0] + '...'
584
-
585
  # Capitalize first letter
586
  if first_sentence:
587
  first_sentence = first_sentence[0].upper() + first_sentence[1:]
588
-
589
  return first_sentence or "Untitled Paper"
590
 
591
 
@@ -593,202 +776,222 @@ def download_arxiv_summarization(max_samples: int = 50000) -> list[dict[str, Any
593
  """
594
  Download arXiv papers for academic summarization only (English only).
595
  Note: This dataset doesn't have categories, so can't be used for topic classification.
596
-
597
  Returns: summarization_records
598
  """
599
  print("\nLoading arXiv (academic papers for summarization)...")
600
-
601
  print(" Loading dataset (this may take a minute)...")
602
  arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
603
-
604
  summ_records: list[dict[str, Any]] = []
605
  skipped_language = 0
606
-
607
  indices = list(range(len(arxiv)))
608
  random.shuffle(indices)
609
-
610
  print(" Processing papers...")
611
- for i in tqdm(indices[:max_samples * 2], desc="arXiv", leave=False):
612
  if len(summ_records) >= max_samples:
613
  break
614
-
615
  item = arxiv[i]
616
-
617
  # Get abstract and article
618
  abstract = item.get("abstract", "")
619
  article = item.get("article", "")
620
-
621
  if not abstract or len(abstract) < 100:
622
  continue
623
-
624
  # Clean LaTeX artifacts
625
  abstract = clean_arxiv_text(abstract)
626
  article = clean_arxiv_text(article)
627
-
628
  # Skip if still has too many weird characters after cleaning
629
- if '@' in abstract or '@' in article[:500]:
630
  continue
631
-
632
  # Filter: English only
633
  if not is_english_text(article[:1000]):
634
  skipped_language += 1
635
  continue
636
-
637
  # Summarization: article → abstract
638
  if article and len(article) > 500:
639
  # Extract title from abstract
640
  paper_title = extract_paper_title(abstract)
641
-
642
- summ_records.append({
643
- "source": article[:4000],
644
- "summary": abstract,
645
- "type": "academic",
646
- "title": paper_title,
647
- })
648
-
 
 
649
  print(f" Summarization: {len(summ_records):,} (skipped {skipped_language} non-English)")
650
-
651
  return summ_records
652
 
653
 
654
  def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, Any]]:
655
  """
656
  Download topic classification data from multiple sources with real categories.
657
-
658
  Sources:
659
  - 20 Newsgroups (classic topic classification)
660
  - Wikipedia (article categories)
661
  """
662
  print("\nLoading topic classification datasets...")
663
-
664
  records: list[dict[str, Any]] = []
665
-
666
  # 20 Newsgroups - classic topic dataset
667
  print(" Loading 20 Newsgroups...")
668
  try:
669
  newsgroups = load_dataset("SetFit/20_newsgroups", split="train")
670
-
671
  # Map 20 newsgroups categories to our 8 topics
672
  newsgroup_map = {
673
  # Science
674
- "sci.crypt": "Science", "sci.electronics": "Science",
675
- "sci.med": "Science", "sci.space": "Science",
676
- # Technology
677
- "comp.graphics": "Technology", "comp.os.ms-windows.misc": "Technology",
678
- "comp.sys.ibm.pc.hardware": "Technology", "comp.sys.mac.hardware": "Technology",
 
 
 
 
679
  "comp.windows.x": "Technology",
680
  # Philosophy/Religion
681
- "alt.atheism": "Philosophy", "soc.religion.christian": "Philosophy",
 
682
  "talk.religion.misc": "Philosophy",
683
  # History/Politics
684
- "talk.politics.guns": "History", "talk.politics.mideast": "History",
 
685
  "talk.politics.misc": "History",
686
  # Business
687
  "misc.forsale": "Business",
688
  # Sports/Recreation
689
- "rec.autos": "Arts", "rec.motorcycles": "Arts",
690
- "rec.sport.baseball": "Arts", "rec.sport.hockey": "Arts",
 
 
691
  }
692
-
693
  for item in tqdm(newsgroups, desc="20 Newsgroups", leave=False):
694
  if len(records) >= max_samples:
695
  break
696
  label_name = item.get("label_text", "")
697
  text = item.get("text", "")
698
-
699
  if label_name in newsgroup_map and text and len(text) > 100:
700
- records.append({
701
- "text": text[:1500],
702
- "topic": newsgroup_map[label_name],
703
- "source": "newsgroups",
704
- })
705
-
 
 
706
  print(f" 20 Newsgroups: {len(records):,}")
707
  except Exception as e:
708
  print(f" 20 Newsgroups failed: {e}")
709
-
710
  # Add from Gutenberg for Fiction
711
  gutenberg_topics = download_gutenberg_topics(max_samples // 4)
712
  records.extend(gutenberg_topics)
713
-
714
  # Add from scientific papers abstract dataset for more Science/Tech
715
  print(" Loading scientific papers...")
716
  try:
717
  sci_papers = load_dataset("scientific_papers", "arxiv", split="train", streaming=True)
718
  sci_count = 0
719
- for item in tqdm(sci_papers, desc="Scientific papers", leave=False, total=max_samples//4):
720
  if sci_count >= max_samples // 4:
721
  break
722
  abstract = item.get("abstract", "")
723
  if abstract and len(abstract) > 100:
724
  # Alternate between Science and Technology
725
  topic = "Science" if sci_count % 2 == 0 else "Technology"
726
- records.append({
727
- "text": abstract[:1500],
728
- "topic": topic,
729
- "source": "scientific_papers",
730
- })
 
 
731
  sci_count += 1
732
  print(f" Scientific papers: {sci_count:,}")
733
  except Exception as e:
734
  print(f" Scientific papers failed: {e}")
735
-
736
  return records
737
 
738
 
739
  def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> None:
740
  """Download all summarization data (books + arxiv, NO news).
741
-
742
  Book data now uses Goodreads descriptions (back-cover blurbs) instead of
743
  plot summaries. This trains the model to describe "what the book is about"
744
  rather than summarizing the plot.
745
  """
746
  print("\nDownloading Summarization Data...")
747
  out_dir = OUTPUT_DIR / "summarization"
748
-
749
  all_records: list[dict[str, Any]] = []
750
-
751
  # Goodreads descriptions - primary book training data (back-cover style)
752
  goodreads_descriptions = download_goodreads_descriptions()
753
  book_records = download_book_descriptions(goodreads_descriptions, max_books)
754
  all_records.extend(book_records)
755
-
756
  # Optional: Add some BookSum for additional literary variety
757
  # These are chapter summaries, not back-cover style, so keep limited
758
  # booksum_records = download_booksum(max_books // 4)
759
  # all_records.extend(booksum_records)
760
-
761
  # arXiv - academic (abstracts are already "what is this paper about")
762
  arxiv_summ = download_arxiv_summarization(max_arxiv)
763
  all_records.extend(arxiv_summ)
764
-
765
  # Shuffle and split
766
  random.shuffle(all_records)
767
-
768
  # Split by original split if available, else 90/5/5
769
- train_records = [r for r in all_records if r.get("split", "train") == "train" or "split" not in r]
 
 
770
  val_records = [r for r in all_records if r.get("split") == "validation"]
771
  test_records = [r for r in all_records if r.get("split") == "test"]
772
-
773
  # If no split info, do 90/5/5
774
  if len(val_records) < 100:
775
  n = len(train_records)
776
  random.shuffle(train_records)
777
- val_records = train_records[int(n*0.9):int(n*0.95)]
778
- test_records = train_records[int(n*0.95):]
779
- train_records = train_records[:int(n*0.9)]
780
-
781
  # Remove split key before saving
782
  for r in train_records + val_records + test_records:
783
  r.pop("split", None)
784
-
785
  write_jsonl(train_records, out_dir / "train.jsonl", "train")
786
  write_jsonl(val_records, out_dir / "validation.jsonl", "val")
787
  write_jsonl(test_records, out_dir / "test.jsonl", "test")
788
-
789
  # Print breakdown
790
- literary_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "literary")
791
- academic_count = sum(1 for r in train_records + val_records + test_records if r.get("type") == "academic")
 
 
 
 
792
  print(f"\n Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
793
  print(f" Literary (book descriptions): {literary_count:,}")
794
  print(f" Academic (paper abstracts): {academic_count:,}")
@@ -796,10 +999,11 @@ def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> No
796
 
797
  # ------------ TOPIC CLASSIFICATION ------------
798
 
 
799
  def download_topics(max_samples: int = 50000) -> None:
800
  """
801
  Download topic classification data from multiple sources.
802
-
803
  Sources:
804
  - 20 Newsgroups (classic topic dataset)
805
  - Gutenberg books (Fiction)
@@ -807,49 +1011,49 @@ def download_topics(max_samples: int = 50000) -> None:
807
  """
808
  print("\nDownloading Topic Classification...")
809
  out_dir = OUTPUT_DIR / "topic"
810
-
811
  # Get topic records from various sources
812
  all_records = download_topics_from_datasets(max_samples)
813
-
814
  # Balance topics
815
  topic_counts: dict[str, list] = {t: [] for t in TOPIC_LABELS}
816
  for r in all_records:
817
  topic = r.get("topic")
818
  if topic in topic_counts:
819
  topic_counts[topic].append(r)
820
-
821
  # Print distribution before balancing
822
  print("\n Topic distribution (before balancing):")
823
  for topic, records in topic_counts.items():
824
  print(f" {topic}: {len(records):,}")
825
-
826
  # Balance to min count (with some tolerance) - only from topics that have data
827
  counts_with_data = [len(v) for v in topic_counts.values() if v]
828
  if not counts_with_data:
829
  print(" Warning: No topic data found!")
830
  return
831
-
832
  min_count = min(counts_with_data)
833
  target_count = min(min_count, max_samples // len(TOPIC_LABELS))
834
-
835
  balanced: list[dict[str, Any]] = []
836
  for _topic, records in topic_counts.items():
837
  if records:
838
  random.shuffle(records)
839
  balanced.extend(records[:target_count])
840
-
841
  random.shuffle(balanced)
842
-
843
  # Split 90/5/5
844
  n = len(balanced)
845
- train_records = balanced[:int(n*0.9)]
846
- val_records = balanced[int(n*0.9):int(n*0.95)]
847
- test_records = balanced[int(n*0.95):]
848
-
849
  write_jsonl(train_records, out_dir / "train.jsonl", "train")
850
  write_jsonl(val_records, out_dir / "validation.jsonl", "val")
851
  write_jsonl(test_records, out_dir / "test.jsonl", "test")
852
-
853
  # Save labels - only labels that have data
854
  used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
855
  (out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
@@ -859,82 +1063,85 @@ def download_topics(max_samples: int = 50000) -> None:
859
  def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
860
  """Extract topic-labeled samples from Gutenberg books (English only)."""
861
  print("\nLoading Gutenberg for topic classification...")
862
-
863
  try:
864
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
865
  except Exception:
866
  print(" Trying pg19...")
867
  gutenberg = load_dataset("pg19", split="train")
868
-
869
  records: list[dict[str, Any]] = []
870
  skipped_language = 0
871
-
872
  indices = list(range(len(gutenberg)))
873
  random.shuffle(indices)
874
-
875
  for i in tqdm(indices, desc="Gutenberg topics", leave=False):
876
  if len(records) >= max_samples:
877
  break
878
-
879
  item = gutenberg[i]
880
  text = item.get("TEXT", "") or item.get("text", "")
881
  metadata = item.get("METADATA", {}) or {}
882
-
883
  if not text or len(text) < 1000:
884
  continue
885
-
886
  # Try to determine topic from metadata
887
  subjects = ""
888
  if isinstance(metadata, dict):
889
  subjects = str(metadata.get("subjects", "")).lower()
890
  subjects += " " + str(metadata.get("subject", "")).lower()
891
  subjects += " " + str(metadata.get("category", "")).lower()
892
-
893
  topic = None
894
  for keyword, mapped_topic in GUTENBERG_SUBJECT_MAP.items():
895
  if keyword in subjects:
896
  topic = mapped_topic
897
  break
898
-
899
  # Default fiction for novels without clear subject
900
  if not topic and ("novel" in subjects or not subjects.strip()):
901
  topic = "Fiction"
902
-
903
  if topic:
904
  # Get a clean paragraph as sample
905
- paragraphs = re.split(r'\n\s*\n', text)
906
  for para in paragraphs[5:]: # Skip front matter
907
  para = para.strip()
908
- if 200 < len(para) < 1500 and para.count('.') >= 2:
909
  # Filter: English only
910
  if not is_english_text(para):
911
  skipped_language += 1
912
  break
913
-
914
- records.append({
915
- "text": para,
916
- "topic": topic,
917
- "source": "gutenberg",
918
- })
 
 
919
  break
920
-
921
  print(f" Gutenberg topics: {len(records):,} (skipped {skipped_language} non-English)")
922
  return records
923
 
924
 
925
  # ------------ EMOTIONS (unchanged) -------------
926
 
 
927
  def download_emotions() -> None:
928
  """Download GoEmotions for emotion classification."""
929
  print("\nDownloading Emotions (GoEmotions)...")
930
  out_dir = OUTPUT_DIR / "emotion"
931
-
932
  ds = load_dataset("google-research-datasets/go_emotions", "simplified")
933
-
934
  for split_name in ds.keys():
935
  split = str(split_name)
936
  data = ds[split_name]
937
-
938
  records: list[dict[str, Any]] = []
939
  for item in tqdm(data, desc=split, leave=False):
940
  text = item.get("text", "")
@@ -944,7 +1151,7 @@ def download_emotions() -> None:
944
  if emotions:
945
  records.append({"text": text, "emotions": emotions})
946
  write_jsonl(records, out_dir / f"{split}.jsonl", split)
947
-
948
  (out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
949
  print(f" {len(EMOTION_LABELS)} emotion labels saved")
950
 
@@ -952,12 +1159,23 @@ def download_emotions() -> None:
952
  # --------------- GUTENBERG BOOKS (for language modeling) ---------------
953
 
954
  GUTENBERG_JUNK_PATTERNS = [
955
- r"Project Gutenberg", r"www\.gutenberg\.org", r"This ebook is for",
956
- r"Gutenberg License", r"^\*\*\* START OF", r"^\*\*\* END OF",
957
- r"Produced by", r"Transcriber's Note", r"TABLE OF CONTENTS",
958
- r"^\s*CHAPTER\s+[IVXLC\d]+", r"^\s*Chapter\s+[IVXLC\d]+",
959
- r"^\s*BOOK\s+[IVXLC\d]+", r"^\s*PREFACE\s*$", r"^\s*INTRODUCTION\s*$",
960
- r"E-text prepared by", r"Internet Archive", r"Distributed Proofreaders",
 
 
 
 
 
 
 
 
 
 
 
961
  ]
962
  GUTENBERG_JUNK_REGEX = re.compile("|".join(GUTENBERG_JUNK_PATTERNS), re.IGNORECASE)
963
 
@@ -968,7 +1186,7 @@ def is_clean_prose(text: str) -> bool:
968
  return False
969
  if GUTENBERG_JUNK_REGEX.search(text):
970
  return False
971
- if text.count('.') < 2:
972
  return False
973
  uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
974
  if uppercase_ratio > 0.3:
@@ -987,68 +1205,66 @@ def download_gutenberg(max_samples: int = 30000) -> None:
987
  print("\nDownloading Gutenberg Books (English only)...")
988
  out_dir = OUTPUT_DIR / "books"
989
  out_dir.mkdir(parents=True, exist_ok=True)
990
-
991
  try:
992
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
993
  except Exception:
994
  gutenberg = load_dataset("pg19", split="train")
995
-
996
  records: list[dict[str, Any]] = []
997
  indices = list(range(len(gutenberg)))
998
  random.shuffle(indices)
999
-
1000
  for i in tqdm(indices, desc="Books", leave=False):
1001
  if len(records) >= max_samples:
1002
  break
1003
-
1004
  item = gutenberg[i]
1005
  text = item.get("TEXT", "") or item.get("text", "")
1006
  metadata_raw = item.get("METADATA", "") or "{}"
1007
-
1008
  # Parse metadata - it's stored as JSON string
1009
  try:
1010
  metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
1011
  except (json.JSONDecodeError, TypeError):
1012
  metadata = {}
1013
-
1014
  # Extract title and author
1015
  title = metadata.get("title", "") if isinstance(metadata, dict) else ""
1016
  author = metadata.get("author", "") if isinstance(metadata, dict) else ""
1017
  if not title:
1018
  title = item.get("title", f"Unknown Book #{i}")
1019
-
1020
  if not text or len(text) < 1000:
1021
  continue
1022
-
1023
- paragraphs = re.split(r'\n\s*\n', text)
1024
  for para in paragraphs:
1025
  para = para.strip()
1026
  if is_clean_prose(para):
1027
- records.append({
1028
- "text": para,
1029
- "title": title,
1030
- "author": author,
1031
- "type": "gutenberg"
1032
- })
1033
  if len(records) >= max_samples:
1034
  break
1035
-
1036
  random.shuffle(records)
1037
  n = len(records)
1038
- write_jsonl(records[:int(n*0.9)], out_dir / "train.jsonl", "train")
1039
- write_jsonl(records[int(n*0.9):int(n*0.95)], out_dir / "validation.jsonl", "val")
1040
- write_jsonl(records[int(n*0.95):], out_dir / "test.jsonl", "test")
1041
 
1042
 
1043
  # ------------ MAIN ------------
1044
 
 
1045
  def main() -> None:
1046
  parser = argparse.ArgumentParser(description="Download LexiMind datasets")
1047
  parser.add_argument(
1048
  "--task",
1049
  choices=["all", "summarization", "emotion", "topic", "gutenberg"],
1050
  default="all",
1051
- help="Dataset to download"
1052
  )
1053
  parser.add_argument("--max-books", type=int, default=40000, help="Max BookSum samples")
1054
  parser.add_argument("--max-arxiv", type=int, default=50000, help="Max arXiv samples")
@@ -1056,14 +1272,14 @@ def main() -> None:
1056
  parser.add_argument("--max-topics", type=int, default=50000, help="Max topic samples")
1057
  parser.add_argument("--seed", type=int, default=42, help="Random seed")
1058
  args = parser.parse_args()
1059
-
1060
  random.seed(args.seed)
1061
-
1062
  print("=" * 60)
1063
  print("LexiMind Dataset Download")
1064
  print("Books + Academic Papers + Topic Classification")
1065
  print("=" * 60)
1066
-
1067
  if args.task in ["all", "summarization"]:
1068
  download_summarization(args.max_books, args.max_arxiv)
1069
  if args.task in ["all", "emotion"]:
@@ -1072,7 +1288,7 @@ def main() -> None:
1072
  download_topics(args.max_topics)
1073
  if args.task in ["all", "gutenberg"]:
1074
  download_gutenberg(args.max_gutenberg)
1075
-
1076
  print("\n" + "=" * 60)
1077
  print("Download complete!")
1078
  print("=" * 60)
 
45
 
46
  # 28 emotions from GoEmotions - works for all text types
47
  EMOTION_LABELS = [
48
+ "admiration",
49
+ "amusement",
50
+ "anger",
51
+ "annoyance",
52
+ "approval",
53
+ "caring",
54
+ "confusion",
55
+ "curiosity",
56
+ "desire",
57
+ "disappointment",
58
+ "disapproval",
59
+ "disgust",
60
+ "embarrassment",
61
+ "excitement",
62
+ "fear",
63
+ "gratitude",
64
+ "grief",
65
+ "joy",
66
+ "love",
67
+ "nervousness",
68
+ "optimism",
69
+ "pride",
70
+ "realization",
71
+ "relief",
72
+ "remorse",
73
+ "sadness",
74
+ "surprise",
75
+ "neutral",
76
  ]
77
 
78
  # New topic labels for books + papers + blogs
79
  TOPIC_LABELS = [
80
+ "Fiction", # Novels, short stories, literary fiction
81
+ "Science", # Physics, chemistry, biology, nature
82
+ "Technology", # CS, engineering, programming, AI/ML
83
+ "Philosophy", # Ethics, logic, metaphysics, epistemology
84
+ "History", # Historical texts, biographies, memoirs
85
+ "Psychology", # Mind, behavior, self-help, mental health
86
+ "Business", # Economics, finance, entrepreneurship
87
+ "Arts", # Music, visual arts, film, architecture
88
  ]
89
 
90
  # arXiv category → our topic mapping
91
  ARXIV_CATEGORY_MAP = {
92
  # Computer Science
93
+ "cs.AI": "Technology",
94
+ "cs.CL": "Technology",
95
+ "cs.CV": "Technology",
96
+ "cs.LG": "Technology",
97
+ "cs.NE": "Technology",
98
+ "cs.RO": "Technology",
99
+ "cs.SE": "Technology",
100
+ "cs.PL": "Technology",
101
+ "cs.DB": "Technology",
102
+ "cs.DS": "Technology",
103
+ "cs.CR": "Technology",
104
+ "cs.DC": "Technology",
105
+ "cs.HC": "Technology",
106
+ "cs.IR": "Technology",
107
+ "cs.IT": "Technology",
108
+ "cs.MA": "Technology",
109
+ "cs.MM": "Technology",
110
+ "cs.NI": "Technology",
111
+ "cs.OS": "Technology",
112
+ "cs.PF": "Technology",
113
+ "cs.SY": "Technology",
114
  # Physics
115
+ "physics": "Science",
116
+ "astro-ph": "Science",
117
+ "cond-mat": "Science",
118
+ "gr-qc": "Science",
119
+ "hep-ex": "Science",
120
+ "hep-lat": "Science",
121
+ "hep-ph": "Science",
122
+ "hep-th": "Science",
123
+ "math-ph": "Science",
124
+ "nlin": "Science",
125
+ "nucl-ex": "Science",
126
+ "nucl-th": "Science",
127
  "quant-ph": "Science",
128
  # Math
129
  "math": "Science",
130
  # Biology/Medicine
131
+ "q-bio": "Science",
132
+ "stat": "Science",
133
  # Economics/Finance
134
+ "econ": "Business",
135
+ "q-fin": "Business",
136
  # Electrical Engineering
137
  "eess": "Technology",
138
  }
139
 
140
  # Gutenberg subject → our topic mapping
141
  GUTENBERG_SUBJECT_MAP = {
142
+ "fiction": "Fiction",
143
+ "novel": "Fiction",
144
+ "stories": "Fiction",
145
+ "poetry": "Arts",
146
+ "drama": "Arts",
147
+ "plays": "Arts",
148
+ "science": "Science",
149
+ "physics": "Science",
150
+ "chemistry": "Science",
151
+ "biology": "Science",
152
+ "nature": "Science",
153
+ "astronomy": "Science",
154
+ "philosophy": "Philosophy",
155
+ "ethics": "Philosophy",
156
+ "logic": "Philosophy",
157
+ "history": "History",
158
+ "biography": "History",
159
+ "memoir": "History",
160
+ "psychology": "Psychology",
161
+ "mind": "Psychology",
162
+ "economics": "Business",
163
+ "business": "Business",
164
+ "finance": "Business",
165
+ "art": "Arts",
166
+ "music": "Arts",
167
+ "architecture": "Arts",
168
+ "technology": "Technology",
169
+ "engineering": "Technology",
170
  }
171
 
172
 
 
183
 
184
  # Common English words for detection
185
  ENGLISH_WORDS = {
186
+ "the",
187
+ "and",
188
+ "of",
189
+ "to",
190
+ "a",
191
+ "in",
192
+ "that",
193
+ "is",
194
+ "was",
195
+ "he",
196
+ "she",
197
+ "it",
198
+ "for",
199
+ "with",
200
+ "as",
201
+ "his",
202
+ "her",
203
+ "they",
204
+ "be",
205
+ "at",
206
+ "on",
207
+ "have",
208
+ "had",
209
+ "this",
210
+ "but",
211
+ "not",
212
+ "from",
213
+ "by",
214
+ "or",
215
+ "an",
216
+ "said",
217
+ "were",
218
+ "been",
219
+ "would",
220
+ "could",
221
+ "which",
222
+ "their",
223
+ "there",
224
+ "what",
225
+ "when",
226
+ "who",
227
+ "will",
228
+ "more",
229
+ "if",
230
+ "no",
231
+ "out",
232
+ "so",
233
+ "up",
234
+ "into",
235
+ "than",
236
+ "them",
237
+ "can",
238
+ "only",
239
+ "other",
240
+ "new",
241
+ "some",
242
+ "very",
243
+ "just",
244
+ "over",
245
+ "such",
246
+ "also",
247
+ "its",
248
+ "then",
249
  }
250
 
251
  # Non-English language patterns
 
266
 
267
  # Patterns that indicate garbage/metadata text
268
  GARBAGE_PATTERNS = [
269
+ r"^Page \d+:", # Page corrections
270
+ r"changed to", # Errata
271
+ r"Punctuation has been", # Editorial notes
272
+ r"^\[.*\]$", # Bracketed notes
273
+ r"^Note\.?[-—]", # Notes
274
+ r"^follows:", # "as follows:"
275
+ r"CHAPTER [IVXLC]+\.", # Chapter headers only
276
+ r"^\*\*\*", # Project Gutenberg markers
277
+ r"^End of.*Project", # End markers
278
+ r"^Produced by", # Production credits
279
+ r"transcriber", # Transcriber notes
280
+ r"eBook", # eBook references
281
+ r"©|copyright", # Copyright notices
282
+ r"^INDEX", # Index pages
283
  r"^\d+\.\s+\w+,\s+\d+", # Index entries like "1. Name, 234"
284
+ r"(syn\.|var\.|sp\.)", # Botanical abbreviations
285
+ r"[A-Z][a-z]+aceae", # Botanical family names
286
+ r"\(\s*syn\s+", # Synonym references
287
  ]
288
 
289
  # Patterns that indicate technical manuals/instructions (not narrative)
290
  TECHNICAL_PATTERNS = [
291
  r"\d+\.\s+It\s+(is|has|can)", # Numbered features "1. It is a..."
292
+ r"^\d+(st|nd|rd|th)\.", # "1st. 2nd. 3rd."
293
+ r"Mesh\.?\s*\d+", # Mesh sizes (pottery)
294
  r"\d+\s*(oz|lb|kg|g|ml|mm|cm|inch)", # Measurements
295
+ r"Parts?\s*:?\s*\d+", # "Parts: 50"
296
+ r"Method of Using", # Instructions
297
+ r"How to\s+\w+", # How-to guides
298
+ r"Step\s+\d+", # Step-by-step
299
+ r"wire.*address", # Business instructions
300
+ r"orders?\s+should\s+be", # Order instructions
301
+ r"specifications?", # Technical specs
302
+ r"(Front|Back)\s+Focus", # Camera terms
303
+ r"Rack and Pinion", # Mechanical terms
304
  ]
305
 
306
  # Shakespeare and plays to exclude (model hallucinates on Early Modern English)
307
  EXCLUDED_TITLES = {
308
  # Shakespeare
309
+ "King Lear",
310
+ "Hamlet",
311
+ "Macbeth",
312
+ "Othello",
313
+ "Romeo and Juliet",
314
+ "A Midsummer Night's Dream",
315
+ "The Tempest",
316
+ "Julius Caesar",
317
+ "The Merchant of Venice",
318
+ "Twelfth Night",
319
+ "Much Ado About Nothing",
320
+ "As You Like It",
321
+ "The Taming of the Shrew",
322
+ "Antony and Cleopatra",
323
+ "Coriolanus",
324
+ "Cymbeline",
325
+ "Timon of Athens",
326
+ "Troilus and Cressida",
327
+ "Measure for Measure",
328
+ "All's Well That Ends Well",
329
+ "Pericles",
330
+ "The Winter's Tale",
331
+ "The Comedy of Errors",
332
+ "Two Gentlemen of Verona",
333
+ "Love's Labour's Lost",
334
+ "The Merry Wives of Windsor",
335
+ "Henry IV",
336
+ "Henry V",
337
+ "Henry VI",
338
+ "Henry VIII",
339
+ "Richard II",
340
+ "Richard III",
341
+ "King John",
342
+ "Titus Andronicus",
343
  # French plays
344
+ "Tartuffe",
345
+ "Phaedra",
346
+ "Cyrano de Bergerac",
347
+ "Cyrano De Bergerac",
348
+ "Le Misanthrope",
349
+ "The School for Wives",
350
+ "The Miser",
351
+ "The Imaginary Invalid",
352
+ "Andromaque",
353
+ "Britannicus",
354
+ "Bérénice",
355
+ "Le Cid",
356
  # Greek/Roman plays
357
+ "Oedipus Rex",
358
+ "Oedipus the King",
359
+ "Antigone",
360
+ "Electra",
361
+ "Medea",
362
+ "The Bacchae",
363
+ "The Oresteia",
364
+ "Agamemnon",
365
+ "Prometheus Bound",
366
  # Other classic plays
367
+ "The Importance of Being Earnest",
368
+ "Pygmalion",
369
+ "Doctor Faustus",
370
+ "Waiting for Godot",
371
+ "Death of a Salesman",
372
+ "A Streetcar Named Desire",
373
+ "The Glass Menagerie",
374
+ "Our Town",
375
+ "Long Day's Journey Into Night",
376
+ "Who's Afraid of Virginia Woolf",
377
+ "The Crucible",
378
+ "Cat on a Hot Tin Roof",
379
  # Verse/poetic epics
380
+ "Idylls of the King",
381
+ "Paradise Lost",
382
+ "Paradise Regained",
383
+ "The Divine Comedy",
384
+ "Inferno",
385
+ "Purgatorio",
386
+ "Paradiso",
387
+ "The Faerie Queene",
388
+ "Beowulf",
389
  }
390
 
391
 
 
403
  for pattern in GARBAGE_PATTERNS:
404
  if re.search(pattern, text, re.IGNORECASE | re.MULTILINE):
405
  return False
406
+
407
  # Reject technical manuals/instructions
408
  if is_technical_manual(text):
409
  return False
410
+
411
  # Must have reasonable length
412
  if len(text) < 300:
413
  return False
414
+
415
  # Must have sentences (not just fragments)
416
+ sentences = re.split(r"[.!?]+", text)
417
  if len(sentences) < 4:
418
  return False
419
+
420
  # Check for too many special characters
421
  special_ratio = len(re.findall(r'[^\w\s.,!?\'"()-]', text)) / max(len(text), 1)
422
  if special_ratio > 0.08:
423
  return False
424
+
425
  return True
426
 
427
 
 
439
  r"^[A-Z]{2,}\.\s", # Character names like "HAMLET."
440
  r"Alarum|Flourish|Sennet", # Stage directions
441
  ]
442
+ lines = text.split("\n")[:10]
443
  play_indicators = 0
444
  for line in lines:
445
  for pattern in play_patterns:
 
451
  def is_english_text(text: str, min_ratio: float = 0.08, max_foreign: int = 5) -> bool:
452
  """
453
  Check if text is primarily English.
454
+
455
  Args:
456
  text: Text to check
457
  min_ratio: Minimum ratio of common English words
458
  max_foreign: Maximum number of foreign word matches before rejecting
459
+
460
  Returns:
461
  True if text appears to be English
462
  """
463
  if not text or len(text) < 100:
464
  return False
465
+
466
  text_lower = text.lower()
467
  words = text_lower.split()
468
+
469
  if len(words) < 20:
470
  return False
471
+
472
  # Check for excessive non-English words
473
  for pattern in NON_ENGLISH_PATTERNS:
474
  matches = len(re.findall(pattern, text_lower))
475
  if matches > max_foreign:
476
  return False
477
+
478
  # Check for sufficient English words
479
  english_count = sum(1 for w in words if w.strip(".,!?;:'\"") in ENGLISH_WORDS)
480
  ratio = english_count / len(words)
481
+
482
  return ratio >= min_ratio
483
 
484
 
485
  def normalize_title(title: str) -> str:
486
  """Normalize a book title for matching."""
487
  # Remove common prefixes/suffixes
488
+ title = re.sub(r"^(The|A|An)\s+", "", title, flags=re.IGNORECASE)
489
+ title = re.sub(r"\s*\([^)]*\)\s*", "", title) # Remove parentheticals
490
+ title = re.sub(r"\s*:.+$", "", title) # Remove subtitles
491
+ title = re.sub(r"[^\w\s]", "", title) # Remove punctuation
492
  return title.lower().strip()
493
 
494
 
495
  # -------- SUMMARIZATION: BOOKS + ARXIV ----------
496
 
497
+
498
  def download_goodreads_descriptions() -> dict[str, dict]:
499
  """
500
  Download Goodreads book descriptions - back-cover style blurbs.
501
+
502
  These are "what the book is about" descriptions, not plot summaries.
503
  Returns dict mapping normalized title -> {title, description}
504
  """
505
  print("\nLoading Goodreads book descriptions...")
506
+
507
  descriptions = {}
508
+
509
  # Try multiple sources
510
  datasets_to_try = [
511
  "booksouls/goodreads-book-descriptions",
512
  "Skelebor/book_titles_and_descriptions_en_clean",
513
  ]
514
+
515
  for ds_name in datasets_to_try:
516
  try:
517
  print(f" Loading {ds_name}...")
518
  ds = load_dataset(ds_name, split="train")
519
+
520
  for item in tqdm(ds, desc="Goodreads", leave=False):
521
  title = item.get("title", "")
522
  description = item.get("description", "")
523
+
524
  if not title or not description:
525
  continue
526
+
527
  # Skip very short descriptions (not useful for training)
528
  if len(description) < 100:
529
  continue
530
+
531
  # Skip very long descriptions (truncate later)
532
  if len(description) > 2000:
533
  description = description[:2000]
534
+
535
  # Skip plays and excluded titles
536
  if is_excluded_title(title):
537
  continue
538
+
539
  # Skip non-English descriptions
540
  if not is_english_text(description):
541
  continue
542
+
543
  norm_title = normalize_title(title)
544
  if norm_title and norm_title not in descriptions:
545
  descriptions[norm_title] = {
546
  "title": title,
547
  "description": description,
548
  }
549
+
550
  print(f" Loaded {len(descriptions):,} descriptions from {ds_name}")
551
  except Exception as e:
552
  print(f" {ds_name} failed: {e}")
553
+
554
  print(f" Total: {len(descriptions):,} unique book descriptions")
555
  return descriptions
556
 
557
 
558
  def download_book_descriptions(
559
+ goodreads_descriptions: dict[str, dict], max_samples: int = 20000
 
560
  ) -> list[dict[str, Any]]:
561
  """
562
  Download book description data by matching Gutenberg texts with Goodreads descriptions.
563
+
564
  This gives us (book_excerpt, book_description) training pairs where descriptions
565
  are back-cover style "what is this book about" blurbs, not plot summaries.
566
  """
567
  print("\nMatching Gutenberg books with Goodreads descriptions...")
568
+
569
  try:
570
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
571
  except Exception:
572
  gutenberg = load_dataset("pg19", split="train")
573
+
574
  records: list[dict[str, Any]] = []
575
  matched_titles = set()
576
  skipped_quality = 0
577
  skipped_play = 0
578
+
579
  indices = list(range(len(gutenberg)))
580
  random.shuffle(indices)
581
+
582
  for i in tqdm(indices, desc="Matching books", leave=False):
583
  if len(records) >= max_samples:
584
  break
585
+
586
  item = gutenberg[i]
587
  text = item.get("TEXT", "") or item.get("text", "")
588
  metadata_raw = item.get("METADATA", "") or "{}"
589
+
590
  # Parse metadata
591
  try:
592
  metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
593
  except (json.JSONDecodeError, TypeError):
594
  metadata = {}
595
+
596
  # Get title
597
  title = metadata.get("title", "") if isinstance(metadata, dict) else ""
598
  if not title:
599
  continue
600
+
601
  # Check if we have a Goodreads description for this book
602
  norm_title = normalize_title(title)
603
  if norm_title not in goodreads_descriptions:
604
  continue
605
+
606
  # Skip if already matched this book
607
  if norm_title in matched_titles:
608
  continue
609
+
610
  goodreads_data = goodreads_descriptions[norm_title]
611
+
612
  # Skip plays and excluded titles
613
  if is_excluded_title(title):
614
  skipped_play += 1
615
  continue
616
+
617
  if not text or len(text) < 2000:
618
  continue
619
+
620
  # Get a clean excerpt from the book (skip front matter)
621
+ paragraphs = re.split(r"\n\s*\n", text)
622
  excerpt_parts = []
623
  total_len = 0
624
+
625
  for para in paragraphs[10:]: # Skip front matter
626
  para = para.strip()
627
  if len(para) < 100:
628
  continue
629
+
630
  # Quality check on paragraph
631
  if not is_english_text(para):
632
  continue
 
636
  if not is_quality_text(para) and len(para) > 300:
637
  skipped_quality += 1
638
  continue
639
+
640
  excerpt_parts.append(para)
641
  total_len += len(para)
642
+
643
  if total_len >= 3000:
644
  break
645
+
646
  if total_len < 1000:
647
  continue
648
+
649
  book_excerpt = "\n\n".join(excerpt_parts)[:4000]
650
  matched_titles.add(norm_title)
651
+
652
+ records.append(
653
+ {
654
+ "source": book_excerpt,
655
+ "summary": goodreads_data["description"][:800], # Back-cover blurbs are shorter
656
+ "type": "literary",
657
+ "title": goodreads_data["title"],
658
+ }
659
+ )
660
+
661
  print(f" Matched {len(records):,} books with descriptions")
662
  print(f" Skipped: {skipped_quality} quality, {skipped_play} plays")
663
+
664
  return records
665
 
666
 
667
  # Keep BookSum for additional literary training (chapter summaries are still useful)
668
  def download_booksum(max_samples: int = 20000) -> list[dict[str, Any]]:
669
  """Download BookSum - literary chapter summarization (English only, quality filtered).
670
+
671
  Note: These are chapter-level plot summaries, useful as supplementary training data.
672
  The primary book training comes from Goodreads descriptions (back-cover style).
673
  """
674
  print("\nLoading BookSum (supplementary literary data)...")
675
+
676
  all_records: list[dict[str, Any]] = []
677
  booksum = load_dataset("kmfoda/booksum")
678
+
679
  for split_name in booksum.keys():
680
  split = str(split_name)
681
  data = booksum[split_name]
682
  limit = max_samples if "train" in split else max_samples // 10
683
  indices = random.sample(range(len(data)), min(len(data), limit))
684
+
685
  records = []
686
  skipped_language = 0
687
  skipped_excluded = 0
688
  skipped_play = 0
689
+
690
  for i in tqdm(indices, desc=f"BookSum {split}", leave=False):
691
  item = data[i]
692
  chapter = item.get("chapter", "")
693
  summary = item.get("summary_text") or item.get("summary", "")
694
+
695
  # Extract book title from book_id (e.g., "The Last of the Mohicans.chapters 1-2")
696
  book_id = item.get("book_id", "")
697
  book_title = book_id.split(".")[0] if "." in book_id else book_id
698
  chapter_name = item.get("summary_id", "") or item.get("summary_name", "")
699
+
700
  if not (chapter and summary and len(chapter) > 300):
701
  continue
702
+
703
  # Filter: excluded titles (Shakespeare, plays, etc.)
704
  if is_excluded_title(book_title):
705
  skipped_excluded += 1
706
  continue
707
+
708
  # Filter: play text format
709
  if is_play_text(chapter):
710
  skipped_play += 1
711
  continue
712
+
713
  # Filter: English only
714
  if not is_english_text(chapter):
715
  skipped_language += 1
716
  continue
717
+
718
  # Filter: quality text
719
  if not is_quality_text(chapter):
720
  continue
721
+
722
+ records.append(
723
+ {
724
+ "source": chapter[:4000],
725
+ "summary": summary,
726
+ "type": "literary",
727
+ "split": split,
728
+ "title": book_title,
729
+ "chapter": chapter_name,
730
+ }
731
+ )
732
  all_records.extend(records)
733
+ print(
734
+ f" {split}: {len(records):,} (skipped {skipped_language} non-English, {skipped_excluded} excluded, {skipped_play} plays)"
735
+ )
736
+
737
  return all_records
738
 
739
 
740
  def clean_arxiv_text(text: str) -> str:
741
  """Clean arXiv LaTeX-style text to make it more readable."""
742
  import re
743
+
744
  # Remove LaTeX math placeholders
745
+ text = re.sub(r"@xmath\d+", "", text)
746
+ text = re.sub(r"@xcite", "", text)
747
  # Remove excessive whitespace
748
+ text = re.sub(r"\s+", " ", text)
749
  # Remove LaTeX commands
750
+ text = re.sub(r"\\[a-zA-Z]+\{[^}]*\}", "", text)
751
+ text = re.sub(r"\\[a-zA-Z]+", "", text)
752
  return text.strip()
753
 
754
 
 
756
  """Extract a meaningful title from the first sentence of an abstract."""
757
  # Clean the abstract first
758
  abstract = clean_arxiv_text(abstract)
759
+
760
  # Get the first sentence (up to first period, question mark, or newline)
761
+ first_sentence = re.split(r"[.!?\n]", abstract)[0].strip()
762
+
763
  # Truncate if too long
764
  if len(first_sentence) > 100:
765
  # Try to cut at a natural word boundary
766
+ first_sentence = first_sentence[:100].rsplit(" ", 1)[0] + "..."
767
+
768
  # Capitalize first letter
769
  if first_sentence:
770
  first_sentence = first_sentence[0].upper() + first_sentence[1:]
771
+
772
  return first_sentence or "Untitled Paper"
773
 
774
 
 
776
  """
777
  Download arXiv papers for academic summarization only (English only).
778
  Note: This dataset doesn't have categories, so can't be used for topic classification.
779
+
780
  Returns: summarization_records
781
  """
782
  print("\nLoading arXiv (academic papers for summarization)...")
783
+
784
  print(" Loading dataset (this may take a minute)...")
785
  arxiv = load_dataset("ccdv/arxiv-summarization", split="train")
786
+
787
  summ_records: list[dict[str, Any]] = []
788
  skipped_language = 0
789
+
790
  indices = list(range(len(arxiv)))
791
  random.shuffle(indices)
792
+
793
  print(" Processing papers...")
794
+ for i in tqdm(indices[: max_samples * 2], desc="arXiv", leave=False):
795
  if len(summ_records) >= max_samples:
796
  break
797
+
798
  item = arxiv[i]
799
+
800
  # Get abstract and article
801
  abstract = item.get("abstract", "")
802
  article = item.get("article", "")
803
+
804
  if not abstract or len(abstract) < 100:
805
  continue
806
+
807
  # Clean LaTeX artifacts
808
  abstract = clean_arxiv_text(abstract)
809
  article = clean_arxiv_text(article)
810
+
811
  # Skip if still has too many weird characters after cleaning
812
+ if "@" in abstract or "@" in article[:500]:
813
  continue
814
+
815
  # Filter: English only
816
  if not is_english_text(article[:1000]):
817
  skipped_language += 1
818
  continue
819
+
820
  # Summarization: article → abstract
821
  if article and len(article) > 500:
822
  # Extract title from abstract
823
  paper_title = extract_paper_title(abstract)
824
+
825
+ summ_records.append(
826
+ {
827
+ "source": article[:4000],
828
+ "summary": abstract,
829
+ "type": "academic",
830
+ "title": paper_title,
831
+ }
832
+ )
833
+
834
  print(f" Summarization: {len(summ_records):,} (skipped {skipped_language} non-English)")
835
+
836
  return summ_records
837
 
838
 
839
  def download_topics_from_datasets(max_samples: int = 50000) -> list[dict[str, Any]]:
840
  """
841
  Download topic classification data from multiple sources with real categories.
842
+
843
  Sources:
844
  - 20 Newsgroups (classic topic classification)
845
  - Wikipedia (article categories)
846
  """
847
  print("\nLoading topic classification datasets...")
848
+
849
  records: list[dict[str, Any]] = []
850
+
851
  # 20 Newsgroups - classic topic dataset
852
  print(" Loading 20 Newsgroups...")
853
  try:
854
  newsgroups = load_dataset("SetFit/20_newsgroups", split="train")
855
+
856
  # Map 20 newsgroups categories to our 8 topics
857
  newsgroup_map = {
858
  # Science
859
+ "sci.crypt": "Science",
860
+ "sci.electronics": "Science",
861
+ "sci.med": "Science",
862
+ "sci.space": "Science",
863
+ # Technology
864
+ "comp.graphics": "Technology",
865
+ "comp.os.ms-windows.misc": "Technology",
866
+ "comp.sys.ibm.pc.hardware": "Technology",
867
+ "comp.sys.mac.hardware": "Technology",
868
  "comp.windows.x": "Technology",
869
  # Philosophy/Religion
870
+ "alt.atheism": "Philosophy",
871
+ "soc.religion.christian": "Philosophy",
872
  "talk.religion.misc": "Philosophy",
873
  # History/Politics
874
+ "talk.politics.guns": "History",
875
+ "talk.politics.mideast": "History",
876
  "talk.politics.misc": "History",
877
  # Business
878
  "misc.forsale": "Business",
879
  # Sports/Recreation
880
+ "rec.autos": "Arts",
881
+ "rec.motorcycles": "Arts",
882
+ "rec.sport.baseball": "Arts",
883
+ "rec.sport.hockey": "Arts",
884
  }
885
+
886
  for item in tqdm(newsgroups, desc="20 Newsgroups", leave=False):
887
  if len(records) >= max_samples:
888
  break
889
  label_name = item.get("label_text", "")
890
  text = item.get("text", "")
891
+
892
  if label_name in newsgroup_map and text and len(text) > 100:
893
+ records.append(
894
+ {
895
+ "text": text[:1500],
896
+ "topic": newsgroup_map[label_name],
897
+ "source": "newsgroups",
898
+ }
899
+ )
900
+
901
  print(f" 20 Newsgroups: {len(records):,}")
902
  except Exception as e:
903
  print(f" 20 Newsgroups failed: {e}")
904
+
905
  # Add from Gutenberg for Fiction
906
  gutenberg_topics = download_gutenberg_topics(max_samples // 4)
907
  records.extend(gutenberg_topics)
908
+
909
  # Add from scientific papers abstract dataset for more Science/Tech
910
  print(" Loading scientific papers...")
911
  try:
912
  sci_papers = load_dataset("scientific_papers", "arxiv", split="train", streaming=True)
913
  sci_count = 0
914
+ for item in tqdm(sci_papers, desc="Scientific papers", leave=False, total=max_samples // 4):
915
  if sci_count >= max_samples // 4:
916
  break
917
  abstract = item.get("abstract", "")
918
  if abstract and len(abstract) > 100:
919
  # Alternate between Science and Technology
920
  topic = "Science" if sci_count % 2 == 0 else "Technology"
921
+ records.append(
922
+ {
923
+ "text": abstract[:1500],
924
+ "topic": topic,
925
+ "source": "scientific_papers",
926
+ }
927
+ )
928
  sci_count += 1
929
  print(f" Scientific papers: {sci_count:,}")
930
  except Exception as e:
931
  print(f" Scientific papers failed: {e}")
932
+
933
  return records
934
 
935
 
936
  def download_summarization(max_books: int = 20000, max_arxiv: int = 50000) -> None:
937
  """Download all summarization data (books + arxiv, NO news).
938
+
939
  Book data now uses Goodreads descriptions (back-cover blurbs) instead of
940
  plot summaries. This trains the model to describe "what the book is about"
941
  rather than summarizing the plot.
942
  """
943
  print("\nDownloading Summarization Data...")
944
  out_dir = OUTPUT_DIR / "summarization"
945
+
946
  all_records: list[dict[str, Any]] = []
947
+
948
  # Goodreads descriptions - primary book training data (back-cover style)
949
  goodreads_descriptions = download_goodreads_descriptions()
950
  book_records = download_book_descriptions(goodreads_descriptions, max_books)
951
  all_records.extend(book_records)
952
+
953
  # Optional: Add some BookSum for additional literary variety
954
  # These are chapter summaries, not back-cover style, so keep limited
955
  # booksum_records = download_booksum(max_books // 4)
956
  # all_records.extend(booksum_records)
957
+
958
  # arXiv - academic (abstracts are already "what is this paper about")
959
  arxiv_summ = download_arxiv_summarization(max_arxiv)
960
  all_records.extend(arxiv_summ)
961
+
962
  # Shuffle and split
963
  random.shuffle(all_records)
964
+
965
  # Split by original split if available, else 90/5/5
966
+ train_records = [
967
+ r for r in all_records if r.get("split", "train") == "train" or "split" not in r
968
+ ]
969
  val_records = [r for r in all_records if r.get("split") == "validation"]
970
  test_records = [r for r in all_records if r.get("split") == "test"]
971
+
972
  # If no split info, do 90/5/5
973
  if len(val_records) < 100:
974
  n = len(train_records)
975
  random.shuffle(train_records)
976
+ val_records = train_records[int(n * 0.9) : int(n * 0.95)]
977
+ test_records = train_records[int(n * 0.95) :]
978
+ train_records = train_records[: int(n * 0.9)]
979
+
980
  # Remove split key before saving
981
  for r in train_records + val_records + test_records:
982
  r.pop("split", None)
983
+
984
  write_jsonl(train_records, out_dir / "train.jsonl", "train")
985
  write_jsonl(val_records, out_dir / "validation.jsonl", "val")
986
  write_jsonl(test_records, out_dir / "test.jsonl", "test")
987
+
988
  # Print breakdown
989
+ literary_count = sum(
990
+ 1 for r in train_records + val_records + test_records if r.get("type") == "literary"
991
+ )
992
+ academic_count = sum(
993
+ 1 for r in train_records + val_records + test_records if r.get("type") == "academic"
994
+ )
995
  print(f"\n Total summarization: {len(train_records) + len(val_records) + len(test_records):,}")
996
  print(f" Literary (book descriptions): {literary_count:,}")
997
  print(f" Academic (paper abstracts): {academic_count:,}")
 
999
 
1000
  # ------------ TOPIC CLASSIFICATION ------------
1001
 
1002
+
1003
  def download_topics(max_samples: int = 50000) -> None:
1004
  """
1005
  Download topic classification data from multiple sources.
1006
+
1007
  Sources:
1008
  - 20 Newsgroups (classic topic dataset)
1009
  - Gutenberg books (Fiction)
 
1011
  """
1012
  print("\nDownloading Topic Classification...")
1013
  out_dir = OUTPUT_DIR / "topic"
1014
+
1015
  # Get topic records from various sources
1016
  all_records = download_topics_from_datasets(max_samples)
1017
+
1018
  # Balance topics
1019
  topic_counts: dict[str, list] = {t: [] for t in TOPIC_LABELS}
1020
  for r in all_records:
1021
  topic = r.get("topic")
1022
  if topic in topic_counts:
1023
  topic_counts[topic].append(r)
1024
+
1025
  # Print distribution before balancing
1026
  print("\n Topic distribution (before balancing):")
1027
  for topic, records in topic_counts.items():
1028
  print(f" {topic}: {len(records):,}")
1029
+
1030
  # Balance to min count (with some tolerance) - only from topics that have data
1031
  counts_with_data = [len(v) for v in topic_counts.values() if v]
1032
  if not counts_with_data:
1033
  print(" Warning: No topic data found!")
1034
  return
1035
+
1036
  min_count = min(counts_with_data)
1037
  target_count = min(min_count, max_samples // len(TOPIC_LABELS))
1038
+
1039
  balanced: list[dict[str, Any]] = []
1040
  for _topic, records in topic_counts.items():
1041
  if records:
1042
  random.shuffle(records)
1043
  balanced.extend(records[:target_count])
1044
+
1045
  random.shuffle(balanced)
1046
+
1047
  # Split 90/5/5
1048
  n = len(balanced)
1049
+ train_records = balanced[: int(n * 0.9)]
1050
+ val_records = balanced[int(n * 0.9) : int(n * 0.95)]
1051
+ test_records = balanced[int(n * 0.95) :]
1052
+
1053
  write_jsonl(train_records, out_dir / "train.jsonl", "train")
1054
  write_jsonl(val_records, out_dir / "validation.jsonl", "val")
1055
  write_jsonl(test_records, out_dir / "test.jsonl", "test")
1056
+
1057
  # Save labels - only labels that have data
1058
  used_labels = [t for t in TOPIC_LABELS if topic_counts.get(t)]
1059
  (out_dir / "labels.json").write_text(json.dumps(used_labels, indent=2))
 
1063
  def download_gutenberg_topics(max_samples: int = 30000) -> list[dict[str, Any]]:
1064
  """Extract topic-labeled samples from Gutenberg books (English only)."""
1065
  print("\nLoading Gutenberg for topic classification...")
1066
+
1067
  try:
1068
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
1069
  except Exception:
1070
  print(" Trying pg19...")
1071
  gutenberg = load_dataset("pg19", split="train")
1072
+
1073
  records: list[dict[str, Any]] = []
1074
  skipped_language = 0
1075
+
1076
  indices = list(range(len(gutenberg)))
1077
  random.shuffle(indices)
1078
+
1079
  for i in tqdm(indices, desc="Gutenberg topics", leave=False):
1080
  if len(records) >= max_samples:
1081
  break
1082
+
1083
  item = gutenberg[i]
1084
  text = item.get("TEXT", "") or item.get("text", "")
1085
  metadata = item.get("METADATA", {}) or {}
1086
+
1087
  if not text or len(text) < 1000:
1088
  continue
1089
+
1090
  # Try to determine topic from metadata
1091
  subjects = ""
1092
  if isinstance(metadata, dict):
1093
  subjects = str(metadata.get("subjects", "")).lower()
1094
  subjects += " " + str(metadata.get("subject", "")).lower()
1095
  subjects += " " + str(metadata.get("category", "")).lower()
1096
+
1097
  topic = None
1098
  for keyword, mapped_topic in GUTENBERG_SUBJECT_MAP.items():
1099
  if keyword in subjects:
1100
  topic = mapped_topic
1101
  break
1102
+
1103
  # Default fiction for novels without clear subject
1104
  if not topic and ("novel" in subjects or not subjects.strip()):
1105
  topic = "Fiction"
1106
+
1107
  if topic:
1108
  # Get a clean paragraph as sample
1109
+ paragraphs = re.split(r"\n\s*\n", text)
1110
  for para in paragraphs[5:]: # Skip front matter
1111
  para = para.strip()
1112
+ if 200 < len(para) < 1500 and para.count(".") >= 2:
1113
  # Filter: English only
1114
  if not is_english_text(para):
1115
  skipped_language += 1
1116
  break
1117
+
1118
+ records.append(
1119
+ {
1120
+ "text": para,
1121
+ "topic": topic,
1122
+ "source": "gutenberg",
1123
+ }
1124
+ )
1125
  break
1126
+
1127
  print(f" Gutenberg topics: {len(records):,} (skipped {skipped_language} non-English)")
1128
  return records
1129
 
1130
 
1131
  # ------------ EMOTIONS (unchanged) -------------
1132
 
1133
+
1134
  def download_emotions() -> None:
1135
  """Download GoEmotions for emotion classification."""
1136
  print("\nDownloading Emotions (GoEmotions)...")
1137
  out_dir = OUTPUT_DIR / "emotion"
1138
+
1139
  ds = load_dataset("google-research-datasets/go_emotions", "simplified")
1140
+
1141
  for split_name in ds.keys():
1142
  split = str(split_name)
1143
  data = ds[split_name]
1144
+
1145
  records: list[dict[str, Any]] = []
1146
  for item in tqdm(data, desc=split, leave=False):
1147
  text = item.get("text", "")
 
1151
  if emotions:
1152
  records.append({"text": text, "emotions": emotions})
1153
  write_jsonl(records, out_dir / f"{split}.jsonl", split)
1154
+
1155
  (out_dir / "labels.json").write_text(json.dumps(EMOTION_LABELS, indent=2))
1156
  print(f" {len(EMOTION_LABELS)} emotion labels saved")
1157
 
 
1159
  # --------------- GUTENBERG BOOKS (for language modeling) ---------------
1160
 
1161
  GUTENBERG_JUNK_PATTERNS = [
1162
+ r"Project Gutenberg",
1163
+ r"www\.gutenberg\.org",
1164
+ r"This ebook is for",
1165
+ r"Gutenberg License",
1166
+ r"^\*\*\* START OF",
1167
+ r"^\*\*\* END OF",
1168
+ r"Produced by",
1169
+ r"Transcriber's Note",
1170
+ r"TABLE OF CONTENTS",
1171
+ r"^\s*CHAPTER\s+[IVXLC\d]+",
1172
+ r"^\s*Chapter\s+[IVXLC\d]+",
1173
+ r"^\s*BOOK\s+[IVXLC\d]+",
1174
+ r"^\s*PREFACE\s*$",
1175
+ r"^\s*INTRODUCTION\s*$",
1176
+ r"E-text prepared by",
1177
+ r"Internet Archive",
1178
+ r"Distributed Proofreaders",
1179
  ]
1180
  GUTENBERG_JUNK_REGEX = re.compile("|".join(GUTENBERG_JUNK_PATTERNS), re.IGNORECASE)
1181
 
 
1186
  return False
1187
  if GUTENBERG_JUNK_REGEX.search(text):
1188
  return False
1189
+ if text.count(".") < 2:
1190
  return False
1191
  uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
1192
  if uppercase_ratio > 0.3:
 
1205
  print("\nDownloading Gutenberg Books (English only)...")
1206
  out_dir = OUTPUT_DIR / "books"
1207
  out_dir.mkdir(parents=True, exist_ok=True)
1208
+
1209
  try:
1210
  gutenberg = load_dataset("sedthh/gutenberg_english", split="train")
1211
  except Exception:
1212
  gutenberg = load_dataset("pg19", split="train")
1213
+
1214
  records: list[dict[str, Any]] = []
1215
  indices = list(range(len(gutenberg)))
1216
  random.shuffle(indices)
1217
+
1218
  for i in tqdm(indices, desc="Books", leave=False):
1219
  if len(records) >= max_samples:
1220
  break
1221
+
1222
  item = gutenberg[i]
1223
  text = item.get("TEXT", "") or item.get("text", "")
1224
  metadata_raw = item.get("METADATA", "") or "{}"
1225
+
1226
  # Parse metadata - it's stored as JSON string
1227
  try:
1228
  metadata = json.loads(metadata_raw) if isinstance(metadata_raw, str) else metadata_raw
1229
  except (json.JSONDecodeError, TypeError):
1230
  metadata = {}
1231
+
1232
  # Extract title and author
1233
  title = metadata.get("title", "") if isinstance(metadata, dict) else ""
1234
  author = metadata.get("author", "") if isinstance(metadata, dict) else ""
1235
  if not title:
1236
  title = item.get("title", f"Unknown Book #{i}")
1237
+
1238
  if not text or len(text) < 1000:
1239
  continue
1240
+
1241
+ paragraphs = re.split(r"\n\s*\n", text)
1242
  for para in paragraphs:
1243
  para = para.strip()
1244
  if is_clean_prose(para):
1245
+ records.append(
1246
+ {"text": para, "title": title, "author": author, "type": "gutenberg"}
1247
+ )
 
 
 
1248
  if len(records) >= max_samples:
1249
  break
1250
+
1251
  random.shuffle(records)
1252
  n = len(records)
1253
+ write_jsonl(records[: int(n * 0.9)], out_dir / "train.jsonl", "train")
1254
+ write_jsonl(records[int(n * 0.9) : int(n * 0.95)], out_dir / "validation.jsonl", "val")
1255
+ write_jsonl(records[int(n * 0.95) :], out_dir / "test.jsonl", "test")
1256
 
1257
 
1258
  # ------------ MAIN ------------
1259
 
1260
+
1261
  def main() -> None:
1262
  parser = argparse.ArgumentParser(description="Download LexiMind datasets")
1263
  parser.add_argument(
1264
  "--task",
1265
  choices=["all", "summarization", "emotion", "topic", "gutenberg"],
1266
  default="all",
1267
+ help="Dataset to download",
1268
  )
1269
  parser.add_argument("--max-books", type=int, default=40000, help="Max BookSum samples")
1270
  parser.add_argument("--max-arxiv", type=int, default=50000, help="Max arXiv samples")
 
1272
  parser.add_argument("--max-topics", type=int, default=50000, help="Max topic samples")
1273
  parser.add_argument("--seed", type=int, default=42, help="Random seed")
1274
  args = parser.parse_args()
1275
+
1276
  random.seed(args.seed)
1277
+
1278
  print("=" * 60)
1279
  print("LexiMind Dataset Download")
1280
  print("Books + Academic Papers + Topic Classification")
1281
  print("=" * 60)
1282
+
1283
  if args.task in ["all", "summarization"]:
1284
  download_summarization(args.max_books, args.max_arxiv)
1285
  if args.task in ["all", "emotion"]:
 
1288
  download_topics(args.max_topics)
1289
  if args.task in ["all", "gutenberg"]:
1290
  download_gutenberg(args.max_gutenberg)
1291
+
1292
  print("\n" + "=" * 60)
1293
  print("Download complete!")
1294
  print("=" * 60)
scripts/evaluate.py CHANGED
@@ -65,34 +65,34 @@ def evaluate_summarization(
65
  print("\n" + "=" * 60)
66
  print("SUMMARIZATION EVALUATION")
67
  print("=" * 60)
68
-
69
  # Load data - try to get domain info from the raw JSONL
70
  raw_data = []
71
  with open(data_path) as f:
72
  for line in f:
73
  if line.strip():
74
  raw_data.append(json.loads(line))
75
-
76
  data = load_summarization_jsonl(str(data_path))
77
  if max_samples:
78
  data = data[:max_samples]
79
  raw_data = raw_data[:max_samples]
80
  print(f"Evaluating on {len(data)} samples...")
81
-
82
  # Generate summaries
83
  predictions = []
84
  references = []
85
  domains = [] # Track domain for per-domain breakdown
86
-
87
  for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
88
- batch = data[i:i + batch_size]
89
  sources = [ex.source for ex in batch]
90
  refs = [ex.summary for ex in batch]
91
-
92
  preds = pipeline.summarize(sources)
93
  predictions.extend(preds)
94
  references.extend(refs)
95
-
96
  # Track domain if available
97
  for j in range(len(batch)):
98
  idx = i + j
@@ -101,14 +101,14 @@ def evaluate_summarization(
101
  domains.append(domain)
102
  else:
103
  domains.append("unknown")
104
-
105
  # Calculate overall metrics
106
  print("\nCalculating ROUGE scores...")
107
  rouge_scores = calculate_rouge(predictions, references)
108
-
109
  print("Calculating BLEU score...")
110
  bleu = calculate_bleu(predictions, references)
111
-
112
  metrics: dict = {
113
  "rouge1": rouge_scores["rouge1"],
114
  "rouge2": rouge_scores["rouge2"],
@@ -116,14 +116,14 @@ def evaluate_summarization(
116
  "bleu4": bleu,
117
  "num_samples": len(predictions),
118
  }
119
-
120
  if include_bertscore:
121
  print("Calculating BERTScore (this may take a few minutes)...")
122
  bert_scores = calculate_bertscore(predictions, references)
123
  metrics["bertscore_precision"] = bert_scores["precision"]
124
  metrics["bertscore_recall"] = bert_scores["recall"]
125
  metrics["bertscore_f1"] = bert_scores["f1"]
126
-
127
  # Per-domain breakdown
128
  unique_domains = sorted(set(domains))
129
  if len(unique_domains) > 1:
@@ -150,25 +150,26 @@ def evaluate_summarization(
150
  dm["bertscore_f1"] = d_bert["f1"]
151
  domain_metrics[domain] = dm
152
  metrics["per_domain"] = domain_metrics
153
-
154
  # Bootstrap confidence intervals
155
  if compute_bootstrap:
156
  try:
157
  from rouge_score import rouge_scorer
158
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
 
159
  per_sample_r1 = []
160
  per_sample_rL = []
161
  for pred, ref in zip(predictions, references, strict=True):
162
  scores = scorer.score(ref, pred)
163
- per_sample_r1.append(scores['rouge1'].fmeasure)
164
- per_sample_rL.append(scores['rougeL'].fmeasure)
165
  r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
166
  rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
167
  metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
168
  metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
169
  except ImportError:
170
  pass
171
-
172
  # Print results
173
  print("\n" + "-" * 40)
174
  print("SUMMARIZATION RESULTS:")
@@ -181,27 +182,29 @@ def evaluate_summarization(
181
  print(f" BERTScore P: {metrics['bertscore_precision']:.4f}")
182
  print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
183
  print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
184
-
185
  if "per_domain" in metrics:
186
  print("\n Per-Domain Breakdown:")
187
  for domain, dm in metrics["per_domain"].items():
188
  bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
189
- print(f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}")
190
-
 
 
191
  if "rouge1_ci" in metrics:
192
  ci = metrics["rouge1_ci"]
193
  print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
194
-
195
  # Show examples
196
  print("\n" + "-" * 40)
197
  print("SAMPLE OUTPUTS:")
198
  print("-" * 40)
199
  for i in range(min(3, len(predictions))):
200
- print(f"\nExample {i+1}:")
201
  print(f" Source: {data[i].source[:100]}...")
202
  print(f" Generated: {predictions[i][:150]}...")
203
  print(f" Reference: {references[i][:150]}...")
204
-
205
  return metrics
206
 
207
 
@@ -214,62 +217,64 @@ def evaluate_emotion(
214
  compute_bootstrap: bool = False,
215
  ) -> dict:
216
  """Evaluate emotion detection with comprehensive multi-label metrics.
217
-
218
  Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
219
  Optionally tunes per-class thresholds on the evaluation set.
220
  """
221
  print("\n" + "=" * 60)
222
  print("EMOTION DETECTION EVALUATION")
223
  print("=" * 60)
224
-
225
  # Load data (returns EmotionExample dataclass objects)
226
  data = load_emotion_jsonl(str(data_path))
227
  if max_samples:
228
  data = data[:max_samples]
229
  print(f"Evaluating on {len(data)} samples...")
230
-
231
  # Get predictions - collect raw logits for threshold tuning
232
  all_preds = []
233
  all_refs = []
234
  all_logits_list = []
235
-
236
  for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
237
- batch = data[i:i + batch_size]
238
  texts = [ex.text for ex in batch]
239
  refs = [set(ex.emotions) for ex in batch]
240
-
241
  preds = pipeline.predict_emotions(texts)
242
  pred_sets = [set(p.labels) for p in preds]
243
-
244
  all_preds.extend(pred_sets)
245
  all_refs.extend(refs)
246
-
247
  # Also get raw logits for threshold tuning
248
  if tune_thresholds:
249
  encoded = pipeline.tokenizer.batch_encode(texts)
250
  input_ids = encoded["input_ids"].to(pipeline.device)
251
  attention_mask = encoded["attention_mask"].to(pipeline.device)
252
  with torch.inference_mode():
253
- logits = pipeline.model.forward("emotion", {"input_ids": input_ids, "attention_mask": attention_mask})
 
 
254
  all_logits_list.append(logits.cpu())
255
-
256
  # Calculate metrics
257
  all_emotions = sorted(pipeline.emotion_labels)
258
-
259
  def to_binary(emotion_sets, labels):
260
  return [[1 if e in es else 0 for e in labels] for es in emotion_sets]
261
-
262
  pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
263
  ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
264
-
265
  # Core metrics: sample-avg F1, macro F1, micro F1
266
  sample_f1 = multilabel_f1(pred_binary, ref_binary)
267
  macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
268
  micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
269
-
270
  # Per-class metrics
271
  per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
272
-
273
  metrics: dict = {
274
  "sample_avg_f1": sample_f1,
275
  "macro_f1": macro_f1,
@@ -278,7 +283,7 @@ def evaluate_emotion(
278
  "num_classes": len(all_emotions),
279
  "per_class": per_class,
280
  }
281
-
282
  # Per-class threshold tuning
283
  if tune_thresholds and all_logits_list:
284
  print("\nTuning per-class thresholds...")
@@ -288,7 +293,7 @@ def evaluate_emotion(
288
  name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
289
  }
290
  metrics["tuned_macro_f1"] = tuned_macro_f1
291
-
292
  # Also compute tuned sample-avg F1
293
  probs = torch.sigmoid(all_logits)
294
  tuned_preds = torch.zeros_like(probs)
@@ -296,7 +301,7 @@ def evaluate_emotion(
296
  tuned_preds[:, c] = (probs[:, c] >= t).float()
297
  metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
298
  metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
299
-
300
  # Bootstrap confidence intervals
301
  if compute_bootstrap:
302
  # Compute per-sample F1 for bootstrapping
@@ -313,7 +318,7 @@ def evaluate_emotion(
313
  per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
314
  mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
315
  metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
316
-
317
  # Print results
318
  print("\n" + "-" * 40)
319
  print("EMOTION DETECTION RESULTS:")
@@ -322,23 +327,25 @@ def evaluate_emotion(
322
  print(f" Macro F1: {metrics['macro_f1']:.4f}")
323
  print(f" Micro F1: {metrics['micro_f1']:.4f}")
324
  print(f" Num Classes: {metrics['num_classes']}")
325
-
326
  if "tuned_macro_f1" in metrics:
327
  print("\n After per-class threshold tuning:")
328
  print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
329
  print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
330
  print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
331
-
332
  if "sample_f1_ci" in metrics:
333
  ci = metrics["sample_f1_ci"]
334
  print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
335
-
336
  # Print top-10 per-class performance
337
  print("\n Per-class F1 (top 10 by support):")
338
  sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
339
  for name, m in sorted_classes[:10]:
340
- print(f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})")
341
-
 
 
342
  return metrics
343
 
344
 
@@ -353,61 +360,63 @@ def evaluate_topic(
353
  print("\n" + "=" * 60)
354
  print("TOPIC CLASSIFICATION EVALUATION")
355
  print("=" * 60)
356
-
357
  # Load data (returns TopicExample dataclass objects)
358
  data = load_topic_jsonl(str(data_path))
359
  if max_samples:
360
  data = data[:max_samples]
361
  print(f"Evaluating on {len(data)} samples...")
362
-
363
  # Get predictions
364
  all_preds = []
365
  all_refs = []
366
-
367
  for i in tqdm(range(0, len(data), batch_size), desc="Predicting topics"):
368
- batch = data[i:i + batch_size]
369
  texts = [ex.text for ex in batch]
370
  refs = [ex.topic for ex in batch]
371
-
372
  preds = pipeline.predict_topics(texts)
373
  pred_labels = [p.label for p in preds]
374
-
375
  all_preds.extend(pred_labels)
376
  all_refs.extend(refs)
377
-
378
  # Calculate metrics
379
  accuracy = accuracy_score(all_refs, all_preds)
380
  macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
381
-
382
  metrics: dict = {
383
  "accuracy": accuracy,
384
  "macro_f1": macro_f1,
385
  "num_samples": len(all_preds),
386
  }
387
-
388
  # Bootstrap confidence intervals for accuracy
389
  if compute_bootstrap:
390
- per_sample_correct = [1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)]
 
 
391
  mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
392
  metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
393
-
394
  # Print results
395
  print("\n" + "-" * 40)
396
  print("TOPIC CLASSIFICATION RESULTS:")
397
  print("-" * 40)
398
- print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
399
  print(f" Macro F1: {metrics['macro_f1']:.4f}")
400
-
401
  if "accuracy_ci" in metrics:
402
  ci = metrics["accuracy_ci"]
403
  print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
404
-
405
  # Classification report
406
  print("\n" + "-" * 40)
407
  print("PER-CLASS METRICS:")
408
  print("-" * 40)
409
  print(classification_report(all_refs, all_preds, zero_division=0))
410
-
411
  return metrics
412
 
413
 
@@ -418,20 +427,28 @@ def main():
418
  parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
419
  parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
420
  parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
421
- parser.add_argument("--include-bertscore", action="store_true", help="Include BERTScore (slow, optional)")
422
- parser.add_argument("--tune-thresholds", action="store_true", help="Tune per-class emotion thresholds on val set")
423
- parser.add_argument("--bootstrap", action="store_true", help="Compute bootstrap confidence intervals")
 
 
 
 
 
 
 
 
424
  parser.add_argument("--summarization-only", action="store_true")
425
  parser.add_argument("--emotion-only", action="store_true")
426
  parser.add_argument("--topic-only", action="store_true")
427
  args = parser.parse_args()
428
-
429
  print("=" * 60)
430
  print("LexiMind Evaluation")
431
  print("=" * 60)
432
-
433
  start_time = time.perf_counter()
434
-
435
  # Load model
436
  print(f"\nLoading model from {args.checkpoint}...")
437
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -443,12 +460,12 @@ def main():
443
  print(f" Device: {device}")
444
  print(f" Topics: {labels.topic}")
445
  print(f" Emotions: {len(labels.emotion)} classes")
446
-
447
  results = {}
448
-
449
  # Determine which tasks to evaluate
450
  eval_all = not (args.summarization_only or args.emotion_only or args.topic_only)
451
-
452
  # Evaluate summarization
453
  if eval_all or args.summarization_only:
454
  val_path = args.data_dir / "summarization" / "validation.jsonl"
@@ -456,14 +473,15 @@ def main():
456
  val_path = args.data_dir / "summarization" / "val.jsonl"
457
  if val_path.exists():
458
  results["summarization"] = evaluate_summarization(
459
- pipeline, val_path,
 
460
  max_samples=args.max_samples,
461
  include_bertscore=args.include_bertscore,
462
  compute_bootstrap=args.bootstrap,
463
  )
464
  else:
465
  print("Warning: summarization validation data not found, skipping")
466
-
467
  # Evaluate emotion
468
  if eval_all or args.emotion_only:
469
  val_path = args.data_dir / "emotion" / "validation.jsonl"
@@ -471,14 +489,15 @@ def main():
471
  val_path = args.data_dir / "emotion" / "val.jsonl"
472
  if val_path.exists():
473
  results["emotion"] = evaluate_emotion(
474
- pipeline, val_path,
 
475
  max_samples=args.max_samples,
476
  tune_thresholds=args.tune_thresholds,
477
  compute_bootstrap=args.bootstrap,
478
  )
479
  else:
480
  print("Warning: emotion validation data not found, skipping")
481
-
482
  # Evaluate topic
483
  if eval_all or args.topic_only:
484
  val_path = args.data_dir / "topic" / "validation.jsonl"
@@ -486,30 +505,31 @@ def main():
486
  val_path = args.data_dir / "topic" / "val.jsonl"
487
  if val_path.exists():
488
  results["topic"] = evaluate_topic(
489
- pipeline, val_path,
 
490
  max_samples=args.max_samples,
491
  compute_bootstrap=args.bootstrap,
492
  )
493
  else:
494
  print("Warning: topic validation data not found, skipping")
495
-
496
  # Save results
497
  print("\n" + "=" * 60)
498
  print("SAVING RESULTS")
499
  print("=" * 60)
500
-
501
  args.output.parent.mkdir(parents=True, exist_ok=True)
502
  with open(args.output, "w") as f:
503
  json.dump(results, f, indent=2)
504
  print(f" Saved to: {args.output}")
505
-
506
  # Final summary
507
  elapsed = time.perf_counter() - start_time
508
  print("\n" + "=" * 60)
509
  print("EVALUATION COMPLETE")
510
  print("=" * 60)
511
- print(f" Time: {elapsed/60:.1f} minutes")
512
-
513
  if "summarization" in results:
514
  s = results["summarization"]
515
  print("\n Summarization:")
@@ -519,14 +539,14 @@ def main():
519
  print(f" BLEU-4: {s['bleu4']:.4f}")
520
  if "bertscore_f1" in s:
521
  print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
522
-
523
  if "emotion" in results:
524
  e = results["emotion"]
525
  print("\n Emotion:")
526
  print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
527
  print(f" Macro F1: {e['macro_f1']:.4f}")
528
  print(f" Micro F1: {e['micro_f1']:.4f}")
529
-
530
  if "topic" in results:
531
  print("\n Topic:")
532
  print(f" Accuracy: {results['topic']['accuracy']:.2%}")
 
65
  print("\n" + "=" * 60)
66
  print("SUMMARIZATION EVALUATION")
67
  print("=" * 60)
68
+
69
  # Load data - try to get domain info from the raw JSONL
70
  raw_data = []
71
  with open(data_path) as f:
72
  for line in f:
73
  if line.strip():
74
  raw_data.append(json.loads(line))
75
+
76
  data = load_summarization_jsonl(str(data_path))
77
  if max_samples:
78
  data = data[:max_samples]
79
  raw_data = raw_data[:max_samples]
80
  print(f"Evaluating on {len(data)} samples...")
81
+
82
  # Generate summaries
83
  predictions = []
84
  references = []
85
  domains = [] # Track domain for per-domain breakdown
86
+
87
  for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
88
+ batch = data[i : i + batch_size]
89
  sources = [ex.source for ex in batch]
90
  refs = [ex.summary for ex in batch]
91
+
92
  preds = pipeline.summarize(sources)
93
  predictions.extend(preds)
94
  references.extend(refs)
95
+
96
  # Track domain if available
97
  for j in range(len(batch)):
98
  idx = i + j
 
101
  domains.append(domain)
102
  else:
103
  domains.append("unknown")
104
+
105
  # Calculate overall metrics
106
  print("\nCalculating ROUGE scores...")
107
  rouge_scores = calculate_rouge(predictions, references)
108
+
109
  print("Calculating BLEU score...")
110
  bleu = calculate_bleu(predictions, references)
111
+
112
  metrics: dict = {
113
  "rouge1": rouge_scores["rouge1"],
114
  "rouge2": rouge_scores["rouge2"],
 
116
  "bleu4": bleu,
117
  "num_samples": len(predictions),
118
  }
119
+
120
  if include_bertscore:
121
  print("Calculating BERTScore (this may take a few minutes)...")
122
  bert_scores = calculate_bertscore(predictions, references)
123
  metrics["bertscore_precision"] = bert_scores["precision"]
124
  metrics["bertscore_recall"] = bert_scores["recall"]
125
  metrics["bertscore_f1"] = bert_scores["f1"]
126
+
127
  # Per-domain breakdown
128
  unique_domains = sorted(set(domains))
129
  if len(unique_domains) > 1:
 
150
  dm["bertscore_f1"] = d_bert["f1"]
151
  domain_metrics[domain] = dm
152
  metrics["per_domain"] = domain_metrics
153
+
154
  # Bootstrap confidence intervals
155
  if compute_bootstrap:
156
  try:
157
  from rouge_score import rouge_scorer
158
+
159
+ scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
160
  per_sample_r1 = []
161
  per_sample_rL = []
162
  for pred, ref in zip(predictions, references, strict=True):
163
  scores = scorer.score(ref, pred)
164
+ per_sample_r1.append(scores["rouge1"].fmeasure)
165
+ per_sample_rL.append(scores["rougeL"].fmeasure)
166
  r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
167
  rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
168
  metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
169
  metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
170
  except ImportError:
171
  pass
172
+
173
  # Print results
174
  print("\n" + "-" * 40)
175
  print("SUMMARIZATION RESULTS:")
 
182
  print(f" BERTScore P: {metrics['bertscore_precision']:.4f}")
183
  print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
184
  print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
185
+
186
  if "per_domain" in metrics:
187
  print("\n Per-Domain Breakdown:")
188
  for domain, dm in metrics["per_domain"].items():
189
  bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
190
+ print(
191
+ f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}"
192
+ )
193
+
194
  if "rouge1_ci" in metrics:
195
  ci = metrics["rouge1_ci"]
196
  print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
197
+
198
  # Show examples
199
  print("\n" + "-" * 40)
200
  print("SAMPLE OUTPUTS:")
201
  print("-" * 40)
202
  for i in range(min(3, len(predictions))):
203
+ print(f"\nExample {i + 1}:")
204
  print(f" Source: {data[i].source[:100]}...")
205
  print(f" Generated: {predictions[i][:150]}...")
206
  print(f" Reference: {references[i][:150]}...")
207
+
208
  return metrics
209
 
210
 
 
217
  compute_bootstrap: bool = False,
218
  ) -> dict:
219
  """Evaluate emotion detection with comprehensive multi-label metrics.
220
+
221
  Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
222
  Optionally tunes per-class thresholds on the evaluation set.
223
  """
224
  print("\n" + "=" * 60)
225
  print("EMOTION DETECTION EVALUATION")
226
  print("=" * 60)
227
+
228
  # Load data (returns EmotionExample dataclass objects)
229
  data = load_emotion_jsonl(str(data_path))
230
  if max_samples:
231
  data = data[:max_samples]
232
  print(f"Evaluating on {len(data)} samples...")
233
+
234
  # Get predictions - collect raw logits for threshold tuning
235
  all_preds = []
236
  all_refs = []
237
  all_logits_list = []
238
+
239
  for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
240
+ batch = data[i : i + batch_size]
241
  texts = [ex.text for ex in batch]
242
  refs = [set(ex.emotions) for ex in batch]
243
+
244
  preds = pipeline.predict_emotions(texts)
245
  pred_sets = [set(p.labels) for p in preds]
246
+
247
  all_preds.extend(pred_sets)
248
  all_refs.extend(refs)
249
+
250
  # Also get raw logits for threshold tuning
251
  if tune_thresholds:
252
  encoded = pipeline.tokenizer.batch_encode(texts)
253
  input_ids = encoded["input_ids"].to(pipeline.device)
254
  attention_mask = encoded["attention_mask"].to(pipeline.device)
255
  with torch.inference_mode():
256
+ logits = pipeline.model.forward(
257
+ "emotion", {"input_ids": input_ids, "attention_mask": attention_mask}
258
+ )
259
  all_logits_list.append(logits.cpu())
260
+
261
  # Calculate metrics
262
  all_emotions = sorted(pipeline.emotion_labels)
263
+
264
  def to_binary(emotion_sets, labels):
265
  return [[1 if e in es else 0 for e in labels] for es in emotion_sets]
266
+
267
  pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
268
  ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
269
+
270
  # Core metrics: sample-avg F1, macro F1, micro F1
271
  sample_f1 = multilabel_f1(pred_binary, ref_binary)
272
  macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
273
  micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
274
+
275
  # Per-class metrics
276
  per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
277
+
278
  metrics: dict = {
279
  "sample_avg_f1": sample_f1,
280
  "macro_f1": macro_f1,
 
283
  "num_classes": len(all_emotions),
284
  "per_class": per_class,
285
  }
286
+
287
  # Per-class threshold tuning
288
  if tune_thresholds and all_logits_list:
289
  print("\nTuning per-class thresholds...")
 
293
  name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
294
  }
295
  metrics["tuned_macro_f1"] = tuned_macro_f1
296
+
297
  # Also compute tuned sample-avg F1
298
  probs = torch.sigmoid(all_logits)
299
  tuned_preds = torch.zeros_like(probs)
 
301
  tuned_preds[:, c] = (probs[:, c] >= t).float()
302
  metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
303
  metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
304
+
305
  # Bootstrap confidence intervals
306
  if compute_bootstrap:
307
  # Compute per-sample F1 for bootstrapping
 
318
  per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
319
  mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
320
  metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
321
+
322
  # Print results
323
  print("\n" + "-" * 40)
324
  print("EMOTION DETECTION RESULTS:")
 
327
  print(f" Macro F1: {metrics['macro_f1']:.4f}")
328
  print(f" Micro F1: {metrics['micro_f1']:.4f}")
329
  print(f" Num Classes: {metrics['num_classes']}")
330
+
331
  if "tuned_macro_f1" in metrics:
332
  print("\n After per-class threshold tuning:")
333
  print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
334
  print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
335
  print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
336
+
337
  if "sample_f1_ci" in metrics:
338
  ci = metrics["sample_f1_ci"]
339
  print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
340
+
341
  # Print top-10 per-class performance
342
  print("\n Per-class F1 (top 10 by support):")
343
  sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
344
  for name, m in sorted_classes[:10]:
345
+ print(
346
+ f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})"
347
+ )
348
+
349
  return metrics
350
 
351
 
 
360
  print("\n" + "=" * 60)
361
  print("TOPIC CLASSIFICATION EVALUATION")
362
  print("=" * 60)
363
+
364
  # Load data (returns TopicExample dataclass objects)
365
  data = load_topic_jsonl(str(data_path))
366
  if max_samples:
367
  data = data[:max_samples]
368
  print(f"Evaluating on {len(data)} samples...")
369
+
370
  # Get predictions
371
  all_preds = []
372
  all_refs = []
373
+
374
  for i in tqdm(range(0, len(data), batch_size), desc="Predicting topics"):
375
+ batch = data[i : i + batch_size]
376
  texts = [ex.text for ex in batch]
377
  refs = [ex.topic for ex in batch]
378
+
379
  preds = pipeline.predict_topics(texts)
380
  pred_labels = [p.label for p in preds]
381
+
382
  all_preds.extend(pred_labels)
383
  all_refs.extend(refs)
384
+
385
  # Calculate metrics
386
  accuracy = accuracy_score(all_refs, all_preds)
387
  macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
388
+
389
  metrics: dict = {
390
  "accuracy": accuracy,
391
  "macro_f1": macro_f1,
392
  "num_samples": len(all_preds),
393
  }
394
+
395
  # Bootstrap confidence intervals for accuracy
396
  if compute_bootstrap:
397
+ per_sample_correct = [
398
+ 1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)
399
+ ]
400
  mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
401
  metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
402
+
403
  # Print results
404
  print("\n" + "-" * 40)
405
  print("TOPIC CLASSIFICATION RESULTS:")
406
  print("-" * 40)
407
+ print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy'] * 100:.1f}%)")
408
  print(f" Macro F1: {metrics['macro_f1']:.4f}")
409
+
410
  if "accuracy_ci" in metrics:
411
  ci = metrics["accuracy_ci"]
412
  print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
413
+
414
  # Classification report
415
  print("\n" + "-" * 40)
416
  print("PER-CLASS METRICS:")
417
  print("-" * 40)
418
  print(classification_report(all_refs, all_preds, zero_division=0))
419
+
420
  return metrics
421
 
422
 
 
427
  parser.add_argument("--data-dir", type=Path, default=Path("data/processed"))
428
  parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
429
  parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
430
+ parser.add_argument(
431
+ "--include-bertscore", action="store_true", help="Include BERTScore (slow, optional)"
432
+ )
433
+ parser.add_argument(
434
+ "--tune-thresholds",
435
+ action="store_true",
436
+ help="Tune per-class emotion thresholds on val set",
437
+ )
438
+ parser.add_argument(
439
+ "--bootstrap", action="store_true", help="Compute bootstrap confidence intervals"
440
+ )
441
  parser.add_argument("--summarization-only", action="store_true")
442
  parser.add_argument("--emotion-only", action="store_true")
443
  parser.add_argument("--topic-only", action="store_true")
444
  args = parser.parse_args()
445
+
446
  print("=" * 60)
447
  print("LexiMind Evaluation")
448
  print("=" * 60)
449
+
450
  start_time = time.perf_counter()
451
+
452
  # Load model
453
  print(f"\nLoading model from {args.checkpoint}...")
454
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
460
  print(f" Device: {device}")
461
  print(f" Topics: {labels.topic}")
462
  print(f" Emotions: {len(labels.emotion)} classes")
463
+
464
  results = {}
465
+
466
  # Determine which tasks to evaluate
467
  eval_all = not (args.summarization_only or args.emotion_only or args.topic_only)
468
+
469
  # Evaluate summarization
470
  if eval_all or args.summarization_only:
471
  val_path = args.data_dir / "summarization" / "validation.jsonl"
 
473
  val_path = args.data_dir / "summarization" / "val.jsonl"
474
  if val_path.exists():
475
  results["summarization"] = evaluate_summarization(
476
+ pipeline,
477
+ val_path,
478
  max_samples=args.max_samples,
479
  include_bertscore=args.include_bertscore,
480
  compute_bootstrap=args.bootstrap,
481
  )
482
  else:
483
  print("Warning: summarization validation data not found, skipping")
484
+
485
  # Evaluate emotion
486
  if eval_all or args.emotion_only:
487
  val_path = args.data_dir / "emotion" / "validation.jsonl"
 
489
  val_path = args.data_dir / "emotion" / "val.jsonl"
490
  if val_path.exists():
491
  results["emotion"] = evaluate_emotion(
492
+ pipeline,
493
+ val_path,
494
  max_samples=args.max_samples,
495
  tune_thresholds=args.tune_thresholds,
496
  compute_bootstrap=args.bootstrap,
497
  )
498
  else:
499
  print("Warning: emotion validation data not found, skipping")
500
+
501
  # Evaluate topic
502
  if eval_all or args.topic_only:
503
  val_path = args.data_dir / "topic" / "validation.jsonl"
 
505
  val_path = args.data_dir / "topic" / "val.jsonl"
506
  if val_path.exists():
507
  results["topic"] = evaluate_topic(
508
+ pipeline,
509
+ val_path,
510
  max_samples=args.max_samples,
511
  compute_bootstrap=args.bootstrap,
512
  )
513
  else:
514
  print("Warning: topic validation data not found, skipping")
515
+
516
  # Save results
517
  print("\n" + "=" * 60)
518
  print("SAVING RESULTS")
519
  print("=" * 60)
520
+
521
  args.output.parent.mkdir(parents=True, exist_ok=True)
522
  with open(args.output, "w") as f:
523
  json.dump(results, f, indent=2)
524
  print(f" Saved to: {args.output}")
525
+
526
  # Final summary
527
  elapsed = time.perf_counter() - start_time
528
  print("\n" + "=" * 60)
529
  print("EVALUATION COMPLETE")
530
  print("=" * 60)
531
+ print(f" Time: {elapsed / 60:.1f} minutes")
532
+
533
  if "summarization" in results:
534
  s = results["summarization"]
535
  print("\n Summarization:")
 
539
  print(f" BLEU-4: {s['bleu4']:.4f}")
540
  if "bertscore_f1" in s:
541
  print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
542
+
543
  if "emotion" in results:
544
  e = results["emotion"]
545
  print("\n Emotion:")
546
  print(f" Sample-avg F1: {e['sample_avg_f1']:.4f}")
547
  print(f" Macro F1: {e['macro_f1']:.4f}")
548
  print(f" Micro F1: {e['micro_f1']:.4f}")
549
+
550
  if "topic" in results:
551
  print("\n Topic:")
552
  print(f" Accuracy: {results['topic']['accuracy']:.2%}")
scripts/profile_training.py CHANGED
@@ -96,10 +96,12 @@ def main(cfg: DictConfig) -> None:
96
 
97
  tok_cfg = data_cfg.get("tokenizer", {})
98
  max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
99
- tokenizer = Tokenizer(TokenizerConfig(
100
- pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
101
- max_length=max_len,
102
- ))
 
 
103
 
104
  summ_train = SummarizationDataset(summ_splits["train"])
105
  emot_train = EmotionDataset(emot_splits["train"])
@@ -112,23 +114,42 @@ def main(cfg: DictConfig) -> None:
112
 
113
  train_loaders = {
114
  "summarization": build_summarization_dataloader(
115
- summ_train, tokenizer, shuffle=True,
116
- max_source_length=max_len, max_target_length=max_len,
117
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
118
  ),
119
  "emotion": build_emotion_dataloader(
120
- emot_train, tokenizer, shuffle=True, max_length=classification_max_len,
121
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
122
  ),
123
  "topic": build_topic_dataloader(
124
- topic_train, tokenizer, shuffle=True, max_length=classification_max_len,
125
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
126
  ),
127
  }
128
 
129
  # Build model
130
- grad_ckpt = cfg.training.get("gradient_checkpointing", cfg.model.get("gradient_checkpointing", False))
131
- use_rel_pos = cfg.training.get("use_relative_position_bias", cfg.model.get("use_relative_position_bias", False))
 
 
 
 
132
 
133
  model_cfg = ModelConfig(
134
  d_model=cfg.model.d_model,
@@ -202,8 +223,10 @@ def main(cfg: DictConfig) -> None:
202
  except StopIteration:
203
  iterators[task] = iter(train_loaders[task])
204
  batch = next(iterators[task])
205
- return {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
206
- for k, v in batch.items()}
 
 
207
 
208
  def training_step(step):
209
  """One training step across all tasks."""
@@ -219,7 +242,8 @@ def main(cfg: DictConfig) -> None:
219
  loss = torch.nn.functional.cross_entropy(
220
  logits.view(-1, logits.size(-1)),
221
  batch["labels"].view(-1),
222
- ignore_index=-100, label_smoothing=0.1,
 
223
  )
224
  elif task == "emotion":
225
  inputs = {"input_ids": batch["input_ids"]}
@@ -262,7 +286,10 @@ def main(cfg: DictConfig) -> None:
262
  torch.profiler.ProfilerActivity.CUDA,
263
  ],
264
  schedule=torch.profiler.schedule(
265
- wait=1, warmup=2, active=active_steps - 3, repeat=1,
 
 
 
266
  ),
267
  on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
268
  record_shapes=True,
 
96
 
97
  tok_cfg = data_cfg.get("tokenizer", {})
98
  max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
99
+ tokenizer = Tokenizer(
100
+ TokenizerConfig(
101
+ pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
102
+ max_length=max_len,
103
+ )
104
+ )
105
 
106
  summ_train = SummarizationDataset(summ_splits["train"])
107
  emot_train = EmotionDataset(emot_splits["train"])
 
114
 
115
  train_loaders = {
116
  "summarization": build_summarization_dataloader(
117
+ summ_train,
118
+ tokenizer,
119
+ shuffle=True,
120
+ max_source_length=max_len,
121
+ max_target_length=max_len,
122
+ batch_size=batch_size,
123
+ num_workers=num_workers,
124
+ pin_memory=True,
125
  ),
126
  "emotion": build_emotion_dataloader(
127
+ emot_train,
128
+ tokenizer,
129
+ shuffle=True,
130
+ max_length=classification_max_len,
131
+ batch_size=batch_size,
132
+ num_workers=num_workers,
133
+ pin_memory=True,
134
  ),
135
  "topic": build_topic_dataloader(
136
+ topic_train,
137
+ tokenizer,
138
+ shuffle=True,
139
+ max_length=classification_max_len,
140
+ batch_size=batch_size,
141
+ num_workers=num_workers,
142
+ pin_memory=True,
143
  ),
144
  }
145
 
146
  # Build model
147
+ grad_ckpt = cfg.training.get(
148
+ "gradient_checkpointing", cfg.model.get("gradient_checkpointing", False)
149
+ )
150
+ use_rel_pos = cfg.training.get(
151
+ "use_relative_position_bias", cfg.model.get("use_relative_position_bias", False)
152
+ )
153
 
154
  model_cfg = ModelConfig(
155
  d_model=cfg.model.d_model,
 
223
  except StopIteration:
224
  iterators[task] = iter(train_loaders[task])
225
  batch = next(iterators[task])
226
+ return {
227
+ k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
228
+ for k, v in batch.items()
229
+ }
230
 
231
  def training_step(step):
232
  """One training step across all tasks."""
 
242
  loss = torch.nn.functional.cross_entropy(
243
  logits.view(-1, logits.size(-1)),
244
  batch["labels"].view(-1),
245
+ ignore_index=-100,
246
+ label_smoothing=0.1,
247
  )
248
  elif task == "emotion":
249
  inputs = {"input_ids": batch["input_ids"]}
 
286
  torch.profiler.ProfilerActivity.CUDA,
287
  ],
288
  schedule=torch.profiler.schedule(
289
+ wait=1,
290
+ warmup=2,
291
+ active=active_steps - 3,
292
+ repeat=1,
293
  ),
294
  on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
295
  record_shapes=True,
scripts/train.py CHANGED
@@ -56,6 +56,7 @@ def set_seed(seed: int) -> None:
56
  import random
57
 
58
  import numpy as np
 
59
  random.seed(seed)
60
  np.random.seed(seed)
61
  torch.manual_seed(seed)
@@ -78,20 +79,20 @@ def load_splits(data_dir: Path, loader_fn) -> Dict[str, list]:
78
  def main(cfg: DictConfig) -> None:
79
  """Main training entry point."""
80
  start_time = time.perf_counter()
81
-
82
  print("=" * 60)
83
  print("LexiMind Training")
84
  print("=" * 60)
85
  print(OmegaConf.to_yaml(cfg))
86
-
87
  set_seed(cfg.seed)
88
  device = torch.device(cfg.device)
89
-
90
  # GPU optimizations for Ampere+
91
  if device.type == "cuda":
92
  # Enable cudnn benchmark for fixed-size inputs (10-20% speedup)
93
  torch.backends.cudnn.benchmark = True
94
-
95
  if torch.cuda.get_device_capability()[0] >= 8:
96
  torch.set_float32_matmul_precision("high")
97
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -99,18 +100,18 @@ def main(cfg: DictConfig) -> None:
99
  print(" TF32 + cudnn.benchmark enabled (Ampere GPU)")
100
  else:
101
  print(" cudnn.benchmark enabled")
102
-
103
  # --------------- Load Data ---------------
104
-
105
  print("\nLoading datasets...")
106
  data_cfg = cfg.data
107
  trainer_cfg = cfg.training.get("trainer", {})
108
-
109
  # Load splits
110
  summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
111
  emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
112
  topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
113
-
114
  # Apply sample limits for dev runs
115
  max_train = trainer_cfg.get("max_train_samples")
116
  max_val = trainer_cfg.get("max_val_samples")
@@ -121,86 +122,130 @@ def main(cfg: DictConfig) -> None:
121
  for splits in [summ_splits, emot_splits, topic_splits]:
122
  if "val" in splits:
123
  splits["val"] = splits["val"][:max_val]
124
-
125
- print(f" Summarization: {len(summ_splits['train']):,} train, {len(summ_splits.get('val', [])):,} val")
126
- print(f" Emotion: {len(emot_splits['train']):,} train, {len(emot_splits.get('val', [])):,} val")
127
- print(f" Topic: {len(topic_splits['train']):,} train, {len(topic_splits.get('val', [])):,} val")
128
-
 
 
 
 
 
 
129
  # --------------- Tokenizer ---------------
130
-
131
  tok_cfg = data_cfg.get("tokenizer", {})
132
  max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
133
-
134
- tokenizer = Tokenizer(TokenizerConfig(
135
- pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
136
- max_length=max_len,
137
- ))
 
 
138
  print(f" Tokenizer: {tokenizer.vocab_size:,} vocab, max_len={max_len}")
139
-
140
  # --------------- Datasets ---------------
141
-
142
  summ_train = SummarizationDataset(summ_splits["train"])
143
  summ_val = SummarizationDataset(summ_splits.get("val", []))
144
  emot_train = EmotionDataset(emot_splits["train"])
145
  emot_val = EmotionDataset(emot_splits.get("val", []), binarizer=emot_train.binarizer)
146
  topic_train = TopicDataset(topic_splits["train"])
147
  topic_val = TopicDataset(topic_splits.get("val", []), encoder=topic_train.encoder)
148
-
149
  print(f" Emotions: {len(emot_train.emotion_classes)} classes")
150
- print(f" Topics: {len(topic_train.topic_classes)} classes → {list(map(str, topic_train.topic_classes))}")
151
-
 
 
152
  # --------------- DataLoaders ---------------
153
-
154
  dl_cfg = cfg.training.get("dataloader", {})
155
  batch_size = int(dl_cfg.get("batch_size", 8))
156
  num_workers = int(dl_cfg.get("num_workers", 4))
157
-
158
  # Classification tasks don't need full 512 tokens - 256 is sufficient
159
  # This speeds up emotion/topic forward passes significantly
160
  classification_max_len = min(256, max_len)
161
-
162
  train_loaders = {
163
  "summarization": build_summarization_dataloader(
164
- summ_train, tokenizer, shuffle=True,
165
- max_source_length=max_len, max_target_length=max_len,
166
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
167
  ),
168
  "emotion": build_emotion_dataloader(
169
- emot_train, tokenizer, shuffle=True, max_length=classification_max_len,
170
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
171
  ),
172
  "topic": build_topic_dataloader(
173
- topic_train, tokenizer, shuffle=True, max_length=classification_max_len,
174
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
175
  ),
176
  }
177
-
178
  val_loaders = {}
179
  if summ_val:
180
  val_loaders["summarization"] = build_summarization_dataloader(
181
- summ_val, tokenizer, shuffle=False,
182
- max_source_length=max_len, max_target_length=max_len,
183
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
184
  )
185
  if emot_val:
186
  val_loaders["emotion"] = build_emotion_dataloader(
187
- emot_val, tokenizer, shuffle=False, max_length=classification_max_len,
188
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
189
  )
190
  if topic_val:
191
  val_loaders["topic"] = build_topic_dataloader(
192
- topic_val, tokenizer, shuffle=False, max_length=classification_max_len,
193
- batch_size=batch_size, num_workers=num_workers, pin_memory=True,
 
 
 
 
 
194
  )
195
-
196
  # --------------- Model ---------------
197
-
198
  print("\nBuilding model...")
199
-
200
  # Check for overrides in training config
201
- grad_ckpt = cfg.training.get("gradient_checkpointing", cfg.model.get("gradient_checkpointing", False))
202
- use_rel_pos = cfg.training.get("use_relative_position_bias", cfg.model.get("use_relative_position_bias", False))
203
-
 
 
 
 
204
  model_cfg = ModelConfig(
205
  d_model=cfg.model.d_model,
206
  vocab_size=getattr(cfg.model, "vocab_size", None),
@@ -215,42 +260,42 @@ def main(cfg: DictConfig) -> None:
215
  use_relative_position_bias=use_rel_pos,
216
  gradient_checkpointing=grad_ckpt,
217
  )
218
-
219
  if grad_ckpt:
220
  print(" Gradient checkpointing: on")
221
  if not use_rel_pos:
222
  print(" FlashAttention: on (no relative position bias)")
223
-
224
  model = build_multitask_model(
225
  tokenizer,
226
  num_emotions=len(emot_train.emotion_classes),
227
  num_topics=len(topic_train.topic_classes),
228
  config=model_cfg,
229
  ).to(device)
230
-
231
  param_count = sum(p.numel() for p in model.parameters())
232
- print(f" Parameters: {param_count:,} ({param_count/1e6:.1f}M)")
233
-
234
  # Freeze lower encoder layers (keeps pretrained language understanding, adapts upper layers)
235
  freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
236
  if freeze_layers > 0:
237
  frozen_params = 0
238
  # Freeze embedding layer
239
- if hasattr(model.encoder, 'embed_tokens'):
240
  for p in model.encoder.embed_tokens.parameters():
241
  p.requires_grad = False
242
  frozen_params += p.numel()
243
  # Freeze specified number of encoder layers
244
- if hasattr(model.encoder, 'layers'):
245
  for i, layer in enumerate(model.encoder.layers):
246
  if i < freeze_layers:
247
  for p in layer.parameters():
248
  p.requires_grad = False
249
  frozen_params += p.numel()
250
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
251
- print(f" Frozen layers: 0-{freeze_layers-1} ({frozen_params/1e6:.1f}M params)")
252
- print(f" Trainable: {trainable:,} ({trainable/1e6:.1f}M)")
253
-
254
  # Resume from checkpoint?
255
  start_epoch = 1
256
  resume_path = cfg.get("resume_from")
@@ -258,10 +303,11 @@ def main(cfg: DictConfig) -> None:
258
  print(f" Resuming from: {resume_path}")
259
  load_state(model, str(resume_path))
260
  import re
 
261
  digits = re.findall(r"\d+", Path(resume_path).stem)
262
  if digits:
263
  start_epoch = int(digits[-1]) + 1
264
-
265
  # Compile model for speed
266
  # Note: "reduce-overhead" mode uses CUDA graphs which conflicts with gradient checkpointing
267
  # Use "default" mode when checkpointing is enabled
@@ -272,13 +318,13 @@ def main(cfg: DictConfig) -> None:
272
  if cfg.training.get("compile_decoder", True):
273
  model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
274
  print(f" Decoder compiled ({compile_mode})")
275
-
276
  # --------------- Train ---------------
277
-
278
  print("\nStarting training...")
279
  opt_cfg = cfg.training.get("optimizer", {})
280
  sched_cfg = cfg.training.get("scheduler", {})
281
-
282
  # Use fused AdamW on CUDA for ~5-10% speedup
283
  use_fused = device.type == "cuda" and "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
284
  optimizer = torch.optim.AdamW(
@@ -289,7 +335,7 @@ def main(cfg: DictConfig) -> None:
289
  )
290
  if use_fused:
291
  print(" Fused AdamW: on")
292
-
293
  trainer = Trainer(
294
  model=model,
295
  optimizer=optimizer,
@@ -309,38 +355,38 @@ def main(cfg: DictConfig) -> None:
309
  device=device,
310
  tokenizer=tokenizer,
311
  )
312
-
313
  # Checkpoint callback
314
  ckpt_dir = Path(cfg.checkpoint_out).parent
315
- best_val_loss = float('inf')
316
-
317
  def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
318
  nonlocal best_val_loss
319
  ckpt_dir.mkdir(parents=True, exist_ok=True)
320
-
321
  # Save epoch checkpoint
322
  save_state(model, str(ckpt_dir / f"epoch_{epoch}.pt"))
323
-
324
  # Track best
325
  val_key = f"val_epoch_{epoch}"
326
  if val_key in history:
327
- val_loss = history[val_key].get("total_loss", float('inf'))
328
  if val_loss < best_val_loss:
329
  best_val_loss = val_loss
330
  save_state(model, str(ckpt_dir / "best.pt"))
331
  print(f" New best model saved (val_loss={val_loss:.4f})")
332
-
333
  history = trainer.fit(
334
  train_loaders,
335
  val_loaders if val_loaders else None,
336
  checkpoint_callback=save_checkpoint,
337
  start_epoch=start_epoch,
338
  )
339
-
340
  # --------------- Save Outputs ---------------
341
-
342
  print("\nSaving outputs...")
343
-
344
  # Labels
345
  labels_path = Path(cfg.labels_out)
346
  save_label_metadata(
@@ -348,17 +394,17 @@ def main(cfg: DictConfig) -> None:
348
  labels_path,
349
  )
350
  print(f" Labels: {labels_path}")
351
-
352
  # History
353
  history_path = Path(cfg.history_out)
354
  history_path.parent.mkdir(parents=True, exist_ok=True)
355
  with history_path.open("w") as f:
356
  json.dump(history, f, indent=2)
357
  print(f" History: {history_path}")
358
-
359
  total_time = time.perf_counter() - start_time
360
  print(f"\n{'=' * 60}")
361
- print(f"Training complete in {total_time/60:.1f} minutes")
362
  print(f" Best checkpoint: {ckpt_dir / 'best.pt'}")
363
  print(f"{'=' * 60}")
364
 
 
56
  import random
57
 
58
  import numpy as np
59
+
60
  random.seed(seed)
61
  np.random.seed(seed)
62
  torch.manual_seed(seed)
 
79
  def main(cfg: DictConfig) -> None:
80
  """Main training entry point."""
81
  start_time = time.perf_counter()
82
+
83
  print("=" * 60)
84
  print("LexiMind Training")
85
  print("=" * 60)
86
  print(OmegaConf.to_yaml(cfg))
87
+
88
  set_seed(cfg.seed)
89
  device = torch.device(cfg.device)
90
+
91
  # GPU optimizations for Ampere+
92
  if device.type == "cuda":
93
  # Enable cudnn benchmark for fixed-size inputs (10-20% speedup)
94
  torch.backends.cudnn.benchmark = True
95
+
96
  if torch.cuda.get_device_capability()[0] >= 8:
97
  torch.set_float32_matmul_precision("high")
98
  torch.backends.cuda.matmul.allow_tf32 = True
 
100
  print(" TF32 + cudnn.benchmark enabled (Ampere GPU)")
101
  else:
102
  print(" cudnn.benchmark enabled")
103
+
104
  # --------------- Load Data ---------------
105
+
106
  print("\nLoading datasets...")
107
  data_cfg = cfg.data
108
  trainer_cfg = cfg.training.get("trainer", {})
109
+
110
  # Load splits
111
  summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
112
  emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
113
  topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
114
+
115
  # Apply sample limits for dev runs
116
  max_train = trainer_cfg.get("max_train_samples")
117
  max_val = trainer_cfg.get("max_val_samples")
 
122
  for splits in [summ_splits, emot_splits, topic_splits]:
123
  if "val" in splits:
124
  splits["val"] = splits["val"][:max_val]
125
+
126
+ print(
127
+ f" Summarization: {len(summ_splits['train']):,} train, {len(summ_splits.get('val', [])):,} val"
128
+ )
129
+ print(
130
+ f" Emotion: {len(emot_splits['train']):,} train, {len(emot_splits.get('val', [])):,} val"
131
+ )
132
+ print(
133
+ f" Topic: {len(topic_splits['train']):,} train, {len(topic_splits.get('val', [])):,} val"
134
+ )
135
+
136
  # --------------- Tokenizer ---------------
137
+
138
  tok_cfg = data_cfg.get("tokenizer", {})
139
  max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512))
140
+
141
+ tokenizer = Tokenizer(
142
+ TokenizerConfig(
143
+ pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
144
+ max_length=max_len,
145
+ )
146
+ )
147
  print(f" Tokenizer: {tokenizer.vocab_size:,} vocab, max_len={max_len}")
148
+
149
  # --------------- Datasets ---------------
150
+
151
  summ_train = SummarizationDataset(summ_splits["train"])
152
  summ_val = SummarizationDataset(summ_splits.get("val", []))
153
  emot_train = EmotionDataset(emot_splits["train"])
154
  emot_val = EmotionDataset(emot_splits.get("val", []), binarizer=emot_train.binarizer)
155
  topic_train = TopicDataset(topic_splits["train"])
156
  topic_val = TopicDataset(topic_splits.get("val", []), encoder=topic_train.encoder)
157
+
158
  print(f" Emotions: {len(emot_train.emotion_classes)} classes")
159
+ print(
160
+ f" Topics: {len(topic_train.topic_classes)} classes → {list(map(str, topic_train.topic_classes))}"
161
+ )
162
+
163
  # --------------- DataLoaders ---------------
164
+
165
  dl_cfg = cfg.training.get("dataloader", {})
166
  batch_size = int(dl_cfg.get("batch_size", 8))
167
  num_workers = int(dl_cfg.get("num_workers", 4))
168
+
169
  # Classification tasks don't need full 512 tokens - 256 is sufficient
170
  # This speeds up emotion/topic forward passes significantly
171
  classification_max_len = min(256, max_len)
172
+
173
  train_loaders = {
174
  "summarization": build_summarization_dataloader(
175
+ summ_train,
176
+ tokenizer,
177
+ shuffle=True,
178
+ max_source_length=max_len,
179
+ max_target_length=max_len,
180
+ batch_size=batch_size,
181
+ num_workers=num_workers,
182
+ pin_memory=True,
183
  ),
184
  "emotion": build_emotion_dataloader(
185
+ emot_train,
186
+ tokenizer,
187
+ shuffle=True,
188
+ max_length=classification_max_len,
189
+ batch_size=batch_size,
190
+ num_workers=num_workers,
191
+ pin_memory=True,
192
  ),
193
  "topic": build_topic_dataloader(
194
+ topic_train,
195
+ tokenizer,
196
+ shuffle=True,
197
+ max_length=classification_max_len,
198
+ batch_size=batch_size,
199
+ num_workers=num_workers,
200
+ pin_memory=True,
201
  ),
202
  }
203
+
204
  val_loaders = {}
205
  if summ_val:
206
  val_loaders["summarization"] = build_summarization_dataloader(
207
+ summ_val,
208
+ tokenizer,
209
+ shuffle=False,
210
+ max_source_length=max_len,
211
+ max_target_length=max_len,
212
+ batch_size=batch_size,
213
+ num_workers=num_workers,
214
+ pin_memory=True,
215
  )
216
  if emot_val:
217
  val_loaders["emotion"] = build_emotion_dataloader(
218
+ emot_val,
219
+ tokenizer,
220
+ shuffle=False,
221
+ max_length=classification_max_len,
222
+ batch_size=batch_size,
223
+ num_workers=num_workers,
224
+ pin_memory=True,
225
  )
226
  if topic_val:
227
  val_loaders["topic"] = build_topic_dataloader(
228
+ topic_val,
229
+ tokenizer,
230
+ shuffle=False,
231
+ max_length=classification_max_len,
232
+ batch_size=batch_size,
233
+ num_workers=num_workers,
234
+ pin_memory=True,
235
  )
236
+
237
  # --------------- Model ---------------
238
+
239
  print("\nBuilding model...")
240
+
241
  # Check for overrides in training config
242
+ grad_ckpt = cfg.training.get(
243
+ "gradient_checkpointing", cfg.model.get("gradient_checkpointing", False)
244
+ )
245
+ use_rel_pos = cfg.training.get(
246
+ "use_relative_position_bias", cfg.model.get("use_relative_position_bias", False)
247
+ )
248
+
249
  model_cfg = ModelConfig(
250
  d_model=cfg.model.d_model,
251
  vocab_size=getattr(cfg.model, "vocab_size", None),
 
260
  use_relative_position_bias=use_rel_pos,
261
  gradient_checkpointing=grad_ckpt,
262
  )
263
+
264
  if grad_ckpt:
265
  print(" Gradient checkpointing: on")
266
  if not use_rel_pos:
267
  print(" FlashAttention: on (no relative position bias)")
268
+
269
  model = build_multitask_model(
270
  tokenizer,
271
  num_emotions=len(emot_train.emotion_classes),
272
  num_topics=len(topic_train.topic_classes),
273
  config=model_cfg,
274
  ).to(device)
275
+
276
  param_count = sum(p.numel() for p in model.parameters())
277
+ print(f" Parameters: {param_count:,} ({param_count / 1e6:.1f}M)")
278
+
279
  # Freeze lower encoder layers (keeps pretrained language understanding, adapts upper layers)
280
  freeze_layers = cfg.training.get("freeze_encoder_layers", 0)
281
  if freeze_layers > 0:
282
  frozen_params = 0
283
  # Freeze embedding layer
284
+ if hasattr(model.encoder, "embed_tokens"):
285
  for p in model.encoder.embed_tokens.parameters():
286
  p.requires_grad = False
287
  frozen_params += p.numel()
288
  # Freeze specified number of encoder layers
289
+ if hasattr(model.encoder, "layers"):
290
  for i, layer in enumerate(model.encoder.layers):
291
  if i < freeze_layers:
292
  for p in layer.parameters():
293
  p.requires_grad = False
294
  frozen_params += p.numel()
295
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
296
+ print(f" Frozen layers: 0-{freeze_layers - 1} ({frozen_params / 1e6:.1f}M params)")
297
+ print(f" Trainable: {trainable:,} ({trainable / 1e6:.1f}M)")
298
+
299
  # Resume from checkpoint?
300
  start_epoch = 1
301
  resume_path = cfg.get("resume_from")
 
303
  print(f" Resuming from: {resume_path}")
304
  load_state(model, str(resume_path))
305
  import re
306
+
307
  digits = re.findall(r"\d+", Path(resume_path).stem)
308
  if digits:
309
  start_epoch = int(digits[-1]) + 1
310
+
311
  # Compile model for speed
312
  # Note: "reduce-overhead" mode uses CUDA graphs which conflicts with gradient checkpointing
313
  # Use "default" mode when checkpointing is enabled
 
318
  if cfg.training.get("compile_decoder", True):
319
  model.decoder = torch.compile(model.decoder, mode=compile_mode) # type: ignore[assignment]
320
  print(f" Decoder compiled ({compile_mode})")
321
+
322
  # --------------- Train ---------------
323
+
324
  print("\nStarting training...")
325
  opt_cfg = cfg.training.get("optimizer", {})
326
  sched_cfg = cfg.training.get("scheduler", {})
327
+
328
  # Use fused AdamW on CUDA for ~5-10% speedup
329
  use_fused = device.type == "cuda" and "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
330
  optimizer = torch.optim.AdamW(
 
335
  )
336
  if use_fused:
337
  print(" Fused AdamW: on")
338
+
339
  trainer = Trainer(
340
  model=model,
341
  optimizer=optimizer,
 
355
  device=device,
356
  tokenizer=tokenizer,
357
  )
358
+
359
  # Checkpoint callback
360
  ckpt_dir = Path(cfg.checkpoint_out).parent
361
+ best_val_loss = float("inf")
362
+
363
  def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
364
  nonlocal best_val_loss
365
  ckpt_dir.mkdir(parents=True, exist_ok=True)
366
+
367
  # Save epoch checkpoint
368
  save_state(model, str(ckpt_dir / f"epoch_{epoch}.pt"))
369
+
370
  # Track best
371
  val_key = f"val_epoch_{epoch}"
372
  if val_key in history:
373
+ val_loss = history[val_key].get("total_loss", float("inf"))
374
  if val_loss < best_val_loss:
375
  best_val_loss = val_loss
376
  save_state(model, str(ckpt_dir / "best.pt"))
377
  print(f" New best model saved (val_loss={val_loss:.4f})")
378
+
379
  history = trainer.fit(
380
  train_loaders,
381
  val_loaders if val_loaders else None,
382
  checkpoint_callback=save_checkpoint,
383
  start_epoch=start_epoch,
384
  )
385
+
386
  # --------------- Save Outputs ---------------
387
+
388
  print("\nSaving outputs...")
389
+
390
  # Labels
391
  labels_path = Path(cfg.labels_out)
392
  save_label_metadata(
 
394
  labels_path,
395
  )
396
  print(f" Labels: {labels_path}")
397
+
398
  # History
399
  history_path = Path(cfg.history_out)
400
  history_path.parent.mkdir(parents=True, exist_ok=True)
401
  with history_path.open("w") as f:
402
  json.dump(history, f, indent=2)
403
  print(f" History: {history_path}")
404
+
405
  total_time = time.perf_counter() - start_time
406
  print(f"\n{'=' * 60}")
407
+ print(f"Training complete in {total_time / 60:.1f} minutes")
408
  print(f" Best checkpoint: {ckpt_dir / 'best.pt'}")
409
  print(f"{'=' * 60}")
410
 
scripts/train_multiseed.py CHANGED
@@ -30,7 +30,8 @@ def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
30
  seed_dir.mkdir(parents=True, exist_ok=True)
31
 
32
  cmd = [
33
- sys.executable, "scripts/train.py",
 
34
  f"seed={seed}",
35
  f"checkpoint_out={seed_dir}/checkpoints/best.pt",
36
  f"history_out={seed_dir}/training_history.json",
@@ -39,9 +40,9 @@ def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
39
  if config_overrides:
40
  cmd.extend(config_overrides.split())
41
 
42
- print(f"\n{'='*60}")
43
  print(f"Training seed {seed}")
44
- print(f"{'='*60}")
45
  print(f" Command: {' '.join(cmd)}")
46
 
47
  result = subprocess.run(cmd, capture_output=False)
@@ -69,7 +70,8 @@ def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = Non
69
  return {}
70
 
71
  cmd = [
72
- sys.executable, "scripts/evaluate.py",
 
73
  f"--checkpoint={checkpoint}",
74
  f"--labels={labels}",
75
  f"--output={output}",
@@ -105,7 +107,11 @@ def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
105
  if not isinstance(task_metrics, dict):
106
  continue
107
  for metric_name, value in task_metrics.items():
108
- if isinstance(value, (int, float)) and metric_name != "num_samples" and metric_name != "num_classes":
 
 
 
 
109
  key = f"{task}/{metric_name}"
110
  metric_values.setdefault(key, []).append(float(value))
111
 
@@ -125,9 +131,9 @@ def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
125
 
126
  def print_summary(aggregated: Dict, seeds: List[int]) -> None:
127
  """Print human-readable summary of multi-seed results."""
128
- print(f"\n{'='*70}")
129
  print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
130
- print(f"{'='*70}")
131
 
132
  # Group by task
133
  tasks: Dict[str, Dict[str, Dict]] = {}
@@ -142,23 +148,32 @@ def print_summary(aggregated: Dict, seeds: List[int]) -> None:
142
  std = stats["std"]
143
  # Format based on metric type
144
  if "accuracy" in metric:
145
- print(f" {metric:25s}: {mean*100:.1f}% ± {std*100:.1f}%")
146
  else:
147
  print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
148
 
149
 
150
  def main():
151
  parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
152
- parser.add_argument("--seeds", nargs="+", type=int, default=[17, 42, 123],
153
- help="Random seeds to train with")
154
- parser.add_argument("--config", type=str, default="",
155
- help="Hydra config overrides (e.g., 'training=full')")
156
- parser.add_argument("--output-dir", type=Path, default=Path("outputs/multiseed"),
157
- help="Base output directory")
158
- parser.add_argument("--skip-training", action="store_true",
159
- help="Skip training, only aggregate existing results")
160
- parser.add_argument("--skip-eval", action="store_true",
161
- help="Skip evaluation, only aggregate training histories")
 
 
 
 
 
 
 
 
 
162
  args = parser.parse_args()
163
 
164
  args.output_dir.mkdir(parents=True, exist_ok=True)
@@ -184,11 +199,15 @@ def main():
184
  # Save aggregated results
185
  output_path = args.output_dir / "aggregated_results.json"
186
  with open(output_path, "w") as f:
187
- json.dump({
188
- "seeds": args.seeds,
189
- "per_seed": {str(k): v for k, v in all_eval_results.items()},
190
- "aggregated": aggregated,
191
- }, f, indent=2)
 
 
 
 
192
  print(f"\n Saved to: {output_path}")
193
  else:
194
  print("\nNo evaluation results to aggregate.")
 
30
  seed_dir.mkdir(parents=True, exist_ok=True)
31
 
32
  cmd = [
33
+ sys.executable,
34
+ "scripts/train.py",
35
  f"seed={seed}",
36
  f"checkpoint_out={seed_dir}/checkpoints/best.pt",
37
  f"history_out={seed_dir}/training_history.json",
 
40
  if config_overrides:
41
  cmd.extend(config_overrides.split())
42
 
43
+ print(f"\n{'=' * 60}")
44
  print(f"Training seed {seed}")
45
+ print(f"{'=' * 60}")
46
  print(f" Command: {' '.join(cmd)}")
47
 
48
  result = subprocess.run(cmd, capture_output=False)
 
70
  return {}
71
 
72
  cmd = [
73
+ sys.executable,
74
+ "scripts/evaluate.py",
75
  f"--checkpoint={checkpoint}",
76
  f"--labels={labels}",
77
  f"--output={output}",
 
107
  if not isinstance(task_metrics, dict):
108
  continue
109
  for metric_name, value in task_metrics.items():
110
+ if (
111
+ isinstance(value, (int, float))
112
+ and metric_name != "num_samples"
113
+ and metric_name != "num_classes"
114
+ ):
115
  key = f"{task}/{metric_name}"
116
  metric_values.setdefault(key, []).append(float(value))
117
 
 
131
 
132
  def print_summary(aggregated: Dict, seeds: List[int]) -> None:
133
  """Print human-readable summary of multi-seed results."""
134
+ print(f"\n{'=' * 70}")
135
  print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
136
+ print(f"{'=' * 70}")
137
 
138
  # Group by task
139
  tasks: Dict[str, Dict[str, Dict]] = {}
 
148
  std = stats["std"]
149
  # Format based on metric type
150
  if "accuracy" in metric:
151
+ print(f" {metric:25s}: {mean * 100:.1f}% ± {std * 100:.1f}%")
152
  else:
153
  print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
154
 
155
 
156
  def main():
157
  parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
158
+ parser.add_argument(
159
+ "--seeds", nargs="+", type=int, default=[17, 42, 123], help="Random seeds to train with"
160
+ )
161
+ parser.add_argument(
162
+ "--config", type=str, default="", help="Hydra config overrides (e.g., 'training=full')"
163
+ )
164
+ parser.add_argument(
165
+ "--output-dir", type=Path, default=Path("outputs/multiseed"), help="Base output directory"
166
+ )
167
+ parser.add_argument(
168
+ "--skip-training",
169
+ action="store_true",
170
+ help="Skip training, only aggregate existing results",
171
+ )
172
+ parser.add_argument(
173
+ "--skip-eval",
174
+ action="store_true",
175
+ help="Skip evaluation, only aggregate training histories",
176
+ )
177
  args = parser.parse_args()
178
 
179
  args.output_dir.mkdir(parents=True, exist_ok=True)
 
199
  # Save aggregated results
200
  output_path = args.output_dir / "aggregated_results.json"
201
  with open(output_path, "w") as f:
202
+ json.dump(
203
+ {
204
+ "seeds": args.seeds,
205
+ "per_seed": {str(k): v for k, v in all_eval_results.items()},
206
+ "aggregated": aggregated,
207
+ },
208
+ f,
209
+ indent=2,
210
+ )
211
  print(f"\n Saved to: {output_path}")
212
  else:
213
  print("\nNo evaluation results to aggregate.")
scripts/visualize_training.py CHANGED
@@ -81,31 +81,33 @@ ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
81
 
82
  # Professional color palette (accessible + publication-ready)
83
  COLORS = {
84
- "primary": "#2E86AB", # Deep blue - training
85
- "secondary": "#E94F37", # Coral red - validation
86
- "accent": "#28A745", # Green - best points
87
- "highlight": "#F7B801", # Gold - highlights
88
- "dark": "#1E3A5F", # Navy - text
89
- "light": "#F5F5F5", # Light gray - background
90
- "topic": "#8338EC", # Purple
91
- "emotion": "#FF6B6B", # Salmon
92
- "summary": "#06D6A0", # Teal
93
  }
94
 
95
  # Style configuration
96
  plt.style.use("seaborn-v0_8-whitegrid")
97
- plt.rcParams.update({
98
- "font.family": "sans-serif",
99
- "font.size": 11,
100
- "axes.titlesize": 14,
101
- "axes.titleweight": "bold",
102
- "axes.labelsize": 12,
103
- "legend.fontsize": 10,
104
- "figure.titlesize": 16,
105
- "figure.titleweight": "bold",
106
- "savefig.dpi": 150,
107
- "savefig.bbox": "tight",
108
- })
 
 
109
 
110
  # Custom colormap for heatmaps
111
  HEATMAP_CMAP = LinearSegmentedColormap.from_list(
@@ -115,12 +117,14 @@ HEATMAP_CMAP = LinearSegmentedColormap.from_list(
115
 
116
  # MLflow Utilities
117
 
 
118
  def get_mlflow_client():
119
  """Get MLflow client with correct tracking URI."""
120
  if not HAS_MLFLOW:
121
  raise ImportError("MLflow not installed. Install with: pip install mlflow")
122
  import mlflow
123
  import mlflow.tracking
 
124
  # Use SQLite database (same as trainer.py)
125
  mlflow.set_tracking_uri("sqlite:///mlruns.db")
126
  return mlflow.tracking.MlflowClient()
@@ -153,6 +157,7 @@ def get_metric_history(run, metric_name: str) -> tuple[list, list]:
153
 
154
  # Core Training Visualizations
155
 
 
156
  def plot_loss_curves(run, interactive: bool = False) -> None:
157
  """
158
  Plot training and validation loss over time.
@@ -164,37 +169,49 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
164
 
165
  if interactive and HAS_PLOTLY:
166
  import plotly.graph_objects as go
 
167
  fig = go.Figure()
168
 
169
  if train_values:
170
- fig.add_trace(go.Scatter(
171
- x=train_steps, y=train_values,
172
- name="Training Loss", mode="lines",
173
- line=dict(color=COLORS["primary"], width=3)
174
- ))
 
 
 
 
175
 
176
  if val_values:
177
- fig.add_trace(go.Scatter(
178
- x=val_steps, y=val_values,
179
- name="Validation Loss", mode="lines",
180
- line=dict(color=COLORS["secondary"], width=3)
181
- ))
 
 
 
 
182
 
183
  # Best point
184
  best_idx = int(np.argmin(val_values))
185
- fig.add_trace(go.Scatter(
186
- x=[val_steps[best_idx]], y=[val_values[best_idx]],
187
- name=f"Best: {val_values[best_idx]:.3f}",
188
- mode="markers",
189
- marker=dict(color=COLORS["accent"], size=15, symbol="star")
190
- ))
 
 
 
191
 
192
  fig.update_layout(
193
  title="Training Progress: Multi-Task Loss",
194
  xaxis_title="Epoch",
195
  yaxis_title="Loss",
196
  template="plotly_white",
197
- hovermode="x unified"
198
  )
199
 
200
  output_path = OUTPUTS_DIR / "training_loss_curve.html"
@@ -206,32 +223,62 @@ def plot_loss_curves(run, interactive: bool = False) -> None:
206
  fig, ax = plt.subplots(figsize=(12, 6))
207
 
208
  if not train_values:
209
- ax.text(0.5, 0.5, "No training data yet\n\nWaiting for first epoch...",
210
- ha="center", va="center", fontsize=14, color="gray")
 
 
 
 
 
 
 
211
  ax.set_xlim(0, 1)
212
  ax.set_ylim(0, 1)
213
  else:
214
  # Training curve
215
- ax.plot(train_steps, train_values, label="Training Loss", linewidth=2.5,
216
- color=COLORS["primary"], alpha=0.9)
 
 
 
 
 
 
217
 
218
  # Validation curve with best point
219
  if val_values:
220
- ax.plot(val_steps, val_values, label="Validation Loss", linewidth=2.5,
221
- color=COLORS["secondary"], alpha=0.9)
 
 
 
 
 
 
222
 
223
  best_idx = int(np.argmin(val_values))
224
- ax.scatter([val_steps[best_idx]], [val_values[best_idx]],
225
- s=200, c=COLORS["accent"], zorder=5, marker="*",
226
- edgecolors="white", linewidth=2,
227
- label=f"Best: {val_values[best_idx]:.3f}")
 
 
 
 
 
 
 
228
 
229
  # Annotate best point
230
- ax.annotate(f"Epoch {val_steps[best_idx]}",
231
- xy=(val_steps[best_idx], val_values[best_idx]),
232
- xytext=(10, 20), textcoords="offset points",
233
- fontsize=10, color=COLORS["accent"],
234
- arrowprops=dict(arrowstyle="->", color=COLORS["accent"]))
 
 
 
 
235
 
236
  ax.legend(fontsize=11, loc="upper right", framealpha=0.9)
237
  ax.set_ylim(bottom=0)
@@ -265,11 +312,22 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
265
  val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
266
 
267
  if train_sum:
268
- ax.plot([m.step for m in train_sum], [m.value for m in train_sum],
269
- label="Train", linewidth=2.5, color=COLORS["summary"])
 
 
 
 
 
270
  if val_sum:
271
- ax.plot([m.step for m in val_sum], [m.value for m in val_sum],
272
- label="Validation", linewidth=2.5, color=COLORS["secondary"], linestyle="--")
 
 
 
 
 
 
273
 
274
  ax.set_title("Summarization Loss")
275
  ax.set_xlabel("Epoch")
@@ -286,20 +344,43 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
286
  val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
287
 
288
  if train_emo:
289
- ax.plot([m.step for m in train_emo], [m.value for m in train_emo],
290
- label="Train Loss", linewidth=2.5, color=COLORS["emotion"])
 
 
 
 
 
291
  if val_emo:
292
- ax.plot([m.step for m in val_emo], [m.value for m in val_emo],
293
- label="Val Loss", linewidth=2.5, color=COLORS["secondary"], linestyle="--")
 
 
 
 
 
 
294
 
295
  # Secondary axis for F1
296
  ax2 = ax.twinx()
297
  if train_f1:
298
- ax2.plot([m.step for m in train_f1], [m.value for m in train_f1],
299
- label="Train F1", linewidth=2, color=COLORS["accent"], alpha=0.7)
 
 
 
 
 
 
300
  if val_f1:
301
- ax2.plot([m.step for m in val_f1], [m.value for m in val_f1],
302
- label="Val F1", linewidth=2, color=COLORS["highlight"], alpha=0.7)
 
 
 
 
 
 
303
  ax2.set_ylim(0, 1)
304
 
305
  ax.set_title("Emotion Detection (28 classes)")
@@ -320,19 +401,42 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
320
  val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
321
 
322
  if train_topic:
323
- ax.plot([m.step for m in train_topic], [m.value for m in train_topic],
324
- label="Train Loss", linewidth=2.5, color=COLORS["topic"])
 
 
 
 
 
325
  if val_topic:
326
- ax.plot([m.step for m in val_topic], [m.value for m in val_topic],
327
- label="Val Loss", linewidth=2.5, color=COLORS["secondary"], linestyle="--")
 
 
 
 
 
 
328
 
329
  ax2 = ax.twinx()
330
  if train_acc:
331
- ax2.plot([m.step for m in train_acc], [m.value for m in train_acc],
332
- label="Train Acc", linewidth=2, color=COLORS["accent"], alpha=0.7)
 
 
 
 
 
 
333
  if val_acc:
334
- ax2.plot([m.step for m in val_acc], [m.value for m in val_acc],
335
- label="Val Acc", linewidth=2, color=COLORS["highlight"], alpha=0.7)
 
 
 
 
 
 
336
  ax2.set_ylim(0, 1)
337
 
338
  ax.set_title("Topic Classification (4 classes)")
@@ -350,9 +454,11 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
350
  ax.axis("off")
351
 
352
  # Get final metrics
353
- summary_lines = ["+--------------------------------------+",
354
- "| FINAL METRICS (Last Epoch) |",
355
- "+--------------------------------------+"]
 
 
356
 
357
  if val_topic and val_acc:
358
  summary_lines.append(f"| Topic Accuracy: {val_acc[-1].value:>6.1%} |")
@@ -363,8 +469,15 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
363
 
364
  summary_lines.append("+--------------------------------------+")
365
 
366
- ax.text(0.1, 0.6, "\n".join(summary_lines), fontsize=11, family="monospace",
367
- verticalalignment="center", bbox=dict(boxstyle="round", facecolor=COLORS["light"]))
 
 
 
 
 
 
 
368
 
369
  # Add model info
370
  run_params = run.data.params
@@ -372,8 +485,7 @@ def plot_task_metrics(run, interactive: bool = False) -> None:
372
  model_info += f"Batch Size: {run_params.get('batch_size', 'N/A')}\n"
373
  model_info += f"Learning Rate: {run_params.get('learning_rate', 'N/A')}"
374
 
375
- ax.text(0.1, 0.15, model_info, fontsize=10, color="gray",
376
- verticalalignment="center")
377
 
378
  plt.tight_layout()
379
  output_path = OUTPUTS_DIR / "task_metrics.png"
@@ -392,13 +504,13 @@ def plot_learning_rate(run) -> None:
392
  if not lr_metrics or len(lr_metrics) < 2:
393
  # No LR data logged - generate theoretical schedule from config
394
  logger.info(" No LR metrics found - generating theoretical schedule...")
395
-
396
  # Get config from run params
397
  params = run.data.params
398
  lr_max = float(params.get("learning_rate", params.get("lr", 5e-5)))
399
  warmup_steps = int(params.get("warmup_steps", 500))
400
  max_epochs = int(params.get("max_epochs", 5))
401
-
402
  # Estimate total steps from training loss history
403
  train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
404
  if train_loss:
@@ -407,7 +519,7 @@ def plot_learning_rate(run) -> None:
407
  total_steps = max_epochs * estimated_steps_per_epoch
408
  else:
409
  total_steps = 4000 # Default fallback
410
-
411
  # Generate cosine schedule with warmup
412
  steps = np.arange(0, total_steps)
413
  values = []
@@ -418,25 +530,43 @@ def plot_learning_rate(run) -> None:
418
  progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
419
  lr = lr_max * max(0.1, 0.5 * (1 + np.cos(np.pi * progress)))
420
  values.append(lr)
421
-
422
  ax.fill_between(steps, values, alpha=0.3, color=COLORS["primary"])
423
  ax.plot(steps, values, linewidth=2.5, color=COLORS["primary"], label="Cosine + Warmup")
424
-
425
  # Mark warmup region
426
- ax.axvline(warmup_steps, color=COLORS["secondary"], linestyle="--",
427
- alpha=0.7, linewidth=2, label=f"Warmup End ({warmup_steps})")
 
 
 
 
 
 
428
  ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"])
429
-
430
  # Add annotation
431
- ax.annotate(f"Peak LR: {lr_max:.1e}", xy=(warmup_steps, lr_max),
432
- xytext=(warmup_steps + 200, lr_max * 0.9),
433
- fontsize=10, color=COLORS["dark"],
434
- arrowprops=dict(arrowstyle="->", color=COLORS["dark"], alpha=0.5))
435
-
 
 
 
 
436
  ax.legend(loc="upper right")
437
- ax.text(0.98, 0.02, "(Theoretical - actual LR not logged)",
438
- transform=ax.transAxes, ha="right", va="bottom",
439
- fontsize=9, color="gray", style="italic")
 
 
 
 
 
 
 
 
440
  else:
441
  steps = np.array([m.step for m in lr_metrics])
442
  values = [m.value for m in lr_metrics]
@@ -449,10 +579,15 @@ def plot_learning_rate(run) -> None:
449
  params = run.data.params
450
  warmup_steps = int(params.get("warmup_steps", 500))
451
  if warmup_steps < max(steps):
452
- ax.axvline(warmup_steps, color=COLORS["secondary"], linestyle="--",
453
- alpha=0.7, linewidth=2, label="Warmup End")
454
- ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"],
455
- label="Warmup Phase")
 
 
 
 
 
456
  ax.legend(loc="upper right")
457
 
458
  # Scientific notation for y-axis if needed
@@ -471,6 +606,7 @@ def plot_learning_rate(run) -> None:
471
 
472
  # Advanced Visualizations
473
 
 
474
  def plot_confusion_matrix(run, task: str = "topic") -> None:
475
  """
476
  Plot confusion matrix for classification tasks.
@@ -482,8 +618,16 @@ def plot_confusion_matrix(run, task: str = "topic") -> None:
482
  if task == "topic":
483
  default_labels = ["World", "Sports", "Business", "Sci/Tech"]
484
  else: # emotion - top 8 for visibility
485
- default_labels = ["admiration", "amusement", "anger", "annoyance",
486
- "approval", "caring", "curiosity", "desire"]
 
 
 
 
 
 
 
 
487
 
488
  if labels_path.exists():
489
  with open(labels_path) as f:
@@ -516,9 +660,16 @@ def plot_confusion_matrix(run, task: str = "topic") -> None:
516
  # Plot
517
  fig, ax = plt.subplots(figsize=(10, 8))
518
 
519
- sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap=HEATMAP_CMAP,
520
- xticklabels=labels[:n_classes], yticklabels=labels[:n_classes],
521
- ax=ax, cbar_kws={"label": "Proportion"})
 
 
 
 
 
 
 
522
 
523
  ax.set_title(f"Confusion Matrix: {task.title()} Classification")
524
  ax.set_xlabel("Predicted Label")
@@ -570,7 +721,7 @@ def plot_3d_loss_landscape(run) -> None:
570
 
571
  # Synthetic loss surface (bowl shape with some local minima)
572
  min_loss = min(val_loss) if val_loss else min(train_loss)
573
- Z = min_loss + 0.3 * (X**2 + Y**2) + 0.1 * np.sin(3*X) * np.cos(3*Y)
574
 
575
  # Add noise for realism
576
  Z += np.random.normal(0, 0.02, Z.shape)
@@ -584,41 +735,57 @@ def plot_3d_loss_landscape(run) -> None:
584
  fig = go.Figure()
585
 
586
  # Loss surface
587
- fig.add_trace(go.Surface(
588
- x=X, y=Y, z=Z,
589
- colorscale=[[0, COLORS["accent"]], [0.5, COLORS["primary"]], [1, COLORS["secondary"]]],
590
- opacity=0.8,
591
- showscale=True,
592
- colorbar=dict(title="Loss", x=1.02)
593
- ))
 
 
 
 
594
 
595
  # Training trajectory
596
- fig.add_trace(go.Scatter3d(
597
- x=trajectory_x, y=trajectory_y, z=trajectory_z,
598
- mode="lines+markers",
599
- line=dict(color=COLORS["highlight"], width=5),
600
- marker=dict(size=4, color=COLORS["highlight"]),
601
- name="Training Path"
602
- ))
 
 
 
 
603
 
604
  # Mark start and end
605
- fig.add_trace(go.Scatter3d(
606
- x=[trajectory_x[0]], y=[trajectory_y[0]], z=[trajectory_z[0]],
607
- mode="markers+text",
608
- marker=dict(size=10, color="red", symbol="circle"),
609
- text=["Start"],
610
- textposition="top center",
611
- name="Start"
612
- ))
613
-
614
- fig.add_trace(go.Scatter3d(
615
- x=[trajectory_x[-1]], y=[trajectory_y[-1]], z=[trajectory_z[-1]],
616
- mode="markers+text",
617
- marker=dict(size=10, color="green", symbol="diamond"),
618
- text=["Converged"],
619
- textposition="top center",
620
- name="Converged"
621
- ))
 
 
 
 
 
 
 
 
622
 
623
  fig.update_layout(
624
  title="Loss Landscape & Optimization Trajectory",
@@ -626,7 +793,7 @@ def plot_3d_loss_landscape(run) -> None:
626
  xaxis_title="Parameter Direction 1",
627
  yaxis_title="Parameter Direction 2",
628
  zaxis_title="Loss",
629
- camera=dict(eye=dict(x=1.5, y=1.5, z=0.8))
630
  ),
631
  width=900,
632
  height=700,
@@ -658,26 +825,46 @@ def plot_3d_loss_landscape_static(run) -> None:
658
  X, Y = np.meshgrid(x, y)
659
 
660
  min_loss = min(train_loss)
661
- Z = min_loss + 0.3 * (X**2 + Y**2) + 0.08 * np.sin(3*X) * np.cos(3*Y)
662
 
663
  fig = plt.figure(figsize=(12, 8))
664
  ax = fig.add_subplot(111, projection="3d")
665
 
666
  # Surface
667
- surf = ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.7,
668
- linewidth=0, antialiased=True)
669
 
670
  # Training path
671
  path_x = np.linspace(-1.5, 0, len(train_loss))
672
  path_y = np.linspace(1.2, 0, len(train_loss))
673
- ax.plot(path_x, path_y, train_loss, color=COLORS["secondary"],
674
- linewidth=3, label="Training Path", zorder=10)
 
 
 
 
 
 
 
675
 
676
  # Start/end markers
677
- ax.scatter([path_x[0]], [path_y[0]], train_loss[0], # type: ignore[arg-type]
678
- c="red", s=100, marker="o", label="Start")
679
- ax.scatter([path_x[-1]], [path_y[-1]], train_loss[-1], # type: ignore[arg-type]
680
- c="green", s=100, marker="*", label="Converged")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
 
682
  ax.set_xlabel("θ₁ Direction")
683
  ax.set_ylabel("θ₂ Direction")
@@ -722,7 +909,7 @@ def plot_embedding_space(run) -> None:
722
  for i in range(n_clusters):
723
  # Create cluster center
724
  center = np.random.randn(64) * 0.5
725
- center[i*16:(i+1)*16] += 3 # Make clusters separable
726
 
727
  # Add samples around center
728
  samples = center + np.random.randn(n_samples // n_clusters, 64) * 0.5
@@ -742,8 +929,14 @@ def plot_embedding_space(run) -> None:
742
 
743
  for i in range(n_clusters):
744
  mask = cluster_labels == i
745
- ax.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
746
- c=colors[i], label=labels[i], alpha=0.6, s=30)
 
 
 
 
 
 
747
 
748
  ax.set_xlabel("t-SNE Dimension 1")
749
  ax.set_ylabel("t-SNE Dimension 2")
@@ -787,14 +980,18 @@ def plot_training_dynamics(run) -> None:
787
  # Smoothed loss (exponential moving average)
788
  if len(train_loss) > 5:
789
  window = min(5, len(train_loss) // 2)
790
- smoothed = np.convolve(train_loss, np.ones(window)/window, mode="valid")
791
- smoothed_steps = train_steps[window-1:]
792
- ax.plot(smoothed_steps, smoothed, color=COLORS["primary"],
793
- linewidth=2.5, label="Training (smoothed)")
 
 
 
 
 
794
 
795
  if val_loss:
796
- ax.plot(val_steps, val_loss, color=COLORS["secondary"],
797
- linewidth=2.5, label="Validation")
798
 
799
  ax.set_title("Loss Convergence")
800
  ax.set_xlabel("Epoch")
@@ -806,8 +1003,10 @@ def plot_training_dynamics(run) -> None:
806
  ax = axes[0, 1]
807
 
808
  if len(train_loss) > 1:
809
- improvements = [-(train_loss[i] - train_loss[i-1])/train_loss[i-1] * 100
810
- for i in range(1, len(train_loss))]
 
 
811
  colors_bar = [COLORS["accent"] if imp > 0 else COLORS["secondary"] for imp in improvements]
812
  ax.bar(train_steps[1:], improvements, color=colors_bar, alpha=0.7)
813
  ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
@@ -862,6 +1061,7 @@ def plot_training_dynamics(run) -> None:
862
 
863
  # Dashboard Generator
864
 
 
865
  def generate_dashboard(run) -> None:
866
  """
867
  Generate an interactive HTML dashboard with all visualizations.
@@ -883,63 +1083,73 @@ def generate_dashboard(run) -> None:
883
 
884
  # Create subplots
885
  fig = make_subplots(
886
- rows=2, cols=2,
 
887
  subplot_titles=("Total Loss", "Task Losses", "Learning Rate", "Metrics"),
888
- specs=[[{}, {}], [{}, {}]]
889
  )
890
 
891
  # Total loss
892
  if train_loss:
893
  fig.add_trace(
894
- go.Scatter(x=train_steps, y=train_loss, name="Train Loss",
895
- line=dict(color=COLORS["primary"])),
896
- row=1, col=1
 
 
897
  )
898
  if val_loss:
899
  fig.add_trace(
900
- go.Scatter(x=val_steps, y=val_loss, name="Val Loss",
901
- line=dict(color=COLORS["secondary"])),
902
- row=1, col=1
 
 
903
  )
904
 
905
  # Per-task losses
906
- for task, color in [("summarization", COLORS["summary"]),
907
- ("emotion", COLORS["emotion"]),
908
- ("topic", COLORS["topic"])]:
 
 
909
  steps, values = get_metric_history(run, f"val_{task}_loss")
910
  if values:
911
  fig.add_trace(
912
- go.Scatter(x=steps, y=values, name=f"{task.title()} Loss",
913
- line=dict(color=color)),
914
- row=1, col=2
915
  )
916
 
917
  # Learning rate
918
  lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
919
  if lr_metrics:
920
  fig.add_trace(
921
- go.Scatter(x=[m.step for m in lr_metrics], y=[m.value for m in lr_metrics],
922
- name="Learning Rate", fill="tozeroy",
923
- line=dict(color=COLORS["primary"])),
924
- row=2, col=1
 
 
 
 
 
925
  )
926
 
927
  # Accuracy metrics
928
- for metric, color in [("topic_accuracy", COLORS["topic"]),
929
- ("emotion_f1", COLORS["emotion"])]:
930
  steps, values = get_metric_history(run, f"val_{metric}")
931
  if values:
932
  fig.add_trace(
933
- go.Scatter(x=steps, y=values, name=metric.replace("_", " ").title(),
934
- line=dict(color=color)),
935
- row=2, col=2
 
 
936
  )
937
 
938
  fig.update_layout(
939
- title="LexiMind Training Dashboard",
940
- height=800,
941
- template="plotly_white",
942
- showlegend=True
943
  )
944
 
945
  output_path = OUTPUTS_DIR / "training_dashboard.html"
@@ -949,17 +1159,20 @@ def generate_dashboard(run) -> None:
949
 
950
  # Main Entry Point
951
 
 
952
  def main():
953
  """Generate all training visualizations."""
954
  parser = argparse.ArgumentParser(description="LexiMind Visualization Suite")
955
- parser.add_argument("--interactive", action="store_true",
956
- help="Generate interactive HTML plots (requires plotly)")
957
- parser.add_argument("--landscape", action="store_true",
958
- help="Include 3D loss landscape visualization")
959
- parser.add_argument("--dashboard", action="store_true",
960
- help="Generate interactive dashboard")
961
- parser.add_argument("--all", action="store_true",
962
- help="Generate all visualizations")
 
 
963
  args = parser.parse_args()
964
 
965
  logger.info("=" * 60)
 
81
 
82
  # Professional color palette (accessible + publication-ready)
83
  COLORS = {
84
+ "primary": "#2E86AB", # Deep blue - training
85
+ "secondary": "#E94F37", # Coral red - validation
86
+ "accent": "#28A745", # Green - best points
87
+ "highlight": "#F7B801", # Gold - highlights
88
+ "dark": "#1E3A5F", # Navy - text
89
+ "light": "#F5F5F5", # Light gray - background
90
+ "topic": "#8338EC", # Purple
91
+ "emotion": "#FF6B6B", # Salmon
92
+ "summary": "#06D6A0", # Teal
93
  }
94
 
95
  # Style configuration
96
  plt.style.use("seaborn-v0_8-whitegrid")
97
+ plt.rcParams.update(
98
+ {
99
+ "font.family": "sans-serif",
100
+ "font.size": 11,
101
+ "axes.titlesize": 14,
102
+ "axes.titleweight": "bold",
103
+ "axes.labelsize": 12,
104
+ "legend.fontsize": 10,
105
+ "figure.titlesize": 16,
106
+ "figure.titleweight": "bold",
107
+ "savefig.dpi": 150,
108
+ "savefig.bbox": "tight",
109
+ }
110
+ )
111
 
112
  # Custom colormap for heatmaps
113
  HEATMAP_CMAP = LinearSegmentedColormap.from_list(
 
117
 
118
  # MLflow Utilities
119
 
120
+
121
  def get_mlflow_client():
122
  """Get MLflow client with correct tracking URI."""
123
  if not HAS_MLFLOW:
124
  raise ImportError("MLflow not installed. Install with: pip install mlflow")
125
  import mlflow
126
  import mlflow.tracking
127
+
128
  # Use SQLite database (same as trainer.py)
129
  mlflow.set_tracking_uri("sqlite:///mlruns.db")
130
  return mlflow.tracking.MlflowClient()
 
157
 
158
  # Core Training Visualizations
159
 
160
+
161
  def plot_loss_curves(run, interactive: bool = False) -> None:
162
  """
163
  Plot training and validation loss over time.
 
169
 
170
  if interactive and HAS_PLOTLY:
171
  import plotly.graph_objects as go
172
+
173
  fig = go.Figure()
174
 
175
  if train_values:
176
+ fig.add_trace(
177
+ go.Scatter(
178
+ x=train_steps,
179
+ y=train_values,
180
+ name="Training Loss",
181
+ mode="lines",
182
+ line=dict(color=COLORS["primary"], width=3),
183
+ )
184
+ )
185
 
186
  if val_values:
187
+ fig.add_trace(
188
+ go.Scatter(
189
+ x=val_steps,
190
+ y=val_values,
191
+ name="Validation Loss",
192
+ mode="lines",
193
+ line=dict(color=COLORS["secondary"], width=3),
194
+ )
195
+ )
196
 
197
  # Best point
198
  best_idx = int(np.argmin(val_values))
199
+ fig.add_trace(
200
+ go.Scatter(
201
+ x=[val_steps[best_idx]],
202
+ y=[val_values[best_idx]],
203
+ name=f"Best: {val_values[best_idx]:.3f}",
204
+ mode="markers",
205
+ marker=dict(color=COLORS["accent"], size=15, symbol="star"),
206
+ )
207
+ )
208
 
209
  fig.update_layout(
210
  title="Training Progress: Multi-Task Loss",
211
  xaxis_title="Epoch",
212
  yaxis_title="Loss",
213
  template="plotly_white",
214
+ hovermode="x unified",
215
  )
216
 
217
  output_path = OUTPUTS_DIR / "training_loss_curve.html"
 
223
  fig, ax = plt.subplots(figsize=(12, 6))
224
 
225
  if not train_values:
226
+ ax.text(
227
+ 0.5,
228
+ 0.5,
229
+ "No training data yet\n\nWaiting for first epoch...",
230
+ ha="center",
231
+ va="center",
232
+ fontsize=14,
233
+ color="gray",
234
+ )
235
  ax.set_xlim(0, 1)
236
  ax.set_ylim(0, 1)
237
  else:
238
  # Training curve
239
+ ax.plot(
240
+ train_steps,
241
+ train_values,
242
+ label="Training Loss",
243
+ linewidth=2.5,
244
+ color=COLORS["primary"],
245
+ alpha=0.9,
246
+ )
247
 
248
  # Validation curve with best point
249
  if val_values:
250
+ ax.plot(
251
+ val_steps,
252
+ val_values,
253
+ label="Validation Loss",
254
+ linewidth=2.5,
255
+ color=COLORS["secondary"],
256
+ alpha=0.9,
257
+ )
258
 
259
  best_idx = int(np.argmin(val_values))
260
+ ax.scatter(
261
+ [val_steps[best_idx]],
262
+ [val_values[best_idx]],
263
+ s=200,
264
+ c=COLORS["accent"],
265
+ zorder=5,
266
+ marker="*",
267
+ edgecolors="white",
268
+ linewidth=2,
269
+ label=f"Best: {val_values[best_idx]:.3f}",
270
+ )
271
 
272
  # Annotate best point
273
+ ax.annotate(
274
+ f"Epoch {val_steps[best_idx]}",
275
+ xy=(val_steps[best_idx], val_values[best_idx]),
276
+ xytext=(10, 20),
277
+ textcoords="offset points",
278
+ fontsize=10,
279
+ color=COLORS["accent"],
280
+ arrowprops=dict(arrowstyle="->", color=COLORS["accent"]),
281
+ )
282
 
283
  ax.legend(fontsize=11, loc="upper right", framealpha=0.9)
284
  ax.set_ylim(bottom=0)
 
312
  val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
313
 
314
  if train_sum:
315
+ ax.plot(
316
+ [m.step for m in train_sum],
317
+ [m.value for m in train_sum],
318
+ label="Train",
319
+ linewidth=2.5,
320
+ color=COLORS["summary"],
321
+ )
322
  if val_sum:
323
+ ax.plot(
324
+ [m.step for m in val_sum],
325
+ [m.value for m in val_sum],
326
+ label="Validation",
327
+ linewidth=2.5,
328
+ color=COLORS["secondary"],
329
+ linestyle="--",
330
+ )
331
 
332
  ax.set_title("Summarization Loss")
333
  ax.set_xlabel("Epoch")
 
344
  val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
345
 
346
  if train_emo:
347
+ ax.plot(
348
+ [m.step for m in train_emo],
349
+ [m.value for m in train_emo],
350
+ label="Train Loss",
351
+ linewidth=2.5,
352
+ color=COLORS["emotion"],
353
+ )
354
  if val_emo:
355
+ ax.plot(
356
+ [m.step for m in val_emo],
357
+ [m.value for m in val_emo],
358
+ label="Val Loss",
359
+ linewidth=2.5,
360
+ color=COLORS["secondary"],
361
+ linestyle="--",
362
+ )
363
 
364
  # Secondary axis for F1
365
  ax2 = ax.twinx()
366
  if train_f1:
367
+ ax2.plot(
368
+ [m.step for m in train_f1],
369
+ [m.value for m in train_f1],
370
+ label="Train F1",
371
+ linewidth=2,
372
+ color=COLORS["accent"],
373
+ alpha=0.7,
374
+ )
375
  if val_f1:
376
+ ax2.plot(
377
+ [m.step for m in val_f1],
378
+ [m.value for m in val_f1],
379
+ label="Val F1",
380
+ linewidth=2,
381
+ color=COLORS["highlight"],
382
+ alpha=0.7,
383
+ )
384
  ax2.set_ylim(0, 1)
385
 
386
  ax.set_title("Emotion Detection (28 classes)")
 
401
  val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
402
 
403
  if train_topic:
404
+ ax.plot(
405
+ [m.step for m in train_topic],
406
+ [m.value for m in train_topic],
407
+ label="Train Loss",
408
+ linewidth=2.5,
409
+ color=COLORS["topic"],
410
+ )
411
  if val_topic:
412
+ ax.plot(
413
+ [m.step for m in val_topic],
414
+ [m.value for m in val_topic],
415
+ label="Val Loss",
416
+ linewidth=2.5,
417
+ color=COLORS["secondary"],
418
+ linestyle="--",
419
+ )
420
 
421
  ax2 = ax.twinx()
422
  if train_acc:
423
+ ax2.plot(
424
+ [m.step for m in train_acc],
425
+ [m.value for m in train_acc],
426
+ label="Train Acc",
427
+ linewidth=2,
428
+ color=COLORS["accent"],
429
+ alpha=0.7,
430
+ )
431
  if val_acc:
432
+ ax2.plot(
433
+ [m.step for m in val_acc],
434
+ [m.value for m in val_acc],
435
+ label="Val Acc",
436
+ linewidth=2,
437
+ color=COLORS["highlight"],
438
+ alpha=0.7,
439
+ )
440
  ax2.set_ylim(0, 1)
441
 
442
  ax.set_title("Topic Classification (4 classes)")
 
454
  ax.axis("off")
455
 
456
  # Get final metrics
457
+ summary_lines = [
458
+ "+--------------------------------------+",
459
+ "| FINAL METRICS (Last Epoch) |",
460
+ "+--------------------------------------+",
461
+ ]
462
 
463
  if val_topic and val_acc:
464
  summary_lines.append(f"| Topic Accuracy: {val_acc[-1].value:>6.1%} |")
 
469
 
470
  summary_lines.append("+--------------------------------------+")
471
 
472
+ ax.text(
473
+ 0.1,
474
+ 0.6,
475
+ "\n".join(summary_lines),
476
+ fontsize=11,
477
+ family="monospace",
478
+ verticalalignment="center",
479
+ bbox=dict(boxstyle="round", facecolor=COLORS["light"]),
480
+ )
481
 
482
  # Add model info
483
  run_params = run.data.params
 
485
  model_info += f"Batch Size: {run_params.get('batch_size', 'N/A')}\n"
486
  model_info += f"Learning Rate: {run_params.get('learning_rate', 'N/A')}"
487
 
488
+ ax.text(0.1, 0.15, model_info, fontsize=10, color="gray", verticalalignment="center")
 
489
 
490
  plt.tight_layout()
491
  output_path = OUTPUTS_DIR / "task_metrics.png"
 
504
  if not lr_metrics or len(lr_metrics) < 2:
505
  # No LR data logged - generate theoretical schedule from config
506
  logger.info(" No LR metrics found - generating theoretical schedule...")
507
+
508
  # Get config from run params
509
  params = run.data.params
510
  lr_max = float(params.get("learning_rate", params.get("lr", 5e-5)))
511
  warmup_steps = int(params.get("warmup_steps", 500))
512
  max_epochs = int(params.get("max_epochs", 5))
513
+
514
  # Estimate total steps from training loss history
515
  train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
516
  if train_loss:
 
519
  total_steps = max_epochs * estimated_steps_per_epoch
520
  else:
521
  total_steps = 4000 # Default fallback
522
+
523
  # Generate cosine schedule with warmup
524
  steps = np.arange(0, total_steps)
525
  values = []
 
530
  progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
531
  lr = lr_max * max(0.1, 0.5 * (1 + np.cos(np.pi * progress)))
532
  values.append(lr)
533
+
534
  ax.fill_between(steps, values, alpha=0.3, color=COLORS["primary"])
535
  ax.plot(steps, values, linewidth=2.5, color=COLORS["primary"], label="Cosine + Warmup")
536
+
537
  # Mark warmup region
538
+ ax.axvline(
539
+ warmup_steps,
540
+ color=COLORS["secondary"],
541
+ linestyle="--",
542
+ alpha=0.7,
543
+ linewidth=2,
544
+ label=f"Warmup End ({warmup_steps})",
545
+ )
546
  ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"])
547
+
548
  # Add annotation
549
+ ax.annotate(
550
+ f"Peak LR: {lr_max:.1e}",
551
+ xy=(warmup_steps, lr_max),
552
+ xytext=(warmup_steps + 200, lr_max * 0.9),
553
+ fontsize=10,
554
+ color=COLORS["dark"],
555
+ arrowprops=dict(arrowstyle="->", color=COLORS["dark"], alpha=0.5),
556
+ )
557
+
558
  ax.legend(loc="upper right")
559
+ ax.text(
560
+ 0.98,
561
+ 0.02,
562
+ "(Theoretical - actual LR not logged)",
563
+ transform=ax.transAxes,
564
+ ha="right",
565
+ va="bottom",
566
+ fontsize=9,
567
+ color="gray",
568
+ style="italic",
569
+ )
570
  else:
571
  steps = np.array([m.step for m in lr_metrics])
572
  values = [m.value for m in lr_metrics]
 
579
  params = run.data.params
580
  warmup_steps = int(params.get("warmup_steps", 500))
581
  if warmup_steps < max(steps):
582
+ ax.axvline(
583
+ warmup_steps,
584
+ color=COLORS["secondary"],
585
+ linestyle="--",
586
+ alpha=0.7,
587
+ linewidth=2,
588
+ label="Warmup End",
589
+ )
590
+ ax.axvspan(0, warmup_steps, alpha=0.1, color=COLORS["highlight"], label="Warmup Phase")
591
  ax.legend(loc="upper right")
592
 
593
  # Scientific notation for y-axis if needed
 
606
 
607
  # Advanced Visualizations
608
 
609
+
610
  def plot_confusion_matrix(run, task: str = "topic") -> None:
611
  """
612
  Plot confusion matrix for classification tasks.
 
618
  if task == "topic":
619
  default_labels = ["World", "Sports", "Business", "Sci/Tech"]
620
  else: # emotion - top 8 for visibility
621
+ default_labels = [
622
+ "admiration",
623
+ "amusement",
624
+ "anger",
625
+ "annoyance",
626
+ "approval",
627
+ "caring",
628
+ "curiosity",
629
+ "desire",
630
+ ]
631
 
632
  if labels_path.exists():
633
  with open(labels_path) as f:
 
660
  # Plot
661
  fig, ax = plt.subplots(figsize=(10, 8))
662
 
663
+ sns.heatmap(
664
+ cm_normalized,
665
+ annot=True,
666
+ fmt=".2f",
667
+ cmap=HEATMAP_CMAP,
668
+ xticklabels=labels[:n_classes],
669
+ yticklabels=labels[:n_classes],
670
+ ax=ax,
671
+ cbar_kws={"label": "Proportion"},
672
+ )
673
 
674
  ax.set_title(f"Confusion Matrix: {task.title()} Classification")
675
  ax.set_xlabel("Predicted Label")
 
721
 
722
  # Synthetic loss surface (bowl shape with some local minima)
723
  min_loss = min(val_loss) if val_loss else min(train_loss)
724
+ Z = min_loss + 0.3 * (X**2 + Y**2) + 0.1 * np.sin(3 * X) * np.cos(3 * Y)
725
 
726
  # Add noise for realism
727
  Z += np.random.normal(0, 0.02, Z.shape)
 
735
  fig = go.Figure()
736
 
737
  # Loss surface
738
+ fig.add_trace(
739
+ go.Surface(
740
+ x=X,
741
+ y=Y,
742
+ z=Z,
743
+ colorscale=[[0, COLORS["accent"]], [0.5, COLORS["primary"]], [1, COLORS["secondary"]]],
744
+ opacity=0.8,
745
+ showscale=True,
746
+ colorbar=dict(title="Loss", x=1.02),
747
+ )
748
+ )
749
 
750
  # Training trajectory
751
+ fig.add_trace(
752
+ go.Scatter3d(
753
+ x=trajectory_x,
754
+ y=trajectory_y,
755
+ z=trajectory_z,
756
+ mode="lines+markers",
757
+ line=dict(color=COLORS["highlight"], width=5),
758
+ marker=dict(size=4, color=COLORS["highlight"]),
759
+ name="Training Path",
760
+ )
761
+ )
762
 
763
  # Mark start and end
764
+ fig.add_trace(
765
+ go.Scatter3d(
766
+ x=[trajectory_x[0]],
767
+ y=[trajectory_y[0]],
768
+ z=[trajectory_z[0]],
769
+ mode="markers+text",
770
+ marker=dict(size=10, color="red", symbol="circle"),
771
+ text=["Start"],
772
+ textposition="top center",
773
+ name="Start",
774
+ )
775
+ )
776
+
777
+ fig.add_trace(
778
+ go.Scatter3d(
779
+ x=[trajectory_x[-1]],
780
+ y=[trajectory_y[-1]],
781
+ z=[trajectory_z[-1]],
782
+ mode="markers+text",
783
+ marker=dict(size=10, color="green", symbol="diamond"),
784
+ text=["Converged"],
785
+ textposition="top center",
786
+ name="Converged",
787
+ )
788
+ )
789
 
790
  fig.update_layout(
791
  title="Loss Landscape & Optimization Trajectory",
 
793
  xaxis_title="Parameter Direction 1",
794
  yaxis_title="Parameter Direction 2",
795
  zaxis_title="Loss",
796
+ camera=dict(eye=dict(x=1.5, y=1.5, z=0.8)),
797
  ),
798
  width=900,
799
  height=700,
 
825
  X, Y = np.meshgrid(x, y)
826
 
827
  min_loss = min(train_loss)
828
+ Z = min_loss + 0.3 * (X**2 + Y**2) + 0.08 * np.sin(3 * X) * np.cos(3 * Y)
829
 
830
  fig = plt.figure(figsize=(12, 8))
831
  ax = fig.add_subplot(111, projection="3d")
832
 
833
  # Surface
834
+ surf = ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.7, linewidth=0, antialiased=True)
 
835
 
836
  # Training path
837
  path_x = np.linspace(-1.5, 0, len(train_loss))
838
  path_y = np.linspace(1.2, 0, len(train_loss))
839
+ ax.plot(
840
+ path_x,
841
+ path_y,
842
+ train_loss,
843
+ color=COLORS["secondary"],
844
+ linewidth=3,
845
+ label="Training Path",
846
+ zorder=10,
847
+ )
848
 
849
  # Start/end markers
850
+ ax.scatter(
851
+ [path_x[0]],
852
+ [path_y[0]],
853
+ train_loss[0], # type: ignore[arg-type]
854
+ c="red",
855
+ s=100,
856
+ marker="o",
857
+ label="Start",
858
+ )
859
+ ax.scatter(
860
+ [path_x[-1]],
861
+ [path_y[-1]],
862
+ train_loss[-1], # type: ignore[arg-type]
863
+ c="green",
864
+ s=100,
865
+ marker="*",
866
+ label="Converged",
867
+ )
868
 
869
  ax.set_xlabel("θ₁ Direction")
870
  ax.set_ylabel("θ₂ Direction")
 
909
  for i in range(n_clusters):
910
  # Create cluster center
911
  center = np.random.randn(64) * 0.5
912
+ center[i * 16 : (i + 1) * 16] += 3 # Make clusters separable
913
 
914
  # Add samples around center
915
  samples = center + np.random.randn(n_samples // n_clusters, 64) * 0.5
 
929
 
930
  for i in range(n_clusters):
931
  mask = cluster_labels == i
932
+ ax.scatter(
933
+ embeddings_2d[mask, 0],
934
+ embeddings_2d[mask, 1],
935
+ c=colors[i],
936
+ label=labels[i],
937
+ alpha=0.6,
938
+ s=30,
939
+ )
940
 
941
  ax.set_xlabel("t-SNE Dimension 1")
942
  ax.set_ylabel("t-SNE Dimension 2")
 
980
  # Smoothed loss (exponential moving average)
981
  if len(train_loss) > 5:
982
  window = min(5, len(train_loss) // 2)
983
+ smoothed = np.convolve(train_loss, np.ones(window) / window, mode="valid")
984
+ smoothed_steps = train_steps[window - 1 :]
985
+ ax.plot(
986
+ smoothed_steps,
987
+ smoothed,
988
+ color=COLORS["primary"],
989
+ linewidth=2.5,
990
+ label="Training (smoothed)",
991
+ )
992
 
993
  if val_loss:
994
+ ax.plot(val_steps, val_loss, color=COLORS["secondary"], linewidth=2.5, label="Validation")
 
995
 
996
  ax.set_title("Loss Convergence")
997
  ax.set_xlabel("Epoch")
 
1003
  ax = axes[0, 1]
1004
 
1005
  if len(train_loss) > 1:
1006
+ improvements = [
1007
+ -(train_loss[i] - train_loss[i - 1]) / train_loss[i - 1] * 100
1008
+ for i in range(1, len(train_loss))
1009
+ ]
1010
  colors_bar = [COLORS["accent"] if imp > 0 else COLORS["secondary"] for imp in improvements]
1011
  ax.bar(train_steps[1:], improvements, color=colors_bar, alpha=0.7)
1012
  ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
 
1061
 
1062
  # Dashboard Generator
1063
 
1064
+
1065
  def generate_dashboard(run) -> None:
1066
  """
1067
  Generate an interactive HTML dashboard with all visualizations.
 
1083
 
1084
  # Create subplots
1085
  fig = make_subplots(
1086
+ rows=2,
1087
+ cols=2,
1088
  subplot_titles=("Total Loss", "Task Losses", "Learning Rate", "Metrics"),
1089
+ specs=[[{}, {}], [{}, {}]],
1090
  )
1091
 
1092
  # Total loss
1093
  if train_loss:
1094
  fig.add_trace(
1095
+ go.Scatter(
1096
+ x=train_steps, y=train_loss, name="Train Loss", line=dict(color=COLORS["primary"])
1097
+ ),
1098
+ row=1,
1099
+ col=1,
1100
  )
1101
  if val_loss:
1102
  fig.add_trace(
1103
+ go.Scatter(
1104
+ x=val_steps, y=val_loss, name="Val Loss", line=dict(color=COLORS["secondary"])
1105
+ ),
1106
+ row=1,
1107
+ col=1,
1108
  )
1109
 
1110
  # Per-task losses
1111
+ for task, color in [
1112
+ ("summarization", COLORS["summary"]),
1113
+ ("emotion", COLORS["emotion"]),
1114
+ ("topic", COLORS["topic"]),
1115
+ ]:
1116
  steps, values = get_metric_history(run, f"val_{task}_loss")
1117
  if values:
1118
  fig.add_trace(
1119
+ go.Scatter(x=steps, y=values, name=f"{task.title()} Loss", line=dict(color=color)),
1120
+ row=1,
1121
+ col=2,
1122
  )
1123
 
1124
  # Learning rate
1125
  lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
1126
  if lr_metrics:
1127
  fig.add_trace(
1128
+ go.Scatter(
1129
+ x=[m.step for m in lr_metrics],
1130
+ y=[m.value for m in lr_metrics],
1131
+ name="Learning Rate",
1132
+ fill="tozeroy",
1133
+ line=dict(color=COLORS["primary"]),
1134
+ ),
1135
+ row=2,
1136
+ col=1,
1137
  )
1138
 
1139
  # Accuracy metrics
1140
+ for metric, color in [("topic_accuracy", COLORS["topic"]), ("emotion_f1", COLORS["emotion"])]:
 
1141
  steps, values = get_metric_history(run, f"val_{metric}")
1142
  if values:
1143
  fig.add_trace(
1144
+ go.Scatter(
1145
+ x=steps, y=values, name=metric.replace("_", " ").title(), line=dict(color=color)
1146
+ ),
1147
+ row=2,
1148
+ col=2,
1149
  )
1150
 
1151
  fig.update_layout(
1152
+ title="LexiMind Training Dashboard", height=800, template="plotly_white", showlegend=True
 
 
 
1153
  )
1154
 
1155
  output_path = OUTPUTS_DIR / "training_dashboard.html"
 
1159
 
1160
  # Main Entry Point
1161
 
1162
+
1163
  def main():
1164
  """Generate all training visualizations."""
1165
  parser = argparse.ArgumentParser(description="LexiMind Visualization Suite")
1166
+ parser.add_argument(
1167
+ "--interactive",
1168
+ action="store_true",
1169
+ help="Generate interactive HTML plots (requires plotly)",
1170
+ )
1171
+ parser.add_argument(
1172
+ "--landscape", action="store_true", help="Include 3D loss landscape visualization"
1173
+ )
1174
+ parser.add_argument("--dashboard", action="store_true", help="Generate interactive dashboard")
1175
+ parser.add_argument("--all", action="store_true", help="Generate all visualizations")
1176
  args = parser.parse_args()
1177
 
1178
  logger.info("=" * 60)
src/data/dataset.py CHANGED
@@ -24,6 +24,7 @@ from torch.utils.data import Dataset
24
  @dataclass
25
  class SummarizationExample:
26
  """Container for abstractive summarization samples."""
 
27
  source: str
28
  summary: str
29
 
@@ -31,6 +32,7 @@ class SummarizationExample:
31
  @dataclass
32
  class EmotionExample:
33
  """Container for multi-label emotion classification samples."""
 
34
  text: str
35
  emotions: Sequence[str]
36
 
@@ -38,12 +40,14 @@ class EmotionExample:
38
  @dataclass
39
  class TopicExample:
40
  """Container for topic clustering / classification samples."""
 
41
  text: str
42
  topic: str
43
 
44
 
45
  class SummarizationDataset(Dataset[SummarizationExample]):
46
  """Dataset yielding encoder-decoder training pairs."""
 
47
  def __init__(self, examples: Iterable[SummarizationExample]) -> None:
48
  self._examples = list(examples)
49
 
@@ -56,6 +60,7 @@ class SummarizationDataset(Dataset[SummarizationExample]):
56
 
57
  class EmotionDataset(Dataset[EmotionExample]):
58
  """Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
 
59
  def __init__(
60
  self,
61
  examples: Iterable[EmotionExample],
@@ -91,6 +96,7 @@ class EmotionDataset(Dataset[EmotionExample]):
91
 
92
  class TopicDataset(Dataset[TopicExample]):
93
  """Dataset that owns a LabelEncoder for topic ids."""
 
94
  def __init__(
95
  self,
96
  examples: Iterable[TopicExample],
@@ -241,7 +247,7 @@ def load_topic_jsonl(path: str) -> List[TopicExample]:
241
 
242
  def _text_fingerprint(text: str, n_chars: int = 200) -> str:
243
  """Create a stable fingerprint from the first N characters of text.
244
-
245
  Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
246
  to detect document-level overlap across tasks.
247
  """
@@ -255,28 +261,28 @@ def deduplicate_across_tasks(
255
  emotion_examples: List[EmotionExample] | None = None,
256
  ) -> Dict[str, int]:
257
  """Detect and report cross-task document overlap.
258
-
259
  Checks whether texts appearing in the summarization dataset also appear
260
  in the topic or emotion datasets, which could create data leakage in MTL.
261
-
262
  Returns:
263
  Dict with overlap counts between task pairs.
264
  """
265
  summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
266
  topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
267
-
268
  overlap: Dict[str, int] = {
269
  "summ_topic_overlap": len(summ_fps & topic_fps),
270
  "summ_total": len(summ_fps),
271
  "topic_total": len(topic_fps),
272
  }
273
-
274
  if emotion_examples:
275
  emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
276
  overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
277
  overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
278
  overlap["emotion_total"] = len(emot_fps)
279
-
280
  return overlap
281
 
282
 
@@ -286,20 +292,20 @@ def remove_overlapping_examples(
286
  split: str = "val",
287
  ) -> tuple[List[TopicExample], int]:
288
  """Remove topic examples whose texts overlap with summarization data.
289
-
290
- This prevents cross-task data leakage where a document seen during
291
  summarization training could boost topic classification on validation/test.
292
-
293
  Args:
294
  primary_examples: Topic examples to filter
295
  reference_examples: Summarization examples to check against
296
  split: Name of split being processed (for logging)
297
-
298
  Returns:
299
  Tuple of (filtered_examples, num_removed)
300
  """
301
  ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
302
-
303
  filtered = []
304
  removed = 0
305
  for ex in primary_examples:
@@ -308,8 +314,8 @@ def remove_overlapping_examples(
308
  removed += 1
309
  else:
310
  filtered.append(ex)
311
-
312
  if removed > 0:
313
  print(f" Dedup: removed {removed} overlapping examples from topic {split}")
314
-
315
  return filtered, removed
 
24
  @dataclass
25
  class SummarizationExample:
26
  """Container for abstractive summarization samples."""
27
+
28
  source: str
29
  summary: str
30
 
 
32
  @dataclass
33
  class EmotionExample:
34
  """Container for multi-label emotion classification samples."""
35
+
36
  text: str
37
  emotions: Sequence[str]
38
 
 
40
  @dataclass
41
  class TopicExample:
42
  """Container for topic clustering / classification samples."""
43
+
44
  text: str
45
  topic: str
46
 
47
 
48
  class SummarizationDataset(Dataset[SummarizationExample]):
49
  """Dataset yielding encoder-decoder training pairs."""
50
+
51
  def __init__(self, examples: Iterable[SummarizationExample]) -> None:
52
  self._examples = list(examples)
53
 
 
60
 
61
  class EmotionDataset(Dataset[EmotionExample]):
62
  """Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
63
+
64
  def __init__(
65
  self,
66
  examples: Iterable[EmotionExample],
 
96
 
97
  class TopicDataset(Dataset[TopicExample]):
98
  """Dataset that owns a LabelEncoder for topic ids."""
99
+
100
  def __init__(
101
  self,
102
  examples: Iterable[TopicExample],
 
247
 
248
  def _text_fingerprint(text: str, n_chars: int = 200) -> str:
249
  """Create a stable fingerprint from the first N characters of text.
250
+
251
  Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
252
  to detect document-level overlap across tasks.
253
  """
 
261
  emotion_examples: List[EmotionExample] | None = None,
262
  ) -> Dict[str, int]:
263
  """Detect and report cross-task document overlap.
264
+
265
  Checks whether texts appearing in the summarization dataset also appear
266
  in the topic or emotion datasets, which could create data leakage in MTL.
267
+
268
  Returns:
269
  Dict with overlap counts between task pairs.
270
  """
271
  summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
272
  topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
273
+
274
  overlap: Dict[str, int] = {
275
  "summ_topic_overlap": len(summ_fps & topic_fps),
276
  "summ_total": len(summ_fps),
277
  "topic_total": len(topic_fps),
278
  }
279
+
280
  if emotion_examples:
281
  emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
282
  overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
283
  overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
284
  overlap["emotion_total"] = len(emot_fps)
285
+
286
  return overlap
287
 
288
 
 
292
  split: str = "val",
293
  ) -> tuple[List[TopicExample], int]:
294
  """Remove topic examples whose texts overlap with summarization data.
295
+
296
+ This prevents cross-task data leakage where a document seen during
297
  summarization training could boost topic classification on validation/test.
298
+
299
  Args:
300
  primary_examples: Topic examples to filter
301
  reference_examples: Summarization examples to check against
302
  split: Name of split being processed (for logging)
303
+
304
  Returns:
305
  Tuple of (filtered_examples, num_removed)
306
  """
307
  ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
308
+
309
  filtered = []
310
  removed = 0
311
  for ex in primary_examples:
 
314
  removed += 1
315
  else:
316
  filtered.append(ex)
317
+
318
  if removed > 0:
319
  print(f" Dedup: removed {removed} overlapping examples from topic {split}")
320
+
321
  return filtered, removed
src/models/decoder.py CHANGED
@@ -327,7 +327,6 @@ class TransformerDecoder(nn.Module):
327
  elif tgt_mask.dim() == 3:
328
  tgt_mask = tgt_mask.unsqueeze(1)
329
 
330
-
331
  # Normalize memory_mask dtype/device and expand simple shapes
332
  if memory_mask is not None:
333
  memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
@@ -355,7 +354,15 @@ class TransformerDecoder(nn.Module):
355
  # Gradient checkpointing requires the inputs to require grad
356
  def create_custom_forward(module):
357
  def custom_forward(*inputs):
358
- return module(*inputs, tgt_mask=tgt_mask, memory_mask=memory_mask, collect_attn=collect_attn, self_attn_position_bias=self_position_bias, cross_attn_position_bias=cross_position_bias)
 
 
 
 
 
 
 
 
359
  return custom_forward
360
 
361
  x, attn = cast(
@@ -450,7 +457,7 @@ class TransformerDecoder(nn.Module):
450
  ) -> torch.Tensor:
451
  """
452
  Greedy decoding with KV caching for O(N) complexity.
453
-
454
  Args:
455
  length_penalty: Values > 1.0 encourage shorter sequences by boosting EOS probability
456
  as sequence length increases. Default 1.0 (no penalty).
 
327
  elif tgt_mask.dim() == 3:
328
  tgt_mask = tgt_mask.unsqueeze(1)
329
 
 
330
  # Normalize memory_mask dtype/device and expand simple shapes
331
  if memory_mask is not None:
332
  memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
 
354
  # Gradient checkpointing requires the inputs to require grad
355
  def create_custom_forward(module):
356
  def custom_forward(*inputs):
357
+ return module(
358
+ *inputs,
359
+ tgt_mask=tgt_mask,
360
+ memory_mask=memory_mask,
361
+ collect_attn=collect_attn,
362
+ self_attn_position_bias=self_position_bias,
363
+ cross_attn_position_bias=cross_position_bias,
364
+ )
365
+
366
  return custom_forward
367
 
368
  x, attn = cast(
 
457
  ) -> torch.Tensor:
458
  """
459
  Greedy decoding with KV caching for O(N) complexity.
460
+
461
  Args:
462
  length_penalty: Values > 1.0 encourage shorter sequences by boosting EOS probability
463
  as sequence length increases. Default 1.0 (no penalty).
src/models/encoder.py CHANGED
@@ -291,7 +291,13 @@ class TransformerEncoder(nn.Module):
291
  # We use a lambda to pass keyword arguments
292
  def create_custom_forward(module):
293
  def custom_forward(*inputs):
294
- return module(*inputs, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
 
 
 
 
 
 
295
  return custom_forward
296
 
297
  x, attn = cast(
@@ -303,8 +309,10 @@ class TransformerEncoder(nn.Module):
303
  ),
304
  )
305
  else:
306
- x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
307
-
 
 
308
  if collect_attn:
309
  attn_weights_per_layer.append(attn)
310
 
 
291
  # We use a lambda to pass keyword arguments
292
  def create_custom_forward(module):
293
  def custom_forward(*inputs):
294
+ return module(
295
+ *inputs,
296
+ mask=mask,
297
+ collect_attn=collect_attn,
298
+ position_bias=position_bias,
299
+ )
300
+
301
  return custom_forward
302
 
303
  x, attn = cast(
 
309
  ),
310
  )
311
  else:
312
+ x, attn = layer(
313
+ x, mask=mask, collect_attn=collect_attn, position_bias=position_bias
314
+ )
315
+
316
  if collect_attn:
317
  attn_weights_per_layer.append(attn)
318
 
src/models/factory.py CHANGED
@@ -208,7 +208,9 @@ def _load_pretrained_weights(
208
  if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
209
  print("Transferring encoder relative position bias...")
210
  t5_enc_rel_bias = (
211
- cast(Any, t5_encoder.block[0]).layer[0].SelfAttention.relative_attention_bias.weight.data
 
 
212
  )
213
  encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
214
 
@@ -285,7 +287,9 @@ def _load_pretrained_weights(
285
  ):
286
  print("Transferring decoder self-attention relative position bias...")
287
  t5_dec_self_rel_bias = (
288
- cast(Any, t5_decoder.block[0]).layer[0].SelfAttention.relative_attention_bias.weight.data
 
 
289
  )
290
  decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
291
  t5_dec_self_rel_bias
@@ -298,7 +302,9 @@ def _load_pretrained_weights(
298
  print("Transferring decoder cross-attention relative position bias...")
299
  # Cross-attention relative position bias is in EncDecAttention of first block
300
  t5_dec_cross_rel_bias = (
301
- cast(Any, t5_decoder.block[0]).layer[1].EncDecAttention.relative_attention_bias.weight.data
 
 
302
  )
303
  decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
304
  t5_dec_cross_rel_bias
@@ -554,9 +560,9 @@ def build_multitask_model(
554
  model.add_head(
555
  "emotion",
556
  ClassificationHead(
557
- d_model=cfg.d_model,
558
- num_labels=num_emotions,
559
- pooler="attention",
560
  dropout=cfg.dropout,
561
  hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
562
  ),
 
208
  if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
209
  print("Transferring encoder relative position bias...")
210
  t5_enc_rel_bias = (
211
+ cast(Any, t5_encoder.block[0])
212
+ .layer[0]
213
+ .SelfAttention.relative_attention_bias.weight.data
214
  )
215
  encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
216
 
 
287
  ):
288
  print("Transferring decoder self-attention relative position bias...")
289
  t5_dec_self_rel_bias = (
290
+ cast(Any, t5_decoder.block[0])
291
+ .layer[0]
292
+ .SelfAttention.relative_attention_bias.weight.data
293
  )
294
  decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
295
  t5_dec_self_rel_bias
 
302
  print("Transferring decoder cross-attention relative position bias...")
303
  # Cross-attention relative position bias is in EncDecAttention of first block
304
  t5_dec_cross_rel_bias = (
305
+ cast(Any, t5_decoder.block[0])
306
+ .layer[1]
307
+ .EncDecAttention.relative_attention_bias.weight.data
308
  )
309
  decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
310
  t5_dec_cross_rel_bias
 
560
  model.add_head(
561
  "emotion",
562
  ClassificationHead(
563
+ d_model=cfg.d_model,
564
+ num_labels=num_emotions,
565
+ pooler="attention",
566
  dropout=cfg.dropout,
567
  hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
568
  ),
src/models/heads.py CHANGED
@@ -66,13 +66,15 @@ class ClassificationHead(nn.Module):
66
  hidden_dim: Optional[int] = None,
67
  ):
68
  super().__init__()
69
- assert pooler in ("mean", "cls", "max", "attention"), "pooler must be 'mean'|'cls'|'max'|'attention'"
 
 
70
  self.pooler = pooler
71
  self.dropout = nn.Dropout(dropout)
72
 
73
  if pooler == "attention":
74
  self.attn_pool = AttentionPooling(d_model)
75
-
76
  # Optional 2-layer MLP for more capacity (useful for multi-label)
77
  if hidden_dim is not None:
78
  self.out_proj = nn.Sequential(
 
66
  hidden_dim: Optional[int] = None,
67
  ):
68
  super().__init__()
69
+ assert pooler in ("mean", "cls", "max", "attention"), (
70
+ "pooler must be 'mean'|'cls'|'max'|'attention'"
71
+ )
72
  self.pooler = pooler
73
  self.dropout = nn.Dropout(dropout)
74
 
75
  if pooler == "attention":
76
  self.attn_pool = AttentionPooling(d_model)
77
+
78
  # Optional 2-layer MLP for more capacity (useful for multi-label)
79
  if hidden_dim is not None:
80
  self.out_proj = nn.Sequential(
src/training/metrics.py CHANGED
@@ -72,33 +72,33 @@ def calculate_bertscore(
72
  ) -> Dict[str, float]:
73
  """
74
  Calculate BERTScore for semantic similarity between predictions and references.
75
-
76
  BERTScore measures semantic similarity using contextual embeddings, making it
77
  more robust than n-gram based metrics like ROUGE for paraphrased content.
78
-
79
  Args:
80
  predictions: Generated summaries/descriptions
81
  references: Reference summaries/descriptions
82
  model_type: BERT model to use (default: deberta-xlarge-mnli for best quality)
83
  batch_size: Batch size for encoding
84
  device: Device to use (auto-detected if None)
85
-
86
  Returns:
87
  Dict with 'precision', 'recall', 'f1' BERTScore averages
88
  """
89
  if not predictions or not references:
90
  return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
91
-
92
  try:
93
  from bert_score import score as bert_score # type: ignore[import-not-found]
94
  except ImportError:
95
  print("Warning: bert-score not installed. Run: pip install bert-score")
96
  return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
97
-
98
  # Auto-detect device
99
  if device is None:
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
-
102
  # Calculate BERTScore
103
  P, R, F1 = bert_score(
104
  list(predictions),
@@ -108,7 +108,7 @@ def calculate_bertscore(
108
  device=device,
109
  verbose=False,
110
  )
111
-
112
  return {
113
  "precision": float(P.mean().item()), # type: ignore[union-attr]
114
  "recall": float(R.mean().item()), # type: ignore[union-attr]
@@ -122,35 +122,35 @@ def calculate_rouge(
122
  ) -> Dict[str, float]:
123
  """
124
  Calculate proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L).
125
-
126
  Args:
127
  predictions: Generated summaries
128
  references: Reference summaries
129
-
130
  Returns:
131
  Dict with rouge1, rouge2, rougeL F1 scores
132
  """
133
  if not predictions or not references:
134
  return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
135
-
136
  try:
137
  from rouge_score import rouge_scorer
138
  except ImportError:
139
  print("Warning: rouge-score not installed. Run: pip install rouge-score")
140
  return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
141
-
142
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
143
-
144
  rouge1_scores = []
145
  rouge2_scores = []
146
  rougeL_scores = []
147
-
148
  for pred, ref in zip(predictions, references, strict=False):
149
  scores = scorer.score(ref, pred)
150
- rouge1_scores.append(scores['rouge1'].fmeasure)
151
- rouge2_scores.append(scores['rouge2'].fmeasure)
152
- rougeL_scores.append(scores['rougeL'].fmeasure)
153
-
154
  return {
155
  "rouge1": sum(rouge1_scores) / len(rouge1_scores),
156
  "rouge2": sum(rouge2_scores) / len(rouge2_scores),
@@ -166,37 +166,35 @@ def calculate_all_summarization_metrics(
166
  ) -> Dict[str, float]:
167
  """
168
  Calculate comprehensive summarization metrics for research paper reporting.
169
-
170
  Includes:
171
  - ROUGE-1, ROUGE-2, ROUGE-L (lexical overlap)
172
  - BLEU-4 (n-gram precision)
173
  - BERTScore (semantic similarity)
174
-
175
  Args:
176
  predictions: Generated summaries/descriptions
177
  references: Reference summaries/descriptions
178
  include_bertscore: Whether to compute BERTScore (slower but valuable)
179
  bertscore_model: Model for BERTScore computation
180
-
181
  Returns:
182
  Dict with all metric scores
183
  """
184
  metrics: Dict[str, float] = {}
185
-
186
  # ROUGE scores
187
  rouge_scores = calculate_rouge(predictions, references)
188
  metrics.update({f"rouge_{k}": v for k, v in rouge_scores.items()})
189
-
190
  # BLEU score
191
  metrics["bleu4"] = calculate_bleu(predictions, references)
192
-
193
  # BERTScore (semantic similarity - important for back-cover style descriptions)
194
  if include_bertscore:
195
- bert_scores = calculate_bertscore(
196
- predictions, references, model_type=bertscore_model
197
- )
198
  metrics.update({f"bertscore_{k}": v for k, v in bert_scores.items()})
199
-
200
  return metrics
201
 
202
 
@@ -246,22 +244,22 @@ def get_confusion_matrix(
246
 
247
  def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
248
  """Compute macro F1: average F1 per class (as in GoEmotions paper).
249
-
250
- This averages F1 across labels, giving equal weight to each emotion class
251
  regardless of prevalence. Directly comparable to GoEmotions baselines.
252
  """
253
  preds = predictions.float()
254
  gold = targets.float()
255
-
256
  # Per-class TP, FP, FN
257
  tp = (preds * gold).sum(dim=0)
258
  fp = (preds * (1 - gold)).sum(dim=0)
259
  fn = ((1 - preds) * gold).sum(dim=0)
260
-
261
  precision = tp / (tp + fp).clamp(min=1e-8)
262
  recall = tp / (tp + fn).clamp(min=1e-8)
263
  f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
264
-
265
  # Zero out F1 for classes with no support in either predictions or targets
266
  mask = (tp + fp + fn) > 0
267
  if mask.sum() == 0:
@@ -271,16 +269,16 @@ def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> flo
271
 
272
  def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
273
  """Compute micro F1: aggregate TP/FP/FN across all classes.
274
-
275
  This gives more weight to frequent classes. Useful when class distribution matters.
276
  """
277
  preds = predictions.float()
278
  gold = targets.float()
279
-
280
  tp = (preds * gold).sum()
281
  fp = (preds * (1 - gold)).sum()
282
  fn = ((1 - preds) * gold).sum()
283
-
284
  precision = tp / (tp + fp).clamp(min=1e-8)
285
  recall = tp / (tp + fn).clamp(min=1e-8)
286
  f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
@@ -293,17 +291,17 @@ def multilabel_per_class_metrics(
293
  class_names: Sequence[str] | None = None,
294
  ) -> Dict[str, Dict[str, float]]:
295
  """Compute per-class precision, recall, F1 for multi-label classification.
296
-
297
  Returns a dict mapping class name/index to its metrics.
298
  """
299
  preds = predictions.float()
300
  gold = targets.float()
301
  num_classes = preds.shape[1]
302
-
303
  tp = (preds * gold).sum(dim=0)
304
  fp = (preds * (1 - gold)).sum(dim=0)
305
  fn = ((1 - preds) * gold).sum(dim=0)
306
-
307
  report: Dict[str, Dict[str, float]] = {}
308
  for i in range(num_classes):
309
  name = class_names[i] if class_names else str(i)
@@ -325,26 +323,26 @@ def tune_per_class_thresholds(
325
  thresholds: Sequence[float] | None = None,
326
  ) -> tuple[List[float], float]:
327
  """Tune per-class thresholds on validation set to maximize macro F1.
328
-
329
- For each class, tries multiple thresholds and selects the one that
330
- maximizes that class's F1 score. This is standard practice for multi-label
331
  classification (used in the original GoEmotions paper).
332
-
333
  Args:
334
  logits: Raw model logits (batch, num_classes)
335
  targets: Binary target labels (batch, num_classes)
336
  thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
337
-
338
  Returns:
339
  Tuple of (best_thresholds_per_class, resulting_macro_f1)
340
  """
341
  if thresholds is None:
342
  thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
343
-
344
  probs = torch.sigmoid(logits)
345
  num_classes = probs.shape[1]
346
  gold = targets.float()
347
-
348
  best_thresholds: List[float] = []
349
  for c in range(num_classes):
350
  best_f1 = -1.0
@@ -364,13 +362,13 @@ def tune_per_class_thresholds(
364
  best_f1 = f1
365
  best_t = t
366
  best_thresholds.append(best_t)
367
-
368
  # Compute resulting macro F1 with tuned thresholds
369
  tuned_preds = torch.zeros_like(probs)
370
  for c in range(num_classes):
371
  tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
372
  macro_f1 = multilabel_macro_f1(tuned_preds, targets)
373
-
374
  return best_thresholds, macro_f1
375
 
376
 
@@ -384,30 +382,30 @@ def bootstrap_confidence_interval(
384
  seed: int = 42,
385
  ) -> tuple[float, float, float]:
386
  """Compute bootstrap confidence interval for a metric.
387
-
388
  Args:
389
  scores: Per-sample metric values
390
  n_bootstrap: Number of bootstrap resamples
391
  confidence: Confidence level (default 95%)
392
  seed: Random seed for reproducibility
393
-
394
  Returns:
395
  Tuple of (mean, lower_bound, upper_bound)
396
  """
397
  rng = np.random.default_rng(seed)
398
  scores_arr = np.array(scores)
399
  n = len(scores_arr)
400
-
401
  bootstrap_means = []
402
  for _ in range(n_bootstrap):
403
  sample = rng.choice(scores_arr, size=n, replace=True)
404
  bootstrap_means.append(float(np.mean(sample)))
405
-
406
  bootstrap_means.sort()
407
  alpha = 1 - confidence
408
  lower_idx = int(alpha / 2 * n_bootstrap)
409
  upper_idx = int((1 - alpha / 2) * n_bootstrap)
410
-
411
  return (
412
  float(np.mean(scores_arr)),
413
  bootstrap_means[lower_idx],
@@ -422,15 +420,15 @@ def paired_bootstrap_test(
422
  seed: int = 42,
423
  ) -> float:
424
  """Paired bootstrap significance test between two systems.
425
-
426
  Tests if system B is significantly better than system A.
427
-
428
  Args:
429
  scores_a: Per-sample scores from system A
430
  scores_b: Per-sample scores from system B
431
  n_bootstrap: Number of bootstrap iterations
432
  seed: Random seed
433
-
434
  Returns:
435
  p-value (probability that B is not better than A)
436
  """
@@ -438,14 +436,14 @@ def paired_bootstrap_test(
438
  a = np.array(scores_a)
439
  b = np.array(scores_b)
440
  assert len(a) == len(b), "Both score lists must have the same length"
441
-
442
  n = len(a)
443
-
444
  count = 0
445
  for _ in range(n_bootstrap):
446
  idx = rng.choice(n, size=n, replace=True)
447
  diff = float(np.mean(b[idx]) - np.mean(a[idx]))
448
  if diff <= 0:
449
  count += 1
450
-
451
  return count / n_bootstrap
 
72
  ) -> Dict[str, float]:
73
  """
74
  Calculate BERTScore for semantic similarity between predictions and references.
75
+
76
  BERTScore measures semantic similarity using contextual embeddings, making it
77
  more robust than n-gram based metrics like ROUGE for paraphrased content.
78
+
79
  Args:
80
  predictions: Generated summaries/descriptions
81
  references: Reference summaries/descriptions
82
  model_type: BERT model to use (default: deberta-xlarge-mnli for best quality)
83
  batch_size: Batch size for encoding
84
  device: Device to use (auto-detected if None)
85
+
86
  Returns:
87
  Dict with 'precision', 'recall', 'f1' BERTScore averages
88
  """
89
  if not predictions or not references:
90
  return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
91
+
92
  try:
93
  from bert_score import score as bert_score # type: ignore[import-not-found]
94
  except ImportError:
95
  print("Warning: bert-score not installed. Run: pip install bert-score")
96
  return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
97
+
98
  # Auto-detect device
99
  if device is None:
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
+
102
  # Calculate BERTScore
103
  P, R, F1 = bert_score(
104
  list(predictions),
 
108
  device=device,
109
  verbose=False,
110
  )
111
+
112
  return {
113
  "precision": float(P.mean().item()), # type: ignore[union-attr]
114
  "recall": float(R.mean().item()), # type: ignore[union-attr]
 
122
  ) -> Dict[str, float]:
123
  """
124
  Calculate proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L).
125
+
126
  Args:
127
  predictions: Generated summaries
128
  references: Reference summaries
129
+
130
  Returns:
131
  Dict with rouge1, rouge2, rougeL F1 scores
132
  """
133
  if not predictions or not references:
134
  return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
135
+
136
  try:
137
  from rouge_score import rouge_scorer
138
  except ImportError:
139
  print("Warning: rouge-score not installed. Run: pip install rouge-score")
140
  return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
141
+
142
+ scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
143
+
144
  rouge1_scores = []
145
  rouge2_scores = []
146
  rougeL_scores = []
147
+
148
  for pred, ref in zip(predictions, references, strict=False):
149
  scores = scorer.score(ref, pred)
150
+ rouge1_scores.append(scores["rouge1"].fmeasure)
151
+ rouge2_scores.append(scores["rouge2"].fmeasure)
152
+ rougeL_scores.append(scores["rougeL"].fmeasure)
153
+
154
  return {
155
  "rouge1": sum(rouge1_scores) / len(rouge1_scores),
156
  "rouge2": sum(rouge2_scores) / len(rouge2_scores),
 
166
  ) -> Dict[str, float]:
167
  """
168
  Calculate comprehensive summarization metrics for research paper reporting.
169
+
170
  Includes:
171
  - ROUGE-1, ROUGE-2, ROUGE-L (lexical overlap)
172
  - BLEU-4 (n-gram precision)
173
  - BERTScore (semantic similarity)
174
+
175
  Args:
176
  predictions: Generated summaries/descriptions
177
  references: Reference summaries/descriptions
178
  include_bertscore: Whether to compute BERTScore (slower but valuable)
179
  bertscore_model: Model for BERTScore computation
180
+
181
  Returns:
182
  Dict with all metric scores
183
  """
184
  metrics: Dict[str, float] = {}
185
+
186
  # ROUGE scores
187
  rouge_scores = calculate_rouge(predictions, references)
188
  metrics.update({f"rouge_{k}": v for k, v in rouge_scores.items()})
189
+
190
  # BLEU score
191
  metrics["bleu4"] = calculate_bleu(predictions, references)
192
+
193
  # BERTScore (semantic similarity - important for back-cover style descriptions)
194
  if include_bertscore:
195
+ bert_scores = calculate_bertscore(predictions, references, model_type=bertscore_model)
 
 
196
  metrics.update({f"bertscore_{k}": v for k, v in bert_scores.items()})
197
+
198
  return metrics
199
 
200
 
 
244
 
245
  def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
246
  """Compute macro F1: average F1 per class (as in GoEmotions paper).
247
+
248
+ This averages F1 across labels, giving equal weight to each emotion class
249
  regardless of prevalence. Directly comparable to GoEmotions baselines.
250
  """
251
  preds = predictions.float()
252
  gold = targets.float()
253
+
254
  # Per-class TP, FP, FN
255
  tp = (preds * gold).sum(dim=0)
256
  fp = (preds * (1 - gold)).sum(dim=0)
257
  fn = ((1 - preds) * gold).sum(dim=0)
258
+
259
  precision = tp / (tp + fp).clamp(min=1e-8)
260
  recall = tp / (tp + fn).clamp(min=1e-8)
261
  f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
262
+
263
  # Zero out F1 for classes with no support in either predictions or targets
264
  mask = (tp + fp + fn) > 0
265
  if mask.sum() == 0:
 
269
 
270
  def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
271
  """Compute micro F1: aggregate TP/FP/FN across all classes.
272
+
273
  This gives more weight to frequent classes. Useful when class distribution matters.
274
  """
275
  preds = predictions.float()
276
  gold = targets.float()
277
+
278
  tp = (preds * gold).sum()
279
  fp = (preds * (1 - gold)).sum()
280
  fn = ((1 - preds) * gold).sum()
281
+
282
  precision = tp / (tp + fp).clamp(min=1e-8)
283
  recall = tp / (tp + fn).clamp(min=1e-8)
284
  f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
 
291
  class_names: Sequence[str] | None = None,
292
  ) -> Dict[str, Dict[str, float]]:
293
  """Compute per-class precision, recall, F1 for multi-label classification.
294
+
295
  Returns a dict mapping class name/index to its metrics.
296
  """
297
  preds = predictions.float()
298
  gold = targets.float()
299
  num_classes = preds.shape[1]
300
+
301
  tp = (preds * gold).sum(dim=0)
302
  fp = (preds * (1 - gold)).sum(dim=0)
303
  fn = ((1 - preds) * gold).sum(dim=0)
304
+
305
  report: Dict[str, Dict[str, float]] = {}
306
  for i in range(num_classes):
307
  name = class_names[i] if class_names else str(i)
 
323
  thresholds: Sequence[float] | None = None,
324
  ) -> tuple[List[float], float]:
325
  """Tune per-class thresholds on validation set to maximize macro F1.
326
+
327
+ For each class, tries multiple thresholds and selects the one that
328
+ maximizes that class's F1 score. This is standard practice for multi-label
329
  classification (used in the original GoEmotions paper).
330
+
331
  Args:
332
  logits: Raw model logits (batch, num_classes)
333
  targets: Binary target labels (batch, num_classes)
334
  thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
335
+
336
  Returns:
337
  Tuple of (best_thresholds_per_class, resulting_macro_f1)
338
  """
339
  if thresholds is None:
340
  thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
341
+
342
  probs = torch.sigmoid(logits)
343
  num_classes = probs.shape[1]
344
  gold = targets.float()
345
+
346
  best_thresholds: List[float] = []
347
  for c in range(num_classes):
348
  best_f1 = -1.0
 
362
  best_f1 = f1
363
  best_t = t
364
  best_thresholds.append(best_t)
365
+
366
  # Compute resulting macro F1 with tuned thresholds
367
  tuned_preds = torch.zeros_like(probs)
368
  for c in range(num_classes):
369
  tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
370
  macro_f1 = multilabel_macro_f1(tuned_preds, targets)
371
+
372
  return best_thresholds, macro_f1
373
 
374
 
 
382
  seed: int = 42,
383
  ) -> tuple[float, float, float]:
384
  """Compute bootstrap confidence interval for a metric.
385
+
386
  Args:
387
  scores: Per-sample metric values
388
  n_bootstrap: Number of bootstrap resamples
389
  confidence: Confidence level (default 95%)
390
  seed: Random seed for reproducibility
391
+
392
  Returns:
393
  Tuple of (mean, lower_bound, upper_bound)
394
  """
395
  rng = np.random.default_rng(seed)
396
  scores_arr = np.array(scores)
397
  n = len(scores_arr)
398
+
399
  bootstrap_means = []
400
  for _ in range(n_bootstrap):
401
  sample = rng.choice(scores_arr, size=n, replace=True)
402
  bootstrap_means.append(float(np.mean(sample)))
403
+
404
  bootstrap_means.sort()
405
  alpha = 1 - confidence
406
  lower_idx = int(alpha / 2 * n_bootstrap)
407
  upper_idx = int((1 - alpha / 2) * n_bootstrap)
408
+
409
  return (
410
  float(np.mean(scores_arr)),
411
  bootstrap_means[lower_idx],
 
420
  seed: int = 42,
421
  ) -> float:
422
  """Paired bootstrap significance test between two systems.
423
+
424
  Tests if system B is significantly better than system A.
425
+
426
  Args:
427
  scores_a: Per-sample scores from system A
428
  scores_b: Per-sample scores from system B
429
  n_bootstrap: Number of bootstrap iterations
430
  seed: Random seed
431
+
432
  Returns:
433
  p-value (probability that B is not better than A)
434
  """
 
436
  a = np.array(scores_a)
437
  b = np.array(scores_b)
438
  assert len(a) == len(b), "Both score lists must have the same length"
439
+
440
  n = len(a)
441
+
442
  count = 0
443
  for _ in range(n_bootstrap):
444
  idx = rng.choice(n, size=n, replace=True)
445
  diff = float(np.mean(b[idx]) - np.mean(a[idx]))
446
  if diff <= 0:
447
  count += 1
448
+
449
  return count / n_bootstrap
src/training/trainer.py CHANGED
@@ -48,24 +48,24 @@ class TrainerConfig:
48
  validation_max_length: int = 128
49
  label_smoothing: float = 0.1
50
  gradient_accumulation_steps: int = 1
51
-
52
  # LR scheduler
53
  scheduler_type: str = "cosine"
54
  warmup_steps: int = 500
55
-
56
  # Early stopping
57
  early_stopping_patience: int | None = 5
58
-
59
  # Task sampling strategy: "round_robin" or "temperature"
60
  # Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
61
  # alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
62
  task_sampling: str = "temperature"
63
  task_sampling_alpha: float = 0.5
64
-
65
  # Gradient conflict diagnostics
66
  # Compute inter-task gradient cosine similarity every N steps (0 = disabled)
67
  gradient_conflict_frequency: int = 0
68
-
69
  # MLflow
70
  experiment_name: str = "LexiMind"
71
  run_name: str | None = None
@@ -76,13 +76,13 @@ class TrainerConfig:
76
 
77
  class EarlyStopping:
78
  """Stop training when validation loss stops improving."""
79
-
80
  def __init__(self, patience: int = 5, min_delta: float = 0.001):
81
  self.patience = patience
82
  self.min_delta = min_delta
83
  self.counter = 0
84
- self.best_value = float('inf')
85
-
86
  def __call__(self, val_loss: float) -> bool:
87
  """Returns True if training should stop."""
88
  if val_loss < self.best_value - self.min_delta:
@@ -155,7 +155,9 @@ class Trainer:
155
 
156
  pbar = tqdm(
157
  range(start_epoch, self.config.max_epochs + 1),
158
- desc="Training", unit="epoch", file=sys.stderr
 
 
159
  )
160
 
161
  for epoch in pbar:
@@ -178,10 +180,12 @@ class Trainer:
178
 
179
  # Early stopping
180
  if self.early_stopping:
181
- val_loss = val_metrics.get("total_loss", float('inf'))
182
  if self.early_stopping(val_loss):
183
- tqdm.write(f"\nEarly stopping at epoch {epoch} (best loss: {self.early_stopping.best_value:.4f})")
184
-
 
 
185
  break
186
 
187
  # Checkpoint
@@ -190,11 +194,11 @@ class Trainer:
190
 
191
  # Update progress
192
  epoch_time = time.perf_counter() - epoch_start
193
- loss = train_metrics.get('total_loss', 0)
194
  pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
195
 
196
  total_time = time.perf_counter() - total_start
197
- print(f"\nTraining complete in {total_time/60:.1f} minutes")
198
  return history
199
 
200
  def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
@@ -203,7 +207,9 @@ class Trainer:
203
  self.scheduler = None
204
  return
205
 
206
- steps_per_epoch = max(len(loader) for loader in loaders.values()) // max(1, self.config.gradient_accumulation_steps)
 
 
207
  total_steps = steps_per_epoch * (self.config.max_epochs - start_epoch + 1)
208
  warmup = self.config.warmup_steps
209
 
@@ -238,10 +244,12 @@ class Trainer:
238
  if self.config.task_sampling == "temperature" and len(task_names) > 1:
239
  sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
240
  alpha = self.config.task_sampling_alpha
241
- probs = sizes ** alpha
242
  probs = probs / probs.sum()
243
- tqdm.write(f" Temperature sampling (α={alpha}): " +
244
- ", ".join(f"{t}={p:.2%}" for t, p in zip(task_names, probs, strict=True)))
 
 
245
  else:
246
  probs = None
247
 
@@ -253,7 +261,9 @@ class Trainer:
253
  # Select tasks for this step
254
  if probs is not None and train:
255
  # Temperature sampling: sample tasks based on dataset size
256
- selected_tasks = list(np.random.choice(task_names, size=len(task_names), replace=True, p=probs))
 
 
257
  else:
258
  # Round-robin: all tasks every step
259
  selected_tasks = task_names
@@ -288,8 +298,11 @@ class Trainer:
288
  scaled.backward()
289
 
290
  # Gradient conflict diagnostics
291
- if (train and self.config.gradient_conflict_frequency > 0
292
- and (step + 1) % self.config.gradient_conflict_frequency == 0):
 
 
 
293
  conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
294
  for k, v in conflict_stats.items():
295
  metrics[f"grad_{k}"].append(v)
@@ -316,8 +329,10 @@ class Trainer:
316
 
317
  # Average metrics
318
  averaged = {k: sum(v) / len(v) for k, v in metrics.items() if v}
319
- tqdm.write(f"[{phase.lower()}] epoch {epoch}: " +
320
- ", ".join(f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch"))
 
 
321
  return averaged
322
 
323
  def _get_batch(self, iterators: Dict, loader: DataLoader, task: str) -> Dict | None:
@@ -330,8 +345,10 @@ class Trainer:
330
  batch = next(iterators[task])
331
  except StopIteration:
332
  return None
333
- return {k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
334
- for k, v in batch.items()}
 
 
335
 
336
  def _forward_task(self, task: str, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
337
  """Route to task-specific forward pass."""
@@ -360,10 +377,10 @@ class Trainer:
360
  # Decode predictions and references
361
  preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
362
  refs = self._decode_labels(batch["labels"])
363
-
364
  # Calculate comprehensive metrics
365
  metrics = {"rouge_like": rouge_like(preds, refs)}
366
-
367
  # Proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L)
368
  try:
369
  rouge_scores = calculate_rouge(preds, refs)
@@ -372,13 +389,13 @@ class Trainer:
372
  metrics["rougeL"] = rouge_scores["rougeL"]
373
  except Exception:
374
  pass # Fall back to rouge_like only if rouge-score not installed
375
-
376
  # BLEU-4 score
377
  try:
378
  metrics["bleu4"] = calculate_bleu(preds, refs)
379
  except Exception:
380
  pass
381
-
382
  return loss, metrics
383
 
384
  def _forward_emotion(self, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
@@ -423,8 +440,10 @@ class Trainer:
423
  if i >= n:
424
  break
425
 
426
- batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
427
- for k, v in batch.items()}
 
 
428
  src_ids = batch["src_ids"][:1]
429
  src_mask = batch.get("src_mask", None)
430
  if src_mask is not None:
@@ -432,7 +451,9 @@ class Trainer:
432
 
433
  # Generate with anti-repetition
434
  model: Any = self.model
435
- enc_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
 
 
436
  memory = model.encoder(src_ids, mask=enc_mask)
437
  generated = model.decoder.greedy_decode(
438
  memory=memory,
@@ -463,27 +484,27 @@ class Trainer:
463
  iterators: Dict,
464
  ) -> Dict[str, float]:
465
  """Compute inter-task gradient cosine similarity to diagnose conflicts.
466
-
467
  Returns cosine similarity between gradient vectors for each task pair.
468
  Negative values indicate conflicting gradients (negative transfer risk).
469
  """
470
  task_grads: Dict[str, torch.Tensor] = {}
471
-
472
  for task, loader in loaders.items():
473
  self.optimizer.zero_grad()
474
  batch = self._get_batch(iterators, loader, task)
475
  if batch is None:
476
  continue
477
-
478
  dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
479
  with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
480
  loss, _ = self._forward_task(task, batch)
481
-
482
  if torch.isnan(loss):
483
  continue
484
-
485
  loss.backward()
486
-
487
  # Flatten all gradients into a single vector
488
  grad_vec = []
489
  for p in self.model.parameters():
@@ -491,9 +512,9 @@ class Trainer:
491
  grad_vec.append(p.grad.detach().clone().flatten())
492
  if grad_vec:
493
  task_grads[task] = torch.cat(grad_vec)
494
-
495
  self.optimizer.zero_grad()
496
-
497
  # Compute pairwise cosine similarity
498
  stats: Dict[str, float] = {}
499
  tasks = list(task_grads.keys())
@@ -504,20 +525,22 @@ class Trainer:
504
  cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
505
  stats[f"cos_sim_{t1}_{t2}"] = cos_sim
506
  stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
507
-
508
  return stats
509
 
510
  def _log_config(self) -> None:
511
  """Log config to MLflow."""
512
- mlflow.log_params({
513
- "max_epochs": self.config.max_epochs,
514
- "gradient_clip_norm": self.config.gradient_clip_norm,
515
- "label_smoothing": self.config.label_smoothing,
516
- "task_weights": str(self.config.task_weights),
517
- "warmup_steps": self.config.warmup_steps,
518
- "scheduler_type": self.config.scheduler_type,
519
- "learning_rate": self.optimizer.param_groups[0]["lr"],
520
- })
 
 
521
 
522
  def _log_metrics(self, metrics: Dict[str, float], prefix: str, epoch: int) -> None:
523
  """Log metrics to MLflow."""
 
48
  validation_max_length: int = 128
49
  label_smoothing: float = 0.1
50
  gradient_accumulation_steps: int = 1
51
+
52
  # LR scheduler
53
  scheduler_type: str = "cosine"
54
  warmup_steps: int = 500
55
+
56
  # Early stopping
57
  early_stopping_patience: int | None = 5
58
+
59
  # Task sampling strategy: "round_robin" or "temperature"
60
  # Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
61
  # alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
62
  task_sampling: str = "temperature"
63
  task_sampling_alpha: float = 0.5
64
+
65
  # Gradient conflict diagnostics
66
  # Compute inter-task gradient cosine similarity every N steps (0 = disabled)
67
  gradient_conflict_frequency: int = 0
68
+
69
  # MLflow
70
  experiment_name: str = "LexiMind"
71
  run_name: str | None = None
 
76
 
77
  class EarlyStopping:
78
  """Stop training when validation loss stops improving."""
79
+
80
  def __init__(self, patience: int = 5, min_delta: float = 0.001):
81
  self.patience = patience
82
  self.min_delta = min_delta
83
  self.counter = 0
84
+ self.best_value = float("inf")
85
+
86
  def __call__(self, val_loss: float) -> bool:
87
  """Returns True if training should stop."""
88
  if val_loss < self.best_value - self.min_delta:
 
155
 
156
  pbar = tqdm(
157
  range(start_epoch, self.config.max_epochs + 1),
158
+ desc="Training",
159
+ unit="epoch",
160
+ file=sys.stderr,
161
  )
162
 
163
  for epoch in pbar:
 
180
 
181
  # Early stopping
182
  if self.early_stopping:
183
+ val_loss = val_metrics.get("total_loss", float("inf"))
184
  if self.early_stopping(val_loss):
185
+ tqdm.write(
186
+ f"\nEarly stopping at epoch {epoch} (best loss: {self.early_stopping.best_value:.4f})"
187
+ )
188
+
189
  break
190
 
191
  # Checkpoint
 
194
 
195
  # Update progress
196
  epoch_time = time.perf_counter() - epoch_start
197
+ loss = train_metrics.get("total_loss", 0)
198
  pbar.set_postfix({"loss": f"{loss:.3f}", "time": f"{epoch_time:.0f}s"})
199
 
200
  total_time = time.perf_counter() - total_start
201
+ print(f"\nTraining complete in {total_time / 60:.1f} minutes")
202
  return history
203
 
204
  def _setup_scheduler(self, loaders: Dict[str, DataLoader], start_epoch: int) -> None:
 
207
  self.scheduler = None
208
  return
209
 
210
+ steps_per_epoch = max(len(loader) for loader in loaders.values()) // max(
211
+ 1, self.config.gradient_accumulation_steps
212
+ )
213
  total_steps = steps_per_epoch * (self.config.max_epochs - start_epoch + 1)
214
  warmup = self.config.warmup_steps
215
 
 
244
  if self.config.task_sampling == "temperature" and len(task_names) > 1:
245
  sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
246
  alpha = self.config.task_sampling_alpha
247
+ probs = sizes**alpha
248
  probs = probs / probs.sum()
249
+ tqdm.write(
250
+ f" Temperature sampling (α={alpha}): "
251
+ + ", ".join(f"{t}={p:.2%}" for t, p in zip(task_names, probs, strict=True))
252
+ )
253
  else:
254
  probs = None
255
 
 
261
  # Select tasks for this step
262
  if probs is not None and train:
263
  # Temperature sampling: sample tasks based on dataset size
264
+ selected_tasks = list(
265
+ np.random.choice(task_names, size=len(task_names), replace=True, p=probs)
266
+ )
267
  else:
268
  # Round-robin: all tasks every step
269
  selected_tasks = task_names
 
298
  scaled.backward()
299
 
300
  # Gradient conflict diagnostics
301
+ if (
302
+ train
303
+ and self.config.gradient_conflict_frequency > 0
304
+ and (step + 1) % self.config.gradient_conflict_frequency == 0
305
+ ):
306
  conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
307
  for k, v in conflict_stats.items():
308
  metrics[f"grad_{k}"].append(v)
 
329
 
330
  # Average metrics
331
  averaged = {k: sum(v) / len(v) for k, v in metrics.items() if v}
332
+ tqdm.write(
333
+ f"[{phase.lower()}] epoch {epoch}: "
334
+ + ", ".join(f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch")
335
+ )
336
  return averaged
337
 
338
  def _get_batch(self, iterators: Dict, loader: DataLoader, task: str) -> Dict | None:
 
345
  batch = next(iterators[task])
346
  except StopIteration:
347
  return None
348
+ return {
349
+ k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
350
+ for k, v in batch.items()
351
+ }
352
 
353
  def _forward_task(self, task: str, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
354
  """Route to task-specific forward pass."""
 
377
  # Decode predictions and references
378
  preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
379
  refs = self._decode_labels(batch["labels"])
380
+
381
  # Calculate comprehensive metrics
382
  metrics = {"rouge_like": rouge_like(preds, refs)}
383
+
384
  # Proper ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L)
385
  try:
386
  rouge_scores = calculate_rouge(preds, refs)
 
389
  metrics["rougeL"] = rouge_scores["rougeL"]
390
  except Exception:
391
  pass # Fall back to rouge_like only if rouge-score not installed
392
+
393
  # BLEU-4 score
394
  try:
395
  metrics["bleu4"] = calculate_bleu(preds, refs)
396
  except Exception:
397
  pass
398
+
399
  return loss, metrics
400
 
401
  def _forward_emotion(self, batch: Dict) -> tuple[torch.Tensor, Dict[str, float]]:
 
440
  if i >= n:
441
  break
442
 
443
+ batch = {
444
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
445
+ for k, v in batch.items()
446
+ }
447
  src_ids = batch["src_ids"][:1]
448
  src_mask = batch.get("src_mask", None)
449
  if src_mask is not None:
 
451
 
452
  # Generate with anti-repetition
453
  model: Any = self.model
454
+ enc_mask = (
455
+ src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
456
+ )
457
  memory = model.encoder(src_ids, mask=enc_mask)
458
  generated = model.decoder.greedy_decode(
459
  memory=memory,
 
484
  iterators: Dict,
485
  ) -> Dict[str, float]:
486
  """Compute inter-task gradient cosine similarity to diagnose conflicts.
487
+
488
  Returns cosine similarity between gradient vectors for each task pair.
489
  Negative values indicate conflicting gradients (negative transfer risk).
490
  """
491
  task_grads: Dict[str, torch.Tensor] = {}
492
+
493
  for task, loader in loaders.items():
494
  self.optimizer.zero_grad()
495
  batch = self._get_batch(iterators, loader, task)
496
  if batch is None:
497
  continue
498
+
499
  dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
500
  with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
501
  loss, _ = self._forward_task(task, batch)
502
+
503
  if torch.isnan(loss):
504
  continue
505
+
506
  loss.backward()
507
+
508
  # Flatten all gradients into a single vector
509
  grad_vec = []
510
  for p in self.model.parameters():
 
512
  grad_vec.append(p.grad.detach().clone().flatten())
513
  if grad_vec:
514
  task_grads[task] = torch.cat(grad_vec)
515
+
516
  self.optimizer.zero_grad()
517
+
518
  # Compute pairwise cosine similarity
519
  stats: Dict[str, float] = {}
520
  tasks = list(task_grads.keys())
 
525
  cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
526
  stats[f"cos_sim_{t1}_{t2}"] = cos_sim
527
  stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
528
+
529
  return stats
530
 
531
  def _log_config(self) -> None:
532
  """Log config to MLflow."""
533
+ mlflow.log_params(
534
+ {
535
+ "max_epochs": self.config.max_epochs,
536
+ "gradient_clip_norm": self.config.gradient_clip_norm,
537
+ "label_smoothing": self.config.label_smoothing,
538
+ "task_weights": str(self.config.task_weights),
539
+ "warmup_steps": self.config.warmup_steps,
540
+ "scheduler_type": self.config.scheduler_type,
541
+ "learning_rate": self.optimizer.param_groups[0]["lr"],
542
+ }
543
+ )
544
 
545
  def _log_metrics(self, metrics: Dict[str, float], prefix: str, epoch: int) -> None:
546
  """Log metrics to MLflow."""
src/utils/__init__.py CHANGED
@@ -14,9 +14,16 @@ from .io import load_state, save_state
14
  from .labels import load_label_metadata, save_label_metadata
15
 
16
  __all__ = [
17
- "save_checkpoint", "load_checkpoint",
18
- "save_state", "load_state",
19
- "LabelMetadata", "load_labels", "save_labels",
20
- "load_label_metadata", "save_label_metadata",
21
- "set_seed", "Config", "load_yaml",
 
 
 
 
 
 
 
22
  ]
 
14
  from .labels import load_label_metadata, save_label_metadata
15
 
16
  __all__ = [
17
+ "save_checkpoint",
18
+ "load_checkpoint",
19
+ "save_state",
20
+ "load_state",
21
+ "LabelMetadata",
22
+ "load_labels",
23
+ "save_labels",
24
+ "load_label_metadata",
25
+ "save_label_metadata",
26
+ "set_seed",
27
+ "Config",
28
+ "load_yaml",
29
  ]
src/utils/core.py CHANGED
@@ -28,7 +28,7 @@ def save_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
28
  """Save model state dict, handling torch.compile artifacts."""
29
  path = Path(path)
30
  path.parent.mkdir(parents=True, exist_ok=True)
31
-
32
  # Strip '_orig_mod.' prefix from compiled models
33
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in model.state_dict().items()}
34
  torch.save(state_dict, path)
@@ -47,7 +47,7 @@ def load_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
47
  @dataclass
48
  class LabelMetadata:
49
  """Container for emotion and topic label vocabularies."""
50
-
51
  emotion: List[str]
52
  topic: List[str]
53
 
@@ -65,16 +65,16 @@ def load_labels(path: str | Path) -> LabelMetadata:
65
  path = Path(path)
66
  if not path.exists():
67
  raise FileNotFoundError(f"Labels not found: {path}")
68
-
69
  with path.open("r", encoding="utf-8") as f:
70
  data = json.load(f)
71
-
72
  emotion = data.get("emotion") or data.get("emotions", [])
73
  topic = data.get("topic") or data.get("topics", [])
74
-
75
  if not emotion or not topic:
76
  raise ValueError("Labels file must contain 'emotion' and 'topic' lists")
77
-
78
  return LabelMetadata(emotion=emotion, topic=topic)
79
 
80
 
@@ -82,7 +82,7 @@ def save_labels(labels: LabelMetadata, path: str | Path) -> None:
82
  """Save label metadata to JSON file."""
83
  path = Path(path)
84
  path.parent.mkdir(parents=True, exist_ok=True)
85
-
86
  with path.open("w", encoding="utf-8") as f:
87
  json.dump({"emotion": labels.emotion, "topic": labels.topic}, f, indent=2)
88
 
@@ -105,12 +105,14 @@ def set_seed(seed: int) -> None:
105
  @dataclass
106
  class Config:
107
  """Simple config wrapper."""
 
108
  data: dict
109
 
110
 
111
  def load_yaml(path: str | Path) -> Config:
112
  """Load YAML configuration file."""
113
  import yaml
 
114
  with Path(path).open("r", encoding="utf-8") as f:
115
  content = yaml.safe_load(f)
116
  if not isinstance(content, dict):
 
28
  """Save model state dict, handling torch.compile artifacts."""
29
  path = Path(path)
30
  path.parent.mkdir(parents=True, exist_ok=True)
31
+
32
  # Strip '_orig_mod.' prefix from compiled models
33
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in model.state_dict().items()}
34
  torch.save(state_dict, path)
 
47
  @dataclass
48
  class LabelMetadata:
49
  """Container for emotion and topic label vocabularies."""
50
+
51
  emotion: List[str]
52
  topic: List[str]
53
 
 
65
  path = Path(path)
66
  if not path.exists():
67
  raise FileNotFoundError(f"Labels not found: {path}")
68
+
69
  with path.open("r", encoding="utf-8") as f:
70
  data = json.load(f)
71
+
72
  emotion = data.get("emotion") or data.get("emotions", [])
73
  topic = data.get("topic") or data.get("topics", [])
74
+
75
  if not emotion or not topic:
76
  raise ValueError("Labels file must contain 'emotion' and 'topic' lists")
77
+
78
  return LabelMetadata(emotion=emotion, topic=topic)
79
 
80
 
 
82
  """Save label metadata to JSON file."""
83
  path = Path(path)
84
  path.parent.mkdir(parents=True, exist_ok=True)
85
+
86
  with path.open("w", encoding="utf-8") as f:
87
  json.dump({"emotion": labels.emotion, "topic": labels.topic}, f, indent=2)
88
 
 
105
  @dataclass
106
  class Config:
107
  """Simple config wrapper."""
108
+
109
  data: dict
110
 
111
 
112
  def load_yaml(path: str | Path) -> Config:
113
  """Load YAML configuration file."""
114
  import yaml
115
+
116
  with Path(path).open("r", encoding="utf-8") as f:
117
  content = yaml.safe_load(f)
118
  if not isinstance(content, dict):
tests/test_training/test_trainer.py CHANGED
@@ -111,8 +111,9 @@ class TestGradientFlow(unittest.TestCase):
111
  loss = nn.CrossEntropyLoss()(logits, batch["labels"])
112
  loss.backward()
113
 
114
- has_grads = any(p.grad is not None and p.grad.abs().sum() > 0
115
- for p in self.model.parameters())
 
116
  self.assertTrue(has_grads, "No gradients found")
117
 
118
  def test_emotion_gradients(self):
@@ -130,8 +131,9 @@ class TestGradientFlow(unittest.TestCase):
130
  loss = nn.BCEWithLogitsLoss()(logits, batch["labels"])
131
  loss.backward()
132
 
133
- has_grads = any(p.grad is not None and p.grad.abs().sum() > 0
134
- for p in self.model.parameters())
 
135
  self.assertTrue(has_grads, "No gradients found")
136
 
137
  def test_summarization_gradients(self):
@@ -145,14 +147,12 @@ class TestGradientFlow(unittest.TestCase):
145
  self.model.zero_grad()
146
  logits = self.model.forward("summarization", batch)
147
  # Flatten for cross entropy: (B*T, vocab) vs (B*T,)
148
- loss = nn.CrossEntropyLoss()(
149
- logits.view(-1, 100),
150
- batch["labels"].view(-1)
151
- )
152
  loss.backward()
153
 
154
- has_grads = any(p.grad is not None and p.grad.abs().sum() > 0
155
- for p in self.model.parameters())
 
156
  self.assertTrue(has_grads, "No gradients found")
157
 
158
 
 
111
  loss = nn.CrossEntropyLoss()(logits, batch["labels"])
112
  loss.backward()
113
 
114
+ has_grads = any(
115
+ p.grad is not None and p.grad.abs().sum() > 0 for p in self.model.parameters()
116
+ )
117
  self.assertTrue(has_grads, "No gradients found")
118
 
119
  def test_emotion_gradients(self):
 
131
  loss = nn.BCEWithLogitsLoss()(logits, batch["labels"])
132
  loss.backward()
133
 
134
+ has_grads = any(
135
+ p.grad is not None and p.grad.abs().sum() > 0 for p in self.model.parameters()
136
+ )
137
  self.assertTrue(has_grads, "No gradients found")
138
 
139
  def test_summarization_gradients(self):
 
147
  self.model.zero_grad()
148
  logits = self.model.forward("summarization", batch)
149
  # Flatten for cross entropy: (B*T, vocab) vs (B*T,)
150
+ loss = nn.CrossEntropyLoss()(logits.view(-1, 100), batch["labels"].view(-1))
 
 
 
151
  loss.backward()
152
 
153
+ has_grads = any(
154
+ p.grad is not None and p.grad.abs().sum() > 0 for p in self.model.parameters()
155
+ )
156
  self.assertTrue(has_grads, "No gradients found")
157
 
158