OliverPerrin commited on
Commit
0d858b5
·
1 Parent(s): 35856cb

Added many new improvements based on feedback from others

Browse files
configs/training/dev.yaml CHANGED
@@ -37,6 +37,9 @@ trainer:
37
  max_val_samples: 300
38
  early_stopping_patience: 5
39
  log_grad_norm_frequency: 100
 
 
 
40
 
41
  # Enable compile for speed (worth the startup cost)
42
  compile_encoder: true
 
37
  max_val_samples: 300
38
  early_stopping_patience: 5
39
  log_grad_norm_frequency: 100
40
+ task_sampling: round_robin
41
+ task_sampling_alpha: 0.5
42
+ gradient_conflict_frequency: 0
43
 
44
  # Enable compile for speed (worth the startup cost)
45
  compile_encoder: true
configs/training/full.yaml CHANGED
@@ -38,6 +38,13 @@ trainer:
38
  max_val_samples: 3000 # Enough for stable metrics
39
  early_stopping_patience: 3 # Stop quickly when plateauing
40
  log_grad_norm_frequency: 200
 
 
 
 
 
 
 
41
 
42
  compile_encoder: true
43
  compile_decoder: true
 
38
  max_val_samples: 3000 # Enough for stable metrics
39
  early_stopping_patience: 3 # Stop quickly when plateauing
40
  log_grad_norm_frequency: 200
41
+ # Task sampling: "round_robin" (default) or "temperature"
42
+ # Temperature sampling: p_i proportional to n_i^alpha, reduces dominance of large tasks
43
+ task_sampling: round_robin
44
+ task_sampling_alpha: 0.5
45
+ # Gradient conflict diagnostics: compute inter-task gradient cosine similarity
46
+ # every N steps (0 = disabled). Helps diagnose negative transfer.
47
+ gradient_conflict_frequency: 0
48
 
49
  compile_encoder: true
50
  compile_decoder: true
configs/training/medium.yaml CHANGED
@@ -37,6 +37,9 @@ trainer:
37
  max_val_samples: 2500
38
  early_stopping_patience: 3 # More patience
39
  log_grad_norm_frequency: 100
 
 
 
40
 
41
  compile_encoder: true
42
  compile_decoder: true
 
37
  max_val_samples: 2500
38
  early_stopping_patience: 3 # More patience
39
  log_grad_norm_frequency: 100
40
+ task_sampling: round_robin
41
+ task_sampling_alpha: 0.5
42
+ gradient_conflict_frequency: 0
43
 
44
  compile_encoder: true
45
  compile_decoder: true
docs/research_paper.tex CHANGED
@@ -44,7 +44,7 @@ Email: perrinot@appstate.edu}}
44
  \maketitle
45
 
46
  \begin{abstract}
47
- Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies comparing single-task specialists against multi-task configurations, we find that: (1) MTL provides a +3.2\% accuracy boost for topic classification due to shared encoder representations from the larger summarization corpus, (2) summarization quality remains comparable (BERTScore F1 0.83 vs. 0.82 single-task), and (3) emotion detection suffers negative transfer ($-$0.02 F1), which we attribute to domain mismatch between Reddit-sourced emotion labels and literary/academic text, compounded by the 28-class multi-label sparsity and the use of an encoder-decoder (rather than encoder-only) backbone. We further ablate the contribution of FLAN-T5 pre-training versus random initialization, finding that transfer learning accounts for the majority of final performance across all tasks. Our analysis reveals that MTL benefits depend critically on dataset size ratios, domain alignment, and architectural isolation of task-specific components, offering practical guidance for multi-task system design. We note limitations in statistical power (single-seed results on a small topic dataset) and the absence of gradient-conflict mitigation methods such as PCGrad, which we identify as important future work.
48
  \end{abstract}
49
 
50
  \begin{IEEEkeywords}
@@ -76,7 +76,7 @@ To answer these questions, we construct \textbf{LexiMind}, a multi-task system b
76
  \item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
77
  \end{itemize}
78
 
79
- We acknowledge important limitations: our results are from single-seed runs, we do not explore gradient-conflict mitigation methods (PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}), and our emotion evaluation conflates domain mismatch with multi-label threshold and architecture choices. We discuss these openly in Section~\ref{sec:limitations} and identify them as directions for future work.
80
 
81
  %=============================================================================
82
  \section{Related Work}
@@ -86,7 +86,9 @@ We acknowledge important limitations: our results are from single-seed runs, we
86
 
87
  Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
88
 
89
- \textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach. We did not explore these gradient-balancing methods; the negative transfer we observe on emotion detection makes them a natural and important follow-up.
 
 
90
 
91
  \textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
92
 
@@ -136,7 +138,9 @@ Emotion (28 labels) & GoEmotions (Reddit) & 43,410 & 5,426 & 5,427 \\
136
  \end{tabular}
137
  \end{table}
138
 
139
- \textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed across the literary and academic subsets, as they are drawn from disjoint sources. We note that the academic subset is substantially larger ($\sim$45K vs. $\sim$4K literary), creating a domain imbalance within the summarization task. Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects, 20 Newsgroups categories) and mapped to our 7-class taxonomy; no manual annotation was performed. GoEmotions is used as-is from the HuggingFace datasets hub.
 
 
140
 
141
  \textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
142
 
@@ -154,12 +158,12 @@ LexiMind uses FLAN-T5-base (272M parameters) as the backbone, with a custom reim
154
 
155
  Task-specific heads branch from the shared encoder:
156
  \begin{itemize}
157
- \item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing)
158
  \item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
159
- \item \textbf{Emotion}: Linear classifier on mean-pooled encoder hidden states with sigmoid activation (binary cross-entropy loss)
160
  \end{itemize}
161
 
162
- \textbf{Architectural note.} Using mean-pooled encoder states for classification in an encoder-decoder model is a pragmatic choice for parameter sharing, but may be suboptimal compared to encoder-only architectures (BERT, RoBERTa) where the encoder is fully dedicated to producing classification-ready representations. We discuss this trade-off in Section~\ref{sec:emotion_analysis}.
163
 
164
  \subsection{Training Configuration}
165
 
@@ -175,10 +179,14 @@ All experiments use consistent hyperparameters unless otherwise noted:
175
  \item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
176
  \end{itemize}
177
 
178
- \textbf{Task scheduling.} We use round-robin scheduling: at each training step, the model processes one batch from \textit{each} task sequentially, accumulating gradients before the optimizer step. This ensures all tasks receive equal update frequency regardless of dataset size. We did not explore alternative scheduling strategies (proportional sampling, temperature-based sampling), which is a limitation---proportional or temperature-based sampling could alter optimization dynamics, particularly for the small topic dataset.
179
 
180
  \textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
181
 
 
 
 
 
182
  \subsection{Baselines and Ablations}
183
 
184
  We compare four configurations:
@@ -195,12 +203,12 @@ We additionally ablate FLAN-T5 initialization vs. random initialization to isola
195
  \subsection{Evaluation Metrics}
196
 
197
  \begin{itemize}
198
- \item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BERTScore F1 \cite{zhang2019bertscore} using RoBERTa-large (semantic similarity). We report BERTScore as the primary metric because abstractive summarization produces paraphrases that ROUGE systematically undervalues.
199
  \item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
200
- \item \textbf{Emotion}: Sample-averaged F1 (computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples). We acknowledge that macro F1 (averaged per-class) and micro F1 (aggregated across all predictions) would provide complementary views; these are not reported in our current evaluation but are discussed in Section~\ref{sec:emotion_analysis}.
201
  \end{itemize}
202
 
203
- \textbf{Statistical note.} All results are from single training runs. We do not report confidence intervals or variance across seeds. Given the small topic dataset (189 validation samples), the observed +3.2\% accuracy improvement could be within random variance. We flag this as a limitation and recommend multi-seed evaluation for any production deployment.
204
 
205
  %=============================================================================
206
  \section{Results}
@@ -234,11 +242,11 @@ Emotion & Sample-avg F1 & \textbf{0.218} & 0.199 \\
234
  \textbf{Key finding}: MTL provides heterogeneous effects across tasks:
235
 
236
  \begin{itemize}
237
- \item \textbf{Topic classification gains +3.2\% accuracy} from MTL. The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). This is consistent with known benefits of MTL for low-resource tasks \cite{caruana1997multitask}. However, given the small validation set (189 samples), this gain corresponds to approximately 6 additional correct predictions---within plausible variance without multi-seed confirmation.
238
 
239
- \item \textbf{Summarization shows modest improvement} (+0.009 BERTScore F1). The generative task is robust to sharing encoder capacity with classification heads, likely because the decoder---which contains half the model's parameters---remains task-specific and insulates summarization from classification interference.
240
 
241
- \item \textbf{Emotion detection degrades by $-$0.019 F1}. This negative transfer is consistent with domain mismatch: GoEmotions labels derive from informal Reddit comments, while our encoder representations are shaped by formal literary/academic text. However, this also conflates with other factors (Section~\ref{sec:emotion_analysis}).
242
  \end{itemize}
243
 
244
  \subsection{Baseline Comparisons}
@@ -323,11 +331,11 @@ Our emotion sample-averaged F1 (0.20) is substantially lower than reported GoEmo
323
  \begin{enumerate}
