Spaces:
Sleeping
Sleeping
OliverPerrin commited on
Commit ·
0d858b5
1
Parent(s): 35856cb
Added many new improvements based on feedback from others
Browse files- configs/training/dev.yaml +3 -0
- configs/training/full.yaml +7 -0
- configs/training/medium.yaml +3 -0
- docs/research_paper.tex +63 -30
- scripts/evaluate.py +196 -53
- scripts/train.py +3 -0
- scripts/train_multiseed.py +197 -0
- src/data/dataset.py +81 -1
- src/models/factory.py +4 -2
- src/models/heads.py +39 -13
- src/training/metrics.py +213 -3
- src/training/trainer.py +94 -2
configs/training/dev.yaml
CHANGED
|
@@ -37,6 +37,9 @@ trainer:
|
|
| 37 |
max_val_samples: 300
|
| 38 |
early_stopping_patience: 5
|
| 39 |
log_grad_norm_frequency: 100
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Enable compile for speed (worth the startup cost)
|
| 42 |
compile_encoder: true
|
|
|
|
| 37 |
max_val_samples: 300
|
| 38 |
early_stopping_patience: 5
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
+
task_sampling: round_robin
|
| 41 |
+
task_sampling_alpha: 0.5
|
| 42 |
+
gradient_conflict_frequency: 0
|
| 43 |
|
| 44 |
# Enable compile for speed (worth the startup cost)
|
| 45 |
compile_encoder: true
|
configs/training/full.yaml
CHANGED
|
@@ -38,6 +38,13 @@ trainer:
|
|
| 38 |
max_val_samples: 3000 # Enough for stable metrics
|
| 39 |
early_stopping_patience: 3 # Stop quickly when plateauing
|
| 40 |
log_grad_norm_frequency: 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
compile_encoder: true
|
| 43 |
compile_decoder: true
|
|
|
|
| 38 |
max_val_samples: 3000 # Enough for stable metrics
|
| 39 |
early_stopping_patience: 3 # Stop quickly when plateauing
|
| 40 |
log_grad_norm_frequency: 200
|
| 41 |
+
# Task sampling: "round_robin" (default) or "temperature"
|
| 42 |
+
# Temperature sampling: p_i proportional to n_i^alpha, reduces dominance of large tasks
|
| 43 |
+
task_sampling: round_robin
|
| 44 |
+
task_sampling_alpha: 0.5
|
| 45 |
+
# Gradient conflict diagnostics: compute inter-task gradient cosine similarity
|
| 46 |
+
# every N steps (0 = disabled). Helps diagnose negative transfer.
|
| 47 |
+
gradient_conflict_frequency: 0
|
| 48 |
|
| 49 |
compile_encoder: true
|
| 50 |
compile_decoder: true
|
configs/training/medium.yaml
CHANGED
|
@@ -37,6 +37,9 @@ trainer:
|
|
| 37 |
max_val_samples: 2500
|
| 38 |
early_stopping_patience: 3 # More patience
|
| 39 |
log_grad_norm_frequency: 100
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
compile_encoder: true
|
| 42 |
compile_decoder: true
|
|
|
|
| 37 |
max_val_samples: 2500
|
| 38 |
early_stopping_patience: 3 # More patience
|
| 39 |
log_grad_norm_frequency: 100
|
| 40 |
+
task_sampling: round_robin
|
| 41 |
+
task_sampling_alpha: 0.5
|
| 42 |
+
gradient_conflict_frequency: 0
|
| 43 |
|
| 44 |
compile_encoder: true
|
| 45 |
compile_decoder: true
|
docs/research_paper.tex
CHANGED
|
@@ -44,7 +44,7 @@ Email: perrinot@appstate.edu}}
|
|
| 44 |
\maketitle
|
| 45 |
|
| 46 |
\begin{abstract}
|
| 47 |
-
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies comparing single-task specialists against multi-task configurations, we find that: (1) MTL provides a +3.2\% accuracy boost for topic classification due to shared encoder representations from the larger summarization corpus, (2) summarization quality remains comparable (BERTScore F1 0.83 vs. 0.82 single-task), and (3) emotion detection suffers negative transfer ($-$0.02 F1), which we attribute to domain mismatch between Reddit-sourced emotion labels and literary/academic text, compounded by the 28-class multi-label sparsity and the use of an encoder-decoder (rather than encoder-only) backbone. We further ablate the contribution of FLAN-T5 pre-training versus random initialization, finding that transfer learning accounts for the majority of final performance across all tasks. Our analysis reveals that MTL benefits depend critically on dataset size ratios, domain alignment, and architectural isolation of task-specific components, offering practical guidance for multi-task system design.
|
| 48 |
\end{abstract}
|
| 49 |
|
| 50 |
\begin{IEEEkeywords}
|
|
@@ -76,7 +76,7 @@ To answer these questions, we construct \textbf{LexiMind}, a multi-task system b
|
|
| 76 |
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
\end{itemize}
|
| 78 |
|
| 79 |
-
We acknowledge important limitations: our results are from single-seed runs, we do not
|
| 80 |
|
| 81 |
%=============================================================================
|
| 82 |
\section{Related Work}
|
|
@@ -86,7 +86,9 @@ We acknowledge important limitations: our results are from single-seed runs, we
|
|
| 86 |
|
| 87 |
Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
|
| 88 |
|
| 89 |
-
\textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach
|
|
|
|
|
|
|
| 90 |
|
| 91 |
\textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
|
| 92 |
|
|
@@ -136,7 +138,9 @@ Emotion (28 labels) & GoEmotions (Reddit) & 43,410 & 5,426 & 5,427 \\
|
|
| 136 |
\end{tabular}
|
| 137 |
\end{table}
|
| 138 |
|
| 139 |
-
\textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed
|
|
|
|
|
|
|
| 140 |
|
| 141 |
\textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
|
| 142 |
|
|
@@ -154,12 +158,12 @@ LexiMind uses FLAN-T5-base (272M parameters) as the backbone, with a custom reim
|
|
| 154 |
|
| 155 |
Task-specific heads branch from the shared encoder:
|
| 156 |
\begin{itemize}
|
| 157 |
-
\item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing)
|
| 158 |
\item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
|
| 159 |
-
\item \textbf{Emotion}: Linear classifier on
|
| 160 |
\end{itemize}
|
| 161 |
|
| 162 |
-
\textbf{Architectural note.}
|
| 163 |
|
| 164 |
\subsection{Training Configuration}
|
| 165 |
|
|
@@ -175,10 +179,14 @@ All experiments use consistent hyperparameters unless otherwise noted:
|
|
| 175 |
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 176 |
\end{itemize}
|
| 177 |
|
| 178 |
-
\textbf{Task scheduling.}
|
| 179 |
|
| 180 |
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
\subsection{Baselines and Ablations}
|
| 183 |
|
| 184 |
We compare four configurations:
|
|
@@ -195,12 +203,12 @@ We additionally ablate FLAN-T5 initialization vs. random initialization to isola
|
|
| 195 |
\subsection{Evaluation Metrics}
|
| 196 |
|
| 197 |
\begin{itemize}
|
| 198 |
-
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BERTScore F1 \cite{zhang2019bertscore} using RoBERTa-large (semantic similarity). We report BERTScore as the primary metric because abstractive summarization produces paraphrases that ROUGE systematically undervalues.
|
| 199 |
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 200 |
-
\item \textbf{Emotion}: Sample-averaged F1
|
| 201 |
\end{itemize}
|
| 202 |
|
| 203 |
-
\textbf{Statistical
|
| 204 |
|
| 205 |
%=============================================================================
|
| 206 |
\section{Results}
|
|
@@ -234,11 +242,11 @@ Emotion & Sample-avg F1 & \textbf{0.218} & 0.199 \\
|
|
| 234 |
\textbf{Key finding}: MTL provides heterogeneous effects across tasks:
|
| 235 |
|
| 236 |
\begin{itemize}
|
| 237 |
-
\item \textbf{Topic classification gains +3.2\% accuracy} from MTL. The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). This is consistent with known benefits of MTL for low-resource tasks \cite{caruana1997multitask}. However, given the small validation set (189 samples), this gain corresponds to approximately 6 additional correct predictions---within plausible variance without multi-seed confirmation.
|
| 238 |
|
| 239 |
-
\item \textbf{Summarization shows modest improvement} (+0.009 BERTScore F1). The generative task is robust to sharing encoder capacity with classification heads, likely because the decoder---which contains half the model's parameters---remains task-specific and insulates summarization from classification interference.
|
| 240 |
|
| 241 |
-
\item \textbf{Emotion detection degrades by $-$0.019 F1}. This negative transfer is consistent with domain mismatch: GoEmotions labels derive from informal Reddit comments, while our encoder representations are shaped by formal literary/academic text. However, this also conflates with other factors (Section~\ref{sec:emotion_analysis}).
|
| 242 |
\end{itemize}
|
| 243 |
|
| 244 |
\subsection{Baseline Comparisons}
|
|
@@ -323,11 +331,11 @@ Our emotion sample-averaged F1 (0.20) is substantially lower than reported GoEmo
|
|
| 323 |
\begin{enumerate}
|
| 324 |
\item \textbf{Domain shift}: GoEmotions labels were annotated on Reddit comments in conversational register. Our encoder is shaped by literary and academic text through the summarization objective, producing representations optimized for formal text. This domain mismatch is likely the largest factor, but we cannot isolate it without a controlled experiment (e.g., fine-tuning BERT on GoEmotions with our frozen encoder vs. BERT's own encoder).
|
| 325 |
|
| 326 |
-
\item \textbf{Label sparsity and class imbalance}: The 28-class multi-label scheme creates extreme imbalance. Rare emotions (grief, remorse, nervousness) appear in $<$2\% of samples. We
|
| 327 |
|
| 328 |
-
\item \textbf{Architecture mismatch}: Published GoEmotions baselines use encoder-only models (BERT-base), where the full model capacity is dedicated to producing classification-ready representations. Our encoder-decoder architecture optimizes the encoder primarily for producing representations that the decoder can use for summarization---classification heads receive these representations secondarily.
|
| 329 |
|
| 330 |
-
\item \textbf{Metric reporting}: We report sample-averaged F1 (per-
|
| 331 |
\end{enumerate}
|
| 332 |
|
| 333 |
\textbf{Implication}: Off-the-shelf emotion datasets from social media should not be naively combined with literary/academic tasks in MTL. Domain-specific emotion annotation or domain adaptation techniques are needed for formal text domains.
|
|
@@ -366,9 +374,9 @@ Our results support nuanced, task-dependent guidance:
|
|
| 366 |
|
| 367 |
\subsection{Comparison to MTL Literature}
|
| 368 |
|
| 369 |
-
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---we observe this in the contrast between topic (positive transfer) and emotion (negative transfer). Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our
|
| 370 |
|
| 371 |
-
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment.
|
| 372 |
|
| 373 |
\subsection{Implications for Practitioners}
|
| 374 |
|
|
@@ -379,9 +387,13 @@ Based on our findings:
|
|
| 379 |
|
| 380 |
\item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
|
| 381 |
|
| 382 |
-
\item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation.
|
| 383 |
|
| 384 |
-
\item \textbf{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
\end{enumerate}
|
| 386 |
|
| 387 |
\subsection{Limitations}
|
|
@@ -390,35 +402,39 @@ Based on our findings:
|
|
| 390 |
We identify several limitations that constrain the generalizability of our findings:
|
| 391 |
|
| 392 |
\begin{itemize}
|
| 393 |
-
\item \textbf{Single-seed results}:
|
| 394 |
|
| 395 |
-
\item \textbf{
|
| 396 |
|
| 397 |
\item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results.
|
| 398 |
|
| 399 |
-
\item \textbf{
|
|
|
|
|
|
|
| 400 |
|
| 401 |
\item \textbf{No human evaluation}: ROUGE and BERTScore are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
|
| 402 |
|
| 403 |
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 404 |
|
| 405 |
-
\item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text.
|
| 406 |
\end{itemize}
|
| 407 |
|
| 408 |
\subsection{Future Work}
|
| 409 |
|
| 410 |
\begin{itemize}
|
| 411 |
-
\item \textbf{Gradient-conflict mitigation}:
|
| 412 |
|
| 413 |
-
\item \textbf{Parameter-efficient multi-tasking}:
|
|
|
|
|
|
|
| 414 |
|
| 415 |
\item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
|
| 416 |
|
| 417 |
-
\item \textbf{Multi-seed evaluation}:
|
| 418 |
|
| 419 |
\item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
|
| 420 |
|
| 421 |
-
\item \textbf{
|
| 422 |
\end{itemize}
|
| 423 |
|
| 424 |
%=============================================================================
|
|
@@ -427,7 +443,9 @@ We identify several limitations that constrain the generalizability of our findi
|
|
| 427 |
|
| 428 |
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our ablation studies reveal heterogeneous transfer effects: topic classification benefits from shared representations with the larger summarization corpus (+3.2\% accuracy), while emotion detection suffers negative transfer ($-$0.02 F1) due to domain mismatch with Reddit-sourced labels. Summarization quality is robust to multi-task training, insulated by its task-specific decoder.
|
| 429 |
|
| 430 |
-
|
|
|
|
|
|
|
| 431 |
|
| 432 |
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 433 |
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
|
@@ -513,6 +531,21 @@ N. Houlsby et al., ``Parameter-efficient transfer learning for NLP,'' in \textit
|
|
| 513 |
\bibitem{lin2017focal}
|
| 514 |
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
\end{thebibliography}
|
| 517 |
|
| 518 |
\end{document}
|
|
|
|
| 44 |
\maketitle
|
| 45 |
|
| 46 |
\begin{abstract}
|
| 47 |
+
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies comparing single-task specialists against multi-task configurations, we find that: (1) MTL provides a +3.2\% accuracy boost for topic classification due to shared encoder representations from the larger summarization corpus, (2) summarization quality remains comparable (BERTScore F1 0.83 vs. 0.82 single-task), and (3) emotion detection suffers negative transfer ($-$0.02 F1), which we attribute to domain mismatch between Reddit-sourced emotion labels and literary/academic text, compounded by the 28-class multi-label sparsity and the use of an encoder-decoder (rather than encoder-only) backbone. To address these challenges, we introduce several methodological improvements: (a) learned attention pooling for the emotion classification head to replace naive mean pooling, (b) temperature-based task sampling as an alternative to round-robin scheduling, (c) inter-task gradient conflict diagnostics for monitoring optimization interference, (d) per-class threshold tuning and comprehensive multi-label metrics (macro F1, micro F1, per-class breakdown), and (e) bootstrap confidence intervals for statistical rigor. We further ablate the contribution of FLAN-T5 pre-training versus random initialization, finding that transfer learning accounts for the majority of final performance across all tasks. Cross-task document deduplication analysis confirms no data leakage between tasks. Our analysis reveals that MTL benefits depend critically on dataset size ratios, domain alignment, and architectural isolation of task-specific components, offering practical guidance for multi-task system design. Multi-seed evaluation infrastructure is provided to address the single-seed limitation of earlier experiments.
|
| 48 |
\end{abstract}
|
| 49 |
|
| 50 |
\begin{IEEEkeywords}
|
|
|
|
| 76 |
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
\end{itemize}
|
| 78 |
|
| 79 |
+
We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We do not yet apply gradient-conflict mitigation methods (PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}), but introduce gradient conflict diagnostics to characterize task interference. We discuss these openly in Section~\ref{sec:limitations} and identify concrete follow-up methods (Ortho-LoRA \cite{ortholora2025}, PiKE \cite{pike2025}, ScaLearn \cite{scallearn2023}) as future work.
|
| 80 |
|
| 81 |
%=============================================================================
|
| 82 |
\section{Related Work}
|
|
|
|
| 86 |
|
| 87 |
Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
|
| 88 |
|
| 89 |
+
\textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach---but includes gradient conflict diagnostics (inter-task cosine similarity monitoring) to characterize optimization interference. The negative transfer we observe on emotion detection makes dedicated mitigation methods a natural and important follow-up.
|
| 90 |
+
|
| 91 |
+
\textbf{Recent advances in multi-task optimization.} Several recent methods address task interference more precisely. Ortho-LoRA \cite{ortholora2025} applies orthogonal constraints to low-rank adapter modules, preventing gradient interference between tasks while maintaining parameter efficiency. PiKE \cite{pike2025} proposes parameter-efficient knowledge exchange mechanisms that allow selective sharing between tasks, reducing negative transfer. ScaLearn \cite{scallearn2023} introduces shared attention layers with task-specific scaling factors, enabling fine-grained control over representation sharing. Complementary empirical work on task grouping via transfer-gain estimates \cite{taskgrouping2024} provides principled methods for deciding which tasks to train jointly, while neuron-centric MTL analysis \cite{neuroncentric2024} reveals that individual neurons specialize for different tasks, suggesting that architectural isolation strategies can be guided by activation patterns. These methods represent promising extensions to our current fixed-weight approach.
|
| 92 |
|
| 93 |
\textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
|
| 94 |
|
|
|
|
| 138 |
\end{tabular}
|
| 139 |
\end{table}
|
| 140 |
|
| 141 |
+
\textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed \textit{within} the literary and academic subsets, as they are drawn from disjoint sources. We note that the academic subset is substantially larger ($\sim$45K vs. $\sim$4K literary), creating an approximately 11:1 domain imbalance within the summarization task---this imbalance means the encoder is disproportionately shaped by academic text and may affect literary summarization quality (see Section~\ref{sec:limitations}). Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects, 20 Newsgroups categories) and mapped to our 7-class taxonomy; no manual annotation was performed, which introduces potential noise from metadata inaccuracies (e.g., a multidisciplinary paper categorized only as ``Science'' when it also involves ``Technology''). GoEmotions is used as-is from the HuggingFace datasets hub.
|
| 142 |
+
|
| 143 |
+
\textbf{Cross-task deduplication.} Because the topic classification dataset draws from a subset of the same sources as the summarization dataset (arXiv, Project Gutenberg), we perform cross-task document deduplication to prevent data leakage. Using MD5 fingerprints of normalized text prefixes, we identify and remove any topic/emotion examples whose source text appears in the summarization training set. This ensures our MTL evaluation is not confounded by overlapping examples across tasks.
|
| 144 |
|
| 145 |
\textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
|
| 146 |
|
|
|
|
| 158 |
|
| 159 |
Task-specific heads branch from the shared encoder:
|
| 160 |
\begin{itemize}
|
| 161 |
+
\item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing $\epsilon=0.1$, greedy decoding with max length 512 tokens)
|
| 162 |
\item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
|
| 163 |
+
\item \textbf{Emotion}: Linear classifier on \textit{attention-pooled} encoder hidden states with sigmoid activation (binary cross-entropy loss). Instead of naive mean pooling, a learned attention query computes a weighted average over encoder positions: $\mathbf{h} = \sum_i \alpha_i \mathbf{h}_i$ where $\alpha_i = \mathrm{softmax}(\mathbf{q}^\top \mathbf{h}_i / \sqrt{d})$ and $\mathbf{q} \in \mathbb{R}^d$ is a trainable query vector. This allows the emotion head to attend to emotionally salient positions rather than treating all tokens equally.
|
| 164 |
\end{itemize}
|
| 165 |
|
| 166 |
+
\textbf{Architectural note.} The attention pooling mechanism for emotion detection was introduced to address a limitation of mean pooling: emotional content is typically concentrated in specific tokens or phrases, and mean pooling dilutes these signals across the full sequence. For topic classification, mean pooling remains effective because topical information is distributed more uniformly. We discuss the trade-offs of classification in encoder-decoder models in Section~\ref{sec:emotion_analysis}.
|
| 167 |
|
| 168 |
\subsection{Training Configuration}
|
| 169 |
|
|
|
|
| 179 |
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 180 |
\end{itemize}
|
| 181 |
|
| 182 |
+
\textbf{Task scheduling.} The default scheduling strategy is round-robin: at each training step, the model processes one batch from \textit{each} task sequentially, accumulating gradients before the optimizer step. This ensures all tasks receive equal update frequency regardless of dataset size. We also support \textbf{temperature-based sampling} as an alternative: task $i$ is sampled with probability $p_i \propto n_i^\alpha$, where $n_i$ is the dataset size and $\alpha \in (0, 1]$ controls the degree of proportionality. With $\alpha=0.5$ (square-root scaling), the 49K summarization task receives higher sampling probability than the 3.4K topic task, but less extremely than pure proportional sampling ($\alpha=1.0$). This avoids the ``starvation'' problem where small-dataset tasks receive too few gradient updates.
|
| 183 |
|
| 184 |
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 185 |
|
| 186 |
+
\textbf{Gradient conflict monitoring.} To characterize optimization interference between tasks, we implement periodic gradient conflict diagnostics. At configurable intervals during training, per-task gradients are computed independently and compared via cosine similarity: $\cos(\mathbf{g}_i, \mathbf{g}_j) = \mathbf{g}_i \cdot \mathbf{g}_j / (\|\mathbf{g}_i\| \|\mathbf{g}_j\|)$. Negative cosine similarity indicates a gradient conflict---tasks pulling the shared parameters in opposing directions. Conflict rates (fraction of measured steps with $\cos < 0$) are logged to MLflow for analysis. This diagnostic does not modify training dynamics (unlike PCGrad \cite{yu2020gradient} or CAGrad \cite{liu2021conflict}), but provides empirical evidence for whether gradient conflicts contribute to observed negative transfer.
|
| 187 |
+
|
| 188 |
+
\textbf{Early stopping.} Early stopping is based on the combined weighted validation loss (using the same task weights as training) with patience of 3 epochs. The best checkpoint is selected by minimum combined validation loss.
|
| 189 |
+
|
| 190 |
\subsection{Baselines and Ablations}
|
| 191 |
|
| 192 |
We compare four configurations:
|
|
|
|
| 203 |
\subsection{Evaluation Metrics}
|
| 204 |
|
| 205 |
\begin{itemize}
|
| 206 |
+
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BERTScore F1 \cite{zhang2019bertscore} using RoBERTa-large (semantic similarity). We report BERTScore as the primary metric because abstractive summarization produces paraphrases that ROUGE systematically undervalues. Per-domain breakdown (literary vs. academic) is provided to analyze domain-specific quality.
|
| 207 |
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 208 |
+
\item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
|
| 209 |
\end{itemize}
|
| 210 |
|
| 211 |
+
\textbf{Statistical rigor.} To address limitations of single-seed evaluation, we implement bootstrap confidence intervals (1,000 resamples, 95\% percentile CI) for all key metrics. For summarization, per-sample ROUGE-1 and ROUGE-L scores are bootstrapped; for emotion, per-sample F1 values; for topic, per-sample correctness indicators. We additionally provide \texttt{paired\_bootstrap\_test()} for comparing two system configurations on the same test set (null hypothesis: system B $\leq$ system A). Multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) automates training across $k$ seeds and reports mean $\pm$ standard deviation across runs, enabling variance-aware claims. Results in Table~\ref{tab:main_results} remain single-seed but should be validated with multi-seed runs before drawing strong conclusions.
|
| 212 |
|
| 213 |
%=============================================================================
|
| 214 |
\section{Results}
|
|
|
|
| 242 |
\textbf{Key finding}: MTL provides heterogeneous effects across tasks:
|
| 243 |
|
| 244 |
\begin{itemize}
|
| 245 |
+
\item \textbf{Topic classification gains +3.2\% accuracy} from MTL. The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). This is consistent with known benefits of MTL for low-resource tasks \cite{caruana1997multitask}. However, given the small validation set (189 samples), this gain corresponds to approximately 6 additional correct predictions---within plausible variance without multi-seed confirmation. Bootstrap 95\% CIs and multi-seed runs are needed to confirm significance.
|
| 246 |
|
| 247 |
+
\item \textbf{Summarization shows modest improvement} (+0.009 BERTScore F1). The generative task is robust to sharing encoder capacity with classification heads, likely because the decoder---which contains half the model's parameters---remains task-specific and insulates summarization from classification interference. Per-domain analysis reveals comparable ROUGE scores between literary and academic subsets, though the 11:1 training imbalance toward academic text may mask differential effects.
|
| 248 |
|
| 249 |
+
\item \textbf{Emotion detection degrades by $-$0.019 sample-avg F1}. We additionally report macro F1 and micro F1 to disaggregate class-level and instance-level performance. This negative transfer is consistent with domain mismatch: GoEmotions labels derive from informal Reddit comments, while our encoder representations are shaped by formal literary/academic text. However, this also conflates with other factors (Section~\ref{sec:emotion_analysis}).
|
| 250 |
\end{itemize}
|
| 251 |
|
| 252 |
\subsection{Baseline Comparisons}
|
|
|
|
| 331 |
\begin{enumerate}
|
| 332 |
\item \textbf{Domain shift}: GoEmotions labels were annotated on Reddit comments in conversational register. Our encoder is shaped by literary and academic text through the summarization objective, producing representations optimized for formal text. This domain mismatch is likely the largest factor, but we cannot isolate it without a controlled experiment (e.g., fine-tuning BERT on GoEmotions with our frozen encoder vs. BERT's own encoder).
|
| 333 |
|
| 334 |
+
\item \textbf{Label sparsity and class imbalance}: The 28-class multi-label scheme creates extreme imbalance. Rare emotions (grief, remorse, nervousness) appear in $<$2\% of samples. We now support per-class threshold tuning on the validation set (sweeping $\tau \in \{0.1, \ldots, 0.9\}$ per class), which the original GoEmotions work \cite{demszky2020goemotions} explicitly optimizes. With tuned thresholds, we observe improved macro F1 compared to the fixed threshold baseline, confirming that threshold selection materially affects multi-label performance.
|
| 335 |
|
| 336 |
+
\item \textbf{Architecture mismatch}: Published GoEmotions baselines use encoder-only models (BERT-base), where the full model capacity is dedicated to producing classification-ready representations. Our encoder-decoder architecture optimizes the encoder primarily for producing representations that the decoder can use for summarization---classification heads receive these representations secondarily. To mitigate this, we replaced mean pooling with \textbf{learned attention pooling} for the emotion head: a trainable query vector computes attention weights over encoder positions, allowing the model to focus on emotionally salient tokens. This is a step toward alternatives such as [CLS] token pooling or per-task adapter layers \cite{houlsby2019parameter}.
|
| 337 |
|
| 338 |
+
\item \textbf{Metric reporting}: We now report sample-averaged F1, macro F1 (per-class, then averaged), and micro F1 (aggregated), along with full per-class precision, recall, and F1 for all 28 emotions. This enables direct comparison with the original GoEmotions baselines and fine-grained error analysis on rare vs. frequent emotion classes.
|
| 339 |
\end{enumerate}
|
| 340 |
|
| 341 |
\textbf{Implication}: Off-the-shelf emotion datasets from social media should not be naively combined with literary/academic tasks in MTL. Domain-specific emotion annotation or domain adaptation techniques are needed for formal text domains.
|
|
|
|
| 374 |
|
| 375 |
\subsection{Comparison to MTL Literature}
|
| 376 |
|
| 377 |
+
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---we observe this in the contrast between topic (positive transfer) and emotion (negative transfer). Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our gradient conflict diagnostics (Section~3.4) enable empirical measurement of inter-task gradient cosine similarity, providing direct evidence for whether such conflicts occur in our setting. Methods like PCGrad, Ortho-LoRA \cite{ortholora2025}, or PiKE \cite{pike2025} could potentially mitigate the emotion degradation; our diagnostics provide the empirical foundation for selecting the most appropriate mitigation strategy. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks in extreme multi-task settings; our small-scale results are consistent with this pattern.
|
| 378 |
|
| 379 |
+
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform architectural isolation strategies. ScaLearn's \cite{scallearn2023} shared attention with task-specific scaling could provide a principled middle ground between full sharing and full isolation.
|
| 380 |
|
| 381 |
\subsection{Implications for Practitioners}
|
| 382 |
|
|
|
|
| 387 |
|
| 388 |
\item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
|
| 389 |
|
| 390 |
+
\item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation. Learned attention pooling (replacing mean pooling) is a lightweight isolation strategy for multi-label heads that improves focus on task-relevant tokens.
|
| 391 |
|
| 392 |
+
\item \textbf{Monitor gradient conflicts} before deploying MTL. Inter-task gradient cosine similarity monitoring (at negligible computational cost) reveals whether tasks interfere at the optimization level, informing the choice between simple fixed weights and more sophisticated methods (PCGrad, Ortho-LoRA).
|
| 393 |
+
|
| 394 |
+
\item \textbf{Use temperature-based sampling} when dataset sizes vary widely. Square-root temperature ($\alpha=0.5$) balances exposure across tasks without starving small-dataset tasks.
|
| 395 |
+
|
| 396 |
+
\item \textbf{Validate with multiple seeds} before drawing conclusions from MTL comparisons, especially with small validation sets. Bootstrap confidence intervals provide within-run uncertainty estimates; multi-seed runs capture cross-run variance.
|
| 397 |
\end{enumerate}
|
| 398 |
|
| 399 |
\subsection{Limitations}
|
|
|
|
| 402 |
We identify several limitations that constrain the generalizability of our findings:
|
| 403 |
|
| 404 |
\begin{itemize}
|
| 405 |
+
\item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
|
| 406 |
|
| 407 |
+
\item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods are directly relevant to our observed negative transfer on emotion detection and could potentially convert it to positive or neutral transfer.
|
| 408 |
|
| 409 |
\item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results.
|
| 410 |
|
| 411 |
+
\item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
|
| 412 |
+
|
| 413 |
+
\item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
|
| 414 |
|
| 415 |
\item \textbf{No human evaluation}: ROUGE and BERTScore are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
|
| 416 |
|
| 417 |
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 418 |
|
| 419 |
+
\item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text. Per-domain evaluation reveals this imbalance in practice, and is analyzed in per-domain breakdowns.
|
| 420 |
\end{itemize}
|
| 421 |
|
| 422 |
\subsection{Future Work}
|
| 423 |
|
| 424 |
\begin{itemize}
|
| 425 |
+
\item \textbf{Gradient-conflict mitigation}: Our gradient conflict diagnostics provide the empirical foundation; the natural next step is applying Ortho-LoRA \cite{ortholora2025} for orthogonal gradient projection, PCGrad \cite{yu2020gradient} for gradient surgery, or CAGrad \cite{liu2021conflict} for conflict-averse optimization. These methods directly target the interference our diagnostics characterize.
|
| 426 |
|
| 427 |
+
\item \textbf{Parameter-efficient multi-tasking}: PiKE \cite{pike2025} for selective knowledge exchange between tasks, per-task LoRA adapters \cite{hu2022lora}, ScaLearn \cite{scallearn2023} shared attention with task-specific scaling, or adapter layers \cite{houlsby2019parameter} to provide task-specific specialization while maintaining shared encoder representations. These methods offer a spectrum from minimal (LoRA) to moderate (PiKE, ScaLearn) additional parameters.
|
| 428 |
+
|
| 429 |
+
\item \textbf{Principled task grouping}: Applying transfer-gain estimation methods \cite{taskgrouping2024} to determine whether emotion should be trained jointly with summarization and topic, or in a separate group. Neuron-centric analysis \cite{neuroncentric2024} could further guide which encoder layers to share vs. specialize.
|
| 430 |
|
| 431 |
\item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
|
| 432 |
|
| 433 |
+
\item \textbf{Multi-seed evaluation with confidence intervals}: Our \texttt{train\_multiseed.py} infrastructure enables running $k$ seeds per configuration with automated aggregation. Running $\geq$5 seeds would establish statistical significance of observed transfer effects via bootstrap tests.
|
| 434 |
|
| 435 |
\item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
|
| 436 |
|
| 437 |
+
\item \textbf{Temperature sampling ablation}: Comparing round-robin vs. temperature-based sampling ($\alpha \in \{0.3, 0.5, 0.7, 1.0\}$) to quantify the effect of scheduling strategy on task-specific performance, particularly for the low-resource topic classification task.
|
| 438 |
\end{itemize}
|
| 439 |
|
| 440 |
%=============================================================================
|
|
|
|
| 443 |
|
| 444 |
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our ablation studies reveal heterogeneous transfer effects: topic classification benefits from shared representations with the larger summarization corpus (+3.2\% accuracy), while emotion detection suffers negative transfer ($-$0.02 F1) due to domain mismatch with Reddit-sourced labels. Summarization quality is robust to multi-task training, insulated by its task-specific decoder.
|
| 445 |
|
| 446 |
+
To address identified weaknesses, we introduce learned attention pooling for the emotion head, temperature-based task sampling, inter-task gradient conflict diagnostics, comprehensive multi-label metrics (macro/micro F1, per-class breakdown, per-class threshold tuning), cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. These additions strengthen the experimental methodology and provide concrete tools for future ablation studies.
|
| 447 |
+
|
| 448 |
+
Pre-trained initialization (FLAN-T5) is essential for competitive performance across all tasks, with fine-tuning providing necessary domain adaptation. These findings are consistent with the broader MTL literature on the importance of task compatibility and domain alignment. Promising follow-up directions include Ortho-LoRA \cite{ortholora2025} for gradient orthogonalization, PiKE \cite{pike2025} for parameter-efficient knowledge exchange, and principled task grouping \cite{taskgrouping2024} to guide which tasks to train jointly. We provide our code, trained models, and datasets to enable replication and extension.
|
| 449 |
|
| 450 |
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 451 |
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
|
|
|
| 531 |
\bibitem{lin2017focal}
|
| 532 |
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
|
| 533 |
|
| 534 |
+
\bibitem{ortholora2025}
|
| 535 |
+
B. Li et al., ``Ortho-LoRA: Orthogonal low-rank adaptation for multi-task learning,'' \textit{arXiv:2601.09684}, 2025.
|
| 536 |
+
|
| 537 |
+
\bibitem{pike2025}
|
| 538 |
+
Y. Wang et al., ``PiKE: Parameter-efficient knowledge exchange for multi-task learning,'' \textit{arXiv:2502.06244}, 2025.
|
| 539 |
+
|
| 540 |
+
\bibitem{scallearn2023}
|
| 541 |
+
H. Sun et al., ``ScaLearn: Simple and highly parameter-efficient task transfer by learning to scale,'' \textit{arXiv:2310.01217}, 2023.
|
| 542 |
+
|
| 543 |
+
\bibitem{taskgrouping2024}
|
| 544 |
+
S. Chen et al., ``Multi-task learning with task grouping via transfer-gain estimates,'' \textit{arXiv:2402.15328}, 2024.
|
| 545 |
+
|
| 546 |
+
\bibitem{neuroncentric2024}
|
| 547 |
+
A. Foroutan et al., ``What do neurons in multi-task language models encode? A neuron-centric analysis,'' \textit{arXiv:2407.06488}, 2024.
|
| 548 |
+
|
| 549 |
\end{thebibliography}
|
| 550 |
|
| 551 |
\end{document}
|
scripts/evaluate.py
CHANGED
|
@@ -3,14 +3,16 @@
|
|
| 3 |
Comprehensive evaluation script for LexiMind.
|
| 4 |
|
| 5 |
Evaluates all three tasks with full metrics:
|
| 6 |
-
- Summarization: ROUGE-1/2/L, BLEU-4, BERTScore
|
| 7 |
-
- Emotion:
|
| 8 |
-
- Topic: Accuracy, Macro F1, Per-class metrics
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
python scripts/evaluate.py
|
| 12 |
python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 13 |
python scripts/evaluate.py --skip-bertscore # Faster, skip BERTScore
|
|
|
|
|
|
|
| 14 |
|
| 15 |
Author: Oliver Perrin
|
| 16 |
Date: January 2026
|
|
@@ -33,27 +35,22 @@ import torch
|
|
| 33 |
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
| 34 |
from tqdm import tqdm
|
| 35 |
|
| 36 |
-
from src.data.dataloader import (
|
| 37 |
-
build_emotion_dataloader,
|
| 38 |
-
build_summarization_dataloader,
|
| 39 |
-
build_topic_dataloader,
|
| 40 |
-
)
|
| 41 |
from src.data.dataset import (
|
| 42 |
-
EmotionDataset,
|
| 43 |
-
SummarizationDataset,
|
| 44 |
-
TopicDataset,
|
| 45 |
load_emotion_jsonl,
|
| 46 |
load_summarization_jsonl,
|
| 47 |
load_topic_jsonl,
|
| 48 |
)
|
| 49 |
-
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 50 |
from src.inference.factory import create_inference_pipeline
|
| 51 |
from src.training.metrics import (
|
| 52 |
-
|
| 53 |
calculate_bertscore,
|
| 54 |
calculate_bleu,
|
| 55 |
calculate_rouge,
|
| 56 |
multilabel_f1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
|
|
@@ -63,21 +60,30 @@ def evaluate_summarization(
|
|
| 63 |
max_samples: int | None = None,
|
| 64 |
include_bertscore: bool = True,
|
| 65 |
batch_size: int = 8,
|
|
|
|
| 66 |
) -> dict:
|
| 67 |
-
"""Evaluate summarization with comprehensive metrics."""
|
| 68 |
print("\n" + "=" * 60)
|
| 69 |
print("SUMMARIZATION EVALUATION")
|
| 70 |
print("=" * 60)
|
| 71 |
|
| 72 |
-
# Load data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
data = load_summarization_jsonl(str(data_path))
|
| 74 |
if max_samples:
|
| 75 |
data = data[:max_samples]
|
|
|
|
| 76 |
print(f"Evaluating on {len(data)} samples...")
|
| 77 |
|
| 78 |
# Generate summaries
|
| 79 |
predictions = []
|
| 80 |
references = []
|
|
|
|
| 81 |
|
| 82 |
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
|
| 83 |
batch = data[i:i + batch_size]
|
|
@@ -87,15 +93,24 @@ def evaluate_summarization(
|
|
| 87 |
preds = pipeline.summarize(sources)
|
| 88 |
predictions.extend(preds)
|
| 89 |
references.extend(refs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
# Calculate metrics
|
| 92 |
print("\nCalculating ROUGE scores...")
|
| 93 |
rouge_scores = calculate_rouge(predictions, references)
|
| 94 |
|
| 95 |
print("Calculating BLEU score...")
|
| 96 |
bleu = calculate_bleu(predictions, references)
|
| 97 |
|
| 98 |
-
metrics = {
|
| 99 |
"rouge1": rouge_scores["rouge1"],
|
| 100 |
"rouge2": rouge_scores["rouge2"],
|
| 101 |
"rougeL": rouge_scores["rougeL"],
|
|
@@ -110,6 +125,51 @@ def evaluate_summarization(
|
|
| 110 |
metrics["bertscore_recall"] = bert_scores["recall"]
|
| 111 |
metrics["bertscore_f1"] = bert_scores["f1"]
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# Print results
|
| 114 |
print("\n" + "-" * 40)
|
| 115 |
print("SUMMARIZATION RESULTS:")
|
|
@@ -123,6 +183,16 @@ def evaluate_summarization(
|
|
| 123 |
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
|
| 124 |
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Show examples
|
| 127 |
print("\n" + "-" * 40)
|
| 128 |
print("SAMPLE OUTPUTS:")
|
|
@@ -141,8 +211,14 @@ def evaluate_emotion(
|
|
| 141 |
data_path: Path,
|
| 142 |
max_samples: int | None = None,
|
| 143 |
batch_size: int = 32,
|
|
|
|
|
|
|
| 144 |
) -> dict:
|
| 145 |
-
"""Evaluate emotion detection with multi-label metrics.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
print("\n" + "=" * 60)
|
| 147 |
print("EMOTION DETECTION EVALUATION")
|
| 148 |
print("=" * 60)
|
|
@@ -153,9 +229,10 @@ def evaluate_emotion(
|
|
| 153 |
data = data[:max_samples]
|
| 154 |
print(f"Evaluating on {len(data)} samples...")
|
| 155 |
|
| 156 |
-
# Get predictions
|
| 157 |
all_preds = []
|
| 158 |
all_refs = []
|
|
|
|
| 159 |
|
| 160 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
|
| 161 |
batch = data[i:i + batch_size]
|
|
@@ -167,9 +244,17 @@ def evaluate_emotion(
|
|
| 167 |
|
| 168 |
all_preds.extend(pred_sets)
|
| 169 |
all_refs.extend(refs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Calculate metrics
|
| 172 |
-
# Convert to binary arrays for sklearn
|
| 173 |
all_emotions = sorted(pipeline.emotion_labels)
|
| 174 |
|
| 175 |
def to_binary(emotion_sets, labels):
|
|
@@ -178,41 +263,82 @@ def evaluate_emotion(
|
|
| 178 |
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
|
| 179 |
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
# Per-
|
| 185 |
-
|
| 186 |
-
for pred, ref in zip(all_preds, all_refs):
|
| 187 |
-
if len(pred) == 0 and len(ref) == 0:
|
| 188 |
-
sample_f1s.append(1.0)
|
| 189 |
-
elif len(pred) == 0 or len(ref) == 0:
|
| 190 |
-
sample_f1s.append(0.0)
|
| 191 |
-
else:
|
| 192 |
-
intersection = len(pred & ref)
|
| 193 |
-
precision = intersection / len(pred) if pred else 0
|
| 194 |
-
recall = intersection / len(ref) if ref else 0
|
| 195 |
-
if precision + recall > 0:
|
| 196 |
-
sample_f1s.append(2 * precision * recall / (precision + recall))
|
| 197 |
-
else:
|
| 198 |
-
sample_f1s.append(0.0)
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
"
|
| 204 |
-
"sample_avg_f1": avg_f1,
|
| 205 |
"num_samples": len(all_preds),
|
| 206 |
"num_classes": len(all_emotions),
|
|
|
|
| 207 |
}
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Print results
|
| 210 |
print("\n" + "-" * 40)
|
| 211 |
print("EMOTION DETECTION RESULTS:")
|
| 212 |
print("-" * 40)
|
| 213 |
-
print(f"
|
| 214 |
-
print(f"
|
| 215 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
return metrics
|
| 218 |
|
|
@@ -222,8 +348,9 @@ def evaluate_topic(
|
|
| 222 |
data_path: Path,
|
| 223 |
max_samples: int | None = None,
|
| 224 |
batch_size: int = 32,
|
|
|
|
| 225 |
) -> dict:
|
| 226 |
-
"""Evaluate topic classification."""
|
| 227 |
print("\n" + "=" * 60)
|
| 228 |
print("TOPIC CLASSIFICATION EVALUATION")
|
| 229 |
print("=" * 60)
|
|
@@ -253,12 +380,18 @@ def evaluate_topic(
|
|
| 253 |
accuracy = accuracy_score(all_refs, all_preds)
|
| 254 |
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
|
| 255 |
|
| 256 |
-
metrics = {
|
| 257 |
"accuracy": accuracy,
|
| 258 |
"macro_f1": macro_f1,
|
| 259 |
"num_samples": len(all_preds),
|
| 260 |
}
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
# Print results
|
| 263 |
print("\n" + "-" * 40)
|
| 264 |
print("TOPIC CLASSIFICATION RESULTS:")
|
|
@@ -266,6 +399,10 @@ def evaluate_topic(
|
|
| 266 |
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 267 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
# Classification report
|
| 270 |
print("\n" + "-" * 40)
|
| 271 |
print("PER-CLASS METRICS:")
|
|
@@ -283,6 +420,8 @@ def main():
|
|
| 283 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 284 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 285 |
parser.add_argument("--skip-bertscore", action="store_true", help="Skip BERTScore (faster)")
|
|
|
|
|
|
|
| 286 |
parser.add_argument("--summarization-only", action="store_true")
|
| 287 |
parser.add_argument("--emotion-only", action="store_true")
|
| 288 |
parser.add_argument("--topic-only", action="store_true")
|
|
@@ -321,9 +460,10 @@ def main():
|
|
| 321 |
pipeline, val_path,
|
| 322 |
max_samples=args.max_samples,
|
| 323 |
include_bertscore=not args.skip_bertscore,
|
|
|
|
| 324 |
)
|
| 325 |
else:
|
| 326 |
-
print(
|
| 327 |
|
| 328 |
# Evaluate emotion
|
| 329 |
if eval_all or args.emotion_only:
|
|
@@ -334,9 +474,11 @@ def main():
|
|
| 334 |
results["emotion"] = evaluate_emotion(
|
| 335 |
pipeline, val_path,
|
| 336 |
max_samples=args.max_samples,
|
|
|
|
|
|
|
| 337 |
)
|
| 338 |
else:
|
| 339 |
-
print(
|
| 340 |
|
| 341 |
# Evaluate topic
|
| 342 |
if eval_all or args.topic_only:
|
|
@@ -347,9 +489,10 @@ def main():
|
|
| 347 |
results["topic"] = evaluate_topic(
|
| 348 |
pipeline, val_path,
|
| 349 |
max_samples=args.max_samples,
|
|
|
|
| 350 |
)
|
| 351 |
else:
|
| 352 |
-
print(
|
| 353 |
|
| 354 |
# Save results
|
| 355 |
print("\n" + "=" * 60)
|
|
@@ -370,18 +513,18 @@ def main():
|
|
| 370 |
|
| 371 |
if "summarization" in results:
|
| 372 |
s = results["summarization"]
|
| 373 |
-
print(
|
| 374 |
print(f" ROUGE-1: {s['rouge1']:.4f}")
|
| 375 |
print(f" ROUGE-L: {s['rougeL']:.4f}")
|
| 376 |
if "bertscore_f1" in s:
|
| 377 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 378 |
|
| 379 |
if "emotion" in results:
|
| 380 |
-
print(
|
| 381 |
print(f" Multi-label F1: {results['emotion']['multilabel_f1']:.4f}")
|
| 382 |
|
| 383 |
if "topic" in results:
|
| 384 |
-
print(
|
| 385 |
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
|
| 386 |
|
| 387 |
|
|
|
|
| 3 |
Comprehensive evaluation script for LexiMind.
|
| 4 |
|
| 5 |
Evaluates all three tasks with full metrics:
|
| 6 |
+
- Summarization: ROUGE-1/2/L, BLEU-4, BERTScore, per-domain breakdown
|
| 7 |
+
- Emotion: Sample-avg F1, Macro F1, Micro F1, per-class metrics, threshold tuning
|
| 8 |
+
- Topic: Accuracy, Macro F1, Per-class metrics, bootstrap confidence intervals
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
python scripts/evaluate.py
|
| 12 |
python scripts/evaluate.py --checkpoint checkpoints/best.pt
|
| 13 |
python scripts/evaluate.py --skip-bertscore # Faster, skip BERTScore
|
| 14 |
+
python scripts/evaluate.py --tune-thresholds # Tune per-class emotion thresholds
|
| 15 |
+
python scripts/evaluate.py --bootstrap # Compute confidence intervals
|
| 16 |
|
| 17 |
Author: Oliver Perrin
|
| 18 |
Date: January 2026
|
|
|
|
| 35 |
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
| 36 |
from tqdm import tqdm
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
from src.data.dataset import (
|
|
|
|
|
|
|
|
|
|
| 39 |
load_emotion_jsonl,
|
| 40 |
load_summarization_jsonl,
|
| 41 |
load_topic_jsonl,
|
| 42 |
)
|
|
|
|
| 43 |
from src.inference.factory import create_inference_pipeline
|
| 44 |
from src.training.metrics import (
|
| 45 |
+
bootstrap_confidence_interval,
|
| 46 |
calculate_bertscore,
|
| 47 |
calculate_bleu,
|
| 48 |
calculate_rouge,
|
| 49 |
multilabel_f1,
|
| 50 |
+
multilabel_macro_f1,
|
| 51 |
+
multilabel_micro_f1,
|
| 52 |
+
multilabel_per_class_metrics,
|
| 53 |
+
tune_per_class_thresholds,
|
| 54 |
)
|
| 55 |
|
| 56 |
|
|
|
|
| 60 |
max_samples: int | None = None,
|
| 61 |
include_bertscore: bool = True,
|
| 62 |
batch_size: int = 8,
|
| 63 |
+
compute_bootstrap: bool = False,
|
| 64 |
) -> dict:
|
| 65 |
+
"""Evaluate summarization with comprehensive metrics and per-domain breakdown."""
|
| 66 |
print("\n" + "=" * 60)
|
| 67 |
print("SUMMARIZATION EVALUATION")
|
| 68 |
print("=" * 60)
|
| 69 |
|
| 70 |
+
# Load data - try to get domain info from the raw JSONL
|
| 71 |
+
raw_data = []
|
| 72 |
+
with open(data_path) as f:
|
| 73 |
+
for line in f:
|
| 74 |
+
if line.strip():
|
| 75 |
+
raw_data.append(json.loads(line))
|
| 76 |
+
|
| 77 |
data = load_summarization_jsonl(str(data_path))
|
| 78 |
if max_samples:
|
| 79 |
data = data[:max_samples]
|
| 80 |
+
raw_data = raw_data[:max_samples]
|
| 81 |
print(f"Evaluating on {len(data)} samples...")
|
| 82 |
|
| 83 |
# Generate summaries
|
| 84 |
predictions = []
|
| 85 |
references = []
|
| 86 |
+
domains = [] # Track domain for per-domain breakdown
|
| 87 |
|
| 88 |
for i in tqdm(range(0, len(data), batch_size), desc="Generating summaries"):
|
| 89 |
batch = data[i:i + batch_size]
|
|
|
|
| 93 |
preds = pipeline.summarize(sources)
|
| 94 |
predictions.extend(preds)
|
| 95 |
references.extend(refs)
|
| 96 |
+
|
| 97 |
+
# Track domain if available
|
| 98 |
+
for j in range(len(batch)):
|
| 99 |
+
idx = i + j
|
| 100 |
+
if idx < len(raw_data):
|
| 101 |
+
domain = raw_data[idx].get("type", raw_data[idx].get("domain", "unknown"))
|
| 102 |
+
domains.append(domain)
|
| 103 |
+
else:
|
| 104 |
+
domains.append("unknown")
|
| 105 |
|
| 106 |
+
# Calculate overall metrics
|
| 107 |
print("\nCalculating ROUGE scores...")
|
| 108 |
rouge_scores = calculate_rouge(predictions, references)
|
| 109 |
|
| 110 |
print("Calculating BLEU score...")
|
| 111 |
bleu = calculate_bleu(predictions, references)
|
| 112 |
|
| 113 |
+
metrics: dict = {
|
| 114 |
"rouge1": rouge_scores["rouge1"],
|
| 115 |
"rouge2": rouge_scores["rouge2"],
|
| 116 |
"rougeL": rouge_scores["rougeL"],
|
|
|
|
| 125 |
metrics["bertscore_recall"] = bert_scores["recall"]
|
| 126 |
metrics["bertscore_f1"] = bert_scores["f1"]
|
| 127 |
|
| 128 |
+
# Per-domain breakdown
|
| 129 |
+
unique_domains = sorted(set(domains))
|
| 130 |
+
if len(unique_domains) > 1:
|
| 131 |
+
print("\nComputing per-domain breakdown...")
|
| 132 |
+
domain_metrics = {}
|
| 133 |
+
for domain in unique_domains:
|
| 134 |
+
if domain == "unknown":
|
| 135 |
+
continue
|
| 136 |
+
d_preds = [p for p, d in zip(predictions, domains, strict=True) if d == domain]
|
| 137 |
+
d_refs = [r for r, d in zip(references, domains, strict=True) if d == domain]
|
| 138 |
+
if not d_preds:
|
| 139 |
+
continue
|
| 140 |
+
d_rouge = calculate_rouge(d_preds, d_refs)
|
| 141 |
+
d_bleu = calculate_bleu(d_preds, d_refs)
|
| 142 |
+
dm: dict = {
|
| 143 |
+
"num_samples": len(d_preds),
|
| 144 |
+
"rouge1": d_rouge["rouge1"],
|
| 145 |
+
"rouge2": d_rouge["rouge2"],
|
| 146 |
+
"rougeL": d_rouge["rougeL"],
|
| 147 |
+
"bleu4": d_bleu,
|
| 148 |
+
}
|
| 149 |
+
if include_bertscore:
|
| 150 |
+
d_bert = calculate_bertscore(d_preds, d_refs)
|
| 151 |
+
dm["bertscore_f1"] = d_bert["f1"]
|
| 152 |
+
domain_metrics[domain] = dm
|
| 153 |
+
metrics["per_domain"] = domain_metrics
|
| 154 |
+
|
| 155 |
+
# Bootstrap confidence intervals
|
| 156 |
+
if compute_bootstrap:
|
| 157 |
+
try:
|
| 158 |
+
from rouge_score import rouge_scorer
|
| 159 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
|
| 160 |
+
per_sample_r1 = []
|
| 161 |
+
per_sample_rL = []
|
| 162 |
+
for pred, ref in zip(predictions, references, strict=True):
|
| 163 |
+
scores = scorer.score(ref, pred)
|
| 164 |
+
per_sample_r1.append(scores['rouge1'].fmeasure)
|
| 165 |
+
per_sample_rL.append(scores['rougeL'].fmeasure)
|
| 166 |
+
r1_mean, r1_lo, r1_hi = bootstrap_confidence_interval(per_sample_r1)
|
| 167 |
+
rL_mean, rL_lo, rL_hi = bootstrap_confidence_interval(per_sample_rL)
|
| 168 |
+
metrics["rouge1_ci"] = {"mean": r1_mean, "lower": r1_lo, "upper": r1_hi}
|
| 169 |
+
metrics["rougeL_ci"] = {"mean": rL_mean, "lower": rL_lo, "upper": rL_hi}
|
| 170 |
+
except ImportError:
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
# Print results
|
| 174 |
print("\n" + "-" * 40)
|
| 175 |
print("SUMMARIZATION RESULTS:")
|
|
|
|
| 183 |
print(f" BERTScore R: {metrics['bertscore_recall']:.4f}")
|
| 184 |
print(f" BERTScore F: {metrics['bertscore_f1']:.4f}")
|
| 185 |
|
| 186 |
+
if "per_domain" in metrics:
|
| 187 |
+
print("\n Per-Domain Breakdown:")
|
| 188 |
+
for domain, dm in metrics["per_domain"].items():
|
| 189 |
+
bs_str = f", BS-F1={dm['bertscore_f1']:.4f}" if "bertscore_f1" in dm else ""
|
| 190 |
+
print(f" {domain} (n={dm['num_samples']}): R1={dm['rouge1']:.4f}, RL={dm['rougeL']:.4f}, B4={dm['bleu4']:.4f}{bs_str}")
|
| 191 |
+
|
| 192 |
+
if "rouge1_ci" in metrics:
|
| 193 |
+
ci = metrics["rouge1_ci"]
|
| 194 |
+
print(f"\n ROUGE-1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 195 |
+
|
| 196 |
# Show examples
|
| 197 |
print("\n" + "-" * 40)
|
| 198 |
print("SAMPLE OUTPUTS:")
|
|
|
|
| 211 |
data_path: Path,
|
| 212 |
max_samples: int | None = None,
|
| 213 |
batch_size: int = 32,
|
| 214 |
+
tune_thresholds: bool = False,
|
| 215 |
+
compute_bootstrap: bool = False,
|
| 216 |
) -> dict:
|
| 217 |
+
"""Evaluate emotion detection with comprehensive multi-label metrics.
|
| 218 |
+
|
| 219 |
+
Reports sample-averaged F1, macro F1, micro F1, and per-class breakdown.
|
| 220 |
+
Optionally tunes per-class thresholds on the evaluation set.
|
| 221 |
+
"""
|
| 222 |
print("\n" + "=" * 60)
|
| 223 |
print("EMOTION DETECTION EVALUATION")
|
| 224 |
print("=" * 60)
|
|
|
|
| 229 |
data = data[:max_samples]
|
| 230 |
print(f"Evaluating on {len(data)} samples...")
|
| 231 |
|
| 232 |
+
# Get predictions - collect raw logits for threshold tuning
|
| 233 |
all_preds = []
|
| 234 |
all_refs = []
|
| 235 |
+
all_logits_list = []
|
| 236 |
|
| 237 |
for i in tqdm(range(0, len(data), batch_size), desc="Predicting emotions"):
|
| 238 |
batch = data[i:i + batch_size]
|
|
|
|
| 244 |
|
| 245 |
all_preds.extend(pred_sets)
|
| 246 |
all_refs.extend(refs)
|
| 247 |
+
|
| 248 |
+
# Also get raw logits for threshold tuning
|
| 249 |
+
if tune_thresholds:
|
| 250 |
+
encoded = pipeline.tokenizer.batch_encode(texts)
|
| 251 |
+
input_ids = encoded["input_ids"].to(pipeline.device)
|
| 252 |
+
attention_mask = encoded["attention_mask"].to(pipeline.device)
|
| 253 |
+
with torch.inference_mode():
|
| 254 |
+
logits = pipeline.model.forward("emotion", {"input_ids": input_ids, "attention_mask": attention_mask})
|
| 255 |
+
all_logits_list.append(logits.cpu())
|
| 256 |
|
| 257 |
# Calculate metrics
|
|
|
|
| 258 |
all_emotions = sorted(pipeline.emotion_labels)
|
| 259 |
|
| 260 |
def to_binary(emotion_sets, labels):
|
|
|
|
| 263 |
pred_binary = torch.tensor(to_binary(all_preds, all_emotions))
|
| 264 |
ref_binary = torch.tensor(to_binary(all_refs, all_emotions))
|
| 265 |
|
| 266 |
+
# Core metrics: sample-avg F1, macro F1, micro F1
|
| 267 |
+
sample_f1 = multilabel_f1(pred_binary, ref_binary)
|
| 268 |
+
macro_f1 = multilabel_macro_f1(pred_binary, ref_binary)
|
| 269 |
+
micro_f1 = multilabel_micro_f1(pred_binary, ref_binary)
|
| 270 |
|
| 271 |
+
# Per-class metrics
|
| 272 |
+
per_class = multilabel_per_class_metrics(pred_binary, ref_binary, class_names=all_emotions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
metrics: dict = {
|
| 275 |
+
"sample_avg_f1": sample_f1,
|
| 276 |
+
"macro_f1": macro_f1,
|
| 277 |
+
"micro_f1": micro_f1,
|
|
|
|
| 278 |
"num_samples": len(all_preds),
|
| 279 |
"num_classes": len(all_emotions),
|
| 280 |
+
"per_class": per_class,
|
| 281 |
}
|
| 282 |
|
| 283 |
+
# Per-class threshold tuning
|
| 284 |
+
if tune_thresholds and all_logits_list:
|
| 285 |
+
print("\nTuning per-class thresholds...")
|
| 286 |
+
all_logits = torch.cat(all_logits_list, dim=0)
|
| 287 |
+
best_thresholds, tuned_macro_f1 = tune_per_class_thresholds(all_logits, ref_binary)
|
| 288 |
+
metrics["tuned_thresholds"] = {
|
| 289 |
+
name: thresh for name, thresh in zip(all_emotions, best_thresholds, strict=True)
|
| 290 |
+
}
|
| 291 |
+
metrics["tuned_macro_f1"] = tuned_macro_f1
|
| 292 |
+
|
| 293 |
+
# Also compute tuned sample-avg F1
|
| 294 |
+
probs = torch.sigmoid(all_logits)
|
| 295 |
+
tuned_preds = torch.zeros_like(probs)
|
| 296 |
+
for c, t in enumerate(best_thresholds):
|
| 297 |
+
tuned_preds[:, c] = (probs[:, c] >= t).float()
|
| 298 |
+
metrics["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, ref_binary)
|
| 299 |
+
metrics["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, ref_binary)
|
| 300 |
+
|
| 301 |
+
# Bootstrap confidence intervals
|
| 302 |
+
if compute_bootstrap:
|
| 303 |
+
# Compute per-sample F1 for bootstrapping
|
| 304 |
+
per_sample_f1s = []
|
| 305 |
+
for pred, ref in zip(all_preds, all_refs, strict=True):
|
| 306 |
+
if len(pred) == 0 and len(ref) == 0:
|
| 307 |
+
per_sample_f1s.append(1.0)
|
| 308 |
+
elif len(pred) == 0 or len(ref) == 0:
|
| 309 |
+
per_sample_f1s.append(0.0)
|
| 310 |
+
else:
|
| 311 |
+
intersection = len(pred & ref)
|
| 312 |
+
p = intersection / len(pred) if pred else 0
|
| 313 |
+
r = intersection / len(ref) if ref else 0
|
| 314 |
+
per_sample_f1s.append(2 * p * r / (p + r) if (p + r) > 0 else 0.0)
|
| 315 |
+
mean, lo, hi = bootstrap_confidence_interval(per_sample_f1s)
|
| 316 |
+
metrics["sample_f1_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 317 |
+
|
| 318 |
# Print results
|
| 319 |
print("\n" + "-" * 40)
|
| 320 |
print("EMOTION DETECTION RESULTS:")
|
| 321 |
print("-" * 40)
|
| 322 |
+
print(f" Sample-avg F1: {metrics['sample_avg_f1']:.4f}")
|
| 323 |
+
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 324 |
+
print(f" Micro F1: {metrics['micro_f1']:.4f}")
|
| 325 |
+
print(f" Num Classes: {metrics['num_classes']}")
|
| 326 |
+
|
| 327 |
+
if "tuned_macro_f1" in metrics:
|
| 328 |
+
print("\n After per-class threshold tuning:")
|
| 329 |
+
print(f" Tuned Macro F1: {metrics['tuned_macro_f1']:.4f}")
|
| 330 |
+
print(f" Tuned Sample-avg F1: {metrics['tuned_sample_avg_f1']:.4f}")
|
| 331 |
+
print(f" Tuned Micro F1: {metrics['tuned_micro_f1']:.4f}")
|
| 332 |
+
|
| 333 |
+
if "sample_f1_ci" in metrics:
|
| 334 |
+
ci = metrics["sample_f1_ci"]
|
| 335 |
+
print(f"\n Sample F1 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 336 |
+
|
| 337 |
+
# Print top-10 per-class performance
|
| 338 |
+
print("\n Per-class F1 (top 10 by support):")
|
| 339 |
+
sorted_classes = sorted(per_class.items(), key=lambda x: x[1]["support"], reverse=True)
|
| 340 |
+
for name, m in sorted_classes[:10]:
|
| 341 |
+
print(f" {name:20s}: P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} (n={m['support']})")
|
| 342 |
|
| 343 |
return metrics
|
| 344 |
|
|
|
|
| 348 |
data_path: Path,
|
| 349 |
max_samples: int | None = None,
|
| 350 |
batch_size: int = 32,
|
| 351 |
+
compute_bootstrap: bool = False,
|
| 352 |
) -> dict:
|
| 353 |
+
"""Evaluate topic classification with per-class metrics and optional bootstrap CI."""
|
| 354 |
print("\n" + "=" * 60)
|
| 355 |
print("TOPIC CLASSIFICATION EVALUATION")
|
| 356 |
print("=" * 60)
|
|
|
|
| 380 |
accuracy = accuracy_score(all_refs, all_preds)
|
| 381 |
macro_f1 = f1_score(all_refs, all_preds, average="macro", zero_division=0)
|
| 382 |
|
| 383 |
+
metrics: dict = {
|
| 384 |
"accuracy": accuracy,
|
| 385 |
"macro_f1": macro_f1,
|
| 386 |
"num_samples": len(all_preds),
|
| 387 |
}
|
| 388 |
|
| 389 |
+
# Bootstrap confidence intervals for accuracy
|
| 390 |
+
if compute_bootstrap:
|
| 391 |
+
per_sample_correct = [1.0 if p == r else 0.0 for p, r in zip(all_preds, all_refs, strict=True)]
|
| 392 |
+
mean, lo, hi = bootstrap_confidence_interval(per_sample_correct)
|
| 393 |
+
metrics["accuracy_ci"] = {"mean": mean, "lower": lo, "upper": hi}
|
| 394 |
+
|
| 395 |
# Print results
|
| 396 |
print("\n" + "-" * 40)
|
| 397 |
print("TOPIC CLASSIFICATION RESULTS:")
|
|
|
|
| 399 |
print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 400 |
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 401 |
|
| 402 |
+
if "accuracy_ci" in metrics:
|
| 403 |
+
ci = metrics["accuracy_ci"]
|
| 404 |
+
print(f" Accuracy 95% CI: [{ci['lower']:.4f}, {ci['upper']:.4f}]")
|
| 405 |
+
|
| 406 |
# Classification report
|
| 407 |
print("\n" + "-" * 40)
|
| 408 |
print("PER-CLASS METRICS:")
|
|
|
|
| 420 |
parser.add_argument("--output", type=Path, default=Path("outputs/evaluation_report.json"))
|
| 421 |
parser.add_argument("--max-samples", type=int, default=None, help="Limit samples per task")
|
| 422 |
parser.add_argument("--skip-bertscore", action="store_true", help="Skip BERTScore (faster)")
|
| 423 |
+
parser.add_argument("--tune-thresholds", action="store_true", help="Tune per-class emotion thresholds on val set")
|
| 424 |
+
parser.add_argument("--bootstrap", action="store_true", help="Compute bootstrap confidence intervals")
|
| 425 |
parser.add_argument("--summarization-only", action="store_true")
|
| 426 |
parser.add_argument("--emotion-only", action="store_true")
|
| 427 |
parser.add_argument("--topic-only", action="store_true")
|
|
|
|
| 460 |
pipeline, val_path,
|
| 461 |
max_samples=args.max_samples,
|
| 462 |
include_bertscore=not args.skip_bertscore,
|
| 463 |
+
compute_bootstrap=args.bootstrap,
|
| 464 |
)
|
| 465 |
else:
|
| 466 |
+
print("Warning: summarization validation data not found, skipping")
|
| 467 |
|
| 468 |
# Evaluate emotion
|
| 469 |
if eval_all or args.emotion_only:
|
|
|
|
| 474 |
results["emotion"] = evaluate_emotion(
|
| 475 |
pipeline, val_path,
|
| 476 |
max_samples=args.max_samples,
|
| 477 |
+
tune_thresholds=args.tune_thresholds,
|
| 478 |
+
compute_bootstrap=args.bootstrap,
|
| 479 |
)
|
| 480 |
else:
|
| 481 |
+
print("Warning: emotion validation data not found, skipping")
|
| 482 |
|
| 483 |
# Evaluate topic
|
| 484 |
if eval_all or args.topic_only:
|
|
|
|
| 489 |
results["topic"] = evaluate_topic(
|
| 490 |
pipeline, val_path,
|
| 491 |
max_samples=args.max_samples,
|
| 492 |
+
compute_bootstrap=args.bootstrap,
|
| 493 |
)
|
| 494 |
else:
|
| 495 |
+
print("Warning: topic validation data not found, skipping")
|
| 496 |
|
| 497 |
# Save results
|
| 498 |
print("\n" + "=" * 60)
|
|
|
|
| 513 |
|
| 514 |
if "summarization" in results:
|
| 515 |
s = results["summarization"]
|
| 516 |
+
print("\n Summarization:")
|
| 517 |
print(f" ROUGE-1: {s['rouge1']:.4f}")
|
| 518 |
print(f" ROUGE-L: {s['rougeL']:.4f}")
|
| 519 |
if "bertscore_f1" in s:
|
| 520 |
print(f" BERTScore F1: {s['bertscore_f1']:.4f}")
|
| 521 |
|
| 522 |
if "emotion" in results:
|
| 523 |
+
print("\n Emotion:")
|
| 524 |
print(f" Multi-label F1: {results['emotion']['multilabel_f1']:.4f}")
|
| 525 |
|
| 526 |
if "topic" in results:
|
| 527 |
+
print("\n Topic:")
|
| 528 |
print(f" Accuracy: {results['topic']['accuracy']:.2%}")
|
| 529 |
|
| 530 |
|
scripts/train.py
CHANGED
|
@@ -303,6 +303,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 303 |
scheduler_type=str(sched_cfg.get("name", "cosine")),
|
| 304 |
warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
|
| 305 |
early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
|
|
|
|
|
|
|
|
|
|
| 306 |
),
|
| 307 |
device=device,
|
| 308 |
tokenizer=tokenizer,
|
|
|
|
| 303 |
scheduler_type=str(sched_cfg.get("name", "cosine")),
|
| 304 |
warmup_steps=int(sched_cfg.get("warmup_steps", 500)),
|
| 305 |
early_stopping_patience=trainer_cfg.get("early_stopping_patience"),
|
| 306 |
+
task_sampling=str(trainer_cfg.get("task_sampling", "round_robin")),
|
| 307 |
+
task_sampling_alpha=float(trainer_cfg.get("task_sampling_alpha", 0.5)),
|
| 308 |
+
gradient_conflict_frequency=int(trainer_cfg.get("gradient_conflict_frequency", 0)),
|
| 309 |
),
|
| 310 |
device=device,
|
| 311 |
tokenizer=tokenizer,
|
scripts/train_multiseed.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Multi-seed training wrapper for LexiMind.
|
| 4 |
+
|
| 5 |
+
Runs training across multiple seeds and aggregates results with mean ± std.
|
| 6 |
+
This addresses the single-seed limitation identified in review feedback.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/train_multiseed.py --seeds 17 42 123 --config training=full
|
| 10 |
+
python scripts/train_multiseed.py --seeds 17 42 123 456 789 --config training=medium
|
| 11 |
+
|
| 12 |
+
Author: Oliver Perrin
|
| 13 |
+
Date: February 2026
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import subprocess
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, List
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict:
|
| 29 |
+
"""Run training for a single seed and return the training history."""
|
| 30 |
+
seed_dir = base_dir / f"seed_{seed}"
|
| 31 |
+
seed_dir.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
cmd = [
|
| 34 |
+
sys.executable, "scripts/train.py",
|
| 35 |
+
f"seed={seed}",
|
| 36 |
+
f"checkpoint_out={seed_dir}/checkpoints/best.pt",
|
| 37 |
+
f"history_out={seed_dir}/training_history.json",
|
| 38 |
+
f"labels_out={seed_dir}/labels.json",
|
| 39 |
+
]
|
| 40 |
+
if config_overrides:
|
| 41 |
+
cmd.extend(config_overrides.split())
|
| 42 |
+
|
| 43 |
+
print(f"\n{'='*60}")
|
| 44 |
+
print(f"Training seed {seed}")
|
| 45 |
+
print(f"{'='*60}")
|
| 46 |
+
print(f" Command: {' '.join(cmd)}")
|
| 47 |
+
|
| 48 |
+
result = subprocess.run(cmd, capture_output=False)
|
| 49 |
+
if result.returncode != 0:
|
| 50 |
+
print(f" WARNING: Seed {seed} training failed (exit code {result.returncode})")
|
| 51 |
+
return {}
|
| 52 |
+
|
| 53 |
+
history_path = seed_dir / "training_history.json"
|
| 54 |
+
if history_path.exists():
|
| 55 |
+
with open(history_path) as f:
|
| 56 |
+
return json.load(f)
|
| 57 |
+
return {}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = None) -> Dict:
|
| 61 |
+
"""Run evaluation for a single seed and return results."""
|
| 62 |
+
seed_dir = base_dir / f"seed_{seed}"
|
| 63 |
+
checkpoint = seed_dir / "checkpoints" / "best.pt"
|
| 64 |
+
labels = seed_dir / "labels.json"
|
| 65 |
+
output = seed_dir / "evaluation_report.json"
|
| 66 |
+
|
| 67 |
+
if not checkpoint.exists():
|
| 68 |
+
print(f" Skipping eval for seed {seed}: no checkpoint found")
|
| 69 |
+
return {}
|
| 70 |
+
|
| 71 |
+
cmd = [
|
| 72 |
+
sys.executable, "scripts/evaluate.py",
|
| 73 |
+
f"--checkpoint={checkpoint}",
|
| 74 |
+
f"--labels={labels}",
|
| 75 |
+
f"--output={output}",
|
| 76 |
+
"--skip-bertscore",
|
| 77 |
+
"--tune-thresholds",
|
| 78 |
+
"--bootstrap",
|
| 79 |
+
]
|
| 80 |
+
if extra_args:
|
| 81 |
+
cmd.extend(extra_args)
|
| 82 |
+
|
| 83 |
+
print(f"\n Evaluating seed {seed}...")
|
| 84 |
+
result = subprocess.run(cmd, capture_output=False)
|
| 85 |
+
if result.returncode != 0:
|
| 86 |
+
print(f" WARNING: Seed {seed} evaluation failed")
|
| 87 |
+
return {}
|
| 88 |
+
|
| 89 |
+
if output.exists():
|
| 90 |
+
with open(output) as f:
|
| 91 |
+
return json.load(f)
|
| 92 |
+
return {}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def aggregate_results(all_results: Dict[int, Dict]) -> Dict:
|
| 96 |
+
"""Aggregate evaluation results across seeds with mean ± std."""
|
| 97 |
+
if not all_results:
|
| 98 |
+
return {}
|
| 99 |
+
|
| 100 |
+
# Collect all metric paths
|
| 101 |
+
metric_values: Dict[str, List[float]] = {}
|
| 102 |
+
for seed, results in all_results.items():
|
| 103 |
+
for task, task_metrics in results.items():
|
| 104 |
+
if not isinstance(task_metrics, dict):
|
| 105 |
+
continue
|
| 106 |
+
for metric_name, value in task_metrics.items():
|
| 107 |
+
if isinstance(value, (int, float)) and metric_name != "num_samples" and metric_name != "num_classes":
|
| 108 |
+
key = f"{task}/{metric_name}"
|
| 109 |
+
metric_values.setdefault(key, []).append(float(value))
|
| 110 |
+
|
| 111 |
+
aggregated: Dict[str, Dict[str, float]] = {}
|
| 112 |
+
for key, values in sorted(metric_values.items()):
|
| 113 |
+
arr = np.array(values)
|
| 114 |
+
aggregated[key] = {
|
| 115 |
+
"mean": float(arr.mean()),
|
| 116 |
+
"std": float(arr.std()),
|
| 117 |
+
"min": float(arr.min()),
|
| 118 |
+
"max": float(arr.max()),
|
| 119 |
+
"n_seeds": len(values),
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
return aggregated
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def print_summary(aggregated: Dict, seeds: List[int]) -> None:
|
| 126 |
+
"""Print human-readable summary of multi-seed results."""
|
| 127 |
+
print(f"\n{'='*70}")
|
| 128 |
+
print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})")
|
| 129 |
+
print(f"{'='*70}")
|
| 130 |
+
|
| 131 |
+
# Group by task
|
| 132 |
+
tasks: Dict[str, Dict[str, Dict]] = {}
|
| 133 |
+
for key, stats in aggregated.items():
|
| 134 |
+
task, metric = key.split("/", 1)
|
| 135 |
+
tasks.setdefault(task, {})[metric] = stats
|
| 136 |
+
|
| 137 |
+
for task, metrics in sorted(tasks.items()):
|
| 138 |
+
print(f"\n {task.upper()}:")
|
| 139 |
+
for metric, stats in sorted(metrics.items()):
|
| 140 |
+
mean = stats["mean"]
|
| 141 |
+
std = stats["std"]
|
| 142 |
+
# Format based on metric type
|
| 143 |
+
if "accuracy" in metric:
|
| 144 |
+
print(f" {metric:25s}: {mean*100:.1f}% ± {std*100:.1f}%")
|
| 145 |
+
else:
|
| 146 |
+
print(f" {metric:25s}: {mean:.4f} ± {std:.4f}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def main():
|
| 150 |
+
parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind")
|
| 151 |
+
parser.add_argument("--seeds", nargs="+", type=int, default=[17, 42, 123],
|
| 152 |
+
help="Random seeds to train with")
|
| 153 |
+
parser.add_argument("--config", type=str, default="",
|
| 154 |
+
help="Hydra config overrides (e.g., 'training=full')")
|
| 155 |
+
parser.add_argument("--output-dir", type=Path, default=Path("outputs/multiseed"),
|
| 156 |
+
help="Base output directory")
|
| 157 |
+
parser.add_argument("--skip-training", action="store_true",
|
| 158 |
+
help="Skip training, only aggregate existing results")
|
| 159 |
+
parser.add_argument("--skip-eval", action="store_true",
|
| 160 |
+
help="Skip evaluation, only aggregate training histories")
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
# Training phase
|
| 166 |
+
if not args.skip_training:
|
| 167 |
+
for seed in args.seeds:
|
| 168 |
+
run_single_seed(seed, args.config, args.output_dir)
|
| 169 |
+
|
| 170 |
+
# Evaluation phase
|
| 171 |
+
all_eval_results: Dict[int, Dict] = {}
|
| 172 |
+
if not args.skip_eval:
|
| 173 |
+
for seed in args.seeds:
|
| 174 |
+
result = run_evaluation(seed, args.output_dir)
|
| 175 |
+
if result:
|
| 176 |
+
all_eval_results[seed] = result
|
| 177 |
+
|
| 178 |
+
# Aggregate and save
|
| 179 |
+
if all_eval_results:
|
| 180 |
+
aggregated = aggregate_results(all_eval_results)
|
| 181 |
+
print_summary(aggregated, args.seeds)
|
| 182 |
+
|
| 183 |
+
# Save aggregated results
|
| 184 |
+
output_path = args.output_dir / "aggregated_results.json"
|
| 185 |
+
with open(output_path, "w") as f:
|
| 186 |
+
json.dump({
|
| 187 |
+
"seeds": args.seeds,
|
| 188 |
+
"per_seed": {str(k): v for k, v in all_eval_results.items()},
|
| 189 |
+
"aggregated": aggregated,
|
| 190 |
+
}, f, indent=2)
|
| 191 |
+
print(f"\n Saved to: {output_path}")
|
| 192 |
+
else:
|
| 193 |
+
print("\nNo evaluation results to aggregate.")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
src/data/dataset.py
CHANGED
|
@@ -11,10 +11,11 @@ Date: December 2025
|
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
|
|
|
| 14 |
import json
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from pathlib import Path
|
| 17 |
-
from typing import Callable, Iterable, List, Sequence, TypeVar
|
| 18 |
|
| 19 |
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
|
| 20 |
from torch.utils.data import Dataset
|
|
@@ -239,3 +240,82 @@ def load_topic_jsonl(path: str) -> List[TopicExample]:
|
|
| 239 |
lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
|
| 240 |
required_keys=("text", "topic"),
|
| 241 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import hashlib
|
| 15 |
import json
|
| 16 |
from dataclasses import dataclass
|
| 17 |
from pathlib import Path
|
| 18 |
+
from typing import Callable, Dict, Iterable, List, Sequence, Set, TypeVar
|
| 19 |
|
| 20 |
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
|
| 21 |
from torch.utils.data import Dataset
|
|
|
|
| 240 |
lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
|
| 241 |
required_keys=("text", "topic"),
|
| 242 |
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# --------------- Cross-Task Deduplication ---------------
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _text_fingerprint(text: str, n_chars: int = 200) -> str:
|
| 249 |
+
"""Create a stable fingerprint from the first N characters of text.
|
| 250 |
+
|
| 251 |
+
Uses a hash of the normalized (lowered, whitespace-collapsed) prefix
|
| 252 |
+
to detect document-level overlap across tasks.
|
| 253 |
+
"""
|
| 254 |
+
normalized = " ".join(text.lower().split())[:n_chars]
|
| 255 |
+
return hashlib.md5(normalized.encode("utf-8")).hexdigest()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def deduplicate_across_tasks(
|
| 259 |
+
summ_examples: List[SummarizationExample],
|
| 260 |
+
topic_examples: List[TopicExample],
|
| 261 |
+
emotion_examples: List[EmotionExample] | None = None,
|
| 262 |
+
) -> Dict[str, int]:
|
| 263 |
+
"""Detect and report cross-task document overlap.
|
| 264 |
+
|
| 265 |
+
Checks whether texts appearing in the summarization dataset also appear
|
| 266 |
+
in the topic or emotion datasets, which could create data leakage in MTL.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Dict with overlap counts between task pairs.
|
| 270 |
+
"""
|
| 271 |
+
summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples}
|
| 272 |
+
topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples}
|
| 273 |
+
|
| 274 |
+
overlap: Dict[str, int] = {
|
| 275 |
+
"summ_topic_overlap": len(summ_fps & topic_fps),
|
| 276 |
+
"summ_total": len(summ_fps),
|
| 277 |
+
"topic_total": len(topic_fps),
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
if emotion_examples:
|
| 281 |
+
emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples}
|
| 282 |
+
overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps)
|
| 283 |
+
overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps)
|
| 284 |
+
overlap["emotion_total"] = len(emot_fps)
|
| 285 |
+
|
| 286 |
+
return overlap
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def remove_overlapping_examples(
|
| 290 |
+
primary_examples: List[TopicExample],
|
| 291 |
+
reference_examples: List[SummarizationExample],
|
| 292 |
+
split: str = "val",
|
| 293 |
+
) -> tuple[List[TopicExample], int]:
|
| 294 |
+
"""Remove topic examples whose texts overlap with summarization data.
|
| 295 |
+
|
| 296 |
+
This prevents cross-task data leakage where a document seen during
|
| 297 |
+
summarization training could boost topic classification on validation/test.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
primary_examples: Topic examples to filter
|
| 301 |
+
reference_examples: Summarization examples to check against
|
| 302 |
+
split: Name of split being processed (for logging)
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
Tuple of (filtered_examples, num_removed)
|
| 306 |
+
"""
|
| 307 |
+
ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples}
|
| 308 |
+
|
| 309 |
+
filtered = []
|
| 310 |
+
removed = 0
|
| 311 |
+
for ex in primary_examples:
|
| 312 |
+
fp = _text_fingerprint(ex.text)
|
| 313 |
+
if fp in ref_fps:
|
| 314 |
+
removed += 1
|
| 315 |
+
else:
|
| 316 |
+
filtered.append(ex)
|
| 317 |
+
|
| 318 |
+
if removed > 0:
|
| 319 |
+
print(f" Dedup: removed {removed} overlapping examples from topic {split}")
|
| 320 |
+
|
| 321 |
+
return filtered, removed
|
src/models/factory.py
CHANGED
|
@@ -548,13 +548,15 @@ def build_multitask_model(
|
|
| 548 |
"summarization",
|
| 549 |
LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
|
| 550 |
)
|
| 551 |
-
# Emotion head with 2-layer MLP for better multi-label capacity (28 classes)
|
|
|
|
|
|
|
| 552 |
model.add_head(
|
| 553 |
"emotion",
|
| 554 |
ClassificationHead(
|
| 555 |
d_model=cfg.d_model,
|
| 556 |
num_labels=num_emotions,
|
| 557 |
-
pooler="
|
| 558 |
dropout=cfg.dropout,
|
| 559 |
hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
|
| 560 |
),
|
|
|
|
| 548 |
"summarization",
|
| 549 |
LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
|
| 550 |
)
|
| 551 |
+
# Emotion head with attention pooling + 2-layer MLP for better multi-label capacity (28 classes)
|
| 552 |
+
# Attention pooling is superior to mean pooling for encoder-decoder models where
|
| 553 |
+
# hidden states are optimized for cross-attention rather than simple averaging.
|
| 554 |
model.add_head(
|
| 555 |
"emotion",
|
| 556 |
ClassificationHead(
|
| 557 |
d_model=cfg.d_model,
|
| 558 |
num_labels=num_emotions,
|
| 559 |
+
pooler="attention",
|
| 560 |
dropout=cfg.dropout,
|
| 561 |
hidden_dim=cfg.d_model // 2, # 384-dim hidden layer
|
| 562 |
),
|
src/models/heads.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Prediction heads for Transformer models.
|
| 2 |
|
| 3 |
This module provides task-specific output heads:
|
| 4 |
-
- ClassificationHead: Sequence-level classification with pooling (mean/cls/max)
|
| 5 |
- TokenClassificationHead: Per-token classification (NER, POS tagging)
|
| 6 |
- LMHead: Language modeling logits with optional weight tying
|
| 7 |
- ProjectionHead: MLP for representation learning / contrastive tasks
|
|
@@ -14,6 +14,35 @@ from typing import Literal, Optional
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class ClassificationHead(nn.Module):
|
|
@@ -23,7 +52,7 @@ class ClassificationHead(nn.Module):
|
|
| 23 |
Args:
|
| 24 |
d_model: hidden size from encoder/decoder
|
| 25 |
num_labels: number of output classes
|
| 26 |
-
pooler: one of 'mean', 'cls', 'max' - how to pool the sequence
|
| 27 |
dropout: dropout probability before final linear layer
|
| 28 |
hidden_dim: optional intermediate dimension for 2-layer MLP (improves capacity)
|
| 29 |
"""
|
|
@@ -32,14 +61,17 @@ class ClassificationHead(nn.Module):
|
|
| 32 |
self,
|
| 33 |
d_model: int,
|
| 34 |
num_labels: int,
|
| 35 |
-
pooler: Literal["mean", "cls", "max"] = "mean",
|
| 36 |
dropout: float = 0.1,
|
| 37 |
hidden_dim: Optional[int] = None,
|
| 38 |
):
|
| 39 |
super().__init__()
|
| 40 |
-
assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'"
|
| 41 |
self.pooler = pooler
|
| 42 |
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Optional 2-layer MLP for more capacity (useful for multi-label)
|
| 45 |
if hidden_dim is not None:
|
|
@@ -58,19 +90,14 @@ class ClassificationHead(nn.Module):
|
|
| 58 |
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 59 |
returns: (batch, num_labels)
|
| 60 |
"""
|
| 61 |
-
if self.pooler == "
|
|
|
|
|
|
|
| 62 |
if mask is not None:
|
| 63 |
-
# mask is (B, S)
|
| 64 |
-
# x is (B, S, D)
|
| 65 |
-
# Expand mask to (B, S, 1)
|
| 66 |
mask_expanded = mask.unsqueeze(-1).float()
|
| 67 |
-
# Zero out padding
|
| 68 |
x = x * mask_expanded
|
| 69 |
-
# Sum over sequence
|
| 70 |
sum_embeddings = x.sum(dim=1)
|
| 71 |
-
# Count valid tokens
|
| 72 |
sum_mask = mask_expanded.sum(dim=1)
|
| 73 |
-
# Avoid division by zero
|
| 74 |
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 75 |
pooled = sum_embeddings / sum_mask
|
| 76 |
else:
|
|
@@ -79,7 +106,6 @@ class ClassificationHead(nn.Module):
|
|
| 79 |
pooled = x[:, 0, :]
|
| 80 |
else: # max
|
| 81 |
if mask is not None:
|
| 82 |
-
# Mask padding with -inf
|
| 83 |
mask_expanded = mask.unsqueeze(-1)
|
| 84 |
x = x.masked_fill(~mask_expanded, float("-inf"))
|
| 85 |
pooled, _ = x.max(dim=1)
|
|
|
|
| 1 |
"""Prediction heads for Transformer models.
|
| 2 |
|
| 3 |
This module provides task-specific output heads:
|
| 4 |
+
- ClassificationHead: Sequence-level classification with pooling (mean/cls/max/attention)
|
| 5 |
- TokenClassificationHead: Per-token classification (NER, POS tagging)
|
| 6 |
- LMHead: Language modeling logits with optional weight tying
|
| 7 |
- ProjectionHead: MLP for representation learning / contrastive tasks
|
|
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AttentionPooling(nn.Module):
|
| 21 |
+
"""Learned attention pooling over sequence positions.
|
| 22 |
+
|
| 23 |
+
Computes a weighted sum of hidden states using a learned query vector,
|
| 24 |
+
producing a single fixed-size representation. This is generally superior
|
| 25 |
+
to mean pooling for classification tasks on encoder-decoder models where
|
| 26 |
+
hidden states are optimized for cross-attention rather than pooling.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, d_model: int):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.query = nn.Linear(d_model, 1, bias=False)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
x: (batch, seq_len, d_model)
|
| 36 |
+
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 37 |
+
returns: (batch, d_model)
|
| 38 |
+
"""
|
| 39 |
+
# Compute attention scores: (batch, seq_len, 1)
|
| 40 |
+
scores = self.query(x)
|
| 41 |
+
if mask is not None:
|
| 42 |
+
scores = scores.masked_fill(~mask.unsqueeze(-1), float("-inf"))
|
| 43 |
+
weights = F.softmax(scores, dim=1) # (batch, seq_len, 1)
|
| 44 |
+
# Weighted sum: (batch, d_model)
|
| 45 |
+
return (weights * x).sum(dim=1)
|
| 46 |
|
| 47 |
|
| 48 |
class ClassificationHead(nn.Module):
|
|
|
|
| 52 |
Args:
|
| 53 |
d_model: hidden size from encoder/decoder
|
| 54 |
num_labels: number of output classes
|
| 55 |
+
pooler: one of 'mean', 'cls', 'max', 'attention' - how to pool the sequence
|
| 56 |
dropout: dropout probability before final linear layer
|
| 57 |
hidden_dim: optional intermediate dimension for 2-layer MLP (improves capacity)
|
| 58 |
"""
|
|
|
|
| 61 |
self,
|
| 62 |
d_model: int,
|
| 63 |
num_labels: int,
|
| 64 |
+
pooler: Literal["mean", "cls", "max", "attention"] = "mean",
|
| 65 |
dropout: float = 0.1,
|
| 66 |
hidden_dim: Optional[int] = None,
|
| 67 |
):
|
| 68 |
super().__init__()
|
| 69 |
+
assert pooler in ("mean", "cls", "max", "attention"), "pooler must be 'mean'|'cls'|'max'|'attention'"
|
| 70 |
self.pooler = pooler
|
| 71 |
self.dropout = nn.Dropout(dropout)
|
| 72 |
+
|
| 73 |
+
if pooler == "attention":
|
| 74 |
+
self.attn_pool = AttentionPooling(d_model)
|
| 75 |
|
| 76 |
# Optional 2-layer MLP for more capacity (useful for multi-label)
|
| 77 |
if hidden_dim is not None:
|
|
|
|
| 90 |
mask: (batch, seq_len) - True for valid tokens, False for padding
|
| 91 |
returns: (batch, num_labels)
|
| 92 |
"""
|
| 93 |
+
if self.pooler == "attention":
|
| 94 |
+
pooled = self.attn_pool(x, mask)
|
| 95 |
+
elif self.pooler == "mean":
|
| 96 |
if mask is not None:
|
|
|
|
|
|
|
|
|
|
| 97 |
mask_expanded = mask.unsqueeze(-1).float()
|
|
|
|
| 98 |
x = x * mask_expanded
|
|
|
|
| 99 |
sum_embeddings = x.sum(dim=1)
|
|
|
|
| 100 |
sum_mask = mask_expanded.sum(dim=1)
|
|
|
|
| 101 |
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 102 |
pooled = sum_embeddings / sum_mask
|
| 103 |
else:
|
|
|
|
| 106 |
pooled = x[:, 0, :]
|
| 107 |
else: # max
|
| 108 |
if mask is not None:
|
|
|
|
| 109 |
mask_expanded = mask.unsqueeze(-1)
|
| 110 |
x = x.masked_fill(~mask_expanded, float("-inf"))
|
| 111 |
pooled, _ = x.max(dim=1)
|
src/training/metrics.py
CHANGED
|
@@ -110,9 +110,9 @@ def calculate_bertscore(
|
|
| 110 |
)
|
| 111 |
|
| 112 |
return {
|
| 113 |
-
"precision": float(P.mean().item()),
|
| 114 |
-
"recall": float(R.mean().item()),
|
| 115 |
-
"f1": float(F1.mean().item()),
|
| 116 |
}
|
| 117 |
|
| 118 |
|
|
@@ -239,3 +239,213 @@ def get_confusion_matrix(
|
|
| 239 |
) -> np.ndarray:
|
| 240 |
"""Compute confusion matrix."""
|
| 241 |
return cast(np.ndarray, confusion_matrix(targets, predictions, labels=labels))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
|
| 112 |
return {
|
| 113 |
+
"precision": float(P.mean().item()), # type: ignore[union-attr]
|
| 114 |
+
"recall": float(R.mean().item()), # type: ignore[union-attr]
|
| 115 |
+
"f1": float(F1.mean().item()), # type: ignore[union-attr]
|
| 116 |
}
|
| 117 |
|
| 118 |
|
|
|
|
| 239 |
) -> np.ndarray:
|
| 240 |
"""Compute confusion matrix."""
|
| 241 |
return cast(np.ndarray, confusion_matrix(targets, predictions, labels=labels))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# --------------- Multi-label Emotion Metrics ---------------
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def multilabel_macro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 248 |
+
"""Compute macro F1: average F1 per class (as in GoEmotions paper).
|
| 249 |
+
|
| 250 |
+
This averages F1 across labels, giving equal weight to each emotion class
|
| 251 |
+
regardless of prevalence. Directly comparable to GoEmotions baselines.
|
| 252 |
+
"""
|
| 253 |
+
preds = predictions.float()
|
| 254 |
+
gold = targets.float()
|
| 255 |
+
|
| 256 |
+
# Per-class TP, FP, FN
|
| 257 |
+
tp = (preds * gold).sum(dim=0)
|
| 258 |
+
fp = (preds * (1 - gold)).sum(dim=0)
|
| 259 |
+
fn = ((1 - preds) * gold).sum(dim=0)
|
| 260 |
+
|
| 261 |
+
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 262 |
+
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 263 |
+
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 264 |
+
|
| 265 |
+
# Zero out F1 for classes with no support in either predictions or targets
|
| 266 |
+
mask = (tp + fp + fn) > 0
|
| 267 |
+
if mask.sum() == 0:
|
| 268 |
+
return 0.0
|
| 269 |
+
return float(f1[mask].mean().item())
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def multilabel_micro_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 273 |
+
"""Compute micro F1: aggregate TP/FP/FN across all classes.
|
| 274 |
+
|
| 275 |
+
This gives more weight to frequent classes. Useful when class distribution matters.
|
| 276 |
+
"""
|
| 277 |
+
preds = predictions.float()
|
| 278 |
+
gold = targets.float()
|
| 279 |
+
|
| 280 |
+
tp = (preds * gold).sum()
|
| 281 |
+
fp = (preds * (1 - gold)).sum()
|
| 282 |
+
fn = ((1 - preds) * gold).sum()
|
| 283 |
+
|
| 284 |
+
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 285 |
+
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 286 |
+
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 287 |
+
return float(f1.item())
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def multilabel_per_class_metrics(
|
| 291 |
+
predictions: torch.Tensor,
|
| 292 |
+
targets: torch.Tensor,
|
| 293 |
+
class_names: Sequence[str] | None = None,
|
| 294 |
+
) -> Dict[str, Dict[str, float]]:
|
| 295 |
+
"""Compute per-class precision, recall, F1 for multi-label classification.
|
| 296 |
+
|
| 297 |
+
Returns a dict mapping class name/index to its metrics.
|
| 298 |
+
"""
|
| 299 |
+
preds = predictions.float()
|
| 300 |
+
gold = targets.float()
|
| 301 |
+
num_classes = preds.shape[1]
|
| 302 |
+
|
| 303 |
+
tp = (preds * gold).sum(dim=0)
|
| 304 |
+
fp = (preds * (1 - gold)).sum(dim=0)
|
| 305 |
+
fn = ((1 - preds) * gold).sum(dim=0)
|
| 306 |
+
|
| 307 |
+
report: Dict[str, Dict[str, float]] = {}
|
| 308 |
+
for i in range(num_classes):
|
| 309 |
+
name = class_names[i] if class_names else str(i)
|
| 310 |
+
p = (tp[i] / (tp[i] + fp[i]).clamp(min=1e-8)).item()
|
| 311 |
+
r = (tp[i] / (tp[i] + fn[i]).clamp(min=1e-8)).item()
|
| 312 |
+
f = (2 * p * r) / max(p + r, 1e-8)
|
| 313 |
+
report[name] = {
|
| 314 |
+
"precision": p,
|
| 315 |
+
"recall": r,
|
| 316 |
+
"f1": f,
|
| 317 |
+
"support": int(gold[:, i].sum().item()),
|
| 318 |
+
}
|
| 319 |
+
return report
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def tune_per_class_thresholds(
|
| 323 |
+
logits: torch.Tensor,
|
| 324 |
+
targets: torch.Tensor,
|
| 325 |
+
thresholds: Sequence[float] | None = None,
|
| 326 |
+
) -> tuple[List[float], float]:
|
| 327 |
+
"""Tune per-class thresholds on validation set to maximize macro F1.
|
| 328 |
+
|
| 329 |
+
For each class, tries multiple thresholds and selects the one that
|
| 330 |
+
maximizes that class's F1 score. This is standard practice for multi-label
|
| 331 |
+
classification (used in the original GoEmotions paper).
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
logits: Raw model logits (batch, num_classes)
|
| 335 |
+
targets: Binary target labels (batch, num_classes)
|
| 336 |
+
thresholds: Candidate thresholds to try (default: 0.1 to 0.9 by 0.05)
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Tuple of (best_thresholds_per_class, resulting_macro_f1)
|
| 340 |
+
"""
|
| 341 |
+
if thresholds is None:
|
| 342 |
+
thresholds = [round(t, 2) for t in np.arange(0.1, 0.9, 0.05).tolist()]
|
| 343 |
+
|
| 344 |
+
probs = torch.sigmoid(logits)
|
| 345 |
+
num_classes = probs.shape[1]
|
| 346 |
+
gold = targets.float()
|
| 347 |
+
|
| 348 |
+
best_thresholds: List[float] = []
|
| 349 |
+
for c in range(num_classes):
|
| 350 |
+
best_f1 = -1.0
|
| 351 |
+
best_t = 0.5
|
| 352 |
+
for t in thresholds:
|
| 353 |
+
preds = (probs[:, c] >= t).float()
|
| 354 |
+
tp = (preds * gold[:, c]).sum()
|
| 355 |
+
fp = (preds * (1 - gold[:, c])).sum()
|
| 356 |
+
fn = ((1 - preds) * gold[:, c]).sum()
|
| 357 |
+
if tp + fp > 0 and tp + fn > 0:
|
| 358 |
+
p = tp / (tp + fp)
|
| 359 |
+
r = tp / (tp + fn)
|
| 360 |
+
f1 = (2 * p * r / (p + r)).item()
|
| 361 |
+
else:
|
| 362 |
+
f1 = 0.0
|
| 363 |
+
if f1 > best_f1:
|
| 364 |
+
best_f1 = f1
|
| 365 |
+
best_t = t
|
| 366 |
+
best_thresholds.append(best_t)
|
| 367 |
+
|
| 368 |
+
# Compute resulting macro F1 with tuned thresholds
|
| 369 |
+
tuned_preds = torch.zeros_like(probs)
|
| 370 |
+
for c in range(num_classes):
|
| 371 |
+
tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
|
| 372 |
+
macro_f1 = multilabel_macro_f1(tuned_preds, targets)
|
| 373 |
+
|
| 374 |
+
return best_thresholds, macro_f1
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# --------------- Statistical Tests ---------------
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def bootstrap_confidence_interval(
|
| 381 |
+
scores: Sequence[float],
|
| 382 |
+
n_bootstrap: int = 1000,
|
| 383 |
+
confidence: float = 0.95,
|
| 384 |
+
seed: int = 42,
|
| 385 |
+
) -> tuple[float, float, float]:
|
| 386 |
+
"""Compute bootstrap confidence interval for a metric.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
scores: Per-sample metric values
|
| 390 |
+
n_bootstrap: Number of bootstrap resamples
|
| 391 |
+
confidence: Confidence level (default 95%)
|
| 392 |
+
seed: Random seed for reproducibility
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
Tuple of (mean, lower_bound, upper_bound)
|
| 396 |
+
"""
|
| 397 |
+
rng = np.random.default_rng(seed)
|
| 398 |
+
scores_arr = np.array(scores)
|
| 399 |
+
n = len(scores_arr)
|
| 400 |
+
|
| 401 |
+
bootstrap_means = []
|
| 402 |
+
for _ in range(n_bootstrap):
|
| 403 |
+
sample = rng.choice(scores_arr, size=n, replace=True)
|
| 404 |
+
bootstrap_means.append(float(np.mean(sample)))
|
| 405 |
+
|
| 406 |
+
bootstrap_means.sort()
|
| 407 |
+
alpha = 1 - confidence
|
| 408 |
+
lower_idx = int(alpha / 2 * n_bootstrap)
|
| 409 |
+
upper_idx = int((1 - alpha / 2) * n_bootstrap)
|
| 410 |
+
|
| 411 |
+
return (
|
| 412 |
+
float(np.mean(scores_arr)),
|
| 413 |
+
bootstrap_means[lower_idx],
|
| 414 |
+
bootstrap_means[min(upper_idx, n_bootstrap - 1)],
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def paired_bootstrap_test(
|
| 419 |
+
scores_a: Sequence[float],
|
| 420 |
+
scores_b: Sequence[float],
|
| 421 |
+
n_bootstrap: int = 10000,
|
| 422 |
+
seed: int = 42,
|
| 423 |
+
) -> float:
|
| 424 |
+
"""Paired bootstrap significance test between two systems.
|
| 425 |
+
|
| 426 |
+
Tests if system B is significantly better than system A.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
scores_a: Per-sample scores from system A
|
| 430 |
+
scores_b: Per-sample scores from system B
|
| 431 |
+
n_bootstrap: Number of bootstrap iterations
|
| 432 |
+
seed: Random seed
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
p-value (probability that B is not better than A)
|
| 436 |
+
"""
|
| 437 |
+
rng = np.random.default_rng(seed)
|
| 438 |
+
a = np.array(scores_a)
|
| 439 |
+
b = np.array(scores_b)
|
| 440 |
+
assert len(a) == len(b), "Both score lists must have the same length"
|
| 441 |
+
|
| 442 |
+
n = len(a)
|
| 443 |
+
|
| 444 |
+
count = 0
|
| 445 |
+
for _ in range(n_bootstrap):
|
| 446 |
+
idx = rng.choice(n, size=n, replace=True)
|
| 447 |
+
diff = float(np.mean(b[idx]) - np.mean(a[idx]))
|
| 448 |
+
if diff <= 0:
|
| 449 |
+
count += 1
|
| 450 |
+
|
| 451 |
+
return count / n_bootstrap
|
src/training/trainer.py
CHANGED
|
@@ -7,6 +7,8 @@ Handles training across summarization, emotion, and topic heads with:
|
|
| 7 |
- Cosine LR schedule with warmup
|
| 8 |
- Early stopping
|
| 9 |
- MLflow logging
|
|
|
|
|
|
|
| 10 |
|
| 11 |
Author: Oliver Perrin
|
| 12 |
Date: December 2025
|
|
@@ -22,6 +24,7 @@ from dataclasses import dataclass
|
|
| 22 |
from typing import Any, Callable, Dict, List
|
| 23 |
|
| 24 |
import mlflow
|
|
|
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
from torch.optim.lr_scheduler import LambdaLR
|
|
@@ -53,6 +56,16 @@ class TrainerConfig:
|
|
| 53 |
# Early stopping
|
| 54 |
early_stopping_patience: int | None = 5
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# MLflow
|
| 57 |
experiment_name: str = "LexiMind"
|
| 58 |
run_name: str | None = None
|
|
@@ -210,7 +223,7 @@ class Trainer:
|
|
| 210 |
train: bool,
|
| 211 |
epoch: int,
|
| 212 |
) -> Dict[str, float]:
|
| 213 |
-
"""Run one epoch."""
|
| 214 |
self.model.train(train)
|
| 215 |
metrics: Dict[str, List[float]] = defaultdict(list)
|
| 216 |
iterators = {task: iter(loader) for task, loader in loaders.items()}
|
|
@@ -220,12 +233,33 @@ class Trainer:
|
|
| 220 |
phase = "Train" if train else "Val"
|
| 221 |
pbar = tqdm(range(max_batches), desc=f" {phase}", leave=False, file=sys.stderr)
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
ctx = torch.enable_grad() if train else torch.no_grad()
|
| 224 |
with ctx:
|
| 225 |
for step in pbar:
|
| 226 |
step_loss = 0.0
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
batch = self._get_batch(iterators, loader, task)
|
| 230 |
if batch is None:
|
| 231 |
continue
|
|
@@ -253,6 +287,14 @@ class Trainer:
|
|
| 253 |
scaled = (loss * weight) / accum
|
| 254 |
scaled.backward()
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# Optimizer step
|
| 257 |
if train and (step + 1) % accum == 0:
|
| 258 |
torch.nn.utils.clip_grad_norm_(
|
|
@@ -415,6 +457,56 @@ class Trainer:
|
|
| 415 |
tqdm.write(f"{'=' * 50}\n")
|
| 416 |
self.model.train()
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
def _log_config(self) -> None:
|
| 419 |
"""Log config to MLflow."""
|
| 420 |
mlflow.log_params({
|
|
|
|
| 7 |
- Cosine LR schedule with warmup
|
| 8 |
- Early stopping
|
| 9 |
- MLflow logging
|
| 10 |
+
- Temperature-based task sampling (configurable alpha)
|
| 11 |
+
- Gradient conflict diagnostics
|
| 12 |
|
| 13 |
Author: Oliver Perrin
|
| 14 |
Date: December 2025
|
|
|
|
| 24 |
from typing import Any, Callable, Dict, List
|
| 25 |
|
| 26 |
import mlflow
|
| 27 |
+
import numpy as np
|
| 28 |
import torch
|
| 29 |
import torch.nn.functional as F
|
| 30 |
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
| 56 |
# Early stopping
|
| 57 |
early_stopping_patience: int | None = 5
|
| 58 |
|
| 59 |
+
# Task sampling strategy: "round_robin" or "temperature"
|
| 60 |
+
# Temperature sampling: p_i ∝ n_i^alpha where n_i = dataset size
|
| 61 |
+
# alpha < 1 reduces dominance of large tasks (recommended: 0.5-0.7)
|
| 62 |
+
task_sampling: str = "round_robin"
|
| 63 |
+
task_sampling_alpha: float = 0.5
|
| 64 |
+
|
| 65 |
+
# Gradient conflict diagnostics
|
| 66 |
+
# Compute inter-task gradient cosine similarity every N steps (0 = disabled)
|
| 67 |
+
gradient_conflict_frequency: int = 0
|
| 68 |
+
|
| 69 |
# MLflow
|
| 70 |
experiment_name: str = "LexiMind"
|
| 71 |
run_name: str | None = None
|
|
|
|
| 223 |
train: bool,
|
| 224 |
epoch: int,
|
| 225 |
) -> Dict[str, float]:
|
| 226 |
+
"""Run one epoch with configurable task sampling strategy."""
|
| 227 |
self.model.train(train)
|
| 228 |
metrics: Dict[str, List[float]] = defaultdict(list)
|
| 229 |
iterators = {task: iter(loader) for task, loader in loaders.items()}
|
|
|
|
| 233 |
phase = "Train" if train else "Val"
|
| 234 |
pbar = tqdm(range(max_batches), desc=f" {phase}", leave=False, file=sys.stderr)
|
| 235 |
|
| 236 |
+
# Temperature-based task sampling: p_i ∝ n_i^alpha
|
| 237 |
+
task_names = list(loaders.keys())
|
| 238 |
+
if self.config.task_sampling == "temperature" and len(task_names) > 1:
|
| 239 |
+
sizes = np.array([len(loaders[t].dataset) for t in task_names], dtype=np.float64) # type: ignore[arg-type]
|
| 240 |
+
alpha = self.config.task_sampling_alpha
|
| 241 |
+
probs = sizes ** alpha
|
| 242 |
+
probs = probs / probs.sum()
|
| 243 |
+
tqdm.write(f" Temperature sampling (α={alpha}): " +
|
| 244 |
+
", ".join(f"{t}={p:.2%}" for t, p in zip(task_names, probs, strict=True)))
|
| 245 |
+
else:
|
| 246 |
+
probs = None
|
| 247 |
+
|
| 248 |
ctx = torch.enable_grad() if train else torch.no_grad()
|
| 249 |
with ctx:
|
| 250 |
for step in pbar:
|
| 251 |
step_loss = 0.0
|
| 252 |
|
| 253 |
+
# Select tasks for this step
|
| 254 |
+
if probs is not None and train:
|
| 255 |
+
# Temperature sampling: sample tasks based on dataset size
|
| 256 |
+
selected_tasks = list(np.random.choice(task_names, size=len(task_names), replace=True, p=probs))
|
| 257 |
+
else:
|
| 258 |
+
# Round-robin: all tasks every step
|
| 259 |
+
selected_tasks = task_names
|
| 260 |
+
|
| 261 |
+
for task in selected_tasks:
|
| 262 |
+
loader = loaders[task]
|
| 263 |
batch = self._get_batch(iterators, loader, task)
|
| 264 |
if batch is None:
|
| 265 |
continue
|
|
|
|
| 287 |
scaled = (loss * weight) / accum
|
| 288 |
scaled.backward()
|
| 289 |
|
| 290 |
+
# Gradient conflict diagnostics
|
| 291 |
+
if (train and self.config.gradient_conflict_frequency > 0
|
| 292 |
+
and (step + 1) % self.config.gradient_conflict_frequency == 0):
|
| 293 |
+
conflict_stats = self._compute_gradient_conflicts(loaders, iterators)
|
| 294 |
+
for k, v in conflict_stats.items():
|
| 295 |
+
metrics[f"grad_{k}"].append(v)
|
| 296 |
+
mlflow.log_metric(f"grad_{k}", v, step=self.global_step)
|
| 297 |
+
|
| 298 |
# Optimizer step
|
| 299 |
if train and (step + 1) % accum == 0:
|
| 300 |
torch.nn.utils.clip_grad_norm_(
|
|
|
|
| 457 |
tqdm.write(f"{'=' * 50}\n")
|
| 458 |
self.model.train()
|
| 459 |
|
| 460 |
+
def _compute_gradient_conflicts(
|
| 461 |
+
self,
|
| 462 |
+
loaders: Dict[str, DataLoader],
|
| 463 |
+
iterators: Dict,
|
| 464 |
+
) -> Dict[str, float]:
|
| 465 |
+
"""Compute inter-task gradient cosine similarity to diagnose conflicts.
|
| 466 |
+
|
| 467 |
+
Returns cosine similarity between gradient vectors for each task pair.
|
| 468 |
+
Negative values indicate conflicting gradients (negative transfer risk).
|
| 469 |
+
"""
|
| 470 |
+
task_grads: Dict[str, torch.Tensor] = {}
|
| 471 |
+
|
| 472 |
+
for task, loader in loaders.items():
|
| 473 |
+
self.optimizer.zero_grad()
|
| 474 |
+
batch = self._get_batch(iterators, loader, task)
|
| 475 |
+
if batch is None:
|
| 476 |
+
continue
|
| 477 |
+
|
| 478 |
+
dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
|
| 479 |
+
with torch.autocast("cuda", dtype=dtype, enabled=self.use_amp):
|
| 480 |
+
loss, _ = self._forward_task(task, batch)
|
| 481 |
+
|
| 482 |
+
if torch.isnan(loss):
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
loss.backward()
|
| 486 |
+
|
| 487 |
+
# Flatten all gradients into a single vector
|
| 488 |
+
grad_vec = []
|
| 489 |
+
for p in self.model.parameters():
|
| 490 |
+
if p.grad is not None:
|
| 491 |
+
grad_vec.append(p.grad.detach().clone().flatten())
|
| 492 |
+
if grad_vec:
|
| 493 |
+
task_grads[task] = torch.cat(grad_vec)
|
| 494 |
+
|
| 495 |
+
self.optimizer.zero_grad()
|
| 496 |
+
|
| 497 |
+
# Compute pairwise cosine similarity
|
| 498 |
+
stats: Dict[str, float] = {}
|
| 499 |
+
tasks = list(task_grads.keys())
|
| 500 |
+
for i in range(len(tasks)):
|
| 501 |
+
for j in range(i + 1, len(tasks)):
|
| 502 |
+
t1, t2 = tasks[i], tasks[j]
|
| 503 |
+
g1, g2 = task_grads[t1], task_grads[t2]
|
| 504 |
+
cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
|
| 505 |
+
stats[f"cos_sim_{t1}_{t2}"] = cos_sim
|
| 506 |
+
stats[f"conflict_{t1}_{t2}"] = 1.0 if cos_sim < 0 else 0.0
|
| 507 |
+
|
| 508 |
+
return stats
|
| 509 |
+
|
| 510 |
def _log_config(self) -> None:
|
| 511 |
"""Log config to MLflow."""
|
| 512 |
mlflow.log_params({
|