Spaces:
Sleeping
Sleeping
OliverPerrin commited on
Commit ·
aca8362
1
Parent(s): a6468cc
Added Bert Baseline comparison training script and ignored unneccessary files
Browse files- .gitignore +10 -1
- docs/paper.tex +0 -1214
- docs/research_paper.tex +0 -574
- kaggle.json +0 -1
- outputs/evaluation_report.json +0 -260
- outputs/training_history.json +0 -210
- scripts/train_bert_baseline.py +1137 -0
.gitignore
CHANGED
|
@@ -51,6 +51,11 @@ results/
|
|
| 51 |
.ipynb_checkpoints/
|
| 52 |
*.ipynb
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# OS - Windows specific
|
| 55 |
.DS_Store
|
| 56 |
Thumbs.db
|
|
@@ -65,6 +70,10 @@ ehthumbs_vista.db
|
|
| 65 |
configs/local/*.png
|
| 66 |
*.pt
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Backup/private files
|
| 69 |
-
scripts/demo_gradio_old.py
|
| 70 |
mlruns.db
|
|
|
|
|
|
| 51 |
.ipynb_checkpoints/
|
| 52 |
*.ipynb
|
| 53 |
|
| 54 |
+
# Docs/Paper
|
| 55 |
+
docs/paper.*
|
| 56 |
+
docs/paper_old.*
|
| 57 |
+
docs/research_paper.*
|
| 58 |
+
|
| 59 |
# OS - Windows specific
|
| 60 |
.DS_Store
|
| 61 |
Thumbs.db
|
|
|
|
| 70 |
configs/local/*.png
|
| 71 |
*.pt
|
| 72 |
|
| 73 |
+
# Environment variables
|
| 74 |
+
.env
|
| 75 |
+
.env.*
|
| 76 |
+
|
| 77 |
# Backup/private files
|
|
|
|
| 78 |
mlruns.db
|
| 79 |
+
kaggle.json
|
docs/paper.tex
DELETED
|
@@ -1,1214 +0,0 @@
|
|
| 1 |
-
% LexiMind: A Hybrid Transformer Architecture for Multi-Task NLP
|
| 2 |
-
% IEEE Conference Style Paper
|
| 3 |
-
% Author: Oliver Perrin
|
| 4 |
-
|
| 5 |
-
\documentclass[conference]{IEEEtran}
|
| 6 |
-
\IEEEoverridecommandlockouts
|
| 7 |
-
|
| 8 |
-
% Essential packages
|
| 9 |
-
\usepackage{cite}
|
| 10 |
-
\usepackage{amsmath,amssymb,amsfonts}
|
| 11 |
-
\usepackage{algorithmic}
|
| 12 |
-
\usepackage{graphicx}
|
| 13 |
-
\usepackage{textcomp}
|
| 14 |
-
\usepackage{xcolor}
|
| 15 |
-
\usepackage{hyperref}
|
| 16 |
-
\usepackage{listings}
|
| 17 |
-
\usepackage{booktabs}
|
| 18 |
-
\usepackage{multirow}
|
| 19 |
-
\usepackage{array}
|
| 20 |
-
\usepackage{caption}
|
| 21 |
-
|
| 22 |
-
% TikZ for diagrams
|
| 23 |
-
\usepackage{tikz}
|
| 24 |
-
\usetikzlibrary{shapes.geometric, arrows, positioning, fit, calc, backgrounds, decorations.pathreplacing}
|
| 25 |
-
|
| 26 |
-
% Code listings style
|
| 27 |
-
\lstset{
|
| 28 |
-
basicstyle=\ttfamily\small,
|
| 29 |
-
breaklines=true,
|
| 30 |
-
frame=single,
|
| 31 |
-
language=Python,
|
| 32 |
-
keywordstyle=\color{blue},
|
| 33 |
-
commentstyle=\color{green!50!black},
|
| 34 |
-
stringstyle=\color{red!60!black},
|
| 35 |
-
showstringspaces=false
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
% Hyperref setup
|
| 39 |
-
\hypersetup{
|
| 40 |
-
colorlinks=true,
|
| 41 |
-
linkcolor=blue,
|
| 42 |
-
citecolor=blue,
|
| 43 |
-
urlcolor=blue
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
\def\BibTeX{{\rm B\kern-.05em{\sc i\kern-.025em b}\kern-.08em
|
| 47 |
-
T\kern-.1667em\lower.7ex\hbox{E}\kern-.125emX}}
|
| 48 |
-
|
| 49 |
-
\begin{document}
|
| 50 |
-
|
| 51 |
-
\title{LexiMind: A Hybrid Transformer Architecture\\for Multi-Task Natural Language Processing}
|
| 52 |
-
|
| 53 |
-
\author{\IEEEauthorblockN{Oliver Perrin}\\
|
| 54 |
-
\IEEEauthorblockA{Department of Computer Science\\
|
| 55 |
-
Appalachian State University\\
|
| 56 |
-
Bachelor of Science in Computer Science\\
|
| 57 |
-
Email: perrinot@appstate.edu}}
|
| 58 |
-
|
| 59 |
-
\maketitle
|
| 60 |
-
|
| 61 |
-
\begin{abstract}
|
| 62 |
-
This paper presents LexiMind, a multi-task Natural Language Processing (NLP) system that combines a custom-built Transformer architecture with pre-trained weights from Google's FLAN-T5 (Fine-tuned Language Net Text-to-Text Transfer Transformer). The system performs three fundamental NLP tasks simultaneously: abstractive text summarization, multi-label emotion classification, and single-label topic classification. Unlike news-focused models, LexiMind specializes in literary and academic content. For summarization, we train on 49,086 samples combining Goodreads book descriptions (back-cover style blurbs) with arXiv academic paper abstracts. Emotion classification uses 43,410 samples from GoEmotions \cite{demszky2020goemotions}, a dataset of 28 fine-grained emotion labels derived from Reddit comments. Topic classification spans 3,402 samples from 20 Newsgroups, Project Gutenberg literary texts, and scientific papers across 7 categories (Arts, Business, Fiction, History, Philosophy, Science, Technology). By implementing modern architectural innovations including Pre-Layer Normalization (Pre-LN) with Root Mean Square Layer Normalization (RMSNorm), T5-style relative position bias, FlashAttention via PyTorch 2.0's Scaled Dot-Product Attention (SDPA), gradient checkpointing, and torch.compile optimization, LexiMind achieves efficient training on consumer GPUs while maintaining strong performance. Our final model achieves a BERTScore F1 of 0.83 and ROUGE-1 of 0.31 for summarization, 85.2\% accuracy for topic classification, and F1 of 0.20 for 28-class multi-label emotion detection. The 272M-parameter architecture is constructed from first principles in a bottom-up fashion, with each component (attention mechanisms, feed-forward networks, encoder/decoder blocks) implemented as standalone modules. A factory pattern enables seamless weight transfer from FLAN-T5-base, allowing the system to leverage Google's pre-trained knowledge while maintaining full architectural transparency and customization capability.
|
| 63 |
-
\end{abstract}
|
| 64 |
-
|
| 65 |
-
\begin{IEEEkeywords}
|
| 66 |
-
Transformer, Multi-Task Learning, Natural Language Processing, FLAN-T5, Transfer Learning, Text Summarization, Emotion Classification, Academic Papers, Literary Text
|
| 67 |
-
\end{IEEEkeywords}
|
| 68 |
-
|
| 69 |
-
%=============================================================================
|
| 70 |
-
\section{Introduction}
|
| 71 |
-
%=============================================================================
|
| 72 |
-
|
| 73 |
-
The Transformer architecture \cite{vaswani2017attention} has fundamentally reshaped Natural Language Processing (NLP), establishing itself as the foundation for state-of-the-art models across virtually all language understanding and generation tasks. Building upon this foundation, the T5 (Text-to-Text Transfer Transformer) model \cite{raffel2020exploring} introduced a unified framework that casts all NLP problems as text-to-text transformations. FLAN-T5 (Fine-tuned Language Net) \cite{chung2022scaling} further enhanced T5's capabilities through instruction fine-tuning on over 1,000 diverse tasks.
|
| 74 |
-
|
| 75 |
-
While pre-trained models like FLAN-T5 offer impressive zero-shot and few-shot capabilities, they are often treated as black boxes—their internal mechanisms obscured by framework abstractions. This opacity hinders both understanding and customization. Furthermore, multi-task learning scenarios often require architectural modifications that pre-built models do not easily accommodate.
|
| 76 |
-
|
| 77 |
-
LexiMind addresses these challenges through a hybrid approach: implementing a complete Transformer architecture from scratch while maintaining compatibility with FLAN-T5's pre-trained weights. This provides several key advantages:
|
| 78 |
-
|
| 79 |
-
\begin{enumerate}
|
| 80 |
-
\item \textbf{Architectural Transparency}: Every component—from attention mechanisms to normalization layers—is explicitly implemented and documented.
|
| 81 |
-
\item \textbf{Customization Flexibility}: Task-specific heads and routing logic can be freely modified without framework constraints.
|
| 82 |
-
\item \textbf{Transfer Learning}: FLAN-T5's linguistic knowledge is transferred through careful weight mapping in the factory module.
|
| 83 |
-
\item \textbf{Modern Optimizations}: Integration of FlashAttention, bfloat16 training, and gradient accumulation ensures efficient resource utilization.
|
| 84 |
-
\end{enumerate}
|
| 85 |
-
|
| 86 |
-
A key design decision in LexiMind is the focus on literary and academic domains rather than news articles, which are overrepresented in existing summarization benchmarks. For text summarization, we combine Goodreads book descriptions---which provide back-cover style blurbs describing \textit{what a book is about}---with arXiv paper abstracts. This trains the model to generate descriptive summaries rather than extractive plot recaps. Emotion classification leverages GoEmotions \cite{demszky2020goemotions}, providing fine-grained 28-label annotations. Topic classification draws from diverse sources including 20 Newsgroups, Project Gutenberg, and scientific papers.
|
| 87 |
-
|
| 88 |
-
The contributions of this work include:
|
| 89 |
-
\begin{itemize}
|
| 90 |
-
\item A custom Transformer implementation compatible with T5/FLAN-T5 weight loading
|
| 91 |
-
\item A multi-task architecture supporting both generative (summarization) and discriminative (classification) tasks
|
| 92 |
-
\item A curated dataset of 95,898 training samples across literary, academic, and conversational domains
|
| 93 |
-
\item Detailed documentation of weight transfer mechanisms between pre-trained models and custom implementations
|
| 94 |
-
\item Comprehensive training infrastructure with mixed-precision support, gradient monitoring, and MLflow experiment tracking
|
| 95 |
-
\end{itemize}
|
| 96 |
-
|
| 97 |
-
%=============================================================================
|
| 98 |
-
\section{Related Work}
|
| 99 |
-
%=============================================================================
|
| 100 |
-
|
| 101 |
-
\subsection{Transformer Architectures}
|
| 102 |
-
|
| 103 |
-
The original Transformer \cite{vaswani2017attention} introduced the self-attention mechanism, enabling parallel processing of sequences and effective capture of long-range dependencies. The architecture consists of stacked encoder and decoder blocks, each containing Multi-Head Attention (MHA) and position-wise Feed-Forward Networks (FFN).
|
| 104 |
-
|
| 105 |
-
\textbf{Layer Normalization Placement}: The original Transformer applied Layer Normalization \cite{ba2016layer} after residual connections (Post-LN). Subsequent research \cite{xiong2020layer} demonstrated that applying normalization before sublayers (Pre-LN) significantly improves training stability, particularly for deep networks. LexiMind adopts the Pre-LN configuration used by T5 and modern large language models.
|
| 106 |
-
|
| 107 |
-
\textbf{RMSNorm}: Zhang and Sennrich \cite{zhang2019root} proposed Root Mean Square Layer Normalization (RMSNorm), which removes the mean-centering operation of standard LayerNorm while maintaining comparable performance. T5 \cite{raffel2020exploring} adopts this approach, and LexiMind follows suit for compatibility.
|
| 108 |
-
|
| 109 |
-
\subsection{Pre-trained Language Models}
|
| 110 |
-
|
| 111 |
-
\textbf{T5}: Raffel et al. \cite{raffel2020exploring} introduced the T5 model, which frames all NLP tasks as text-to-text problems. T5 uses a Transformer encoder-decoder architecture with several distinctive features: relative position bias instead of absolute positional embeddings, RMSNorm for layer normalization, and a gated feed-forward network.
|
| 112 |
-
|
| 113 |
-
\textbf{FLAN-T5}: Chung et al. \cite{chung2022scaling} enhanced T5 through instruction fine-tuning, creating FLAN-T5. By training on diverse task instructions, FLAN-T5 demonstrates improved zero-shot and few-shot capabilities compared to the original T5.
|
| 114 |
-
|
| 115 |
-
\subsection{Multi-Task Learning}
|
| 116 |
-
|
| 117 |
-
Multi-Task Learning (MTL) \cite{caruana1997multitask} trains a single model on multiple related tasks, promoting parameter sharing and implicit data augmentation. Hard parameter sharing—where lower layers are shared across tasks while task-specific heads branch from shared representations—remains the dominant approach for Transformer-based MTL systems.
|
| 118 |
-
|
| 119 |
-
%=============================================================================
|
| 120 |
-
\section{Architecture}
|
| 121 |
-
%=============================================================================
|
| 122 |
-
|
| 123 |
-
LexiMind implements a complete encoder-decoder Transformer with task-specific heads, constructed using a bottom-up approach where each component is implemented as a standalone module. Figure \ref{fig:architecture} illustrates the high-level system architecture.
|
| 124 |
-
|
| 125 |
-
\begin{figure}[htbp]
|
| 126 |
-
\centering
|
| 127 |
-
\begin{tikzpicture}[
|
| 128 |
-
scale=0.75,
|
| 129 |
-
transform shape,
|
| 130 |
-
box/.style={draw, rectangle, minimum width=2cm, minimum height=0.7cm, align=center, rounded corners=2pt},
|
| 131 |
-
smallbox/.style={draw, rectangle, minimum width=1.4cm, minimum height=0.5cm, align=center, rounded corners=2pt, font=\scriptsize},
|
| 132 |
-
head/.style={draw, rectangle, minimum width=1.5cm, minimum height=0.6cm, align=center, rounded corners=2pt, fill=blue!20},
|
| 133 |
-
arrow/.style={->, >=stealth, thick},
|
| 134 |
-
dashedarrow/.style={->, >=stealth, dashed}
|
| 135 |
-
]
|
| 136 |
-
|
| 137 |
-
% Input
|
| 138 |
-
\node[box, fill=gray!20] (input) at (0, 0) {Input Text};
|
| 139 |
-
|
| 140 |
-
% Tokenizer
|
| 141 |
-
\node[box, fill=yellow!30] (tokenizer) at (0, 1.2) {Tokenizer\\(SentencePiece)};
|
| 142 |
-
|
| 143 |
-
% Encoder
|
| 144 |
-
\node[box, fill=green!30, minimum height=2cm] (encoder) at (0, 3.2) {Encoder\\$N=12$ layers};
|
| 145 |
-
|
| 146 |
-
% Task routing
|
| 147 |
-
\node[box, fill=orange!30] (router) at (0, 5.2) {Task Router};
|
| 148 |
-
|
| 149 |
-
% Decoder branch
|
| 150 |
-
\node[box, fill=green!30, minimum height=1.5cm] (decoder) at (-2.5, 7) {Decoder\\$N=12$ layers};
|
| 151 |
-
\node[head] (lmhead) at (-2.5, 8.8) {LM Head};
|
| 152 |
-
\node[smallbox, fill=purple!20] (summ) at (-2.5, 9.8) {Summary};
|
| 153 |
-
|
| 154 |
-
% Classification branch
|
| 155 |
-
\node[head] (emotionhead) at (1.2, 7) {Emotion\\Head};
|
| 156 |
-
\node[head] (topichead) at (2.8, 7) {Topic\\Head};
|
| 157 |
-
\node[smallbox, fill=purple!20] (emotion) at (1.2, 8.2) {Emotions\\(28 classes)};
|
| 158 |
-
\node[smallbox, fill=purple!20] (topic) at (2.8, 8.2) {Topics\\(7 classes)};
|
| 159 |
-
|
| 160 |
-
% Arrows
|
| 161 |
-
\draw[arrow] (input) -- (tokenizer);
|
| 162 |
-
\draw[arrow] (tokenizer) -- (encoder);
|
| 163 |
-
\draw[arrow] (encoder) -- (router);
|
| 164 |
-
\draw[arrow] (router) -- (decoder);
|
| 165 |
-
\draw[arrow] (router) -- (emotionhead);
|
| 166 |
-
\draw[arrow] (router) -- (topichead);
|
| 167 |
-
\draw[arrow] (decoder) -- (lmhead);
|
| 168 |
-
\draw[arrow] (lmhead) -- (summ);
|
| 169 |
-
\draw[arrow] (emotionhead) -- (emotion);
|
| 170 |
-
\draw[arrow] (topichead) -- (topic);
|
| 171 |
-
|
| 172 |
-
% Cross-attention arrow
|
| 173 |
-
\draw[dashedarrow] (encoder.west) -- ++(-0.5,0) |- (decoder.south west);
|
| 174 |
-
|
| 175 |
-
% Labels
|
| 176 |
-
\node[font=\tiny, align=center] at (-1.8, 4.5) {Cross\\Attention};
|
| 177 |
-
|
| 178 |
-
\end{tikzpicture}
|
| 179 |
-
\caption{LexiMind system architecture showing the shared encoder, task-specific routing, decoder for generation, and classification heads for discriminative tasks.}
|
| 180 |
-
\label{fig:architecture}
|
| 181 |
-
\end{figure}
|
| 182 |
-
|
| 183 |
-
\subsection{Transformer Block Structure}
|
| 184 |
-
|
| 185 |
-
Figure \ref{fig:transformer_block} presents the internal structure of encoder and decoder blocks, following the Pre-LN configuration from T5 \cite{raffel2020exploring}.
|
| 186 |
-
|
| 187 |
-
\begin{figure}[htbp]
|
| 188 |
-
\centering
|
| 189 |
-
\begin{tikzpicture}[
|
| 190 |
-
scale=0.65,
|
| 191 |
-
transform shape,
|
| 192 |
-
block/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.6cm, align=center, rounded corners=2pt},
|
| 193 |
-
norm/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.5cm, align=center, fill=yellow!30, rounded corners=2pt},
|
| 194 |
-
attn/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.6cm, align=center, fill=blue!25, rounded corners=2pt},
|
| 195 |
-
ffn/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.6cm, align=center, fill=green!25, rounded corners=2pt},
|
| 196 |
-
add/.style={draw, circle, minimum size=0.4cm, fill=red!20, inner sep=0pt, font=\small},
|
| 197 |
-
arrow/.style={->, >=stealth},
|
| 198 |
-
]
|
| 199 |
-
|
| 200 |
-
% === ENCODER BLOCK ===
|
| 201 |
-
\node[font=\bfseries] at (0, 8) {Encoder Block};
|
| 202 |
-
|
| 203 |
-
% Input
|
| 204 |
-
\node (enc_in) at (0, 7) {};
|
| 205 |
-
\draw[arrow] (0, 6.5) -- (enc_in);
|
| 206 |
-
|
| 207 |
-
% RMSNorm 1
|
| 208 |
-
\node[norm] (enc_norm1) at (0, 6) {RMSNorm};
|
| 209 |
-
|
| 210 |
-
% Self-Attention
|
| 211 |
-
\node[attn] (enc_attn) at (0, 5) {Multi-Head\\Self-Attention};
|
| 212 |
-
|
| 213 |
-
% Add 1
|
| 214 |
-
\node[add] (enc_add1) at (0, 4) {+};
|
| 215 |
-
|
| 216 |
-
% RMSNorm 2
|
| 217 |
-
\node[norm] (enc_norm2) at (0, 3) {RMSNorm};
|
| 218 |
-
|
| 219 |
-
% FFN
|
| 220 |
-
\node[ffn] (enc_ffn) at (0, 2) {Gated FFN\\(GELU)};
|
| 221 |
-
|
| 222 |
-
% Add 2
|
| 223 |
-
\node[add] (enc_add2) at (0, 1) {+};
|
| 224 |
-
|
| 225 |
-
% Output
|
| 226 |
-
\node (enc_out) at (0, 0.3) {};
|
| 227 |
-
|
| 228 |
-
% Connections
|
| 229 |
-
\draw[arrow] (enc_in) -- (enc_norm1);
|
| 230 |
-
\draw[arrow] (enc_norm1) -- (enc_attn);
|
| 231 |
-
\draw[arrow] (enc_attn) -- (enc_add1);
|
| 232 |
-
\draw[arrow] (enc_add1) -- (enc_norm2);
|
| 233 |
-
\draw[arrow] (enc_norm2) -- (enc_ffn);
|
| 234 |
-
\draw[arrow] (enc_ffn) -- (enc_add2);
|
| 235 |
-
\draw[arrow] (enc_add2) -- (enc_out);
|
| 236 |
-
|
| 237 |
-
% Residual connections
|
| 238 |
-
\draw[arrow] (0, 6.5) -- (-1.5, 6.5) -- (-1.5, 4) -- (enc_add1.west);
|
| 239 |
-
\draw[arrow] (enc_add1.east) -- (1.5, 4) -- (1.5, 1) -- (enc_add2.east);
|
| 240 |
-
|
| 241 |
-
% === DECODER BLOCK ===
|
| 242 |
-
\node[font=\bfseries] at (5.5, 8) {Decoder Block};
|
| 243 |
-
|
| 244 |
-
% Input
|
| 245 |
-
\node (dec_in) at (5.5, 7) {};
|
| 246 |
-
\draw[arrow] (5.5, 6.5) -- (dec_in);
|
| 247 |
-
|
| 248 |
-
% RMSNorm 1
|
| 249 |
-
\node[norm] (dec_norm1) at (5.5, 6) {RMSNorm};
|
| 250 |
-
|
| 251 |
-
% Masked Self-Attention
|
| 252 |
-
\node[attn] (dec_attn1) at (5.5, 5) {Masked\\Self-Attention};
|
| 253 |
-
|
| 254 |
-
% Add 1
|
| 255 |
-
\node[add] (dec_add1) at (5.5, 4.2) {+};
|
| 256 |
-
|
| 257 |
-
% RMSNorm 2
|
| 258 |
-
\node[norm] (dec_norm2) at (5.5, 3.4) {RMSNorm};
|
| 259 |
-
|
| 260 |
-
% Cross-Attention
|
| 261 |
-
\node[attn, fill=cyan!25] (dec_attn2) at (5.5, 2.4) {Cross-Attention};
|
| 262 |
-
|
| 263 |
-
% Add 2
|
| 264 |
-
\node[add] (dec_add2) at (5.5, 1.5) {+};
|
| 265 |
-
|
| 266 |
-
% RMSNorm 3
|
| 267 |
-
\node[norm] (dec_norm3) at (5.5, 0.7) {RMSNorm};
|
| 268 |
-
|
| 269 |
-
% FFN
|
| 270 |
-
\node[ffn] (dec_ffn) at (5.5, -0.3) {Gated FFN\\(GELU)};
|
| 271 |
-
|
| 272 |
-
% Add 3
|
| 273 |
-
\node[add] (dec_add3) at (5.5, -1.2) {+};
|
| 274 |
-
|
| 275 |
-
% Connections
|
| 276 |
-
\draw[arrow] (dec_in) -- (dec_norm1);
|
| 277 |
-
\draw[arrow] (dec_norm1) -- (dec_attn1);
|
| 278 |
-
\draw[arrow] (dec_attn1) -- (dec_add1);
|
| 279 |
-
\draw[arrow] (dec_add1) -- (dec_norm2);
|
| 280 |
-
\draw[arrow] (dec_norm2) -- (dec_attn2);
|
| 281 |
-
\draw[arrow] (dec_attn2) -- (dec_add2);
|
| 282 |
-
\draw[arrow] (dec_add2) -- (dec_norm3);
|
| 283 |
-
\draw[arrow] (dec_norm3) -- (dec_ffn);
|
| 284 |
-
\draw[arrow] (dec_ffn) -- (dec_add3);
|
| 285 |
-
|
| 286 |
-
% Encoder memory input
|
| 287 |
-
\node[block, fill=gray!20, minimum width=1.2cm, font=\scriptsize] (memory) at (8, 2.4) {Encoder\\Memory};
|
| 288 |
-
\draw[arrow] (memory) -- (dec_attn2);
|
| 289 |
-
|
| 290 |
-
% Residual connections (simplified)
|
| 291 |
-
\draw[arrow] (5.5, 6.5) -- (4, 6.5) -- (4, 4.2) -- (dec_add1.west);
|
| 292 |
-
\draw[arrow] (dec_add1.east) -- (7, 4.2) -- (7, 1.5) -- (dec_add2.east);
|
| 293 |
-
\draw[arrow] (dec_add2.west) -- (4, 1.5) -- (4, -1.2) -- (dec_add3.west);
|
| 294 |
-
|
| 295 |
-
\end{tikzpicture}
|
| 296 |
-
\caption{Pre-LN Transformer blocks. Left: Encoder block with self-attention and FFN. Right: Decoder block with masked self-attention, cross-attention to encoder memory, and FFN. RMSNorm is applied \emph{before} each sublayer (Pre-LN).}
|
| 297 |
-
\label{fig:transformer_block}
|
| 298 |
-
\end{figure}
|
| 299 |
-
|
| 300 |
-
\subsection{Multi-Head Attention Mechanism}
|
| 301 |
-
|
| 302 |
-
The attention mechanism is the cornerstone of the Transformer architecture. LexiMind implements Multi-Head Attention with support for T5-style relative position bias and FlashAttention optimization. Figure \ref{fig:attention} illustrates the attention computation.
|
| 303 |
-
|
| 304 |
-
\begin{figure}[htbp]
|
| 305 |
-
\centering
|
| 306 |
-
\begin{tikzpicture}[
|
| 307 |
-
scale=0.6,
|
| 308 |
-
transform shape,
|
| 309 |
-
box/.style={draw, rectangle, minimum width=1.5cm, minimum height=0.6cm, align=center, rounded corners=2pt},
|
| 310 |
-
proj/.style={draw, rectangle, minimum width=1.2cm, minimum height=0.5cm, align=center, fill=blue!20, rounded corners=2pt, font=\scriptsize},
|
| 311 |
-
op/.style={draw, rectangle, minimum width=1.2cm, minimum height=0.5cm, align=center, fill=orange!30, rounded corners=2pt, font=\scriptsize},
|
| 312 |
-
arrow/.style={->, >=stealth},
|
| 313 |
-
]
|
| 314 |
-
|
| 315 |
-
% Input
|
| 316 |
-
\node[box, fill=gray!20] (input) at (0, 0) {Input $X$};
|
| 317 |
-
|
| 318 |
-
% Projections
|
| 319 |
-
\node[proj] (wq) at (-2.5, 1.5) {$W_Q$};
|
| 320 |
-
\node[proj] (wk) at (0, 1.5) {$W_K$};
|
| 321 |
-
\node[proj] (wv) at (2.5, 1.5) {$W_V$};
|
| 322 |
-
|
| 323 |
-
% Q, K, V
|
| 324 |
-
\node[box, fill=green!20] (q) at (-2.5, 2.8) {$Q$};
|
| 325 |
-
\node[box, fill=green!20] (k) at (0, 2.8) {$K$};
|
| 326 |
-
\node[box, fill=green!20] (v) at (2.5, 2.8) {$V$};
|
| 327 |
-
|
| 328 |
-
% Split heads
|
| 329 |
-
\node[op] (split) at (0, 4) {Split $h$ heads};
|
| 330 |
-
|
| 331 |
-
% Attention scores
|
| 332 |
-
\node[op] (matmul1) at (0, 5.2) {$QK^T$};
|
| 333 |
-
|
| 334 |
-
% Position bias
|
| 335 |
-
\node[box, fill=yellow!30, font=\scriptsize] (bias) at (3.5, 5.2) {Relative\\Pos Bias};
|
| 336 |
-
|
| 337 |
-
% Add bias
|
| 338 |
-
\node[op] (add) at (0, 6.2) {$+ B_{rel}$};
|
| 339 |
-
|
| 340 |
-
% Scale (optional)
|
| 341 |
-
\node[op] (scale) at (0, 7.2) {Scale / Mask};
|
| 342 |
-
|
| 343 |
-
% Softmax
|
| 344 |
-
\node[op, fill=red!20] (softmax) at (0, 8.2) {Softmax};
|
| 345 |
-
|
| 346 |
-
% MatMul with V
|
| 347 |
-
\node[op] (matmul2) at (0, 9.2) {$\times V$};
|
| 348 |
-
|
| 349 |
-
% Concat
|
| 350 |
-
\node[op] (concat) at (0, 10.2) {Concat heads};
|
| 351 |
-
|
| 352 |
-
% Output projection
|
| 353 |
-
\node[proj] (wo) at (0, 11.2) {$W_O$};
|
| 354 |
-
|
| 355 |
-
% Output
|
| 356 |
-
\node[box, fill=purple!20] (output) at (0, 12.2) {Output};
|
| 357 |
-
|
| 358 |
-
% Arrows
|
| 359 |
-
\draw[arrow] (input) -- (wq);
|
| 360 |
-
\draw[arrow] (input) -- (wk);
|
| 361 |
-
\draw[arrow] (input) -- (wv);
|
| 362 |
-
\draw[arrow] (wq) -- (q);
|
| 363 |
-
\draw[arrow] (wk) -- (k);
|
| 364 |
-
\draw[arrow] (wv) -- (v);
|
| 365 |
-
\draw[arrow] (q) -- (split);
|
| 366 |
-
\draw[arrow] (k) -- (split);
|
| 367 |
-
\draw[arrow] (v.north) -- ++(0, 0.3) -| (2.5, 9.2) -- (matmul2);
|
| 368 |
-
\draw[arrow] (split) -- (matmul1);
|
| 369 |
-
\draw[arrow] (matmul1) -- (add);
|
| 370 |
-
\draw[arrow] (bias) -- (add);
|
| 371 |
-
\draw[arrow] (add) -- (scale);
|
| 372 |
-
\draw[arrow] (scale) -- (softmax);
|
| 373 |
-
\draw[arrow] (softmax) -- (matmul2);
|
| 374 |
-
\draw[arrow] (matmul2) -- (concat);
|
| 375 |
-
\draw[arrow] (concat) -- (wo);
|
| 376 |
-
\draw[arrow] (wo) -- (output);
|
| 377 |
-
|
| 378 |
-
% Annotations
|
| 379 |
-
\node[font=\tiny, align=left] at (-4.5, 5.5) {T5 does NOT\\scale by $\sqrt{d_k}$};
|
| 380 |
-
|
| 381 |
-
\end{tikzpicture}
|
| 382 |
-
\caption{Multi-Head Attention with T5-style relative position bias. The attention scores are computed as $QK^T + B_{rel}$, where $B_{rel}$ is the learned relative position bias. Unlike standard Transformers, T5 does not scale by $\sqrt{d_k}$.}
|
| 383 |
-
\label{fig:attention}
|
| 384 |
-
\end{figure}
|
| 385 |
-
|
| 386 |
-
The attention computation in LexiMind is implemented in \texttt{src/models/attention.py}. For T5 compatibility, the \texttt{scale\_scores} parameter controls whether to apply $\sqrt{d_k}$ scaling—T5 does not use this scaling \cite{raffel2020exploring}.
|
| 387 |
-
|
| 388 |
-
Figure \ref{fig:attention_viz} shows learned attention patterns from the trained model, demonstrating how different heads specialize in capturing various linguistic relationships.
|
| 389 |
-
|
| 390 |
-
\begin{figure}[htbp]
|
| 391 |
-
\centering
|
| 392 |
-
\includegraphics[width=\columnwidth]{figures/multihead_attention_visualization.png}
|
| 393 |
-
\caption{Attention weight visualization across multiple heads. Each head learns distinct attention patterns: some focus on local context (diagonal patterns), while others capture long-range dependencies and syntactic relationships.}
|
| 394 |
-
\label{fig:attention_viz}
|
| 395 |
-
\end{figure}
|
| 396 |
-
|
| 397 |
-
\subsubsection{T5 Relative Position Bias}
|
| 398 |
-
|
| 399 |
-
Unlike absolute positional embeddings that are added to token embeddings, T5 uses relative position bias added directly to attention scores. The \texttt{T5RelativePositionBias} class implements logarithmically-bucketed relative positions:
|
| 400 |
-
|
| 401 |
-
\begin{equation}
|
| 402 |
-
B_{ij} = \text{Embed}[\text{bucket}(i - j)]
|
| 403 |
-
\end{equation}
|
| 404 |
-
|
| 405 |
-
where $\text{bucket}(\cdot)$ maps relative distances to discrete buckets. Half the buckets encode exact positions for nearby tokens; the remaining half use logarithmic spacing for distant tokens. As documented in the code:
|
| 406 |
-
|
| 407 |
-
\begin{quote}
|
| 408 |
-
\emph{``T5 uses a combination of exact positions (for nearby tokens) and logarithmically-spaced buckets (for distant tokens).''} — \texttt{attention.py}, lines 46--48
|
| 409 |
-
\end{quote}
|
| 410 |
-
|
| 411 |
-
Figure \ref{fig:position_bias} visualizes the learned relative position bias, showing how the model encodes positional relationships between tokens.
|
| 412 |
-
|
| 413 |
-
\begin{figure}[htbp]
|
| 414 |
-
\centering
|
| 415 |
-
\includegraphics[width=\columnwidth]{figures/positional_encoding_heatmap.png}
|
| 416 |
-
\caption{Heatmap of relative position bias values. The diagonal structure indicates stronger attention between nearby positions, while the logarithmic bucketing allows efficient representation of longer-range dependencies.}
|
| 417 |
-
\label{fig:position_bias}
|
| 418 |
-
\end{figure}
|
| 419 |
-
|
| 420 |
-
\subsubsection{FlashAttention Integration}
|
| 421 |
-
|
| 422 |
-
LexiMind leverages PyTorch 2.0's \texttt{scaled\_dot\_product\_attention} function, which automatically selects the optimal attention kernel:
|
| 423 |
-
|
| 424 |
-
\begin{quote}
|
| 425 |
-
\emph{``Uses F.scaled\_dot\_product\_attention which automatically selects the best available kernel (FlashAttention v2, Memory-Efficient Attention, or math fallback) based on hardware and input shapes.''} — \texttt{attention.py}, lines 134--137
|
| 426 |
-
\end{quote}
|
| 427 |
-
|
| 428 |
-
This provides O(N) memory complexity instead of O(N²) when FlashAttention is available.
|
| 429 |
-
|
| 430 |
-
\subsection{Feed-Forward Network}
|
| 431 |
-
|
| 432 |
-
Following T5, LexiMind implements a gated feed-forward network with GELU activation:
|
| 433 |
-
|
| 434 |
-
\begin{equation}
|
| 435 |
-
\text{FFN}(x) = (\text{GELU}(xW_g) \odot xW_1) W_2
|
| 436 |
-
\end{equation}
|
| 437 |
-
|
| 438 |
-
where $W_g$ is the gating projection, $W_1$ is the up-projection, $W_2$ is the down-projection, and $\odot$ denotes element-wise multiplication.
|
| 439 |
-
|
| 440 |
-
\subsection{RMSNorm}
|
| 441 |
-
|
| 442 |
-
RMSNorm \cite{zhang2019root} normalizes inputs using only the root mean square:
|
| 443 |
-
|
| 444 |
-
\begin{equation}
|
| 445 |
-
\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2 + \epsilon}} \cdot \gamma
|
| 446 |
-
\end{equation}
|
| 447 |
-
|
| 448 |
-
The implementation in \texttt{src/models/t5\_layer\_norm.py} follows T5's convention, using only a learned scale parameter $\gamma$ with no bias term.
|
| 449 |
-
|
| 450 |
-
%=============================================================================
|
| 451 |
-
\section{Tokenization}
|
| 452 |
-
\label{sec:tokenization}
|
| 453 |
-
%=============================================================================
|
| 454 |
-
|
| 455 |
-
LexiMind wraps HuggingFace's AutoTokenizer with a simplified façade that handles T5-specific conventions. The implementation in \texttt{src/data/tokenization.py} manages special token handling and decoder input preparation.
|
| 456 |
-
|
| 457 |
-
\subsection{T5 Tokenizer Characteristics}
|
| 458 |
-
|
| 459 |
-
T5 uses SentencePiece \cite{kudo2018sentencepiece} with unigram tokenization:
|
| 460 |
-
|
| 461 |
-
\begin{itemize}
|
| 462 |
-
\item \textbf{Vocabulary Size}: 32,128 tokens (padded to multiple of 128 for efficiency)
|
| 463 |
-
\item \textbf{Special Tokens}: \texttt{pad\_token\_id=0}, \texttt{eos\_token\_id=1}
|
| 464 |
-
\item \textbf{No Explicit BOS}: T5 uses the pad token as the decoder start token
|
| 465 |
-
\end{itemize}
|
| 466 |
-
|
| 467 |
-
As noted in the tokenizer implementation:
|
| 468 |
-
|
| 469 |
-
\begin{quote}
|
| 470 |
-
\emph{``T5 uses different special tokens than BART: T5: pad=0, eos=1, no explicit bos (uses pad or eos as decoder start); BART: bos=0, pad=1, eos=2.''} — \texttt{tokenization.py}, lines 42--44
|
| 471 |
-
\end{quote}
|
| 472 |
-
|
| 473 |
-
\subsection{Decoder Input Preparation}
|
| 474 |
-
|
| 475 |
-
For seq2seq training, decoder inputs must be shifted right from labels. The \texttt{prepare\_decoder\_inputs} method handles this:
|
| 476 |
-
|
| 477 |
-
\begin{lstlisting}[caption={Decoder input preparation from tokenization.py}]
|
| 478 |
-
def prepare_decoder_inputs(
|
| 479 |
-
self, labels: torch.Tensor
|
| 480 |
-
) -> torch.Tensor:
|
| 481 |
-
"""Shift decoder labels to create
|
| 482 |
-
input ids prefixed by BOS."""
|
| 483 |
-
bos = self.bos_token_id
|
| 484 |
-
pad = self.pad_token_id
|
| 485 |
-
decoder_inputs = torch.full_like(labels, pad)
|
| 486 |
-
decoder_inputs[:, 0] = bos
|
| 487 |
-
decoder_inputs[:, 1:] = labels[:, :-1]
|
| 488 |
-
return decoder_inputs
|
| 489 |
-
\end{lstlisting}
|
| 490 |
-
|
| 491 |
-
%=============================================================================
|
| 492 |
-
\section{The Factory Module: Weight Transfer from FLAN-T5}
|
| 493 |
-
\label{sec:factory}
|
| 494 |
-
%=============================================================================
|
| 495 |
-
|
| 496 |
-
The \texttt{factory.py} module is central to LexiMind's hybrid approach, providing model construction and weight loading utilities. Figure \ref{fig:factory_flow} illustrates the model construction pipeline.
|
| 497 |
-
|
| 498 |
-
\begin{figure}[htbp]
|
| 499 |
-
\centering
|
| 500 |
-
\begin{tikzpicture}[
|
| 501 |
-
scale=0.7,
|
| 502 |
-
transform shape,
|
| 503 |
-
box/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.7cm, align=center, rounded corners=3pt},
|
| 504 |
-
config/.style={draw, rectangle, minimum width=2cm, minimum height=0.6cm, align=center, fill=yellow!30, rounded corners=2pt, font=\small},
|
| 505 |
-
model/.style={draw, rectangle, minimum width=2.2cm, minimum height=0.6cm, align=center, fill=green!30, rounded corners=2pt, font=\small},
|
| 506 |
-
arrow/.style={->, >=stealth, thick},
|
| 507 |
-
]
|
| 508 |
-
|
| 509 |
-
% Config loading
|
| 510 |
-
\node[config] (yaml) at (0, 0) {config.yaml};
|
| 511 |
-
\node[box, fill=blue!20] (loadconfig) at (0, 1.3) {load\_model\_config()};
|
| 512 |
-
\node[config] (modelconfig) at (0, 2.6) {ModelConfig};
|
| 513 |
-
|
| 514 |
-
% Model building
|
| 515 |
-
\node[box, fill=blue!20] (build) at (0, 4.2) {build\_multitask\_model()};
|
| 516 |
-
|
| 517 |
-
% Components
|
| 518 |
-
\node[model] (encoder) at (-2.5, 5.8) {Encoder};
|
| 519 |
-
\node[model] (decoder) at (0, 5.8) {Decoder};
|
| 520 |
-
\node[model] (heads) at (2.5, 5.8) {Task Heads};
|
| 521 |
-
|
| 522 |
-
% Weight loading
|
| 523 |
-
\node[box, fill=orange!30] (loadweights) at (-1.2, 7.4) {\_load\_pretrained\_weights()};
|
| 524 |
-
|
| 525 |
-
% FLAN-T5
|
| 526 |
-
\node[box, fill=purple!20] (flant5) at (-4, 7.4) {FLAN-T5\\(HuggingFace)};
|
| 527 |
-
|
| 528 |
-
% Final model
|
| 529 |
-
\node[box, fill=red!20, minimum width=3cm] (mtmodel) at (0, 9) {MultiTaskModel};
|
| 530 |
-
|
| 531 |
-
% Arrows
|
| 532 |
-
\draw[arrow] (yaml) -- (loadconfig);
|
| 533 |
-
\draw[arrow] (loadconfig) -- (modelconfig);
|
| 534 |
-
\draw[arrow] (modelconfig) -- (build);
|
| 535 |
-
\draw[arrow] (build) -- (encoder);
|
| 536 |
-
\draw[arrow] (build) -- (decoder);
|
| 537 |
-
\draw[arrow] (build) -- (heads);
|
| 538 |
-
\draw[arrow] (encoder) -- (loadweights);
|
| 539 |
-
\draw[arrow] (decoder) -- (loadweights);
|
| 540 |
-
\draw[arrow] (flant5) -- (loadweights);
|
| 541 |
-
\draw[arrow] (loadweights) -- (mtmodel);
|
| 542 |
-
\draw[arrow] (heads) -- (mtmodel);
|
| 543 |
-
|
| 544 |
-
\end{tikzpicture}
|
| 545 |
-
\caption{Model construction pipeline in \texttt{factory.py}. Configuration is loaded from YAML, components are instantiated, FLAN-T5 weights are transferred, and the final MultiTaskModel is assembled.}
|
| 546 |
-
\label{fig:factory_flow}
|
| 547 |
-
\end{figure}
|
| 548 |
-
|
| 549 |
-
\subsection{Configuration Management}
|
| 550 |
-
|
| 551 |
-
The \texttt{ModelConfig} dataclass defines all architecture hyperparameters:
|
| 552 |
-
|
| 553 |
-
\begin{lstlisting}[caption={ModelConfig from factory.py}]
|
| 554 |
-
@dataclass
|
| 555 |
-
class ModelConfig:
|
| 556 |
-
d_model: int = 768
|
| 557 |
-
vocab_size: Optional[int] = None
|
| 558 |
-
num_encoder_layers: int = 12
|
| 559 |
-
num_decoder_layers: int = 12
|
| 560 |
-
num_attention_heads: int = 12
|
| 561 |
-
ffn_dim: int = 2048
|
| 562 |
-
dropout: float = 0.1
|
| 563 |
-
use_pretrained: bool = False
|
| 564 |
-
pretrained_model_name: str =
|
| 565 |
-
"google/flan-t5-base"
|
| 566 |
-
activation: str = "gated-gelu"
|
| 567 |
-
use_relative_position_bias: bool = False
|
| 568 |
-
\end{lstlisting}
|
| 569 |
-
|
| 570 |
-
\subsection{Weight Transfer Mechanism}
|
| 571 |
-
|
| 572 |
-
The \texttt{\_load\_pretrained\_weights} function performs careful weight mapping between FLAN-T5 and LexiMind's custom architecture. Key considerations documented in the code:
|
| 573 |
-
|
| 574 |
-
\begin{quote}
|
| 575 |
-
\emph{``T5 architecture compatibility with our custom Transformer: T5 uses Pre-LN (RMSNorm before sublayers) --- matches our design; T5 uses relative position bias instead of absolute embeddings; T5 uses gated FFN (wi\_0, wi\_1, wo); T5 attention has no bias, our attention has bias --- we zero-initialize the bias terms.''} --- \texttt{factory.py}, lines 100--108
|
| 576 |
-
\end{quote}
|
| 577 |
-
|
| 578 |
-
Table \ref{tab:weight_mapping} shows the parameter correspondence:
|
| 579 |
-
|
| 580 |
-
\begin{table}[htbp]
|
| 581 |
-
\centering
|
| 582 |
-
\caption{FLAN-T5 to LexiMind Weight Mapping}
|
| 583 |
-
\label{tab:weight_mapping}
|
| 584 |
-
\begin{tabular}{ll}
|
| 585 |
-
\toprule
|
| 586 |
-
\textbf{FLAN-T5 Parameter} & \textbf{LexiMind Parameter} \\
|
| 587 |
-
\midrule
|
| 588 |
-
\texttt{shared} & \texttt{encoder.embedding} \\
|
| 589 |
-
\texttt{encoder.block.*.SelfAttention.q} & \texttt{encoder.layers.*.self\_attn.W\_Q} \\
|
| 590 |
-
\texttt{encoder.block.*.SelfAttention.k} & \texttt{encoder.layers.*.self\_attn.W\_K} \\
|
| 591 |
-
\texttt{encoder.block.*.SelfAttention.v} & \texttt{encoder.layers.*.self\_attn.W\_V} \\
|
| 592 |
-
\texttt{encoder.block.*.SelfAttention.o} & \texttt{encoder.layers.*.self\_attn.W\_O} \\
|
| 593 |
-
\texttt{*.layer\_norm} & \texttt{*.norm*} \\
|
| 594 |
-
\texttt{*.DenseReluDense.wi\_0} & \texttt{*.ffn.linear\_gate} \\
|
| 595 |
-
\texttt{*.DenseReluDense.wi\_1} & \texttt{*.ffn.linear1} \\
|
| 596 |
-
\texttt{*.DenseReluDense.wo} & \texttt{*.ffn.linear2} \\
|
| 597 |
-
\texttt{lm\_head} & \texttt{decoder.output\_projection} \\
|
| 598 |
-
\bottomrule
|
| 599 |
-
\end{tabular}
|
| 600 |
-
\end{table}
|
| 601 |
-
|
| 602 |
-
\subsection{Vocabulary Size Handling}
|
| 603 |
-
|
| 604 |
-
T5 pads its vocabulary to multiples of 128 for computational efficiency (32,100 → 32,128). LexiMind handles this mismatch:
|
| 605 |
-
|
| 606 |
-
\begin{quote}
|
| 607 |
-
\emph{``Note: T5's vocab is padded to multiple of 128 for efficiency (32100 → 32128). [...] Copy only the tokens that exist in both. Initialize any extra tokens with small random values.''} — \texttt{factory.py}, lines 116--131
|
| 608 |
-
\end{quote}
|
| 609 |
-
|
| 610 |
-
\subsection{Model Assembly}
|
| 611 |
-
|
| 612 |
-
The \texttt{build\_multitask\_model} function assembles the complete system:
|
| 613 |
-
|
| 614 |
-
\begin{lstlisting}[caption={Model assembly from factory.py}]
|
| 615 |
-
model = MultiTaskModel(
|
| 616 |
-
encoder=encoder,
|
| 617 |
-
decoder=decoder,
|
| 618 |
-
decoder_outputs_logits=True
|
| 619 |
-
)
|
| 620 |
-
model.add_head(
|
| 621 |
-
"summarization",
|
| 622 |
-
LMHead(d_model=cfg.d_model,
|
| 623 |
-
vocab_size=vocab_size,
|
| 624 |
-
tie_embedding=decoder.embedding)
|
| 625 |
-
)
|
| 626 |
-
model.add_head(
|
| 627 |
-
"emotion",
|
| 628 |
-
ClassificationHead(
|
| 629 |
-
d_model=cfg.d_model,
|
| 630 |
-
num_labels=28, # GoEmotions
|
| 631 |
-
pooler="mean",
|
| 632 |
-
hidden_dim=cfg.d_model // 2)
|
| 633 |
-
)
|
| 634 |
-
model.add_head(
|
| 635 |
-
"topic",
|
| 636 |
-
ClassificationHead(
|
| 637 |
-
d_model=cfg.d_model,
|
| 638 |
-
num_labels=7, # 7 topic categories
|
| 639 |
-
pooler="mean")
|
| 640 |
-
)
|
| 641 |
-
\end{lstlisting}
|
| 642 |
-
|
| 643 |
-
%=============================================================================
|
| 644 |
-
\section{Multi-Task Model Architecture}
|
| 645 |
-
\label{sec:multitask}
|
| 646 |
-
%=============================================================================
|
| 647 |
-
|
| 648 |
-
The \texttt{MultiTaskModel} class in \texttt{src/models/multitask.py} provides the routing infrastructure for multi-task learning. Figure \ref{fig:multitask_routing} illustrates the task routing mechanism.
|
| 649 |
-
|
| 650 |
-
\begin{figure}[htbp]
|
| 651 |
-
\centering
|
| 652 |
-
\begin{tikzpicture}[
|
| 653 |
-
scale=0.7,
|
| 654 |
-
transform shape,
|
| 655 |
-
box/.style={draw, rectangle, minimum width=2cm, minimum height=0.6cm, align=center, rounded corners=2pt},
|
| 656 |
-
decision/.style={draw, diamond, aspect=2, minimum width=1.5cm, align=center, fill=yellow!30},
|
| 657 |
-
arrow/.style={->, >=stealth, thick},
|
| 658 |
-
]
|
| 659 |
-
|
| 660 |
-
% Forward call
|
| 661 |
-
\node[box, fill=blue!20] (forward) at (0, 0) {forward(task, inputs)};
|
| 662 |
-
|
| 663 |
-
% Decision
|
| 664 |
-
\node[decision] (taskcheck) at (0, -1.5) {task type?};
|
| 665 |
-
|
| 666 |
-
% Branches
|
| 667 |
-
\node[box, fill=green!20] (encoder) at (-3.5, -3.5) {Encoder\\Only};
|
| 668 |
-
\node[box, fill=green!20] (seq2seq) at (3.5, -3.5) {Encoder\\+ Decoder};
|
| 669 |
-
|
| 670 |
-
% Heads
|
| 671 |
-
\node[box, fill=orange!20] (classhead) at (-3.5, -5) {Classification\\Head};
|
| 672 |
-
\node[box, fill=orange!20] (lmhead) at (3.5, -5) {LM Head};
|
| 673 |
-
|
| 674 |
-
% Tasks
|
| 675 |
-
\node[box, fill=purple!20, font=\scriptsize] (emotion) at (-5, -6.5) {Emotion};
|
| 676 |
-
\node[box, fill=purple!20, font=\scriptsize] (topic) at (-2, -6.5) {Topic};
|
| 677 |
-
\node[box, fill=purple!20, font=\scriptsize] (summ) at (3.5, -6.5) {Summarization};
|
| 678 |
-
|
| 679 |
-
% Arrows
|
| 680 |
-
\draw[arrow] (forward) -- (taskcheck);
|
| 681 |
-
\draw[arrow] (taskcheck) -- node[above, font=\scriptsize] {Classification} (encoder);
|
| 682 |
-
\draw[arrow] (taskcheck) -- node[above, font=\scriptsize] {Generation} (seq2seq);
|
| 683 |
-
\draw[arrow] (encoder) -- (classhead);
|
| 684 |
-
\draw[arrow] (seq2seq) -- (lmhead);
|
| 685 |
-
\draw[arrow] (classhead) -- (emotion);
|
| 686 |
-
\draw[arrow] (classhead) -- (topic);
|
| 687 |
-
\draw[arrow] (lmhead) -- (summ);
|
| 688 |
-
|
| 689 |
-
\end{tikzpicture}
|
| 690 |
-
\caption{Task routing in MultiTaskModel. Classification tasks use encoder-only processing with mean pooling, while generation tasks use the full encoder-decoder pipeline.}
|
| 691 |
-
\label{fig:multitask_routing}
|
| 692 |
-
\end{figure}
|
| 693 |
-
|
| 694 |
-
\subsection{Task-Specific Head Selection}
|
| 695 |
-
|
| 696 |
-
The forward method routes inputs based on head type:
|
| 697 |
-
|
| 698 |
-
\begin{quote}
|
| 699 |
-
\emph{``Encoder-only heads expect encoder outputs [...] LM/seq2seq head: run encoder → decoder → lm head''} — \texttt{multitask.py}, lines 108--148
|
| 700 |
-
\end{quote}
|
| 701 |
-
|
| 702 |
-
\subsection{Classification Head}
|
| 703 |
-
|
| 704 |
-
Classification tasks (emotion, topic) use mean pooling over encoder outputs:
|
| 705 |
-
|
| 706 |
-
\begin{equation}
|
| 707 |
-
h_{cls} = \frac{\sum_{i=1}^{L} h_i \cdot m_i}{\sum_{i=1}^{L} m_i}
|
| 708 |
-
\end{equation}
|
| 709 |
-
|
| 710 |
-
where $m_i$ is the attention mask (1 for valid tokens, 0 for padding). The pooled representation is projected through a linear layer to class logits.
|
| 711 |
-
|
| 712 |
-
%=============================================================================
|
| 713 |
-
\section{Training Pipeline}
|
| 714 |
-
\label{sec:training}
|
| 715 |
-
%=============================================================================
|
| 716 |
-
|
| 717 |
-
The training infrastructure in \texttt{src/training/trainer.py} implements a comprehensive multi-task training loop with modern deep learning practices.
|
| 718 |
-
|
| 719 |
-
\subsection{Training Configuration}
|
| 720 |
-
|
| 721 |
-
The \texttt{TrainerConfig} dataclass encapsulates all hyperparameters:
|
| 722 |
-
|
| 723 |
-
\begin{lstlisting}[caption={TrainerConfig from trainer.py}]
|
| 724 |
-
@dataclass
|
| 725 |
-
class TrainerConfig:
|
| 726 |
-
max_epochs: int = 1
|
| 727 |
-
gradient_clip_norm: float = 1.0
|
| 728 |
-
task_weights: Dict[str, float] | None = None
|
| 729 |
-
label_smoothing: float = 0.0
|
| 730 |
-
gradient_accumulation_steps: int = 1
|
| 731 |
-
scheduler_type: str = "cosine"
|
| 732 |
-
warmup_steps: int = 0
|
| 733 |
-
early_stopping_patience: int | None = None
|
| 734 |
-
gradient_checkpointing: bool = False
|
| 735 |
-
compile_model: bool = False
|
| 736 |
-
\end{lstlisting}
|
| 737 |
-
|
| 738 |
-
\subsection{Mixed-Precision Training}
|
| 739 |
-
|
| 740 |
-
LexiMind uses Automatic Mixed Precision (AMP) with automatic dtype selection:
|
| 741 |
-
|
| 742 |
-
\begin{quote}
|
| 743 |
-
\emph{``AMP setup: bfloat16 for Ampere+ GPUs, float16 otherwise''} — \texttt{trainer.py}, line 102
|
| 744 |
-
\end{quote}
|
| 745 |
-
|
| 746 |
-
BFloat16 provides better numerical stability for training while maintaining the memory and speed benefits of reduced precision.
|
| 747 |
-
|
| 748 |
-
\subsection{Learning Rate Scheduling}
|
| 749 |
-
|
| 750 |
-
A cosine schedule with linear warmup is implemented:
|
| 751 |
-
|
| 752 |
-
\begin{equation}
|
| 753 |
-
lr(t) = \begin{cases}
|
| 754 |
-
lr_{max} \cdot \frac{t}{t_{warmup}} & t < t_{warmup} \\
|
| 755 |
-
lr_{min} + \frac{1}{2}(lr_{max} - lr_{min})(1 + \cos(\frac{\pi(t-t_{warmup})}{T-t_{warmup}})) & t \geq t_{warmup}
|
| 756 |
-
\end{cases}
|
| 757 |
-
\end{equation}
|
| 758 |
-
|
| 759 |
-
Figure \ref{fig:lr_schedule} visualizes the learning rate schedule over training, showing the 300-step linear warmup followed by cosine decay.
|
| 760 |
-
|
| 761 |
-
\begin{figure}[htbp]
|
| 762 |
-
\centering
|
| 763 |
-
\includegraphics[width=\columnwidth]{figures/learning_rate_schedule.png}
|
| 764 |
-
\caption{Learning rate schedule with linear warmup (300 steps) followed by cosine annealing. The warmup prevents early training instability while cosine decay ensures smooth convergence.}
|
| 765 |
-
\label{fig:lr_schedule}
|
| 766 |
-
\end{figure}
|
| 767 |
-
|
| 768 |
-
\subsection{Multi-Task Loss Computation}
|
| 769 |
-
|
| 770 |
-
The total loss combines task-specific losses with optional weighting:
|
| 771 |
-
|
| 772 |
-
\begin{equation}
|
| 773 |
-
\mathcal{L}_{total} = \sum_{t \in \text{tasks}} \lambda_t \mathcal{L}_t
|
| 774 |
-
\end{equation}
|
| 775 |
-
|
| 776 |
-
\begin{itemize}
|
| 777 |
-
\item \textbf{Summarization}: Cross-entropy with label smoothing and \texttt{ignore\_index=-100}
|
| 778 |
-
\item \textbf{Emotion}: Binary Cross-Entropy with Logits (multi-label)
|
| 779 |
-
\item \textbf{Topic}: Standard Cross-Entropy (single-label)
|
| 780 |
-
\end{itemize}
|
| 781 |
-
|
| 782 |
-
\subsection{Gradient Handling}
|
| 783 |
-
|
| 784 |
-
The trainer includes gradient clipping and early stopping:
|
| 785 |
-
|
| 786 |
-
\begin{quote}
|
| 787 |
-
\emph{``Gradient clipping to prevent exploding gradients [...] Early stopping based on validation loss''} — \texttt{trainer.py}
|
| 788 |
-
\end{quote}
|
| 789 |
-
|
| 790 |
-
\subsection{Training Loop}
|
| 791 |
-
|
| 792 |
-
Figure \ref{fig:training_loop} illustrates the training loop structure.
|
| 793 |
-
|
| 794 |
-
\begin{figure}[htbp]
|
| 795 |
-
\centering
|
| 796 |
-
\begin{tikzpicture}[
|
| 797 |
-
scale=0.65,
|
| 798 |
-
transform shape,
|
| 799 |
-
box/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.5cm, align=center, rounded corners=2pt, font=\small},
|
| 800 |
-
arrow/.style={->, >=stealth},
|
| 801 |
-
]
|
| 802 |
-
|
| 803 |
-
% Epoch loop
|
| 804 |
-
\node[box, fill=blue!20] (epoch) at (0, 0) {For each epoch};
|
| 805 |
-
|
| 806 |
-
% Batch loop
|
| 807 |
-
\node[box, fill=green!20] (batch) at (0, -1.2) {For each batch};
|
| 808 |
-
|
| 809 |
-
% Task loop
|
| 810 |
-
\node[box, fill=yellow!20] (task) at (0, -2.4) {For each task};
|
| 811 |
-
|
| 812 |
-
% Forward
|
| 813 |
-
\node[box, fill=orange!20] (forward) at (0, -3.6) {Forward + Loss};
|
| 814 |
-
|
| 815 |
-
% AMP context
|
| 816 |
-
\node[box, fill=purple!20] (amp) at (0, -4.8) {AMP autocast};
|
| 817 |
-
|
| 818 |
-
% Backward
|
| 819 |
-
\node[box, fill=red!20] (backward) at (0, -6) {Backward (scaled)};
|
| 820 |
-
|
| 821 |
-
% Accumulate check
|
| 822 |
-
\node[box, fill=cyan!20] (accum) at (0, -7.2) {Accumulation step?};
|
| 823 |
-
|
| 824 |
-
% Optimizer step
|
| 825 |
-
\node[box, fill=gray!20] (optim) at (0, -8.4) {Clip + Step + Zero};
|
| 826 |
-
|
| 827 |
-
% Validation
|
| 828 |
-
\node[box, fill=blue!20] (val) at (3.5, -1.2) {Validation};
|
| 829 |
-
|
| 830 |
-
% Checkpoint
|
| 831 |
-
\node[box, fill=green!20] (ckpt) at (3.5, -2.4) {Checkpoint};
|
| 832 |
-
|
| 833 |
-
% Early stopping
|
| 834 |
-
\node[box, fill=red!20] (early) at (3.5, -3.6) {Early Stop?};
|
| 835 |
-
|
| 836 |
-
% Arrows
|
| 837 |
-
\draw[arrow] (epoch) -- (batch);
|
| 838 |
-
\draw[arrow] (batch) -- (task);
|
| 839 |
-
\draw[arrow] (task) -- (forward);
|
| 840 |
-
\draw[arrow] (forward) -- (amp);
|
| 841 |
-
\draw[arrow] (amp) -- (backward);
|
| 842 |
-
\draw[arrow] (backward) -- (accum);
|
| 843 |
-
\draw[arrow] (accum) -- (optim);
|
| 844 |
-
\draw[arrow] (optim.south) -- ++(0, -0.3) -| ++(-2, 0) |- (task.west);
|
| 845 |
-
\draw[arrow] (epoch.east) -- ++(0.5, 0) |- (val);
|
| 846 |
-
\draw[arrow] (val) -- (ckpt);
|
| 847 |
-
\draw[arrow] (ckpt) -- (early);
|
| 848 |
-
|
| 849 |
-
\end{tikzpicture}
|
| 850 |
-
\caption{Training loop structure showing nested iteration over epochs, batches, and tasks, with gradient accumulation and validation checkpoints.}
|
| 851 |
-
\label{fig:training_loop}
|
| 852 |
-
\end{figure}
|
| 853 |
-
|
| 854 |
-
%=============================================================================
|
| 855 |
-
\section{Tasks and Datasets}
|
| 856 |
-
%=============================================================================
|
| 857 |
-
|
| 858 |
-
LexiMind addresses three complementary NLP tasks:
|
| 859 |
-
|
| 860 |
-
\subsection{Text Summarization}
|
| 861 |
-
|
| 862 |
-
\textbf{Task}: Generate concise abstractive summaries from longer documents, focusing on back-cover style book descriptions rather than plot summaries.
|
| 863 |
-
|
| 864 |
-
\textbf{Datasets}: The summarization corpus comprises 49,086 training samples, 2,727 validation samples, and 2,727 test samples. Literary content consists of Goodreads book descriptions (back-cover blurbs) matched with full texts from Project Gutenberg. Academic content includes arXiv paper abstracts paired with introduction sections. Unlike news-focused summarization models, LexiMind specializes in literary and academic long-form content.
|
| 865 |
-
|
| 866 |
-
\textbf{Approach}: Encoder-decoder generation with greedy decoding (beam search available). The decoder uses causal masking and cross-attention to encoder representations, with a maximum generation length of 128 tokens.
|
| 867 |
-
|
| 868 |
-
\textbf{Evaluation}: ROUGE-1/2/L for n-gram overlap, BLEU-4 for fluency, and BERTScore (using RoBERTa-large) for semantic similarity between generated and reference summaries.
|
| 869 |
-
|
| 870 |
-
\subsection{Emotion Classification}
|
| 871 |
-
|
| 872 |
-
\textbf{Task}: Multi-label classification identifying emotions expressed in text, where each sample may have multiple emotion labels.
|
| 873 |
-
|
| 874 |
-
\textbf{Dataset}: Google's GoEmotions \cite{demszky2020goemotions}, comprising 43,410 training samples, 5,426 validation samples, and 5,427 test samples sourced from Reddit comments.
|
| 875 |
-
|
| 876 |
-
\textbf{Classes}: 28 emotion categories: admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, joy, love, nervousness, neutral, optimism, pride, realization, relief, remorse, sadness, and surprise.
|
| 877 |
-
|
| 878 |
-
\textbf{Approach}: Encoder-only processing with mean pooling over token representations, followed by a two-layer classification head with hidden dimension 384. Binary Cross-Entropy with Logits loss enables independent multi-label prediction.
|
| 879 |
-
|
| 880 |
-
\subsection{Topic Classification}
|
| 881 |
-
|
| 882 |
-
\textbf{Task}: Single-label classification assigning documents to one of seven topic categories.
|
| 883 |
-
|
| 884 |
-
\textbf{Datasets}: A curated collection of 3,402 training samples, 189 validation samples, and 189 test samples drawn from arXiv paper categories and Project Gutenberg book metadata.
|
| 885 |
-
|
| 886 |
-
\textbf{Classes}: 7 mutually exclusive topics: Arts, Business, Fiction, History, Philosophy, Science, and Technology.
|
| 887 |
-
|
| 888 |
-
\textbf{Approach}: Encoder-only architecture with mean pooling, identical to emotion classification but using standard Cross-Entropy loss for mutually exclusive classes. Due to the significantly smaller dataset (3.4K vs 43K for emotion), the topic loss weight is reduced to 0.3 during training to prevent overfitting while maintaining balanced multi-task learning.
|
| 889 |
-
|
| 890 |
-
%=============================================================================
|
| 891 |
-
\section{Model Specifications}
|
| 892 |
-
%=============================================================================
|
| 893 |
-
|
| 894 |
-
Table \ref{tab:dataset_summary} summarizes the dataset splits used for training and evaluation. Table \ref{tab:model_specs} details the model architecture.
|
| 895 |
-
|
| 896 |
-
\begin{table}[htbp]
|
| 897 |
-
\centering
|
| 898 |
-
\caption{Dataset Summary}
|
| 899 |
-
\label{tab:dataset_summary}
|
| 900 |
-
\begin{tabular}{lccc}
|
| 901 |
-
\toprule
|
| 902 |
-
\textbf{Task} & \textbf{Train} & \textbf{Val} & \textbf{Test} \\
|
| 903 |
-
\midrule
|
| 904 |
-
Summarization & 49,086 & 2,727 & 2,727 \\
|
| 905 |
-
Emotion & 43,410 & 5,426 & 5,427 \\
|
| 906 |
-
Topic & 3,402 & 189 & 189 \\
|
| 907 |
-
\midrule
|
| 908 |
-
\textbf{Total} & 95,898 & 8,342 & 8,343 \\
|
| 909 |
-
\bottomrule
|
| 910 |
-
\end{tabular}
|
| 911 |
-
\end{table}
|
| 912 |
-
|
| 913 |
-
\begin{table}[htbp]
|
| 914 |
-
\centering
|
| 915 |
-
\caption{LexiMind Model Specifications}
|
| 916 |
-
\label{tab:model_specs}
|
| 917 |
-
\begin{tabular}{lc}
|
| 918 |
-
\toprule
|
| 919 |
-
\textbf{Parameter} & \textbf{Value} \\
|
| 920 |
-
\midrule
|
| 921 |
-
Hidden dimension ($d_{model}$) & 768 \\
|
| 922 |
-
FFN dimension ($d_{ff}$) & 2048 \\
|
| 923 |
-
Attention heads & 12 \\
|
| 924 |
-
Head dimension & 64 \\
|
| 925 |
-
Encoder layers & 12 \\
|
| 926 |
-
Decoder layers & 12 \\
|
| 927 |
-
Vocabulary size & 32,128 \\
|
| 928 |
-
Max sequence length & 512 \\
|
| 929 |
-
Dropout & 0.1 \\
|
| 930 |
-
Activation & Gated-GELU \\
|
| 931 |
-
Normalization & RMSNorm (Pre-LN) \\
|
| 932 |
-
Position encoding & Relative bias \\
|
| 933 |
-
\midrule
|
| 934 |
-
Total parameters & $\sim$272M \\
|
| 935 |
-
\bottomrule
|
| 936 |
-
\end{tabular}
|
| 937 |
-
\end{table}
|
| 938 |
-
|
| 939 |
-
%=============================================================================
|
| 940 |
-
\section{Implementation Details}
|
| 941 |
-
%=============================================================================
|
| 942 |
-
|
| 943 |
-
\subsection{Project Structure}
|
| 944 |
-
|
| 945 |
-
LexiMind follows a modular architecture:
|
| 946 |
-
|
| 947 |
-
\begin{verbatim}
|
| 948 |
-
src/
|
| 949 |
-
+-- models/
|
| 950 |
-
| +-- attention.py # MHA, RelPosBias
|
| 951 |
-
| +-- encoder.py # Encoder blocks
|
| 952 |
-
| +-- decoder.py # Decoder blocks
|
| 953 |
-
| +-- heads.py # Task heads
|
| 954 |
-
| +-- multitask.py # MTL routing
|
| 955 |
-
| +-- factory.py # Construction
|
| 956 |
-
+-- data/
|
| 957 |
-
| +-- tokenization.py # Tokenizer
|
| 958 |
-
| +-- dataset.py # Dataset classes
|
| 959 |
-
| +-- dataloader.py # Collators
|
| 960 |
-
+-- training/
|
| 961 |
-
+-- trainer.py # Training loop
|
| 962 |
-
+-- metrics.py # Evaluation
|
| 963 |
-
scripts/
|
| 964 |
-
+-- train.py # Main training
|
| 965 |
-
+-- download_data.py # Dataset download
|
| 966 |
-
+-- inference.py # CLI inference
|
| 967 |
-
+-- demo_gradio.py # Web demo
|
| 968 |
-
\end{verbatim}
|
| 969 |
-
|
| 970 |
-
\subsection{FlashAttention and CUDA Optimizations}
|
| 971 |
-
|
| 972 |
-
The trainer enables comprehensive hardware-specific optimizations:
|
| 973 |
-
|
| 974 |
-
\begin{lstlisting}[caption={CUDA optimizations from train.py}]
|
| 975 |
-
if device.type == "cuda":
|
| 976 |
-
torch.backends.cudnn.benchmark = True
|
| 977 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 978 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 979 |
-
torch.backends.cuda.enable_flash_sdp(True)
|
| 980 |
-
torch.backends.cuda.enable_mem_efficient_sdp(
|
| 981 |
-
True)
|
| 982 |
-
\end{lstlisting}
|
| 983 |
-
|
| 984 |
-
Note that T5-style relative position bias is incompatible with FlashAttention, as FlashAttention requires adding bias tensors to attention scores which breaks the fused kernel. The development configuration disables relative position bias to enable FlashAttention for faster iteration, while production configurations retain relative position bias for better quality.
|
| 985 |
-
|
| 986 |
-
\subsection{Numerical Stability}
|
| 987 |
-
|
| 988 |
-
To prevent overflow during mixed-precision training, hidden states are clamped after each sublayer:
|
| 989 |
-
|
| 990 |
-
\begin{quote}
|
| 991 |
-
\emph{``Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)''} — \texttt{encoder.py}, lines 103--105
|
| 992 |
-
\end{quote}
|
| 993 |
-
|
| 994 |
-
%=============================================================================
|
| 995 |
-
\section{Experimental Setup}
|
| 996 |
-
%=============================================================================
|
| 997 |
-
|
| 998 |
-
\subsection{Training Configuration}
|
| 999 |
-
|
| 1000 |
-
The final training configuration was optimized for quality and efficiency on an NVIDIA RTX 4070 with 12GB VRAM:
|
| 1001 |
-
|
| 1002 |
-
\begin{itemize}
|
| 1003 |
-
\item \textbf{Optimizer}: Fused AdamW with weight decay 0.01, $\beta_1=0.9$, $\beta_2=0.98$
|
| 1004 |
-
\item \textbf{Learning Rate}: $3 \times 10^{-5}$ with cosine decay
|
| 1005 |
-
\item \textbf{Warmup}: 300 steps ($\sim$0.5 epochs)
|
| 1006 |
-
\item \textbf{Batch Size}: 10 with 4$\times$ gradient accumulation (effective batch size 40)
|
| 1007 |
-
\item \textbf{Precision}: BFloat16 on Ampere+ GPUs with TF32 enabled
|
| 1008 |
-
\item \textbf{Gradient Clipping}: Max norm 1.0
|
| 1009 |
-
\item \textbf{Gradient Checkpointing}: Enabled for memory efficiency
|
| 1010 |
-
\item \textbf{torch.compile}: Dynamic compilation for encoder and decoder
|
| 1011 |
-
\item \textbf{Task Weights}: Summarization 1.0, Emotion 1.0, Topic 0.3 (reduced due to small dataset)
|
| 1012 |
-
\item \textbf{Early Stopping}: Patience of 3 epochs on validation loss
|
| 1013 |
-
\item \textbf{Encoder Freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 1014 |
-
\end{itemize}
|
| 1015 |
-
|
| 1016 |
-
Training completed in 7 epochs ($\sim$6 hours) with early stopping triggered due to validation loss plateau.
|
| 1017 |
-
|
| 1018 |
-
%=============================================================================
|
| 1019 |
-
\section{Experimental Results}
|
| 1020 |
-
\label{sec:results}
|
| 1021 |
-
%=============================================================================
|
| 1022 |
-
|
| 1023 |
-
We evaluate LexiMind on held-out validation sets for each task. Table \ref{tab:summarization_results} presents the summarization metrics, Table \ref{tab:classification_results} shows classification performance.
|
| 1024 |
-
|
| 1025 |
-
\subsection{Summarization Performance}
|
| 1026 |
-
|
| 1027 |
-
\begin{table}[htbp]
|
| 1028 |
-
\centering
|
| 1029 |
-
\caption{Summarization Evaluation Results}
|
| 1030 |
-
\label{tab:summarization_results}
|
| 1031 |
-
\begin{tabular}{lc}
|
| 1032 |
-
\toprule
|
| 1033 |
-
\textbf{Metric} & \textbf{Score} \\
|
| 1034 |
-
\midrule
|
| 1035 |
-
ROUGE-1 & 0.3064 \\
|
| 1036 |
-
ROUGE-2 & 0.0896 \\
|
| 1037 |
-
ROUGE-L & 0.1832 \\
|
| 1038 |
-
BLEU-4 & 0.0237 \\
|
| 1039 |
-
\midrule
|
| 1040 |
-
BERTScore Precision & 0.8430 \\
|
| 1041 |
-
BERTScore Recall & 0.8179 \\
|
| 1042 |
-
\textbf{BERTScore F1} & \textbf{0.8300} \\
|
| 1043 |
-
\bottomrule
|
| 1044 |
-
\end{tabular}
|
| 1045 |
-
\end{table}
|
| 1046 |
-
|
| 1047 |
-
The BERTScore F1 of \textbf{0.83} demonstrates strong semantic similarity between generated descriptions and references, indicating the model captures meaning effectively even when exact wording differs. ROUGE scores are typical for abstractive summarization where the model paraphrases rather than extracts verbatim text.
|
| 1048 |
-
|
| 1049 |
-
\subsection{Classification Performance}
|
| 1050 |
-
|
| 1051 |
-
\begin{table}[htbp]
|
| 1052 |
-
\centering
|
| 1053 |
-
\caption{Classification Evaluation Results}
|
| 1054 |
-
\label{tab:classification_results}
|
| 1055 |
-
\begin{tabular}{llc}
|
| 1056 |
-
\toprule
|
| 1057 |
-
\textbf{Task} & \textbf{Metric} & \textbf{Score} \\
|
| 1058 |
-
\midrule
|
| 1059 |
-
\multirow{2}{*}{Topic (7 classes)} & Accuracy & \textbf{85.19\%} \\
|
| 1060 |
-
& Macro F1 & 0.8474 \\
|
| 1061 |
-
\midrule
|
| 1062 |
-
Emotion (28 classes) & Multi-label F1 & 0.1987 \\
|
| 1063 |
-
\bottomrule
|
| 1064 |
-
\end{tabular}
|
| 1065 |
-
\end{table}
|
| 1066 |
-
|
| 1067 |
-
Topic classification achieves \textbf{85.2\%} accuracy with balanced per-class performance. The emotion detection task proves more challenging due to the 28-class multi-label setting with inherent label ambiguity in the GoEmotions dataset.
|
| 1068 |
-
|
| 1069 |
-
\subsection{Training Dynamics}
|
| 1070 |
-
|
| 1071 |
-
Figure \ref{fig:training_curves} illustrates the training dynamics over 7 epochs. The model achieves lowest validation loss at epoch 4 (summarization loss: 3.698), with the checkpoint from this epoch saved as the best model. Training continued through epoch 7 due to the early stopping patience of 3, but validation loss plateaued, confirming epoch 4 as optimal. The cosine learning rate schedule with 300-step warmup ensures smooth convergence.
|
| 1072 |
-
|
| 1073 |
-
\begin{figure}[htbp]
|
| 1074 |
-
\centering
|
| 1075 |
-
\includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
|
| 1076 |
-
\caption{Training and validation loss curves over 7 epochs. Best validation performance achieved at epoch 4 (marked), with subsequent epochs showing slight overfitting on the topic task due to its small dataset size.}
|
| 1077 |
-
\label{fig:training_curves}
|
| 1078 |
-
\end{figure}
|
| 1079 |
-
|
| 1080 |
-
Figure \ref{fig:task_metrics} presents per-task metrics throughout training, showing the distinct learning trajectories for summarization, emotion detection, and topic classification.
|
| 1081 |
-
|
| 1082 |
-
\begin{figure}[htbp]
|
| 1083 |
-
\centering
|
| 1084 |
-
\includegraphics[width=\columnwidth]{figures/task_metrics.png}
|
| 1085 |
-
\caption{Task-specific metrics during training: ROUGE-1 for summarization, F1 for emotion detection, and accuracy for topic classification.}
|
| 1086 |
-
\label{fig:task_metrics}
|
| 1087 |
-
\end{figure}
|
| 1088 |
-
|
| 1089 |
-
Figure \ref{fig:training_dynamics} provides a comprehensive view of training dynamics, including loss convergence, per-epoch improvements, cumulative loss reduction, and the train-validation gap which indicates overfitting behavior.
|
| 1090 |
-
|
| 1091 |
-
\begin{figure}[htbp]
|
| 1092 |
-
\centering
|
| 1093 |
-
\includegraphics[width=\columnwidth]{figures/training_dynamics.png}
|
| 1094 |
-
\caption{Training dynamics overview: (top-left) Loss convergence with smoothing, (top-right) Relative improvement per epoch, (bottom-left) Cumulative loss reduction from initial values, (bottom-right) Train-validation gap showing slight overfitting after epoch 4.}
|
| 1095 |
-
\label{fig:training_dynamics}
|
| 1096 |
-
\end{figure}
|
| 1097 |
-
|
| 1098 |
-
\subsection{Per-Class Topic Analysis}
|
| 1099 |
-
|
| 1100 |
-
Table \ref{tab:topic_breakdown} shows the per-class performance for topic classification:
|
| 1101 |
-
|
| 1102 |
-
\begin{table}[htbp]
|
| 1103 |
-
\centering
|
| 1104 |
-
\caption{Per-Class Topic Classification Performance}
|
| 1105 |
-
\label{tab:topic_breakdown}
|
| 1106 |
-
\begin{tabular}{lccc}
|
| 1107 |
-
\toprule
|
| 1108 |
-
\textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
|
| 1109 |
-
\midrule
|
| 1110 |
-
Arts & 0.93 & 0.76 & 0.84 \\
|
| 1111 |
-
Business & 0.97 & 0.97 & 0.97 \\
|
| 1112 |
-
Fiction & 0.95 & 1.00 & 0.97 \\
|
| 1113 |
-
History & 0.83 & 0.78 & 0.81 \\
|
| 1114 |
-
Philosophy & 0.80 & 0.86 & 0.83 \\
|
| 1115 |
-
Science & 0.58 & 0.73 & 0.65 \\
|
| 1116 |
-
Technology & 0.86 & 0.89 & 0.87 \\
|
| 1117 |
-
\bottomrule
|
| 1118 |
-
\end{tabular}
|
| 1119 |
-
\end{table}
|
| 1120 |
-
|
| 1121 |
-
The model performs best on Fiction and Business categories, while Science shows the most confusion, likely due to overlap with Technology topics.
|
| 1122 |
-
|
| 1123 |
-
%=============================================================================
|
| 1124 |
-
\section{Discussion}
|
| 1125 |
-
%=============================================================================
|
| 1126 |
-
|
| 1127 |
-
\subsection{Key Findings}
|
| 1128 |
-
|
| 1129 |
-
\textbf{BERTScore vs. ROUGE}: The high BERTScore F1 (0.83) combined with moderate ROUGE-1 (0.31) illustrates a key characteristic of abstractive summarization. The model generates semantically accurate paraphrases rather than extractive copies---behavior that ROUGE undervalues but BERTScore's contextual embeddings capture effectively. This aligns with our goal of generating back-cover style descriptions rather than plot summaries.
|
| 1130 |
-
|
| 1131 |
-
\textbf{Multi-Task Learning Dynamics}: Analysis of training curves reveals distinct learning trajectories across tasks. Topic classification converges rapidly (reaching 99\% training accuracy by epoch 3) due to its smaller dataset, necessitating the reduced weight (0.3) to prevent gradient dominance. Emotion detection shows steady improvement throughout training, with validation F1 increasing from 0.30 to 0.40. Summarization loss decreases monotonically, with the best checkpoint captured at epoch 4.
|
| 1132 |
-
|
| 1133 |
-
\textbf{Transfer Learning Benefits}: Initializing from FLAN-T5-base provides strong linguistic priors, enabling competitive performance with only 7 epochs of fine-tuning ($\sim$6 hours on consumer hardware). Freezing the bottom 4 encoder layers preserves general language understanding while allowing upper layers to specialize for our domain-specific tasks.
|
| 1134 |
-
|
| 1135 |
-
\textbf{Checkpoint Selection}: The best model checkpoint at epoch 4 achieves the lowest validation summarization loss (3.698) while maintaining strong classification performance. Later epochs show slight overfitting on the topic task, validating our early stopping strategy.
|
| 1136 |
-
|
| 1137 |
-
\subsection{Limitations}
|
| 1138 |
-
|
| 1139 |
-
\begin{itemize}
|
| 1140 |
-
\item \textbf{Emotion Detection}: The 28-class multi-label setting remains challenging, with F1 of 0.20 on validation data. GoEmotions' Reddit-sourced training data may not generalize well to the formal register of literary and academic content.
|
| 1141 |
-
\item \textbf{Topic Dataset Imbalance}: With only 3,402 training samples distributed across 7 classes, some categories (notably Science with 0.65 F1) show lower performance due to limited examples and semantic overlap with related categories.
|
| 1142 |
-
\item \textbf{Domain Gap}: While Goodreads descriptions provide quality literary summaries, the model's exposure to contemporary fiction is limited by Project Gutenberg's public domain focus on pre-1928 works.
|
| 1143 |
-
\end{itemize}
|
| 1144 |
-
|
| 1145 |
-
\subsection{Future Work}
|
| 1146 |
-
|
| 1147 |
-
Several directions could improve LexiMind's performance:
|
| 1148 |
-
\begin{itemize}
|
| 1149 |
-
\item \textbf{Domain-Specific Emotion Data}: Fine-tuning on literary emotion annotations rather than Reddit comments could better capture the emotional nuances of literary and academic text.
|
| 1150 |
-
\item \textbf{Parameter-Efficient Fine-Tuning}: Integrating LoRA \cite{hu2022lora} would reduce memory requirements and enable experimentation with larger base models (FLAN-T5-large, FLAN-T5-xl).
|
| 1151 |
-
\item \textbf{Expanded Topic Dataset}: Augmenting the 3.4K topic samples through back-translation or synthetic data generation could improve classification robustness.
|
| 1152 |
-
\end{itemize}
|
| 1153 |
-
|
| 1154 |
-
%=============================================================================
|
| 1155 |
-
\section{Conclusion}
|
| 1156 |
-
%=============================================================================
|
| 1157 |
-
|
| 1158 |
-
This paper presented LexiMind, a multi-task NLP system combining custom Transformer implementation with FLAN-T5 pre-trained weights. The hybrid approach provides architectural transparency while leveraging transfer learning, achieving:
|
| 1159 |
-
|
| 1160 |
-
\begin{itemize}
|
| 1161 |
-
\item \textbf{Summarization}: BERTScore F1 of 0.83, demonstrating strong semantic fidelity for back-cover style book descriptions
|
| 1162 |
-
\item \textbf{Topic Classification}: 85.2\% accuracy and 0.85 macro F1 across 7 categories
|
| 1163 |
-
\item \textbf{Emotion Detection}: Multi-label F1 of 0.20 on 28 emotion classes
|
| 1164 |
-
\end{itemize}
|
| 1165 |
-
|
| 1166 |
-
The complete system trains in approximately 6 hours on a consumer GPU (RTX 4070 12GB), demonstrating that sophisticated multi-task models remain accessible without datacenter-scale resources. The modular codebase serves both as a practical NLP tool for literary and academic content analysis and as an educational resource for understanding Transformer architecture internals.
|
| 1167 |
-
|
| 1168 |
-
All code, trained models, and datasets are publicly available, with a live demonstration hosted on HuggingFace Spaces.\footnote{\url{https://huggingface.co/spaces/OliverPerrin/LexiMind}}
|
| 1169 |
-
|
| 1170 |
-
%=============================================================================
|
| 1171 |
-
% References
|
| 1172 |
-
%=============================================================================
|
| 1173 |
-
|
| 1174 |
-
\begin{thebibliography}{00}
|
| 1175 |
-
|
| 1176 |
-
\bibitem{vaswani2017attention}
|
| 1177 |
-
A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin, ``Attention is all you need,'' in \textit{Advances in Neural Information Processing Systems (NeurIPS)}, vol. 30, 2017, pp. 5998--6008. [Online]. Available: \url{https://arxiv.org/abs/1706.03762}
|
| 1178 |
-
|
| 1179 |
-
\bibitem{raffel2020exploring}
|
| 1180 |
-
C. Raffel, N. Shazeer, A. Roberts, K. Lee, S. Narang, M. Matena, Y. Zhou, W. Li, and P. J. Liu, ``Exploring the limits of transfer learning with a unified text-to-text transformer,'' \textit{Journal of Machine Learning Research}, vol. 21, no. 140, pp. 1--67, 2020. [Online]. Available: \url{https://arxiv.org/abs/1910.10683}
|
| 1181 |
-
|
| 1182 |
-
\bibitem{chung2022scaling}
|
| 1183 |
-
H. W. Chung, L. Hou, S. Longpre, B. Zoph, Y. Tay, W. Fedus, Y. Li, X. Wang, M. Dehghani, S. Brahma, A. Webson, S. S. Gu, Z. Dai, M. Suzgun, X. Chen, A. Chowdhery, A. Castro-Ros, M. Pellat, K. Robinson, D. Valter, S. Narang, G. Mishra, A. Yu, V. Zhao, Y. Huang, A. Dai, H. Yu, S. Petrov, E. H. Chi, J. Dean, J. Devlin, A. Roberts, D. Zhou, Q. V. Le, and J. Wei, ``Scaling instruction-finetuned language models,'' \textit{arXiv preprint arXiv:2210.11416}, 2022. [Online]. Available: \url{https://arxiv.org/abs/2210.11416}
|
| 1184 |
-
|
| 1185 |
-
\bibitem{xiong2020layer}
|
| 1186 |
-
R. Xiong, Y. Yang, J. He, K. Zheng, S. Zheng, C. Xing, H. Zhang, Y. Lan, L. Wang, and T. Liu, ``On layer normalization in the transformer architecture,'' in \textit{International Conference on Machine Learning (ICML)}, 2020, pp. 10524--10533. [Online]. Available: \url{https://arxiv.org/abs/2002.04745}
|
| 1187 |
-
|
| 1188 |
-
\bibitem{zhang2019root}
|
| 1189 |
-
B. Zhang and R. Sennrich, ``Root mean square layer normalization,'' in \textit{Advances in Neural Information Processing Systems (NeurIPS)}, vol. 32, 2019, pp. 12360--12371. [Online]. Available: \url{https://arxiv.org/abs/1910.07467}
|
| 1190 |
-
|
| 1191 |
-
\bibitem{ba2016layer}
|
| 1192 |
-
J. L. Ba, J. R. Kiros, and G. E. Hinton, ``Layer normalization,'' \textit{arXiv preprint arXiv:1607.06450}, 2016. [Online]. Available: \url{https://arxiv.org/abs/1607.06450}
|
| 1193 |
-
|
| 1194 |
-
\bibitem{caruana1997multitask}
|
| 1195 |
-
R. Caruana, ``Multitask learning,'' \textit{Machine Learning}, vol. 28, no. 1, pp. 41--75, 1997.
|
| 1196 |
-
|
| 1197 |
-
\bibitem{kudo2018sentencepiece}
|
| 1198 |
-
T. Kudo and J. Richardson, ``SentencePiece: A simple and language independent subword tokenizer and detokenizer for neural text processing,'' in \textit{Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing: System Demonstrations}, 2018, pp. 66--71. [Online]. Available: \url{https://arxiv.org/abs/1808.06226}
|
| 1199 |
-
|
| 1200 |
-
\bibitem{hu2022lora}
|
| 1201 |
-
E. J. Hu, Y. Shen, P. Wallis, Z. Allen-Zhu, Y. Li, S. Wang, L. Wang, and W. Chen, ``LoRA: Low-rank adaptation of large language models,'' in \textit{International Conference on Learning Representations (ICLR)}, 2022. [Online]. Available: \url{https://arxiv.org/abs/2106.09685}
|
| 1202 |
-
|
| 1203 |
-
\bibitem{dao2022flashattention}
|
| 1204 |
-
T. Dao, D. Fu, S. Ermon, A. Rudra, and C. Ré, ``FlashAttention: Fast and memory-efficient exact attention with IO-awareness,'' in \textit{Advances in Neural Information Processing Systems (NeurIPS)}, vol. 35, 2022, pp. 16344--16359. [Online]. Available: \url{https://arxiv.org/abs/2205.14135}
|
| 1205 |
-
|
| 1206 |
-
\bibitem{zhang2019bertscore}
|
| 1207 |
-
T. Zhang, V. Kishore, F. Wu, K. Q. Weinberger, and Y. Artzi, ``BERTScore: Evaluating text generation with BERT,'' in \textit{International Conference on Learning Representations (ICLR)}, 2020. [Online]. Available: \url{https://arxiv.org/abs/1904.09675}
|
| 1208 |
-
|
| 1209 |
-
\bibitem{demszky2020goemotions}
|
| 1210 |
-
D. Demszky, D. Movshovitz-Attias, J. Ko, A. Cowen, G. Nemade, and S. Ravi, ``GoEmotions: A dataset of fine-grained emotions,'' in \textit{Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics}, 2020, pp. 4040--4054. [Online]. Available: \url{https://arxiv.org/abs/2005.00547}
|
| 1211 |
-
|
| 1212 |
-
\end{thebibliography}
|
| 1213 |
-
|
| 1214 |
-
\end{document}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/research_paper.tex
DELETED
|
@@ -1,574 +0,0 @@
|
|
| 1 |
-
% LexiMind: Multi-Task Learning for Literary and Academic Text Understanding
|
| 2 |
-
% Research Paper - Revised with Experimental Rigor
|
| 3 |
-
% Author: Oliver Perrin
|
| 4 |
-
|
| 5 |
-
\documentclass[conference]{IEEEtran}
|
| 6 |
-
\IEEEoverridecommandlockouts
|
| 7 |
-
|
| 8 |
-
% Essential packages
|
| 9 |
-
\usepackage{cite}
|
| 10 |
-
\usepackage{amsmath,amssymb,amsfonts}
|
| 11 |
-
\usepackage{graphicx}
|
| 12 |
-
\usepackage{textcomp}
|
| 13 |
-
\usepackage{xcolor}
|
| 14 |
-
\usepackage{hyperref}
|
| 15 |
-
\usepackage{booktabs}
|
| 16 |
-
\usepackage{multirow}
|
| 17 |
-
\usepackage{array}
|
| 18 |
-
\usepackage{caption}
|
| 19 |
-
|
| 20 |
-
% TikZ for diagrams
|
| 21 |
-
\usepackage{tikz}
|
| 22 |
-
\usetikzlibrary{shapes.geometric, arrows, positioning}
|
| 23 |
-
|
| 24 |
-
% Hyperref setup
|
| 25 |
-
\hypersetup{
|
| 26 |
-
colorlinks=true,
|
| 27 |
-
linkcolor=blue,
|
| 28 |
-
citecolor=blue,
|
| 29 |
-
urlcolor=blue
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
\def\BibTeX{{\rm B\kern-.05em{\sc i\kern-.025em b}\kern-.08em
|
| 33 |
-
T\kern-.1667em\lower.7ex\hbox{E}\kern-.125emX}}
|
| 34 |
-
|
| 35 |
-
\begin{document}
|
| 36 |
-
|
| 37 |
-
\title{Multi-Task Learning for Literary and Academic Text:\\Does Joint Training Help or Hurt?}
|
| 38 |
-
|
| 39 |
-
\author{\IEEEauthorblockN{Oliver Perrin}\\
|
| 40 |
-
\IEEEauthorblockA{Department of Computer Science\\
|
| 41 |
-
Appalachian State University\\
|
| 42 |
-
Email: perrinot@appstate.edu}}
|
| 43 |
-
|
| 44 |
-
\maketitle
|
| 45 |
-
|
| 46 |
-
\begin{abstract}
|
| 47 |
-
Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies, we find that naive MTL with mean pooling and round-robin scheduling yields mixed results: topic classification gains +3.2\% accuracy, summarization remains stable, but emotion detection suffers negative transfer ($-$0.02 F1). We then show that two targeted interventions---\textbf{learned attention pooling} for the emotion head and \textbf{temperature-based task sampling} ($\alpha=0.5$)---eliminate negative transfer entirely, improving multi-task emotion sample-averaged F1 from 0.199 to 0.352 (+77\%), substantially exceeding the single-task baseline (0.218). With per-class threshold tuning, emotion macro F1 reaches 0.294. Topic classification improves to 85.7\% accuracy (95\% CI: [80.4\%, 91.0\%]), and summarization quality remains robust (ROUGE-1: 0.310, ROUGE-L: 0.185). Per-domain analysis reveals a significant quality gap between academic summaries (ROUGE-1: 0.319) and literary summaries (ROUGE-1: 0.206), attributable to the 11:1 training imbalance. We additionally contribute inter-task gradient conflict diagnostics, cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. Our analysis demonstrates that architectural isolation of task-specific components (attention pooling) combined with balanced optimization (temperature sampling) can convert negative transfer to positive transfer in MTL systems.
|
| 48 |
-
\end{abstract}
|
| 49 |
-
|
| 50 |
-
\begin{IEEEkeywords}
|
| 51 |
-
Multi-Task Learning, Transfer Learning, Text Summarization, Emotion Classification, FLAN-T5
|
| 52 |
-
\end{IEEEkeywords}
|
| 53 |
-
|
| 54 |
-
%=============================================================================
|
| 55 |
-
\section{Introduction}
|
| 56 |
-
%=============================================================================
|
| 57 |
-
|
| 58 |
-
Multi-task learning (MTL) \cite{caruana1997multitask} trains a single model on multiple related tasks, hypothesizing that shared representations improve generalization. In NLP, MTL has shown promise for sequence labeling \cite{collobert2011natural}, machine translation \cite{johnson2017google}, and question answering \cite{mccann2018natural}. However, recent work highlights that MTL does not universally help---negative transfer can occur when tasks compete for model capacity \cite{standley2020tasks}, and gradient conflicts between tasks can degrade joint optimization \cite{yu2020gradient}.
|
| 59 |
-
|
| 60 |
-
We investigate MTL effectiveness in a specific, underexplored domain: \textbf{literary and academic text understanding}. Unlike news articles---which dominate existing benchmarks like CNN/DailyMail \cite{nallapati2016abstractive} and XSum \cite{narayan2018don}---literary and academic texts exhibit distinct characteristics: longer context dependencies, domain-specific vocabulary, and different summary styles (descriptive abstracts vs. extractive headlines). Recent domain-specific summarization work, including BookSum \cite{kryscinski2021booksum} for narrative summarization and CiteSum \cite{mao2022citesum} for citation-contextualized scientific summaries, demonstrates that domain matters for summarization quality---yet multi-task learning effects within these domains remain unstudied.
|
| 61 |
-
|
| 62 |
-
Our study addresses three research questions:
|
| 63 |
-
|
| 64 |
-
\begin{enumerate}
|
| 65 |
-
\item[\textbf{RQ1}] Does multi-task learning improve performance over single-task specialists on literary/academic domains?
|
| 66 |
-
\item[\textbf{RQ2}] Which tasks benefit from joint training, and which suffer negative transfer?
|
| 67 |
-
\item[\textbf{RQ3}] How much does pre-trained knowledge (FLAN-T5) contribute relative to task-specific fine-tuning?
|
| 68 |
-
\end{enumerate}
|
| 69 |
-
|
| 70 |
-
To answer these questions, we construct \textbf{LexiMind}, a multi-task system built on FLAN-T5-base \cite{chung2022scaling} that performs abstractive summarization, topic classification, and emotion detection. We conduct ablations comparing multi-task vs. single-task training, with vs. without FLAN-T5 initialization, and different task weight configurations. Our primary experimental contribution is the empirical characterization of transfer effects across these heterogeneous tasks:
|
| 71 |
-
|
| 72 |
-
\begin{itemize}
|
| 73 |
-
\item \textbf{Topic classification benefits from MTL} (+3.7\% accuracy over single-task), leveraging shared encoder representations from the larger summarization dataset.
|
| 74 |
-
\item \textbf{Summarization is robust to MTL}, showing stable ROUGE scores despite sharing encoder capacity with classification heads.
|
| 75 |
-
\item \textbf{Emotion detection: from negative to positive transfer}. Naive MTL with mean pooling degrades emotion F1 by $-$0.02; learned attention pooling combined with temperature-based task sampling reverses this, yielding +0.134 F1 over the single-task baseline.
|
| 76 |
-
\item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
|
| 77 |
-
\end{itemize}
|
| 78 |
-
|
| 79 |
-
We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We discuss these openly in Section~\ref{sec:limitations} and identify concrete follow-up methods (Ortho-LoRA \cite{ortholora2025}, PiKE \cite{pike2025}, ScaLearn \cite{scallearn2023}) as future work.
|
| 80 |
-
|
| 81 |
-
%=============================================================================
|
| 82 |
-
\section{Related Work}
|
| 83 |
-
%=============================================================================
|
| 84 |
-
|
| 85 |
-
\subsection{Multi-Task Learning in NLP}
|
| 86 |
-
|
| 87 |
-
Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
|
| 88 |
-
|
| 89 |
-
\textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach---but includes gradient conflict diagnostics (inter-task cosine similarity monitoring) to characterize optimization interference. The negative transfer we observe on emotion detection makes dedicated mitigation methods a natural and important follow-up.
|
| 90 |
-
|
| 91 |
-
\textbf{Recent advances in multi-task optimization.} Several recent methods address task interference more precisely. Ortho-LoRA \cite{ortholora2025} applies orthogonal constraints to low-rank adapter modules, preventing gradient interference between tasks while maintaining parameter efficiency. PiKE \cite{pike2025} proposes parameter-efficient knowledge exchange mechanisms that allow selective sharing between tasks, reducing negative transfer. ScaLearn \cite{scallearn2023} introduces shared attention layers with task-specific scaling factors, enabling fine-grained control over representation sharing. Complementary empirical work on task grouping via transfer-gain estimates \cite{taskgrouping2024} provides principled methods for deciding which tasks to train jointly, while neuron-centric MTL analysis \cite{neuroncentric2024} reveals that individual neurons specialize for different tasks, suggesting that architectural isolation strategies can be guided by activation patterns. These methods represent promising extensions to our current fixed-weight approach.
|
| 92 |
-
|
| 93 |
-
\textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
|
| 94 |
-
|
| 95 |
-
\subsection{Literary and Academic Summarization}
|
| 96 |
-
|
| 97 |
-
Most summarization benchmarks focus on news \cite{nallapati2016abstractive, narayan2018don}. BookSum \cite{kryscinski2021booksum} introduced chapter-level and book-level summarization for literary texts, but targets plot summaries rather than descriptive abstracts. arXiv summarization \cite{cohan2018discourse} addresses academic papers with discourse-aware models. CiteSum \cite{mao2022citesum} leverages citation sentences as summaries for scientific papers. Our summarization setup differs from these: we pair literary source passages (extracted from Project Gutenberg full texts, avg. 3,030 characters) with Goodreads book descriptions (avg. 572 characters) as targets, training the model to generate \textit{what a book is about} rather than plot recaps. For academic text, arXiv paper body text (avg. 3,967 characters) is paired with abstracts (avg. 1,433 characters). The resulting compression ratios (0.19 for literary, 0.36 for academic) are closer to genuine summarization than short paraphrasing.
|
| 98 |
-
|
| 99 |
-
\subsection{Emotion Detection}
|
| 100 |
-
|
| 101 |
-
GoEmotions \cite{demszky2020goemotions} provides 28 fine-grained emotion labels from Reddit comments. The original work reports 0.46 macro F1 using BERT-base with per-label thresholds tuned on the validation set. Subsequent work achieves 0.35--0.46 macro F1 depending on the model and threshold strategy. Importantly, all published GoEmotions baselines use encoder-only architectures (BERT, RoBERTa) rather than encoder-decoder models like T5. Our setup differs in both architecture (encoder-decoder with attention-pooled encoder states for emotion detection) and domain (training encoder primarily on literary/academic summarization), making direct comparison to published baselines informative but not fully controlled.
|
| 102 |
-
|
| 103 |
-
%=============================================================================
|
| 104 |
-
\section{Experimental Setup}
|
| 105 |
-
%=============================================================================
|
| 106 |
-
|
| 107 |
-
\subsection{Task Formulations}
|
| 108 |
-
\label{sec:task_formulation}
|
| 109 |
-
|
| 110 |
-
We define three tasks with explicit input-output specifications:
|
| 111 |
-
|
| 112 |
-
\textbf{Summarization (generative).} The input is a passage of source text; the target is a descriptive summary. For literary texts, the source is a passage from a Project Gutenberg full text (mean: 3,030 characters, truncated to 512 tokens), and the target is the corresponding Goodreads book description (mean: 572 characters)---a back-cover style blurb describing \textit{what the book is about}, not a plot recap. For academic texts, the source is a passage from an arXiv paper body (mean: 3,967 characters, truncated to 512 tokens), and the target is the paper's abstract (mean: 1,433 characters, truncated to 512 tokens). This formulation is closer to genuine document summarization than paraphrasing: the average compression ratios are 0.19 (literary) and 0.36 (academic), comparable to standard summarization benchmarks.
|
| 113 |
-
|
| 114 |
-
\textbf{Topic classification (discriminative, single-label).} The input is a text passage; the output is one of 7 classes: \textbf{Arts, Business, Fiction, History, Philosophy, Science, Technology}. Sources include 20 Newsgroups (mapped to our label taxonomy), Project Gutenberg subject metadata (for Fiction and Arts), and arXiv category metadata (for Science and Technology).
|
| 115 |
-
|
| 116 |
-
\textbf{Emotion detection (discriminative, multi-label).} The input is a text passage; the output is a subset of 28 emotion labels from GoEmotions \cite{demszky2020goemotions}. Labels are predicted via sigmoid activation with a fixed threshold of 0.3 during training evaluation and 0.5 during inference. We use a fixed threshold rather than per-class tuning; this simplifies the setup but likely underestimates achievable performance (see Section~\ref{sec:emotion_analysis}).
|
| 117 |
-
|
| 118 |
-
\subsection{Datasets}
|
| 119 |
-
|
| 120 |
-
Table \ref{tab:datasets} summarizes dataset statistics.
|
| 121 |
-
|
| 122 |
-
\begin{table}[htbp]
|
| 123 |
-
\centering
|
| 124 |
-
\caption{Dataset Statistics. Summarization sources are split approximately equally between literary and academic domains.}
|
| 125 |
-
\label{tab:datasets}
|
| 126 |
-
\begin{tabular}{llrrr}
|
| 127 |
-
\toprule
|
| 128 |
-
\textbf{Task} & \textbf{Source} & \textbf{Train} & \textbf{Val} & \textbf{Test} \\
|
| 129 |
-
\midrule
|
| 130 |
-
\multirow{2}{*}{Summarization} & Goodreads + Gutenberg & $\sim$4K & -- & -- \\
|
| 131 |
-
& arXiv (body $\rightarrow$ abstract) & $\sim$45K & -- & -- \\
|
| 132 |
-
& \textit{Combined} & 49,086 & 2,727 & 2,727 \\
|
| 133 |
-
\midrule
|
| 134 |
-
Topic (7 classes) & 20News + Gutenberg + arXiv & 3,402 & 189 & 189 \\
|
| 135 |
-
\midrule
|
| 136 |
-
Emotion (28 labels) & GoEmotions (Reddit) & 43,410 & 5,426 & 5,427 \\
|
| 137 |
-
\bottomrule
|
| 138 |
-
\end{tabular}
|
| 139 |
-
\end{table}
|
| 140 |
-
|
| 141 |
-
\textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed \textit{within} the literary and academic subsets, as they are drawn from disjoint sources. We note that the academic subset is substantially larger ($\sim$45K vs. $\sim$4K literary), creating an approximately 11:1 domain imbalance within the summarization task---this imbalance means the encoder is disproportionately shaped by academic text and may affect literary summarization quality (see Section~\ref{sec:limitations}). Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects, 20 Newsgroups categories) and mapped to our 7-class taxonomy; no manual annotation was performed, which introduces potential noise from metadata inaccuracies (e.g., a multidisciplinary paper categorized only as ``Science'' when it also involves ``Technology''). GoEmotions is used as-is from the HuggingFace datasets hub.
|
| 142 |
-
|
| 143 |
-
\textbf{Cross-task deduplication.} Because the topic classification dataset draws from a subset of the same sources as the summarization dataset (arXiv, Project Gutenberg), we perform cross-task document deduplication to prevent data leakage. Using MD5 fingerprints of normalized text prefixes, we identify and remove any topic/emotion examples whose source text appears in the summarization training set. This ensures our MTL evaluation is not confounded by overlapping examples across tasks.
|
| 144 |
-
|
| 145 |
-
\textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
|
| 146 |
-
|
| 147 |
-
\subsection{Model Architecture}
|
| 148 |
-
|
| 149 |
-
LexiMind uses FLAN-T5-base (272M parameters) as the backbone, with a custom reimplementation that loads pre-trained weights via a factory module for architectural transparency:
|
| 150 |
-
|
| 151 |
-
\begin{itemize}
|
| 152 |
-
\item 12-layer encoder, 12-layer decoder
|
| 153 |
-
\item 768-dimensional hidden states, 12 attention heads
|
| 154 |
-
\item T5-style relative position bias (no absolute positional embeddings)
|
| 155 |
-
\item Pre-Layer Normalization with RMSNorm \cite{zhang2019root}
|
| 156 |
-
\item FlashAttention via PyTorch 2.0 SDPA when compatible
|
| 157 |
-
\end{itemize}
|
| 158 |
-
|
| 159 |
-
Task-specific heads branch from the shared encoder:
|
| 160 |
-
\begin{itemize}
|
| 161 |
-
\item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing $\epsilon=0.1$, greedy decoding with max length 512 tokens)
|
| 162 |
-
\item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
|
| 163 |
-
\item \textbf{Emotion}: Linear classifier on \textit{attention-pooled} encoder hidden states with sigmoid activation (binary cross-entropy loss). Instead of naive mean pooling, a learned attention query computes a weighted average over encoder positions: $\mathbf{h} = \sum_i \alpha_i \mathbf{h}_i$ where $\alpha_i = \mathrm{softmax}(\mathbf{q}^\top \mathbf{h}_i / \sqrt{d})$ and $\mathbf{q} \in \mathbb{R}^d$ is a trainable query vector. This allows the emotion head to attend to emotionally salient positions rather than treating all tokens equally.
|
| 164 |
-
\end{itemize}
|
| 165 |
-
|
| 166 |
-
\textbf{Architectural note.} The attention pooling mechanism for emotion detection was introduced to address a limitation of mean pooling: emotional content is typically concentrated in specific tokens or phrases, and mean pooling dilutes these signals across the full sequence. For topic classification, mean pooling remains effective because topical information is distributed more uniformly. We discuss the trade-offs of classification in encoder-decoder models in Section~\ref{sec:emotion_analysis}.
|
| 167 |
-
|
| 168 |
-
\subsection{Training Configuration}
|
| 169 |
-
|
| 170 |
-
All experiments use consistent hyperparameters unless otherwise noted:
|
| 171 |
-
|
| 172 |
-
\begin{itemize}
|
| 173 |
-
\item \textbf{Optimizer}: Fused AdamW, lr=$3\times10^{-5}$, weight decay=0.01, $\beta_1$=0.9, $\beta_2$=0.98
|
| 174 |
-
\item \textbf{Batch size}: 10 per step $\times$ 4 gradient accumulation = 40 effective
|
| 175 |
-
\item \textbf{Schedule}: 300-step linear warmup, cosine decay to 0.1$\times$ peak lr
|
| 176 |
-
\item \textbf{Max epochs}: 8 with early stopping (patience=3 on validation loss)
|
| 177 |
-
\item \textbf{Precision}: BFloat16 on NVIDIA RTX 4070 (12GB VRAM)
|
| 178 |
-
\item \textbf{Gradient clipping}: Max norm 1.0
|
| 179 |
-
\item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
|
| 180 |
-
\end{itemize}
|
| 181 |
-
|
| 182 |
-
\textbf{Task scheduling.} We use \textbf{temperature-based sampling}: task $i$ is sampled with probability $p_i \propto n_i^\alpha$, where $n_i$ is the dataset size and $\alpha = 0.5$ (square-root scaling). This gives sampling probabilities of approximately 45\% summarization, 43\% emotion, and 12\% topic---ensuring the small topic dataset receives proportionally more gradient updates than pure proportional sampling would provide, while still exposing the model more frequently to larger datasets. We compared this against round-robin scheduling (equal update frequency regardless of dataset size) in preliminary experiments and found temperature sampling yields substantially better emotion detection performance.
|
| 183 |
-
|
| 184 |
-
\textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
|
| 185 |
-
|
| 186 |
-
\textbf{Gradient conflict monitoring.} To characterize optimization interference between tasks, we implement periodic gradient conflict diagnostics. At configurable intervals during training, per-task gradients are computed independently and compared via cosine similarity: $\cos(\mathbf{g}_i, \mathbf{g}_j) = \mathbf{g}_i \cdot \mathbf{g}_j / (\|\mathbf{g}_i\| \|\mathbf{g}_j\|)$. Negative cosine similarity indicates a gradient conflict---tasks pulling the shared parameters in opposing directions. Conflict rates (fraction of measured steps with $\cos < 0$) are logged to MLflow for analysis. This diagnostic does not modify training dynamics (unlike PCGrad \cite{yu2020gradient} or CAGrad \cite{liu2021conflict}), but provides empirical evidence for whether gradient conflicts contribute to observed negative transfer.
|
| 187 |
-
|
| 188 |
-
\textbf{Early stopping.} Early stopping is based on the combined weighted validation loss (using the same task weights as training) with patience of 3 epochs. The best checkpoint is selected by minimum combined validation loss.
|
| 189 |
-
|
| 190 |
-
\subsection{Baselines and Ablations}
|
| 191 |
-
|
| 192 |
-
We compare four configurations:
|
| 193 |
-
|
| 194 |
-
\begin{enumerate}
|
| 195 |
-
\item \textbf{Random/Majority}: Random predictions for classification; summarization is not evaluated against random baselines (ROUGE of random text is near zero).
|
| 196 |
-
\item \textbf{FLAN-T5-base (zero-shot)}: Pre-trained model with task-appropriate prompts, no fine-tuning.
|
| 197 |
-
\item \textbf{Single-Task}: Separate models fine-tuned on each task individually with identical hyperparameters.
|
| 198 |
-
\item \textbf{Multi-Task Baseline}: Joint training with mean pooling and round-robin scheduling.
|
| 199 |
-
\item \textbf{Multi-Task Improved}: Joint training with attention pooling for emotion and temperature sampling ($\alpha=0.5$).
|
| 200 |
-
\end{enumerate}
|
| 201 |
-
|
| 202 |
-
We additionally ablate FLAN-T5 initialization vs. random initialization to isolate transfer learning contribution.
|
| 203 |
-
|
| 204 |
-
\subsection{Evaluation Metrics}
|
| 205 |
-
|
| 206 |
-
\begin{itemize}
|
| 207 |
-
\item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BLEU-4 (n-gram precision with brevity penalty). ROUGE-1 serves as the primary metric for summarization quality. BERTScore \cite{zhang2019bertscore} is available as an optional semantic similarity metric but is not used in our primary evaluation due to its high computational cost and the difficulty of interpreting its absolute values. Per-domain breakdown (literary vs. academic) is provided to analyze domain-specific quality.
|
| 208 |
-
\item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
|
| 209 |
-
\item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
|
| 210 |
-
\end{itemize}
|
| 211 |
-
|
| 212 |
-
\textbf{Statistical rigor.} To address limitations of single-seed evaluation, we implement bootstrap confidence intervals (1,000 resamples, 95\% percentile CI) for all key metrics. For summarization, per-sample ROUGE-1 and ROUGE-L scores are bootstrapped; for emotion, per-sample F1 values; for topic, per-sample correctness indicators. We additionally provide \texttt{paired\_bootstrap\_test()} for comparing two system configurations on the same test set (null hypothesis: system B $\leq$ system A). Multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) automates training across $k$ seeds and reports mean $\pm$ standard deviation across runs, enabling variance-aware claims. Results in Table~\ref{tab:main_results} remain single-seed but should be validated with multi-seed runs before drawing strong conclusions.
|
| 213 |
-
|
| 214 |
-
%=============================================================================
|
| 215 |
-
\section{Results}
|
| 216 |
-
%=============================================================================
|
| 217 |
-
|
| 218 |
-
\subsection{Main Results}
|
| 219 |
-
|
| 220 |
-
Table \ref{tab:main_results} compares single-task specialists, baseline MTL (mean pooling, round-robin scheduling), and improved MTL (attention pooling, temperature sampling).
|
| 221 |
-
|
| 222 |
-
\begin{table}[htbp]
|
| 223 |
-
\centering
|
| 224 |
-
\caption{Main Results. Single-Task and MTL Baseline use mean pooling and round-robin scheduling. MTL Improved uses attention pooling for emotion and temperature sampling ($\alpha=0.5$). All results are single-seed. Bold indicates best.}
|
| 225 |
-
\label{tab:main_results}
|
| 226 |
-
\begin{tabular}{llccc}
|
| 227 |
-
\toprule
|
| 228 |
-
\textbf{Task} & \textbf{Metric} & \textbf{Single} & \textbf{MTL Base} & \textbf{MTL Impr.} \\
|
| 229 |
-
\midrule
|
| 230 |
-
\multirow{3}{*}{Summ.} & ROUGE-1 & 0.298 & 0.306 & \textbf{0.310} \\
|
| 231 |
-
& ROUGE-2 & 0.085 & 0.090 & \textbf{0.091} \\
|
| 232 |
-
& ROUGE-L & 0.179 & 0.183 & \textbf{0.185} \\
|
| 233 |
-
\midrule
|
| 234 |
-
\multirow{2}{*}{Topic} & Accuracy & 82.0\% & 85.2\% & \textbf{85.7\%} \\
|
| 235 |
-
& Macro F1 & 0.812 & 0.847 & \textbf{0.854} \\
|
| 236 |
-
\midrule
|
| 237 |
-
\multirow{3}{*}{Emotion} & Sample F1 & 0.218 & 0.199 & \textbf{0.352} \\
|
| 238 |
-
& Macro F1 & --- & --- & 0.143 \\
|
| 239 |
-
& Micro F1 & --- & --- & \textbf{0.443} \\
|
| 240 |
-
\bottomrule
|
| 241 |
-
\end{tabular}
|
| 242 |
-
\end{table}
|
| 243 |
-
|
| 244 |
-
\textbf{Key finding}: Attention pooling and temperature sampling yield improvements across \textit{all} tasks, with the largest impact on emotion detection:
|
| 245 |
-
|
| 246 |
-
\begin{itemize}
|
| 247 |
-
\item \textbf{Emotion detection: negative transfer eliminated.} Baseline MTL with mean pooling degraded emotion F1 by $-$0.019 vs. single-task. With attention pooling and temperature sampling, multi-task emotion F1 improves to 0.352---a +0.134 gain over single-task (0.218) and +0.153 over baseline MTL (0.199). The attention pooling mechanism allows the emotion head to focus on emotionally salient tokens rather than averaging over the full sequence, which is critical for the sparse multi-label task. Temperature sampling ensures the emotion task receives proportional gradient exposure ($\sim$43\% of steps).
|
| 248 |
-
|
| 249 |
-
\item \textbf{Topic classification: +3.7\% accuracy} over single-task (85.7\% vs. 82.0\%, 95\% CI: [80.4\%, 91.0\%]). The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). The bootstrap CI is wide due to the small validation set (189 samples), but the lower bound (80.4\%) still exceeds the single-task point estimate.
|
| 250 |
-
|
| 251 |
-
\item \textbf{Summarization remains stable} across all configurations. ROUGE-1 improves slightly from 0.298 (single-task) to 0.310 (improved MTL). The decoder---which contains half the model's parameters---insulates summarization from classification interference. ROUGE-1 95\% CI: [0.306, 0.313].
|
| 252 |
-
\end{itemize}
|
| 253 |
-
|
| 254 |
-
\subsection{Baseline Comparisons}
|
| 255 |
-
\label{sec:baseline_discussion}
|
| 256 |
-
|
| 257 |
-
Table \ref{tab:baselines} contextualizes our results against trivial and zero-shot baselines.
|
| 258 |
-
|
| 259 |
-
\begin{table}[htbp]
|
| 260 |
-
\centering
|
| 261 |
-
\caption{Comparison with Baselines (Improved MTL Configuration)}
|
| 262 |
-
\label{tab:baselines}
|
| 263 |
-
\begin{tabular}{lccc}
|
| 264 |
-
\toprule
|
| 265 |
-
\textbf{Model} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
|
| 266 |
-
\midrule
|
| 267 |
-
Random/Majority & --- & 14.3\% & 0.036 \\
|
| 268 |
-
FLAN-T5 zero-shot & 0.121 & 58.2\% & 0.089 \\
|
| 269 |
-
Single-Task & 0.179 & 82.0\% & 0.218 \\
|
| 270 |
-
\textbf{Multi-Task (Impr.)} & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
|
| 271 |
-
\bottomrule
|
| 272 |
-
\end{tabular}
|
| 273 |
-
\end{table}
|
| 274 |
-
|
| 275 |
-
Fine-tuning provides substantial gains over zero-shot across all tasks (+0.064 ROUGE-L, +27\% topic accuracy, +0.13 emotion F1), demonstrating the importance of domain adaptation even with instruction-tuned models. The improved MTL configuration further improves over single-task baselines on all three tasks, demonstrating that the combination of attention pooling and temperature sampling enables positive transfer even for the domain-mismatched emotion task.
|
| 276 |
-
|
| 277 |
-
\subsection{Ablation: Transfer Learning Contribution}
|
| 278 |
-
|
| 279 |
-
Table \ref{tab:transfer_ablation} isolates the contribution of FLAN-T5 pre-training by comparing against random initialization with identical architecture and training.
|
| 280 |
-
|
| 281 |
-
\begin{table}[htbp]
|
| 282 |
-
\centering
|
| 283 |
-
\caption{Effect of Pre-trained Initialization (Improved MTL Setting)}
|
| 284 |
-
\label{tab:transfer_ablation}
|
| 285 |
-
\begin{tabular}{lccc}
|
| 286 |
-
\toprule
|
| 287 |
-
\textbf{Initialization} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
|
| 288 |
-
\midrule
|
| 289 |
-
Random & 0.098 & 45.2\% & 0.082 \\
|
| 290 |
-
FLAN-T5-base & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
|
| 291 |
-
\midrule
|
| 292 |
-
\textit{Absolute gain} & +0.087 & +40.5\% & +0.270 \\
|
| 293 |
-
\bottomrule
|
| 294 |
-
\end{tabular}
|
| 295 |
-
\end{table}
|
| 296 |
-
|
| 297 |
-
FLAN-T5 initialization provides large absolute gains across all tasks. \textbf{Pre-training is necessary for competitive performance}---random initialization produces substantially worse results on all tasks even with identical data and training budget. Fine-tuning provides the remaining domain adaptation that zero-shot pre-training alone cannot achieve.
|
| 298 |
-
|
| 299 |
-
\subsection{Per-Class Topic Analysis}
|
| 300 |
-
|
| 301 |
-
Table \ref{tab:topic_breakdown} reveals per-class patterns in topic classification across the 7 classes.
|
| 302 |
-
|
| 303 |
-
\begin{table}[htbp]
|
| 304 |
-
\centering
|
| 305 |
-
\caption{Per-Class Topic Classification (Improved MTL)}
|
| 306 |
-
\label{tab:topic_breakdown}
|
| 307 |
-
\begin{tabular}{lccc}
|
| 308 |
-
\toprule
|
| 309 |
-
\textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
|
| 310 |
-
\midrule
|
| 311 |
-
Arts & 0.93 & 0.79 & 0.86 \\
|
| 312 |
-
Business & 0.97 & 1.00 & 0.98 \\
|
| 313 |
-
Fiction & 0.95 & 1.00 & 0.97 \\
|
| 314 |
-
History & 0.85 & 0.76 & 0.80 \\
|
| 315 |
-
Philosophy & 0.79 & 0.82 & 0.81 \\
|
| 316 |
-
Science & 0.57 & 0.80 & 0.67 \\
|
| 317 |
-
Technology & 0.89 & 0.89 & 0.89 \\
|
| 318 |
-
\midrule
|
| 319 |
-
\textit{Macro Avg} & 0.85 & 0.87 & 0.85 \\
|
| 320 |
-
\bottomrule
|
| 321 |
-
\end{tabular}
|
| 322 |
-
\end{table}
|
| 323 |
-
|
| 324 |
-
Fiction and Business achieve near-perfect classification (F1 $\geq$ 0.97), while Science shows the most confusion (F1 = 0.67). Error analysis reveals Science samples are frequently misclassified as Technology---semantically plausible given that scientific research papers often describe technical methods. The Arts class shows lower recall (0.79), suggesting some arts-related texts are misclassified into adjacent categories.
|
| 325 |
-
|
| 326 |
-
\subsection{Per-Domain Summarization Analysis}
|
| 327 |
-
|
| 328 |
-
Table \ref{tab:domain_breakdown} reveals a substantial quality gap between academic and literary summarization, reflecting the 11:1 training imbalance.
|
| 329 |
-
|
| 330 |
-
\begin{table}[htbp]
|
| 331 |
-
\centering
|
| 332 |
-
\caption{Per-Domain Summarization Performance (Improved MTL)}
|
| 333 |
-
\label{tab:domain_breakdown}
|
| 334 |
-
\begin{tabular}{lcccc}
|
| 335 |
-
\toprule
|
| 336 |
-
\textbf{Domain} & \textbf{N} & \textbf{ROUGE-1} & \textbf{ROUGE-L} & \textbf{BLEU-4} \\
|
| 337 |
-
\midrule
|
| 338 |
-
Academic & 2,493 & 0.319 & 0.189 & 0.026 \\
|
| 339 |
-
Literary & 234 & 0.206 & 0.137 & 0.008 \\
|
| 340 |
-
\midrule
|
| 341 |
-
\textit{Overall} & 2,727 & 0.310 & 0.185 & 0.024 \\
|
| 342 |
-
\bottomrule
|
| 343 |
-
\end{tabular}
|
| 344 |
-
\end{table}
|
| 345 |
-
|
| 346 |
-
Academic summaries (ROUGE-1: 0.319) outperform literary summaries (ROUGE-1: 0.206) by +0.113, a large gap attributable to two factors: (1) the encoder is disproportionately trained on academic text ($\sim$45K academic vs. $\sim$4K literary), and (2) academic abstracts follow more predictable structural conventions (background-method-result) that are easier for the model to reproduce. Literary descriptions---which describe \textit{what a book is about} in narrative prose---require more creative generation.
|
| 347 |
-
|
| 348 |
-
\subsection{Analysis: Emotion Detection Improvements}
|
| 349 |
-
\label{sec:emotion_analysis}
|
| 350 |
-
|
| 351 |
-
Our improved multi-task emotion sample-averaged F1 (0.352) represents a dramatic improvement over the baseline MTL configuration (0.199). With per-class threshold tuning, macro F1 reaches 0.294---approaching published GoEmotions baselines (0.46 macro F1 with BERT-base \cite{demszky2020goemotions}). We analyze the contributing factors:
|
| 352 |
-
|
| 353 |
-
\begin{enumerate}
|
| 354 |
-
\item \textbf{Attention pooling is critical.} Replacing mean pooling with a learned attention query allows the emotion head to focus on emotionally salient tokens. In our 28-class multi-label setting, emotional signals are typically concentrated in specific words or phrases (e.g., ``grateful,'' ``hilarious,'' ``heartbreaking''), which mean pooling dilutes across the full 512-token sequence. The top-performing classes---gratitude (F1: 0.888), amusement (0.751), love (0.740), admiration (0.653)---correspond to emotions with distinctive lexical markers that attention pooling can localize.
|
| 355 |
-
|
| 356 |
-
\item \textbf{Temperature sampling improves optimization.} With round-robin scheduling, emotion receives equal update frequency as the other tasks, but the summarization decoder backpropagates much larger gradients through the encoder, skewing shared representations toward academic text style. Temperature sampling ($\alpha=0.5$) allocates $\sim$43\% of steps to emotion---proportional to its dataset size---ensuring the encoder maintains emotion-relevant features.
|
| 357 |
-
|
| 358 |
-
\item \textbf{Remaining class-level gaps.} Despite overall improvement, 15 of 28 emotion classes still have zero F1 at the default 0.5 threshold (including approval, annoyance, disapproval, anger). These tend to be either rare classes ($<$100 support) or semantically subtle emotions that overlap with other classes. Per-class threshold tuning recovers non-zero performance for most of these classes, increasing macro F1 from 0.143 to 0.294.
|
| 359 |
-
|
| 360 |
-
\item \textbf{Domain gap persists.} Despite improvements, the remaining gap vs. published GoEmotions baselines (0.46 macro F1) reflects the fundamental domain mismatch between Reddit comments and our literary/academic encoder. Encoder-only architectures (BERT) dedicate full model capacity to classification, whereas our encoder is optimized primarily for summarization decoding.
|
| 361 |
-
\end{enumerate}
|
| 362 |
-
|
| 363 |
-
\textbf{Per-class threshold tuning results.} Sweeping $\tau \in \{0.1, \ldots, 0.9\}$ per class on the validation set yields tuned sample-averaged F1 of 0.503, tuned macro F1 of 0.294, and tuned micro F1 of 0.486. The optimal thresholds vary widely: gratitude saturates at $\tau=0.65$ (high confidence predictions), while rare classes require $\tau \leq 0.2$ to achieve non-zero recall.
|
| 364 |
-
|
| 365 |
-
\textbf{Implication}: Architectural isolation of classification heads (attention pooling) combined with balanced optimization (temperature sampling) can overcome domain mismatch in MTL, converting negative transfer to substantial positive transfer.
|
| 366 |
-
|
| 367 |
-
\subsection{Training Dynamics}
|
| 368 |
-
|
| 369 |
-
Figure \ref{fig:training_curves} shows training progression over 8 epochs (approximately 9 hours on RTX 4070 with temperature sampling).
|
| 370 |
-
|
| 371 |
-
\begin{figure}[htbp]
|
| 372 |
-
\centering
|
| 373 |
-
\includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
|
| 374 |
-
\caption{Training and validation loss with temperature sampling and attention pooling. Combined validation loss decreases from 4.298 to 3.925 over 8 epochs; best checkpoint at epoch 8.}
|
| 375 |
-
\label{fig:training_curves}
|
| 376 |
-
\end{figure}
|
| 377 |
-
|
| 378 |
-
Key observations:
|
| 379 |
-
\begin{itemize}
|
| 380 |
-
\item Topic classification converges rapidly: 91\% training accuracy by epoch 3 (84\% validation), reaching 98\% by epoch 8. Validation accuracy plateaus near 86\% from epoch 2 onward, while training accuracy continues climbing---a sign of mild overfitting on the small (3.4K) topic dataset. The reduced task weight (0.3) limits gradient dominance.
|
| 381 |
-
\item Summarization training loss decreases steadily (4.057 $\rightarrow$ 3.699), with validation loss flattening after epoch 5 (3.665 $\rightarrow$ 3.653). Training ROUGE-1 improves from 0.287 to 0.308.
|
| 382 |
-
\item Emotion F1 improves steadily throughout training: validation F1 rises from 0.197 (epoch 1) to 0.459 (epoch 8), indicating the attention pooling mechanism continues refining its weights over the full training duration.
|
| 383 |
-
\item Combined validation loss decreases from 4.298 (epoch 1) to 3.925 (epoch 8), though the decrease is marginal after epoch 5. Early stopping (patience=3) did not trigger because the combined loss continued improving slightly each epoch. Additional epochs could yield further modest gains, though the near-plateau after epoch 5 suggests diminishing returns.
|
| 384 |
-
\end{itemize}
|
| 385 |
-
|
| 386 |
-
%=============================================================================
|
| 387 |
-
\section{Discussion}
|
| 388 |
-
%=============================================================================
|
| 389 |
-
|
| 390 |
-
\subsection{When Does MTL Help?}
|
| 391 |
-
|
| 392 |
-
Our results demonstrate that MTL effectiveness depends on both task relatedness \textit{and} architectural/optimization choices:
|
| 393 |
-
|
| 394 |
-
\textbf{MTL helps when}: (1) A small-dataset task (topic: 3.4K samples) shares domain with a large-dataset task (summarization: 49K literary/academic samples)---the topic classifier benefits from shared encoder representations tuned to literary and academic vocabulary. (2) Task-specific heads are architecturally isolated from shared representations---attention pooling for emotion allows task-specific feature extraction without interfering with the shared encoder.
|
| 395 |
-
|
| 396 |
-
\textbf{MTL requires intervention when}: An auxiliary task's domain is misaligned with the primary training signal. With naive mean pooling, emotion detection suffered negative transfer because the encoder's representations were skewed toward summarization. Attention pooling and temperature sampling together overcame this: attention pooling provides architectural isolation, while temperature sampling ensures balanced optimization.
|
| 397 |
-
|
| 398 |
-
\textbf{MTL is neutral for}: The primary task (summarization) with sufficient data and a dedicated component (decoder, $\sim$136M parameters) that insulates it from interference. Classification heads are small and their gradients have limited impact relative to the decoder's backpropagation signal.
|
| 399 |
-
|
| 400 |
-
\subsection{Comparison to MTL Literature}
|
| 401 |
-
|
| 402 |
-
Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---our baseline results (positive transfer for topic, negative for emotion) confirmed this, but our improved configuration shows that \textit{architectural interventions can change these grouping dynamics}. Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our gradient conflict diagnostics (Section~3.4) enable empirical measurement of inter-task gradient cosine similarity, and our temperature sampling partially addresses gradient imbalance by controlling task exposure frequency. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks; our results suggest that per-task architectural isolation (attention pooling) can mitigate this.
|
| 403 |
-
|
| 404 |
-
A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Our results show that task-specific pooling strategies can partially compensate for this asymmetry. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform more targeted isolation strategies.
|
| 405 |
-
|
| 406 |
-
\subsection{Implications for Practitioners}
|
| 407 |
-
|
| 408 |
-
Based on our findings:
|
| 409 |
-
|
| 410 |
-
\begin{enumerate}
|
| 411 |
-
\item \textbf{Audit domain alignment} before combining tasks in MTL. If auxiliary tasks draw from different text domains (e.g., social media vs. academic), negative transfer is likely unless mitigated by gradient-conflict methods or per-task adapters.
|
| 412 |
-
|
| 413 |
-
\item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
|
| 414 |
-
|
| 415 |
-
\item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation. Learned attention pooling (replacing mean pooling) is a lightweight isolation strategy for multi-label heads that improves focus on task-relevant tokens.
|
| 416 |
-
|
| 417 |
-
\item \textbf{Monitor gradient conflicts} before deploying MTL. Inter-task gradient cosine similarity monitoring (at negligible computational cost) reveals whether tasks interfere at the optimization level, informing the choice between simple fixed weights and more sophisticated methods (PCGrad, Ortho-LoRA).
|
| 418 |
-
|
| 419 |
-
\item \textbf{Use temperature-based sampling} when dataset sizes vary widely. Square-root temperature ($\alpha=0.5$) balances exposure across tasks without starving small-dataset tasks.
|
| 420 |
-
|
| 421 |
-
\item \textbf{Validate with multiple seeds} before drawing conclusions from MTL comparisons, especially with small validation sets. Bootstrap confidence intervals provide within-run uncertainty estimates; multi-seed runs capture cross-run variance.
|
| 422 |
-
\end{enumerate}
|
| 423 |
-
|
| 424 |
-
\subsection{Limitations}
|
| 425 |
-
\label{sec:limitations}
|
| 426 |
-
|
| 427 |
-
We identify several limitations that constrain the generalizability of our findings:
|
| 428 |
-
|
| 429 |
-
\begin{itemize}
|
| 430 |
-
\item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
|
| 431 |
-
|
| 432 |
-
\item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods could provide additional gains beyond our attention pooling and temperature sampling improvements.
|
| 433 |
-
|
| 434 |
-
\item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results. The remaining gap between our tuned macro F1 (0.294) and published GoEmotions baselines (0.46) likely reflects this architectural difference.
|
| 435 |
-
|
| 436 |
-
\item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
|
| 437 |
-
|
| 438 |
-
\item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
|
| 439 |
-
|
| 440 |
-
\item \textbf{No human evaluation}: ROUGE scores are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
|
| 441 |
-
|
| 442 |
-
\item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
|
| 443 |
-
|
| 444 |
-
\item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text. Per-domain evaluation reveals this imbalance in practice, and is analyzed in per-domain breakdowns.
|
| 445 |
-
\end{itemize}
|
| 446 |
-
|
| 447 |
-
\subsection{Future Work}
|
| 448 |
-
|
| 449 |
-
\begin{itemize}
|
| 450 |
-
\item \textbf{Gradient-conflict mitigation}: Our gradient conflict diagnostics provide the empirical foundation; the natural next step is applying Ortho-LoRA \cite{ortholora2025} for orthogonal gradient projection, PCGrad \cite{yu2020gradient} for gradient surgery, or CAGrad \cite{liu2021conflict} for conflict-averse optimization. These methods directly target the interference our diagnostics characterize.
|
| 451 |
-
|
| 452 |
-
\item \textbf{Parameter-efficient multi-tasking}: PiKE \cite{pike2025} for selective knowledge exchange between tasks, per-task LoRA adapters \cite{hu2022lora}, ScaLearn \cite{scallearn2023} shared attention with task-specific scaling, or adapter layers \cite{houlsby2019parameter} to provide task-specific specialization while maintaining shared encoder representations. These methods offer a spectrum from minimal (LoRA) to moderate (PiKE, ScaLearn) additional parameters.
|
| 453 |
-
|
| 454 |
-
\item \textbf{Principled task grouping}: Applying transfer-gain estimation methods \cite{taskgrouping2024} to determine whether emotion should be trained jointly with summarization and topic, or in a separate group. Neuron-centric analysis \cite{neuroncentric2024} could further guide which encoder layers to share vs. specialize.
|
| 455 |
-
|
| 456 |
-
\item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
|
| 457 |
-
|
| 458 |
-
\item \textbf{Multi-seed evaluation with confidence intervals}: Our \texttt{train\_multiseed.py} infrastructure enables running $k$ seeds per configuration with automated aggregation. Running $\geq$5 seeds would establish statistical significance of observed transfer effects via bootstrap tests.
|
| 459 |
-
|
| 460 |
-
\item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
|
| 461 |
-
|
| 462 |
-
\item \textbf{Temperature sampling ablation}: Comparing round-robin vs. temperature-based sampling ($\alpha \in \{0.3, 0.5, 0.7, 1.0\}$) to quantify the effect of scheduling strategy on task-specific performance, particularly for the low-resource topic classification task.
|
| 463 |
-
\end{itemize}
|
| 464 |
-
|
| 465 |
-
%=============================================================================
|
| 466 |
-
\section{Conclusion}
|
| 467 |
-
%=============================================================================
|
| 468 |
-
|
| 469 |
-
We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our key finding is that naive MTL with mean pooling produces heterogeneous transfer effects---positive for topic (+3.7\%), negative for emotion ($-$0.02 F1)---but that targeted interventions can eliminate negative transfer entirely. Learned attention pooling for the emotion head, combined with temperature-based task sampling ($\alpha=0.5$), improves multi-task emotion F1 from 0.199 to 0.352 (+77\%), surpassing the single-task baseline. With per-class threshold tuning, macro F1 reaches 0.294. Summarization quality remains robust across configurations (ROUGE-1: 0.310, ROUGE-L: 0.185), with per-domain analysis revealing a quality gap between academic (ROUGE-1: 0.319) and literary (ROUGE-1: 0.206) summaries driven by training data imbalance.
|
| 470 |
-
|
| 471 |
-
These results demonstrate that negative transfer in MTL is not an inherent limitation but can be addressed through architectural isolation (task-specific pooling) and balanced optimization (temperature sampling). Pre-trained initialization (FLAN-T5) remains essential for competitive performance across all tasks. Promising follow-up directions include Ortho-LoRA \cite{ortholora2025} for gradient orthogonalization, PiKE \cite{pike2025} for parameter-efficient knowledge exchange, and principled task grouping \cite{taskgrouping2024} to guide which tasks to train jointly. We provide our code, trained models, and datasets to enable replication and extension.
|
| 472 |
-
|
| 473 |
-
Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
|
| 474 |
-
Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
|
| 475 |
-
|
| 476 |
-
%=============================================================================
|
| 477 |
-
% References
|
| 478 |
-
%=============================================================================
|
| 479 |
-
|
| 480 |
-
\begin{thebibliography}{00}
|
| 481 |
-
|
| 482 |
-
\bibitem{caruana1997multitask}
|
| 483 |
-
R. Caruana, ``Multitask learning,'' \textit{Machine Learning}, vol. 28, no. 1, pp. 41--75, 1997.
|
| 484 |
-
|
| 485 |
-
\bibitem{collobert2011natural}
|
| 486 |
-
R. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu, and P. Kuksa, ``Natural language processing (almost) from scratch,'' \textit{JMLR}, vol. 12, pp. 2493--2537, 2011.
|
| 487 |
-
|
| 488 |
-
\bibitem{johnson2017google}
|
| 489 |
-
M. Johnson et al., ``Google's multilingual neural machine translation system: Enabling zero-shot translation,'' \textit{TACL}, vol. 5, pp. 339--351, 2017.
|
| 490 |
-
|
| 491 |
-
\bibitem{mccann2018natural}
|
| 492 |
-
B. McCann, N. S. Keskar, C. Xiong, and R. Socher, ``The natural language decathlon: Multitask learning as question answering,'' \textit{arXiv:1806.08730}, 2018.
|
| 493 |
-
|
| 494 |
-
\bibitem{standley2020tasks}
|
| 495 |
-
T. Standley, A. Zamir, D. Chen, L. Guibas, J. Malik, and S. Savarese, ``Which tasks should be learned together in multi-task learning?'' in \textit{ICML}, 2020.
|
| 496 |
-
|
| 497 |
-
\bibitem{yu2020gradient}
|
| 498 |
-
T. Yu, S. Kumar, A. Gupta, S. Levine, K. Hausman, and C. Finn, ``Gradient surgery for multi-task learning,'' in \textit{NeurIPS}, 2020.
|
| 499 |
-
|
| 500 |
-
\bibitem{liu2021conflict}
|
| 501 |
-
B. Liu, X. Liu, X. Jin, P. Stone, and Q. Liu, ``Conflict-averse gradient descent for multi-task learning,'' in \textit{NeurIPS}, 2021.
|
| 502 |
-
|
| 503 |
-
\bibitem{chen2018gradnorm}
|
| 504 |
-
Z. Chen, V. Badrinarayanan, C.-Y. Lee, and A. Rabinovich, ``GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks,'' in \textit{ICML}, 2018.
|
| 505 |
-
|
| 506 |
-
\bibitem{kendall2018multi}
|
| 507 |
-
A. Kendall, Y. Gal, and R. Cipolla, ``Multi-task learning using uncertainty to weigh losses for scene geometry and semantics,'' in \textit{CVPR}, 2018.
|
| 508 |
-
|
| 509 |
-
\bibitem{aghajanyan2021muppet}
|
| 510 |
-
A. Aghajanyan, A. Gupta, A. Shrivastava, X. Chen, L. Zettlemoyer, and S. Gupta, ``Muppet: Massive multi-task representations with pre-finetuning,'' in \textit{EMNLP}, 2021.
|
| 511 |
-
|
| 512 |
-
\bibitem{aribandi2022ext5}
|
| 513 |
-
V. Aribandi et al., ``ExT5: Towards extreme multi-task scaling for transfer learning,'' in \textit{ICLR}, 2022.
|
| 514 |
-
|
| 515 |
-
\bibitem{raffel2020exploring}
|
| 516 |
-
C. Raffel et al., ``Exploring the limits of transfer learning with a unified text-to-text transformer,'' \textit{JMLR}, vol. 21, no. 140, pp. 1--67, 2020.
|
| 517 |
-
|
| 518 |
-
\bibitem{chung2022scaling}
|
| 519 |
-
H. W. Chung et al., ``Scaling instruction-finetuned language models,'' \textit{arXiv:2210.11416}, 2022.
|
| 520 |
-
|
| 521 |
-
\bibitem{nallapati2016abstractive}
|
| 522 |
-
R. Nallapati, B. Zhou, C. dos Santos, C. Gulcehre, and B. Xiang, ``Abstractive text summarization using sequence-to-sequence RNNs and beyond,'' in \textit{CoNLL}, 2016.
|
| 523 |
-
|
| 524 |
-
\bibitem{narayan2018don}
|
| 525 |
-
S. Narayan, S. B. Cohen, and M. Lapata, ``Don't give me the details, just the summary! Topic-aware convolutional neural networks for extreme summarization,'' in \textit{EMNLP}, 2018.
|
| 526 |
-
|
| 527 |
-
\bibitem{kryscinski2021booksum}
|
| 528 |
-
W. Kryscinski, N. Rajani, D. Aber, and C. Xiong, ``BookSum: A collection of datasets for long-form narrative summarization,'' in \textit{Findings of EMNLP}, 2021.
|
| 529 |
-
|
| 530 |
-
\bibitem{cohan2018discourse}
|
| 531 |
-
A. Cohan et al., ``A discourse-aware attention model for abstractive summarization of long documents,'' in \textit{NAACL-HLT}, 2018.
|
| 532 |
-
|
| 533 |
-
\bibitem{mao2022citesum}
|
| 534 |
-
Y. Mao, M. Zhong, and J. Han, ``CiteSum: Citation text-guided scientific extreme summarization and domain adaptation with limited supervision,'' in \textit{EMNLP}, 2022.
|
| 535 |
-
|
| 536 |
-
\bibitem{demszky2020goemotions}
|
| 537 |
-
D. Demszky et al., ``GoEmotions: A dataset of fine-grained emotions,'' in \textit{ACL}, 2020.
|
| 538 |
-
|
| 539 |
-
\bibitem{zhang2019root}
|
| 540 |
-
B. Zhang and R. Sennrich, ``Root mean square layer normalization,'' in \textit{NeurIPS}, 2019.
|
| 541 |
-
|
| 542 |
-
\bibitem{lin2004rouge}
|
| 543 |
-
C.-Y. Lin, ``ROUGE: A package for automatic evaluation of summaries,'' in \textit{Text Summarization Branches Out}, 2004.
|
| 544 |
-
|
| 545 |
-
\bibitem{zhang2019bertscore}
|
| 546 |
-
T. Zhang, V. Kishore, F. Wu, K. Q. Weinberger, and Y. Artzi, ``BERTScore: Evaluating text generation with BERT,'' in \textit{ICLR}, 2020.
|
| 547 |
-
|
| 548 |
-
\bibitem{hu2022lora}
|
| 549 |
-
E. J. Hu et al., ``LoRA: Low-rank adaptation of large language models,'' in \textit{ICLR}, 2022.
|
| 550 |
-
|
| 551 |
-
\bibitem{houlsby2019parameter}
|
| 552 |
-
N. Houlsby et al., ``Parameter-efficient transfer learning for NLP,'' in \textit{ICML}, 2019.
|
| 553 |
-
|
| 554 |
-
\bibitem{lin2017focal}
|
| 555 |
-
T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
|
| 556 |
-
|
| 557 |
-
\bibitem{ortholora2025}
|
| 558 |
-
B. Li et al., ``Ortho-LoRA: Orthogonal low-rank adaptation for multi-task learning,'' \textit{arXiv:2601.09684}, 2025.
|
| 559 |
-
|
| 560 |
-
\bibitem{pike2025}
|
| 561 |
-
Y. Wang et al., ``PiKE: Parameter-efficient knowledge exchange for multi-task learning,'' \textit{arXiv:2502.06244}, 2025.
|
| 562 |
-
|
| 563 |
-
\bibitem{scallearn2023}
|
| 564 |
-
H. Sun et al., ``ScaLearn: Simple and highly parameter-efficient task transfer by learning to scale,'' \textit{arXiv:2310.01217}, 2023.
|
| 565 |
-
|
| 566 |
-
\bibitem{taskgrouping2024}
|
| 567 |
-
S. Chen et al., ``Multi-task learning with task grouping via transfer-gain estimates,'' \textit{arXiv:2402.15328}, 2024.
|
| 568 |
-
|
| 569 |
-
\bibitem{neuroncentric2024}
|
| 570 |
-
A. Foroutan et al., ``What do neurons in multi-task language models encode? A neuron-centric analysis,'' \textit{arXiv:2407.06488}, 2024.
|
| 571 |
-
|
| 572 |
-
\end{thebibliography}
|
| 573 |
-
|
| 574 |
-
\end{document}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kaggle.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"username":"oliverperrin","key":"af2c73b5725d2839410a1f72cf84cf48"}
|
|
|
|
|
|
outputs/evaluation_report.json
DELETED
|
@@ -1,260 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"summarization": {
|
| 3 |
-
"rouge1": 0.3094793058747055,
|
| 4 |
-
"rouge2": 0.09069756722817666,
|
| 5 |
-
"rougeL": 0.1847154828322755,
|
| 6 |
-
"bleu4": 0.023982657019404153,
|
| 7 |
-
"num_samples": 2727,
|
| 8 |
-
"per_domain": {
|
| 9 |
-
"academic": {
|
| 10 |
-
"num_samples": 2493,
|
| 11 |
-
"rouge1": 0.31919183681728475,
|
| 12 |
-
"rouge2": 0.0968589097730544,
|
| 13 |
-
"rougeL": 0.18921129182459423,
|
| 14 |
-
"bleu4": 0.02551610700902003
|
| 15 |
-
},
|
| 16 |
-
"literary": {
|
| 17 |
-
"num_samples": 234,
|
| 18 |
-
"rouge1": 0.2060034954479976,
|
| 19 |
-
"rouge2": 0.0250555716539014,
|
| 20 |
-
"rougeL": 0.1368178254910352,
|
| 21 |
-
"bleu4": 0.0076455167454197795
|
| 22 |
-
}
|
| 23 |
-
},
|
| 24 |
-
"rouge1_ci": {
|
| 25 |
-
"mean": 0.30947930587470557,
|
| 26 |
-
"lower": 0.3060921045548166,
|
| 27 |
-
"upper": 0.3131015955325767
|
| 28 |
-
},
|
| 29 |
-
"rougeL_ci": {
|
| 30 |
-
"mean": 0.18471548283227546,
|
| 31 |
-
"lower": 0.18251665662669495,
|
| 32 |
-
"upper": 0.18701919830414013
|
| 33 |
-
}
|
| 34 |
-
},
|
| 35 |
-
"emotion": {
|
| 36 |
-
"sample_avg_f1": 0.3522975742816925,
|
| 37 |
-
"macro_f1": 0.14317210018634796,
|
| 38 |
-
"micro_f1": 0.4430159032344818,
|
| 39 |
-
"num_samples": 5426,
|
| 40 |
-
"num_classes": 28,
|
| 41 |
-
"per_class": {
|
| 42 |
-
"admiration": {
|
| 43 |
-
"precision": 0.714634120464325,
|
| 44 |
-
"recall": 0.6004098653793335,
|
| 45 |
-
"f1": 0.6525612537411732,
|
| 46 |
-
"support": 488
|
| 47 |
-
},
|
| 48 |
-
"amusement": {
|
| 49 |
-
"precision": 0.7708333134651184,
|
| 50 |
-
"recall": 0.7326732873916626,
|
| 51 |
-
"f1": 0.7512690366449468,
|
| 52 |
-
"support": 303
|
| 53 |
-
},
|
| 54 |
-
"anger": {
|
| 55 |
-
"precision": 0.0,
|
| 56 |
-
"recall": 0.0,
|
| 57 |
-
"f1": 0.0,
|
| 58 |
-
"support": 195
|
| 59 |
-
},
|
| 60 |
-
"annoyance": {
|
| 61 |
-
"precision": 0.0,
|
| 62 |
-
"recall": 0.0,
|
| 63 |
-
"f1": 0.0,
|
| 64 |
-
"support": 303
|
| 65 |
-
},
|
| 66 |
-
"approval": {
|
| 67 |
-
"precision": 0.0,
|
| 68 |
-
"recall": 0.0,
|
| 69 |
-
"f1": 0.0,
|
| 70 |
-
"support": 397
|
| 71 |
-
},
|
| 72 |
-
"caring": {
|
| 73 |
-
"precision": 0.0,
|
| 74 |
-
"recall": 0.0,
|
| 75 |
-
"f1": 0.0,
|
| 76 |
-
"support": 153
|
| 77 |
-
},
|
| 78 |
-
"confusion": {
|
| 79 |
-
"precision": 0.0,
|
| 80 |
-
"recall": 0.0,
|
| 81 |
-
"f1": 0.0,
|
| 82 |
-
"support": 152
|
| 83 |
-
},
|
| 84 |
-
"curiosity": {
|
| 85 |
-
"precision": 0.6166666746139526,
|
| 86 |
-
"recall": 0.14919355511665344,
|
| 87 |
-
"f1": 0.24025974958898805,
|
| 88 |
-
"support": 248
|
| 89 |
-
},
|
| 90 |
-
"desire": {
|
| 91 |
-
"precision": 0.0,
|
| 92 |
-
"recall": 0.0,
|
| 93 |
-
"f1": 0.0,
|
| 94 |
-
"support": 77
|
| 95 |
-
},
|
| 96 |
-
"disappointment": {
|
| 97 |
-
"precision": 0.0,
|
| 98 |
-
"recall": 0.0,
|
| 99 |
-
"f1": 0.0,
|
| 100 |
-
"support": 163
|
| 101 |
-
},
|
| 102 |
-
"disapproval": {
|
| 103 |
-
"precision": 0.0,
|
| 104 |
-
"recall": 0.0,
|
| 105 |
-
"f1": 0.0,
|
| 106 |
-
"support": 292
|
| 107 |
-
},
|
| 108 |
-
"disgust": {
|
| 109 |
-
"precision": 0.0,
|
| 110 |
-
"recall": 0.0,
|
| 111 |
-
"f1": 0.0,
|
| 112 |
-
"support": 97
|
| 113 |
-
},
|
| 114 |
-
"embarrassment": {
|
| 115 |
-
"precision": 0.0,
|
| 116 |
-
"recall": 0.0,
|
| 117 |
-
"f1": 0.0,
|
| 118 |
-
"support": 35
|
| 119 |
-
},
|
| 120 |
-
"excitement": {
|
| 121 |
-
"precision": 0.0,
|
| 122 |
-
"recall": 0.0,
|
| 123 |
-
"f1": 0.0,
|
| 124 |
-
"support": 96
|
| 125 |
-
},
|
| 126 |
-
"fear": {
|
| 127 |
-
"precision": 0.0,
|
| 128 |
-
"recall": 0.0,
|
| 129 |
-
"f1": 0.0,
|
| 130 |
-
"support": 90
|
| 131 |
-
},
|
| 132 |
-
"gratitude": {
|
| 133 |
-
"precision": 0.8997134566307068,
|
| 134 |
-
"recall": 0.8770949840545654,
|
| 135 |
-
"f1": 0.8882602556669954,
|
| 136 |
-
"support": 358
|
| 137 |
-
},
|
| 138 |
-
"grief": {
|
| 139 |
-
"precision": 0.0,
|
| 140 |
-
"recall": 0.0,
|
| 141 |
-
"f1": 0.0,
|
| 142 |
-
"support": 13
|
| 143 |
-
},
|
| 144 |
-
"joy": {
|
| 145 |
-
"precision": 0.0,
|
| 146 |
-
"recall": 0.0,
|
| 147 |
-
"f1": 0.0,
|
| 148 |
-
"support": 172
|
| 149 |
-
},
|
| 150 |
-
"love": {
|
| 151 |
-
"precision": 0.6996466517448425,
|
| 152 |
-
"recall": 0.7857142686843872,
|
| 153 |
-
"f1": 0.740186913163602,
|
| 154 |
-
"support": 252
|
| 155 |
-
},
|
| 156 |
-
"nervousness": {
|
| 157 |
-
"precision": 0.0,
|
| 158 |
-
"recall": 0.0,
|
| 159 |
-
"f1": 0.0,
|
| 160 |
-
"support": 21
|
| 161 |
-
},
|
| 162 |
-
"neutral": {
|
| 163 |
-
"precision": 0.6869627237319946,
|
| 164 |
-
"recall": 0.543035089969635,
|
| 165 |
-
"f1": 0.6065780936064032,
|
| 166 |
-
"support": 1766
|
| 167 |
-
},
|
| 168 |
-
"optimism": {
|
| 169 |
-
"precision": 0.7142857313156128,
|
| 170 |
-
"recall": 0.023923445492982864,
|
| 171 |
-
"f1": 0.04629629729995926,
|
| 172 |
-
"support": 209
|
| 173 |
-
},
|
| 174 |
-
"pride": {
|
| 175 |
-
"precision": 0.0,
|
| 176 |
-
"recall": 0.0,
|
| 177 |
-
"f1": 0.0,
|
| 178 |
-
"support": 15
|
| 179 |
-
},
|
| 180 |
-
"realization": {
|
| 181 |
-
"precision": 0.0,
|
| 182 |
-
"recall": 0.0,
|
| 183 |
-
"f1": 0.0,
|
| 184 |
-
"support": 127
|
| 185 |
-
},
|
| 186 |
-
"relief": {
|
| 187 |
-
"precision": 0.0,
|
| 188 |
-
"recall": 0.0,
|
| 189 |
-
"f1": 0.0,
|
| 190 |
-
"support": 18
|
| 191 |
-
},
|
| 192 |
-
"remorse": {
|
| 193 |
-
"precision": 1.0,
|
| 194 |
-
"recall": 0.014705882407724857,
|
| 195 |
-
"f1": 0.02898550735279132,
|
| 196 |
-
"support": 68
|
| 197 |
-
},
|
| 198 |
-
"sadness": {
|
| 199 |
-
"precision": 1.0,
|
| 200 |
-
"recall": 0.0279720276594162,
|
| 201 |
-
"f1": 0.054421768115822285,
|
| 202 |
-
"support": 143
|
| 203 |
-
},
|
| 204 |
-
"surprise": {
|
| 205 |
-
"precision": 0.0,
|
| 206 |
-
"recall": 0.0,
|
| 207 |
-
"f1": 0.0,
|
| 208 |
-
"support": 129
|
| 209 |
-
}
|
| 210 |
-
},
|
| 211 |
-
"tuned_thresholds": {
|
| 212 |
-
"admiration": 0.4,
|
| 213 |
-
"amusement": 0.55,
|
| 214 |
-
"anger": 0.2,
|
| 215 |
-
"annoyance": 0.15,
|
| 216 |
-
"approval": 0.15,
|
| 217 |
-
"caring": 0.1,
|
| 218 |
-
"confusion": 0.1,
|
| 219 |
-
"curiosity": 0.25,
|
| 220 |
-
"desire": 0.15,
|
| 221 |
-
"disappointment": 0.1,
|
| 222 |
-
"disapproval": 0.1,
|
| 223 |
-
"disgust": 0.1,
|
| 224 |
-
"embarrassment": 0.1,
|
| 225 |
-
"excitement": 0.1,
|
| 226 |
-
"fear": 0.1,
|
| 227 |
-
"gratitude": 0.65,
|
| 228 |
-
"grief": 0.1,
|
| 229 |
-
"joy": 0.2,
|
| 230 |
-
"love": 0.45,
|
| 231 |
-
"nervousness": 0.1,
|
| 232 |
-
"neutral": 0.3,
|
| 233 |
-
"optimism": 0.25,
|
| 234 |
-
"pride": 0.1,
|
| 235 |
-
"realization": 0.2,
|
| 236 |
-
"relief": 0.1,
|
| 237 |
-
"remorse": 0.2,
|
| 238 |
-
"sadness": 0.25,
|
| 239 |
-
"surprise": 0.1
|
| 240 |
-
},
|
| 241 |
-
"tuned_macro_f1": 0.29355332255363464,
|
| 242 |
-
"tuned_sample_avg_f1": 0.5025880336761475,
|
| 243 |
-
"tuned_micro_f1": 0.48644566535949707,
|
| 244 |
-
"sample_f1_ci": {
|
| 245 |
-
"mean": 0.3522975795552279,
|
| 246 |
-
"lower": 0.33984518982676004,
|
| 247 |
-
"upper": 0.3658618994962526
|
| 248 |
-
}
|
| 249 |
-
},
|
| 250 |
-
"topic": {
|
| 251 |
-
"accuracy": 0.8571428571428571,
|
| 252 |
-
"macro_f1": 0.8538751111963805,
|
| 253 |
-
"num_samples": 189,
|
| 254 |
-
"accuracy_ci": {
|
| 255 |
-
"mean": 0.8571428571428571,
|
| 256 |
-
"lower": 0.8042328042328042,
|
| 257 |
-
"upper": 0.91005291005291
|
| 258 |
-
}
|
| 259 |
-
}
|
| 260 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs/training_history.json
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train_epoch_1": {
|
| 3 |
-
"summarization_loss": 4.05727732604402,
|
| 4 |
-
"summarization_rouge_like": 0.20427788603502178,
|
| 5 |
-
"summarization_rouge1": 0.2867239527374218,
|
| 6 |
-
"summarization_rouge2": 0.08530419039955006,
|
| 7 |
-
"summarization_rougeL": 0.21671934441779328,
|
| 8 |
-
"summarization_bleu4": 0.046807610480627294,
|
| 9 |
-
"emotion_loss": 0.26667120894821444,
|
| 10 |
-
"emotion_f1": 0.20469974499405505,
|
| 11 |
-
"total_loss": 6.105969995046231,
|
| 12 |
-
"topic_loss": 1.6857517635423032,
|
| 13 |
-
"topic_accuracy": 0.4217703349282306
|
| 14 |
-
},
|
| 15 |
-
"val_epoch_1": {
|
| 16 |
-
"summarization_loss": 3.8147981293996174,
|
| 17 |
-
"summarization_rouge_like": 0.2193516213078271,
|
| 18 |
-
"summarization_rouge1": 0.26079796060659194,
|
| 19 |
-
"summarization_rouge2": 0.08507403927823329,
|
| 20 |
-
"summarization_rougeL": 0.2006794257877804,
|
| 21 |
-
"summarization_bleu4": 0.047830237595825456,
|
| 22 |
-
"emotion_loss": 0.14947238216797512,
|
| 23 |
-
"emotion_f1": 0.19722222675879797,
|
| 24 |
-
"topic_loss": 1.111328324476878,
|
| 25 |
-
"topic_accuracy": 0.7036666666666669,
|
| 26 |
-
"total_loss": 4.297669008910658
|
| 27 |
-
},
|
| 28 |
-
"train_epoch_2": {
|
| 29 |
-
"summarization_loss": 3.849853239841521,
|
| 30 |
-
"summarization_rouge_like": 0.21403927033293962,
|
| 31 |
-
"summarization_rouge1": 0.27951076939717684,
|
| 32 |
-
"summarization_rouge2": 0.0873046768836161,
|
| 33 |
-
"summarization_rougeL": 0.21374118927203542,
|
| 34 |
-
"summarization_bleu4": 0.04958577524880755,
|
| 35 |
-
"emotion_loss": 0.14221051054989492,
|
| 36 |
-
"emotion_f1": 0.26357290302542585,
|
| 37 |
-
"topic_loss": 0.7299397268663149,
|
| 38 |
-
"topic_accuracy": 0.8084686774942008,
|
| 39 |
-
"total_loss": 5.528282663174978
|
| 40 |
-
},
|
| 41 |
-
"val_epoch_2": {
|
| 42 |
-
"summarization_loss": 3.738964385986328,
|
| 43 |
-
"summarization_rouge_like": 0.22322817933854347,
|
| 44 |
-
"summarization_rouge1": 0.2648447987903156,
|
| 45 |
-
"summarization_rouge2": 0.08777067266852198,
|
| 46 |
-
"summarization_rougeL": 0.2049718124413594,
|
| 47 |
-
"summarization_bleu4": 0.04980800809043137,
|
| 48 |
-
"emotion_loss": 0.1332480485116442,
|
| 49 |
-
"emotion_f1": 0.3008111199736595,
|
| 50 |
-
"topic_loss": 0.5171811254819234,
|
| 51 |
-
"topic_accuracy": 0.8467777777777786,
|
| 52 |
-
"total_loss": 4.027366772142546
|
| 53 |
-
},
|
| 54 |
-
"train_epoch_3": {
|
| 55 |
-
"emotion_loss": 0.12831888329927568,
|
| 56 |
-
"emotion_f1": 0.3325013413977316,
|
| 57 |
-
"summarization_loss": 3.7839796767703127,
|
| 58 |
-
"summarization_rouge_like": 0.21797276831106976,
|
| 59 |
-
"summarization_rouge1": 0.28868883384124916,
|
| 60 |
-
"summarization_rouge2": 0.09150032176337587,
|
| 61 |
-
"summarization_rougeL": 0.22148013487440707,
|
| 62 |
-
"summarization_bleu4": 0.052993168973641876,
|
| 63 |
-
"total_loss": 5.379445686122572,
|
| 64 |
-
"topic_loss": 0.3385182340765703,
|
| 65 |
-
"topic_accuracy": 0.9149137451307789
|
| 66 |
-
},
|
| 67 |
-
"val_epoch_3": {
|
| 68 |
-
"summarization_loss": 3.699807391166687,
|
| 69 |
-
"summarization_rouge_like": 0.22613490382620294,
|
| 70 |
-
"summarization_rouge1": 0.27110048990501884,
|
| 71 |
-
"summarization_rouge2": 0.09042725720607361,
|
| 72 |
-
"summarization_rougeL": 0.209904253200661,
|
| 73 |
-
"summarization_bleu4": 0.05177241093143676,
|
| 74 |
-
"emotion_loss": 0.12147359546273946,
|
| 75 |
-
"emotion_f1": 0.3474666798238953,
|
| 76 |
-
"topic_loss": 0.5068136086066564,
|
| 77 |
-
"topic_accuracy": 0.8417777777777792,
|
| 78 |
-
"total_loss": 3.9733250692114286
|
| 79 |
-
},
|
| 80 |
-
"train_epoch_4": {
|
| 81 |
-
"summarization_loss": 3.746917572488457,
|
| 82 |
-
"summarization_rouge_like": 0.22054338132572013,
|
| 83 |
-
"summarization_rouge1": 0.29700759128401966,
|
| 84 |
-
"summarization_rouge2": 0.09528349132659034,
|
| 85 |
-
"summarization_rougeL": 0.2286643637324592,
|
| 86 |
-
"summarization_bleu4": 0.05591190647915982,
|
| 87 |
-
"emotion_loss": 0.12003502780097021,
|
| 88 |
-
"emotion_f1": 0.37240424536844824,
|
| 89 |
-
"total_loss": 5.303101773435515,
|
| 90 |
-
"topic_loss": 0.19978291214297147,
|
| 91 |
-
"topic_accuracy": 0.9528935185185234
|
| 92 |
-
},
|
| 93 |
-
"val_epoch_4": {
|
| 94 |
-
"summarization_loss": 3.6773871207237243,
|
| 95 |
-
"summarization_rouge_like": 0.22730110361278533,
|
| 96 |
-
"summarization_rouge1": 0.2719731929407321,
|
| 97 |
-
"summarization_rouge2": 0.09117786246379923,
|
| 98 |
-
"summarization_rougeL": 0.21082587270737135,
|
| 99 |
-
"summarization_bleu4": 0.052260125383420154,
|
| 100 |
-
"emotion_loss": 0.11476812147845825,
|
| 101 |
-
"emotion_f1": 0.40390001876900594,
|
| 102 |
-
"topic_loss": 0.5311758625507355,
|
| 103 |
-
"topic_accuracy": 0.8574444444444455,
|
| 104 |
-
"total_loss": 3.95150800096741
|
| 105 |
-
},
|
| 106 |
-
"train_epoch_5": {
|
| 107 |
-
"summarization_loss": 3.72376742684834,
|
| 108 |
-
"summarization_rouge_like": 0.22218972657959773,
|
| 109 |
-
"summarization_rouge1": 0.30386172952451457,
|
| 110 |
-
"summarization_rouge2": 0.09807265293507532,
|
| 111 |
-
"summarization_rougeL": 0.23422938393417417,
|
| 112 |
-
"summarization_bleu4": 0.05821407514551748,
|
| 113 |
-
"emotion_loss": 0.11460309708431649,
|
| 114 |
-
"emotion_f1": 0.41015538037428334,
|
| 115 |
-
"total_loss": 5.207888234798891,
|
| 116 |
-
"topic_loss": 0.13986067138923575,
|
| 117 |
-
"topic_accuracy": 0.9685236768802278
|
| 118 |
-
},
|
| 119 |
-
"val_epoch_5": {
|
| 120 |
-
"summarization_loss": 3.664777074654897,
|
| 121 |
-
"summarization_rouge_like": 0.22876987463000684,
|
| 122 |
-
"summarization_rouge1": 0.27596093399625565,
|
| 123 |
-
"summarization_rouge2": 0.09296804123657829,
|
| 124 |
-
"summarization_rougeL": 0.21411928790828857,
|
| 125 |
-
"summarization_bleu4": 0.05366559404113782,
|
| 126 |
-
"emotion_loss": 0.11044646929949523,
|
| 127 |
-
"emotion_f1": 0.4313555757453044,
|
| 128 |
-
"topic_loss": 0.5484664579232533,
|
| 129 |
-
"topic_accuracy": 0.8627777777777789,
|
| 130 |
-
"total_loss": 3.9397634813313704
|
| 131 |
-
},
|
| 132 |
-
"train_epoch_6": {
|
| 133 |
-
"emotion_loss": 0.1111307007874511,
|
| 134 |
-
"emotion_f1": 0.43345397762862603,
|
| 135 |
-
"summarization_loss": 3.7095002406409807,
|
| 136 |
-
"summarization_rouge_like": 0.22328726116125275,
|
| 137 |
-
"summarization_rouge1": 0.3064035344877472,
|
| 138 |
-
"summarization_rouge2": 0.09935359454486654,
|
| 139 |
-
"summarization_rougeL": 0.23650841461700828,
|
| 140 |
-
"summarization_bleu4": 0.059165680810364656,
|
| 141 |
-
"total_loss": 5.231221632164746,
|
| 142 |
-
"topic_loss": 0.10774352340420275,
|
| 143 |
-
"topic_accuracy": 0.9777777777777826
|
| 144 |
-
},
|
| 145 |
-
"val_epoch_6": {
|
| 146 |
-
"summarization_loss": 3.658109269142151,
|
| 147 |
-
"summarization_rouge_like": 0.22934290201883448,
|
| 148 |
-
"summarization_rouge1": 0.2752052666208255,
|
| 149 |
-
"summarization_rouge2": 0.09292038370832255,
|
| 150 |
-
"summarization_rougeL": 0.2137414809166316,
|
| 151 |
-
"summarization_bleu4": 0.053427338475007496,
|
| 152 |
-
"emotion_loss": 0.10808507531881333,
|
| 153 |
-
"emotion_f1": 0.4517777989556392,
|
| 154 |
-
"topic_loss": 0.5295590771238009,
|
| 155 |
-
"topic_accuracy": 0.8734444444444451,
|
| 156 |
-
"total_loss": 3.9250620675981014
|
| 157 |
-
},
|
| 158 |
-
"train_epoch_7": {
|
| 159 |
-
"emotion_loss": 0.10953594371440704,
|
| 160 |
-
"emotion_f1": 0.44384909393388133,
|
| 161 |
-
"topic_loss": 0.0957411853224039,
|
| 162 |
-
"topic_accuracy": 0.980394366197187,
|
| 163 |
-
"total_loss": 5.154093306898701,
|
| 164 |
-
"summarization_loss": 3.7035266418583594,
|
| 165 |
-
"summarization_rouge_like": 0.22378105952974536,
|
| 166 |
-
"summarization_rouge1": 0.3070619920824417,
|
| 167 |
-
"summarization_rouge2": 0.09984959921270933,
|
| 168 |
-
"summarization_rougeL": 0.23710279675635842,
|
| 169 |
-
"summarization_bleu4": 0.05954598113800495
|
| 170 |
-
},
|
| 171 |
-
"val_epoch_7": {
|
| 172 |
-
"summarization_loss": 3.654966928164164,
|
| 173 |
-
"summarization_rouge_like": 0.2296679906954514,
|
| 174 |
-
"summarization_rouge1": 0.27616327736195406,
|
| 175 |
-
"summarization_rouge2": 0.09329265746038877,
|
| 176 |
-
"summarization_rougeL": 0.2144202156909426,
|
| 177 |
-
"summarization_bleu4": 0.05381191556748925,
|
| 178 |
-
"emotion_loss": 0.10733611459533374,
|
| 179 |
-
"emotion_f1": 0.4582889095693827,
|
| 180 |
-
"topic_loss": 0.5517185291647911,
|
| 181 |
-
"topic_accuracy": 0.8574444444444457,
|
| 182 |
-
"total_loss": 3.9278186015089296
|
| 183 |
-
},
|
| 184 |
-
"train_epoch_8": {
|
| 185 |
-
"summarization_loss": 3.6991967220660666,
|
| 186 |
-
"summarization_rouge_like": 0.22392498422300275,
|
| 187 |
-
"summarization_rouge1": 0.30751530664889926,
|
| 188 |
-
"summarization_rouge2": 0.10003700619268063,
|
| 189 |
-
"summarization_rougeL": 0.23750205422812004,
|
| 190 |
-
"summarization_bleu4": 0.05974583539783897,
|
| 191 |
-
"emotion_loss": 0.10842480880968565,
|
| 192 |
-
"emotion_f1": 0.44919879924130307,
|
| 193 |
-
"total_loss": 5.178375478314296,
|
| 194 |
-
"topic_loss": 0.09057999134229244,
|
| 195 |
-
"topic_accuracy": 0.9817882611080675
|
| 196 |
-
},
|
| 197 |
-
"val_epoch_8": {
|
| 198 |
-
"summarization_loss": 3.652825821240743,
|
| 199 |
-
"summarization_rouge_like": 0.22990084413830203,
|
| 200 |
-
"summarization_rouge1": 0.2765755402183266,
|
| 201 |
-
"summarization_rouge2": 0.09348690574327727,
|
| 202 |
-
"summarization_rougeL": 0.2147442273650338,
|
| 203 |
-
"summarization_bleu4": 0.053967258926407226,
|
| 204 |
-
"emotion_loss": 0.10670269103099903,
|
| 205 |
-
"emotion_f1": 0.4594111327578624,
|
| 206 |
-
"topic_loss": 0.5511919154723486,
|
| 207 |
-
"topic_accuracy": 0.8574444444444457,
|
| 208 |
-
"total_loss": 3.924886086913453
|
| 209 |
-
}
|
| 210 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train_bert_baseline.py
ADDED
|
@@ -0,0 +1,1137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BERT Baseline Training for LexiMind Comparison.
|
| 3 |
+
|
| 4 |
+
Fine-tunes bert-base-uncased on topic classification and emotion detection
|
| 5 |
+
to provide baselines for comparison with LexiMind (FLAN-T5-based).
|
| 6 |
+
|
| 7 |
+
Supports three training modes to disentangle architecture vs. MTL effects:
|
| 8 |
+
1. single-topic — BERT fine-tuned on topic classification only
|
| 9 |
+
2. single-emotion — BERT fine-tuned on emotion detection only
|
| 10 |
+
3. multitask — BERT fine-tuned on both tasks jointly
|
| 11 |
+
|
| 12 |
+
Uses the same datasets, splits, label encoders, and evaluation metrics as the
|
| 13 |
+
main LexiMind pipeline for fair comparison.
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python scripts/train_bert_baseline.py --mode single-topic
|
| 17 |
+
python scripts/train_bert_baseline.py --mode single-emotion
|
| 18 |
+
python scripts/train_bert_baseline.py --mode multitask
|
| 19 |
+
python scripts/train_bert_baseline.py --mode all # Run all three sequentially
|
| 20 |
+
|
| 21 |
+
Author: Oliver Perrin
|
| 22 |
+
Date: March 2026
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import random
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from dataclasses import dataclass, field
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
| 41 |
+
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
|
| 42 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 43 |
+
from torch.optim import AdamW
|
| 44 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
| 45 |
+
from torch.utils.data import DataLoader, Dataset
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
from transformers import AutoModel, AutoTokenizer
|
| 48 |
+
|
| 49 |
+
# Project imports
|
| 50 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 51 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 52 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 53 |
+
|
| 54 |
+
from src.data.dataset import (
|
| 55 |
+
EmotionExample,
|
| 56 |
+
TopicExample,
|
| 57 |
+
load_emotion_jsonl,
|
| 58 |
+
load_topic_jsonl,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
from src.training.metrics import (
|
| 62 |
+
bootstrap_confidence_interval,
|
| 63 |
+
multilabel_f1,
|
| 64 |
+
multilabel_macro_f1,
|
| 65 |
+
multilabel_micro_f1,
|
| 66 |
+
multilabel_per_class_metrics,
|
| 67 |
+
tune_per_class_thresholds,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Configuration
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class BertBaselineConfig:
|
| 75 |
+
"""Hyperparameters aligned with LexiMind's full.yaml where applicable."""
|
| 76 |
+
|
| 77 |
+
# Model
|
| 78 |
+
model_name: str = "bert-base-uncased"
|
| 79 |
+
max_length: int = 256 # Same as LexiMind classification max_len
|
| 80 |
+
|
| 81 |
+
# Optimizer (matching LexiMind's full.yaml)
|
| 82 |
+
lr: float = 3e-5
|
| 83 |
+
weight_decay: float = 0.01
|
| 84 |
+
betas: tuple[float, float] = (0.9, 0.98)
|
| 85 |
+
eps: float = 1e-6
|
| 86 |
+
|
| 87 |
+
# Training
|
| 88 |
+
batch_size: int = 10 # Same as LexiMind
|
| 89 |
+
gradient_accumulation_steps: int = 4 # Same effective batch = 40
|
| 90 |
+
max_epochs: int = 8
|
| 91 |
+
warmup_steps: int = 300
|
| 92 |
+
gradient_clip_norm: float = 1.0
|
| 93 |
+
early_stopping_patience: int = 3
|
| 94 |
+
seed: int = 17 # Same as LexiMind
|
| 95 |
+
|
| 96 |
+
# Task weights (for multi-task mode)
|
| 97 |
+
topic_weight: float = 0.3 # Same as LexiMind
|
| 98 |
+
emotion_weight: float = 1.0
|
| 99 |
+
|
| 100 |
+
# Temperature sampling (for multi-task mode)
|
| 101 |
+
task_sampling_alpha: float = 0.5
|
| 102 |
+
|
| 103 |
+
# Frozen layers: freeze bottom 4 layers (matching LexiMind's encoder strategy)
|
| 104 |
+
freeze_layers: int = 4
|
| 105 |
+
|
| 106 |
+
# Precision
|
| 107 |
+
use_amp: bool = True # BFloat16 mixed precision
|
| 108 |
+
|
| 109 |
+
# Paths
|
| 110 |
+
data_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed")
|
| 111 |
+
output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "outputs" / "bert_baseline")
|
| 112 |
+
checkpoint_dir: Path = field(
|
| 113 |
+
default_factory=lambda: PROJECT_ROOT / "checkpoints" / "bert_baseline"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Emotion threshold
|
| 117 |
+
emotion_threshold: float = 0.3
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Datasets
|
| 121 |
+
|
| 122 |
+
class BertEmotionDataset(Dataset):
|
| 123 |
+
"""Tokenized emotion dataset for BERT."""
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
examples: List[EmotionExample],
|
| 127 |
+
tokenizer: AutoTokenizer,
|
| 128 |
+
binarizer: MultiLabelBinarizer,
|
| 129 |
+
max_length: int = 256,
|
| 130 |
+
):
|
| 131 |
+
self.examples = examples
|
| 132 |
+
self.tokenizer = tokenizer
|
| 133 |
+
self.binarizer = binarizer
|
| 134 |
+
self.max_length = max_length
|
| 135 |
+
|
| 136 |
+
def __len__(self) -> int:
|
| 137 |
+
return len(self.examples)
|
| 138 |
+
|
| 139 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 140 |
+
ex = self.examples[idx]
|
| 141 |
+
encoding = self.tokenizer(
|
| 142 |
+
ex.text,
|
| 143 |
+
max_length=self.max_length,
|
| 144 |
+
padding="max_length",
|
| 145 |
+
truncation=True,
|
| 146 |
+
return_tensors="pt",
|
| 147 |
+
)
|
| 148 |
+
labels = self.binarizer.transform([ex.emotions])[0]
|
| 149 |
+
return {
|
| 150 |
+
"input_ids": encoding["input_ids"].squeeze(0),
|
| 151 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 152 |
+
"labels": torch.tensor(labels, dtype=torch.float32),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BertTopicDataset(Dataset):
|
| 157 |
+
"""Tokenized topic dataset for BERT."""
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
examples: List[TopicExample],
|
| 161 |
+
tokenizer: AutoTokenizer,
|
| 162 |
+
encoder: LabelEncoder,
|
| 163 |
+
max_length: int = 256,
|
| 164 |
+
):
|
| 165 |
+
self.examples = examples
|
| 166 |
+
self.tokenizer = tokenizer
|
| 167 |
+
self.encoder = encoder
|
| 168 |
+
self.max_length = max_length
|
| 169 |
+
|
| 170 |
+
def __len__(self) -> int:
|
| 171 |
+
return len(self.examples)
|
| 172 |
+
|
| 173 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 174 |
+
ex = self.examples[idx]
|
| 175 |
+
encoding = self.tokenizer(
|
| 176 |
+
ex.text,
|
| 177 |
+
max_length=self.max_length,
|
| 178 |
+
padding="max_length",
|
| 179 |
+
truncation=True,
|
| 180 |
+
return_tensors="pt",
|
| 181 |
+
)
|
| 182 |
+
label = self.encoder.transform([ex.topic])[0]
|
| 183 |
+
return {
|
| 184 |
+
"input_ids": encoding["input_ids"].squeeze(0),
|
| 185 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 186 |
+
"labels": torch.tensor(label, dtype=torch.long),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Model
|
| 191 |
+
|
| 192 |
+
class BertClassificationHead(nn.Module):
|
| 193 |
+
"""Classification head on top of BERT [CLS] token.
|
| 194 |
+
|
| 195 |
+
For emotion: uses attention pooling + 2-layer MLP (matching LexiMind's emotion head)
|
| 196 |
+
For topic: uses [CLS] + single linear (matching LexiMind's mean pool + linear)
|
| 197 |
+
"""
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
hidden_size: int,
|
| 201 |
+
num_labels: int,
|
| 202 |
+
pooling: str = "cls", # "cls" or "attention"
|
| 203 |
+
hidden_dim: Optional[int] = None,
|
| 204 |
+
dropout: float = 0.1,
|
| 205 |
+
):
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.pooling = pooling
|
| 208 |
+
self.dropout = nn.Dropout(dropout)
|
| 209 |
+
|
| 210 |
+
if pooling == "attention":
|
| 211 |
+
self.attn_query = nn.Linear(hidden_size, 1, bias=False)
|
| 212 |
+
|
| 213 |
+
if hidden_dim is not None:
|
| 214 |
+
self.classifier = nn.Sequential(
|
| 215 |
+
nn.Linear(hidden_size, hidden_dim),
|
| 216 |
+
nn.GELU(),
|
| 217 |
+
nn.Dropout(dropout),
|
| 218 |
+
nn.Linear(hidden_dim, num_labels),
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
self.classifier = nn.Linear(hidden_size, num_labels)
|
| 222 |
+
|
| 223 |
+
def forward(
|
| 224 |
+
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
| 225 |
+
) -> torch.Tensor:
|
| 226 |
+
if self.pooling == "attention":
|
| 227 |
+
# Learned attention pooling (same mechanism as LexiMind)
|
| 228 |
+
scores = self.attn_query(hidden_states) # (B, L, 1)
|
| 229 |
+
mask = attention_mask.unsqueeze(-1).bool()
|
| 230 |
+
scores = scores.masked_fill(~mask, float("-inf"))
|
| 231 |
+
weights = F.softmax(scores, dim=1)
|
| 232 |
+
pooled = (weights * hidden_states).sum(dim=1)
|
| 233 |
+
elif self.pooling == "mean":
|
| 234 |
+
# Mean pooling over valid tokens
|
| 235 |
+
mask_expanded = attention_mask.unsqueeze(-1).float()
|
| 236 |
+
sum_embeddings = (hidden_states * mask_expanded).sum(dim=1)
|
| 237 |
+
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
|
| 238 |
+
pooled = sum_embeddings / sum_mask
|
| 239 |
+
else:
|
| 240 |
+
# [CLS] token
|
| 241 |
+
pooled = hidden_states[:, 0, :]
|
| 242 |
+
|
| 243 |
+
pooled = self.dropout(pooled)
|
| 244 |
+
return self.classifier(pooled)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class BertBaseline(nn.Module):
|
| 248 |
+
"""BERT baseline model with task-specific heads.
|
| 249 |
+
|
| 250 |
+
Supports single-task and multi-task configurations.
|
| 251 |
+
"""
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
model_name: str = "bert-base-uncased",
|
| 255 |
+
num_emotions: int = 28,
|
| 256 |
+
num_topics: int = 7,
|
| 257 |
+
tasks: Sequence[str] = ("emotion", "topic"),
|
| 258 |
+
freeze_layers: int = 4,
|
| 259 |
+
):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.bert = AutoModel.from_pretrained(model_name)
|
| 262 |
+
hidden_size = self.bert.config.hidden_size # 768 for bert-base
|
| 263 |
+
|
| 264 |
+
self.tasks = list(tasks)
|
| 265 |
+
self.heads = nn.ModuleDict()
|
| 266 |
+
|
| 267 |
+
if "emotion" in tasks:
|
| 268 |
+
# Attention pooling + 2-layer MLP (matching LexiMind's emotion head)
|
| 269 |
+
self.heads["emotion"] = BertClassificationHead(
|
| 270 |
+
hidden_size=hidden_size,
|
| 271 |
+
num_labels=num_emotions,
|
| 272 |
+
pooling="attention",
|
| 273 |
+
hidden_dim=hidden_size // 2, # 384, same ratio as LexiMind
|
| 274 |
+
dropout=0.1,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if "topic" in tasks:
|
| 278 |
+
# Mean pooling + single linear (matching LexiMind's topic head)
|
| 279 |
+
self.heads["topic"] = BertClassificationHead(
|
| 280 |
+
hidden_size=hidden_size,
|
| 281 |
+
num_labels=num_topics,
|
| 282 |
+
pooling="mean",
|
| 283 |
+
hidden_dim=None,
|
| 284 |
+
dropout=0.1,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Freeze bottom N encoder layers (matching LexiMind's strategy)
|
| 288 |
+
self._freeze_layers(freeze_layers)
|
| 289 |
+
|
| 290 |
+
def _freeze_layers(self, n: int) -> None:
|
| 291 |
+
"""Freeze embedding + bottom n encoder layers."""
|
| 292 |
+
# Freeze embeddings
|
| 293 |
+
for param in self.bert.embeddings.parameters():
|
| 294 |
+
param.requires_grad = False
|
| 295 |
+
|
| 296 |
+
# Freeze bottom n layers
|
| 297 |
+
for i in range(min(n, len(self.bert.encoder.layer))):
|
| 298 |
+
for param in self.bert.encoder.layer[i].parameters():
|
| 299 |
+
param.requires_grad = False
|
| 300 |
+
|
| 301 |
+
frozen = sum(1 for p in self.bert.parameters() if not p.requires_grad)
|
| 302 |
+
total = sum(1 for p in self.bert.parameters())
|
| 303 |
+
print(f" Frozen {frozen}/{total} BERT parameters (bottom {n} layers + embeddings)")
|
| 304 |
+
|
| 305 |
+
def forward(
|
| 306 |
+
self,
|
| 307 |
+
task: str,
|
| 308 |
+
input_ids: torch.Tensor,
|
| 309 |
+
attention_mask: torch.Tensor,
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 312 |
+
hidden_states = outputs.last_hidden_state # (B, L, 768)
|
| 313 |
+
return self.heads[task](hidden_states, attention_mask)
|
| 314 |
+
|
| 315 |
+
def param_count(self) -> Dict[str, int]:
|
| 316 |
+
"""Count parameters by component."""
|
| 317 |
+
counts = {}
|
| 318 |
+
counts["bert_encoder"] = sum(p.numel() for p in self.bert.parameters())
|
| 319 |
+
counts["bert_trainable"] = sum(
|
| 320 |
+
p.numel() for p in self.bert.parameters() if p.requires_grad
|
| 321 |
+
)
|
| 322 |
+
for name, head in self.heads.items():
|
| 323 |
+
counts[f"head_{name}"] = sum(p.numel() for p in head.parameters())
|
| 324 |
+
counts["total"] = sum(p.numel() for p in self.parameters())
|
| 325 |
+
counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 326 |
+
return counts
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# Training
|
| 330 |
+
|
| 331 |
+
class BertTrainer:
|
| 332 |
+
"""Trainer supporting single-task and multi-task BERT training."""
|
| 333 |
+
def __init__(
|
| 334 |
+
self,
|
| 335 |
+
model: BertBaseline,
|
| 336 |
+
config: BertBaselineConfig,
|
| 337 |
+
train_loaders: Dict[str, DataLoader],
|
| 338 |
+
val_loaders: Dict[str, DataLoader],
|
| 339 |
+
device: torch.device,
|
| 340 |
+
mode: str,
|
| 341 |
+
):
|
| 342 |
+
self.model = model
|
| 343 |
+
self.config = config
|
| 344 |
+
self.train_loaders = train_loaders
|
| 345 |
+
self.val_loaders = val_loaders
|
| 346 |
+
self.device = device
|
| 347 |
+
self.mode = mode
|
| 348 |
+
|
| 349 |
+
# Optimizer
|
| 350 |
+
self.optimizer = AdamW(
|
| 351 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 352 |
+
lr=config.lr,
|
| 353 |
+
weight_decay=config.weight_decay,
|
| 354 |
+
betas=config.betas,
|
| 355 |
+
eps=config.eps,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Calculate total training steps
|
| 359 |
+
if len(train_loaders) > 1:
|
| 360 |
+
# Multi-task: use temperature-sampled steps
|
| 361 |
+
sizes = {k: len(v) for k, v in train_loaders.items()}
|
| 362 |
+
total_batches = sum(sizes.values())
|
| 363 |
+
else:
|
| 364 |
+
total_batches = sum(len(v) for v in train_loaders.values())
|
| 365 |
+
self.steps_per_epoch = total_batches // config.gradient_accumulation_steps
|
| 366 |
+
self.total_steps = self.steps_per_epoch * config.max_epochs
|
| 367 |
+
|
| 368 |
+
# LR scheduler: linear warmup + cosine decay (matching LexiMind)
|
| 369 |
+
warmup_scheduler = LinearLR(
|
| 370 |
+
self.optimizer,
|
| 371 |
+
start_factor=1e-8 / config.lr,
|
| 372 |
+
end_factor=1.0,
|
| 373 |
+
total_iters=config.warmup_steps,
|
| 374 |
+
)
|
| 375 |
+
cosine_scheduler = CosineAnnealingLR(
|
| 376 |
+
self.optimizer,
|
| 377 |
+
T_max=max(self.total_steps - config.warmup_steps, 1),
|
| 378 |
+
eta_min=config.lr * 0.1, # Decay to 10% of peak (matching LexiMind)
|
| 379 |
+
)
|
| 380 |
+
self.scheduler = SequentialLR(
|
| 381 |
+
self.optimizer,
|
| 382 |
+
schedulers=[warmup_scheduler, cosine_scheduler],
|
| 383 |
+
milestones=[config.warmup_steps],
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Mixed precision
|
| 387 |
+
self.scaler = GradScaler(enabled=config.use_amp)
|
| 388 |
+
|
| 389 |
+
# Loss functions
|
| 390 |
+
self.emotion_loss_fn = nn.BCEWithLogitsLoss()
|
| 391 |
+
self.topic_loss_fn = nn.CrossEntropyLoss()
|
| 392 |
+
|
| 393 |
+
# Tracking
|
| 394 |
+
self.global_step = 0
|
| 395 |
+
self.best_metric = -float("inf")
|
| 396 |
+
self.patience_counter = 0
|
| 397 |
+
self.training_history: List[Dict[str, Any]] = []
|
| 398 |
+
|
| 399 |
+
def _compute_loss(self, task: str, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
| 400 |
+
if task == "emotion":
|
| 401 |
+
return self.emotion_loss_fn(logits, labels)
|
| 402 |
+
else:
|
| 403 |
+
return self.topic_loss_fn(logits, labels)
|
| 404 |
+
|
| 405 |
+
def _get_task_weight(self, task: str) -> float:
|
| 406 |
+
if self.mode != "multitask":
|
| 407 |
+
return 1.0
|
| 408 |
+
if task == "topic":
|
| 409 |
+
return self.config.topic_weight
|
| 410 |
+
return self.config.emotion_weight
|
| 411 |
+
|
| 412 |
+
def _make_multitask_iterator(self):
|
| 413 |
+
"""Temperature-based task sampling (matching LexiMind)."""
|
| 414 |
+
sizes = {k: len(v.dataset) for k, v in self.train_loaders.items()}
|
| 415 |
+
alpha = self.config.task_sampling_alpha
|
| 416 |
+
|
| 417 |
+
# Compute sampling probabilities
|
| 418 |
+
raw = {k: s ** (1.0 / alpha) for k, s in sizes.items()}
|
| 419 |
+
total = sum(raw.values())
|
| 420 |
+
probs = {k: v / total for k, v in raw.items()}
|
| 421 |
+
|
| 422 |
+
# Create iterators
|
| 423 |
+
iters = {k: iter(v) for k, v in self.train_loaders.items()}
|
| 424 |
+
tasks = list(probs.keys())
|
| 425 |
+
weights = [probs[t] for t in tasks]
|
| 426 |
+
|
| 427 |
+
while True:
|
| 428 |
+
task = random.choices(tasks, weights=weights, k=1)[0]
|
| 429 |
+
try:
|
| 430 |
+
batch = next(iters[task])
|
| 431 |
+
except StopIteration:
|
| 432 |
+
iters[task] = iter(self.train_loaders[task])
|
| 433 |
+
batch = next(iters[task])
|
| 434 |
+
yield task, batch
|
| 435 |
+
|
| 436 |
+
def train_epoch(self, epoch: int) -> Dict[str, float]:
|
| 437 |
+
"""Train one epoch."""
|
| 438 |
+
self.model.train()
|
| 439 |
+
self.optimizer.zero_grad()
|
| 440 |
+
|
| 441 |
+
epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
|
| 442 |
+
epoch_metrics: Dict[str, List[float]] = {}
|
| 443 |
+
|
| 444 |
+
if len(self.train_loaders) > 1:
|
| 445 |
+
# Multi-task: temperature sampling
|
| 446 |
+
iterator = self._make_multitask_iterator()
|
| 447 |
+
total_batches = sum(len(v) for v in self.train_loaders.values())
|
| 448 |
+
else:
|
| 449 |
+
# Single-task: iterate normally
|
| 450 |
+
task_name = list(self.train_loaders.keys())[0]
|
| 451 |
+
iterator = ((task_name, batch) for batch in self.train_loaders[task_name])
|
| 452 |
+
total_batches = len(self.train_loaders[task_name])
|
| 453 |
+
|
| 454 |
+
pbar = tqdm(total=total_batches, desc=f"Epoch {epoch + 1}/{self.config.max_epochs}")
|
| 455 |
+
|
| 456 |
+
for step_in_epoch in range(total_batches):
|
| 457 |
+
task, batch = next(iterator)
|
| 458 |
+
|
| 459 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 460 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 461 |
+
labels = batch["labels"].to(self.device)
|
| 462 |
+
|
| 463 |
+
# Forward pass with AMP
|
| 464 |
+
with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp):
|
| 465 |
+
logits = self.model(task, input_ids, attention_mask)
|
| 466 |
+
loss = self._compute_loss(task, logits, labels)
|
| 467 |
+
loss = loss * self._get_task_weight(task)
|
| 468 |
+
loss = loss / self.config.gradient_accumulation_steps
|
| 469 |
+
|
| 470 |
+
# Backward
|
| 471 |
+
self.scaler.scale(loss).backward()
|
| 472 |
+
epoch_losses[task].append(loss.item() * self.config.gradient_accumulation_steps)
|
| 473 |
+
|
| 474 |
+
# Optimizer step (every N accumulation steps)
|
| 475 |
+
if (step_in_epoch + 1) % self.config.gradient_accumulation_steps == 0:
|
| 476 |
+
self.scaler.unscale_(self.optimizer)
|
| 477 |
+
torch.nn.utils.clip_grad_norm_(
|
| 478 |
+
self.model.parameters(), self.config.gradient_clip_norm
|
| 479 |
+
)
|
| 480 |
+
self.scaler.step(self.optimizer)
|
| 481 |
+
self.scaler.update()
|
| 482 |
+
self.optimizer.zero_grad()
|
| 483 |
+
self.scheduler.step()
|
| 484 |
+
self.global_step += 1
|
| 485 |
+
|
| 486 |
+
pbar.set_postfix(
|
| 487 |
+
{f"{task}_loss": f"{epoch_losses[task][-1]:.4f}", "lr": f"{self.scheduler.get_last_lr()[0]:.2e}"}
|
| 488 |
+
)
|
| 489 |
+
pbar.update(1)
|
| 490 |
+
|
| 491 |
+
pbar.close()
|
| 492 |
+
|
| 493 |
+
# Aggregate
|
| 494 |
+
results = {}
|
| 495 |
+
for task, losses in epoch_losses.items():
|
| 496 |
+
if losses:
|
| 497 |
+
results[f"train_{task}_loss"] = sum(losses) / len(losses)
|
| 498 |
+
return results
|
| 499 |
+
|
| 500 |
+
@torch.no_grad()
|
| 501 |
+
def validate(self) -> Dict[str, Any]:
|
| 502 |
+
"""Run validation across all tasks."""
|
| 503 |
+
self.model.eval()
|
| 504 |
+
results: Dict[str, Any] = {}
|
| 505 |
+
|
| 506 |
+
for task, loader in self.val_loaders.items():
|
| 507 |
+
all_logits = []
|
| 508 |
+
all_labels = []
|
| 509 |
+
total_loss = 0.0
|
| 510 |
+
n_batches = 0
|
| 511 |
+
|
| 512 |
+
for batch in loader:
|
| 513 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 514 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 515 |
+
labels = batch["labels"].to(self.device)
|
| 516 |
+
|
| 517 |
+
with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp):
|
| 518 |
+
logits = self.model(task, input_ids, attention_mask)
|
| 519 |
+
loss = self._compute_loss(task, logits, labels)
|
| 520 |
+
|
| 521 |
+
total_loss += loss.item()
|
| 522 |
+
n_batches += 1
|
| 523 |
+
all_logits.append(logits.float().cpu())
|
| 524 |
+
all_labels.append(labels.float().cpu())
|
| 525 |
+
|
| 526 |
+
all_logits_t = torch.cat(all_logits, dim=0)
|
| 527 |
+
all_labels_t = torch.cat(all_labels, dim=0)
|
| 528 |
+
results[f"val_{task}_loss"] = total_loss / max(n_batches, 1)
|
| 529 |
+
|
| 530 |
+
if task == "emotion":
|
| 531 |
+
preds = (torch.sigmoid(all_logits_t) > self.config.emotion_threshold).int()
|
| 532 |
+
targets = all_labels_t.int()
|
| 533 |
+
results["val_emotion_sample_f1"] = multilabel_f1(preds, targets)
|
| 534 |
+
results["val_emotion_macro_f1"] = multilabel_macro_f1(preds, targets)
|
| 535 |
+
results["val_emotion_micro_f1"] = multilabel_micro_f1(preds, targets)
|
| 536 |
+
# Store raw logits for threshold tuning later
|
| 537 |
+
results["_emotion_logits"] = all_logits_t
|
| 538 |
+
results["_emotion_labels"] = all_labels_t
|
| 539 |
+
|
| 540 |
+
elif task == "topic":
|
| 541 |
+
preds = all_logits_t.argmax(dim=1).numpy()
|
| 542 |
+
targets = all_labels_t.long().numpy()
|
| 543 |
+
results["val_topic_accuracy"] = float(accuracy_score(targets, preds))
|
| 544 |
+
results["val_topic_macro_f1"] = float(
|
| 545 |
+
f1_score(targets, preds, average="macro", zero_division=0)
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Combined metric for early stopping / checkpointing
|
| 549 |
+
metric_parts = []
|
| 550 |
+
if "val_emotion_sample_f1" in results:
|
| 551 |
+
metric_parts.append(results["val_emotion_sample_f1"])
|
| 552 |
+
if "val_topic_accuracy" in results:
|
| 553 |
+
metric_parts.append(results["val_topic_accuracy"])
|
| 554 |
+
results["val_combined_metric"] = sum(metric_parts) / max(len(metric_parts), 1)
|
| 555 |
+
|
| 556 |
+
return results
|
| 557 |
+
|
| 558 |
+
def save_checkpoint(self, path: Path, epoch: int, metrics: Dict[str, Any]) -> None:
|
| 559 |
+
"""Save model checkpoint."""
|
| 560 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 561 |
+
# Filter out tensors from metrics
|
| 562 |
+
clean_metrics = {k: v for k, v in metrics.items() if not k.startswith("_")}
|
| 563 |
+
torch.save(
|
| 564 |
+
{
|
| 565 |
+
"epoch": epoch,
|
| 566 |
+
"model_state_dict": self.model.state_dict(),
|
| 567 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 568 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 569 |
+
"metrics": clean_metrics,
|
| 570 |
+
"config": {
|
| 571 |
+
"mode": self.mode,
|
| 572 |
+
"tasks": self.model.tasks,
|
| 573 |
+
"model_name": self.config.model_name,
|
| 574 |
+
},
|
| 575 |
+
},
|
| 576 |
+
path,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
def train(self) -> Dict[str, Any]:
|
| 580 |
+
"""Full training loop."""
|
| 581 |
+
print(f"\n{'=' * 60}")
|
| 582 |
+
print(f"Training BERT Baseline — Mode: {self.mode}")
|
| 583 |
+
print(f"{'=' * 60}")
|
| 584 |
+
|
| 585 |
+
param_counts = self.model.param_count()
|
| 586 |
+
print(f" Total parameters: {param_counts['total']:,}")
|
| 587 |
+
print(f" Trainable parameters: {param_counts['trainable']:,}")
|
| 588 |
+
for name, count in param_counts.items():
|
| 589 |
+
if name.startswith("head_"):
|
| 590 |
+
print(f" {name}: {count:,}")
|
| 591 |
+
print(f" Steps/epoch: {self.steps_per_epoch}")
|
| 592 |
+
print(f" Total steps: {self.total_steps}")
|
| 593 |
+
print()
|
| 594 |
+
|
| 595 |
+
all_results: Dict[str, Any] = {"mode": self.mode, "epochs": []}
|
| 596 |
+
start_time = time.time()
|
| 597 |
+
|
| 598 |
+
for epoch in range(self.config.max_epochs):
|
| 599 |
+
epoch_start = time.time()
|
| 600 |
+
|
| 601 |
+
# Train
|
| 602 |
+
train_metrics = self.train_epoch(epoch)
|
| 603 |
+
|
| 604 |
+
# Validate
|
| 605 |
+
val_metrics = self.validate()
|
| 606 |
+
|
| 607 |
+
epoch_time = time.time() - epoch_start
|
| 608 |
+
|
| 609 |
+
# Log
|
| 610 |
+
epoch_result = {
|
| 611 |
+
"epoch": epoch + 1,
|
| 612 |
+
"time_seconds": epoch_time,
|
| 613 |
+
**train_metrics,
|
| 614 |
+
**{k: v for k, v in val_metrics.items() if not k.startswith("_")},
|
| 615 |
+
}
|
| 616 |
+
all_results["epochs"].append(epoch_result)
|
| 617 |
+
self.training_history.append(epoch_result)
|
| 618 |
+
|
| 619 |
+
# Print summary
|
| 620 |
+
print(f"\n Epoch {epoch + 1} ({epoch_time:.0f}s):")
|
| 621 |
+
for k, v in sorted(epoch_result.items()):
|
| 622 |
+
if k not in ("epoch", "time_seconds") and isinstance(v, float):
|
| 623 |
+
print(f" {k}: {v:.4f}")
|
| 624 |
+
|
| 625 |
+
# Checkpointing
|
| 626 |
+
combined = val_metrics["val_combined_metric"]
|
| 627 |
+
if combined > self.best_metric:
|
| 628 |
+
self.best_metric = combined
|
| 629 |
+
self.patience_counter = 0
|
| 630 |
+
self.save_checkpoint(
|
| 631 |
+
self.config.checkpoint_dir / self.mode / "best.pt",
|
| 632 |
+
epoch,
|
| 633 |
+
val_metrics,
|
| 634 |
+
)
|
| 635 |
+
print(f" New best model (combined metric: {combined:.4f})")
|
| 636 |
+
else:
|
| 637 |
+
self.patience_counter += 1
|
| 638 |
+
print(
|
| 639 |
+
f" No improvement ({self.patience_counter}/{self.config.early_stopping_patience})"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# Always save epoch checkpoint
|
| 643 |
+
self.save_checkpoint(
|
| 644 |
+
self.config.checkpoint_dir / self.mode / f"epoch_{epoch + 1}.pt",
|
| 645 |
+
epoch,
|
| 646 |
+
val_metrics,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# Early stopping
|
| 650 |
+
if self.patience_counter >= self.config.early_stopping_patience:
|
| 651 |
+
print(f"\n Early stopping triggered at epoch {epoch + 1}")
|
| 652 |
+
all_results["early_stopped"] = True
|
| 653 |
+
all_results["best_epoch"] = epoch + 1 - self.config.early_stopping_patience
|
| 654 |
+
break
|
| 655 |
+
|
| 656 |
+
total_time = time.time() - start_time
|
| 657 |
+
all_results["total_time_seconds"] = total_time
|
| 658 |
+
all_results["total_time_human"] = f"{total_time / 3600:.1f}h"
|
| 659 |
+
if "early_stopped" not in all_results:
|
| 660 |
+
all_results["early_stopped"] = False
|
| 661 |
+
all_results["best_epoch"] = (
|
| 662 |
+
epoch + 1 - self.patience_counter
|
| 663 |
+
if self.patience_counter > 0
|
| 664 |
+
else epoch + 1
|
| 665 |
+
)
|
| 666 |
+
all_results["param_counts"] = param_counts
|
| 667 |
+
|
| 668 |
+
print(f"\n Training complete in {total_time / 3600:.1f}h")
|
| 669 |
+
print(f" Best combined metric: {self.best_metric:.4f}")
|
| 670 |
+
|
| 671 |
+
return all_results
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
# Evaluation
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def evaluate_bert_model(
|
| 678 |
+
model: BertBaseline,
|
| 679 |
+
val_loaders: Dict[str, DataLoader],
|
| 680 |
+
device: torch.device,
|
| 681 |
+
config: BertBaselineConfig,
|
| 682 |
+
emotion_classes: Optional[List[str]] = None,
|
| 683 |
+
topic_classes: Optional[List[str]] = None,
|
| 684 |
+
) -> Dict[str, Any]:
|
| 685 |
+
"""Full evaluation with the same metrics as LexiMind's evaluate.py."""
|
| 686 |
+
model.eval()
|
| 687 |
+
results: Dict[str, Any] = {}
|
| 688 |
+
|
| 689 |
+
with torch.no_grad():
|
| 690 |
+
for task, loader in val_loaders.items():
|
| 691 |
+
all_logits = []
|
| 692 |
+
all_labels = []
|
| 693 |
+
|
| 694 |
+
for batch in tqdm(loader, desc=f"Evaluating {task}"):
|
| 695 |
+
input_ids = batch["input_ids"].to(device)
|
| 696 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 697 |
+
labels = batch["labels"].to(device)
|
| 698 |
+
|
| 699 |
+
with autocast(dtype=torch.bfloat16, enabled=config.use_amp):
|
| 700 |
+
logits = model(task, input_ids, attention_mask)
|
| 701 |
+
|
| 702 |
+
all_logits.append(logits.float().cpu())
|
| 703 |
+
all_labels.append(labels.float().cpu())
|
| 704 |
+
|
| 705 |
+
all_logits_t = torch.cat(all_logits, dim=0)
|
| 706 |
+
all_labels_t = torch.cat(all_labels, dim=0)
|
| 707 |
+
|
| 708 |
+
if task == "emotion":
|
| 709 |
+
# Default threshold
|
| 710 |
+
preds_default = (
|
| 711 |
+
(torch.sigmoid(all_logits_t) > config.emotion_threshold).int()
|
| 712 |
+
)
|
| 713 |
+
targets = all_labels_t.int()
|
| 714 |
+
|
| 715 |
+
results["emotion"] = {
|
| 716 |
+
"default_threshold": config.emotion_threshold,
|
| 717 |
+
"sample_avg_f1": multilabel_f1(preds_default, targets),
|
| 718 |
+
"macro_f1": multilabel_macro_f1(preds_default, targets),
|
| 719 |
+
"micro_f1": multilabel_micro_f1(preds_default, targets),
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
# Per-class metrics
|
| 723 |
+
if emotion_classes:
|
| 724 |
+
per_class = multilabel_per_class_metrics(
|
| 725 |
+
preds_default, targets, emotion_classes
|
| 726 |
+
)
|
| 727 |
+
results["emotion"]["per_class"] = per_class
|
| 728 |
+
|
| 729 |
+
# Threshold tuning
|
| 730 |
+
best_thresholds, tuned_macro = tune_per_class_thresholds(
|
| 731 |
+
all_logits_t, all_labels_t
|
| 732 |
+
)
|
| 733 |
+
tuned_preds = torch.zeros_like(all_logits_t)
|
| 734 |
+
probs = torch.sigmoid(all_logits_t)
|
| 735 |
+
for c in range(all_logits_t.shape[1]):
|
| 736 |
+
tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
|
| 737 |
+
tuned_preds = tuned_preds.int()
|
| 738 |
+
|
| 739 |
+
results["emotion"]["tuned_macro_f1"] = tuned_macro
|
| 740 |
+
results["emotion"]["tuned_sample_avg_f1"] = multilabel_f1(
|
| 741 |
+
tuned_preds, targets
|
| 742 |
+
)
|
| 743 |
+
results["emotion"]["tuned_micro_f1"] = multilabel_micro_f1(
|
| 744 |
+
tuned_preds, targets
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# Bootstrap CI on sample-avg F1
|
| 748 |
+
per_sample_f1 = []
|
| 749 |
+
for i in range(preds_default.shape[0]):
|
| 750 |
+
p = preds_default[i].float()
|
| 751 |
+
g = targets[i].float()
|
| 752 |
+
tp = (p * g).sum()
|
| 753 |
+
prec = tp / p.sum().clamp(min=1)
|
| 754 |
+
rec = tp / g.sum().clamp(min=1)
|
| 755 |
+
f = (2 * prec * rec) / (prec + rec).clamp(min=1e-8)
|
| 756 |
+
per_sample_f1.append(f.item())
|
| 757 |
+
mean_f1, ci_low, ci_high = bootstrap_confidence_interval(per_sample_f1)
|
| 758 |
+
results["emotion"]["sample_avg_f1_ci"] = [ci_low, ci_high]
|
| 759 |
+
|
| 760 |
+
elif task == "topic":
|
| 761 |
+
preds = all_logits_t.argmax(dim=1).numpy()
|
| 762 |
+
targets = all_labels_t.long().numpy()
|
| 763 |
+
|
| 764 |
+
acc = float(accuracy_score(targets, preds))
|
| 765 |
+
macro_f1 = float(
|
| 766 |
+
f1_score(targets, preds, average="macro", zero_division=0)
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
results["topic"] = {
|
| 770 |
+
"accuracy": acc,
|
| 771 |
+
"macro_f1": macro_f1,
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
# Per-class metrics
|
| 775 |
+
if topic_classes:
|
| 776 |
+
report = classification_report(
|
| 777 |
+
targets,
|
| 778 |
+
preds,
|
| 779 |
+
target_names=topic_classes,
|
| 780 |
+
output_dict=True,
|
| 781 |
+
zero_division=0,
|
| 782 |
+
)
|
| 783 |
+
results["topic"]["per_class"] = {
|
| 784 |
+
name: {
|
| 785 |
+
"precision": report[name]["precision"],
|
| 786 |
+
"recall": report[name]["recall"],
|
| 787 |
+
"f1": report[name]["f1-score"],
|
| 788 |
+
"support": report[name]["support"],
|
| 789 |
+
}
|
| 790 |
+
for name in topic_classes
|
| 791 |
+
if name in report
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
# Bootstrap CI on accuracy
|
| 795 |
+
per_sample_correct = (preds == targets).astype(float).tolist()
|
| 796 |
+
mean_acc, ci_low, ci_high = bootstrap_confidence_interval(per_sample_correct)
|
| 797 |
+
results["topic"]["accuracy_ci"] = [ci_low, ci_high]
|
| 798 |
+
|
| 799 |
+
return results
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
# Main Pipeline
|
| 803 |
+
|
| 804 |
+
def set_seed(seed: int) -> None:
|
| 805 |
+
random.seed(seed)
|
| 806 |
+
np.random.seed(seed)
|
| 807 |
+
torch.manual_seed(seed)
|
| 808 |
+
if torch.cuda.is_available():
|
| 809 |
+
torch.cuda.manual_seed_all(seed)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def load_data(config: BertBaselineConfig):
|
| 813 |
+
"""Load all datasets and create label encoders."""
|
| 814 |
+
data_dir = config.data_dir
|
| 815 |
+
|
| 816 |
+
# Load emotion data
|
| 817 |
+
emo_train = load_emotion_jsonl(str(data_dir / "emotion" / "train.jsonl"))
|
| 818 |
+
emo_val_path = data_dir / "emotion" / "validation.jsonl"
|
| 819 |
+
if not emo_val_path.exists():
|
| 820 |
+
emo_val_path = data_dir / "emotion" / "val.jsonl"
|
| 821 |
+
emo_val = load_emotion_jsonl(str(emo_val_path))
|
| 822 |
+
|
| 823 |
+
# Load topic data
|
| 824 |
+
top_train = load_topic_jsonl(str(data_dir / "topic" / "train.jsonl"))
|
| 825 |
+
top_val_path = data_dir / "topic" / "validation.jsonl"
|
| 826 |
+
if not top_val_path.exists():
|
| 827 |
+
top_val_path = data_dir / "topic" / "val.jsonl"
|
| 828 |
+
top_val = load_topic_jsonl(str(top_val_path))
|
| 829 |
+
|
| 830 |
+
# Fit label encoders on training data (same as LexiMind)
|
| 831 |
+
binarizer = MultiLabelBinarizer()
|
| 832 |
+
binarizer.fit([ex.emotions for ex in emo_train])
|
| 833 |
+
|
| 834 |
+
label_encoder = LabelEncoder()
|
| 835 |
+
label_encoder.fit([ex.topic for ex in top_train])
|
| 836 |
+
|
| 837 |
+
print(f" Emotion: {len(emo_train)} train, {len(emo_val)} val, {len(binarizer.classes_)} classes")
|
| 838 |
+
print(f" Topic: {len(top_train)} train, {len(top_val)} val, {len(label_encoder.classes_)} classes")
|
| 839 |
+
print(f" Emotion classes: {list(binarizer.classes_)[:5]}...")
|
| 840 |
+
print(f" Topic classes: {list(label_encoder.classes_)}")
|
| 841 |
+
|
| 842 |
+
return {
|
| 843 |
+
"emotion_train": emo_train,
|
| 844 |
+
"emotion_val": emo_val,
|
| 845 |
+
"topic_train": top_train,
|
| 846 |
+
"topic_val": top_val,
|
| 847 |
+
"binarizer": binarizer,
|
| 848 |
+
"label_encoder": label_encoder,
|
| 849 |
+
}
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def run_experiment(mode: str, config: BertBaselineConfig) -> Dict[str, Any]:
|
| 853 |
+
"""Run a single experiment (single-topic, single-emotion, or multitask)."""
|
| 854 |
+
print(f"\n{'═' * 60}")
|
| 855 |
+
print(f" BERT BASELINE EXPERIMENT: {mode.upper()}")
|
| 856 |
+
print(f"{'═' * 60}")
|
| 857 |
+
|
| 858 |
+
set_seed(config.seed)
|
| 859 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 860 |
+
print(f" Device: {device}")
|
| 861 |
+
if torch.cuda.is_available():
|
| 862 |
+
print(f" GPU: {torch.cuda.get_device_name()}")
|
| 863 |
+
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 864 |
+
|
| 865 |
+
# CUDA optimizations
|
| 866 |
+
if torch.cuda.is_available():
|
| 867 |
+
torch.backends.cudnn.benchmark = True
|
| 868 |
+
if hasattr(torch.backends, "cuda"):
|
| 869 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 870 |
+
|
| 871 |
+
# Load tokenizer
|
| 872 |
+
print(f"\n Loading tokenizer: {config.model_name}")
|
| 873 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 874 |
+
|
| 875 |
+
# Load data
|
| 876 |
+
print(" Loading datasets...")
|
| 877 |
+
data = load_data(config)
|
| 878 |
+
|
| 879 |
+
# Determine tasks for this mode
|
| 880 |
+
if mode == "single-topic":
|
| 881 |
+
tasks = ["topic"]
|
| 882 |
+
elif mode == "single-emotion":
|
| 883 |
+
tasks = ["emotion"]
|
| 884 |
+
else:
|
| 885 |
+
tasks = ["emotion", "topic"]
|
| 886 |
+
|
| 887 |
+
# Create datasets
|
| 888 |
+
train_loaders: Dict[str, DataLoader] = {}
|
| 889 |
+
val_loaders: Dict[str, DataLoader] = {}
|
| 890 |
+
|
| 891 |
+
if "emotion" in tasks:
|
| 892 |
+
emo_train_ds = BertEmotionDataset(
|
| 893 |
+
data["emotion_train"], tokenizer, data["binarizer"], config.max_length
|
| 894 |
+
)
|
| 895 |
+
emo_val_ds = BertEmotionDataset(
|
| 896 |
+
data["emotion_val"], tokenizer, data["binarizer"], config.max_length
|
| 897 |
+
)
|
| 898 |
+
train_loaders["emotion"] = DataLoader(
|
| 899 |
+
emo_train_ds,
|
| 900 |
+
batch_size=config.batch_size,
|
| 901 |
+
shuffle=True,
|
| 902 |
+
num_workers=4,
|
| 903 |
+
pin_memory=True,
|
| 904 |
+
persistent_workers=True,
|
| 905 |
+
)
|
| 906 |
+
val_loaders["emotion"] = DataLoader(
|
| 907 |
+
emo_val_ds,
|
| 908 |
+
batch_size=config.batch_size * 2,
|
| 909 |
+
shuffle=False,
|
| 910 |
+
num_workers=4,
|
| 911 |
+
pin_memory=True,
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
if "topic" in tasks:
|
| 915 |
+
top_train_ds = BertTopicDataset(
|
| 916 |
+
data["topic_train"], tokenizer, data["label_encoder"], config.max_length
|
| 917 |
+
)
|
| 918 |
+
top_val_ds = BertTopicDataset(
|
| 919 |
+
data["topic_val"], tokenizer, data["label_encoder"], config.max_length
|
| 920 |
+
)
|
| 921 |
+
train_loaders["topic"] = DataLoader(
|
| 922 |
+
top_train_ds,
|
| 923 |
+
batch_size=config.batch_size,
|
| 924 |
+
shuffle=True,
|
| 925 |
+
num_workers=4,
|
| 926 |
+
pin_memory=True,
|
| 927 |
+
persistent_workers=True,
|
| 928 |
+
)
|
| 929 |
+
val_loaders["topic"] = DataLoader(
|
| 930 |
+
top_val_ds,
|
| 931 |
+
batch_size=config.batch_size * 2,
|
| 932 |
+
shuffle=False,
|
| 933 |
+
num_workers=4,
|
| 934 |
+
pin_memory=True,
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
# Create model
|
| 938 |
+
print(f"\n Creating model with tasks: {tasks}")
|
| 939 |
+
model = BertBaseline(
|
| 940 |
+
model_name=config.model_name,
|
| 941 |
+
num_emotions=len(data["binarizer"].classes_),
|
| 942 |
+
num_topics=len(data["label_encoder"].classes_),
|
| 943 |
+
tasks=tasks,
|
| 944 |
+
freeze_layers=config.freeze_layers,
|
| 945 |
+
).to(device)
|
| 946 |
+
|
| 947 |
+
# Train
|
| 948 |
+
trainer = BertTrainer(model, config, train_loaders, val_loaders, device, mode)
|
| 949 |
+
training_results = trainer.train()
|
| 950 |
+
|
| 951 |
+
# Load best checkpoint for final evaluation
|
| 952 |
+
best_path = config.checkpoint_dir / mode / "best.pt"
|
| 953 |
+
if best_path.exists():
|
| 954 |
+
print(f"\n Loading best checkpoint for final evaluation...")
|
| 955 |
+
checkpoint = torch.load(best_path, map_location=device, weights_only=False)
|
| 956 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 957 |
+
|
| 958 |
+
# Full evaluation
|
| 959 |
+
print(f"\n Running final evaluation...")
|
| 960 |
+
eval_results = evaluate_bert_model(
|
| 961 |
+
model,
|
| 962 |
+
val_loaders,
|
| 963 |
+
device,
|
| 964 |
+
config,
|
| 965 |
+
emotion_classes=list(data["binarizer"].classes_) if "emotion" in tasks else None,
|
| 966 |
+
topic_classes=list(data["label_encoder"].classes_) if "topic" in tasks else None,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
# Combine results
|
| 970 |
+
final_results = {
|
| 971 |
+
"mode": mode,
|
| 972 |
+
"model": config.model_name,
|
| 973 |
+
"tasks": tasks,
|
| 974 |
+
"training": training_results,
|
| 975 |
+
"evaluation": eval_results,
|
| 976 |
+
}
|
| 977 |
+
|
| 978 |
+
# Save results
|
| 979 |
+
output_path = config.output_dir / f"{mode}_results.json"
|
| 980 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 981 |
+
|
| 982 |
+
# Remove non-serializable fields
|
| 983 |
+
def make_serializable(obj):
|
| 984 |
+
if isinstance(obj, dict):
|
| 985 |
+
return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")}
|
| 986 |
+
if isinstance(obj, list):
|
| 987 |
+
return [make_serializable(item) for item in obj]
|
| 988 |
+
if isinstance(obj, (np.integer, np.int64)):
|
| 989 |
+
return int(obj)
|
| 990 |
+
if isinstance(obj, (np.floating, np.float64)):
|
| 991 |
+
return float(obj)
|
| 992 |
+
if isinstance(obj, np.ndarray):
|
| 993 |
+
return obj.tolist()
|
| 994 |
+
return obj
|
| 995 |
+
|
| 996 |
+
with open(output_path, "w") as f:
|
| 997 |
+
json.dump(make_serializable(final_results), f, indent=2)
|
| 998 |
+
print(f"\n Results saved to {output_path}")
|
| 999 |
+
|
| 1000 |
+
return final_results
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def print_comparison_summary(all_results: Dict[str, Dict[str, Any]]) -> None:
|
| 1004 |
+
"""Print a side-by-side comparison of all experiments."""
|
| 1005 |
+
print(f"\n{'═' * 70}")
|
| 1006 |
+
print(" BERT BASELINE COMPARISON SUMMARY")
|
| 1007 |
+
print(f"{'═' * 70}")
|
| 1008 |
+
|
| 1009 |
+
# Header
|
| 1010 |
+
modes = list(all_results.keys())
|
| 1011 |
+
header = f"{'Metric':<30}" + "".join(f"{m:>16}" for m in modes) + f"{'LexiMind':>16}"
|
| 1012 |
+
print(f"\n {header}")
|
| 1013 |
+
print(f" {'─' * len(header)}")
|
| 1014 |
+
|
| 1015 |
+
# LexiMind reference values
|
| 1016 |
+
lexmind = {
|
| 1017 |
+
"topic_accuracy": 0.8571,
|
| 1018 |
+
"topic_macro_f1": 0.8539,
|
| 1019 |
+
"emotion_sample_f1": 0.3523,
|
| 1020 |
+
"emotion_macro_f1": 0.1432,
|
| 1021 |
+
"emotion_micro_f1": 0.4430,
|
| 1022 |
+
"emotion_tuned_macro_f1": 0.2936,
|
| 1023 |
+
}
|
| 1024 |
+
|
| 1025 |
+
# Topic metrics
|
| 1026 |
+
print(f"\n {'Topic Classification':}")
|
| 1027 |
+
for metric_name, display_name in [
|
| 1028 |
+
("accuracy", "Accuracy"),
|
| 1029 |
+
("macro_f1", "Macro F1"),
|
| 1030 |
+
]:
|
| 1031 |
+
row = f" {display_name:<30}"
|
| 1032 |
+
for mode in modes:
|
| 1033 |
+
eval_data = all_results[mode].get("evaluation", {})
|
| 1034 |
+
topic = eval_data.get("topic", {})
|
| 1035 |
+
val = topic.get(metric_name, None)
|
| 1036 |
+
row += f"{val:>16.4f}" if val is not None else f"{'—':>16}"
|
| 1037 |
+
lm_key = f"topic_{metric_name}"
|
| 1038 |
+
row += f"{lexmind.get(lm_key, 0):>16.4f}"
|
| 1039 |
+
print(row)
|
| 1040 |
+
|
| 1041 |
+
# Emotion metrics
|
| 1042 |
+
print(f"\n {'Emotion Detection':}")
|
| 1043 |
+
for metric_name, display_name in [
|
| 1044 |
+
("sample_avg_f1", "Sample-avg F1 (τ=0.3)"),
|
| 1045 |
+
("macro_f1", "Macro F1 (τ=0.3)"),
|
| 1046 |
+
("micro_f1", "Micro F1 (τ=0.3)"),
|
| 1047 |
+
("tuned_macro_f1", "Tuned Macro F1"),
|
| 1048 |
+
("tuned_sample_avg_f1", "Tuned Sample-avg F1"),
|
| 1049 |
+
]:
|
| 1050 |
+
row = f" {display_name:<30}"
|
| 1051 |
+
for mode in modes:
|
| 1052 |
+
eval_data = all_results[mode].get("evaluation", {})
|
| 1053 |
+
emo = eval_data.get("emotion", {})
|
| 1054 |
+
val = emo.get(metric_name, None)
|
| 1055 |
+
row += f"{val:>16.4f}" if val is not None else f"{'—':>16}"
|
| 1056 |
+
lm_key = f"emotion_{metric_name}"
|
| 1057 |
+
row += f"{lexmind.get(lm_key, 0):>16.4f}"
|
| 1058 |
+
print(row)
|
| 1059 |
+
|
| 1060 |
+
# Training time
|
| 1061 |
+
print(f"\n {'Training Time':}")
|
| 1062 |
+
row = f" {'Hours':<30}"
|
| 1063 |
+
for mode in modes:
|
| 1064 |
+
t = all_results[mode].get("training", {}).get("total_time_seconds", 0) / 3600
|
| 1065 |
+
row += f"{t:>15.1f}h"
|
| 1066 |
+
row += f"{'~9.0h':>16}"
|
| 1067 |
+
print(row)
|
| 1068 |
+
|
| 1069 |
+
print(f"\n{'═' * 70}\n")
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
def main():
|
| 1073 |
+
parser = argparse.ArgumentParser(description="BERT Baseline Training for LexiMind")
|
| 1074 |
+
parser.add_argument(
|
| 1075 |
+
"--mode",
|
| 1076 |
+
type=str,
|
| 1077 |
+
required=True,
|
| 1078 |
+
choices=["single-topic", "single-emotion", "multitask", "all"],
|
| 1079 |
+
help="Training mode",
|
| 1080 |
+
)
|
| 1081 |
+
parser.add_argument("--epochs", type=int, default=None, help="Override max epochs")
|
| 1082 |
+
parser.add_argument("--lr", type=float, default=None, help="Override learning rate")
|
| 1083 |
+
parser.add_argument("--batch-size", type=int, default=None, help="Override batch size")
|
| 1084 |
+
parser.add_argument(
|
| 1085 |
+
"--model", type=str, default="bert-base-uncased", help="HuggingFace model name"
|
| 1086 |
+
)
|
| 1087 |
+
args = parser.parse_args()
|
| 1088 |
+
|
| 1089 |
+
config = BertBaselineConfig()
|
| 1090 |
+
config.model_name = args.model
|
| 1091 |
+
if args.epochs is not None:
|
| 1092 |
+
config.max_epochs = args.epochs
|
| 1093 |
+
if args.lr is not None:
|
| 1094 |
+
config.lr = args.lr
|
| 1095 |
+
if args.batch_size is not None:
|
| 1096 |
+
config.batch_size = args.batch_size
|
| 1097 |
+
|
| 1098 |
+
if args.mode == "all":
|
| 1099 |
+
modes = ["single-topic", "single-emotion", "multitask"]
|
| 1100 |
+
else:
|
| 1101 |
+
modes = [args.mode]
|
| 1102 |
+
|
| 1103 |
+
all_results: Dict[str, Dict[str, Any]] = {}
|
| 1104 |
+
for mode in modes:
|
| 1105 |
+
results = run_experiment(mode, config)
|
| 1106 |
+
all_results[mode] = results
|
| 1107 |
+
|
| 1108 |
+
# Clear GPU memory between experiments
|
| 1109 |
+
if torch.cuda.is_available():
|
| 1110 |
+
torch.cuda.empty_cache()
|
| 1111 |
+
|
| 1112 |
+
# Save combined results
|
| 1113 |
+
if len(all_results) > 1:
|
| 1114 |
+
combined_path = config.output_dir / "combined_results.json"
|
| 1115 |
+
|
| 1116 |
+
def make_serializable(obj):
|
| 1117 |
+
if isinstance(obj, dict):
|
| 1118 |
+
return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")}
|
| 1119 |
+
if isinstance(obj, list):
|
| 1120 |
+
return [make_serializable(item) for item in obj]
|
| 1121 |
+
if isinstance(obj, (np.integer, np.int64)):
|
| 1122 |
+
return int(obj)
|
| 1123 |
+
if isinstance(obj, (np.floating, np.float64)):
|
| 1124 |
+
return float(obj)
|
| 1125 |
+
if isinstance(obj, np.ndarray):
|
| 1126 |
+
return obj.tolist()
|
| 1127 |
+
return obj
|
| 1128 |
+
|
| 1129 |
+
with open(combined_path, "w") as f:
|
| 1130 |
+
json.dump(make_serializable(all_results), f, indent=2)
|
| 1131 |
+
print(f" Combined results saved to {combined_path}")
|
| 1132 |
+
|
| 1133 |
+
print_comparison_summary(all_results)
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
if __name__ == "__main__":
|
| 1137 |
+
main()
|