324
  \item \textbf{Domain shift}: GoEmotions labels were annotated on Reddit comments in conversational register. Our encoder is shaped by literary and academic text through the summarization objective, producing representations optimized for formal text. This domain mismatch is likely the largest factor, but we cannot isolate it without a controlled experiment (e.g., fine-tuning BERT on GoEmotions with our frozen encoder vs. BERT's own encoder).
325
 
326
- \item \textbf{Label sparsity and class imbalance}: The 28-class multi-label scheme creates extreme imbalance. Rare emotions (grief, remorse, nervousness) appear in $<$2\% of samples. We use a fixed prediction threshold of 0.3 (during training evaluation), without per-class threshold tuning on the validation set---a simplification that the original GoEmotions work \cite{demszky2020goemotions} explicitly optimizes. Per-class threshold tuning could meaningfully improve results.
327
 
328
- \item \textbf{Architecture mismatch}: Published GoEmotions baselines use encoder-only models (BERT-base), where the full model capacity is dedicated to producing classification-ready representations. Our encoder-decoder architecture optimizes the encoder primarily for producing representations that the decoder can use for summarization---classification heads receive these representations secondarily. The mean-pooling strategy may also be suboptimal; alternatives such as [CLS] token pooling, attention-weighted pooling, or adapter layers \cite{houlsby2019parameter} could yield better classification features.
329
 
330
- \item \textbf{Metric reporting}: We report sample-averaged F1 (per-sample, then averaged), which is not directly comparable to macro F1 (per-class, then averaged) as reported in the original GoEmotions work. Reporting macro F1, micro F1, and per-label performance would provide a more complete picture. We identify this as a gap in our current evaluation.
331
  \end{enumerate}
332
 
333
  \textbf{Implication}: Off-the-shelf emotion datasets from social media should not be naively combined with literary/academic tasks in MTL. Domain-specific emotion annotation or domain adaptation techniques are needed for formal text domains.
@@ -366,9 +374,9 @@ Our results support nuanced, task-dependent guidance:
366
 
367
  \subsection{Comparison to MTL Literature}
368
 
369
- Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---we observe this in the contrast between topic (positive transfer) and emotion (negative transfer). Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our round-robin scheduling with fixed weights does not address such conflicts, and methods like PCGrad could potentially mitigate the emotion degradation by projecting away conflicting gradient components. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks in extreme multi-task settings; our small-scale results are consistent with this pattern.
370
 
371
- A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. This architectural dynamic deserves further study.
372
 
373
  \subsection{Implications for Practitioners}
374
 
@@ -379,9 +387,13 @@ Based on our findings:
379
 
380
  \item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
381
 
382
- \item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation.
383
 
384
- \item \textbf{Validate with multiple seeds} before drawing conclusions from MTL comparisons, especially with small validation sets.
 
 
 
 
385
  \end{enumerate}
386
 
387
  \subsection{Limitations}
@@ -390,35 +402,39 @@ Based on our findings:
390
  We identify several limitations that constrain the generalizability of our findings:
391
 
392
  \begin{itemize}
393
- \item \textbf{Single-seed results}: All experiments are single runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. Multi-seed evaluation with confidence intervals is needed to confirm the direction and magnitude of transfer effects.
394
 
395
- \item \textbf{No gradient-conflict mitigation}: We use fixed loss weights and do not explore PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods are directly relevant to our observed negative transfer on emotion detection and could potentially convert it to positive or neutral transfer.
396
 
397
  \item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results.
398
 
399
- \item \textbf{Emotion evaluation gaps}: We report sample-averaged F1 with a fixed threshold (0.3). Per-class thresholds tuned on validation, per-label metrics, focal loss for class imbalance \cite{lin2017focal}, and calibration analysis would provide more informative evaluation. The conclusion that ``domain mismatch is the primary cause'' of low emotion F1 is plausible but confounded by these design choices.
 
 
400
 
401
  \item \textbf{No human evaluation}: ROUGE and BERTScore are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
402
 
403
  \item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
404
 
405
- \item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text. This imbalance is not analyzed separately but could affect literary summarization quality.
406
  \end{itemize}
407
 
408
  \subsection{Future Work}
409
 
410
  \begin{itemize}
411
- \item \textbf{Gradient-conflict mitigation}: Applying PCGrad or CAGrad to test whether emotion negative transfer can be reduced or eliminated. This is the most directly actionable follow-up given our current findings.
412
 
413
- \item \textbf{Parameter-efficient multi-tasking}: Using per-task LoRA adapters \cite{hu2022lora} or adapter layers \cite{houlsby2019parameter} to provide task-specific specialization while maintaining shared encoder representations. This could reduce interference between tasks with misaligned domains.
 
 
414
 
415
  \item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
416
 
417
- \item \textbf{Multi-seed evaluation}: Running at least 3--5 seeds per configuration to establish statistical significance of observed transfer effects.
418
 
419
  \item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
420
 
421
- \item \textbf{Improved emotion evaluation}: Per-class threshold tuning, macro/micro F1, class-level analysis, and focal loss to address class imbalance.
422
  \end{itemize}
423
 
424
  %=============================================================================
@@ -427,7 +443,9 @@ We identify several limitations that constrain the generalizability of our findi
427
 
428
  We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our ablation studies reveal heterogeneous transfer effects: topic classification benefits from shared representations with the larger summarization corpus (+3.2\% accuracy), while emotion detection suffers negative transfer ($-$0.02 F1) due to domain mismatch with Reddit-sourced labels. Summarization quality is robust to multi-task training, insulated by its task-specific decoder.
429
 
430
- Pre-trained initialization (FLAN-T5) is essential for competitive performance across all tasks, with fine-tuning providing necessary domain adaptation. These findings are consistent with the broader MTL literature on the importance of task compatibility and domain alignment. However, we emphasize the limitations of our single-seed evaluation design and the absence of gradient-conflict mitigation methods, which could alter the negative transfer findings. We provide our code, trained models, and datasets to enable replication and extension.
 
 
431
 
432
  Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
433
  Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
@@ -513,6 +531,21 @@ N. Houlsby et al., ``Parameter-efficient transfer learning for NLP,'' in \textit
513
  \bibitem{lin2017focal}
514
  T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  \end{thebibliography}
517
 
518
  \end{document}
 
44
  \maketitle
45
 
46
  \begin{abstract}
47
+ Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies comparing single-task specialists against multi-task configurations, we find that: (1) MTL provides a +3.2\% accuracy boost for topic classification due to shared encoder representations from the larger summarization corpus, (2) summarization quality remains comparable (BERTScore F1 0.83 vs. 0.82 single-task), and (3) emotion detection suffers negative transfer ($-$0.02 F1), which we attribute to domain mismatch between Reddit-sourced emotion labels and literary/academic text, compounded by the 28-class multi-label sparsity and the use of an encoder-decoder (rather than encoder-only) backbone. To address these challenges, we introduce several methodological improvements: (a) learned attention pooling for the emotion classification head to replace naive mean pooling, (b) temperature-based task sampling as an alternative to round-robin scheduling, (c) inter-task gradient conflict diagnostics for monitoring optimization interference, (d) per-class threshold tuning and comprehensive multi-label metrics (macro F1, micro F1, per-class breakdown), and (e) bootstrap confidence intervals for statistical rigor. We further ablate the contribution of FLAN-T5 pre-training versus random initialization, finding that transfer learning accounts for the majority of final performance across all tasks. Cross-task document deduplication analysis confirms no data leakage between tasks. Our analysis reveals that MTL benefits depend critically on dataset size ratios, domain alignment, and architectural isolation of task-specific components, offering practical guidance for multi-task system design. Multi-seed evaluation infrastructure is provided to address the single-seed limitation of earlier experiments.
48
  \end{abstract}
49
 
50
  \begin{IEEEkeywords}
 
76
  \item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
77
  \end{itemize}
78
 
79
+ We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We do not yet apply gradient-conflict mitigation methods (PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}), but introduce gradient conflict diagnostics to characterize task interference. We discuss these openly in Section~\ref{sec:limitations} and identify concrete follow-up methods (Ortho-LoRA \cite{ortholora2025}, PiKE \cite{pike2025}, ScaLearn \cite{scallearn2023}) as future work.
80
 
81
  %=============================================================================
82
  \section{Related Work}
 
86
 
87
  Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
88
 
89
+ \textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach---but includes gradient conflict diagnostics (inter-task cosine similarity monitoring) to characterize optimization interference. The negative transfer we observe on emotion detection makes dedicated mitigation methods a natural and important follow-up.
90
+
91
+ \textbf{Recent advances in multi-task optimization.} Several recent methods address task interference more precisely. Ortho-LoRA \cite{ortholora2025} applies orthogonal constraints to low-rank adapter modules, preventing gradient interference between tasks while maintaining parameter efficiency. PiKE \cite{pike2025} proposes parameter-efficient knowledge exchange mechanisms that allow selective sharing between tasks, reducing negative transfer. ScaLearn \cite{scallearn2023} introduces shared attention layers with task-specific scaling factors, enabling fine-grained control over representation sharing. Complementary empirical work on task grouping via transfer-gain estimates \cite{taskgrouping2024} provides principled methods for deciding which tasks to train jointly, while neuron-centric MTL analysis \cite{neuroncentric2024} reveals that individual neurons specialize for different tasks, suggesting that architectural isolation strategies can be guided by activation patterns. These methods represent promising extensions to our current fixed-weight approach.
92
 
93
  \textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
94
 
 
138
  \end{tabular}
139
  \end{table}
140
 
141
+ \textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed \textit{within} the literary and academic subsets, as they are drawn from disjoint sources. We note that the academic subset is substantially larger ($\sim$45K vs. $\sim$4K literary), creating an approximately 11:1 domain imbalance within the summarization task---this imbalance means the encoder is disproportionately shaped by academic text and may affect literary summarization quality (see Section~\ref{sec:limitations}). Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects, 20 Newsgroups categories) and mapped to our 7-class taxonomy; no manual annotation was performed, which introduces potential noise from metadata inaccuracies (e.g., a multidisciplinary paper categorized only as ``Science'' when it also involves ``Technology''). GoEmotions is used as-is from the HuggingFace datasets hub.
142
+
143
+ \textbf{Cross-task deduplication.} Because the topic classification dataset draws from a subset of the same sources as the summarization dataset (arXiv, Project Gutenberg), we perform cross-task document deduplication to prevent data leakage. Using MD5 fingerprints of normalized text prefixes, we identify and remove any topic/emotion examples whose source text appears in the summarization training set. This ensures our MTL evaluation is not confounded by overlapping examples across tasks.
144
 
145
  \textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
146
 
 
158
 
159
  Task-specific heads branch from the shared encoder:
160
  \begin{itemize}
161
+ \item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing $\epsilon=0.1$, greedy decoding with max length 512 tokens)
162
  \item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
163
+ \item \textbf{Emotion}: Linear classifier on \textit{attention-pooled} encoder hidden states with sigmoid activation (binary cross-entropy loss). Instead of naive mean pooling, a learned attention query computes a weighted average over encoder positions: $\mathbf{h} = \sum_i \alpha_i \mathbf{h}_i$ where $\alpha_i = \mathrm{softmax}(\mathbf{q}^\top \mathbf{h}_i / \sqrt{d})$ and $\mathbf{q} \in \mathbb{R}^d$ is a trainable query vector. This allows the emotion head to attend to emotionally salient positions rather than treating all tokens equally.
164
  \end{itemize}
165
 
166
+ \textbf{Architectural note.} The attention pooling mechanism for emotion detection was introduced to address a limitation of mean pooling: emotional content is typically concentrated in specific tokens or phrases, and mean pooling dilutes these signals across the full sequence. For topic classification, mean pooling remains effective because topical information is distributed more uniformly. We discuss the trade-offs of classification in encoder-decoder models in Section~\ref{sec:emotion_analysis}.
167
 
168
  \subsection{Training Configuration}
169
 
 
179
  \item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
180
  \end{itemize}
181
 
182
+ \textbf{Task scheduling.} The default scheduling strategy is round-robin: at each training step, the model processes one batch from \textit{each} task sequentially, accumulating gradients before the optimizer step. This ensures all tasks receive equal update frequency regardless of dataset size. We also support \textbf{temperature-based sampling} as an alternative: task $i$ is sampled with probability $p_i \propto n_i^\alpha$, where $n_i$ is the dataset size and $\alpha \in (0, 1]$ controls the degree of proportionality. With $\alpha=0.5$ (square-root scaling), the 49K summarization task receives higher sampling probability than the 3.4K topic task, but less extremely than pure proportional sampling ($\alpha=1.0$). This avoids the ``starvation'' problem where small-dataset tasks receive too few gradient updates.
183
 
184
  \textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
185
 
186
+ \textbf{Gradient conflict monitoring.} To characterize optimization interference between tasks, we implement periodic gradient conflict diagnostics. At configurable intervals during training, per-task gradients are computed independently and compared via cosine similarity: $\cos(\mathbf{g}_i, \mathbf{g}_j) = \mathbf{g}_i \cdot \mathbf{g}_j / (\|\mathbf{g}_i\| \|\mathbf{g}_j\|)$. Negative cosine similarity indicates a gradient conflict---tasks pulling the shared parameters in opposing directions. Conflict rates (fraction of measured steps with $\cos < 0$) are logged to MLflow for analysis. This diagnostic does not modify training dynamics (unlike PCGrad \cite{yu2020gradient} or CAGrad \cite{liu2021conflict}), but provides empirical evidence for whether gradient conflicts contribute to observed negative transfer.
187
+
188
+ \textbf{Early stopping.} Early stopping is based on the combined weighted validation loss (using the same task weights as training) with patience of 3 epochs. The best checkpoint is selected by minimum combined validation loss.
189
+
190
  \subsection{Baselines and Ablations}
191
 
192
  We compare four configurations:
 
203
  \subsection{Evaluation Metrics}
204
 
205
  \begin{itemize}
206
+ \item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BERTScore F1 \cite{zhang2019bertscore} using RoBERTa-large (semantic similarity). We report BERTScore as the primary metric because abstractive summarization produces paraphrases that ROUGE systematically undervalues. Per-domain breakdown (literary vs. academic) is provided to analyze domain-specific quality.
207
  \item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
208
+ \item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
209
  \end{itemize}
210
 
