OliverPerrin commited on
Commit
aca8362
·
1 Parent(s): a6468cc

Added Bert Baseline comparison training script and ignored unneccessary files

Browse files
.gitignore CHANGED
@@ -51,6 +51,11 @@ results/
51
  .ipynb_checkpoints/
52
  *.ipynb
53
 
 
 
 
 
 
54
  # OS - Windows specific
55
  .DS_Store
56
  Thumbs.db
@@ -65,6 +70,10 @@ ehthumbs_vista.db
65
  configs/local/*.png
66
  *.pt
67
 
 
 
 
 
68
  # Backup/private files
69
- scripts/demo_gradio_old.py
70
  mlruns.db
 
 
51
  .ipynb_checkpoints/
52
  *.ipynb
53
 
54
+ # Docs/Paper
55
+ docs/paper.*
56
+ docs/paper_old.*
57
+ docs/research_paper.*
58
+
59
  # OS - Windows specific
60
  .DS_Store
61
  Thumbs.db
 
70
  configs/local/*.png
71
  *.pt
72
 
73
+ # Environment variables
74
+ .env
75
+ .env.*
76
+
77
  # Backup/private files
 
78
  mlruns.db
79
+ kaggle.json
docs/paper.tex DELETED
@@ -1,1214 +0,0 @@
1
- % LexiMind: A Hybrid Transformer Architecture for Multi-Task NLP
2
- % IEEE Conference Style Paper
3
- % Author: Oliver Perrin
4
-
5
- \documentclass[conference]{IEEEtran}
6
- \IEEEoverridecommandlockouts
7
-
8
- % Essential packages
9
- \usepackage{cite}
10
- \usepackage{amsmath,amssymb,amsfonts}
11
- \usepackage{algorithmic}
12
- \usepackage{graphicx}
13
- \usepackage{textcomp}
14
- \usepackage{xcolor}
15
- \usepackage{hyperref}
16
- \usepackage{listings}
17
- \usepackage{booktabs}
18
- \usepackage{multirow}
19
- \usepackage{array}
20
- \usepackage{caption}
21
-
22
- % TikZ for diagrams
23
- \usepackage{tikz}
24
- \usetikzlibrary{shapes.geometric, arrows, positioning, fit, calc, backgrounds, decorations.pathreplacing}
25
-
26
- % Code listings style
27
- \lstset{
28
- basicstyle=\ttfamily\small,
29
- breaklines=true,
30
- frame=single,
31
- language=Python,
32
- keywordstyle=\color{blue},
33
- commentstyle=\color{green!50!black},
34
- stringstyle=\color{red!60!black},
35
- showstringspaces=false
36
- }
37
-
38
- % Hyperref setup
39
- \hypersetup{
40
- colorlinks=true,
41
- linkcolor=blue,
42
- citecolor=blue,
43
- urlcolor=blue
44
- }
45
-
46
- \def\BibTeX{{\rm B\kern-.05em{\sc i\kern-.025em b}\kern-.08em
47
- T\kern-.1667em\lower.7ex\hbox{E}\kern-.125emX}}
48
-
49
- \begin{document}
50
-
51
- \title{LexiMind: A Hybrid Transformer Architecture\\for Multi-Task Natural Language Processing}
52
-
53
- \author{\IEEEauthorblockN{Oliver Perrin}\\
54
- \IEEEauthorblockA{Department of Computer Science\\
55
- Appalachian State University\\
56
- Bachelor of Science in Computer Science\\
57
- Email: perrinot@appstate.edu}}
58
-
59
- \maketitle
60
-
61
- \begin{abstract}
62
- This paper presents LexiMind, a multi-task Natural Language Processing (NLP) system that combines a custom-built Transformer architecture with pre-trained weights from Google's FLAN-T5 (Fine-tuned Language Net Text-to-Text Transfer Transformer). The system performs three fundamental NLP tasks simultaneously: abstractive text summarization, multi-label emotion classification, and single-label topic classification. Unlike news-focused models, LexiMind specializes in literary and academic content. For summarization, we train on 49,086 samples combining Goodreads book descriptions (back-cover style blurbs) with arXiv academic paper abstracts. Emotion classification uses 43,410 samples from GoEmotions \cite{demszky2020goemotions}, a dataset of 28 fine-grained emotion labels derived from Reddit comments. Topic classification spans 3,402 samples from 20 Newsgroups, Project Gutenberg literary texts, and scientific papers across 7 categories (Arts, Business, Fiction, History, Philosophy, Science, Technology). By implementing modern architectural innovations including Pre-Layer Normalization (Pre-LN) with Root Mean Square Layer Normalization (RMSNorm), T5-style relative position bias, FlashAttention via PyTorch 2.0's Scaled Dot-Product Attention (SDPA), gradient checkpointing, and torch.compile optimization, LexiMind achieves efficient training on consumer GPUs while maintaining strong performance. Our final model achieves a BERTScore F1 of 0.83 and ROUGE-1 of 0.31 for summarization, 85.2\% accuracy for topic classification, and F1 of 0.20 for 28-class multi-label emotion detection. The 272M-parameter architecture is constructed from first principles in a bottom-up fashion, with each component (attention mechanisms, feed-forward networks, encoder/decoder blocks) implemented as standalone modules. A factory pattern enables seamless weight transfer from FLAN-T5-base, allowing the system to leverage Google's pre-trained knowledge while maintaining full architectural transparency and customization capability.
63
- \end{abstract}
64
-
65
- \begin{IEEEkeywords}
66
- Transformer, Multi-Task Learning, Natural Language Processing, FLAN-T5, Transfer Learning, Text Summarization, Emotion Classification, Academic Papers, Literary Text
67
- \end{IEEEkeywords}
68
-
69
- %=============================================================================
70
- \section{Introduction}
71
- %=============================================================================
72
-
73
- The Transformer architecture \cite{vaswani2017attention} has fundamentally reshaped Natural Language Processing (NLP), establishing itself as the foundation for state-of-the-art models across virtually all language understanding and generation tasks. Building upon this foundation, the T5 (Text-to-Text Transfer Transformer) model \cite{raffel2020exploring} introduced a unified framework that casts all NLP problems as text-to-text transformations. FLAN-T5 (Fine-tuned Language Net) \cite{chung2022scaling} further enhanced T5's capabilities through instruction fine-tuning on over 1,000 diverse tasks.
74
-
75
- While pre-trained models like FLAN-T5 offer impressive zero-shot and few-shot capabilities, they are often treated as black boxes—their internal mechanisms obscured by framework abstractions. This opacity hinders both understanding and customization. Furthermore, multi-task learning scenarios often require architectural modifications that pre-built models do not easily accommodate.
76
-
77
- LexiMind addresses these challenges through a hybrid approach: implementing a complete Transformer architecture from scratch while maintaining compatibility with FLAN-T5's pre-trained weights. This provides several key advantages:
78
-
79
- \begin{enumerate}
80
- \item \textbf{Architectural Transparency}: Every component—from attention mechanisms to normalization layers—is explicitly implemented and documented.
81
- \item \textbf{Customization Flexibility}: Task-specific heads and routing logic can be freely modified without framework constraints.
82
- \item \textbf{Transfer Learning}: FLAN-T5's linguistic knowledge is transferred through careful weight mapping in the factory module.
83
- \item \textbf{Modern Optimizations}: Integration of FlashAttention, bfloat16 training, and gradient accumulation ensures efficient resource utilization.
84
- \end{enumerate}
85
-
86
- A key design decision in LexiMind is the focus on literary and academic domains rather than news articles, which are overrepresented in existing summarization benchmarks. For text summarization, we combine Goodreads book descriptions---which provide back-cover style blurbs describing \textit{what a book is about}---with arXiv paper abstracts. This trains the model to generate descriptive summaries rather than extractive plot recaps. Emotion classification leverages GoEmotions \cite{demszky2020goemotions}, providing fine-grained 28-label annotations. Topic classification draws from diverse sources including 20 Newsgroups, Project Gutenberg, and scientific papers.
87
-
88
- The contributions of this work include:
89
- \begin{itemize}
90
- \item A custom Transformer implementation compatible with T5/FLAN-T5 weight loading
91
- \item A multi-task architecture supporting both generative (summarization) and discriminative (classification) tasks
92
- \item A curated dataset of 95,898 training samples across literary, academic, and conversational domains
93
- \item Detailed documentation of weight transfer mechanisms between pre-trained models and custom implementations
94
- \item Comprehensive training infrastructure with mixed-precision support, gradient monitoring, and MLflow experiment tracking
95
- \end{itemize}
96
-
97
- %=============================================================================
98
- \section{Related Work}
99
- %=============================================================================
100
-
101
- \subsection{Transformer Architectures}
102
-
103
- The original Transformer \cite{vaswani2017attention} introduced the self-attention mechanism, enabling parallel processing of sequences and effective capture of long-range dependencies. The architecture consists of stacked encoder and decoder blocks, each containing Multi-Head Attention (MHA) and position-wise Feed-Forward Networks (FFN).
104
-
105
- \textbf{Layer Normalization Placement}: The original Transformer applied Layer Normalization \cite{ba2016layer} after residual connections (Post-LN). Subsequent research \cite{xiong2020layer} demonstrated that applying normalization before sublayers (Pre-LN) significantly improves training stability, particularly for deep networks. LexiMind adopts the Pre-LN configuration used by T5 and modern large language models.
106
-
107
- \textbf{RMSNorm}: Zhang and Sennrich \cite{zhang2019root} proposed Root Mean Square Layer Normalization (RMSNorm), which removes the mean-centering operation of standard LayerNorm while maintaining comparable performance. T5 \cite{raffel2020exploring} adopts this approach, and LexiMind follows suit for compatibility.
108
-
109
- \subsection{Pre-trained Language Models}
110
-
111
- \textbf{T5}: Raffel et al. \cite{raffel2020exploring} introduced the T5 model, which frames all NLP tasks as text-to-text problems. T5 uses a Transformer encoder-decoder architecture with several distinctive features: relative position bias instead of absolute positional embeddings, RMSNorm for layer normalization, and a gated feed-forward network.
112
-
113
- \textbf{FLAN-T5}: Chung et al. \cite{chung2022scaling} enhanced T5 through instruction fine-tuning, creating FLAN-T5. By training on diverse task instructions, FLAN-T5 demonstrates improved zero-shot and few-shot capabilities compared to the original T5.
114
-
115
- \subsection{Multi-Task Learning}
116
-
117
- Multi-Task Learning (MTL) \cite{caruana1997multitask} trains a single model on multiple related tasks, promoting parameter sharing and implicit data augmentation. Hard parameter sharing—where lower layers are shared across tasks while task-specific heads branch from shared representations—remains the dominant approach for Transformer-based MTL systems.
118
-
119
- %=============================================================================
120
- \section{Architecture}
121
- %=============================================================================
122
-
123
- LexiMind implements a complete encoder-decoder Transformer with task-specific heads, constructed using a bottom-up approach where each component is implemented as a standalone module. Figure \ref{fig:architecture} illustrates the high-level system architecture.
124
-
125
- \begin{figure}[htbp]
126
- \centering
127
- \begin{tikzpicture}[
128
- scale=0.75,
129
- transform shape,
130
- box/.style={draw, rectangle, minimum width=2cm, minimum height=0.7cm, align=center, rounded corners=2pt},
131
- smallbox/.style={draw, rectangle, minimum width=1.4cm, minimum height=0.5cm, align=center, rounded corners=2pt, font=\scriptsize},
132
- head/.style={draw, rectangle, minimum width=1.5cm, minimum height=0.6cm, align=center, rounded corners=2pt, fill=blue!20},
133
- arrow/.style={->, >=stealth, thick},
134
- dashedarrow/.style={->, >=stealth, dashed}
135
- ]
136
-
137
- % Input
138
- \node[box, fill=gray!20] (input) at (0, 0) {Input Text};
139
-
140
- % Tokenizer
141
- \node[box, fill=yellow!30] (tokenizer) at (0, 1.2) {Tokenizer\\(SentencePiece)};
142
-
143
- % Encoder
144
- \node[box, fill=green!30, minimum height=2cm] (encoder) at (0, 3.2) {Encoder\\$N=12$ layers};
145
-
146
- % Task routing
147
- \node[box, fill=orange!30] (router) at (0, 5.2) {Task Router};
148
-
149
- % Decoder branch
150
- \node[box, fill=green!30, minimum height=1.5cm] (decoder) at (-2.5, 7) {Decoder\\$N=12$ layers};
151
- \node[head] (lmhead) at (-2.5, 8.8) {LM Head};
152
- \node[smallbox, fill=purple!20] (summ) at (-2.5, 9.8) {Summary};
153
-
154
- % Classification branch
155
- \node[head] (emotionhead) at (1.2, 7) {Emotion\\Head};
156
- \node[head] (topichead) at (2.8, 7) {Topic\\Head};
157
- \node[smallbox, fill=purple!20] (emotion) at (1.2, 8.2) {Emotions\\(28 classes)};
158
- \node[smallbox, fill=purple!20] (topic) at (2.8, 8.2) {Topics\\(7 classes)};
159
-
160
- % Arrows
161
- \draw[arrow] (input) -- (tokenizer);
162
- \draw[arrow] (tokenizer) -- (encoder);
163
- \draw[arrow] (encoder) -- (router);
164
- \draw[arrow] (router) -- (decoder);
165
- \draw[arrow] (router) -- (emotionhead);
166
- \draw[arrow] (router) -- (topichead);
167
- \draw[arrow] (decoder) -- (lmhead);
168
- \draw[arrow] (lmhead) -- (summ);
169
- \draw[arrow] (emotionhead) -- (emotion);
170
- \draw[arrow] (topichead) -- (topic);
171
-
172
- % Cross-attention arrow
173
- \draw[dashedarrow] (encoder.west) -- ++(-0.5,0) |- (decoder.south west);
174
-
175
- % Labels
176
- \node[font=\tiny, align=center] at (-1.8, 4.5) {Cross\\Attention};
177
-
178
- \end{tikzpicture}
179
- \caption{LexiMind system architecture showing the shared encoder, task-specific routing, decoder for generation, and classification heads for discriminative tasks.}
180
- \label{fig:architecture}
181
- \end{figure}
182
-
183
- \subsection{Transformer Block Structure}
184
-
185
- Figure \ref{fig:transformer_block} presents the internal structure of encoder and decoder blocks, following the Pre-LN configuration from T5 \cite{raffel2020exploring}.
186
-
187
- \begin{figure}[htbp]
188
- \centering
189
- \begin{tikzpicture}[
190
- scale=0.65,
191
- transform shape,
192
- block/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.6cm, align=center, rounded corners=2pt},
193
- norm/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.5cm, align=center, fill=yellow!30, rounded corners=2pt},
194
- attn/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.6cm, align=center, fill=blue!25, rounded corners=2pt},
195
- ffn/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.6cm, align=center, fill=green!25, rounded corners=2pt},
196
- add/.style={draw, circle, minimum size=0.4cm, fill=red!20, inner sep=0pt, font=\small},
197
- arrow/.style={->, >=stealth},
198
- ]
199
-
200
- % === ENCODER BLOCK ===
201
- \node[font=\bfseries] at (0, 8) {Encoder Block};
202
-
203
- % Input
204
- \node (enc_in) at (0, 7) {};
205
- \draw[arrow] (0, 6.5) -- (enc_in);
206
-
207
- % RMSNorm 1
208
- \node[norm] (enc_norm1) at (0, 6) {RMSNorm};
209
-
210
- % Self-Attention
211
- \node[attn] (enc_attn) at (0, 5) {Multi-Head\\Self-Attention};
212
-
213
- % Add 1
214
- \node[add] (enc_add1) at (0, 4) {+};
215
-
216
- % RMSNorm 2
217
- \node[norm] (enc_norm2) at (0, 3) {RMSNorm};
218
-
219
- % FFN
220
- \node[ffn] (enc_ffn) at (0, 2) {Gated FFN\\(GELU)};
221
-
222
- % Add 2
223
- \node[add] (enc_add2) at (0, 1) {+};
224
-
225
- % Output
226
- \node (enc_out) at (0, 0.3) {};
227
-
228
- % Connections
229
- \draw[arrow] (enc_in) -- (enc_norm1);
230
- \draw[arrow] (enc_norm1) -- (enc_attn);
231
- \draw[arrow] (enc_attn) -- (enc_add1);
232
- \draw[arrow] (enc_add1) -- (enc_norm2);
233
- \draw[arrow] (enc_norm2) -- (enc_ffn);
234
- \draw[arrow] (enc_ffn) -- (enc_add2);
235
- \draw[arrow] (enc_add2) -- (enc_out);
236
-
237
- % Residual connections
238
- \draw[arrow] (0, 6.5) -- (-1.5, 6.5) -- (-1.5, 4) -- (enc_add1.west);
239
- \draw[arrow] (enc_add1.east) -- (1.5, 4) -- (1.5, 1) -- (enc_add2.east);
240
-
241
- % === DECODER BLOCK ===
242
- \node[font=\bfseries] at (5.5, 8) {Decoder Block};
243
-
244
- % Input
245
- \node (dec_in) at (5.5, 7) {};
246
- \draw[arrow] (5.5, 6.5) -- (dec_in);
247
-
248
- % RMSNorm 1
249
- \node[norm] (dec_norm1) at (5.5, 6) {RMSNorm};
250
-
251
- % Masked Self-Attention
252
- \node[attn] (dec_attn1) at (5.5, 5) {Masked\\Self-Attention};
253
-
254
- % Add 1
255
- \node[add] (dec_add1) at (5.5, 4.2) {+};
256
-
257
- % RMSNorm 2
258
- \node[norm] (dec_norm2) at (5.5, 3.4) {RMSNorm};
259
-
260
- % Cross-Attention
261
- \node[attn, fill=cyan!25] (dec_attn2) at (5.5, 2.4) {Cross-Attention};
262
-
263
- % Add 2
264
- \node[add] (dec_add2) at (5.5, 1.5) {+};
265
-
266
- % RMSNorm 3
267
- \node[norm] (dec_norm3) at (5.5, 0.7) {RMSNorm};
268
-
269
- % FFN
270
- \node[ffn] (dec_ffn) at (5.5, -0.3) {Gated FFN\\(GELU)};
271
-
272
- % Add 3
273
- \node[add] (dec_add3) at (5.5, -1.2) {+};
274
-
275
- % Connections
276
- \draw[arrow] (dec_in) -- (dec_norm1);
277
- \draw[arrow] (dec_norm1) -- (dec_attn1);
278
- \draw[arrow] (dec_attn1) -- (dec_add1);
279
- \draw[arrow] (dec_add1) -- (dec_norm2);
280
- \draw[arrow] (dec_norm2) -- (dec_attn2);
281
- \draw[arrow] (dec_attn2) -- (dec_add2);
282
- \draw[arrow] (dec_add2) -- (dec_norm3);
283
- \draw[arrow] (dec_norm3) -- (dec_ffn);
284
- \draw[arrow] (dec_ffn) -- (dec_add3);
285
-
286
- % Encoder memory input
287
- \node[block, fill=gray!20, minimum width=1.2cm, font=\scriptsize] (memory) at (8, 2.4) {Encoder\\Memory};
288
- \draw[arrow] (memory) -- (dec_attn2);
289
-
290
- % Residual connections (simplified)
291
- \draw[arrow] (5.5, 6.5) -- (4, 6.5) -- (4, 4.2) -- (dec_add1.west);
292
- \draw[arrow] (dec_add1.east) -- (7, 4.2) -- (7, 1.5) -- (dec_add2.east);
293
- \draw[arrow] (dec_add2.west) -- (4, 1.5) -- (4, -1.2) -- (dec_add3.west);
294
-
295
- \end{tikzpicture}
296
- \caption{Pre-LN Transformer blocks. Left: Encoder block with self-attention and FFN. Right: Decoder block with masked self-attention, cross-attention to encoder memory, and FFN. RMSNorm is applied \emph{before} each sublayer (Pre-LN).}
297
- \label{fig:transformer_block}
298
- \end{figure}
299
-
300
- \subsection{Multi-Head Attention Mechanism}
301
-
302
- The attention mechanism is the cornerstone of the Transformer architecture. LexiMind implements Multi-Head Attention with support for T5-style relative position bias and FlashAttention optimization. Figure \ref{fig:attention} illustrates the attention computation.
303
-
304
- \begin{figure}[htbp]
305
- \centering
306
- \begin{tikzpicture}[
307
- scale=0.6,
308
- transform shape,
309
- box/.style={draw, rectangle, minimum width=1.5cm, minimum height=0.6cm, align=center, rounded corners=2pt},
310
- proj/.style={draw, rectangle, minimum width=1.2cm, minimum height=0.5cm, align=center, fill=blue!20, rounded corners=2pt, font=\scriptsize},
311
- op/.style={draw, rectangle, minimum width=1.2cm, minimum height=0.5cm, align=center, fill=orange!30, rounded corners=2pt, font=\scriptsize},
312
- arrow/.style={->, >=stealth},
313
- ]
314
-
315
- % Input
316
- \node[box, fill=gray!20] (input) at (0, 0) {Input $X$};
317
-
318
- % Projections
319
- \node[proj] (wq) at (-2.5, 1.5) {$W_Q$};
320
- \node[proj] (wk) at (0, 1.5) {$W_K$};
321
- \node[proj] (wv) at (2.5, 1.5) {$W_V$};
322
-
323
- % Q, K, V
324
- \node[box, fill=green!20] (q) at (-2.5, 2.8) {$Q$};
325
- \node[box, fill=green!20] (k) at (0, 2.8) {$K$};
326
- \node[box, fill=green!20] (v) at (2.5, 2.8) {$V$};
327
-
328
- % Split heads
329
- \node[op] (split) at (0, 4) {Split $h$ heads};
330
-
331
- % Attention scores
332
- \node[op] (matmul1) at (0, 5.2) {$QK^T$};
333
-
334
- % Position bias
335
- \node[box, fill=yellow!30, font=\scriptsize] (bias) at (3.5, 5.2) {Relative\\Pos Bias};
336
-
337
- % Add bias
338
- \node[op] (add) at (0, 6.2) {$+ B_{rel}$};
339
-
340
- % Scale (optional)
341
- \node[op] (scale) at (0, 7.2) {Scale / Mask};
342
-
343
- % Softmax
344
- \node[op, fill=red!20] (softmax) at (0, 8.2) {Softmax};
345
-
346
- % MatMul with V
347
- \node[op] (matmul2) at (0, 9.2) {$\times V$};
348
-
349
- % Concat
350
- \node[op] (concat) at (0, 10.2) {Concat heads};
351
-
352
- % Output projection
353
- \node[proj] (wo) at (0, 11.2) {$W_O$};
354
-
355
- % Output
356
- \node[box, fill=purple!20] (output) at (0, 12.2) {Output};
357
-
358
- % Arrows
359
- \draw[arrow] (input) -- (wq);
360
- \draw[arrow] (input) -- (wk);
361
- \draw[arrow] (input) -- (wv);
362
- \draw[arrow] (wq) -- (q);
363
- \draw[arrow] (wk) -- (k);
364
- \draw[arrow] (wv) -- (v);
365
- \draw[arrow] (q) -- (split);
366
- \draw[arrow] (k) -- (split);
367
- \draw[arrow] (v.north) -- ++(0, 0.3) -| (2.5, 9.2) -- (matmul2);
368
- \draw[arrow] (split) -- (matmul1);
369
- \draw[arrow] (matmul1) -- (add);
370
- \draw[arrow] (bias) -- (add);
371
- \draw[arrow] (add) -- (scale);
372
- \draw[arrow] (scale) -- (softmax);
373
- \draw[arrow] (softmax) -- (matmul2);
374
- \draw[arrow] (matmul2) -- (concat);
375
- \draw[arrow] (concat) -- (wo);
376
- \draw[arrow] (wo) -- (output);
377
-
378
- % Annotations
379
- \node[font=\tiny, align=left] at (-4.5, 5.5) {T5 does NOT\\scale by $\sqrt{d_k}$};
380
-
381
- \end{tikzpicture}
382
- \caption{Multi-Head Attention with T5-style relative position bias. The attention scores are computed as $QK^T + B_{rel}$, where $B_{rel}$ is the learned relative position bias. Unlike standard Transformers, T5 does not scale by $\sqrt{d_k}$.}
383
- \label{fig:attention}
384
- \end{figure}
385
-
386
- The attention computation in LexiMind is implemented in \texttt{src/models/attention.py}. For T5 compatibility, the \texttt{scale\_scores} parameter controls whether to apply $\sqrt{d_k}$ scaling—T5 does not use this scaling \cite{raffel2020exploring}.
387
-
388
- Figure \ref{fig:attention_viz} shows learned attention patterns from the trained model, demonstrating how different heads specialize in capturing various linguistic relationships.
389
-
390
- \begin{figure}[htbp]
391
- \centering
392
- \includegraphics[width=\columnwidth]{figures/multihead_attention_visualization.png}
393
- \caption{Attention weight visualization across multiple heads. Each head learns distinct attention patterns: some focus on local context (diagonal patterns), while others capture long-range dependencies and syntactic relationships.}
394
- \label{fig:attention_viz}
395
- \end{figure}
396
-
397
- \subsubsection{T5 Relative Position Bias}
398
-
399
- Unlike absolute positional embeddings that are added to token embeddings, T5 uses relative position bias added directly to attention scores. The \texttt{T5RelativePositionBias} class implements logarithmically-bucketed relative positions:
400
-
401
- \begin{equation}
402
- B_{ij} = \text{Embed}[\text{bucket}(i - j)]
403
- \end{equation}
404
-
405
- where $\text{bucket}(\cdot)$ maps relative distances to discrete buckets. Half the buckets encode exact positions for nearby tokens; the remaining half use logarithmic spacing for distant tokens. As documented in the code:
406
-
407
- \begin{quote}
408
- \emph{``T5 uses a combination of exact positions (for nearby tokens) and logarithmically-spaced buckets (for distant tokens).''} — \texttt{attention.py}, lines 46--48
409
- \end{quote}
410
-
411
- Figure \ref{fig:position_bias} visualizes the learned relative position bias, showing how the model encodes positional relationships between tokens.
412
-
413
- \begin{figure}[htbp]
414
- \centering
415
- \includegraphics[width=\columnwidth]{figures/positional_encoding_heatmap.png}
416
- \caption{Heatmap of relative position bias values. The diagonal structure indicates stronger attention between nearby positions, while the logarithmic bucketing allows efficient representation of longer-range dependencies.}
417
- \label{fig:position_bias}
418
- \end{figure}
419
-
420
- \subsubsection{FlashAttention Integration}
421
-
422
- LexiMind leverages PyTorch 2.0's \texttt{scaled\_dot\_product\_attention} function, which automatically selects the optimal attention kernel:
423
-
424
- \begin{quote}
425
- \emph{``Uses F.scaled\_dot\_product\_attention which automatically selects the best available kernel (FlashAttention v2, Memory-Efficient Attention, or math fallback) based on hardware and input shapes.''} — \texttt{attention.py}, lines 134--137
426
- \end{quote}
427
-
428
- This provides O(N) memory complexity instead of O(N²) when FlashAttention is available.
429
-
430
- \subsection{Feed-Forward Network}
431
-
432
- Following T5, LexiMind implements a gated feed-forward network with GELU activation:
433
-
434
- \begin{equation}
435
- \text{FFN}(x) = (\text{GELU}(xW_g) \odot xW_1) W_2
436
- \end{equation}
437
-
438
- where $W_g$ is the gating projection, $W_1$ is the up-projection, $W_2$ is the down-projection, and $\odot$ denotes element-wise multiplication.
439
-
440
- \subsection{RMSNorm}
441
-
442
- RMSNorm \cite{zhang2019root} normalizes inputs using only the root mean square:
443
-
444
- \begin{equation}
445
- \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2 + \epsilon}} \cdot \gamma
446
- \end{equation}
447
-
448
- The implementation in \texttt{src/models/t5\_layer\_norm.py} follows T5's convention, using only a learned scale parameter $\gamma$ with no bias term.
449
-
450
- %=============================================================================
451
- \section{Tokenization}
452
- \label{sec:tokenization}
453
- %=============================================================================
454
-
455
- LexiMind wraps HuggingFace's AutoTokenizer with a simplified façade that handles T5-specific conventions. The implementation in \texttt{src/data/tokenization.py} manages special token handling and decoder input preparation.
456
-
457
- \subsection{T5 Tokenizer Characteristics}
458
-
459
- T5 uses SentencePiece \cite{kudo2018sentencepiece} with unigram tokenization:
460
-
461
- \begin{itemize}
462
- \item \textbf{Vocabulary Size}: 32,128 tokens (padded to multiple of 128 for efficiency)
463
- \item \textbf{Special Tokens}: \texttt{pad\_token\_id=0}, \texttt{eos\_token\_id=1}
464
- \item \textbf{No Explicit BOS}: T5 uses the pad token as the decoder start token
465
- \end{itemize}
466
-
467
- As noted in the tokenizer implementation:
468
-
469
- \begin{quote}
470
- \emph{``T5 uses different special tokens than BART: T5: pad=0, eos=1, no explicit bos (uses pad or eos as decoder start); BART: bos=0, pad=1, eos=2.''} — \texttt{tokenization.py}, lines 42--44
471
- \end{quote}
472
-
473
- \subsection{Decoder Input Preparation}
474
-
475
- For seq2seq training, decoder inputs must be shifted right from labels. The \texttt{prepare\_decoder\_inputs} method handles this:
476
-
477
- \begin{lstlisting}[caption={Decoder input preparation from tokenization.py}]
478
- def prepare_decoder_inputs(
479
- self, labels: torch.Tensor
480
- ) -> torch.Tensor:
481
- """Shift decoder labels to create
482
- input ids prefixed by BOS."""
483
- bos = self.bos_token_id
484
- pad = self.pad_token_id
485
- decoder_inputs = torch.full_like(labels, pad)
486
- decoder_inputs[:, 0] = bos
487
- decoder_inputs[:, 1:] = labels[:, :-1]
488
- return decoder_inputs
489
- \end{lstlisting}
490
-
491
- %=============================================================================
492
- \section{The Factory Module: Weight Transfer from FLAN-T5}
493
- \label{sec:factory}
494
- %=============================================================================
495
-
496
- The \texttt{factory.py} module is central to LexiMind's hybrid approach, providing model construction and weight loading utilities. Figure \ref{fig:factory_flow} illustrates the model construction pipeline.
497
-
498
- \begin{figure}[htbp]
499
- \centering
500
- \begin{tikzpicture}[
501
- scale=0.7,
502
- transform shape,
503
- box/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.7cm, align=center, rounded corners=3pt},
504
- config/.style={draw, rectangle, minimum width=2cm, minimum height=0.6cm, align=center, fill=yellow!30, rounded corners=2pt, font=\small},
505
- model/.style={draw, rectangle, minimum width=2.2cm, minimum height=0.6cm, align=center, fill=green!30, rounded corners=2pt, font=\small},
506
- arrow/.style={->, >=stealth, thick},
507
- ]
508
-
509
- % Config loading
510
- \node[config] (yaml) at (0, 0) {config.yaml};
511
- \node[box, fill=blue!20] (loadconfig) at (0, 1.3) {load\_model\_config()};
512
- \node[config] (modelconfig) at (0, 2.6) {ModelConfig};
513
-
514
- % Model building
515
- \node[box, fill=blue!20] (build) at (0, 4.2) {build\_multitask\_model()};
516
-
517
- % Components
518
- \node[model] (encoder) at (-2.5, 5.8) {Encoder};
519
- \node[model] (decoder) at (0, 5.8) {Decoder};
520
- \node[model] (heads) at (2.5, 5.8) {Task Heads};
521
-
522
- % Weight loading
523
- \node[box, fill=orange!30] (loadweights) at (-1.2, 7.4) {\_load\_pretrained\_weights()};
524
-
525
- % FLAN-T5
526
- \node[box, fill=purple!20] (flant5) at (-4, 7.4) {FLAN-T5\\(HuggingFace)};
527
-
528
- % Final model
529
- \node[box, fill=red!20, minimum width=3cm] (mtmodel) at (0, 9) {MultiTaskModel};
530
-
531
- % Arrows
532
- \draw[arrow] (yaml) -- (loadconfig);
533
- \draw[arrow] (loadconfig) -- (modelconfig);
534
- \draw[arrow] (modelconfig) -- (build);
535
- \draw[arrow] (build) -- (encoder);
536
- \draw[arrow] (build) -- (decoder);
537
- \draw[arrow] (build) -- (heads);
538
- \draw[arrow] (encoder) -- (loadweights);
539
- \draw[arrow] (decoder) -- (loadweights);
540
- \draw[arrow] (flant5) -- (loadweights);
541
- \draw[arrow] (loadweights) -- (mtmodel);
542
- \draw[arrow] (heads) -- (mtmodel);
543
-
544
- \end{tikzpicture}
545
- \caption{Model construction pipeline in \texttt{factory.py}. Configuration is loaded from YAML, components are instantiated, FLAN-T5 weights are transferred, and the final MultiTaskModel is assembled.}
546
- \label{fig:factory_flow}
547
- \end{figure}
548
-
549
- \subsection{Configuration Management}
550
-
551
- The \texttt{ModelConfig} dataclass defines all architecture hyperparameters:
552
-
553
- \begin{lstlisting}[caption={ModelConfig from factory.py}]
554
- @dataclass
555
- class ModelConfig:
556
- d_model: int = 768
557
- vocab_size: Optional[int] = None
558
- num_encoder_layers: int = 12
559
- num_decoder_layers: int = 12
560
- num_attention_heads: int = 12
561
- ffn_dim: int = 2048
562
- dropout: float = 0.1
563
- use_pretrained: bool = False
564
- pretrained_model_name: str =
565
- "google/flan-t5-base"
566
- activation: str = "gated-gelu"
567
- use_relative_position_bias: bool = False
568
- \end{lstlisting}
569
-
570
- \subsection{Weight Transfer Mechanism}
571
-
572
- The \texttt{\_load\_pretrained\_weights} function performs careful weight mapping between FLAN-T5 and LexiMind's custom architecture. Key considerations documented in the code:
573
-
574
- \begin{quote}
575
- \emph{``T5 architecture compatibility with our custom Transformer: T5 uses Pre-LN (RMSNorm before sublayers) --- matches our design; T5 uses relative position bias instead of absolute embeddings; T5 uses gated FFN (wi\_0, wi\_1, wo); T5 attention has no bias, our attention has bias --- we zero-initialize the bias terms.''} --- \texttt{factory.py}, lines 100--108
576
- \end{quote}
577
-
578
- Table \ref{tab:weight_mapping} shows the parameter correspondence:
579
-
580
- \begin{table}[htbp]
581
- \centering
582
- \caption{FLAN-T5 to LexiMind Weight Mapping}
583
- \label{tab:weight_mapping}
584
- \begin{tabular}{ll}
585
- \toprule
586
- \textbf{FLAN-T5 Parameter} & \textbf{LexiMind Parameter} \\
587
- \midrule
588
- \texttt{shared} & \texttt{encoder.embedding} \\
589
- \texttt{encoder.block.*.SelfAttention.q} & \texttt{encoder.layers.*.self\_attn.W\_Q} \\
590
- \texttt{encoder.block.*.SelfAttention.k} & \texttt{encoder.layers.*.self\_attn.W\_K} \\
591
- \texttt{encoder.block.*.SelfAttention.v} & \texttt{encoder.layers.*.self\_attn.W\_V} \\
592
- \texttt{encoder.block.*.SelfAttention.o} & \texttt{encoder.layers.*.self\_attn.W\_O} \\
593
- \texttt{*.layer\_norm} & \texttt{*.norm*} \\
594
- \texttt{*.DenseReluDense.wi\_0} & \texttt{*.ffn.linear\_gate} \\
595
- \texttt{*.DenseReluDense.wi\_1} & \texttt{*.ffn.linear1} \\
596
- \texttt{*.DenseReluDense.wo} & \texttt{*.ffn.linear2} \\
597
- \texttt{lm\_head} & \texttt{decoder.output\_projection} \\
598
- \bottomrule
599
- \end{tabular}
600
- \end{table}
601
-
602
- \subsection{Vocabulary Size Handling}
603
-
604
- T5 pads its vocabulary to multiples of 128 for computational efficiency (32,100 → 32,128). LexiMind handles this mismatch:
605
-
606
- \begin{quote}
607
- \emph{``Note: T5's vocab is padded to multiple of 128 for efficiency (32100 → 32128). [...] Copy only the tokens that exist in both. Initialize any extra tokens with small random values.''} — \texttt{factory.py}, lines 116--131
608
- \end{quote}
609
-
610
- \subsection{Model Assembly}
611
-
612
- The \texttt{build\_multitask\_model} function assembles the complete system:
613
-
614
- \begin{lstlisting}[caption={Model assembly from factory.py}]
615
- model = MultiTaskModel(
616
- encoder=encoder,
617
- decoder=decoder,
618
- decoder_outputs_logits=True
619
- )
620
- model.add_head(
621
- "summarization",
622
- LMHead(d_model=cfg.d_model,
623
- vocab_size=vocab_size,
624
- tie_embedding=decoder.embedding)
625
- )
626
- model.add_head(
627
- "emotion",
628
- ClassificationHead(
629
- d_model=cfg.d_model,
630
- num_labels=28, # GoEmotions
631
- pooler="mean",
632
- hidden_dim=cfg.d_model // 2)
633
- )
634
- model.add_head(
635
- "topic",
636
- ClassificationHead(
637
- d_model=cfg.d_model,
638
- num_labels=7, # 7 topic categories
639
- pooler="mean")
640
- )
641
- \end{lstlisting}
642
-
643
- %=============================================================================
644
- \section{Multi-Task Model Architecture}
645
- \label{sec:multitask}
646
- %=============================================================================
647
-
648
- The \texttt{MultiTaskModel} class in \texttt{src/models/multitask.py} provides the routing infrastructure for multi-task learning. Figure \ref{fig:multitask_routing} illustrates the task routing mechanism.
649
-
650
- \begin{figure}[htbp]
651
- \centering
652
- \begin{tikzpicture}[
653
- scale=0.7,
654
- transform shape,
655
- box/.style={draw, rectangle, minimum width=2cm, minimum height=0.6cm, align=center, rounded corners=2pt},
656
- decision/.style={draw, diamond, aspect=2, minimum width=1.5cm, align=center, fill=yellow!30},
657
- arrow/.style={->, >=stealth, thick},
658
- ]
659
-
660
- % Forward call
661
- \node[box, fill=blue!20] (forward) at (0, 0) {forward(task, inputs)};
662
-
663
- % Decision
664
- \node[decision] (taskcheck) at (0, -1.5) {task type?};
665
-
666
- % Branches
667
- \node[box, fill=green!20] (encoder) at (-3.5, -3.5) {Encoder\\Only};
668
- \node[box, fill=green!20] (seq2seq) at (3.5, -3.5) {Encoder\\+ Decoder};
669
-
670
- % Heads
671
- \node[box, fill=orange!20] (classhead) at (-3.5, -5) {Classification\\Head};
672
- \node[box, fill=orange!20] (lmhead) at (3.5, -5) {LM Head};
673
-
674
- % Tasks
675
- \node[box, fill=purple!20, font=\scriptsize] (emotion) at (-5, -6.5) {Emotion};
676
- \node[box, fill=purple!20, font=\scriptsize] (topic) at (-2, -6.5) {Topic};
677
- \node[box, fill=purple!20, font=\scriptsize] (summ) at (3.5, -6.5) {Summarization};
678
-
679
- % Arrows
680
- \draw[arrow] (forward) -- (taskcheck);
681
- \draw[arrow] (taskcheck) -- node[above, font=\scriptsize] {Classification} (encoder);
682
- \draw[arrow] (taskcheck) -- node[above, font=\scriptsize] {Generation} (seq2seq);
683
- \draw[arrow] (encoder) -- (classhead);
684
- \draw[arrow] (seq2seq) -- (lmhead);
685
- \draw[arrow] (classhead) -- (emotion);
686
- \draw[arrow] (classhead) -- (topic);
687
- \draw[arrow] (lmhead) -- (summ);
688
-
689
- \end{tikzpicture}
690
- \caption{Task routing in MultiTaskModel. Classification tasks use encoder-only processing with mean pooling, while generation tasks use the full encoder-decoder pipeline.}
691
- \label{fig:multitask_routing}
692
- \end{figure}
693
-
694
- \subsection{Task-Specific Head Selection}
695
-
696
- The forward method routes inputs based on head type:
697
-
698
- \begin{quote}
699
- \emph{``Encoder-only heads expect encoder outputs [...] LM/seq2seq head: run encoder → decoder → lm head''} — \texttt{multitask.py}, lines 108--148
700
- \end{quote}
701
-
702
- \subsection{Classification Head}
703
-
704
- Classification tasks (emotion, topic) use mean pooling over encoder outputs:
705
-
706
- \begin{equation}
707
- h_{cls} = \frac{\sum_{i=1}^{L} h_i \cdot m_i}{\sum_{i=1}^{L} m_i}
708
- \end{equation}
709
-
710
- where $m_i$ is the attention mask (1 for valid tokens, 0 for padding). The pooled representation is projected through a linear layer to class logits.
711
-
712
- %=============================================================================
713
- \section{Training Pipeline}
714
- \label{sec:training}
715
- %=============================================================================
716
-
717
- The training infrastructure in \texttt{src/training/trainer.py} implements a comprehensive multi-task training loop with modern deep learning practices.
718
-
719
- \subsection{Training Configuration}
720
-
721
- The \texttt{TrainerConfig} dataclass encapsulates all hyperparameters:
722
-
723
- \begin{lstlisting}[caption={TrainerConfig from trainer.py}]
724
- @dataclass
725
- class TrainerConfig:
726
- max_epochs: int = 1
727
- gradient_clip_norm: float = 1.0
728
- task_weights: Dict[str, float] | None = None
729
- label_smoothing: float = 0.0
730
- gradient_accumulation_steps: int = 1
731
- scheduler_type: str = "cosine"
732
- warmup_steps: int = 0
733
- early_stopping_patience: int | None = None
734
- gradient_checkpointing: bool = False
735
- compile_model: bool = False
736
- \end{lstlisting}
737
-
738
- \subsection{Mixed-Precision Training}
739
-
740
- LexiMind uses Automatic Mixed Precision (AMP) with automatic dtype selection:
741
-
742
- \begin{quote}
743
- \emph{``AMP setup: bfloat16 for Ampere+ GPUs, float16 otherwise''} — \texttt{trainer.py}, line 102
744
- \end{quote}
745
-
746
- BFloat16 provides better numerical stability for training while maintaining the memory and speed benefits of reduced precision.
747
-
748
- \subsection{Learning Rate Scheduling}
749
-
750
- A cosine schedule with linear warmup is implemented:
751
-
752
- \begin{equation}
753
- lr(t) = \begin{cases}
754
- lr_{max} \cdot \frac{t}{t_{warmup}} & t < t_{warmup} \\
755
- lr_{min} + \frac{1}{2}(lr_{max} - lr_{min})(1 + \cos(\frac{\pi(t-t_{warmup})}{T-t_{warmup}})) & t \geq t_{warmup}
756
- \end{cases}
757
- \end{equation}
758
-
759
- Figure \ref{fig:lr_schedule} visualizes the learning rate schedule over training, showing the 300-step linear warmup followed by cosine decay.
760
-
761
- \begin{figure}[htbp]
762
- \centering
763
- \includegraphics[width=\columnwidth]{figures/learning_rate_schedule.png}
764
- \caption{Learning rate schedule with linear warmup (300 steps) followed by cosine annealing. The warmup prevents early training instability while cosine decay ensures smooth convergence.}
765
- \label{fig:lr_schedule}
766
- \end{figure}
767
-
768
- \subsection{Multi-Task Loss Computation}
769
-
770
- The total loss combines task-specific losses with optional weighting:
771
-
772
- \begin{equation}
773
- \mathcal{L}_{total} = \sum_{t \in \text{tasks}} \lambda_t \mathcal{L}_t
774
- \end{equation}
775
-
776
- \begin{itemize}
777
- \item \textbf{Summarization}: Cross-entropy with label smoothing and \texttt{ignore\_index=-100}
778
- \item \textbf{Emotion}: Binary Cross-Entropy with Logits (multi-label)
779
- \item \textbf{Topic}: Standard Cross-Entropy (single-label)
780
- \end{itemize}
781
-
782
- \subsection{Gradient Handling}
783
-
784
- The trainer includes gradient clipping and early stopping:
785
-
786
- \begin{quote}
787
- \emph{``Gradient clipping to prevent exploding gradients [...] Early stopping based on validation loss''} — \texttt{trainer.py}
788
- \end{quote}
789
-
790
- \subsection{Training Loop}
791
-
792
- Figure \ref{fig:training_loop} illustrates the training loop structure.
793
-
794
- \begin{figure}[htbp]
795
- \centering
796
- \begin{tikzpicture}[
797
- scale=0.65,
798
- transform shape,
799
- box/.style={draw, rectangle, minimum width=2.5cm, minimum height=0.5cm, align=center, rounded corners=2pt, font=\small},
800
- arrow/.style={->, >=stealth},
801
- ]
802
-
803
- % Epoch loop
804
- \node[box, fill=blue!20] (epoch) at (0, 0) {For each epoch};
805
-
806
- % Batch loop
807
- \node[box, fill=green!20] (batch) at (0, -1.2) {For each batch};
808
-
809
- % Task loop
810
- \node[box, fill=yellow!20] (task) at (0, -2.4) {For each task};
811
-
812
- % Forward
813
- \node[box, fill=orange!20] (forward) at (0, -3.6) {Forward + Loss};
814
-
815
- % AMP context
816
- \node[box, fill=purple!20] (amp) at (0, -4.8) {AMP autocast};
817
-
818
- % Backward
819
- \node[box, fill=red!20] (backward) at (0, -6) {Backward (scaled)};
820
-
821
- % Accumulate check
822
- \node[box, fill=cyan!20] (accum) at (0, -7.2) {Accumulation step?};
823
-
824
- % Optimizer step
825
- \node[box, fill=gray!20] (optim) at (0, -8.4) {Clip + Step + Zero};
826
-
827
- % Validation
828
- \node[box, fill=blue!20] (val) at (3.5, -1.2) {Validation};
829
-
830
- % Checkpoint
831
- \node[box, fill=green!20] (ckpt) at (3.5, -2.4) {Checkpoint};
832
-
833
- % Early stopping
834
- \node[box, fill=red!20] (early) at (3.5, -3.6) {Early Stop?};
835
-
836
- % Arrows
837
- \draw[arrow] (epoch) -- (batch);
838
- \draw[arrow] (batch) -- (task);
839
- \draw[arrow] (task) -- (forward);
840
- \draw[arrow] (forward) -- (amp);
841
- \draw[arrow] (amp) -- (backward);
842
- \draw[arrow] (backward) -- (accum);
843
- \draw[arrow] (accum) -- (optim);
844
- \draw[arrow] (optim.south) -- ++(0, -0.3) -| ++(-2, 0) |- (task.west);
845
- \draw[arrow] (epoch.east) -- ++(0.5, 0) |- (val);
846
- \draw[arrow] (val) -- (ckpt);
847
- \draw[arrow] (ckpt) -- (early);
848
-
849
- \end{tikzpicture}
850
- \caption{Training loop structure showing nested iteration over epochs, batches, and tasks, with gradient accumulation and validation checkpoints.}
851
- \label{fig:training_loop}
852
- \end{figure}
853
-
854
- %=============================================================================
855
- \section{Tasks and Datasets}
856
- %=============================================================================
857
-
858
- LexiMind addresses three complementary NLP tasks:
859
-
860
- \subsection{Text Summarization}
861
-
862
- \textbf{Task}: Generate concise abstractive summaries from longer documents, focusing on back-cover style book descriptions rather than plot summaries.
863
-
864
- \textbf{Datasets}: The summarization corpus comprises 49,086 training samples, 2,727 validation samples, and 2,727 test samples. Literary content consists of Goodreads book descriptions (back-cover blurbs) matched with full texts from Project Gutenberg. Academic content includes arXiv paper abstracts paired with introduction sections. Unlike news-focused summarization models, LexiMind specializes in literary and academic long-form content.
865
-
866
- \textbf{Approach}: Encoder-decoder generation with greedy decoding (beam search available). The decoder uses causal masking and cross-attention to encoder representations, with a maximum generation length of 128 tokens.
867
-
868
- \textbf{Evaluation}: ROUGE-1/2/L for n-gram overlap, BLEU-4 for fluency, and BERTScore (using RoBERTa-large) for semantic similarity between generated and reference summaries.
869
-
870
- \subsection{Emotion Classification}
871
-
872
- \textbf{Task}: Multi-label classification identifying emotions expressed in text, where each sample may have multiple emotion labels.
873
-
874
- \textbf{Dataset}: Google's GoEmotions \cite{demszky2020goemotions}, comprising 43,410 training samples, 5,426 validation samples, and 5,427 test samples sourced from Reddit comments.
875
-
876
- \textbf{Classes}: 28 emotion categories: admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, joy, love, nervousness, neutral, optimism, pride, realization, relief, remorse, sadness, and surprise.
877
-
878
- \textbf{Approach}: Encoder-only processing with mean pooling over token representations, followed by a two-layer classification head with hidden dimension 384. Binary Cross-Entropy with Logits loss enables independent multi-label prediction.
879
-
880
- \subsection{Topic Classification}
881
-
882
- \textbf{Task}: Single-label classification assigning documents to one of seven topic categories.
883
-
884
- \textbf{Datasets}: A curated collection of 3,402 training samples, 189 validation samples, and 189 test samples drawn from arXiv paper categories and Project Gutenberg book metadata.
885
-
886
- \textbf{Classes}: 7 mutually exclusive topics: Arts, Business, Fiction, History, Philosophy, Science, and Technology.
887
-
888
- \textbf{Approach}: Encoder-only architecture with mean pooling, identical to emotion classification but using standard Cross-Entropy loss for mutually exclusive classes. Due to the significantly smaller dataset (3.4K vs 43K for emotion), the topic loss weight is reduced to 0.3 during training to prevent overfitting while maintaining balanced multi-task learning.
889
-
890
- %=============================================================================
891
- \section{Model Specifications}
892
- %=============================================================================
893
-
894
- Table \ref{tab:dataset_summary} summarizes the dataset splits used for training and evaluation. Table \ref{tab:model_specs} details the model architecture.
895
-
896
- \begin{table}[htbp]
897
- \centering
898
- \caption{Dataset Summary}
899
- \label{tab:dataset_summary}
900
- \begin{tabular}{lccc}
901
- \toprule
902
- \textbf{Task} & \textbf{Train} & \textbf{Val} & \textbf{Test} \\
903
- \midrule
904
- Summarization & 49,086 & 2,727 & 2,727 \\
905
- Emotion & 43,410 & 5,426 & 5,427 \\
906
- Topic & 3,402 & 189 & 189 \\
907
- \midrule
908
- \textbf{Total} & 95,898 & 8,342 & 8,343 \\
909
- \bottomrule
910
- \end{tabular}
911
- \end{table}
912
-
913
- \begin{table}[htbp]
914
- \centering
915
- \caption{LexiMind Model Specifications}
916
- \label{tab:model_specs}
917
- \begin{tabular}{lc}
918
- \toprule
919
- \textbf{Parameter} & \textbf{Value} \\
920
- \midrule
921
- Hidden dimension ($d_{model}$) & 768 \\
922
- FFN dimension ($d_{ff}$) & 2048 \\
923
- Attention heads & 12 \\
924
- Head dimension & 64 \\
925
- Encoder layers & 12 \\
926
- Decoder layers & 12 \\
927
- Vocabulary size & 32,128 \\
928
- Max sequence length & 512 \\
929
- Dropout & 0.1 \\
930
- Activation & Gated-GELU \\
931
- Normalization & RMSNorm (Pre-LN) \\
932
- Position encoding & Relative bias \\
933
- \midrule
934
- Total parameters & $\sim$272M \\
935
- \bottomrule
936
- \end{tabular}
937
- \end{table}
938
-
939
- %=============================================================================
940
- \section{Implementation Details}
941
- %=============================================================================
942
-
943
- \subsection{Project Structure}
944
-
945
- LexiMind follows a modular architecture:
946
-
947
- \begin{verbatim}
948
- src/
949
- +-- models/
950
- | +-- attention.py # MHA, RelPosBias
951
- | +-- encoder.py # Encoder blocks
952
- | +-- decoder.py # Decoder blocks
953
- | +-- heads.py # Task heads
954
- | +-- multitask.py # MTL routing
955
- | +-- factory.py # Construction
956
- +-- data/
957
- | +-- tokenization.py # Tokenizer
958
- | +-- dataset.py # Dataset classes
959
- | +-- dataloader.py # Collators
960
- +-- training/
961
- +-- trainer.py # Training loop
962
- +-- metrics.py # Evaluation
963
- scripts/
964
- +-- train.py # Main training
965
- +-- download_data.py # Dataset download
966
- +-- inference.py # CLI inference
967
- +-- demo_gradio.py # Web demo
968
- \end{verbatim}
969
-
970
- \subsection{FlashAttention and CUDA Optimizations}
971
-
972
- The trainer enables comprehensive hardware-specific optimizations:
973
-
974
- \begin{lstlisting}[caption={CUDA optimizations from train.py}]
975
- if device.type == "cuda":
976
- torch.backends.cudnn.benchmark = True
977
- torch.backends.cuda.matmul.allow_tf32 = True
978
- torch.backends.cudnn.allow_tf32 = True
979
- torch.backends.cuda.enable_flash_sdp(True)
980
- torch.backends.cuda.enable_mem_efficient_sdp(
981
- True)
982
- \end{lstlisting}
983
-
984
- Note that T5-style relative position bias is incompatible with FlashAttention, as FlashAttention requires adding bias tensors to attention scores which breaks the fused kernel. The development configuration disables relative position bias to enable FlashAttention for faster iteration, while production configurations retain relative position bias for better quality.
985
-
986
- \subsection{Numerical Stability}
987
-
988
- To prevent overflow during mixed-precision training, hidden states are clamped after each sublayer:
989
-
990
- \begin{quote}
991
- \emph{``Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)''} — \texttt{encoder.py}, lines 103--105
992
- \end{quote}
993
-
994
- %=============================================================================
995
- \section{Experimental Setup}
996
- %=============================================================================
997
-
998
- \subsection{Training Configuration}
999
-
1000
- The final training configuration was optimized for quality and efficiency on an NVIDIA RTX 4070 with 12GB VRAM:
1001
-
1002
- \begin{itemize}
1003
- \item \textbf{Optimizer}: Fused AdamW with weight decay 0.01, $\beta_1=0.9$, $\beta_2=0.98$
1004
- \item \textbf{Learning Rate}: $3 \times 10^{-5}$ with cosine decay
1005
- \item \textbf{Warmup}: 300 steps ($\sim$0.5 epochs)
1006
- \item \textbf{Batch Size}: 10 with 4$\times$ gradient accumulation (effective batch size 40)
1007
- \item \textbf{Precision}: BFloat16 on Ampere+ GPUs with TF32 enabled
1008
- \item \textbf{Gradient Clipping}: Max norm 1.0
1009
- \item \textbf{Gradient Checkpointing}: Enabled for memory efficiency
1010
- \item \textbf{torch.compile}: Dynamic compilation for encoder and decoder
1011
- \item \textbf{Task Weights}: Summarization 1.0, Emotion 1.0, Topic 0.3 (reduced due to small dataset)
1012
- \item \textbf{Early Stopping}: Patience of 3 epochs on validation loss
1013
- \item \textbf{Encoder Freezing}: Bottom 4 layers frozen for stable transfer learning
1014
- \end{itemize}
1015
-
1016
- Training completed in 7 epochs ($\sim$6 hours) with early stopping triggered due to validation loss plateau.
1017
-
1018
- %=============================================================================
1019
- \section{Experimental Results}
1020
- \label{sec:results}
1021
- %=============================================================================
1022
-
1023
- We evaluate LexiMind on held-out validation sets for each task. Table \ref{tab:summarization_results} presents the summarization metrics, Table \ref{tab:classification_results} shows classification performance.
1024
-
1025
- \subsection{Summarization Performance}
1026
-
1027
- \begin{table}[htbp]
1028
- \centering
1029
- \caption{Summarization Evaluation Results}
1030
- \label{tab:summarization_results}
1031
- \begin{tabular}{lc}
1032
- \toprule
1033
- \textbf{Metric} & \textbf{Score} \\
1034
- \midrule
1035
- ROUGE-1 & 0.3064 \\
1036
- ROUGE-2 & 0.0896 \\
1037
- ROUGE-L & 0.1832 \\
1038
- BLEU-4 & 0.0237 \\
1039
- \midrule
1040
- BERTScore Precision & 0.8430 \\
1041
- BERTScore Recall & 0.8179 \\
1042
- \textbf{BERTScore F1} & \textbf{0.8300} \\
1043
- \bottomrule
1044
- \end{tabular}
1045
- \end{table}
1046
-
1047
- The BERTScore F1 of \textbf{0.83} demonstrates strong semantic similarity between generated descriptions and references, indicating the model captures meaning effectively even when exact wording differs. ROUGE scores are typical for abstractive summarization where the model paraphrases rather than extracts verbatim text.
1048
-
1049
- \subsection{Classification Performance}
1050
-
1051
- \begin{table}[htbp]
1052
- \centering
1053
- \caption{Classification Evaluation Results}
1054
- \label{tab:classification_results}
1055
- \begin{tabular}{llc}
1056
- \toprule
1057
- \textbf{Task} & \textbf{Metric} & \textbf{Score} \\
1058
- \midrule
1059
- \multirow{2}{*}{Topic (7 classes)} & Accuracy & \textbf{85.19\%} \\
1060
- & Macro F1 & 0.8474 \\
1061
- \midrule
1062
- Emotion (28 classes) & Multi-label F1 & 0.1987 \\
1063
- \bottomrule
1064
- \end{tabular}
1065
- \end{table}
1066
-
1067
- Topic classification achieves \textbf{85.2\%} accuracy with balanced per-class performance. The emotion detection task proves more challenging due to the 28-class multi-label setting with inherent label ambiguity in the GoEmotions dataset.
1068
-
1069
- \subsection{Training Dynamics}
1070
-
1071
- Figure \ref{fig:training_curves} illustrates the training dynamics over 7 epochs. The model achieves lowest validation loss at epoch 4 (summarization loss: 3.698), with the checkpoint from this epoch saved as the best model. Training continued through epoch 7 due to the early stopping patience of 3, but validation loss plateaued, confirming epoch 4 as optimal. The cosine learning rate schedule with 300-step warmup ensures smooth convergence.
1072
-
1073
- \begin{figure}[htbp]
1074
- \centering
1075
- \includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
1076
- \caption{Training and validation loss curves over 7 epochs. Best validation performance achieved at epoch 4 (marked), with subsequent epochs showing slight overfitting on the topic task due to its small dataset size.}
1077
- \label{fig:training_curves}
1078
- \end{figure}
1079
-
1080
- Figure \ref{fig:task_metrics} presents per-task metrics throughout training, showing the distinct learning trajectories for summarization, emotion detection, and topic classification.
1081
-
1082
- \begin{figure}[htbp]
1083
- \centering
1084
- \includegraphics[width=\columnwidth]{figures/task_metrics.png}
1085
- \caption{Task-specific metrics during training: ROUGE-1 for summarization, F1 for emotion detection, and accuracy for topic classification.}
1086
- \label{fig:task_metrics}
1087
- \end{figure}
1088
-
1089
- Figure \ref{fig:training_dynamics} provides a comprehensive view of training dynamics, including loss convergence, per-epoch improvements, cumulative loss reduction, and the train-validation gap which indicates overfitting behavior.
1090
-
1091
- \begin{figure}[htbp]
1092
- \centering
1093
- \includegraphics[width=\columnwidth]{figures/training_dynamics.png}
1094
- \caption{Training dynamics overview: (top-left) Loss convergence with smoothing, (top-right) Relative improvement per epoch, (bottom-left) Cumulative loss reduction from initial values, (bottom-right) Train-validation gap showing slight overfitting after epoch 4.}
1095
- \label{fig:training_dynamics}
1096
- \end{figure}
1097
-
1098
- \subsection{Per-Class Topic Analysis}
1099
-
1100
- Table \ref{tab:topic_breakdown} shows the per-class performance for topic classification:
1101
-
1102
- \begin{table}[htbp]
1103
- \centering
1104
- \caption{Per-Class Topic Classification Performance}
1105
- \label{tab:topic_breakdown}
1106
- \begin{tabular}{lccc}
1107
- \toprule
1108
- \textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
1109
- \midrule
1110
- Arts & 0.93 & 0.76 & 0.84 \\
1111
- Business & 0.97 & 0.97 & 0.97 \\
1112
- Fiction & 0.95 & 1.00 & 0.97 \\
1113
- History & 0.83 & 0.78 & 0.81 \\
1114
- Philosophy & 0.80 & 0.86 & 0.83 \\
1115
- Science & 0.58 & 0.73 & 0.65 \\
1116
- Technology & 0.86 & 0.89 & 0.87 \\
1117
- \bottomrule
1118
- \end{tabular}
1119
- \end{table}
1120
-
1121
- The model performs best on Fiction and Business categories, while Science shows the most confusion, likely due to overlap with Technology topics.
1122
-
1123
- %=============================================================================
1124
- \section{Discussion}
1125
- %=============================================================================
1126
-
1127
- \subsection{Key Findings}
1128
-
1129
- \textbf{BERTScore vs. ROUGE}: The high BERTScore F1 (0.83) combined with moderate ROUGE-1 (0.31) illustrates a key characteristic of abstractive summarization. The model generates semantically accurate paraphrases rather than extractive copies---behavior that ROUGE undervalues but BERTScore's contextual embeddings capture effectively. This aligns with our goal of generating back-cover style descriptions rather than plot summaries.
1130
-
1131
- \textbf{Multi-Task Learning Dynamics}: Analysis of training curves reveals distinct learning trajectories across tasks. Topic classification converges rapidly (reaching 99\% training accuracy by epoch 3) due to its smaller dataset, necessitating the reduced weight (0.3) to prevent gradient dominance. Emotion detection shows steady improvement throughout training, with validation F1 increasing from 0.30 to 0.40. Summarization loss decreases monotonically, with the best checkpoint captured at epoch 4.
1132
-
1133
- \textbf{Transfer Learning Benefits}: Initializing from FLAN-T5-base provides strong linguistic priors, enabling competitive performance with only 7 epochs of fine-tuning ($\sim$6 hours on consumer hardware). Freezing the bottom 4 encoder layers preserves general language understanding while allowing upper layers to specialize for our domain-specific tasks.
1134
-
1135
- \textbf{Checkpoint Selection}: The best model checkpoint at epoch 4 achieves the lowest validation summarization loss (3.698) while maintaining strong classification performance. Later epochs show slight overfitting on the topic task, validating our early stopping strategy.
1136
-
1137
- \subsection{Limitations}
1138
-
1139
- \begin{itemize}
1140
- \item \textbf{Emotion Detection}: The 28-class multi-label setting remains challenging, with F1 of 0.20 on validation data. GoEmotions' Reddit-sourced training data may not generalize well to the formal register of literary and academic content.
1141
- \item \textbf{Topic Dataset Imbalance}: With only 3,402 training samples distributed across 7 classes, some categories (notably Science with 0.65 F1) show lower performance due to limited examples and semantic overlap with related categories.
1142
- \item \textbf{Domain Gap}: While Goodreads descriptions provide quality literary summaries, the model's exposure to contemporary fiction is limited by Project Gutenberg's public domain focus on pre-1928 works.
1143
- \end{itemize}
1144
-
1145
- \subsection{Future Work}
1146
-
1147
- Several directions could improve LexiMind's performance:
1148
- \begin{itemize}
1149
- \item \textbf{Domain-Specific Emotion Data}: Fine-tuning on literary emotion annotations rather than Reddit comments could better capture the emotional nuances of literary and academic text.
1150
- \item \textbf{Parameter-Efficient Fine-Tuning}: Integrating LoRA \cite{hu2022lora} would reduce memory requirements and enable experimentation with larger base models (FLAN-T5-large, FLAN-T5-xl).
1151
- \item \textbf{Expanded Topic Dataset}: Augmenting the 3.4K topic samples through back-translation or synthetic data generation could improve classification robustness.
1152
- \end{itemize}
1153
-
1154
- %=============================================================================
1155
- \section{Conclusion}
1156
- %=============================================================================
1157
-
1158
- This paper presented LexiMind, a multi-task NLP system combining custom Transformer implementation with FLAN-T5 pre-trained weights. The hybrid approach provides architectural transparency while leveraging transfer learning, achieving:
1159
-
1160
- \begin{itemize}
1161
- \item \textbf{Summarization}: BERTScore F1 of 0.83, demonstrating strong semantic fidelity for back-cover style book descriptions
1162
- \item \textbf{Topic Classification}: 85.2\% accuracy and 0.85 macro F1 across 7 categories
1163
- \item \textbf{Emotion Detection}: Multi-label F1 of 0.20 on 28 emotion classes
1164
- \end{itemize}
1165
-
1166
- The complete system trains in approximately 6 hours on a consumer GPU (RTX 4070 12GB), demonstrating that sophisticated multi-task models remain accessible without datacenter-scale resources. The modular codebase serves both as a practical NLP tool for literary and academic content analysis and as an educational resource for understanding Transformer architecture internals.
1167
-
1168
- All code, trained models, and datasets are publicly available, with a live demonstration hosted on HuggingFace Spaces.\footnote{\url{https://huggingface.co/spaces/OliverPerrin/LexiMind}}
1169
-
1170
- %=============================================================================
1171
- % References
1172
- %=============================================================================
1173
-
1174
- \begin{thebibliography}{00}
1175
-
1176
- \bibitem{vaswani2017attention}
1177
- A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin, ``Attention is all you need,'' in \textit{Advances in Neural Information Processing Systems (NeurIPS)}, vol. 30, 2017, pp. 5998--6008. [Online]. Available: \url{https://arxiv.org/abs/1706.03762}
1178
-
1179
- \bibitem{raffel2020exploring}
1180
- C. Raffel, N. Shazeer, A. Roberts, K. Lee, S. Narang, M. Matena, Y. Zhou, W. Li, and P. J. Liu, ``Exploring the limits of transfer learning with a unified text-to-text transformer,'' \textit{Journal of Machine Learning Research}, vol. 21, no. 140, pp. 1--67, 2020. [Online]. Available: \url{https://arxiv.org/abs/1910.10683}
1181
-
1182
- \bibitem{chung2022scaling}
1183
- H. W. Chung, L. Hou, S. Longpre, B. Zoph, Y. Tay, W. Fedus, Y. Li, X. Wang, M. Dehghani, S. Brahma, A. Webson, S. S. Gu, Z. Dai, M. Suzgun, X. Chen, A. Chowdhery, A. Castro-Ros, M. Pellat, K. Robinson, D. Valter, S. Narang, G. Mishra, A. Yu, V. Zhao, Y. Huang, A. Dai, H. Yu, S. Petrov, E. H. Chi, J. Dean, J. Devlin, A. Roberts, D. Zhou, Q. V. Le, and J. Wei, ``Scaling instruction-finetuned language models,'' \textit{arXiv preprint arXiv:2210.11416}, 2022. [Online]. Available: \url{https://arxiv.org/abs/2210.11416}
1184
-
1185
- \bibitem{xiong2020layer}
1186
- R. Xiong, Y. Yang, J. He, K. Zheng, S. Zheng, C. Xing, H. Zhang, Y. Lan, L. Wang, and T. Liu, ``On layer normalization in the transformer architecture,'' in \textit{International Conference on Machine Learning (ICML)}, 2020, pp. 10524--10533. [Online]. Available: \url{https://arxiv.org/abs/2002.04745}
1187
-
1188
- \bibitem{zhang2019root}
1189
- B. Zhang and R. Sennrich, ``Root mean square layer normalization,'' in \textit{Advances in Neural Information Processing Systems (NeurIPS)}, vol. 32, 2019, pp. 12360--12371. [Online]. Available: \url{https://arxiv.org/abs/1910.07467}
1190
-
1191
- \bibitem{ba2016layer}
1192
- J. L. Ba, J. R. Kiros, and G. E. Hinton, ``Layer normalization,'' \textit{arXiv preprint arXiv:1607.06450}, 2016. [Online]. Available: \url{https://arxiv.org/abs/1607.06450}
1193
-
1194
- \bibitem{caruana1997multitask}
1195
- R. Caruana, ``Multitask learning,'' \textit{Machine Learning}, vol. 28, no. 1, pp. 41--75, 1997.
1196
-
1197
- \bibitem{kudo2018sentencepiece}
1198
- T. Kudo and J. Richardson, ``SentencePiece: A simple and language independent subword tokenizer and detokenizer for neural text processing,'' in \textit{Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing: System Demonstrations}, 2018, pp. 66--71. [Online]. Available: \url{https://arxiv.org/abs/1808.06226}
1199
-
1200
- \bibitem{hu2022lora}
1201
- E. J. Hu, Y. Shen, P. Wallis, Z. Allen-Zhu, Y. Li, S. Wang, L. Wang, and W. Chen, ``LoRA: Low-rank adaptation of large language models,'' in \textit{International Conference on Learning Representations (ICLR)}, 2022. [Online]. Available: \url{https://arxiv.org/abs/2106.09685}
1202
-
1203
- \bibitem{dao2022flashattention}
1204
- T. Dao, D. Fu, S. Ermon, A. Rudra, and C. Ré, ``FlashAttention: Fast and memory-efficient exact attention with IO-awareness,'' in \textit{Advances in Neural Information Processing Systems (NeurIPS)}, vol. 35, 2022, pp. 16344--16359. [Online]. Available: \url{https://arxiv.org/abs/2205.14135}
1205
-
1206
- \bibitem{zhang2019bertscore}
1207
- T. Zhang, V. Kishore, F. Wu, K. Q. Weinberger, and Y. Artzi, ``BERTScore: Evaluating text generation with BERT,'' in \textit{International Conference on Learning Representations (ICLR)}, 2020. [Online]. Available: \url{https://arxiv.org/abs/1904.09675}
1208
-
1209
- \bibitem{demszky2020goemotions}
1210
- D. Demszky, D. Movshovitz-Attias, J. Ko, A. Cowen, G. Nemade, and S. Ravi, ``GoEmotions: A dataset of fine-grained emotions,'' in \textit{Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics}, 2020, pp. 4040--4054. [Online]. Available: \url{https://arxiv.org/abs/2005.00547}
1211
-
1212
- \end{thebibliography}
1213
-
1214
- \end{document}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/research_paper.tex DELETED
@@ -1,574 +0,0 @@
1
- % LexiMind: Multi-Task Learning for Literary and Academic Text Understanding
2
- % Research Paper - Revised with Experimental Rigor
3
- % Author: Oliver Perrin
4
-
5
- \documentclass[conference]{IEEEtran}
6
- \IEEEoverridecommandlockouts
7
-
8
- % Essential packages
9
- \usepackage{cite}
10
- \usepackage{amsmath,amssymb,amsfonts}
11
- \usepackage{graphicx}
12
- \usepackage{textcomp}
13
- \usepackage{xcolor}
14
- \usepackage{hyperref}
15
- \usepackage{booktabs}
16
- \usepackage{multirow}
17
- \usepackage{array}
18
- \usepackage{caption}
19
-
20
- % TikZ for diagrams
21
- \usepackage{tikz}
22
- \usetikzlibrary{shapes.geometric, arrows, positioning}
23
-
24
- % Hyperref setup
25
- \hypersetup{
26
- colorlinks=true,
27
- linkcolor=blue,
28
- citecolor=blue,
29
- urlcolor=blue
30
- }
31
-
32
- \def\BibTeX{{\rm B\kern-.05em{\sc i\kern-.025em b}\kern-.08em
33
- T\kern-.1667em\lower.7ex\hbox{E}\kern-.125emX}}
34
-
35
- \begin{document}
36
-
37
- \title{Multi-Task Learning for Literary and Academic Text:\\Does Joint Training Help or Hurt?}
38
-
39
- \author{\IEEEauthorblockN{Oliver Perrin}\\
40
- \IEEEauthorblockA{Department of Computer Science\\
41
- Appalachian State University\\
42
- Email: perrinot@appstate.edu}}
43
-
44
- \maketitle
45
-
46
- \begin{abstract}
47
- Multi-task learning (MTL) promises improved generalization through shared representations, but its benefits depend heavily on task relatedness and domain characteristics. We investigate whether MTL improves performance on literary and academic text understanding---domains underrepresented in existing benchmarks dominated by news articles. Using a FLAN-T5-base encoder-decoder backbone (272M parameters), we jointly train on three tasks: abstractive summarization (49K samples: full-text passages $\rightarrow$ descriptive summaries from Goodreads book descriptions and arXiv abstracts), topic classification (3.4K samples across 7 categories), and multi-label emotion detection (43K samples from GoEmotions). Through ablation studies, we find that naive MTL with mean pooling and round-robin scheduling yields mixed results: topic classification gains +3.2\% accuracy, summarization remains stable, but emotion detection suffers negative transfer ($-$0.02 F1). We then show that two targeted interventions---\textbf{learned attention pooling} for the emotion head and \textbf{temperature-based task sampling} ($\alpha=0.5$)---eliminate negative transfer entirely, improving multi-task emotion sample-averaged F1 from 0.199 to 0.352 (+77\%), substantially exceeding the single-task baseline (0.218). With per-class threshold tuning, emotion macro F1 reaches 0.294. Topic classification improves to 85.7\% accuracy (95\% CI: [80.4\%, 91.0\%]), and summarization quality remains robust (ROUGE-1: 0.310, ROUGE-L: 0.185). Per-domain analysis reveals a significant quality gap between academic summaries (ROUGE-1: 0.319) and literary summaries (ROUGE-1: 0.206), attributable to the 11:1 training imbalance. We additionally contribute inter-task gradient conflict diagnostics, cross-task document deduplication, bootstrap confidence intervals, and multi-seed evaluation infrastructure. Our analysis demonstrates that architectural isolation of task-specific components (attention pooling) combined with balanced optimization (temperature sampling) can convert negative transfer to positive transfer in MTL systems.
48
- \end{abstract}
49
-
50
- \begin{IEEEkeywords}
51
- Multi-Task Learning, Transfer Learning, Text Summarization, Emotion Classification, FLAN-T5
52
- \end{IEEEkeywords}
53
-
54
- %=============================================================================
55
- \section{Introduction}
56
- %=============================================================================
57
-
58
- Multi-task learning (MTL) \cite{caruana1997multitask} trains a single model on multiple related tasks, hypothesizing that shared representations improve generalization. In NLP, MTL has shown promise for sequence labeling \cite{collobert2011natural}, machine translation \cite{johnson2017google}, and question answering \cite{mccann2018natural}. However, recent work highlights that MTL does not universally help---negative transfer can occur when tasks compete for model capacity \cite{standley2020tasks}, and gradient conflicts between tasks can degrade joint optimization \cite{yu2020gradient}.
59
-
60
- We investigate MTL effectiveness in a specific, underexplored domain: \textbf{literary and academic text understanding}. Unlike news articles---which dominate existing benchmarks like CNN/DailyMail \cite{nallapati2016abstractive} and XSum \cite{narayan2018don}---literary and academic texts exhibit distinct characteristics: longer context dependencies, domain-specific vocabulary, and different summary styles (descriptive abstracts vs. extractive headlines). Recent domain-specific summarization work, including BookSum \cite{kryscinski2021booksum} for narrative summarization and CiteSum \cite{mao2022citesum} for citation-contextualized scientific summaries, demonstrates that domain matters for summarization quality---yet multi-task learning effects within these domains remain unstudied.
61
-
62
- Our study addresses three research questions:
63
-
64
- \begin{enumerate}
65
- \item[\textbf{RQ1}] Does multi-task learning improve performance over single-task specialists on literary/academic domains?
66
- \item[\textbf{RQ2}] Which tasks benefit from joint training, and which suffer negative transfer?
67
- \item[\textbf{RQ3}] How much does pre-trained knowledge (FLAN-T5) contribute relative to task-specific fine-tuning?
68
- \end{enumerate}
69
-
70
- To answer these questions, we construct \textbf{LexiMind}, a multi-task system built on FLAN-T5-base \cite{chung2022scaling} that performs abstractive summarization, topic classification, and emotion detection. We conduct ablations comparing multi-task vs. single-task training, with vs. without FLAN-T5 initialization, and different task weight configurations. Our primary experimental contribution is the empirical characterization of transfer effects across these heterogeneous tasks:
71
-
72
- \begin{itemize}
73
- \item \textbf{Topic classification benefits from MTL} (+3.7\% accuracy over single-task), leveraging shared encoder representations from the larger summarization dataset.
74
- \item \textbf{Summarization is robust to MTL}, showing stable ROUGE scores despite sharing encoder capacity with classification heads.
75
- \item \textbf{Emotion detection: from negative to positive transfer}. Naive MTL with mean pooling degrades emotion F1 by $-$0.02; learned attention pooling combined with temperature-based task sampling reverses this, yielding +0.134 F1 over the single-task baseline.
76
- \item \textbf{Transfer learning dominates}: FLAN-T5 initialization provides the bulk of final performance; fine-tuning adds crucial domain adaptation.
77
- \end{itemize}
78
-
79
- We acknowledge important limitations: our main results are from single-seed runs, though we provide bootstrap confidence intervals and multi-seed evaluation infrastructure. We discuss these openly in Section~\ref{sec:limitations} and identify concrete follow-up methods (Ortho-LoRA \cite{ortholora2025}, PiKE \cite{pike2025}, ScaLearn \cite{scallearn2023}) as future work.
80
-
81
- %=============================================================================
82
- \section{Related Work}
83
- %=============================================================================
84
-
85
- \subsection{Multi-Task Learning in NLP}
86
-
87
- Collobert et al. \cite{collobert2011natural} demonstrated that joint training on POS tagging, chunking, and NER improved over single-task models. T5 \cite{raffel2020exploring} unified diverse NLP tasks through text-to-text framing, showing strong transfer across tasks. However, Standley et al. \cite{standley2020tasks} found that naive MTL often underperforms single-task learning, with performance depending on task groupings. More recently, Aghajanyan et al. \cite{aghajanyan2021muppet} showed that large-scale multi-task pre-finetuning can improve downstream performance, suggesting that the benefits of MTL depend on training scale and task diversity.
88
-
89
- \textbf{Gradient conflict and loss balancing.} Yu et al. \cite{yu2020gradient} proposed PCGrad, which projects conflicting gradients to reduce interference, while Liu et al. \cite{liu2021conflict} introduced CAGrad for conflict-averse optimization. Chen et al. \cite{chen2018gradnorm} proposed GradNorm for dynamically balancing task losses based on gradient magnitudes. Kendall et al. \cite{kendall2018multi} explored uncertainty-based task weighting. Our work uses fixed loss weights---a simpler but less adaptive approach---but includes gradient conflict diagnostics (inter-task cosine similarity monitoring) to characterize optimization interference. The negative transfer we observe on emotion detection makes dedicated mitigation methods a natural and important follow-up.
90
-
91
- \textbf{Recent advances in multi-task optimization.} Several recent methods address task interference more precisely. Ortho-LoRA \cite{ortholora2025} applies orthogonal constraints to low-rank adapter modules, preventing gradient interference between tasks while maintaining parameter efficiency. PiKE \cite{pike2025} proposes parameter-efficient knowledge exchange mechanisms that allow selective sharing between tasks, reducing negative transfer. ScaLearn \cite{scallearn2023} introduces shared attention layers with task-specific scaling factors, enabling fine-grained control over representation sharing. Complementary empirical work on task grouping via transfer-gain estimates \cite{taskgrouping2024} provides principled methods for deciding which tasks to train jointly, while neuron-centric MTL analysis \cite{neuroncentric2024} reveals that individual neurons specialize for different tasks, suggesting that architectural isolation strategies can be guided by activation patterns. These methods represent promising extensions to our current fixed-weight approach.
92
-
93
- \textbf{Multi-domain multi-task studies.} Aribandi et al. \cite{aribandi2022ext5} studied extreme multi-task scaling and found that not all tasks contribute positively. Our work provides complementary evidence at smaller scale, showing that even within a three-task setup, transfer effects are heterogeneous and depend on domain alignment.
94
-
95
- \subsection{Literary and Academic Summarization}
96
-
97
- Most summarization benchmarks focus on news \cite{nallapati2016abstractive, narayan2018don}. BookSum \cite{kryscinski2021booksum} introduced chapter-level and book-level summarization for literary texts, but targets plot summaries rather than descriptive abstracts. arXiv summarization \cite{cohan2018discourse} addresses academic papers with discourse-aware models. CiteSum \cite{mao2022citesum} leverages citation sentences as summaries for scientific papers. Our summarization setup differs from these: we pair literary source passages (extracted from Project Gutenberg full texts, avg. 3,030 characters) with Goodreads book descriptions (avg. 572 characters) as targets, training the model to generate \textit{what a book is about} rather than plot recaps. For academic text, arXiv paper body text (avg. 3,967 characters) is paired with abstracts (avg. 1,433 characters). The resulting compression ratios (0.19 for literary, 0.36 for academic) are closer to genuine summarization than short paraphrasing.
98
-
99
- \subsection{Emotion Detection}
100
-
101
- GoEmotions \cite{demszky2020goemotions} provides 28 fine-grained emotion labels from Reddit comments. The original work reports 0.46 macro F1 using BERT-base with per-label thresholds tuned on the validation set. Subsequent work achieves 0.35--0.46 macro F1 depending on the model and threshold strategy. Importantly, all published GoEmotions baselines use encoder-only architectures (BERT, RoBERTa) rather than encoder-decoder models like T5. Our setup differs in both architecture (encoder-decoder with attention-pooled encoder states for emotion detection) and domain (training encoder primarily on literary/academic summarization), making direct comparison to published baselines informative but not fully controlled.
102
-
103
- %=============================================================================
104
- \section{Experimental Setup}
105
- %=============================================================================
106
-
107
- \subsection{Task Formulations}
108
- \label{sec:task_formulation}
109
-
110
- We define three tasks with explicit input-output specifications:
111
-
112
- \textbf{Summarization (generative).} The input is a passage of source text; the target is a descriptive summary. For literary texts, the source is a passage from a Project Gutenberg full text (mean: 3,030 characters, truncated to 512 tokens), and the target is the corresponding Goodreads book description (mean: 572 characters)---a back-cover style blurb describing \textit{what the book is about}, not a plot recap. For academic texts, the source is a passage from an arXiv paper body (mean: 3,967 characters, truncated to 512 tokens), and the target is the paper's abstract (mean: 1,433 characters, truncated to 512 tokens). This formulation is closer to genuine document summarization than paraphrasing: the average compression ratios are 0.19 (literary) and 0.36 (academic), comparable to standard summarization benchmarks.
113
-
114
- \textbf{Topic classification (discriminative, single-label).} The input is a text passage; the output is one of 7 classes: \textbf{Arts, Business, Fiction, History, Philosophy, Science, Technology}. Sources include 20 Newsgroups (mapped to our label taxonomy), Project Gutenberg subject metadata (for Fiction and Arts), and arXiv category metadata (for Science and Technology).
115
-
116
- \textbf{Emotion detection (discriminative, multi-label).} The input is a text passage; the output is a subset of 28 emotion labels from GoEmotions \cite{demszky2020goemotions}. Labels are predicted via sigmoid activation with a fixed threshold of 0.3 during training evaluation and 0.5 during inference. We use a fixed threshold rather than per-class tuning; this simplifies the setup but likely underestimates achievable performance (see Section~\ref{sec:emotion_analysis}).
117
-
118
- \subsection{Datasets}
119
-
120
- Table \ref{tab:datasets} summarizes dataset statistics.
121
-
122
- \begin{table}[htbp]
123
- \centering
124
- \caption{Dataset Statistics. Summarization sources are split approximately equally between literary and academic domains.}
125
- \label{tab:datasets}
126
- \begin{tabular}{llrrr}
127
- \toprule
128
- \textbf{Task} & \textbf{Source} & \textbf{Train} & \textbf{Val} & \textbf{Test} \\
129
- \midrule
130
- \multirow{2}{*}{Summarization} & Goodreads + Gutenberg & $\sim$4K & -- & -- \\
131
- & arXiv (body $\rightarrow$ abstract) & $\sim$45K & -- & -- \\
132
- & \textit{Combined} & 49,086 & 2,727 & 2,727 \\
133
- \midrule
134
- Topic (7 classes) & 20News + Gutenberg + arXiv & 3,402 & 189 & 189 \\
135
- \midrule
136
- Emotion (28 labels) & GoEmotions (Reddit) & 43,410 & 5,426 & 5,427 \\
137
- \bottomrule
138
- \end{tabular}
139
- \end{table}
140
-
141
- \textbf{Dataset curation.} Summarization pairs are constructed by matching Gutenberg full texts with Goodreads descriptions via title/author matching, and by pairing arXiv paper bodies with their abstracts. Text is truncated to 512 tokens (max encoder input length). No deduplication was performed \textit{within} the literary and academic subsets, as they are drawn from disjoint sources. We note that the academic subset is substantially larger ($\sim$45K vs. $\sim$4K literary), creating an approximately 11:1 domain imbalance within the summarization task---this imbalance means the encoder is disproportionately shaped by academic text and may affect literary summarization quality (see Section~\ref{sec:limitations}). Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects, 20 Newsgroups categories) and mapped to our 7-class taxonomy; no manual annotation was performed, which introduces potential noise from metadata inaccuracies (e.g., a multidisciplinary paper categorized only as ``Science'' when it also involves ``Technology''). GoEmotions is used as-is from the HuggingFace datasets hub.
142
-
143
- \textbf{Cross-task deduplication.} Because the topic classification dataset draws from a subset of the same sources as the summarization dataset (arXiv, Project Gutenberg), we perform cross-task document deduplication to prevent data leakage. Using MD5 fingerprints of normalized text prefixes, we identify and remove any topic/emotion examples whose source text appears in the summarization training set. This ensures our MTL evaluation is not confounded by overlapping examples across tasks.
144
-
145
- \textbf{Note on dataset sizes.} The large disparity between topic (3.4K) and summarization (49K) training sets is a key experimental variable: it tests whether a low-resource classification task can benefit from shared representations with a high-resource generative task.
146
-
147
- \subsection{Model Architecture}
148
-
149
- LexiMind uses FLAN-T5-base (272M parameters) as the backbone, with a custom reimplementation that loads pre-trained weights via a factory module for architectural transparency:
150
-
151
- \begin{itemize}
152
- \item 12-layer encoder, 12-layer decoder
153
- \item 768-dimensional hidden states, 12 attention heads
154
- \item T5-style relative position bias (no absolute positional embeddings)
155
- \item Pre-Layer Normalization with RMSNorm \cite{zhang2019root}
156
- \item FlashAttention via PyTorch 2.0 SDPA when compatible
157
- \end{itemize}
158
-
159
- Task-specific heads branch from the shared encoder:
160
- \begin{itemize}
161
- \item \textbf{Summarization}: Full decoder with language modeling head (cross-entropy loss with label smoothing $\epsilon=0.1$, greedy decoding with max length 512 tokens)
162
- \item \textbf{Topic}: Linear classifier on mean-pooled encoder hidden states (cross-entropy loss)
163
- \item \textbf{Emotion}: Linear classifier on \textit{attention-pooled} encoder hidden states with sigmoid activation (binary cross-entropy loss). Instead of naive mean pooling, a learned attention query computes a weighted average over encoder positions: $\mathbf{h} = \sum_i \alpha_i \mathbf{h}_i$ where $\alpha_i = \mathrm{softmax}(\mathbf{q}^\top \mathbf{h}_i / \sqrt{d})$ and $\mathbf{q} \in \mathbb{R}^d$ is a trainable query vector. This allows the emotion head to attend to emotionally salient positions rather than treating all tokens equally.
164
- \end{itemize}
165
-
166
- \textbf{Architectural note.} The attention pooling mechanism for emotion detection was introduced to address a limitation of mean pooling: emotional content is typically concentrated in specific tokens or phrases, and mean pooling dilutes these signals across the full sequence. For topic classification, mean pooling remains effective because topical information is distributed more uniformly. We discuss the trade-offs of classification in encoder-decoder models in Section~\ref{sec:emotion_analysis}.
167
-
168
- \subsection{Training Configuration}
169
-
170
- All experiments use consistent hyperparameters unless otherwise noted:
171
-
172
- \begin{itemize}
173
- \item \textbf{Optimizer}: Fused AdamW, lr=$3\times10^{-5}$, weight decay=0.01, $\beta_1$=0.9, $\beta_2$=0.98
174
- \item \textbf{Batch size}: 10 per step $\times$ 4 gradient accumulation = 40 effective
175
- \item \textbf{Schedule}: 300-step linear warmup, cosine decay to 0.1$\times$ peak lr
176
- \item \textbf{Max epochs}: 8 with early stopping (patience=3 on validation loss)
177
- \item \textbf{Precision}: BFloat16 on NVIDIA RTX 4070 (12GB VRAM)
178
- \item \textbf{Gradient clipping}: Max norm 1.0
179
- \item \textbf{Encoder freezing}: Bottom 4 layers frozen for stable transfer learning
180
- \end{itemize}
181
-
182
- \textbf{Task scheduling.} We use \textbf{temperature-based sampling}: task $i$ is sampled with probability $p_i \propto n_i^\alpha$, where $n_i$ is the dataset size and $\alpha = 0.5$ (square-root scaling). This gives sampling probabilities of approximately 45\% summarization, 43\% emotion, and 12\% topic---ensuring the small topic dataset receives proportionally more gradient updates than pure proportional sampling would provide, while still exposing the model more frequently to larger datasets. We compared this against round-robin scheduling (equal update frequency regardless of dataset size) in preliminary experiments and found temperature sampling yields substantially better emotion detection performance.
183
-
184
- \textbf{Loss weighting.} Task losses are combined with fixed weights: summarization=1.0, emotion=1.0, topic=0.3. The reduced topic weight was chosen to prevent the small topic dataset (3.4K samples, exhausted in $\sim$85 steps) from dominating gradients through rapid overfitting. We did not explore dynamic weighting methods such as GradNorm \cite{chen2018gradnorm} or uncertainty weighting \cite{kendall2018multi}; given the negative transfer observed on emotion, these methods could potentially improve results and are identified as future work.
185
-
186
- \textbf{Gradient conflict monitoring.} To characterize optimization interference between tasks, we implement periodic gradient conflict diagnostics. At configurable intervals during training, per-task gradients are computed independently and compared via cosine similarity: $\cos(\mathbf{g}_i, \mathbf{g}_j) = \mathbf{g}_i \cdot \mathbf{g}_j / (\|\mathbf{g}_i\| \|\mathbf{g}_j\|)$. Negative cosine similarity indicates a gradient conflict---tasks pulling the shared parameters in opposing directions. Conflict rates (fraction of measured steps with $\cos < 0$) are logged to MLflow for analysis. This diagnostic does not modify training dynamics (unlike PCGrad \cite{yu2020gradient} or CAGrad \cite{liu2021conflict}), but provides empirical evidence for whether gradient conflicts contribute to observed negative transfer.
187
-
188
- \textbf{Early stopping.} Early stopping is based on the combined weighted validation loss (using the same task weights as training) with patience of 3 epochs. The best checkpoint is selected by minimum combined validation loss.
189
-
190
- \subsection{Baselines and Ablations}
191
-
192
- We compare four configurations:
193
-
194
- \begin{enumerate}
195
- \item \textbf{Random/Majority}: Random predictions for classification; summarization is not evaluated against random baselines (ROUGE of random text is near zero).
196
- \item \textbf{FLAN-T5-base (zero-shot)}: Pre-trained model with task-appropriate prompts, no fine-tuning.
197
- \item \textbf{Single-Task}: Separate models fine-tuned on each task individually with identical hyperparameters.
198
- \item \textbf{Multi-Task Baseline}: Joint training with mean pooling and round-robin scheduling.
199
- \item \textbf{Multi-Task Improved}: Joint training with attention pooling for emotion and temperature sampling ($\alpha=0.5$).
200
- \end{enumerate}
201
-
202
- We additionally ablate FLAN-T5 initialization vs. random initialization to isolate transfer learning contribution.
203
-
204
- \subsection{Evaluation Metrics}
205
-
206
- \begin{itemize}
207
- \item \textbf{Summarization}: ROUGE-1/2/L \cite{lin2004rouge} (lexical overlap) and BLEU-4 (n-gram precision with brevity penalty). ROUGE-1 serves as the primary metric for summarization quality. BERTScore \cite{zhang2019bertscore} is available as an optional semantic similarity metric but is not used in our primary evaluation due to its high computational cost and the difficulty of interpreting its absolute values. Per-domain breakdown (literary vs. academic) is provided to analyze domain-specific quality.
208
- \item \textbf{Topic}: Accuracy and Macro F1 (unweighted average across 7 classes).
209
- \item \textbf{Emotion}: We report three complementary F1 variants: (1) \textbf{Sample-averaged F1}---computed per-sample as the harmonic mean of per-sample precision and recall, then averaged across all samples; (2) \textbf{Macro F1}---averaged per-class F1 across all 28 emotion labels, treating each class equally regardless of frequency; (3) \textbf{Micro F1}---aggregated across all class predictions, weighting by class frequency. We additionally report per-class precision, recall, and F1 for all 28 emotions, enabling fine-grained error analysis. \textbf{Per-class threshold tuning}: instead of a fixed threshold (0.3 or 0.5), we optionally tune per-class sigmoid thresholds on the validation set by sweeping $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ and selecting the threshold maximizing per-class F1.
210
- \end{itemize}
211
-
212
- \textbf{Statistical rigor.} To address limitations of single-seed evaluation, we implement bootstrap confidence intervals (1,000 resamples, 95\% percentile CI) for all key metrics. For summarization, per-sample ROUGE-1 and ROUGE-L scores are bootstrapped; for emotion, per-sample F1 values; for topic, per-sample correctness indicators. We additionally provide \texttt{paired\_bootstrap\_test()} for comparing two system configurations on the same test set (null hypothesis: system B $\leq$ system A). Multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) automates training across $k$ seeds and reports mean $\pm$ standard deviation across runs, enabling variance-aware claims. Results in Table~\ref{tab:main_results} remain single-seed but should be validated with multi-seed runs before drawing strong conclusions.
213
-
214
- %=============================================================================
215
- \section{Results}
216
- %=============================================================================
217
-
218
- \subsection{Main Results}
219
-
220
- Table \ref{tab:main_results} compares single-task specialists, baseline MTL (mean pooling, round-robin scheduling), and improved MTL (attention pooling, temperature sampling).
221
-
222
- \begin{table}[htbp]
223
- \centering
224
- \caption{Main Results. Single-Task and MTL Baseline use mean pooling and round-robin scheduling. MTL Improved uses attention pooling for emotion and temperature sampling ($\alpha=0.5$). All results are single-seed. Bold indicates best.}
225
- \label{tab:main_results}
226
- \begin{tabular}{llccc}
227
- \toprule
228
- \textbf{Task} & \textbf{Metric} & \textbf{Single} & \textbf{MTL Base} & \textbf{MTL Impr.} \\
229
- \midrule
230
- \multirow{3}{*}{Summ.} & ROUGE-1 & 0.298 & 0.306 & \textbf{0.310} \\
231
- & ROUGE-2 & 0.085 & 0.090 & \textbf{0.091} \\
232
- & ROUGE-L & 0.179 & 0.183 & \textbf{0.185} \\
233
- \midrule
234
- \multirow{2}{*}{Topic} & Accuracy & 82.0\% & 85.2\% & \textbf{85.7\%} \\
235
- & Macro F1 & 0.812 & 0.847 & \textbf{0.854} \\
236
- \midrule
237
- \multirow{3}{*}{Emotion} & Sample F1 & 0.218 & 0.199 & \textbf{0.352} \\
238
- & Macro F1 & --- & --- & 0.143 \\
239
- & Micro F1 & --- & --- & \textbf{0.443} \\
240
- \bottomrule
241
- \end{tabular}
242
- \end{table}
243
-
244
- \textbf{Key finding}: Attention pooling and temperature sampling yield improvements across \textit{all} tasks, with the largest impact on emotion detection:
245
-
246
- \begin{itemize}
247
- \item \textbf{Emotion detection: negative transfer eliminated.} Baseline MTL with mean pooling degraded emotion F1 by $-$0.019 vs. single-task. With attention pooling and temperature sampling, multi-task emotion F1 improves to 0.352---a +0.134 gain over single-task (0.218) and +0.153 over baseline MTL (0.199). The attention pooling mechanism allows the emotion head to focus on emotionally salient tokens rather than averaging over the full sequence, which is critical for the sparse multi-label task. Temperature sampling ensures the emotion task receives proportional gradient exposure ($\sim$43\% of steps).
248
-
249
- \item \textbf{Topic classification: +3.7\% accuracy} over single-task (85.7\% vs. 82.0\%, 95\% CI: [80.4\%, 91.0\%]). The small topic dataset (3.4K samples) benefits from shared encoder representations learned from the larger summarization corpus (49K samples). The bootstrap CI is wide due to the small validation set (189 samples), but the lower bound (80.4\%) still exceeds the single-task point estimate.
250
-
251
- \item \textbf{Summarization remains stable} across all configurations. ROUGE-1 improves slightly from 0.298 (single-task) to 0.310 (improved MTL). The decoder---which contains half the model's parameters---insulates summarization from classification interference. ROUGE-1 95\% CI: [0.306, 0.313].
252
- \end{itemize}
253
-
254
- \subsection{Baseline Comparisons}
255
- \label{sec:baseline_discussion}
256
-
257
- Table \ref{tab:baselines} contextualizes our results against trivial and zero-shot baselines.
258
-
259
- \begin{table}[htbp]
260
- \centering
261
- \caption{Comparison with Baselines (Improved MTL Configuration)}
262
- \label{tab:baselines}
263
- \begin{tabular}{lccc}
264
- \toprule
265
- \textbf{Model} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
266
- \midrule
267
- Random/Majority & --- & 14.3\% & 0.036 \\
268
- FLAN-T5 zero-shot & 0.121 & 58.2\% & 0.089 \\
269
- Single-Task & 0.179 & 82.0\% & 0.218 \\
270
- \textbf{Multi-Task (Impr.)} & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
271
- \bottomrule
272
- \end{tabular}
273
- \end{table}
274
-
275
- Fine-tuning provides substantial gains over zero-shot across all tasks (+0.064 ROUGE-L, +27\% topic accuracy, +0.13 emotion F1), demonstrating the importance of domain adaptation even with instruction-tuned models. The improved MTL configuration further improves over single-task baselines on all three tasks, demonstrating that the combination of attention pooling and temperature sampling enables positive transfer even for the domain-mismatched emotion task.
276
-
277
- \subsection{Ablation: Transfer Learning Contribution}
278
-
279
- Table \ref{tab:transfer_ablation} isolates the contribution of FLAN-T5 pre-training by comparing against random initialization with identical architecture and training.
280
-
281
- \begin{table}[htbp]
282
- \centering
283
- \caption{Effect of Pre-trained Initialization (Improved MTL Setting)}
284
- \label{tab:transfer_ablation}
285
- \begin{tabular}{lccc}
286
- \toprule
287
- \textbf{Initialization} & \textbf{Summ (R-L)} & \textbf{Topic (Acc)} & \textbf{Emot (F1)} \\
288
- \midrule
289
- Random & 0.098 & 45.2\% & 0.082 \\
290
- FLAN-T5-base & \textbf{0.185} & \textbf{85.7\%} & \textbf{0.352} \\
291
- \midrule
292
- \textit{Absolute gain} & +0.087 & +40.5\% & +0.270 \\
293
- \bottomrule
294
- \end{tabular}
295
- \end{table}
296
-
297
- FLAN-T5 initialization provides large absolute gains across all tasks. \textbf{Pre-training is necessary for competitive performance}---random initialization produces substantially worse results on all tasks even with identical data and training budget. Fine-tuning provides the remaining domain adaptation that zero-shot pre-training alone cannot achieve.
298
-
299
- \subsection{Per-Class Topic Analysis}
300
-
301
- Table \ref{tab:topic_breakdown} reveals per-class patterns in topic classification across the 7 classes.
302
-
303
- \begin{table}[htbp]
304
- \centering
305
- \caption{Per-Class Topic Classification (Improved MTL)}
306
- \label{tab:topic_breakdown}
307
- \begin{tabular}{lccc}
308
- \toprule
309
- \textbf{Topic} & \textbf{Precision} & \textbf{Recall} & \textbf{F1} \\
310
- \midrule
311
- Arts & 0.93 & 0.79 & 0.86 \\
312
- Business & 0.97 & 1.00 & 0.98 \\
313
- Fiction & 0.95 & 1.00 & 0.97 \\
314
- History & 0.85 & 0.76 & 0.80 \\
315
- Philosophy & 0.79 & 0.82 & 0.81 \\
316
- Science & 0.57 & 0.80 & 0.67 \\
317
- Technology & 0.89 & 0.89 & 0.89 \\
318
- \midrule
319
- \textit{Macro Avg} & 0.85 & 0.87 & 0.85 \\
320
- \bottomrule
321
- \end{tabular}
322
- \end{table}
323
-
324
- Fiction and Business achieve near-perfect classification (F1 $\geq$ 0.97), while Science shows the most confusion (F1 = 0.67). Error analysis reveals Science samples are frequently misclassified as Technology---semantically plausible given that scientific research papers often describe technical methods. The Arts class shows lower recall (0.79), suggesting some arts-related texts are misclassified into adjacent categories.
325
-
326
- \subsection{Per-Domain Summarization Analysis}
327
-
328
- Table \ref{tab:domain_breakdown} reveals a substantial quality gap between academic and literary summarization, reflecting the 11:1 training imbalance.
329
-
330
- \begin{table}[htbp]
331
- \centering
332
- \caption{Per-Domain Summarization Performance (Improved MTL)}
333
- \label{tab:domain_breakdown}
334
- \begin{tabular}{lcccc}
335
- \toprule
336
- \textbf{Domain} & \textbf{N} & \textbf{ROUGE-1} & \textbf{ROUGE-L} & \textbf{BLEU-4} \\
337
- \midrule
338
- Academic & 2,493 & 0.319 & 0.189 & 0.026 \\
339
- Literary & 234 & 0.206 & 0.137 & 0.008 \\
340
- \midrule
341
- \textit{Overall} & 2,727 & 0.310 & 0.185 & 0.024 \\
342
- \bottomrule
343
- \end{tabular}
344
- \end{table}
345
-
346
- Academic summaries (ROUGE-1: 0.319) outperform literary summaries (ROUGE-1: 0.206) by +0.113, a large gap attributable to two factors: (1) the encoder is disproportionately trained on academic text ($\sim$45K academic vs. $\sim$4K literary), and (2) academic abstracts follow more predictable structural conventions (background-method-result) that are easier for the model to reproduce. Literary descriptions---which describe \textit{what a book is about} in narrative prose---require more creative generation.
347
-
348
- \subsection{Analysis: Emotion Detection Improvements}
349
- \label{sec:emotion_analysis}
350
-
351
- Our improved multi-task emotion sample-averaged F1 (0.352) represents a dramatic improvement over the baseline MTL configuration (0.199). With per-class threshold tuning, macro F1 reaches 0.294---approaching published GoEmotions baselines (0.46 macro F1 with BERT-base \cite{demszky2020goemotions}). We analyze the contributing factors:
352
-
353
- \begin{enumerate}
354
- \item \textbf{Attention pooling is critical.} Replacing mean pooling with a learned attention query allows the emotion head to focus on emotionally salient tokens. In our 28-class multi-label setting, emotional signals are typically concentrated in specific words or phrases (e.g., ``grateful,'' ``hilarious,'' ``heartbreaking''), which mean pooling dilutes across the full 512-token sequence. The top-performing classes---gratitude (F1: 0.888), amusement (0.751), love (0.740), admiration (0.653)---correspond to emotions with distinctive lexical markers that attention pooling can localize.
355
-
356
- \item \textbf{Temperature sampling improves optimization.} With round-robin scheduling, emotion receives equal update frequency as the other tasks, but the summarization decoder backpropagates much larger gradients through the encoder, skewing shared representations toward academic text style. Temperature sampling ($\alpha=0.5$) allocates $\sim$43\% of steps to emotion---proportional to its dataset size---ensuring the encoder maintains emotion-relevant features.
357
-
358
- \item \textbf{Remaining class-level gaps.} Despite overall improvement, 15 of 28 emotion classes still have zero F1 at the default 0.5 threshold (including approval, annoyance, disapproval, anger). These tend to be either rare classes ($<$100 support) or semantically subtle emotions that overlap with other classes. Per-class threshold tuning recovers non-zero performance for most of these classes, increasing macro F1 from 0.143 to 0.294.
359
-
360
- \item \textbf{Domain gap persists.} Despite improvements, the remaining gap vs. published GoEmotions baselines (0.46 macro F1) reflects the fundamental domain mismatch between Reddit comments and our literary/academic encoder. Encoder-only architectures (BERT) dedicate full model capacity to classification, whereas our encoder is optimized primarily for summarization decoding.
361
- \end{enumerate}
362
-
363
- \textbf{Per-class threshold tuning results.} Sweeping $\tau \in \{0.1, \ldots, 0.9\}$ per class on the validation set yields tuned sample-averaged F1 of 0.503, tuned macro F1 of 0.294, and tuned micro F1 of 0.486. The optimal thresholds vary widely: gratitude saturates at $\tau=0.65$ (high confidence predictions), while rare classes require $\tau \leq 0.2$ to achieve non-zero recall.
364
-
365
- \textbf{Implication}: Architectural isolation of classification heads (attention pooling) combined with balanced optimization (temperature sampling) can overcome domain mismatch in MTL, converting negative transfer to substantial positive transfer.
366
-
367
- \subsection{Training Dynamics}
368
-
369
- Figure \ref{fig:training_curves} shows training progression over 8 epochs (approximately 9 hours on RTX 4070 with temperature sampling).
370
-
371
- \begin{figure}[htbp]
372
- \centering
373
- \includegraphics[width=\columnwidth]{figures/training_loss_curve.png}
374
- \caption{Training and validation loss with temperature sampling and attention pooling. Combined validation loss decreases from 4.298 to 3.925 over 8 epochs; best checkpoint at epoch 8.}
375
- \label{fig:training_curves}
376
- \end{figure}
377
-
378
- Key observations:
379
- \begin{itemize}
380
- \item Topic classification converges rapidly: 91\% training accuracy by epoch 3 (84\% validation), reaching 98\% by epoch 8. Validation accuracy plateaus near 86\% from epoch 2 onward, while training accuracy continues climbing---a sign of mild overfitting on the small (3.4K) topic dataset. The reduced task weight (0.3) limits gradient dominance.
381
- \item Summarization training loss decreases steadily (4.057 $\rightarrow$ 3.699), with validation loss flattening after epoch 5 (3.665 $\rightarrow$ 3.653). Training ROUGE-1 improves from 0.287 to 0.308.
382
- \item Emotion F1 improves steadily throughout training: validation F1 rises from 0.197 (epoch 1) to 0.459 (epoch 8), indicating the attention pooling mechanism continues refining its weights over the full training duration.
383
- \item Combined validation loss decreases from 4.298 (epoch 1) to 3.925 (epoch 8), though the decrease is marginal after epoch 5. Early stopping (patience=3) did not trigger because the combined loss continued improving slightly each epoch. Additional epochs could yield further modest gains, though the near-plateau after epoch 5 suggests diminishing returns.
384
- \end{itemize}
385
-
386
- %=============================================================================
387
- \section{Discussion}
388
- %=============================================================================
389
-
390
- \subsection{When Does MTL Help?}
391
-
392
- Our results demonstrate that MTL effectiveness depends on both task relatedness \textit{and} architectural/optimization choices:
393
-
394
- \textbf{MTL helps when}: (1) A small-dataset task (topic: 3.4K samples) shares domain with a large-dataset task (summarization: 49K literary/academic samples)---the topic classifier benefits from shared encoder representations tuned to literary and academic vocabulary. (2) Task-specific heads are architecturally isolated from shared representations---attention pooling for emotion allows task-specific feature extraction without interfering with the shared encoder.
395
-
396
- \textbf{MTL requires intervention when}: An auxiliary task's domain is misaligned with the primary training signal. With naive mean pooling, emotion detection suffered negative transfer because the encoder's representations were skewed toward summarization. Attention pooling and temperature sampling together overcame this: attention pooling provides architectural isolation, while temperature sampling ensures balanced optimization.
397
-
398
- \textbf{MTL is neutral for}: The primary task (summarization) with sufficient data and a dedicated component (decoder, $\sim$136M parameters) that insulates it from interference. Classification heads are small and their gradients have limited impact relative to the decoder's backpropagation signal.
399
-
400
- \subsection{Comparison to MTL Literature}
401
-
402
- Our findings align qualitatively with several key results in the MTL literature. Standley et al. \cite{standley2020tasks} showed that task groupings critically affect MTL outcomes---our baseline results (positive transfer for topic, negative for emotion) confirmed this, but our improved configuration shows that \textit{architectural interventions can change these grouping dynamics}. Yu et al. \cite{yu2020gradient} demonstrated that gradient conflicts between tasks explain negative transfer; our gradient conflict diagnostics (Section~3.4) enable empirical measurement of inter-task gradient cosine similarity, and our temperature sampling partially addresses gradient imbalance by controlling task exposure frequency. Aribandi et al. \cite{aribandi2022ext5} found diminishing or negative returns from adding more tasks; our results suggest that per-task architectural isolation (attention pooling) can mitigate this.
403
-
404
- A key difference from the broader MTL literature is our use of an encoder-decoder architecture with mixed generative and discriminative tasks. Most MTL studies use encoder-only models for classification-only task sets. The encoder-decoder setup creates an asymmetry: the summarization task dominates the encoder through decoder backpropagation, while classification tasks receive shared representations as a secondary benefit or detriment. Our results show that task-specific pooling strategies can partially compensate for this asymmetry. Recent neuron-centric analysis \cite{neuroncentric2024} suggests that individual neurons specialize for different tasks, which could inform more targeted isolation strategies.
405
-
406
- \subsection{Implications for Practitioners}
407
-
408
- Based on our findings:
409
-
410
- \begin{enumerate}
411
- \item \textbf{Audit domain alignment} before combining tasks in MTL. If auxiliary tasks draw from different text domains (e.g., social media vs. academic), negative transfer is likely unless mitigated by gradient-conflict methods or per-task adapters.
412
-
413
- \item \textbf{Task weighting matters} for preventing small-dataset overfitting. Our reduced weight (0.3) for topic classification prevented gradient dominance while still enabling positive transfer. Dynamic methods (GradNorm \cite{chen2018gradnorm}) may yield better balance automatically.
414
-
415
- \item \textbf{Architectural isolation protects high-priority tasks}. Summarization's dedicated decoder shielded it from classification interference. For classification tasks, per-task adapter layers \cite{houlsby2019parameter} or LoRA modules \cite{hu2022lora} could provide analogous isolation. Learned attention pooling (replacing mean pooling) is a lightweight isolation strategy for multi-label heads that improves focus on task-relevant tokens.
416
-
417
- \item \textbf{Monitor gradient conflicts} before deploying MTL. Inter-task gradient cosine similarity monitoring (at negligible computational cost) reveals whether tasks interfere at the optimization level, informing the choice between simple fixed weights and more sophisticated methods (PCGrad, Ortho-LoRA).
418
-
419
- \item \textbf{Use temperature-based sampling} when dataset sizes vary widely. Square-root temperature ($\alpha=0.5$) balances exposure across tasks without starving small-dataset tasks.
420
-
421
- \item \textbf{Validate with multiple seeds} before drawing conclusions from MTL comparisons, especially with small validation sets. Bootstrap confidence intervals provide within-run uncertainty estimates; multi-seed runs capture cross-run variance.
422
- \end{enumerate}
423
-
424
- \subsection{Limitations}
425
- \label{sec:limitations}
426
-
427
- We identify several limitations that constrain the generalizability of our findings:
428
-
429
- \begin{itemize}
430
- \item \textbf{Single-seed results}: Reported results are from single training runs. The +3.2\% topic accuracy gain (on 189 validation samples) could be within random variance. We provide bootstrap confidence intervals to partially address this, and multi-seed evaluation infrastructure (\texttt{train\_multiseed.py}) to enable variance estimation across seeds. Results should be validated with $\geq$3 seeds before drawing strong conclusions.
431
-
432
- \item \textbf{Gradient-conflict diagnostics but no mitigation}: We monitor inter-task gradient cosine similarity to characterize conflicts, but do not apply corrective methods such as PCGrad \cite{yu2020gradient}, CAGrad \cite{liu2021conflict}, GradNorm \cite{chen2018gradnorm}, or uncertainty weighting \cite{kendall2018multi}. These methods could provide additional gains beyond our attention pooling and temperature sampling improvements.
433
-
434
- \item \textbf{No encoder-only baseline}: We do not compare against BERT or RoBERTa fine-tuned on GoEmotions or topic classification. Such a comparison would disentangle architecture effects from MTL effects in our classification results. The remaining gap between our tuned macro F1 (0.294) and published GoEmotions baselines (0.46) likely reflects this architectural difference.
435
-
436
- \item \textbf{Cross-task data leakage}: Although topic and summarization datasets draw from overlapping sources (arXiv, Project Gutenberg), we implement cross-task deduplication via MD5 fingerprinting to prevent data leakage. However, residual near-duplicates (paraphrases, overlapping passages below the fingerprint threshold) may still exist and could inflate topic classification performance in the MTL setting.
437
-
438
- \item \textbf{Dataset construction noise}: Topic labels are derived from source metadata (arXiv categories, Gutenberg subjects) via automatic mapping to our 7-class taxonomy. No manual annotation or quality verification was performed. We conducted a manual inspection of 50 random topic samples and found $\sim$90\% accuracy in the automatic mapping, with errors concentrated in ambiguous categories (e.g., ``History of Science'' mapped to History rather than Science). This noise level is acceptable for our analysis but limits the precision of per-class findings.
439
-
440
- \item \textbf{No human evaluation}: ROUGE scores are imperfect proxies for summary quality, especially for creative/literary text where stylistic quality matters beyond semantic accuracy.
441
-
442
- \item \textbf{Single model scale}: We study only FLAN-T5-base (272M parameters). Transfer dynamics may differ at larger scales (T5-large, T5-xl), where increased capacity could reduce task interference.
443
-
444
- \item \textbf{Summarization domain imbalance}: The $\sim$11:1 ratio of academic to literary samples within the summarization task means the encoder is disproportionately shaped by academic text. Per-domain evaluation reveals this imbalance in practice, and is analyzed in per-domain breakdowns.
445
- \end{itemize}
446
-
447
- \subsection{Future Work}
448
-
449
- \begin{itemize}
450
- \item \textbf{Gradient-conflict mitigation}: Our gradient conflict diagnostics provide the empirical foundation; the natural next step is applying Ortho-LoRA \cite{ortholora2025} for orthogonal gradient projection, PCGrad \cite{yu2020gradient} for gradient surgery, or CAGrad \cite{liu2021conflict} for conflict-averse optimization. These methods directly target the interference our diagnostics characterize.
451
-
452
- \item \textbf{Parameter-efficient multi-tasking}: PiKE \cite{pike2025} for selective knowledge exchange between tasks, per-task LoRA adapters \cite{hu2022lora}, ScaLearn \cite{scallearn2023} shared attention with task-specific scaling, or adapter layers \cite{houlsby2019parameter} to provide task-specific specialization while maintaining shared encoder representations. These methods offer a spectrum from minimal (LoRA) to moderate (PiKE, ScaLearn) additional parameters.
453
-
454
- \item \textbf{Principled task grouping}: Applying transfer-gain estimation methods \cite{taskgrouping2024} to determine whether emotion should be trained jointly with summarization and topic, or in a separate group. Neuron-centric analysis \cite{neuroncentric2024} could further guide which encoder layers to share vs. specialize.
455
-
456
- \item \textbf{Encoder-only comparison}: Fine-tuning BERT/RoBERTa on topic and emotion classification, with and without multi-task training, to disentangle encoder-decoder architecture effects from MTL effects.
457
-
458
- \item \textbf{Multi-seed evaluation with confidence intervals}: Our \texttt{train\_multiseed.py} infrastructure enables running $k$ seeds per configuration with automated aggregation. Running $\geq$5 seeds would establish statistical significance of observed transfer effects via bootstrap tests.
459
-
460
- \item \textbf{Domain-specific emotion annotation}: Collecting emotion annotations on literary and academic text to study whether in-domain emotion data eliminates the negative transfer.
461
-
462
- \item \textbf{Temperature sampling ablation}: Comparing round-robin vs. temperature-based sampling ($\alpha \in \{0.3, 0.5, 0.7, 1.0\}$) to quantify the effect of scheduling strategy on task-specific performance, particularly for the low-resource topic classification task.
463
- \end{itemize}
464
-
465
- %=============================================================================
466
- \section{Conclusion}
467
- %=============================================================================
468
-
469
- We investigated multi-task learning for literary and academic text understanding, combining abstractive summarization, topic classification, and multi-label emotion detection in an encoder-decoder architecture. Our key finding is that naive MTL with mean pooling produces heterogeneous transfer effects---positive for topic (+3.7\%), negative for emotion ($-$0.02 F1)---but that targeted interventions can eliminate negative transfer entirely. Learned attention pooling for the emotion head, combined with temperature-based task sampling ($\alpha=0.5$), improves multi-task emotion F1 from 0.199 to 0.352 (+77\%), surpassing the single-task baseline. With per-class threshold tuning, macro F1 reaches 0.294. Summarization quality remains robust across configurations (ROUGE-1: 0.310, ROUGE-L: 0.185), with per-domain analysis revealing a quality gap between academic (ROUGE-1: 0.319) and literary (ROUGE-1: 0.206) summaries driven by training data imbalance.
470
-
471
- These results demonstrate that negative transfer in MTL is not an inherent limitation but can be addressed through architectural isolation (task-specific pooling) and balanced optimization (temperature sampling). Pre-trained initialization (FLAN-T5) remains essential for competitive performance across all tasks. Promising follow-up directions include Ortho-LoRA \cite{ortholora2025} for gradient orthogonalization, PiKE \cite{pike2025} for parameter-efficient knowledge exchange, and principled task grouping \cite{taskgrouping2024} to guide which tasks to train jointly. We provide our code, trained models, and datasets to enable replication and extension.
472
-
473
- Code and models: \url{https://github.com/OliverPerrin/LexiMind}\\
474
- Live demo: \url{https://huggingface.co/spaces/OliverPerrin/LexiMind}
475
-
476
- %=============================================================================
477
- % References
478
- %=============================================================================
479
-
480
- \begin{thebibliography}{00}
481
-
482
- \bibitem{caruana1997multitask}
483
- R. Caruana, ``Multitask learning,'' \textit{Machine Learning}, vol. 28, no. 1, pp. 41--75, 1997.
484
-
485
- \bibitem{collobert2011natural}
486
- R. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu, and P. Kuksa, ``Natural language processing (almost) from scratch,'' \textit{JMLR}, vol. 12, pp. 2493--2537, 2011.
487
-
488
- \bibitem{johnson2017google}
489
- M. Johnson et al., ``Google's multilingual neural machine translation system: Enabling zero-shot translation,'' \textit{TACL}, vol. 5, pp. 339--351, 2017.
490
-
491
- \bibitem{mccann2018natural}
492
- B. McCann, N. S. Keskar, C. Xiong, and R. Socher, ``The natural language decathlon: Multitask learning as question answering,'' \textit{arXiv:1806.08730}, 2018.
493
-
494
- \bibitem{standley2020tasks}
495
- T. Standley, A. Zamir, D. Chen, L. Guibas, J. Malik, and S. Savarese, ``Which tasks should be learned together in multi-task learning?'' in \textit{ICML}, 2020.
496
-
497
- \bibitem{yu2020gradient}
498
- T. Yu, S. Kumar, A. Gupta, S. Levine, K. Hausman, and C. Finn, ``Gradient surgery for multi-task learning,'' in \textit{NeurIPS}, 2020.
499
-
500
- \bibitem{liu2021conflict}
501
- B. Liu, X. Liu, X. Jin, P. Stone, and Q. Liu, ``Conflict-averse gradient descent for multi-task learning,'' in \textit{NeurIPS}, 2021.
502
-
503
- \bibitem{chen2018gradnorm}
504
- Z. Chen, V. Badrinarayanan, C.-Y. Lee, and A. Rabinovich, ``GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks,'' in \textit{ICML}, 2018.
505
-
506
- \bibitem{kendall2018multi}
507
- A. Kendall, Y. Gal, and R. Cipolla, ``Multi-task learning using uncertainty to weigh losses for scene geometry and semantics,'' in \textit{CVPR}, 2018.
508
-
509
- \bibitem{aghajanyan2021muppet}
510
- A. Aghajanyan, A. Gupta, A. Shrivastava, X. Chen, L. Zettlemoyer, and S. Gupta, ``Muppet: Massive multi-task representations with pre-finetuning,'' in \textit{EMNLP}, 2021.
511
-
512
- \bibitem{aribandi2022ext5}
513
- V. Aribandi et al., ``ExT5: Towards extreme multi-task scaling for transfer learning,'' in \textit{ICLR}, 2022.
514
-
515
- \bibitem{raffel2020exploring}
516
- C. Raffel et al., ``Exploring the limits of transfer learning with a unified text-to-text transformer,'' \textit{JMLR}, vol. 21, no. 140, pp. 1--67, 2020.
517
-
518
- \bibitem{chung2022scaling}
519
- H. W. Chung et al., ``Scaling instruction-finetuned language models,'' \textit{arXiv:2210.11416}, 2022.
520
-
521
- \bibitem{nallapati2016abstractive}
522
- R. Nallapati, B. Zhou, C. dos Santos, C. Gulcehre, and B. Xiang, ``Abstractive text summarization using sequence-to-sequence RNNs and beyond,'' in \textit{CoNLL}, 2016.
523
-
524
- \bibitem{narayan2018don}
525
- S. Narayan, S. B. Cohen, and M. Lapata, ``Don't give me the details, just the summary! Topic-aware convolutional neural networks for extreme summarization,'' in \textit{EMNLP}, 2018.
526
-
527
- \bibitem{kryscinski2021booksum}
528
- W. Kryscinski, N. Rajani, D. Aber, and C. Xiong, ``BookSum: A collection of datasets for long-form narrative summarization,'' in \textit{Findings of EMNLP}, 2021.
529
-
530
- \bibitem{cohan2018discourse}
531
- A. Cohan et al., ``A discourse-aware attention model for abstractive summarization of long documents,'' in \textit{NAACL-HLT}, 2018.
532
-
533
- \bibitem{mao2022citesum}
534
- Y. Mao, M. Zhong, and J. Han, ``CiteSum: Citation text-guided scientific extreme summarization and domain adaptation with limited supervision,'' in \textit{EMNLP}, 2022.
535
-
536
- \bibitem{demszky2020goemotions}
537
- D. Demszky et al., ``GoEmotions: A dataset of fine-grained emotions,'' in \textit{ACL}, 2020.
538
-
539
- \bibitem{zhang2019root}
540
- B. Zhang and R. Sennrich, ``Root mean square layer normalization,'' in \textit{NeurIPS}, 2019.
541
-
542
- \bibitem{lin2004rouge}
543
- C.-Y. Lin, ``ROUGE: A package for automatic evaluation of summaries,'' in \textit{Text Summarization Branches Out}, 2004.
544
-
545
- \bibitem{zhang2019bertscore}
546
- T. Zhang, V. Kishore, F. Wu, K. Q. Weinberger, and Y. Artzi, ``BERTScore: Evaluating text generation with BERT,'' in \textit{ICLR}, 2020.
547
-
548
- \bibitem{hu2022lora}
549
- E. J. Hu et al., ``LoRA: Low-rank adaptation of large language models,'' in \textit{ICLR}, 2022.
550
-
551
- \bibitem{houlsby2019parameter}
552
- N. Houlsby et al., ``Parameter-efficient transfer learning for NLP,'' in \textit{ICML}, 2019.
553
-
554
- \bibitem{lin2017focal}
555
- T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Doll\'{a}r, ``Focal loss for dense object detection,'' in \textit{ICCV}, 2017.
556
-
557
- \bibitem{ortholora2025}
558
- B. Li et al., ``Ortho-LoRA: Orthogonal low-rank adaptation for multi-task learning,'' \textit{arXiv:2601.09684}, 2025.
559
-
560
- \bibitem{pike2025}
561
- Y. Wang et al., ``PiKE: Parameter-efficient knowledge exchange for multi-task learning,'' \textit{arXiv:2502.06244}, 2025.
562
-
563
- \bibitem{scallearn2023}
564
- H. Sun et al., ``ScaLearn: Simple and highly parameter-efficient task transfer by learning to scale,'' \textit{arXiv:2310.01217}, 2023.
565
-
566
- \bibitem{taskgrouping2024}
567
- S. Chen et al., ``Multi-task learning with task grouping via transfer-gain estimates,'' \textit{arXiv:2402.15328}, 2024.
568
-
569
- \bibitem{neuroncentric2024}
570
- A. Foroutan et al., ``What do neurons in multi-task language models encode? A neuron-centric analysis,'' \textit{arXiv:2407.06488}, 2024.
571
-
572
- \end{thebibliography}
573
-
574
- \end{document}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kaggle.json DELETED
@@ -1 +0,0 @@
1
- {"username":"oliverperrin","key":"af2c73b5725d2839410a1f72cf84cf48"}
 
 
outputs/evaluation_report.json DELETED
@@ -1,260 +0,0 @@
1
- {
2
- "summarization": {
3
- "rouge1": 0.3094793058747055,
4
- "rouge2": 0.09069756722817666,
5
- "rougeL": 0.1847154828322755,
6
- "bleu4": 0.023982657019404153,
7
- "num_samples": 2727,
8
- "per_domain": {
9
- "academic": {
10
- "num_samples": 2493,
11
- "rouge1": 0.31919183681728475,
12
- "rouge2": 0.0968589097730544,
13
- "rougeL": 0.18921129182459423,
14
- "bleu4": 0.02551610700902003
15
- },
16
- "literary": {
17
- "num_samples": 234,
18
- "rouge1": 0.2060034954479976,
19
- "rouge2": 0.0250555716539014,
20
- "rougeL": 0.1368178254910352,
21
- "bleu4": 0.0076455167454197795
22
- }
23
- },
24
- "rouge1_ci": {
25
- "mean": 0.30947930587470557,
26
- "lower": 0.3060921045548166,
27
- "upper": 0.3131015955325767
28
- },
29
- "rougeL_ci": {
30
- "mean": 0.18471548283227546,
31
- "lower": 0.18251665662669495,
32
- "upper": 0.18701919830414013
33
- }
34
- },
35
- "emotion": {
36
- "sample_avg_f1": 0.3522975742816925,
37
- "macro_f1": 0.14317210018634796,
38
- "micro_f1": 0.4430159032344818,
39
- "num_samples": 5426,
40
- "num_classes": 28,
41
- "per_class": {
42
- "admiration": {
43
- "precision": 0.714634120464325,
44
- "recall": 0.6004098653793335,
45
- "f1": 0.6525612537411732,
46
- "support": 488
47
- },
48
- "amusement": {
49
- "precision": 0.7708333134651184,
50
- "recall": 0.7326732873916626,
51
- "f1": 0.7512690366449468,
52
- "support": 303
53
- },
54
- "anger": {
55
- "precision": 0.0,
56
- "recall": 0.0,
57
- "f1": 0.0,
58
- "support": 195
59
- },
60
- "annoyance": {
61
- "precision": 0.0,
62
- "recall": 0.0,
63
- "f1": 0.0,
64
- "support": 303
65
- },
66
- "approval": {
67
- "precision": 0.0,
68
- "recall": 0.0,
69
- "f1": 0.0,
70
- "support": 397
71
- },
72
- "caring": {
73
- "precision": 0.0,
74
- "recall": 0.0,
75
- "f1": 0.0,
76
- "support": 153
77
- },
78
- "confusion": {
79
- "precision": 0.0,
80
- "recall": 0.0,
81
- "f1": 0.0,
82
- "support": 152
83
- },
84
- "curiosity": {
85
- "precision": 0.6166666746139526,
86
- "recall": 0.14919355511665344,
87
- "f1": 0.24025974958898805,
88
- "support": 248
89
- },
90
- "desire": {
91
- "precision": 0.0,
92
- "recall": 0.0,
93
- "f1": 0.0,
94
- "support": 77
95
- },
96
- "disappointment": {
97
- "precision": 0.0,
98
- "recall": 0.0,
99
- "f1": 0.0,
100
- "support": 163
101
- },
102
- "disapproval": {
103
- "precision": 0.0,
104
- "recall": 0.0,
105
- "f1": 0.0,
106
- "support": 292
107
- },
108
- "disgust": {
109
- "precision": 0.0,
110
- "recall": 0.0,
111
- "f1": 0.0,
112
- "support": 97
113
- },
114
- "embarrassment": {
115
- "precision": 0.0,
116
- "recall": 0.0,
117
- "f1": 0.0,
118
- "support": 35
119
- },
120
- "excitement": {
121
- "precision": 0.0,
122
- "recall": 0.0,
123
- "f1": 0.0,
124
- "support": 96
125
- },
126
- "fear": {
127
- "precision": 0.0,
128
- "recall": 0.0,
129
- "f1": 0.0,
130
- "support": 90
131
- },
132
- "gratitude": {
133
- "precision": 0.8997134566307068,
134
- "recall": 0.8770949840545654,
135
- "f1": 0.8882602556669954,
136
- "support": 358
137
- },
138
- "grief": {
139
- "precision": 0.0,
140
- "recall": 0.0,
141
- "f1": 0.0,
142
- "support": 13
143
- },
144
- "joy": {
145
- "precision": 0.0,
146
- "recall": 0.0,
147
- "f1": 0.0,
148
- "support": 172
149
- },
150
- "love": {
151
- "precision": 0.6996466517448425,
152
- "recall": 0.7857142686843872,
153
- "f1": 0.740186913163602,
154
- "support": 252
155
- },
156
- "nervousness": {
157
- "precision": 0.0,
158
- "recall": 0.0,
159
- "f1": 0.0,
160
- "support": 21
161
- },
162
- "neutral": {
163
- "precision": 0.6869627237319946,
164
- "recall": 0.543035089969635,
165
- "f1": 0.6065780936064032,
166
- "support": 1766
167
- },
168
- "optimism": {
169
- "precision": 0.7142857313156128,
170
- "recall": 0.023923445492982864,
171
- "f1": 0.04629629729995926,
172
- "support": 209
173
- },
174
- "pride": {
175
- "precision": 0.0,
176
- "recall": 0.0,
177
- "f1": 0.0,
178
- "support": 15
179
- },
180
- "realization": {
181
- "precision": 0.0,
182
- "recall": 0.0,
183
- "f1": 0.0,
184
- "support": 127
185
- },
186
- "relief": {
187
- "precision": 0.0,
188
- "recall": 0.0,
189
- "f1": 0.0,
190
- "support": 18
191
- },
192
- "remorse": {
193
- "precision": 1.0,
194
- "recall": 0.014705882407724857,
195
- "f1": 0.02898550735279132,
196
- "support": 68
197
- },
198
- "sadness": {
199
- "precision": 1.0,
200
- "recall": 0.0279720276594162,
201
- "f1": 0.054421768115822285,
202
- "support": 143
203
- },
204
- "surprise": {
205
- "precision": 0.0,
206
- "recall": 0.0,
207
- "f1": 0.0,
208
- "support": 129
209
- }
210
- },
211
- "tuned_thresholds": {
212
- "admiration": 0.4,
213
- "amusement": 0.55,
214
- "anger": 0.2,
215
- "annoyance": 0.15,
216
- "approval": 0.15,
217
- "caring": 0.1,
218
- "confusion": 0.1,
219
- "curiosity": 0.25,
220
- "desire": 0.15,
221
- "disappointment": 0.1,
222
- "disapproval": 0.1,
223
- "disgust": 0.1,
224
- "embarrassment": 0.1,
225
- "excitement": 0.1,
226
- "fear": 0.1,
227
- "gratitude": 0.65,
228
- "grief": 0.1,
229
- "joy": 0.2,
230
- "love": 0.45,
231
- "nervousness": 0.1,
232
- "neutral": 0.3,
233
- "optimism": 0.25,
234
- "pride": 0.1,
235
- "realization": 0.2,
236
- "relief": 0.1,
237
- "remorse": 0.2,
238
- "sadness": 0.25,
239
- "surprise": 0.1
240
- },
241
- "tuned_macro_f1": 0.29355332255363464,
242
- "tuned_sample_avg_f1": 0.5025880336761475,
243
- "tuned_micro_f1": 0.48644566535949707,
244
- "sample_f1_ci": {
245
- "mean": 0.3522975795552279,
246
- "lower": 0.33984518982676004,
247
- "upper": 0.3658618994962526
248
- }
249
- },
250
- "topic": {
251
- "accuracy": 0.8571428571428571,
252
- "macro_f1": 0.8538751111963805,
253
- "num_samples": 189,
254
- "accuracy_ci": {
255
- "mean": 0.8571428571428571,
256
- "lower": 0.8042328042328042,
257
- "upper": 0.91005291005291
258
- }
259
- }
260
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
outputs/training_history.json DELETED
@@ -1,210 +0,0 @@
1
- {
2
- "train_epoch_1": {
3
- "summarization_loss": 4.05727732604402,
4
- "summarization_rouge_like": 0.20427788603502178,
5
- "summarization_rouge1": 0.2867239527374218,
6
- "summarization_rouge2": 0.08530419039955006,
7
- "summarization_rougeL": 0.21671934441779328,
8
- "summarization_bleu4": 0.046807610480627294,
9
- "emotion_loss": 0.26667120894821444,
10
- "emotion_f1": 0.20469974499405505,
11
- "total_loss": 6.105969995046231,
12
- "topic_loss": 1.6857517635423032,
13
- "topic_accuracy": 0.4217703349282306
14
- },
15
- "val_epoch_1": {
16
- "summarization_loss": 3.8147981293996174,
17
- "summarization_rouge_like": 0.2193516213078271,
18
- "summarization_rouge1": 0.26079796060659194,
19
- "summarization_rouge2": 0.08507403927823329,
20
- "summarization_rougeL": 0.2006794257877804,
21
- "summarization_bleu4": 0.047830237595825456,
22
- "emotion_loss": 0.14947238216797512,
23
- "emotion_f1": 0.19722222675879797,
24
- "topic_loss": 1.111328324476878,
25
- "topic_accuracy": 0.7036666666666669,
26
- "total_loss": 4.297669008910658
27
- },
28
- "train_epoch_2": {
29
- "summarization_loss": 3.849853239841521,
30
- "summarization_rouge_like": 0.21403927033293962,
31
- "summarization_rouge1": 0.27951076939717684,
32
- "summarization_rouge2": 0.0873046768836161,
33
- "summarization_rougeL": 0.21374118927203542,
34
- "summarization_bleu4": 0.04958577524880755,
35
- "emotion_loss": 0.14221051054989492,
36
- "emotion_f1": 0.26357290302542585,
37
- "topic_loss": 0.7299397268663149,
38
- "topic_accuracy": 0.8084686774942008,
39
- "total_loss": 5.528282663174978
40
- },
41
- "val_epoch_2": {
42
- "summarization_loss": 3.738964385986328,
43
- "summarization_rouge_like": 0.22322817933854347,
44
- "summarization_rouge1": 0.2648447987903156,
45
- "summarization_rouge2": 0.08777067266852198,
46
- "summarization_rougeL": 0.2049718124413594,
47
- "summarization_bleu4": 0.04980800809043137,
48
- "emotion_loss": 0.1332480485116442,
49
- "emotion_f1": 0.3008111199736595,
50
- "topic_loss": 0.5171811254819234,
51
- "topic_accuracy": 0.8467777777777786,
52
- "total_loss": 4.027366772142546
53
- },
54
- "train_epoch_3": {
55
- "emotion_loss": 0.12831888329927568,
56
- "emotion_f1": 0.3325013413977316,
57
- "summarization_loss": 3.7839796767703127,
58
- "summarization_rouge_like": 0.21797276831106976,
59
- "summarization_rouge1": 0.28868883384124916,
60
- "summarization_rouge2": 0.09150032176337587,
61
- "summarization_rougeL": 0.22148013487440707,
62
- "summarization_bleu4": 0.052993168973641876,
63
- "total_loss": 5.379445686122572,
64
- "topic_loss": 0.3385182340765703,
65
- "topic_accuracy": 0.9149137451307789
66
- },
67
- "val_epoch_3": {
68
- "summarization_loss": 3.699807391166687,
69
- "summarization_rouge_like": 0.22613490382620294,
70
- "summarization_rouge1": 0.27110048990501884,
71
- "summarization_rouge2": 0.09042725720607361,
72
- "summarization_rougeL": 0.209904253200661,
73
- "summarization_bleu4": 0.05177241093143676,
74
- "emotion_loss": 0.12147359546273946,
75
- "emotion_f1": 0.3474666798238953,
76
- "topic_loss": 0.5068136086066564,
77
- "topic_accuracy": 0.8417777777777792,
78
- "total_loss": 3.9733250692114286
79
- },
80
- "train_epoch_4": {
81
- "summarization_loss": 3.746917572488457,
82
- "summarization_rouge_like": 0.22054338132572013,
83
- "summarization_rouge1": 0.29700759128401966,
84
- "summarization_rouge2": 0.09528349132659034,
85
- "summarization_rougeL": 0.2286643637324592,
86
- "summarization_bleu4": 0.05591190647915982,
87
- "emotion_loss": 0.12003502780097021,
88
- "emotion_f1": 0.37240424536844824,
89
- "total_loss": 5.303101773435515,
90
- "topic_loss": 0.19978291214297147,
91
- "topic_accuracy": 0.9528935185185234
92
- },
93
- "val_epoch_4": {
94
- "summarization_loss": 3.6773871207237243,
95
- "summarization_rouge_like": 0.22730110361278533,
96
- "summarization_rouge1": 0.2719731929407321,
97
- "summarization_rouge2": 0.09117786246379923,
98
- "summarization_rougeL": 0.21082587270737135,
99
- "summarization_bleu4": 0.052260125383420154,
100
- "emotion_loss": 0.11476812147845825,
101
- "emotion_f1": 0.40390001876900594,
102
- "topic_loss": 0.5311758625507355,
103
- "topic_accuracy": 0.8574444444444455,
104
- "total_loss": 3.95150800096741
105
- },
106
- "train_epoch_5": {
107
- "summarization_loss": 3.72376742684834,
108
- "summarization_rouge_like": 0.22218972657959773,
109
- "summarization_rouge1": 0.30386172952451457,
110
- "summarization_rouge2": 0.09807265293507532,
111
- "summarization_rougeL": 0.23422938393417417,
112
- "summarization_bleu4": 0.05821407514551748,
113
- "emotion_loss": 0.11460309708431649,
114
- "emotion_f1": 0.41015538037428334,
115
- "total_loss": 5.207888234798891,
116
- "topic_loss": 0.13986067138923575,
117
- "topic_accuracy": 0.9685236768802278
118
- },
119
- "val_epoch_5": {
120
- "summarization_loss": 3.664777074654897,
121
- "summarization_rouge_like": 0.22876987463000684,
122
- "summarization_rouge1": 0.27596093399625565,
123
- "summarization_rouge2": 0.09296804123657829,
124
- "summarization_rougeL": 0.21411928790828857,
125
- "summarization_bleu4": 0.05366559404113782,
126
- "emotion_loss": 0.11044646929949523,
127
- "emotion_f1": 0.4313555757453044,
128
- "topic_loss": 0.5484664579232533,
129
- "topic_accuracy": 0.8627777777777789,
130
- "total_loss": 3.9397634813313704
131
- },
132
- "train_epoch_6": {
133
- "emotion_loss": 0.1111307007874511,
134
- "emotion_f1": 0.43345397762862603,
135
- "summarization_loss": 3.7095002406409807,
136
- "summarization_rouge_like": 0.22328726116125275,
137
- "summarization_rouge1": 0.3064035344877472,
138
- "summarization_rouge2": 0.09935359454486654,
139
- "summarization_rougeL": 0.23650841461700828,
140
- "summarization_bleu4": 0.059165680810364656,
141
- "total_loss": 5.231221632164746,
142
- "topic_loss": 0.10774352340420275,
143
- "topic_accuracy": 0.9777777777777826
144
- },
145
- "val_epoch_6": {
146
- "summarization_loss": 3.658109269142151,
147
- "summarization_rouge_like": 0.22934290201883448,
148
- "summarization_rouge1": 0.2752052666208255,
149
- "summarization_rouge2": 0.09292038370832255,
150
- "summarization_rougeL": 0.2137414809166316,
151
- "summarization_bleu4": 0.053427338475007496,
152
- "emotion_loss": 0.10808507531881333,
153
- "emotion_f1": 0.4517777989556392,
154
- "topic_loss": 0.5295590771238009,
155
- "topic_accuracy": 0.8734444444444451,
156
- "total_loss": 3.9250620675981014
157
- },
158
- "train_epoch_7": {
159
- "emotion_loss": 0.10953594371440704,
160
- "emotion_f1": 0.44384909393388133,
161
- "topic_loss": 0.0957411853224039,
162
- "topic_accuracy": 0.980394366197187,
163
- "total_loss": 5.154093306898701,
164
- "summarization_loss": 3.7035266418583594,
165
- "summarization_rouge_like": 0.22378105952974536,
166
- "summarization_rouge1": 0.3070619920824417,
167
- "summarization_rouge2": 0.09984959921270933,
168
- "summarization_rougeL": 0.23710279675635842,
169
- "summarization_bleu4": 0.05954598113800495
170
- },
171
- "val_epoch_7": {
172
- "summarization_loss": 3.654966928164164,
173
- "summarization_rouge_like": 0.2296679906954514,
174
- "summarization_rouge1": 0.27616327736195406,
175
- "summarization_rouge2": 0.09329265746038877,
176
- "summarization_rougeL": 0.2144202156909426,
177
- "summarization_bleu4": 0.05381191556748925,
178
- "emotion_loss": 0.10733611459533374,
179
- "emotion_f1": 0.4582889095693827,
180
- "topic_loss": 0.5517185291647911,
181
- "topic_accuracy": 0.8574444444444457,
182
- "total_loss": 3.9278186015089296
183
- },
184
- "train_epoch_8": {
185
- "summarization_loss": 3.6991967220660666,
186
- "summarization_rouge_like": 0.22392498422300275,
187
- "summarization_rouge1": 0.30751530664889926,
188
- "summarization_rouge2": 0.10003700619268063,
189
- "summarization_rougeL": 0.23750205422812004,
190
- "summarization_bleu4": 0.05974583539783897,
191
- "emotion_loss": 0.10842480880968565,
192
- "emotion_f1": 0.44919879924130307,
193
- "total_loss": 5.178375478314296,
194
- "topic_loss": 0.09057999134229244,
195
- "topic_accuracy": 0.9817882611080675
196
- },
197
- "val_epoch_8": {
198
- "summarization_loss": 3.652825821240743,
199
- "summarization_rouge_like": 0.22990084413830203,
200
- "summarization_rouge1": 0.2765755402183266,
201
- "summarization_rouge2": 0.09348690574327727,
202
- "summarization_rougeL": 0.2147442273650338,
203
- "summarization_bleu4": 0.053967258926407226,
204
- "emotion_loss": 0.10670269103099903,
205
- "emotion_f1": 0.4594111327578624,
206
- "topic_loss": 0.5511919154723486,
207
- "topic_accuracy": 0.8574444444444457,
208
- "total_loss": 3.924886086913453
209
- }
210
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_bert_baseline.py ADDED
@@ -0,0 +1,1137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT Baseline Training for LexiMind Comparison.
3
+
4
+ Fine-tunes bert-base-uncased on topic classification and emotion detection
5
+ to provide baselines for comparison with LexiMind (FLAN-T5-based).
6
+
7
+ Supports three training modes to disentangle architecture vs. MTL effects:
8
+ 1. single-topic — BERT fine-tuned on topic classification only
9
+ 2. single-emotion — BERT fine-tuned on emotion detection only
10
+ 3. multitask — BERT fine-tuned on both tasks jointly
11
+
12
+ Uses the same datasets, splits, label encoders, and evaluation metrics as the
13
+ main LexiMind pipeline for fair comparison.
14
+
15
+ Usage:
16
+ python scripts/train_bert_baseline.py --mode single-topic
17
+ python scripts/train_bert_baseline.py --mode single-emotion
18
+ python scripts/train_bert_baseline.py --mode multitask
19
+ python scripts/train_bert_baseline.py --mode all # Run all three sequentially
20
+
21
+ Author: Oliver Perrin
22
+ Date: March 2026
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import json
29
+ import random
30
+ import sys
31
+ import time
32
+ from dataclasses import dataclass, field
33
+ from pathlib import Path
34
+ from typing import Any, Dict, List, Optional, Sequence
35
+
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+ from sklearn.metrics import accuracy_score, classification_report, f1_score
41
+ from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
42
+ from torch.cuda.amp import GradScaler, autocast
43
+ from torch.optim import AdamW
44
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
45
+ from torch.utils.data import DataLoader, Dataset
46
+ from tqdm import tqdm
47
+ from transformers import AutoModel, AutoTokenizer
48
+
49
+ # Project imports
50
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
51
+ if str(PROJECT_ROOT) not in sys.path:
52
+ sys.path.insert(0, str(PROJECT_ROOT))
53
+
54
+ from src.data.dataset import (
55
+ EmotionExample,
56
+ TopicExample,
57
+ load_emotion_jsonl,
58
+ load_topic_jsonl,
59
+ )
60
+
61
+ from src.training.metrics import (
62
+ bootstrap_confidence_interval,
63
+ multilabel_f1,
64
+ multilabel_macro_f1,
65
+ multilabel_micro_f1,
66
+ multilabel_per_class_metrics,
67
+ tune_per_class_thresholds,
68
+ )
69
+
70
+
71
+ # Configuration
72
+
73
+ @dataclass
74
+ class BertBaselineConfig:
75
+ """Hyperparameters aligned with LexiMind's full.yaml where applicable."""
76
+
77
+ # Model
78
+ model_name: str = "bert-base-uncased"
79
+ max_length: int = 256 # Same as LexiMind classification max_len
80
+
81
+ # Optimizer (matching LexiMind's full.yaml)
82
+ lr: float = 3e-5
83
+ weight_decay: float = 0.01
84
+ betas: tuple[float, float] = (0.9, 0.98)
85
+ eps: float = 1e-6
86
+
87
+ # Training
88
+ batch_size: int = 10 # Same as LexiMind
89
+ gradient_accumulation_steps: int = 4 # Same effective batch = 40
90
+ max_epochs: int = 8
91
+ warmup_steps: int = 300
92
+ gradient_clip_norm: float = 1.0
93
+ early_stopping_patience: int = 3
94
+ seed: int = 17 # Same as LexiMind
95
+
96
+ # Task weights (for multi-task mode)
97
+ topic_weight: float = 0.3 # Same as LexiMind
98
+ emotion_weight: float = 1.0
99
+
100
+ # Temperature sampling (for multi-task mode)
101
+ task_sampling_alpha: float = 0.5
102
+
103
+ # Frozen layers: freeze bottom 4 layers (matching LexiMind's encoder strategy)
104
+ freeze_layers: int = 4
105
+
106
+ # Precision
107
+ use_amp: bool = True # BFloat16 mixed precision
108
+
109
+ # Paths
110
+ data_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed")
111
+ output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "outputs" / "bert_baseline")
112
+ checkpoint_dir: Path = field(
113
+ default_factory=lambda: PROJECT_ROOT / "checkpoints" / "bert_baseline"
114
+ )
115
+
116
+ # Emotion threshold
117
+ emotion_threshold: float = 0.3
118
+
119
+
120
+ # Datasets
121
+
122
+ class BertEmotionDataset(Dataset):
123
+ """Tokenized emotion dataset for BERT."""
124
+ def __init__(
125
+ self,
126
+ examples: List[EmotionExample],
127
+ tokenizer: AutoTokenizer,
128
+ binarizer: MultiLabelBinarizer,
129
+ max_length: int = 256,
130
+ ):
131
+ self.examples = examples
132
+ self.tokenizer = tokenizer
133
+ self.binarizer = binarizer
134
+ self.max_length = max_length
135
+
136
+ def __len__(self) -> int:
137
+ return len(self.examples)
138
+
139
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
140
+ ex = self.examples[idx]
141
+ encoding = self.tokenizer(
142
+ ex.text,
143
+ max_length=self.max_length,
144
+ padding="max_length",
145
+ truncation=True,
146
+ return_tensors="pt",
147
+ )
148
+ labels = self.binarizer.transform([ex.emotions])[0]
149
+ return {
150
+ "input_ids": encoding["input_ids"].squeeze(0),
151
+ "attention_mask": encoding["attention_mask"].squeeze(0),
152
+ "labels": torch.tensor(labels, dtype=torch.float32),
153
+ }
154
+
155
+
156
+ class BertTopicDataset(Dataset):
157
+ """Tokenized topic dataset for BERT."""
158
+ def __init__(
159
+ self,
160
+ examples: List[TopicExample],
161
+ tokenizer: AutoTokenizer,
162
+ encoder: LabelEncoder,
163
+ max_length: int = 256,
164
+ ):
165
+ self.examples = examples
166
+ self.tokenizer = tokenizer
167
+ self.encoder = encoder
168
+ self.max_length = max_length
169
+
170
+ def __len__(self) -> int:
171
+ return len(self.examples)
172
+
173
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
174
+ ex = self.examples[idx]
175
+ encoding = self.tokenizer(
176
+ ex.text,
177
+ max_length=self.max_length,
178
+ padding="max_length",
179
+ truncation=True,
180
+ return_tensors="pt",
181
+ )
182
+ label = self.encoder.transform([ex.topic])[0]
183
+ return {
184
+ "input_ids": encoding["input_ids"].squeeze(0),
185
+ "attention_mask": encoding["attention_mask"].squeeze(0),
186
+ "labels": torch.tensor(label, dtype=torch.long),
187
+ }
188
+
189
+
190
+ # Model
191
+
192
+ class BertClassificationHead(nn.Module):
193
+ """Classification head on top of BERT [CLS] token.
194
+
195
+ For emotion: uses attention pooling + 2-layer MLP (matching LexiMind's emotion head)
196
+ For topic: uses [CLS] + single linear (matching LexiMind's mean pool + linear)
197
+ """
198
+ def __init__(
199
+ self,
200
+ hidden_size: int,
201
+ num_labels: int,
202
+ pooling: str = "cls", # "cls" or "attention"
203
+ hidden_dim: Optional[int] = None,
204
+ dropout: float = 0.1,
205
+ ):
206
+ super().__init__()
207
+ self.pooling = pooling
208
+ self.dropout = nn.Dropout(dropout)
209
+
210
+ if pooling == "attention":
211
+ self.attn_query = nn.Linear(hidden_size, 1, bias=False)
212
+
213
+ if hidden_dim is not None:
214
+ self.classifier = nn.Sequential(
215
+ nn.Linear(hidden_size, hidden_dim),
216
+ nn.GELU(),
217
+ nn.Dropout(dropout),
218
+ nn.Linear(hidden_dim, num_labels),
219
+ )
220
+ else:
221
+ self.classifier = nn.Linear(hidden_size, num_labels)
222
+
223
+ def forward(
224
+ self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
225
+ ) -> torch.Tensor:
226
+ if self.pooling == "attention":
227
+ # Learned attention pooling (same mechanism as LexiMind)
228
+ scores = self.attn_query(hidden_states) # (B, L, 1)
229
+ mask = attention_mask.unsqueeze(-1).bool()
230
+ scores = scores.masked_fill(~mask, float("-inf"))
231
+ weights = F.softmax(scores, dim=1)
232
+ pooled = (weights * hidden_states).sum(dim=1)
233
+ elif self.pooling == "mean":
234
+ # Mean pooling over valid tokens
235
+ mask_expanded = attention_mask.unsqueeze(-1).float()
236
+ sum_embeddings = (hidden_states * mask_expanded).sum(dim=1)
237
+ sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
238
+ pooled = sum_embeddings / sum_mask
239
+ else:
240
+ # [CLS] token
241
+ pooled = hidden_states[:, 0, :]
242
+
243
+ pooled = self.dropout(pooled)
244
+ return self.classifier(pooled)
245
+
246
+
247
+ class BertBaseline(nn.Module):
248
+ """BERT baseline model with task-specific heads.
249
+
250
+ Supports single-task and multi-task configurations.
251
+ """
252
+ def __init__(
253
+ self,
254
+ model_name: str = "bert-base-uncased",
255
+ num_emotions: int = 28,
256
+ num_topics: int = 7,
257
+ tasks: Sequence[str] = ("emotion", "topic"),
258
+ freeze_layers: int = 4,
259
+ ):
260
+ super().__init__()
261
+ self.bert = AutoModel.from_pretrained(model_name)
262
+ hidden_size = self.bert.config.hidden_size # 768 for bert-base
263
+
264
+ self.tasks = list(tasks)
265
+ self.heads = nn.ModuleDict()
266
+
267
+ if "emotion" in tasks:
268
+ # Attention pooling + 2-layer MLP (matching LexiMind's emotion head)
269
+ self.heads["emotion"] = BertClassificationHead(
270
+ hidden_size=hidden_size,
271
+ num_labels=num_emotions,
272
+ pooling="attention",
273
+ hidden_dim=hidden_size // 2, # 384, same ratio as LexiMind
274
+ dropout=0.1,
275
+ )
276
+
277
+ if "topic" in tasks:
278
+ # Mean pooling + single linear (matching LexiMind's topic head)
279
+ self.heads["topic"] = BertClassificationHead(
280
+ hidden_size=hidden_size,
281
+ num_labels=num_topics,
282
+ pooling="mean",
283
+ hidden_dim=None,
284
+ dropout=0.1,
285
+ )
286
+
287
+ # Freeze bottom N encoder layers (matching LexiMind's strategy)
288
+ self._freeze_layers(freeze_layers)
289
+
290
+ def _freeze_layers(self, n: int) -> None:
291
+ """Freeze embedding + bottom n encoder layers."""
292
+ # Freeze embeddings
293
+ for param in self.bert.embeddings.parameters():
294
+ param.requires_grad = False
295
+
296
+ # Freeze bottom n layers
297
+ for i in range(min(n, len(self.bert.encoder.layer))):
298
+ for param in self.bert.encoder.layer[i].parameters():
299
+ param.requires_grad = False
300
+
301
+ frozen = sum(1 for p in self.bert.parameters() if not p.requires_grad)
302
+ total = sum(1 for p in self.bert.parameters())
303
+ print(f" Frozen {frozen}/{total} BERT parameters (bottom {n} layers + embeddings)")
304
+
305
+ def forward(
306
+ self,
307
+ task: str,
308
+ input_ids: torch.Tensor,
309
+ attention_mask: torch.Tensor,
310
+ ) -> torch.Tensor:
311
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
312
+ hidden_states = outputs.last_hidden_state # (B, L, 768)
313
+ return self.heads[task](hidden_states, attention_mask)
314
+
315
+ def param_count(self) -> Dict[str, int]:
316
+ """Count parameters by component."""
317
+ counts = {}
318
+ counts["bert_encoder"] = sum(p.numel() for p in self.bert.parameters())
319
+ counts["bert_trainable"] = sum(
320
+ p.numel() for p in self.bert.parameters() if p.requires_grad
321
+ )
322
+ for name, head in self.heads.items():
323
+ counts[f"head_{name}"] = sum(p.numel() for p in head.parameters())
324
+ counts["total"] = sum(p.numel() for p in self.parameters())
325
+ counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
326
+ return counts
327
+
328
+
329
+ # Training
330
+
331
+ class BertTrainer:
332
+ """Trainer supporting single-task and multi-task BERT training."""
333
+ def __init__(
334
+ self,
335
+ model: BertBaseline,
336
+ config: BertBaselineConfig,
337
+ train_loaders: Dict[str, DataLoader],
338
+ val_loaders: Dict[str, DataLoader],
339
+ device: torch.device,
340
+ mode: str,
341
+ ):
342
+ self.model = model
343
+ self.config = config
344
+ self.train_loaders = train_loaders
345
+ self.val_loaders = val_loaders
346
+ self.device = device
347
+ self.mode = mode
348
+
349
+ # Optimizer
350
+ self.optimizer = AdamW(
351
+ [p for p in model.parameters() if p.requires_grad],
352
+ lr=config.lr,
353
+ weight_decay=config.weight_decay,
354
+ betas=config.betas,
355
+ eps=config.eps,
356
+ )
357
+
358
+ # Calculate total training steps
359
+ if len(train_loaders) > 1:
360
+ # Multi-task: use temperature-sampled steps
361
+ sizes = {k: len(v) for k, v in train_loaders.items()}
362
+ total_batches = sum(sizes.values())
363
+ else:
364
+ total_batches = sum(len(v) for v in train_loaders.values())
365
+ self.steps_per_epoch = total_batches // config.gradient_accumulation_steps
366
+ self.total_steps = self.steps_per_epoch * config.max_epochs
367
+
368
+ # LR scheduler: linear warmup + cosine decay (matching LexiMind)
369
+ warmup_scheduler = LinearLR(
370
+ self.optimizer,
371
+ start_factor=1e-8 / config.lr,
372
+ end_factor=1.0,
373
+ total_iters=config.warmup_steps,
374
+ )
375
+ cosine_scheduler = CosineAnnealingLR(
376
+ self.optimizer,
377
+ T_max=max(self.total_steps - config.warmup_steps, 1),
378
+ eta_min=config.lr * 0.1, # Decay to 10% of peak (matching LexiMind)
379
+ )
380
+ self.scheduler = SequentialLR(
381
+ self.optimizer,
382
+ schedulers=[warmup_scheduler, cosine_scheduler],
383
+ milestones=[config.warmup_steps],
384
+ )
385
+
386
+ # Mixed precision
387
+ self.scaler = GradScaler(enabled=config.use_amp)
388
+
389
+ # Loss functions
390
+ self.emotion_loss_fn = nn.BCEWithLogitsLoss()
391
+ self.topic_loss_fn = nn.CrossEntropyLoss()
392
+
393
+ # Tracking
394
+ self.global_step = 0
395
+ self.best_metric = -float("inf")
396
+ self.patience_counter = 0
397
+ self.training_history: List[Dict[str, Any]] = []
398
+
399
+ def _compute_loss(self, task: str, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
400
+ if task == "emotion":
401
+ return self.emotion_loss_fn(logits, labels)
402
+ else:
403
+ return self.topic_loss_fn(logits, labels)
404
+
405
+ def _get_task_weight(self, task: str) -> float:
406
+ if self.mode != "multitask":
407
+ return 1.0
408
+ if task == "topic":
409
+ return self.config.topic_weight
410
+ return self.config.emotion_weight
411
+
412
+ def _make_multitask_iterator(self):
413
+ """Temperature-based task sampling (matching LexiMind)."""
414
+ sizes = {k: len(v.dataset) for k, v in self.train_loaders.items()}
415
+ alpha = self.config.task_sampling_alpha
416
+
417
+ # Compute sampling probabilities
418
+ raw = {k: s ** (1.0 / alpha) for k, s in sizes.items()}
419
+ total = sum(raw.values())
420
+ probs = {k: v / total for k, v in raw.items()}
421
+
422
+ # Create iterators
423
+ iters = {k: iter(v) for k, v in self.train_loaders.items()}
424
+ tasks = list(probs.keys())
425
+ weights = [probs[t] for t in tasks]
426
+
427
+ while True:
428
+ task = random.choices(tasks, weights=weights, k=1)[0]
429
+ try:
430
+ batch = next(iters[task])
431
+ except StopIteration:
432
+ iters[task] = iter(self.train_loaders[task])
433
+ batch = next(iters[task])
434
+ yield task, batch
435
+
436
+ def train_epoch(self, epoch: int) -> Dict[str, float]:
437
+ """Train one epoch."""
438
+ self.model.train()
439
+ self.optimizer.zero_grad()
440
+
441
+ epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
442
+ epoch_metrics: Dict[str, List[float]] = {}
443
+
444
+ if len(self.train_loaders) > 1:
445
+ # Multi-task: temperature sampling
446
+ iterator = self._make_multitask_iterator()
447
+ total_batches = sum(len(v) for v in self.train_loaders.values())
448
+ else:
449
+ # Single-task: iterate normally
450
+ task_name = list(self.train_loaders.keys())[0]
451
+ iterator = ((task_name, batch) for batch in self.train_loaders[task_name])
452
+ total_batches = len(self.train_loaders[task_name])
453
+
454
+ pbar = tqdm(total=total_batches, desc=f"Epoch {epoch + 1}/{self.config.max_epochs}")
455
+
456
+ for step_in_epoch in range(total_batches):
457
+ task, batch = next(iterator)
458
+
459
+ input_ids = batch["input_ids"].to(self.device)
460
+ attention_mask = batch["attention_mask"].to(self.device)
461
+ labels = batch["labels"].to(self.device)
462
+
463
+ # Forward pass with AMP
464
+ with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp):
465
+ logits = self.model(task, input_ids, attention_mask)
466
+ loss = self._compute_loss(task, logits, labels)
467
+ loss = loss * self._get_task_weight(task)
468
+ loss = loss / self.config.gradient_accumulation_steps
469
+
470
+ # Backward
471
+ self.scaler.scale(loss).backward()
472
+ epoch_losses[task].append(loss.item() * self.config.gradient_accumulation_steps)
473
+
474
+ # Optimizer step (every N accumulation steps)
475
+ if (step_in_epoch + 1) % self.config.gradient_accumulation_steps == 0:
476
+ self.scaler.unscale_(self.optimizer)
477
+ torch.nn.utils.clip_grad_norm_(
478
+ self.model.parameters(), self.config.gradient_clip_norm
479
+ )
480
+ self.scaler.step(self.optimizer)
481
+ self.scaler.update()
482
+ self.optimizer.zero_grad()
483
+ self.scheduler.step()
484
+ self.global_step += 1
485
+
486
+ pbar.set_postfix(
487
+ {f"{task}_loss": f"{epoch_losses[task][-1]:.4f}", "lr": f"{self.scheduler.get_last_lr()[0]:.2e}"}
488
+ )
489
+ pbar.update(1)
490
+
491
+ pbar.close()
492
+
493
+ # Aggregate
494
+ results = {}
495
+ for task, losses in epoch_losses.items():
496
+ if losses:
497
+ results[f"train_{task}_loss"] = sum(losses) / len(losses)
498
+ return results
499
+
500
+ @torch.no_grad()
501
+ def validate(self) -> Dict[str, Any]:
502
+ """Run validation across all tasks."""
503
+ self.model.eval()
504
+ results: Dict[str, Any] = {}
505
+
506
+ for task, loader in self.val_loaders.items():
507
+ all_logits = []
508
+ all_labels = []
509
+ total_loss = 0.0
510
+ n_batches = 0
511
+
512
+ for batch in loader:
513
+ input_ids = batch["input_ids"].to(self.device)
514
+ attention_mask = batch["attention_mask"].to(self.device)
515
+ labels = batch["labels"].to(self.device)
516
+
517
+ with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp):
518
+ logits = self.model(task, input_ids, attention_mask)
519
+ loss = self._compute_loss(task, logits, labels)
520
+
521
+ total_loss += loss.item()
522
+ n_batches += 1
523
+ all_logits.append(logits.float().cpu())
524
+ all_labels.append(labels.float().cpu())
525
+
526
+ all_logits_t = torch.cat(all_logits, dim=0)
527
+ all_labels_t = torch.cat(all_labels, dim=0)
528
+ results[f"val_{task}_loss"] = total_loss / max(n_batches, 1)
529
+
530
+ if task == "emotion":
531
+ preds = (torch.sigmoid(all_logits_t) > self.config.emotion_threshold).int()
532
+ targets = all_labels_t.int()
533
+ results["val_emotion_sample_f1"] = multilabel_f1(preds, targets)
534
+ results["val_emotion_macro_f1"] = multilabel_macro_f1(preds, targets)
535
+ results["val_emotion_micro_f1"] = multilabel_micro_f1(preds, targets)
536
+ # Store raw logits for threshold tuning later
537
+ results["_emotion_logits"] = all_logits_t
538
+ results["_emotion_labels"] = all_labels_t
539
+
540
+ elif task == "topic":
541
+ preds = all_logits_t.argmax(dim=1).numpy()
542
+ targets = all_labels_t.long().numpy()
543
+ results["val_topic_accuracy"] = float(accuracy_score(targets, preds))
544
+ results["val_topic_macro_f1"] = float(
545
+ f1_score(targets, preds, average="macro", zero_division=0)
546
+ )
547
+
548
+ # Combined metric for early stopping / checkpointing
549
+ metric_parts = []
550
+ if "val_emotion_sample_f1" in results:
551
+ metric_parts.append(results["val_emotion_sample_f1"])
552
+ if "val_topic_accuracy" in results:
553
+ metric_parts.append(results["val_topic_accuracy"])
554
+ results["val_combined_metric"] = sum(metric_parts) / max(len(metric_parts), 1)
555
+
556
+ return results
557
+
558
+ def save_checkpoint(self, path: Path, epoch: int, metrics: Dict[str, Any]) -> None:
559
+ """Save model checkpoint."""
560
+ path.parent.mkdir(parents=True, exist_ok=True)
561
+ # Filter out tensors from metrics
562
+ clean_metrics = {k: v for k, v in metrics.items() if not k.startswith("_")}
563
+ torch.save(
564
+ {
565
+ "epoch": epoch,
566
+ "model_state_dict": self.model.state_dict(),
567
+ "optimizer_state_dict": self.optimizer.state_dict(),
568
+ "scheduler_state_dict": self.scheduler.state_dict(),
569
+ "metrics": clean_metrics,
570
+ "config": {
571
+ "mode": self.mode,
572
+ "tasks": self.model.tasks,
573
+ "model_name": self.config.model_name,
574
+ },
575
+ },
576
+ path,
577
+ )
578
+
579
+ def train(self) -> Dict[str, Any]:
580
+ """Full training loop."""
581
+ print(f"\n{'=' * 60}")
582
+ print(f"Training BERT Baseline — Mode: {self.mode}")
583
+ print(f"{'=' * 60}")
584
+
585
+ param_counts = self.model.param_count()
586
+ print(f" Total parameters: {param_counts['total']:,}")
587
+ print(f" Trainable parameters: {param_counts['trainable']:,}")
588
+ for name, count in param_counts.items():
589
+ if name.startswith("head_"):
590
+ print(f" {name}: {count:,}")
591
+ print(f" Steps/epoch: {self.steps_per_epoch}")
592
+ print(f" Total steps: {self.total_steps}")
593
+ print()
594
+
595
+ all_results: Dict[str, Any] = {"mode": self.mode, "epochs": []}
596
+ start_time = time.time()
597
+
598
+ for epoch in range(self.config.max_epochs):
599
+ epoch_start = time.time()
600
+
601
+ # Train
602
+ train_metrics = self.train_epoch(epoch)
603
+
604
+ # Validate
605
+ val_metrics = self.validate()
606
+
607
+ epoch_time = time.time() - epoch_start
608
+
609
+ # Log
610
+ epoch_result = {
611
+ "epoch": epoch + 1,
612
+ "time_seconds": epoch_time,
613
+ **train_metrics,
614
+ **{k: v for k, v in val_metrics.items() if not k.startswith("_")},
615
+ }
616
+ all_results["epochs"].append(epoch_result)
617
+ self.training_history.append(epoch_result)
618
+
619
+ # Print summary
620
+ print(f"\n Epoch {epoch + 1} ({epoch_time:.0f}s):")
621
+ for k, v in sorted(epoch_result.items()):
622
+ if k not in ("epoch", "time_seconds") and isinstance(v, float):
623
+ print(f" {k}: {v:.4f}")
624
+
625
+ # Checkpointing
626
+ combined = val_metrics["val_combined_metric"]
627
+ if combined > self.best_metric:
628
+ self.best_metric = combined
629
+ self.patience_counter = 0
630
+ self.save_checkpoint(
631
+ self.config.checkpoint_dir / self.mode / "best.pt",
632
+ epoch,
633
+ val_metrics,
634
+ )
635
+ print(f" New best model (combined metric: {combined:.4f})")
636
+ else:
637
+ self.patience_counter += 1
638
+ print(
639
+ f" No improvement ({self.patience_counter}/{self.config.early_stopping_patience})"
640
+ )
641
+
642
+ # Always save epoch checkpoint
643
+ self.save_checkpoint(
644
+ self.config.checkpoint_dir / self.mode / f"epoch_{epoch + 1}.pt",
645
+ epoch,
646
+ val_metrics,
647
+ )
648
+
649
+ # Early stopping
650
+ if self.patience_counter >= self.config.early_stopping_patience:
651
+ print(f"\n Early stopping triggered at epoch {epoch + 1}")
652
+ all_results["early_stopped"] = True
653
+ all_results["best_epoch"] = epoch + 1 - self.config.early_stopping_patience
654
+ break
655
+
656
+ total_time = time.time() - start_time
657
+ all_results["total_time_seconds"] = total_time
658
+ all_results["total_time_human"] = f"{total_time / 3600:.1f}h"
659
+ if "early_stopped" not in all_results:
660
+ all_results["early_stopped"] = False
661
+ all_results["best_epoch"] = (
662
+ epoch + 1 - self.patience_counter
663
+ if self.patience_counter > 0
664
+ else epoch + 1
665
+ )
666
+ all_results["param_counts"] = param_counts
667
+
668
+ print(f"\n Training complete in {total_time / 3600:.1f}h")
669
+ print(f" Best combined metric: {self.best_metric:.4f}")
670
+
671
+ return all_results
672
+
673
+
674
+ # Evaluation
675
+
676
+
677
+ def evaluate_bert_model(
678
+ model: BertBaseline,
679
+ val_loaders: Dict[str, DataLoader],
680
+ device: torch.device,
681
+ config: BertBaselineConfig,
682
+ emotion_classes: Optional[List[str]] = None,
683
+ topic_classes: Optional[List[str]] = None,
684
+ ) -> Dict[str, Any]:
685
+ """Full evaluation with the same metrics as LexiMind's evaluate.py."""
686
+ model.eval()
687
+ results: Dict[str, Any] = {}
688
+
689
+ with torch.no_grad():
690
+ for task, loader in val_loaders.items():
691
+ all_logits = []
692
+ all_labels = []
693
+
694
+ for batch in tqdm(loader, desc=f"Evaluating {task}"):
695
+ input_ids = batch["input_ids"].to(device)
696
+ attention_mask = batch["attention_mask"].to(device)
697
+ labels = batch["labels"].to(device)
698
+
699
+ with autocast(dtype=torch.bfloat16, enabled=config.use_amp):
700
+ logits = model(task, input_ids, attention_mask)
701
+
702
+ all_logits.append(logits.float().cpu())
703
+ all_labels.append(labels.float().cpu())
704
+
705
+ all_logits_t = torch.cat(all_logits, dim=0)
706
+ all_labels_t = torch.cat(all_labels, dim=0)
707
+
708
+ if task == "emotion":
709
+ # Default threshold
710
+ preds_default = (
711
+ (torch.sigmoid(all_logits_t) > config.emotion_threshold).int()
712
+ )
713
+ targets = all_labels_t.int()
714
+
715
+ results["emotion"] = {
716
+ "default_threshold": config.emotion_threshold,
717
+ "sample_avg_f1": multilabel_f1(preds_default, targets),
718
+ "macro_f1": multilabel_macro_f1(preds_default, targets),
719
+ "micro_f1": multilabel_micro_f1(preds_default, targets),
720
+ }
721
+
722
+ # Per-class metrics
723
+ if emotion_classes:
724
+ per_class = multilabel_per_class_metrics(
725
+ preds_default, targets, emotion_classes
726
+ )
727
+ results["emotion"]["per_class"] = per_class
728
+
729
+ # Threshold tuning
730
+ best_thresholds, tuned_macro = tune_per_class_thresholds(
731
+ all_logits_t, all_labels_t
732
+ )
733
+ tuned_preds = torch.zeros_like(all_logits_t)
734
+ probs = torch.sigmoid(all_logits_t)
735
+ for c in range(all_logits_t.shape[1]):
736
+ tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
737
+ tuned_preds = tuned_preds.int()
738
+
739
+ results["emotion"]["tuned_macro_f1"] = tuned_macro
740
+ results["emotion"]["tuned_sample_avg_f1"] = multilabel_f1(
741
+ tuned_preds, targets
742
+ )
743
+ results["emotion"]["tuned_micro_f1"] = multilabel_micro_f1(
744
+ tuned_preds, targets
745
+ )
746
+
747
+ # Bootstrap CI on sample-avg F1
748
+ per_sample_f1 = []
749
+ for i in range(preds_default.shape[0]):
750
+ p = preds_default[i].float()
751
+ g = targets[i].float()
752
+ tp = (p * g).sum()
753
+ prec = tp / p.sum().clamp(min=1)
754
+ rec = tp / g.sum().clamp(min=1)
755
+ f = (2 * prec * rec) / (prec + rec).clamp(min=1e-8)
756
+ per_sample_f1.append(f.item())
757
+ mean_f1, ci_low, ci_high = bootstrap_confidence_interval(per_sample_f1)
758
+ results["emotion"]["sample_avg_f1_ci"] = [ci_low, ci_high]
759
+
760
+ elif task == "topic":
761
+ preds = all_logits_t.argmax(dim=1).numpy()
762
+ targets = all_labels_t.long().numpy()
763
+
764
+ acc = float(accuracy_score(targets, preds))
765
+ macro_f1 = float(
766
+ f1_score(targets, preds, average="macro", zero_division=0)
767
+ )
768
+
769
+ results["topic"] = {
770
+ "accuracy": acc,
771
+ "macro_f1": macro_f1,
772
+ }
773
+
774
+ # Per-class metrics
775
+ if topic_classes:
776
+ report = classification_report(
777
+ targets,
778
+ preds,
779
+ target_names=topic_classes,
780
+ output_dict=True,
781
+ zero_division=0,
782
+ )
783
+ results["topic"]["per_class"] = {
784
+ name: {
785
+ "precision": report[name]["precision"],
786
+ "recall": report[name]["recall"],
787
+ "f1": report[name]["f1-score"],
788
+ "support": report[name]["support"],
789
+ }
790
+ for name in topic_classes
791
+ if name in report
792
+ }
793
+
794
+ # Bootstrap CI on accuracy
795
+ per_sample_correct = (preds == targets).astype(float).tolist()
796
+ mean_acc, ci_low, ci_high = bootstrap_confidence_interval(per_sample_correct)
797
+ results["topic"]["accuracy_ci"] = [ci_low, ci_high]
798
+
799
+ return results
800
+
801
+
802
+ # Main Pipeline
803
+
804
+ def set_seed(seed: int) -> None:
805
+ random.seed(seed)
806
+ np.random.seed(seed)
807
+ torch.manual_seed(seed)
808
+ if torch.cuda.is_available():
809
+ torch.cuda.manual_seed_all(seed)
810
+
811
+
812
+ def load_data(config: BertBaselineConfig):
813
+ """Load all datasets and create label encoders."""
814
+ data_dir = config.data_dir
815
+
816
+ # Load emotion data
817
+ emo_train = load_emotion_jsonl(str(data_dir / "emotion" / "train.jsonl"))
818
+ emo_val_path = data_dir / "emotion" / "validation.jsonl"
819
+ if not emo_val_path.exists():
820
+ emo_val_path = data_dir / "emotion" / "val.jsonl"
821
+ emo_val = load_emotion_jsonl(str(emo_val_path))
822
+
823
+ # Load topic data
824
+ top_train = load_topic_jsonl(str(data_dir / "topic" / "train.jsonl"))
825
+ top_val_path = data_dir / "topic" / "validation.jsonl"
826
+ if not top_val_path.exists():
827
+ top_val_path = data_dir / "topic" / "val.jsonl"
828
+ top_val = load_topic_jsonl(str(top_val_path))
829
+
830
+ # Fit label encoders on training data (same as LexiMind)
831
+ binarizer = MultiLabelBinarizer()
832
+ binarizer.fit([ex.emotions for ex in emo_train])
833
+
834
+ label_encoder = LabelEncoder()
835
+ label_encoder.fit([ex.topic for ex in top_train])
836
+
837
+ print(f" Emotion: {len(emo_train)} train, {len(emo_val)} val, {len(binarizer.classes_)} classes")
838
+ print(f" Topic: {len(top_train)} train, {len(top_val)} val, {len(label_encoder.classes_)} classes")
839
+ print(f" Emotion classes: {list(binarizer.classes_)[:5]}...")
840
+ print(f" Topic classes: {list(label_encoder.classes_)}")
841
+
842
+ return {
843
+ "emotion_train": emo_train,
844
+ "emotion_val": emo_val,
845
+ "topic_train": top_train,
846
+ "topic_val": top_val,
847
+ "binarizer": binarizer,
848
+ "label_encoder": label_encoder,
849
+ }
850
+
851
+
852
+ def run_experiment(mode: str, config: BertBaselineConfig) -> Dict[str, Any]:
853
+ """Run a single experiment (single-topic, single-emotion, or multitask)."""
854
+ print(f"\n{'═' * 60}")
855
+ print(f" BERT BASELINE EXPERIMENT: {mode.upper()}")
856
+ print(f"{'═' * 60}")
857
+
858
+ set_seed(config.seed)
859
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
860
+ print(f" Device: {device}")
861
+ if torch.cuda.is_available():
862
+ print(f" GPU: {torch.cuda.get_device_name()}")
863
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
864
+
865
+ # CUDA optimizations
866
+ if torch.cuda.is_available():
867
+ torch.backends.cudnn.benchmark = True
868
+ if hasattr(torch.backends, "cuda"):
869
+ torch.backends.cuda.matmul.allow_tf32 = True
870
+
871
+ # Load tokenizer
872
+ print(f"\n Loading tokenizer: {config.model_name}")
873
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
874
+
875
+ # Load data
876
+ print(" Loading datasets...")
877
+ data = load_data(config)
878
+
879
+ # Determine tasks for this mode
880
+ if mode == "single-topic":
881
+ tasks = ["topic"]
882
+ elif mode == "single-emotion":
883
+ tasks = ["emotion"]
884
+ else:
885
+ tasks = ["emotion", "topic"]
886
+
887
+ # Create datasets
888
+ train_loaders: Dict[str, DataLoader] = {}
889
+ val_loaders: Dict[str, DataLoader] = {}
890
+
891
+ if "emotion" in tasks:
892
+ emo_train_ds = BertEmotionDataset(
893
+ data["emotion_train"], tokenizer, data["binarizer"], config.max_length
894
+ )
895
+ emo_val_ds = BertEmotionDataset(
896
+ data["emotion_val"], tokenizer, data["binarizer"], config.max_length
897
+ )
898
+ train_loaders["emotion"] = DataLoader(
899
+ emo_train_ds,
900
+ batch_size=config.batch_size,
901
+ shuffle=True,
902
+ num_workers=4,
903
+ pin_memory=True,
904
+ persistent_workers=True,
905
+ )
906
+ val_loaders["emotion"] = DataLoader(
907
+ emo_val_ds,
908
+ batch_size=config.batch_size * 2,
909
+ shuffle=False,
910
+ num_workers=4,
911
+ pin_memory=True,
912
+ )
913
+
914
+ if "topic" in tasks:
915
+ top_train_ds = BertTopicDataset(
916
+ data["topic_train"], tokenizer, data["label_encoder"], config.max_length
917
+ )
918
+ top_val_ds = BertTopicDataset(
919
+ data["topic_val"], tokenizer, data["label_encoder"], config.max_length
920
+ )
921
+ train_loaders["topic"] = DataLoader(
922
+ top_train_ds,
923
+ batch_size=config.batch_size,
924
+ shuffle=True,
925
+ num_workers=4,
926
+ pin_memory=True,
927
+ persistent_workers=True,
928
+ )
929
+ val_loaders["topic"] = DataLoader(
930
+ top_val_ds,
931
+ batch_size=config.batch_size * 2,
932
+ shuffle=False,
933
+ num_workers=4,
934
+ pin_memory=True,
935
+ )
936
+
937
+ # Create model
938
+ print(f"\n Creating model with tasks: {tasks}")
939
+ model = BertBaseline(
940
+ model_name=config.model_name,
941
+ num_emotions=len(data["binarizer"].classes_),
942
+ num_topics=len(data["label_encoder"].classes_),
943
+ tasks=tasks,
944
+ freeze_layers=config.freeze_layers,
945
+ ).to(device)
946
+
947
+ # Train
948
+ trainer = BertTrainer(model, config, train_loaders, val_loaders, device, mode)
949
+ training_results = trainer.train()
950
+
951
+ # Load best checkpoint for final evaluation
952
+ best_path = config.checkpoint_dir / mode / "best.pt"
953
+ if best_path.exists():
954
+ print(f"\n Loading best checkpoint for final evaluation...")
955
+ checkpoint = torch.load(best_path, map_location=device, weights_only=False)
956
+ model.load_state_dict(checkpoint["model_state_dict"])
957
+
958
+ # Full evaluation
959
+ print(f"\n Running final evaluation...")
960
+ eval_results = evaluate_bert_model(
961
+ model,
962
+ val_loaders,
963
+ device,
964
+ config,
965
+ emotion_classes=list(data["binarizer"].classes_) if "emotion" in tasks else None,
966
+ topic_classes=list(data["label_encoder"].classes_) if "topic" in tasks else None,
967
+ )
968
+
969
+ # Combine results
970
+ final_results = {
971
+ "mode": mode,
972
+ "model": config.model_name,
973
+ "tasks": tasks,
974
+ "training": training_results,
975
+ "evaluation": eval_results,
976
+ }
977
+
978
+ # Save results
979
+ output_path = config.output_dir / f"{mode}_results.json"
980
+ output_path.parent.mkdir(parents=True, exist_ok=True)
981
+
982
+ # Remove non-serializable fields
983
+ def make_serializable(obj):
984
+ if isinstance(obj, dict):
985
+ return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")}
986
+ if isinstance(obj, list):
987
+ return [make_serializable(item) for item in obj]
988
+ if isinstance(obj, (np.integer, np.int64)):
989
+ return int(obj)
990
+ if isinstance(obj, (np.floating, np.float64)):
991
+ return float(obj)
992
+ if isinstance(obj, np.ndarray):
993
+ return obj.tolist()
994
+ return obj
995
+
996
+ with open(output_path, "w") as f:
997
+ json.dump(make_serializable(final_results), f, indent=2)
998
+ print(f"\n Results saved to {output_path}")
999
+
1000
+ return final_results
1001
+
1002
+
1003
+ def print_comparison_summary(all_results: Dict[str, Dict[str, Any]]) -> None:
1004
+ """Print a side-by-side comparison of all experiments."""
1005
+ print(f"\n{'═' * 70}")
1006
+ print(" BERT BASELINE COMPARISON SUMMARY")
1007
+ print(f"{'═' * 70}")
1008
+
1009
+ # Header
1010
+ modes = list(all_results.keys())
1011
+ header = f"{'Metric':<30}" + "".join(f"{m:>16}" for m in modes) + f"{'LexiMind':>16}"
1012
+ print(f"\n {header}")
1013
+ print(f" {'─' * len(header)}")
1014
+
1015
+ # LexiMind reference values
1016
+ lexmind = {
1017
+ "topic_accuracy": 0.8571,
1018
+ "topic_macro_f1": 0.8539,
1019
+ "emotion_sample_f1": 0.3523,
1020
+ "emotion_macro_f1": 0.1432,
1021
+ "emotion_micro_f1": 0.4430,
1022
+ "emotion_tuned_macro_f1": 0.2936,
1023
+ }
1024
+
1025
+ # Topic metrics
1026
+ print(f"\n {'Topic Classification':}")
1027
+ for metric_name, display_name in [
1028
+ ("accuracy", "Accuracy"),
1029
+ ("macro_f1", "Macro F1"),
1030
+ ]:
1031
+ row = f" {display_name:<30}"
1032
+ for mode in modes:
1033
+ eval_data = all_results[mode].get("evaluation", {})
1034
+ topic = eval_data.get("topic", {})
1035
+ val = topic.get(metric_name, None)
1036
+ row += f"{val:>16.4f}" if val is not None else f"{'—':>16}"
1037
+ lm_key = f"topic_{metric_name}"
1038
+ row += f"{lexmind.get(lm_key, 0):>16.4f}"
1039
+ print(row)
1040
+
1041
+ # Emotion metrics
1042
+ print(f"\n {'Emotion Detection':}")
1043
+ for metric_name, display_name in [
1044
+ ("sample_avg_f1", "Sample-avg F1 (τ=0.3)"),
1045
+ ("macro_f1", "Macro F1 (τ=0.3)"),
1046
+ ("micro_f1", "Micro F1 (τ=0.3)"),
1047
+ ("tuned_macro_f1", "Tuned Macro F1"),
1048
+ ("tuned_sample_avg_f1", "Tuned Sample-avg F1"),
1049
+ ]:
1050
+ row = f" {display_name:<30}"
1051
+ for mode in modes:
1052
+ eval_data = all_results[mode].get("evaluation", {})
1053
+ emo = eval_data.get("emotion", {})
1054
+ val = emo.get(metric_name, None)
1055
+ row += f"{val:>16.4f}" if val is not None else f"{'—':>16}"
1056
+ lm_key = f"emotion_{metric_name}"
1057
+ row += f"{lexmind.get(lm_key, 0):>16.4f}"
1058
+ print(row)
1059
+
1060
+ # Training time
1061
+ print(f"\n {'Training Time':}")
1062
+ row = f" {'Hours':<30}"
1063
+ for mode in modes:
1064
+ t = all_results[mode].get("training", {}).get("total_time_seconds", 0) / 3600
1065
+ row += f"{t:>15.1f}h"
1066
+ row += f"{'~9.0h':>16}"
1067
+ print(row)
1068
+
1069
+ print(f"\n{'═' * 70}\n")
1070
+
1071
+
1072
+ def main():
1073
+ parser = argparse.ArgumentParser(description="BERT Baseline Training for LexiMind")
1074
+ parser.add_argument(
1075
+ "--mode",
1076
+ type=str,
1077
+ required=True,
1078
+ choices=["single-topic", "single-emotion", "multitask", "all"],
1079
+ help="Training mode",
1080
+ )
1081
+ parser.add_argument("--epochs", type=int, default=None, help="Override max epochs")
1082
+ parser.add_argument("--lr", type=float, default=None, help="Override learning rate")
1083
+ parser.add_argument("--batch-size", type=int, default=None, help="Override batch size")
1084
+ parser.add_argument(
1085
+ "--model", type=str, default="bert-base-uncased", help="HuggingFace model name"
1086
+ )
1087
+ args = parser.parse_args()
1088
+
1089
+ config = BertBaselineConfig()
1090
+ config.model_name = args.model
1091
+ if args.epochs is not None:
1092
+ config.max_epochs = args.epochs
1093
+ if args.lr is not None:
1094
+ config.lr = args.lr
1095
+ if args.batch_size is not None:
1096
+ config.batch_size = args.batch_size
1097
+
1098
+ if args.mode == "all":
1099
+ modes = ["single-topic", "single-emotion", "multitask"]
1100
+ else:
1101
+ modes = [args.mode]
1102
+
1103
+ all_results: Dict[str, Dict[str, Any]] = {}
1104
+ for mode in modes:
1105
+ results = run_experiment(mode, config)
1106
+ all_results[mode] = results
1107
+
1108
+ # Clear GPU memory between experiments
1109
+ if torch.cuda.is_available():
1110
+ torch.cuda.empty_cache()
1111
+
1112
+ # Save combined results
1113
+ if len(all_results) > 1:
1114
+ combined_path = config.output_dir / "combined_results.json"
1115
+
1116
+ def make_serializable(obj):
1117
+ if isinstance(obj, dict):
1118
+ return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")}
1119
+ if isinstance(obj, list):
1120
+ return [make_serializable(item) for item in obj]
1121
+ if isinstance(obj, (np.integer, np.int64)):
1122
+ return int(obj)
1123
+ if isinstance(obj, (np.floating, np.float64)):
1124
+ return float(obj)
1125
+ if isinstance(obj, np.ndarray):
1126
+ return obj.tolist()
1127
+ return obj
1128
+
1129
+ with open(combined_path, "w") as f:
1130
+ json.dump(make_serializable(all_results), f, indent=2)
1131
+ print(f" Combined results saved to {combined_path}")
1132
+
1133
+ print_comparison_summary(all_results)
1134
+
1135
+
1136
+ if __name__ == "__main__":
1137
+ main()