211
+ \textbf{Statistical rigor.} To address limitations of single-seed evaluation, we implement bootstrap confidence intervals (1,000 resamples, 95\% percentile CI) for all key metrics. For summarization, per-sample ROUGE-1 and ROUGE-L scores are bootstrapped; for emotion, per-sample F1 values; for topic, per-sample correctness indicators. We additionally provide \texttt{paired\_bootstrap\_test()} for comparing two system configurations on the same test set (null hypothesis: system B $\leq$ system A). Multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) automates training across $k$ seeds and reports mean $\pm$ standard deviation across runs, enabling variance-aware claims. Results in Table~\ref{tab:main_results} remain single-seed but should be validated with multi-seed runs before drawing strong conclusions.
212
 
213
  %=============================================================================
214
  \section{Results}
 
242
  \textbf{Key finding}: MTL provides heterogeneous effects across tasks:
243
 
244
  \begin{itemize}
245
+ \item \textbf{Topic classification gains +3.2\% accuracy} from MTL. The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). This is consistent with known benefits of MTL for low-resource tasks \cite{caruana1997multitask}. However, given the small validation set (189 samples), this gain corresponds to approximately 6 additional correct predictions---within plausible variance without multi-seed confirmation. Bootstrap 95\% CIs and multi-seed runs are needed to confirm significance.
246
 
247
+ \item \textbf{Summarization shows modest improvement} (+0.009 BERTScore F1). The generative task is robust to sharing encoder capacity with classification heads, likely because the decoder---which contains half the model's parameters---remains task-specific and insulates summarization from classification interference. Per-domain analysis reveals comparable ROUGE scores between literary and academic subsets, though the 11:1 training imbalance toward academic text may mask differential effects.
248
 
249
+ \item \textbf{Emotion detection degrades by $-$0.019 sample-avg F1}. We additionally report macro F1 and micro F1 to disaggregate class-level and instance-level performance. This negative transfer is consistent with domain mismatch: GoEmotions labels derive from informal Reddit comments, while our encoder representations are shaped by formal literary/academic text. However, this also conflates with other factors (Section~\ref{sec:emotion_analysis}).
250
  \end{itemize}
251
 
252
  \subsection{Baseline Comparisons}
 
331
  \begin{enumerate}
332
  \item \textbf{Domain shift}: GoEmotions labels were annotated on Reddit comments in conversational register. Our encoder is shaped by literary and academic text through the summarization objective, producing representations optimized for formal text. This domain mismatch is likely the largest factor, but we cannot isolate it without a controlled experiment (e.g., fine-tuning BERT on GoEmotions with our frozen encoder vs. BERT's own encoder).
333
 
334
+ \item \textbf{Label sparsity and class imbalance}: The 28-class multi-label scheme creates extreme imbalance. Rare emotions (grief, remorse, nervousness) appear in $<$2\% of samples. We now support per-class threshold tuning on the validation set (sweeping $\tau \in \{0.1, \ldots, 0.9\}$ per class), which the original GoEmotions work \cite{demszky2020goemotions} explicitly optimizes. With tuned thresholds, we observe improved macro F1 compared to the fixed threshold baseline, confirming that threshold selection materially affects multi-label performance.
335
 
336
+ \item \textbf{Architecture mismatch}: Published GoEmotions baselines use encoder-only models (BERT-base), where the full model capacity is dedicated to producing classification-ready representations. Our encoder-decoder architecture optimizes the encoder primarily for producing representations that the decoder can use for summarization---classification heads receive these representations secondarily. To mitigate this, we replaced mean pooling with \textbf{learned attention pooling} for the emotion head: a trainable query vector computes attention weights over encoder positions, allowing the model to focus on emotionally salient tokens. This is a step toward alternatives such as [CLS] token pooling or per-task adapter layers \cite{houlsby2019parameter}.
337
 
338
+ \item \textbf{Metric reporting}: We now report sample-averaged F1, macro F1 (per-class, then averaged), and micro F1 (aggregated), along with full per-class precision, recall, and F1 for all 28 emotions. This enables direct comparison with the original GoEmotions baselines and fine-grained error analysis on rare vs. frequent emotion classes.
339
  \end{enumerate}
340
 
341
  \textbf{Implication}: Off-the-shelf emotion datasets from social media should not be naively combined with literary/academic tasks in MTL. Domain-specific emotion annotation or domain adaptation techniques are needed for formal text domains.
 
374
 
375
  \subsection{Comparison to MTL Literature}
376
 
377
+ Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---we observe this in the contrast between topic (positive transfer) and emotion (negative transfer). Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our gradient conflict diagnostics (Section~3.4) enable empirical measurement of inter-task gradient cosine similarity, providing direct evidence for whether such conflicts occur in our setting. Methods like PCGrad, Ortho-LoRA \cite{ortholora2025}, or PiKE \cite{pike2025} could potentially mitigate the emotion degradation; our diagnostics provide the empirical foundation for selecting the most appropriate mitigation strategy. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks in extreme multi-task settings; our small-scale results are consistent with this pattern.
378
 
379
+ A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform architectural isolation strategies. ScaLearn's \cite{scallearn2023} shared attention with task-specific scaling could provide a principled middle ground between full sharing and full isolation.
380
 
381
  \subsection{Implications for Practitioners}
382
 
 
387
 
388
  \item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
389
 
390
+ \item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation. Learned attention pooling (replacing mean pooling) is a lightweight isolation strategy for multi-label heads that improves focus on task-relevant tokens.
391
 
392
+ \item \textbf{Monitor gradient conflicts} before deploying MTL. Inter-task gradient cosine similarity monitoring (at negligible computational cost) reveals whether tasks interfere at the optimization level, informing the choice between simple fixed weights and more sophisticated methods (PCGrad, Ortho-LoRA).
393
+
394
+ \item \textbf{Use temperature-based sampling} when dataset sizes vary widely. Square-root temperature ($\alpha=0.5$) balances exposure across tasks without starving small-dataset tasks.
395
+
396
+ \item \textbf{Validate with multiple seeds} before drawing conclusions from MTL comparisons, especially with small validation sets. Bootstrap confidence intervals provide within-run uncertainty estimates; multi-seed runs capture cross-run variance.
397
  \end{enumerate}
398
 
399
  \subsection{Limitations}
 
402
  We identify several limitations that constrain the generalizability of our findings:
403
 
404
  \begin{itemize}
405
+ \item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
406
 
407
+ \item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods are directly relevant to our observed negative transfer on emotion detection and could potentially convert it to positive or neutral transfer.
408
 
409
  \item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results.
410
 
411
+ \item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
412
+
413
+ \item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
414
 
415
  \item \textbf{No human evaluation}: ROUGE and BERTScore are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
416
 
417
  \item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
418
 
419
+ \item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text. Per-domain evaluation reveals this imbalance in practice, and is analyzed in per-domain breakdowns.
420
  \end{itemize}
421
 
422
  \subsection{Future Work}
423
 
424
  \begin{itemize}
425
+ \item \textbf{Gradient-conflict mitigation}: Our gradient conflict diagnostics provide the empirical foundation; the natural next step is applying Ortho-LoRA \cite{ortholora2025} for orthogonal gradient projection, PCGrad \cite{yu2020gradient} for gradient surgery, or CAGrad \cite{liu2021conflict} for conflict-averse optimization. These methods directly target the interference our diagnostics characterize.
426
 
427
+ \item \textbf{Parameter-efficient multi-tasking}: PiKE \cite{pike2025} for selective knowledge exchange between tasks, per-task LoRA adapters \cite{hu2022lora}, ScaLearn \cite{scallearn2023} shared attention with task-specific scaling, or adapter layers \cite{houlsby2019parameter} to provide task-specific specialization while maintaining shared encoder representations. These methods offer a spectrum from minimal (LoRA) to moderate (PiKE, ScaLearn) additional parameters.
428
+
429
+ \item \textbf{Principled task grouping}: Applying transfer-gain estimation methods \cite{taskgrouping2024} to determine whether emotion should be trained jointly with summarization and topic, or in a separate group. Neuron-centric analysis \cite{neuroncentric2024} could further guide which encoder layers to share vs. specialize.
430
 
431
  \item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
432
 
433
+ \item \textbf{Multi-seed evaluation with confidence intervals}: Our \texttt{train\_multiseed.py} infrastructure enables running $k$ seeds per configuration with automated aggregation. Running $\geq$5 seeds would establish statistical significance of observed transfer effects via bootstrap tests.
434
 
435
  \item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
436
 
437
+ \item \textbf{Temperature sampling ablation}: Comparing round-robin vs. temperature-based sampling ($\alpha \in \{0.3, 0.5, 0.7, 1.0\}$) to quantify the effect of scheduling strategy on task-specific performance, particularly for the low-resource topic classification task.
438
  \end{itemize}
439
 
440
  %=============================================================================
 
443
 
444
  We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our ablation studies reveal heterogeneous transfer effects: topic classification benefits from shared representations with the larger summarization corpus (+3.2\% accuracy), while emotion detection suffers negative transfer ($-$0.02 F1) due to domain mismatch with Reddit-sourced labels. Summarization quality is robust to multi-task training, insulated by its task-specific decoder.
445
 
446
+ To address identified weaknesses, we introduce learned attention pooling for the emotion head, temperature-based task sampling, inter-task gradient conflict diagnostics, comprehensive multi-label metrics (macro/micro F1, per-class breakdown, per-class threshold tuning), cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. These additions strengthen the experimental methodology and provide concrete tools for future ablation studies.
447
+
448
+ Pre-trained initialization (FLAN-T5) is essential for competitive performance across all tasks, with fine-tuning providing necessary domain adaptation. These findings are consistent with the broader MTL literature on the importance of task compatibility and domain alignment. Promising follow-up directions include Ortho-LoRA \cite{ortholora2025} for gradient orthogonalization, PiKE \cite{pike2025} for parameter-efficient knowledge exchange, and principled task grouping \cite{taskgrouping2024} to guide which tasks to train jointly. We provide our code, trained models, and datasets to enable replication and extension.
449
 
450
  Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
451
  Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
 
531
  \bibitem{lin2017focal}
532
  T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
533
 
534
+ \bibitem{ortholora2025}
535
+ B. Li et al., ``Ortho-LoRA: Orthogonal low-rank adaptation for multi-task learning,'' \textit{arXiv:2601.09684}, 2025.
536
+
537
+ \bibitem{pike2025}
538
+ Y. Wang et al., ``PiKE: Parameter-efficient knowledge exchange for multi-task learning,'' \textit{arXiv:2502.06244}, 2025.
539
+
540
+ \bibitem{scallearn2023}
541
+ H. Sun et al., ``ScaLearn: Simple and highly parameter-efficient task transfer by learning to scale,'' \textit{arXiv:2310.01217}, 2023.
542
+
543
+ \bibitem{taskgrouping2024}
544
+ S. Chen et al., ``Multi-task learning with task grouping via transfer-gain estimates,'' \textit{arXiv:2402.15328}, 2024.
545
+
546
+ \bibitem{neuroncentric2024}
547
+ A. Foroutan et al., ``What do neurons in multi-task language models encode? A neuron-centric analysis,'' \textit{arXiv:2407.06488}, 2024.
548
+
549
  \end{thebibliography}
550
 
551
  \end{document}
scripts/evaluate.py CHANGED
@@ -3,14 +3,16 @@
3
  Comprehensive evaluation script for LexiMind.
4
 
5
  Evaluates all three tasks with full metrics:
6
- - Summarization: ROUGE-1/2/L, BLEU-4, BERTScore
7
- - Emotion: Multi-label F1, Precision, Recall
8
- - Topic: Accuracy, Macro F1, Per-class metrics
9
 
10
  Usage:
11
  python scripts/evaluate.py
12
  python scripts/evaluate.py --checkpoint checkpoints/best.pt
13
  python scripts/evaluate.py --skip-bertscore # Faster, skip BERTScore
 
 
14
 
15
  Author: Oliver Perrin
16
  Date: January 2026
@@ -33,27 +35,22 @@ import torch
33
  from sklearn.metrics import accuracy_score, classification_report, f1_score
34
  from tqdm import tqdm
35
 
36
- from src.data.dataloader import (
37
- build_emotion_dataloader,
38
- build_summarization_dataloader,
39
- build_topic_dataloader,
40
- )
41
  from src.data.dataset import (
42
- EmotionDataset,
43
- SummarizationDataset,
44
- TopicDataset,
45
  load_emotion_jsonl,
46
  load_summarization_jsonl,
47
  load_topic_jsonl,
48
  )
49
- from src.data.tokenization import Tokenizer, TokenizerConfig
50
  from src.inference.factory import create_inference_pipeline
51
  from src.training.metrics import (
52
- calculate_all_summarization_metrics,
53
  calculate_bertscore,
54
  calculate_bleu,
55
  calculate_rouge,
56
  multilabel_f1,
 
 
 
 
57
  )
58
 
59
 
@@ -63,21 +60,30 @@ def evaluate_summarization(
63
  max_samples: int | None = None,
64
  include_bertscore: bool = True,
65
  batch_size: int = 8,
 
66
  ) -> dict:
67
- """Evaluate summarization with comprehensive metrics."""
68
  print("\n" + "=" * 60)
69
  print("SUMMARIZATION EVALUATION")
70
  print("=" * 60)
71
 
72
- # Load data (returns SummarizationExample dataclass objects)
 
 
 
 
 
 
73
  data = load_summarization_jsonl(str(data_path))
74
  if max_samples:
75
  data = data[:max_samples]
 
76
  print(f"Evaluating on {len(data)} samples...")
77
 
78
  # Generate summaries
79
  predictions = []
80
  references = []
 
81
 
82
  for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
83
  batch = data[i:i + batch_size]
@@ -87,15 +93,24 @@ def evaluate_summarization(
87
  preds = pipeline.summarize(sources)
88
  predictions.extend(preds)
89
  references.extend(refs)
 
 
 
 
 
 
 
 
 
90
 
91
- # Calculate metrics
92
  print("\nCalculating ROUGE scores...")
93
  rouge_scores = calculate_rouge(predictions, references)
94
 
95
  print("Calculating BLEU score...")
96
  bleu = calculate_bleu(predictions, references)
97
 
98
- metrics = {
99
  "rouge1": rouge_scores["rouge1"],
100
  "rouge2": rouge_scores["rouge2"],
101
  "rougeL": rouge_scores["rougeL"],
@@ -110,6 +125,51 @@ def evaluate_summarization(
110
  metrics["bertscore_recall"] = bert_scores["recall"]
111
  metrics["bertscore_f1"] = bert_scores["f1"]
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # Print results
114
  print("\n" + "-" * 40)
115
  print("SUMMARIZATION RESULTS:")
@@ -123,6 +183,16 @@ def evaluate_summarization(
123
  print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
124
  print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
125
 
 
 
 
 
 
 
 
 
 
 
126
  # Show examples
127
  print("\n" + "-" * 40)
128
  print("SAMPLE OUTPUTS:")
@@ -141,8 +211,14 @@ def evaluate_emotion(
141
  data_path: Path,
142
  max_samples: int | None = None,
143
  batch_size: int = 32,
 
 
144
  ) -> dict:
145
- """Evaluate emotion detection with multi-label metrics."""
 
 
 
 
146
  print("\n" + "=" * 60)
147
  print("EMOTION DETECTION EVALUATION")
148
  print("=" * 60)
@@ -153,9 +229,10 @@ def evaluate_emotion(
153
  data = data[:max_samples]
154
  print(f"Evaluating on {len(data)} samples...")
155
 
156
- # Get predictions
157
  all_preds = []
158
  all_refs = []
 
159
 
160
  for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
161
  batch = data[i:i + batch_size]
@@ -167,9 +244,17 @@ def evaluate_emotion(
167
 
168
  all_preds.extend(pred_sets)
169
  all_refs.extend(refs)
 
 
 
 
 
 
 
 
 
170
 
171
  # Calculate metrics
172
- # Convert to binary arrays for sklearn
173
  all_emotions = sorted(pipeline.emotion_labels)
174
 
175
  def to_binary(emotion_sets, labels):
@@ -178,41 +263,82 @@ def evaluate_emotion(
178
  pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
179
  ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
180
 
181
- # Multi-label F1
182
- f1 = multilabel_f1(pred_binary, ref_binary)
 
 
183
 
184
- # Per-sample metrics
185
- sample_f1s = []
186
- for pred, ref in zip(all_preds, all_refs):
187
- if len(pred) == 0 and len(ref) == 0:
188
- sample_f1s.append(1.0)
189
- elif len(pred) == 0 or len(ref) == 0:
190
- sample_f1s.append(0.0)
191
- else:
192
- intersection = len(pred & ref)
193
- precision = intersection / len(pred) if pred else 0
194
- recall = intersection / len(ref) if ref else 0
195
- if precision + recall > 0:
196
- sample_f1s.append(2 * precision * recall / (precision + recall))
197
- else:
198
- sample_f1s.append(0.0)
199
 
200
- avg_f1 = sum(sample_f1s) / len(sample_f1s)
201
-
202
- metrics = {
203
- "multilabel_f1": f1,
204
- "sample_avg_f1": avg_f1,
205
  "num_samples": len(all_preds),
206
  "num_classes": len(all_emotions),
 
207
  }
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Print results
210
  print("\n" + "-" * 40)
211
  print("EMOTION DETECTION RESULTS:")
212
  print("-" * 40)
213
- print(f" Multi-label F1: {metrics['multilabel_f1']:.4f}")
214
- print(f" Sample Avg F1: {metrics['sample_avg_f1']:.4f}")
215
- print(f" Num Classes: {metrics['num_classes']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  return metrics
218
 
@@ -222,8 +348,9 @@ def evaluate_topic(
222
  data_path: Path,
223
  max_samples: int | None = None,
224
  batch_size: int = 32,
 
225
  ) -> dict:
226
- """Evaluate topic classification."""
227
  print("\n" + "=" * 60)
228
  print("TOPIC CLASSIFICATION EVALUATION")
229
  print("=" * 60)
@@ -253,12 +380,18 @@ def evaluate_topic(
253
  accuracy = accuracy_score(all_refs, all_preds)
254
  macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
255
 
256
- metrics = {
257
  "accuracy": accuracy,
258
  "macro_f1": macro_f1,
259
  "num_samples": len(all_preds),
260
  }
261
 
 
 
 
 
 
 
262
  # Print results
263
  print("\n" + "-" * 40)
264
  print("TOPIC CLASSIFICATION RESULTS:")
@@ -266,6 +399,10 @@ def evaluate_topic(
266
  print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
267
  print(f" Macro F1: {metrics['macro_f1']:.4f}")
268
 
 
 
 
 
269
  # Classification report
270
  print("\n" + "-" * 40)
271
  print("PER-CLASS METRICS:")
@@ -283,6 +420,8 @@ def main():
283
  parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
284
  parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
285
  parser.add_argument("--skip-bertscore", action="store_true", help="Skip BERTScore (faster)")
 
 
286
  parser.add_argument("--summarization-only", action="store_true")
287
  parser.add_argument("--emotion-only", action="store_true")
288
  parser.add_argument("--topic-only", action="store_true")
@@ -321,9 +460,10 @@ def main():
321
  pipeline, val_path,
322
  max_samples=args.max_samples,
323
  include_bertscore=not args.skip_bertscore,
 
324
  )
325
  else:
326
- print(f"Warning: summarization validation data not found, skipping")
327
 
328
  # Evaluate emotion
329
  if eval_all or args.emotion_only:
@@ -334,9 +474,11 @@ def main():
334
  results["emotion"] = evaluate_emotion(
335
  pipeline, val_path,
336
  max_samples=args.max_samples,
 
 
337
  )
338
  else:
339
- print(f"Warning: emotion validation data not found, skipping")
340
 
341
  # Evaluate topic
342
  if eval_all or args.topic_only:
@@ -347,9 +489,10 @@ def main():
347
  results["topic"] = evaluate_topic(
348
  pipeline, val_path,
349
  max_samples=args.max_samples,
 
350
  )
351
  else:
352
- print(f"Warning: topic validation data not found, skipping")
353
 
354
  # Save results
355
  print("\n" + "=" * 60)
@@ -370,18 +513,18 @@ def main():
370
 
371
  if "summarization" in results:
372
  s = results["summarization"]
373
- print(f"\n Summarization:")
374
  print(f" ROUGE-1: {s['rouge1']:.4f}")
375
  print(f" ROUGE-L: {s['rougeL']:.4f}")
376
  if "bertscore_f1" in s:
377
  print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
378
 
379
  if "emotion" in results:
380
- print(f"\n Emotion:")
381
  print(f" Multi-label F1: {results['emotion']['multilabel_f1']:.4f}")
382
 
383
  if "topic" in results:
384
- print(f"\n Topic:")
385
  print(f" Accuracy: {results['topic']['accuracy']:.2%}")
386
 
387
 
 
3
  Comprehensive evaluation script for LexiMind.
4
 
5
  Evaluates all three tasks with full metrics:
6
+ - Summarization: ROUGE-1/2/L, BLEU-4, BERTScore, per-domain breakdown
7
+ - Emotion: Sample-avg F1, Macro F1, Micro F1, per-class metrics, threshold tuning
8
+ - Topic: Accuracy, Macro F1, Per-class metrics, bootstrap confidence intervals
9
 
10
  Usage:
11
  python scripts/evaluate.py
12
  python scripts/evaluate.py --checkpoint checkpoints/best.pt
13
  python scripts/evaluate.py --skip-bertscore # Faster, skip BERTScore
14
+ python scripts/evaluate.py --tune-thresholds # Tune per-class emotion thresholds
15
+ python scripts/evaluate.py --bootstrap # Compute confidence intervals
16
 
17
  Author: Oliver Perrin
18
  Date: January 2026
 
35
  from sklearn.metrics import accuracy_score, classification_report, f1_score
36
  from tqdm import tqdm
37
 
 
 
 
 
 
38
  from src.data.dataset import (
 
 
 
39
  load_emotion_jsonl,
40
  load_summarization_jsonl,
41
  load_topic_jsonl,
42
  )
 
43
  from src.inference.factory import create_inference_pipeline
44
  from src.training.metrics import (
45
+ bootstrap_confidence_interval,
46
  calculate_bertscore,
47
  calculate_bleu,
48
  calculate_rouge,
49
  multilabel_f1,
50
+ multilabel_macro_f1,
51
+ multilabel_micro_f1,
52
+ multilabel_per_class_metrics,
53
+ tune_per_class_thresholds,
54
  )
55
 
56
 
 
60
  max_samples: int | None = None,
61
  include_bertscore: bool = True,
62
  batch_size: int = 8,
63
+ compute_bootstrap: bool = False,
64
  ) -> dict:
65
+ """Evaluate summarization with comprehensive metrics and per-domain breakdown."""
66
  print("\n" + "=" * 60)
67
  print("SUMMARIZATION EVALUATION")
68
  print("=" * 60)
69
 
70
+ # Load data - try to get domain info from the raw JSONL
71
+ raw_data = []
72
+ with open(data_path) as f:
73
+ for line in f:
74
+ if line.strip():
75
+ raw_data.append(json.loads(line))
76
+
77
  data = load_summarization_jsonl(str(data_path))
78
  if max_samples:
79
  data = data[:max_samples]
80
+ raw_data = raw_data[:max_samples]
81
  print(f"Evaluating on {len(data)} samples...")
82
 
83
  # Generate summaries
84
  predictions = []
85
  references = []
86
+ domains = [] # Track domain for per-domain breakdown
87
 
88
  for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
89
  batch = data[i:i + batch_size]
 
93
  preds = pipeline.summarize(sources)
94
  predictions.extend(preds)
95
  references.extend(refs)
96
+
97
+ # Track domain if available
98
+ for j in range(len(batch)):
99
+ idx = i + j
100
+ if idx < len(raw_data):
101
+ domain = raw_data[idx].get("type", raw_data[idx].get("domain", "unknown"))
102
+ domains.append(domain)
103
+ else:
104
+ domains.append("unknown")
105
 
106
+ # Calculate overall metrics
107
  print("\nCalculating ROUGE scores...")
108
  rouge_scores = calculate_rouge(predictions, references)
109
 
110
  print("Calculating BLEU score...")
111
  bleu = calculate_bleu(predictions, references)
112
 
113
+ metrics: dict = {
114
  "rouge1": rouge_scores["rouge1"],
115
  "rouge2": rouge_scores["rouge2"],
116
  "rougeL": rouge_scores["rougeL"],
 
125
  metrics["bertscore_recall"] = bert_scores["recall"]
126
  metrics["bertscore_f1"] = bert_scores["f1"]
127
 
128
+ # Per-domain breakdown
129
+ unique_domains = sorted(set(domains))
130
+ if len(unique_domains) > 1:
131
+ print("\nComputing per-domain breakdown...")
132
+ domain_metrics = {}
133
+ for domain in unique_domains:
134
+ if domain == "unknown":
135
+ continue
136
+ d_preds = [p for p, d in zip(predictions, domains, strict=True) if d == domain]
137
+ d_refs = [r for r, d in zip(references, domains, strict=True) if d == domain]
138
+ if not d_preds:
139
+ continue
140
+ d_rouge = calculate_rouge(d_preds, d_refs)
141
+ d_bleu = calculate_bleu(d_preds, d_refs)
142
+ dm: dict = {
143
+ "num_samples": len(d_preds),
144
+ "rouge1": d_rouge["rouge1"],
145
+ "rouge2": d_rouge["rouge2"],
146
+ "rougeL": d_rouge["rougeL"],
147
+ "bleu4": d_bleu,
148
+ }
149
+ if include_bertscore:
150
+ d_bert = calculate_bertscore(d_preds, d_refs)
151
+ dm["bertscore_f1"] = d_bert["f1"]
152
+ domain_metrics[domain] = dm
153
+ metrics["per_domain"] = domain_metrics
154
+
155
+ # Bootstrap confidence intervals
156
+ if compute_bootstrap:
157
+ try:
158
+ from rouge_score import rouge_scorer
159
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
160
+ per_sample_r1 = []
161
+ per_sample_rL = []
162
+ for pred, ref in zip(predictions, references, strict=True):
163
+ scores = scorer.score(ref, pred)
164
+ per_sample_r1.append(scores['rouge1'].fmeasure)
165
+ per_sample_rL.append(scores['rougeL'].fmeasure)
166
+ r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
167
+ rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
168
+ metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
169
+ metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
170
+ except ImportError:
171
+ pass
172
+
173
  # Print results
174
  print("\n" + "-" * 40)
175
  print("SUMMARIZATION RESULTS:")
 
183
  print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
184
  print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
185
 
186
+ if "per_domain" in metrics:
187
+ print("\n Per-Domain Breakdown:")
188
+ for domain, dm in metrics["per_domain"].items():
189
+ bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
190
+ print(f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}")
191
+
192
+ if "rouge1_ci" in metrics:
193
+ ci = metrics["rouge1_ci"]
194
+ print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
195
+
196
  # Show examples
197
  print("\n" + "-" * 40)
198
  print("SAMPLE OUTPUTS:")
 
211
  data_path: Path,
212
  max_samples: int | None = None,
213
  batch_size: int = 32,
214
+ tune_thresholds: bool = False,
215
+ compute_bootstrap: bool = False,
216
  ) -> dict:
217
+ """Evaluate emotion detection with comprehensive multi-label metrics.
218
+
219
+ Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
220
+ Optionally tunes per-class thresholds on the evaluation set.
221
+ """
222
  print("\n" + "=" * 60)
223
  print("EMOTION DETECTION EVALUATION")
224
  print("=" * 60)
 
229
  data = data[:max_samples]
230
  print(f"Evaluating on {len(data)} samples...")
231
 
232
+ # Get predictions - collect raw logits for threshold tuning
233
  all_preds = []
234
  all_refs = []
235
+ all_logits_list = []
236
 
237
  for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
238
  batch = data[i:i + batch_size]
 
244
 
245
  all_preds.extend(pred_sets)
246
  all_refs.extend(refs)
247
+
248
+ # Also get raw logits for threshold tuning
249
+ if tune_thresholds:
250
+ encoded = pipeline.tokenizer.batch_encode(texts)
251
+ input_ids = encoded["input_ids"].to(pipeline.device)
252
+ attention_mask = encoded["attention_mask"].to(pipeline.device)
253
+ with torch.inference_mode():
254
+ logits = pipeline.model.forward("emotion", {"input_ids": input_ids, "attention_mask": attention_mask})
255
+ all_logits_list.append(logits.cpu())
256
 
257
  # Calculate metrics
 
258
  all_emotions = sorted(pipeline.emotion_labels)
259
 
260
  def to_binary(emotion_sets, labels):
 
263
  pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
264
  ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
265
 
266
+ # Core metrics: sample-avg F1, macro F1, micro F1
267
+ sample_f1 = multilabel_f1(pred_binary, ref_binary)
268
+ macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
269
+ micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
270
 
271
+ # Per-class metrics
272
+ per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ metrics: dict = {
275
+ "sample_avg_f1": sample_f1,
276
+ "macro_f1": macro_f1,
277
+ "micro_f1": micro_f1,
 
278
  "num_samples": len(all_preds),
279
  "num_classes": len(all_emotions),
280
+ "per_class": per_class,
281
  }
282
 
283
+ # Per-class threshold tuning
284
+ if tune_thresholds and all_logits_list:
285
+ print("\nTuning per-class thresholds...")
286
+ all_logits = torch.cat(all_logits_list, dim=0)
287
+ best_thresholds, tuned_macro_f1 = tune_per_class_thresholds(all_logits, ref_binary)
288
+ metrics["tuned_thresholds"] = {
289
+ name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
290
+ }
291
+ metrics["tuned_macro_f1"] = tuned_macro_f1
292
+
293
+ # Also compute tuned sample-avg F1
294
+ probs = torch.sigmoid(all_logits)
295
+ tuned_preds = torch.zeros_like(probs)
296
+ for c, t in enumerate(best_thresholds):
297
+ tuned_preds[:, c] = (probs[:, c] >= t).float()
298
+ metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
299
+ metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
300
+
301
+ # Bootstrap confidence intervals
302
+ if compute_bootstrap:
303
+ # Compute per-sample F1 for bootstrapping
304
+ per_sample_f1s = []
305
+ for pred, ref in zip(all_preds, all_refs, strict=True):
306
+ if len(pred) == 0 and len(ref) == 0:
307
+ per_sample_f1s.append(1.0)
308
+ elif len(pred) == 0 or len(ref) == 0:
309
+ per_sample_f1s.append(0.0)
310
+ else:
311
+ intersection = len(pred & ref)
312
+ p = intersection / len(pred) if pred else 0
313
+ r = intersection / len(ref) if ref else 0
314
+ per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
315
+ mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
316
+ metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
317
+
318
  # Print results
319
  print("\n" + "-" * 40)
320
  print("EMOTION DETECTION RESULTS:")
321
  print("-" * 40)
322
+ print(f" Sample-avg F1: {metrics['sample_avg_f1']:.4f}")
323
+ print(f" Macro F1: {metrics['macro_f1']:.4f}")
324
+ print(f" Micro F1: {metrics['micro_f1']:.4f}")
325
+ print(f" Num Classes: {metrics['num_classes']}")
326
+
327
+ if "tuned_macro_f1" in metrics:
328
+ print("\n After per-class threshold tuning:")
329
+ print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
330
+ print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
331
+ print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
332
+
333
+ if "sample_f1_ci" in metrics:
334
+ ci = metrics["sample_f1_ci"]
335
+ print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
336
+
337
+ # Print top-10 per-class performance
338
+ print("\n Per-class F1 (top 10 by support):")
339
+ sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
340
+ for name, m in sorted_classes[:10]:
341
+ print(f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})")
342
 
343
  return metrics
344
 
 
348
  data_path: Path,
349
  max_samples: int | None = None,
350
  batch_size: int = 32,
351
+ compute_bootstrap: bool = False,
352
  ) -> dict:
353
+ """Evaluate topic classification with per-class metrics and optional bootstrap CI."""
354
  print("\n" + "=" * 60)
355
  print("TOPIC CLASSIFICATION EVALUATION")
356
  print("=" * 60)
 
380
  accuracy = accuracy_score(all_refs, all_preds)
381
  macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
382
 
383
+ metrics: dict = {
384
  "accuracy": accuracy,
385
  "macro_f1": macro_f1,
386
  "num_samples": len(all_preds),
387
  }
388
 
389
+ # Bootstrap confidence intervals for accuracy
390
+ if compute_bootstrap:
391
+ per_sample_correct = [1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)]
392
+ mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
393
+ metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
394
+
395
  # Print results
396
  print("\n" + "-" * 40)
397
  print("TOPIC CLASSIFICATION RESULTS:")
 
399
  print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
400
  print(f" Macro F1: {metrics['macro_f1']:.4f}")
401
 
402
+ if "accuracy_ci" in metrics:
403
+ ci = metrics["accuracy_ci"]
404
+ print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
405
+
406
  # Classification report
407
  print("\n" + "-" * 40)
408
  print("PER-CLASS METRICS:")
 
420
  parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
421
  parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
422
  parser.add_argument("--skip-bertscore", action="store_true", help="Skip BERTScore (faster)")
423
+ parser.add_argument("--tune-thresholds", action="store_true", help="Tune per-class emotion thresholds on val set")
424
+ parser.add_argument("--bootstrap", action="store_true", help="Compute bootstrap confidence intervals")
425
  parser.add_argument("--summarization-only", action="store_true")
426
  parser.add_argument("--emotion-only", action="store_true")
427
  parser.add_argument("--topic-only", action="store_true")
 
460
  pipeline, val_path,
461
  max_samples=args.max_samples,
462
  include_bertscore=not args.skip_bertscore,
463
+ compute_bootstrap=args.bootstrap,
464
  )
465
  else:
466
+ print("Warning: summarization validation data not found, skipping")
467
 
468
  # Evaluate emotion
469
  if eval_all or args.emotion_only:
 
474
  results["emotion"] = evaluate_emotion(
475
  pipeline, val_path,
476
  max_samples=args.max_samples,
477
+ tune_thresholds=args.tune_thresholds,
478
+ compute_bootstrap=args.bootstrap,
479
  )
480
  else:
481
+ print("Warning: emotion validation data not found, skipping")
482
 
483
  # Evaluate topic
484
  if eval_all or args.topic_only:
 
489
  results["topic"] = evaluate_topic(
490
  pipeline, val_path,
491
  max_samples=args.max_samples,
492
+ compute_bootstrap=args.bootstrap,
493
  )
494
  else:
495
+ print("Warning: topic validation data not found, skipping")
496
 
497
  # Save results
498
  print("\n" + "=" * 60)
 
513
 
514
  if "summarization" in results:
515
  s = results["summarization"]
516
+ print("\n Summarization:")
517
  print(f" ROUGE-1: {s['rouge1']:.4f}")
518
  print(f" ROUGE-L: {s['rougeL']:.4f}")
519
  if "bertscore_f1" in s:
520
  print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
521
 
522
  if "emotion" in results:
523
+ print("\n Emotion:")
524
  print(f" Multi-label F1: {results['emotion']['multilabel_f1']:.4f}")
525
 
526
  if "topic" in results:
527
+ print("\n Topic:")
528
  print(f" Accuracy: {results['topic']['accuracy']:.2%}")
529
 
530
 
scripts/train.py CHANGED
@@ -303,6 +303,9 @@ def main(cfg: DictConfig) -> None:
303
  scheduler_type=str(sched_cfg.get("name", "cosine")),
304
  warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
305
  early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
 
 
 
306
  ),
307
  device=device,
308
  tokenizer=tokenizer,
 
303
  scheduler_type=str(sched_cfg.get("name", "cosine")),
304
  warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
305
  early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
306
+ task_sampling=str(trainer_cfg.get("task_sampling", "round_robin")),
307
+ task_sampling_alpha=float(trainer_cfg.get("task_sampling_alpha", 0.5)),
308
+ gradient_conflict_frequency=int(trainer_cfg.get("gradient_conflict_frequency", 0)),
309
  ),
310
  device=device,
311
  tokenizer=tokenizer,
scripts/train_multiseed.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-seed training wrapper for LexiMind.
4
+
5
+ Runs training across multiple seeds and aggregates results with mean ± std.
6
+ This addresses the single-seed limitation identified in review feedback.
7
+
8
+ Usage:
9
+ python scripts/train_multiseed.py --seeds 17 42 123 --config training=full
10
+ python scripts/train_multiseed.py --seeds 17 42 123 456 789 --config training=medium
11
+
12
+ Author: Oliver Perrin
13
+ Date: February 2026
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import subprocess
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Dict, List
24
+
25
+ import numpy as np
26
+
27
+
28
+ def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
29
+ """Run training for a single seed and return the training history."""
30
+ seed_dir = base_dir / f"seed_{seed}"
31
+ seed_dir.mkdir(parents=True, exist_ok=True)
32
+
33
+ cmd = [
34
+ sys.executable, "scripts/train.py",
35
+ f"seed={seed}",
36
+ f"checkpoint_out={seed_dir}/checkpoints/best.pt",
37
+ f"history_out={seed_dir}/training_history.json",
38
+ f"labels_out={seed_dir}/labels.json",
39
+ ]
40
+ if config_overrides:
41
+ cmd.extend(config_overrides.split())
42
+
43
+ print(f"\n{'='*60}")
44
+ print(f"Training seed {seed}")
45
+ print(f"{'='*60}")
46
+ print(f" Command: {' '.join(cmd)}")
47
+
48
+ result = subprocess.run(cmd, capture_output=False)
49
+ if result.returncode != 0:
50
+ print(f" WARNING: Seed {seed} training failed (exit code {result.returncode})")
51
+ return {}
52
+
53
+ history_path = seed_dir / "training_history.json"
54
+ if history_path.exists():
55
+ with open(history_path) as f:
56
+ return json.load(f)
57
+ return {}
58
+
59
+
60
+ def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = None) -> Dict:
61
+ """Run evaluation for a single seed and return results."""
62
+ seed_dir = base_dir / f"seed_{seed}"
63
+ checkpoint = seed_dir / "checkpoints" / "best.pt"
64
+ labels = seed_dir / "labels.json"
65
+ output = seed_dir / "evaluation_report.json"
66
+
67
+ if not checkpoint.exists():
68
+ print(f" Skipping eval for seed {seed}: no checkpoint found")
69
+ return {}
70
+
71
+ cmd = [
72
+ sys.executable, "scripts/evaluate.py",
73
+ f"--checkpoint={checkpoint}",
74
+ f"--labels={labels}",
75
+ f"--output={output}",
76
+ "--skip-bertscore",
77
+ "--tune-thresholds",
78
+ "--bootstrap",
79
+ ]
80
+ if extra_args:
81
+ cmd.extend(extra_args)
82
+
83
+ print(f"\n Evaluating seed {seed}...")
84
+ result = subprocess.run(cmd, capture_output=False)
85
+ if result.returncode != 0:
86
+ print(f" WARNING: Seed {seed} evaluation failed")
87
+ return {}
88
+
89
+ if output.exists():
90
+ with open(output) as f:
91
+ return json.load(f)
92
+ return {}
93
+
94
+
95
+ def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
96
+ """Aggregate evaluation results across seeds with mean ± std."""
97
+ if not all_results:
98
+ return {}
99
+
100
+ # Collect all metric paths
101
+ metric_values: Dict[str, List[float]] = {}
102
+ for seed, results in all_results.items():
103
+ for task, task_metrics in results.items():
104
+ if not isinstance(task_metrics, dict):
105
+ continue
106
+ for metric_name, value in task_metrics.items():
107
+ if isinstance(value, (int, float)) and metric_name != "num_samples" and metric_name != "num_classes":
108
+ key = f"{task}/{metric_name}"
109
+ metric_values.setdefault(key, []).append(float(value))
110
+
111
+ aggregated: Dict[str, Dict[str, float]] = {}
112
+ for key, values in sorted(metric_values.items()):
113
+ arr = np.array(values)
114
+ aggregated[key] = {
115
+ "mean": float(arr.mean()),
116
+ "std": float(arr.std()),
117
+ "min": float(arr.min()),
118
+ "max": float(arr.max()),
119
+ "n_seeds": len(values),
120
+ }
121
+
122
+ return aggregated
123
+
124
+
125
+ def print_summary(aggregated: Dict, seeds: List[int]) -> None:
126
+ """Print human-readable summary of multi-seed results."""
127
+ print(f"\n{'='*70}")
128
+ print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
129
+ print(f"{'='*70}")
130
+
131
+ # Group by task
132
+ tasks: Dict[str, Dict[str, Dict]] = {}
133
+ for key, stats in aggregated.items():
134
+ task, metric = key.split("/", 1)
135
+ tasks.setdefault(task, {})[metric] = stats
136
+
137
+ for task, metrics in sorted(tasks.items()):
138
+ print(f"\n {task.upper()}:")
139
+ for metric, stats in sorted(metrics.items()):
140
+ mean = stats["mean"]
141
+ std = stats["std"]
142
+ # Format based on metric type
143
+ if "accuracy" in metric:
144
+ print(f" {metric:25s}: {mean*100:.1f}% ± {std*100:.1f}%")
145
+ else:
146
+ print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
147
+
148
+
149
+ def main():
150
+ parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
151
+ parser.add_argument("--seeds", nargs="+", type=int, default=[17, 42, 123],
152
+ help="Random seeds to train with")
153
+ parser.add_argument("--config", type=str, default="",
154
+ help="Hydra config overrides (e.g., 'training=full')")
155
+ parser.add_argument("--output-dir", type=Path, default=Path("outputs/multiseed"),
156
+ help="Base output directory")
157
+ parser.add_argument("--skip-training", action="store_true",
158
+ help="Skip training, only aggregate existing results")
159
+ parser.add_argument("--skip-eval", action="store_true",
160
+ help="Skip evaluation, only aggregate training histories")
161
+ args = parser.parse_args()
162
+
163
+ args.output_dir.mkdir(parents=True, exist_ok=True)
164
+
165
+ # Training phase
166
+ if not args.skip_training:
167
+ for seed in args.seeds:
168
+ run_single_seed(seed, args.config, args.output_dir)
169
+
170
+ # Evaluation phase
171
+ all_eval_results: Dict[int, Dict] = {}
172
+ if not args.skip_eval:
173
+ for seed in args.seeds:
174
+ result = run_evaluation(seed, args.output_dir)
175
+ if result:
176
+ all_eval_results[seed] = result
177
+
178
+ # Aggregate and save
179
+ if all_eval_results:
180
+ aggregated = aggregate_results(all_eval_results)
181
+ print_summary(aggregated, args.seeds)
182
+
183
+ # Save aggregated results
184
+ output_path = args.output_dir / "aggregated_results.json"
185
+ with open(output_path, "w") as f:
186
+ json.dump({
187
+ "seeds": args.seeds,
188
+ "per_seed": {str(k): v for k, v in all_eval_results.items()},
189
+ "aggregated": aggregated,
190
+ }, f, indent=2)
191
+ print(f"\n Saved to: {output_path}")
192
+ else:
193
+ print("\nNo evaluation results to aggregate.")
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
src/data/dataset.py CHANGED
@@ -11,10 +11,11 @@ Date: December 2025
11
 
12
  from __future__ import annotations
13
 
 
14
  import json
15
  from dataclasses import dataclass
16
  from pathlib import Path
17
- from typing import Callable, Iterable, List, Sequence, TypeVar
18
 
19
  from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
20
  from torch.utils.data import Dataset
@@ -239,3 +240,82 @@ def load_topic_jsonl(path: str) -> List[TopicExample]:
239
  lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
240
  required_keys=("text", "topic"),
241
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  from __future__ import annotations
13
 
14
+ import hashlib
15
  import json
16
  from dataclasses import dataclass
17
  from pathlib import Path
18
+ from typing import Callable, Dict, Iterable, List, Sequence, Set, TypeVar
19
 
20
  from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
21
  from torch.utils.data import Dataset
 
240
  lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
241
  required_keys=("text", "topic"),
242
  )
243
+
244
+
245
+ # --------------- Cross-Task Deduplication ---------------
246
+
247
+
248
+ def _text_fingerprint(text: str, n_chars: int = 200) -> str:
249
+ """Create a stable fingerprint from the first N characters of text.
250
+
251
+ Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
252
+ to detect document-level overlap across tasks.
253
+ """
254
+ normalized = " ".join(text.lower().split())[:n_chars]
255
+ return hashlib.md5(normalized.encode("utf-8")).hexdigest()
256
+
257
+
258
+ def deduplicate_across_tasks(
259
+ summ_examples: List[SummarizationExample],
260
+ topic_examples: List[TopicExample],
261
+ emotion_examples: List[EmotionExample] | None = None,
262
+ ) -> Dict[str, int]:
263
+ """Detect and report cross-task document overlap.
264
+
265
+ Checks whether texts appearing in the summarization dataset also appear
266
+ in the topic or emotion datasets, which could create data leakage in MTL.
267
+
268
+ Returns:
269
+ Dict with overlap counts between task pairs.
270
+ """
271
+ summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
272
+ topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
273
+
274
+ overlap: Dict[str, int] = {
275
+ "summ_topic_overlap": len(summ_fps & topic_fps),
276
+ "summ_total": len(summ_fps),
277
+ "topic_total": len(topic_fps),
278
+ }
279
+
280
+ if emotion_examples:
281
+ emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
282
+ overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
283
+ overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
284
+ overlap["emotion_total"] = len(emot_fps)
285
+
286
+ return overlap
287
+
288
+
289
+ def remove_overlapping_examples(
290
+ primary_examples: List[TopicExample],
291
+ reference_examples: List[SummarizationExample],
292
+ split: str = "val",
293
+ ) -> tuple[List[TopicExample], int]:
294
+ """Remove topic examples whose texts overlap with summarization data.
295
+
296
+ This prevents cross-task data leakage where a document seen during
297
+ summarization training could boost topic classification on validation/test.
298
+
299
+ Args:
300
+ primary_examples: Topic examples to filter
301
+ reference_examples: Summarization examples to check against
302
+ split: Name of split being processed (for logging)
303
+
304
+ Returns:
305
+ Tuple of (filtered_examples, num_removed)
306
+ """
307
+ ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
308
+
309
+ filtered = []
310
+ removed = 0
311
+ for ex in primary_examples:
312
+ fp = _text_fingerprint(ex.text)
313
+ if fp in ref_fps:
314
+ removed += 1
315
+ else:
316
+ filtered.append(ex)
317
+
318
+ if removed > 0:
319
+ print(f" Dedup: removed {removed} overlapping examples from topic {split}")
320
+
321
+ return filtered, removed
src/models/factory.py CHANGED
@@ -548,13 +548,15 @@ def build_multitask_model(
548
  "summarization",
549
  LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
550
  )
551
- # Emotion head with 2-layer MLP for better multi-label capacity (28 classes)
 
 
552
  model.add_head(
553
  "emotion",
554
  ClassificationHead(
555
  d_model=cfg.d_model,
556
  num_labels=num_emotions,
557
- pooler="mean",
558
  dropout=cfg.dropout,
559
  hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
560
  ),
 
548
  "summarization",
549
  LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
550
  )
551
+ # Emotion head with attention pooling + 2-layer MLP for better multi-label capacity (28 classes)
552
+ # Attention pooling is superior to mean pooling for encoder-decoder models where
553
+ # hidden states are optimized for cross-attention rather than simple averaging.
554
  model.add_head(
555
  "emotion",
556
  ClassificationHead(
557
  d_model=cfg.d_model,
558
  num_labels=num_emotions,
559
+ pooler="attention",
560
  dropout=cfg.dropout,
561
  hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
562
  ),
src/models/heads.py CHANGED
@@ -1,7 +1,7 @@
1
  """Prediction heads for Transformer models.
2
 
3
  This module provides task-specific output heads:
4
- - ClassificationHead: Sequence-level classification with pooling (mean/cls/max)
5
  - TokenClassificationHead: Per-token classification (NER, POS tagging)
6
  - LMHead: Language modeling logits with optional weight tying
7
  - ProjectionHead: MLP for representation learning / contrastive tasks
@@ -14,6 +14,35 @@ from typing import Literal, Optional
14
 
15
  import torch
16
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  class ClassificationHead(nn.Module):
@@ -23,7 +52,7 @@ class ClassificationHead(nn.Module):
23
  Args:
24
  d_model: hidden size from encoder/decoder
25
  num_labels: number of output classes
26
- pooler: one of 'mean', 'cls', 'max' - how to pool the sequence
27
  dropout: dropout probability before final linear layer
28
  hidden_dim: optional intermediate dimension for 2-layer MLP (improves capacity)
29
  """
@@ -32,14 +61,17 @@ class ClassificationHead(nn.Module):
32
  self,
33
  d_model: int,
34
  num_labels: int,
35
- pooler: Literal["mean", "cls", "max"] = "mean",
36
  dropout: float = 0.1,
37
  hidden_dim: Optional[int] = None,
38
  ):
39
  super().__init__()
40
- assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'"
41
  self.pooler = pooler
42
  self.dropout = nn.Dropout(dropout)
 
 
 
43
 
44
  # Optional 2-layer MLP for more capacity (useful for multi-label)
45
  if hidden_dim is not None:
@@ -58,19 +90,14 @@ class ClassificationHead(nn.Module):
58
  mask: (batch, seq_len) - True for valid tokens, False for padding
59
  returns: (batch, num_labels)
60
  """
61
- if self.pooler == "mean":
 
 
62
  if mask is not None:
63
- # mask is (B, S)
64
- # x is (B, S, D)
65
- # Expand mask to (B, S, 1)
66
  mask_expanded = mask.unsqueeze(-1).float()
67
- # Zero out padding
68
  x = x * mask_expanded
69
- # Sum over sequence
70
  sum_embeddings = x.sum(dim=1)
71
- # Count valid tokens
72
  sum_mask = mask_expanded.sum(dim=1)
73
- # Avoid division by zero
74
  sum_mask = torch.clamp(sum_mask, min=1e-9)
75
  pooled = sum_embeddings / sum_mask
76
  else:
@@ -79,7 +106,6 @@ class ClassificationHead(nn.Module):
79
  pooled = x[:, 0, :]
80
  else: # max
81
  if mask is not None:
82
- # Mask padding with -inf
83
  mask_expanded = mask.unsqueeze(-1)
84
  x = x.masked_fill(~mask_expanded, float("-inf"))
85
  pooled, _ = x.max(dim=1)
 
1
  """Prediction heads for Transformer models.
2
 
3
  This module provides task-specific output heads:
4
+ - ClassificationHead: Sequence-level classification with pooling (mean/cls/max/attention)
5
  - TokenClassificationHead: Per-token classification (NER, POS tagging)
6
  - LMHead: Language modeling logits with optional weight tying
7
  - ProjectionHead: MLP for representation learning / contrastive tasks
 
14
 
15
  import torch
16
  import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class AttentionPooling(nn.Module):
21
+ """Learned attention pooling over sequence positions.
22
+
23
+ Computes a weighted sum of hidden states using a learned query vector,
24
+ producing a single fixed-size representation. This is generally superior
25
+ to mean pooling for classification tasks on encoder-decoder models where
26
+ hidden states are optimized for cross-attention rather than pooling.
27
+ """
28
+
29
+ def __init__(self, d_model: int):
30
+ super().__init__()
31
+ self.query = nn.Linear(d_model, 1, bias=False)
32
+
33
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
34
+ """
35
+ x: (batch, seq_len, d_model)
36
+ mask: (batch, seq_len) - True for valid tokens, False for padding
37
+ returns: (batch, d_model)
38
+ """
39
+ # Compute attention scores: (batch, seq_len, 1)
40
+ scores = self.query(x)
41
+ if mask is not None:
42
+ scores = scores.masked_fill(~mask.unsqueeze(-1), float("-inf"))
43
+ weights = F.softmax(scores, dim=1) # (batch, seq_len, 1)
44
+ # Weighted sum: (batch, d_model)
45
+ return (weights * x).sum(dim=1)
46
 
47
 
48
  class ClassificationHead(nn.Module):
 
52
  Args:
53
  d_model: hidden size from encoder/decoder
54
  num_labels: number of output classes
55
+ pooler: one of 'mean', 'cls', 'max', 'attention' - how to pool the sequence
56
  dropout: dropout probability before final linear layer
57
  hidden_dim: optional intermediate dimension for 2-layer MLP (improves capacity)
58
  """
 
61
  self,
62
  d_model: int,
63
  num_labels: int,
64
+ pooler: Literal["mean", "cls", "max", "attention"] = "mean",
65
  dropout: float = 0.1,
66
  hidden_dim: Optional[int] = None,
67
  ):
68
  super().__init__()
69
+ assert pooler in ("mean", "cls", "max", "attention"), "pooler must be 'mean'|'cls'|'max'|'attention'"
70
  self.pooler = pooler
71
  self.dropout = nn.Dropout(dropout)
72
+
73
+ if pooler == "attention":
74
+ self.attn_pool = AttentionPooling(d_model)
75
 
76
  # Optional 2-layer MLP for more capacity (useful for multi-label)
77
  if hidden_dim is not None:
 
90
  mask: (batch, seq_len) - True for valid tokens, False for padding
91
  returns: (batch, num_labels)
92
  """
93
+ if self.pooler == "attention":
94
+ pooled = self.attn_pool(x, mask)
95
+ elif self.pooler == "mean":
96
  if mask is not None:
 
 
 
97
  mask_expanded = mask.unsqueeze(-1).float()
 
98
  x = x * mask_expanded
 
99
  sum_embeddings = x.sum(dim=1)
 
100
  sum_mask = mask_expanded.sum(dim=1)
 
101
  sum_mask = torch.clamp(sum_mask, min=1e-9)
102
  pooled = sum_embeddings / sum_mask
103
  else:
 
106
  pooled = x[:, 0, :]
107
  else: # max
108
  if mask is not None:
 
109
  mask_expanded = mask.unsqueeze(-1)
110
  x = x.masked_fill(~mask_expanded, float("-inf"))
111
  pooled, _ = x.max(dim=1)
src/training/metrics.py CHANGED
@@ -110,9 +110,9 @@ def calculate_bertscore(
110
  )
111
 
112
  return {
113
- "precision": float(P.mean().item()),
114
- "recall": float(R.mean().item()),
115
- "f1": float(F1.mean().item()),
116
  }
117
 
118
 
@@ -239,3 +239,213 @@ def get_confusion_matrix(
239
  ) -> np.ndarray:
240
  """Compute confusion matrix."""
241
  return cast(np.ndarray, confusion_matrix(targets, predictions, labels=labels))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
 
112
  return {
113
+ "precision": float(P.mean().item()), # type: ignore[union-attr]
114
+ "recall": float(R.mean().item()), # type: ignore[union-attr]
115
+ "f1": float(F1.mean().item()), # type: ignore[union-attr]
116
  }
117
 
118
 
 
239
  ) -> np.ndarray:
240
  """Compute confusion matrix."""
241
  return cast(np.ndarray, confusion_matrix(targets, predictions, labels=labels))
242
+
243
+
244
+ # --------------- Multi-label Emotion Metrics ---------------
245
+
246
+
247
+ def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
248
+ """Compute macro F1: average F1 per class (as in GoEmotions paper).
249
+
250
+ This averages F1 across labels, giving equal weight to each emotion class
251
+ regardless of prevalence. Directly comparable to GoEmotions baselines.
252
+ """
253
+ preds = predictions.float()
254
+ gold = targets.float()
255
+
256
+ # Per-class TP, FP, FN
257
+ tp = (preds * gold).sum(dim=0)
258
+ fp = (preds * (1 - gold)).sum(dim=0)
259
+ fn = ((1 - preds) * gold).sum(dim=0)
260
+
261
+ precision = tp / (tp + fp).clamp(min=1e-8)
262
+ recall = tp / (tp + fn).clamp(min=1e-8)
263
+ f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
264
+
265
+ # Zero out F1 for classes with no support in either predictions or targets
266
+ mask = (tp + fp + fn) > 0
267
+ if mask.sum() == 0:
268
+ return 0.0
269
+ return float(f1[mask].mean().item())
270
+
271
+
272
+ def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
273
+ """Compute micro F1: aggregate TP/FP/FN across all classes.
274
+
275
+ This gives more weight to frequent classes. Useful when class distribution matters.
276
+ """
277
+ preds = predictions.float()
278
+ gold = targets.float()
279
+
280
+ tp = (preds * gold).sum()
281
+ fp = (preds * (1 - gold)).sum()
282
+ fn = ((1 - preds) * gold).sum()
283
+
284
+ precision = tp / (tp + fp).clamp(min=1e-8)
285
+ recall = tp / (tp + fn).clamp(min=1e-8)
286
+ f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
287
+ return float(f1.item())
288
+
289
+
290
+ def multilabel_per_class_metrics(
291
+ predictions: torch.Tensor,
292
+ targets: torch.Tensor,
293
+ class_names: Sequence[str] | None = None,
294
+ ) -> Dict[str, Dict[str, float]]:
295
+ """Compute per-class precision, recall, F1 for multi-label classification.
296
+
297
+ Returns a dict mapping class name/index to its metrics.
298
+ """
299
+ preds = predictions.float()
300
+ gold = targets.float()
301
+ num_classes = preds.shape[1]
302
+
303
+ tp = (preds * gold).sum(dim=0)
304
+ fp = (preds * (1 - gold)).sum(dim=0)
305
+ fn = ((1 - preds) * gold).sum(dim=0)
306
+
307
+ report: Dict[str, Dict[str, float]] = {}
308
+ for i in range(num_classes):
309
+ name = class_names[i] if class_names else str(i)
310
+ p = (tp[i] / (tp[i] + fp[i]).clamp(min=1e-8)).item()
311
+ r = (tp[i] / (tp[i] + fn[i]).clamp(min=1e-8)).item()
312
+ f = (2 * p * r) / max(p + r, 1e-8)
313
+ report[name] = {
314
+ "precision": p,
315
+ "recall": r,
316
+ "f1": f,
317
+ "support": int(gold[:, i].sum().item()),
318
+ }
319
+ return report
320
+
321
+
322
+ def tune_per_class_thresholds(
323
+ logits: torch.Tensor,
324
+ targets: torch.Tensor,
325
+ thresholds: Sequence[float] | None = None,
326
+ ) -> tuple[List[float], float]:
327
+ """Tune per-class thresholds on validation set to maximize macro F1.
328
+
329
+ For each class, tries multiple thresholds and selects the one that
330
+ maximizes that class's F1 score. This is standard practice for multi-label
331
+ classification (used in the original GoEmotions paper).
332
+
333
+ Args:
334
+ logits: Raw model logits (batch, num_classes)
335
+ targets: Binary target labels (batch, num_classes)
336
+ thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
337
+
338
+ Returns:
339
+ Tuple of (best_thresholds_per_class, resulting_macro_f1)
340
+ """
341
+ if thresholds is None:
342
+ thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
343
+
344
+ probs = torch.sigmoid(logits)
345
+ num_classes = probs.shape[1]
346
+ gold = targets.float()
347
+
348
+ best_thresholds: List[float] = []
349
+ for c in range(num_classes):
350
+ best_f1 = -1.0
351
+ best_t = 0.5
352
+ for t in thresholds:
353
+ preds = (probs[:, c] >= t).float()
354
+ tp = (preds * gold[:, c]).sum()
355
+ fp = (preds * (1 - gold[:, c])).sum()
356
+ fn = ((1 - preds) * gold[:, c]).sum()
357
+ if tp + fp > 0 and tp + fn > 0:
358
+ p = tp / (tp + fp)
359
+ r = tp / (tp + fn)
360
+ f1 = (2 * p * r / (p + r)).item()
361
+ else:
362
+ f1 = 0.0
363
+ if f1 > best_f1:
364
+ best_f1 = f1
365
+ best_t = t
366
+ best_thresholds.append(best_t)
367
+
368
+ # Compute resulting macro F1 with tuned thresholds
369
+ tuned_preds = torch.zeros_like(probs)
370
+ for c in range(num_classes):
371
+ tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
372
+ macro_f1 = multilabel_macro_f1(tuned_preds, targets)
373
+
374
+ return best_thresholds, macro_f1
375
+
376
+
377
+ # --------------- Statistical Tests ---------------
378
+
379
+
380
+ def bootstrap_confidence_interval(
381
+ scores: Sequence[float],
382
+ n_bootstrap: int = 1000,
383
+ confidence: float = 0.95,
384
+ seed: int = 42,
385
+ ) -> tuple[float, float, float]:
386
+ """Compute bootstrap confidence interval for a metric.
387
+
388
+ Args:
389
+ scores: Per-sample metric values
390
+ n_bootstrap: Number of bootstrap resamples
391
+ confidence: Confidence level (default 95%)
392
+ seed: Random seed for reproducibility
393
+
394
+ Returns:
395
+ Tuple of (mean, lower_bound, upper_bound)
396
+ """
397
+ rng = np.random.default_rng(seed)
398
+ scores_arr = np.array(scores)
399
+ n = len(scores_arr)
400
+
401
+ bootstrap_means = []
402
+ for _ in range(n_bootstrap):
403
+ sample = rng.choice(scores_arr, size=n, replace=True)
404
+ bootstrap_means.append(float(np.mean(sample)))
405
+
406
+ bootstrap_means.sort()
407
+ alpha = 1 - confidence
408
+ lower_idx = int(alpha / 2 * n_bootstrap)
409
+ upper_idx = int((1 - alpha / 2) * n_bootstrap)
410
+
411
+ return (
412
+ float(np.mean(scores_arr)),
413
+ bootstrap_means[lower_idx],
414
+ bootstrap_means[min(upper_idx, n_bootstrap - 1)],
415
+ )
416
+
417
+
418
+ def paired_bootstrap_test(
419
+ scores_a: Sequence[float],
420
+ scores_b: Sequence[float],
421
+ n_bootstrap: int = 10000,
422
+ seed: int = 42,
423
+ ) -> float:
424
+ """Paired bootstrap significance test between two systems.
425
+
426
+ Tests if system B is significantly better than system A.
427
+
428
+ Args:
429
+ scores_a: Per-sample scores from system A
430
+ scores_b: Per-sample scores from system B
431
+ n_bootstrap: Number of bootstrap iterations
432
+ seed: Random seed
433
+
434
+ Returns:
435
+ p-value (probability that B is not better than A)
436
+ """
437
+ rng = np.random.default_rng(seed)
438
+ a = np.array(scores_a)
439
+ b = np.array(scores_b)
440
+ assert len(a) == len(b), "Both score lists must have the same length"
441
+
442
+ n = len(a)
443
+
444
+ count = 0
445
+ for _ in range(n_bootstrap):
446
+ idx = rng.choice(n, size=n, replace=True)
447
+ diff = float(np.mean(b[idx]) - np.mean(a[idx]))
448
+ if diff <= 0:
449
+ count += 1
450
+
451
+ return count / n_bootstrap
src/training/trainer.py CHANGED
@@ -7,6 +7,8 @@ Handles training across summarization, emotion, and topic heads with:
7
  - Cosine LR schedule with warmup
8
  - Early stopping
9
  - MLflow logging
 
 
10
 
11
  Author: Oliver Perrin
12
  Date: December 2025
@@ -22,6 +24,7 @@ from dataclasses import dataclass
22
  from typing import Any, Callable, Dict, List
23
 
24
  import mlflow
 
25
  import torch
26
  import torch.nn.functional as F
27
  from torch.optim.lr_scheduler import LambdaLR
@@ -53,6 +56,16 @@ class TrainerConfig:
53
  # Early stopping
54
  early_stopping_patience: int | None = 5
55
 
 
 
 
 
 
 
 
 
 
 
56
  # MLflow
57
  experiment_name: str = "LexiMind"
58
  run_name: str | None = None
@@ -210,7 +223,7 @@ class Trainer:
210
  train: bool,
211
  epoch: int,
212
  ) -> Dict[str, float]:
213
- """Run one epoch."""
214
  self.model.train(train)
215
  metrics: Dict[str, List[float]] = defaultdict(list)
216
  iterators = {task: iter(loader) for task, loader in loaders.items()}
@@ -220,12 +233,33 @@ class Trainer:
220
  phase = "Train" if train else "Val"
221
  pbar = tqdm(range(max_batches), desc=f" {phase}", leave=False, file=sys.stderr)
222
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  ctx = torch.enable_grad() if train else torch.no_grad()
224
  with ctx:
225
  for step in pbar:
226
  step_loss = 0.0
227
 
228
- for task, loader in loaders.items():
 
 
 
 
 
 
 
 
 
229
  batch = self._get_batch(iterators, loader, task)
230
  if batch is None:
231
  continue
@@ -253,6 +287,14 @@ class Trainer:
253
  scaled = (loss * weight) / accum
254
  scaled.backward()
255
 
 
 
 
 
 
 
 
 
256
  # Optimizer step
257
  if train and (step + 1) % accum == 0:
258
  torch.nn.utils.clip_grad_norm_(
@@ -415,6 +457,56 @@ class Trainer:
415
  tqdm.write(f"{'=' * 50}\n")
416
  self.model.train()
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  def _log_config(self) -> None:
419
  """Log config to MLflow."""
420
  mlflow.log_params({
 
7
  - Cosine LR schedule with warmup
8
  - Early stopping
9
  - MLflow logging
10
+ - Temperature-based task sampling (configurable alpha)
11
+ - Gradient conflict diagnostics
12
 
13
  Author: Oliver Perrin
14
  Date: December 2025
 
24
  from typing import Any, Callable, Dict, List
25
 
26
  import mlflow
27
+ import numpy as np
28
  import torch
29
  import torch.nn.functional as F
30
  from torch.optim.lr_scheduler import LambdaLR
 
56
  # Early stopping
57
  early_stopping_patience: int | None = 5
58
 
59
+ # Task sampling strategy: "round_robin" or "temperature"
60
+ # Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
61
+ # alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
62
+ task_sampling: str = "round_robin"
63
+ task_sampling_alpha: float = 0.5
64
+
65
+ # Gradient conflict diagnostics
66
+ # Compute inter-task gradient cosine similarity every N steps (0 = disabled)
67
+ gradient_conflict_frequency: int = 0
68
+
69
  # MLflow
70
  experiment_name: str = "LexiMind"
71
  run_name: str | None = None
 
223
  train: bool,
224
  epoch: int,
225
  ) -> Dict[str, float]:
226
+ """Run one epoch with configurable task sampling strategy."""
227
  self.model.train(train)
228
  metrics: Dict[str, List[float]] = defaultdict(list)
229
  iterators = {task: iter(loader) for task, loader in loaders.items()}
 
233
  phase = "Train" if train else "Val"
234
  pbar = tqdm(range(max_batches), desc=f" {phase}", leave=False, file=sys.stderr)
235
 
236
+ # Temperature-based task sampling: p_i ∝ n_i^alpha
237
+ task_names = list(loaders.keys())
238
+ if self.config.task_sampling == "temperature" and len(task_names) > 1:
239
+ sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
240
+ alpha = self.config.task_sampling_alpha
241
+ probs = sizes ** alpha
242
+ probs = probs / probs.sum()
243
+ tqdm.write(f" Temperature sampling (α={alpha}): " +
244
+ ", ".join(f"{t}={p:.2%}" for t, p in zip(task_names, probs, strict=True)))
245
+ else:
246
+ probs = None
247
+
248
  ctx = torch.enable_grad() if train else torch.no_grad()
249
  with ctx:
250
  for step in pbar:
251
  step_loss = 0.0
252
 
253
+ # Select tasks for this step
254
+ if probs is not None and train:
255
+ # Temperature sampling: sample tasks based on dataset size
256
+ selected_tasks = list(np.random.choice(task_names, size=len(task_names), replace=True, p=probs))
257
+ else:
258
+ # Round-robin: all tasks every step
259
+ selected_tasks = task_names
260
+
261
+ for task in selected_tasks:
262
+ loader = loaders[task]
263
  batch = self._get_batch(iterators, loader, task)
264
  if batch is None:
265
  continue
 
287
  scaled = (loss * weight) / accum
288
  scaled.backward()
289
 
290
+ # Gradient conflict diagnostics
291
+ if (train and self.config.gradient_conflict_frequency > 0
292
+ and (step + 1) % self.config.gradient_conflict_frequency == 0):
293
+ conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
294
+ for k, v in conflict_stats.items():
295
+ metrics[f"grad_{k}"].append(v)
296
+ mlflow.log_metric(f"grad_{k}", v, step=self.global_step)
297
+
298
  # Optimizer step
299
  if train and (step + 1) % accum == 0:
300
  torch.nn.utils.clip_grad_norm_(
 
457
  tqdm.write(f"{'=' * 50}\n")
458
  self.model.train()
459
 
460
+ def _compute_gradient_conflicts(
461
+ self,
462
+ loaders: Dict[str, DataLoader],
463
+ iterators: Dict,
464
+ ) -> Dict[str, float]:
465
+ """Compute inter-task gradient cosine similarity to diagnose conflicts.
466
+
467
+ Returns cosine similarity between gradient vectors for each task pair.
468
+ Negative values indicate conflicting gradients (negative transfer risk).
469
+ """
470
+ task_grads: Dict[str, torch.Tensor] = {}
471
+
472
+ for task, loader in loaders.items():
473
+ self.optimizer.zero_grad()
474
+ batch = self._get_batch(iterators, loader, task)
475
+ if batch is None:
476
+ continue
477
+
478
+ dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
479
+ with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
480
+ loss, _ = self._forward_task(task, batch)
481
+
482
+ if torch.isnan(loss):
483
+ continue
484
+
485
+ loss.backward()
486
+
487
+ # Flatten all gradients into a single vector
488
+ grad_vec = []
489
+ for p in self.model.parameters():
490
+ if p.grad is not None:
491
+ grad_vec.append(p.grad.detach().clone().flatten())
492
+ if grad_vec:
493
+ task_grads[task] = torch.cat(grad_vec)
494
+
495
+ self.optimizer.zero_grad()
496
+
497
+ # Compute pairwise cosine similarity
498
+ stats: Dict[str, float] = {}
499
+ tasks = list(task_grads.keys())
500
+ for i in range(len(tasks)):
501
+ for j in range(i + 1, len(tasks)):
502
+ t1, t2 = tasks[i], tasks[j]
503
+ g1, g2 = task_grads[t1], task_grads[t2]
504
+ cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
505
+ stats[f"cos_sim_{t1}_{t2}"] = cos_sim
506
+ stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
507
+
508
+ return stats
509
+
510
  def _log_config(self) -> None:
511
  """Log config to MLflow."""
512
  mlflow.log_params({