apapagi commited on
Commit
37751f5
·
verified ·
1 Parent(s): 51522dc

Upload 3 files

Browse files
Files changed (3) hide show
  1. eurovoc.py +691 -0
  2. inference_test.ipynb +386 -0
  3. train_lora_included.ipynb +687 -0
eurovoc.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch.nn as nn
6
+ import torch
7
+ from transformers import BertTokenizerFast as BertTokenizer, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel
8
+ import json
9
+ import random
10
+ from collections import Counter
11
+ from tqdm.auto import tqdm
12
+ import gzip
13
+ from pathlib import Path
14
+ import os
15
+ from datasets import load_dataset
16
+ from peft import LoraConfig, get_peft_model, TaskType, PeftModel
17
+
18
+
19
+ def save_split_config(train_files, val_files, config_path, metadata=None):
20
+ """
21
+ Save train/val split configuration to a JSON file.
22
+
23
+ Args:
24
+ train_files: List of training file paths
25
+ val_files: List of validation file paths
26
+ config_path: Path to save the configuration JSON
27
+ metadata: Optional dict with additional info (train_ratio, seed, etc.)
28
+ """
29
+ config = {
30
+ 'train_files': train_files,
31
+ 'val_files': val_files,
32
+ 'num_train_files': len(train_files),
33
+ 'num_val_files': len(val_files),
34
+ 'metadata': metadata or {}
35
+ }
36
+
37
+ # Create directory if it doesn't exist
38
+ os.makedirs(os.path.dirname(config_path) if os.path.dirname(config_path) else '.', exist_ok=True)
39
+
40
+ with open(config_path, 'w') as f:
41
+ json.dump(config, f, indent=2)
42
+
43
+ print(f"✓ Split configuration saved to {config_path}")
44
+
45
+ def load_split_config(config_path):
46
+ """
47
+ Load train/val split configuration from a JSON file.
48
+
49
+ Args:
50
+ config_path: Path to the configuration JSON
51
+
52
+ Returns:
53
+ Tuple of (train_files, val_files, metadata)
54
+ """
55
+ with open(config_path, 'r') as f:
56
+ config = json.load(f)
57
+
58
+ print(f"✓ Loaded split configuration from {config_path}")
59
+ print(f" Train files: {config['num_train_files']}")
60
+ print(f" Val files: {config['num_val_files']}")
61
+
62
+ return config['train_files'], config['val_files'], config.get('metadata', {})
63
+
64
+ def get_file_label_stats(jsonl_files):
65
+ """
66
+ Get label distribution from all files.
67
+ Since we need accurate stats for rare labels, we count everything.
68
+
69
+ Args:
70
+ jsonl_files: List of paths to JSONL files
71
+
72
+ Returns:
73
+ Dict mapping file paths to their label statistics
74
+ """
75
+ file_labels = {}
76
+
77
+ print(f"Analyzing {len(jsonl_files)} files...")
78
+ for file_path in tqdm(jsonl_files):
79
+ label_counts = Counter()
80
+ total_records = 0
81
+
82
+ if file_path.endswith('.gz'):
83
+ open_func = lambda f: gzip.open(f, 'rt', encoding='utf-8')
84
+ else:
85
+ open_func = lambda f: open(f, 'r', encoding='utf-8')
86
+
87
+ with open_func(file_path) as f:
88
+ for line in f:
89
+ try:
90
+ record = json.loads(line)
91
+ eurovoc_ids = record.get('eurovoc_ids', [])
92
+ label_counts.update(eurovoc_ids)
93
+ total_records += 1
94
+ except Exception as e:
95
+ continue
96
+
97
+ file_labels[file_path] = {
98
+ 'label_counts': label_counts,
99
+ 'total_records': total_records
100
+ }
101
+
102
+ return file_labels
103
+
104
+
105
+ def smart_split_files(all_jsonl_files, train_ratio=0.92,
106
+ rare_threshold=0.005, seed=42, verbose=True,
107
+ save_config_path=None):
108
+ """
109
+ Split files ensuring rare labels appear in training set.
110
+
111
+ Args:
112
+ all_jsonl_files: List of all JSONL file paths
113
+ train_ratio: Fraction of files for training (default 0.92)
114
+ rare_threshold: Labels appearing in < this fraction are considered rare
115
+ seed: Random seed for reproducibility
116
+ verbose: Print statistics
117
+ save_config_path: If provided, save the split configuration to this path
118
+
119
+ Returns:
120
+ Tuple of (train_files, val_files)
121
+ """
122
+ random.seed(seed)
123
+
124
+ if verbose:
125
+ print("Analyzing label distribution across files...")
126
+
127
+ file_stats = get_file_label_stats(all_jsonl_files)
128
+
129
+ # Calculate which labels are rare globally
130
+ global_label_counts = Counter()
131
+ for stats in file_stats.values():
132
+ global_label_counts.update(stats['label_counts'])
133
+
134
+ # Identify rare labels
135
+ total_labels = sum(global_label_counts.values())
136
+ rare_count_threshold = total_labels * rare_threshold
137
+ rare_labels = {label for label, count in global_label_counts.items()
138
+ if count < rare_count_threshold}
139
+
140
+ if verbose:
141
+ print(f"Found {len(rare_labels)} rare labels out of {len(global_label_counts)} total")
142
+
143
+ # Score files by number of rare labels they contain
144
+ file_rare_counts = {}
145
+ for file_path, stats in file_stats.items():
146
+ file_labels_set = set(stats['label_counts'].keys())
147
+ rare_in_file = file_labels_set & rare_labels
148
+ file_rare_counts[file_path] = len(rare_in_file)
149
+
150
+ # Sort files by rare label count (descending)
151
+ sorted_files = sorted(file_rare_counts.items(), key=lambda x: x[1], reverse=True)
152
+
153
+ # Calculate split point
154
+ split_idx = int(len(all_jsonl_files) * train_ratio)
155
+
156
+ # Assign files
157
+ train_files = [f for f, _ in sorted_files[:split_idx]]
158
+ val_files = [f for f, _ in sorted_files[split_idx:]]
159
+
160
+ # Calculate stats
161
+ train_rare_count = sum(1 for f in train_files if file_rare_counts[f] > 0)
162
+ val_rare_count = sum(1 for f in val_files if file_rare_counts[f] > 0)
163
+
164
+ if verbose:
165
+ print(f"Train files: {len(train_files)} ({train_rare_count} with rare labels)")
166
+ print(f"Val files: {len(val_files)} ({val_rare_count} with rare labels)")
167
+
168
+ # Check label coverage
169
+ train_labels = set()
170
+ val_labels = set()
171
+ for f in train_files:
172
+ train_labels.update(file_stats[f]['label_counts'].keys())
173
+ for f in val_files:
174
+ val_labels.update(file_stats[f]['label_counts'].keys())
175
+
176
+ labels_only_in_train = train_labels - val_labels
177
+ labels_only_in_val = val_labels - train_labels
178
+
179
+ print(f"Labels only in train: {len(labels_only_in_train)}")
180
+ print(f"Labels only in val: {len(labels_only_in_val)}")
181
+ if len(labels_only_in_val) > 0:
182
+ print(f"⚠️ WARNING: {len(labels_only_in_val)} labels appear only in validation!")
183
+
184
+ # Save configuration if path provided
185
+ if save_config_path:
186
+ metadata = {
187
+ 'train_ratio': train_ratio,
188
+ 'rare_threshold': rare_threshold,
189
+ 'seed': seed,
190
+ 'total_files': len(all_jsonl_files),
191
+ 'num_rare_labels': len(rare_labels),
192
+ 'num_total_labels': len(global_label_counts),
193
+ 'train_rare_count': train_rare_count,
194
+ 'val_rare_count': val_rare_count
195
+ }
196
+ save_split_config(train_files, val_files, save_config_path, metadata)
197
+
198
+ return train_files, val_files
199
+
200
+
201
+
202
+ class EurovocDataset(Dataset):
203
+
204
+ def __init__(
205
+ self,
206
+ text: np.array,
207
+ labels: np.array,
208
+ tokenizer: BertTokenizer,
209
+ max_token_len: int = 128
210
+ ):
211
+ self.tokenizer = tokenizer
212
+ self.text = text
213
+ self.labels = labels
214
+ self.max_token_len = max_token_len
215
+
216
+ def __len__(self):
217
+ return len(self.labels)
218
+
219
+ def __getitem__(self, index: int):
220
+ text = self.text[index][0]
221
+ labels = self.labels[index]
222
+
223
+ encoding = self.tokenizer.encode_plus(
224
+ text,
225
+ add_special_tokens=True,
226
+ max_length=self.max_token_len,
227
+ return_token_type_ids=False,
228
+ padding="max_length",
229
+ truncation=True,
230
+ return_attention_mask=True,
231
+ return_tensors='pt',
232
+ )
233
+
234
+ return dict(
235
+ text=text,
236
+ input_ids=encoding["input_ids"].flatten(),
237
+ attention_mask=encoding["attention_mask"].flatten(),
238
+ labels=torch.FloatTensor(labels)
239
+ )
240
+
241
+
242
+ class StreamingEurovocDataset(IterableDataset):
243
+ """
244
+ Streaming dataset that doesn't load everything into memory.
245
+ Processes one record at a time from disk.
246
+ """
247
+ def __init__(self, jsonl_files, mlb, tokenizer, max_token_len=512, split='train'):
248
+ self.jsonl_files = jsonl_files
249
+ self.mlb = mlb
250
+ self.tokenizer = tokenizer
251
+ self.max_token_len = max_token_len
252
+ self.split = split
253
+
254
+ def __iter__(self):
255
+ dataset = load_dataset(
256
+ 'json',
257
+ data_files=self.jsonl_files,
258
+ streaming=True,
259
+ split='train'
260
+ )
261
+
262
+ for record in dataset:
263
+ text = record.get('text')
264
+ eurovoc_ids = record.get('eurovoc_ids', [])
265
+
266
+ # Skip invalid records
267
+ if not text or not eurovoc_ids:
268
+ continue
269
+
270
+ # Convert concepts to binary labels
271
+ labels = self.mlb.transform([eurovoc_ids])[0]
272
+
273
+ # Tokenize
274
+ encoding = self.tokenizer.encode_plus(
275
+ text,
276
+ add_special_tokens=True,
277
+ max_length=self.max_token_len,
278
+ return_token_type_ids=False,
279
+ padding="max_length",
280
+ truncation=True,
281
+ return_attention_mask=True,
282
+ return_tensors='pt',
283
+ )
284
+
285
+ yield {
286
+ 'input_ids': encoding["input_ids"].flatten(),
287
+ 'attention_mask': encoding["attention_mask"].flatten(),
288
+ 'labels': torch.FloatTensor(labels)
289
+ }
290
+
291
+
292
+
293
+ class EuroVocLongTextDataset(Dataset):
294
+
295
+ def __splitter__(text, max_lenght):
296
+ l = text.split()
297
+ for i in range(0, len(l), max_lenght):
298
+ yield l[i:i + max_lenght]
299
+
300
+ def __init__(
301
+ self,
302
+ text: np.array,
303
+ labels: np.array,
304
+ tokenizer: BertTokenizer,
305
+ max_token_len: int = 128
306
+ ):
307
+ self.tokenizer = tokenizer
308
+ self.text = text
309
+ self.labels = labels
310
+ self.max_token_len = max_token_len
311
+
312
+ self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)]
313
+
314
+ self.encoding = self.tokenizer.batch_encode_plus(
315
+ [c for c, _ in self.chunks_and_labels],
316
+ add_special_tokens=True,
317
+ max_length=self.max_token_len,
318
+ return_token_type_ids=False,
319
+ padding="max_length",
320
+ truncation=True,
321
+ return_attention_mask=True,
322
+ return_tensors='pt',
323
+ )
324
+
325
+ def __len__(self):
326
+ return len(self.chunks_and_labels)
327
+
328
+ def __getitem__(self, index: int):
329
+ text, labels = self.chunks_and_labels[index]
330
+
331
+ return dict(
332
+ text=text,
333
+ input_ids=self.encoding[index]["input_ids"].flatten(),
334
+ attention_mask=self.encoding[index]["attention_mask"].flatten(),
335
+ labels=torch.FloatTensor(labels)
336
+ )
337
+
338
+
339
+ class EurovocDataModule(pl.LightningDataModule):
340
+
341
+ def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512):
342
+ super().__init__()
343
+
344
+ self.batch_size = batch_size
345
+ self.x_tr = x_tr
346
+ self.y_tr = y_tr
347
+ self.x_test = x_test
348
+ self.y_test = y_test
349
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
350
+ self.max_token_len = max_token_len
351
+
352
+ def setup(self, stage=None):
353
+ self.train_dataset = EurovocDataset(
354
+ self.x_tr,
355
+ self.y_tr,
356
+ self.tokenizer,
357
+ self.max_token_len
358
+ )
359
+
360
+ self.test_dataset = EurovocDataset(
361
+ self.x_test,
362
+ self.y_test,
363
+ self.tokenizer,
364
+ self.max_token_len
365
+ )
366
+
367
+ def train_dataloader(self):
368
+ return DataLoader(
369
+ self.train_dataset,
370
+ batch_size=self.batch_size,
371
+ shuffle=True,
372
+ num_workers=2
373
+ )
374
+
375
+ def val_dataloader(self):
376
+ return DataLoader(
377
+ self.test_dataset,
378
+ batch_size=self.batch_size,
379
+ num_workers=2
380
+ )
381
+
382
+ def test_dataloader(self):
383
+ return DataLoader(
384
+ self.test_dataset,
385
+ batch_size=self.batch_size,
386
+ num_workers=2
387
+ )
388
+
389
+ class StreamingEurovocDataModule(pl.LightningDataModule):
390
+ """
391
+ DataModule that uses streaming datasets.
392
+ Supports both random and smart (stratified) file splitting.
393
+ Can load pre-computed splits from config file.
394
+ """
395
+ def __init__(self, bert_model_name, all_jsonl_files, mlb,
396
+ batch_size=64, max_token_len=512,
397
+ train_ratio=0.92, rare_threshold=0.005,
398
+ split_strategy='smart',
399
+ split_config_path="../eurovoc_data/train_val_split_config.json",
400
+ save_split_config_path="../eurovoc_data/train_val_split_config.json"):
401
+ """
402
+ Args:
403
+ bert_model_name: Name of the BERT model to use
404
+ all_jsonl_files: List of all JSONL file paths (ignored if split_config_path provided)
405
+ mlb: Fitted MultiLabelBinarizer
406
+ batch_size: Batch size for dataloaders
407
+ max_token_len: Maximum token length for tokenization
408
+ train_ratio: Fraction of files for training
409
+ rare_threshold: Threshold for rare label identification
410
+ split_strategy: 'random' or 'smart'
411
+ split_config_path: Path to existing split config JSON (if provided, loads from this)
412
+ save_split_config_path: Path to save new split config JSON
413
+ """
414
+ super().__init__()
415
+ self.batch_size = batch_size
416
+ self.mlb = mlb
417
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
418
+ self.max_token_len = max_token_len
419
+
420
+ # Option 1: Load from existing config
421
+ if split_config_path and os.path.exists(split_config_path):
422
+ print(f"Loading split from existing config: {split_config_path}")
423
+ self.train_files, self.val_files, metadata = load_split_config(split_config_path)
424
+ if metadata:
425
+ print(f"Split metadata: {metadata}")
426
+
427
+ # Option 2: Create new split
428
+ else:
429
+ if split_strategy == 'smart':
430
+ print("Using smart split strategy (ensuring rare label coverage)...")
431
+ self.train_files, self.val_files = smart_split_files(
432
+ all_jsonl_files,
433
+ train_ratio=train_ratio,
434
+ rare_threshold=rare_threshold,
435
+ save_config_path=save_split_config_path
436
+ )
437
+ elif split_strategy == 'random':
438
+ print("Using random split strategy...")
439
+ random.shuffle(all_jsonl_files)
440
+
441
+ split_idx = int(len(all_jsonl_files) * train_ratio)
442
+ self.train_files = all_jsonl_files[:split_idx]
443
+ self.val_files = all_jsonl_files[split_idx:]
444
+
445
+ print(f"Train files: {len(self.train_files)}")
446
+ print(f"Val files: {len(self.val_files)}")
447
+
448
+ # Save config if requested
449
+ if save_split_config_path:
450
+ metadata = {
451
+ 'train_ratio': train_ratio,
452
+ 'split_strategy': 'random',
453
+ 'total_files': len(all_jsonl_files)
454
+ }
455
+ save_split_config(self.train_files, self.val_files,
456
+ save_split_config_path, metadata)
457
+ else:
458
+ raise ValueError(f"Unknown split_strategy: {split_strategy}. Use 'random' or 'smart'")
459
+
460
+ def setup(self, stage=None):
461
+ self.train_dataset = StreamingEurovocDataset(
462
+ self.train_files,
463
+ self.mlb,
464
+ self.tokenizer,
465
+ self.max_token_len
466
+ )
467
+
468
+ self.val_dataset = StreamingEurovocDataset(
469
+ self.val_files,
470
+ self.mlb,
471
+ self.tokenizer,
472
+ self.max_token_len
473
+ )
474
+
475
+ def train_dataloader(self):
476
+ return DataLoader(
477
+ self.train_dataset,
478
+ batch_size=self.batch_size,
479
+ num_workers=4,
480
+ pin_memory=True
481
+ )
482
+
483
+ def val_dataloader(self):
484
+ return DataLoader(
485
+ self.val_dataset,
486
+ batch_size=self.batch_size,
487
+ num_workers=4,
488
+ pin_memory=True
489
+ )
490
+
491
+
492
+ class EurovocTagger(pl.LightningModule):
493
+
494
+ def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
495
+ super().__init__()
496
+ self.bert = AutoModel.from_pretrained(bert_model_name)
497
+ self.dropout = nn.Dropout(p=0.2)
498
+ self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
499
+ self.criterion = nn.BCELoss()
500
+ self.lr = lr
501
+ self.eps = eps
502
+
503
+ def forward(self, input_ids, attention_mask, labels=None):
504
+ output = self.bert(input_ids, attention_mask=attention_mask)
505
+ output = self.dropout(output.pooler_output)
506
+ output = self.classifier1(output)
507
+ output = torch.sigmoid(output)
508
+ loss = 0
509
+ if labels is not None:
510
+ loss = self.criterion(output, labels)
511
+ return loss, output
512
+
513
+ def training_step(self, batch, batch_idx):
514
+ input_ids = batch["input_ids"]
515
+ attention_mask = batch["attention_mask"]
516
+ labels = batch["labels"]
517
+ loss, outputs = self(input_ids, attention_mask, labels)
518
+ self.log("train_loss", loss, prog_bar=True, logger=True)
519
+ return {"loss": loss, "predictions": outputs, "labels": labels}
520
+
521
+ def validation_step(self, batch, batch_idx):
522
+ input_ids = batch["input_ids"]
523
+ attention_mask = batch["attention_mask"]
524
+ labels = batch["labels"]
525
+ loss, outputs = self(input_ids, attention_mask, labels)
526
+ self.log("val_loss", loss, prog_bar=True, logger=True)
527
+ return loss
528
+
529
+ def test_step(self, batch, batch_idx):
530
+ input_ids = batch["input_ids"]
531
+ attention_mask = batch["attention_mask"]
532
+ labels = batch["labels"]
533
+ loss, outputs = self(input_ids, attention_mask, labels)
534
+ self.log("test_loss", loss, prog_bar=True, logger=True)
535
+ return loss
536
+
537
+
538
+ def configure_optimizers(self):
539
+ return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
540
+
541
+
542
+ class EurovocTaggerBCELogit(pl.LightningModule):
543
+
544
+ def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
545
+ super().__init__()
546
+ self.bert = AutoModel.from_pretrained(bert_model_name)
547
+ self.dropout = nn.Dropout(p=0.2)
548
+ self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
549
+ self.criterion = nn.BCEWithLogitsLoss()
550
+ self.lr = lr
551
+ self.eps = eps
552
+
553
+ def forward(self, input_ids, attention_mask, labels=None):
554
+ output = self.bert(input_ids, attention_mask=attention_mask)
555
+ output = self.dropout(output.pooler_output)
556
+ output = self.classifier1(output)
557
+ loss = 0
558
+ if labels is not None:
559
+ loss = self.criterion(output, labels)
560
+ return loss, output
561
+
562
+ def training_step(self, batch, batch_idx):
563
+ input_ids = batch["input_ids"]
564
+ attention_mask = batch["attention_mask"]
565
+ labels = batch["labels"]
566
+ loss, outputs = self(input_ids, attention_mask, labels)
567
+ self.log("train_loss", loss, prog_bar=True, logger=True)
568
+ return {"loss": loss, "predictions": outputs, "labels": labels}
569
+
570
+ def validation_step(self, batch, batch_idx):
571
+ input_ids = batch["input_ids"]
572
+ attention_mask = batch["attention_mask"]
573
+ labels = batch["labels"]
574
+ loss, outputs = self(input_ids, attention_mask, labels)
575
+ self.log("val_loss", loss, prog_bar=True, logger=True)
576
+ return loss
577
+
578
+ def test_step(self, batch, batch_idx):
579
+ input_ids = batch["input_ids"]
580
+ attention_mask = batch["attention_mask"]
581
+ labels = batch["labels"]
582
+ loss, outputs = self(input_ids, attention_mask, labels)
583
+ self.log("test_loss", loss, prog_bar=True, logger=True)
584
+ return loss
585
+
586
+ def configure_optimizers(self):
587
+ return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
588
+
589
+
590
+ class EurovocTaggerLoRA(pl.LightningModule):
591
+
592
+ def __init__(self, bert_model_name, n_classes, n_intermediate=256, lr=2e-5, eps=1e-8, lora_r=8, lora_alpha=16, lora_dropout=0.1):
593
+ super().__init__()
594
+
595
+ # Load base BERT model
596
+ self.bert = AutoModel.from_pretrained(bert_model_name)
597
+
598
+ # Configure LoRA
599
+ # Target modules: query and value projection layers in attention
600
+ lora_config = LoraConfig(
601
+ r=lora_r, # Rank of the low-rank matrices (smaller = fewer params)
602
+ lora_alpha=lora_alpha, # Scaling factor
603
+ target_modules=["query", "value"], # Which layers to apply LoRA to
604
+ lora_dropout=lora_dropout,
605
+ bias="none",
606
+ task_type=TaskType.FEATURE_EXTRACTION # For getting embeddings
607
+ )
608
+
609
+ # Apply LoRA to BERT
610
+ self.bert = get_peft_model(self.bert, lora_config)
611
+
612
+ # Print trainable parameters info
613
+ self.bert.print_trainable_parameters()
614
+
615
+ # Hierarchical classification head for 6800 labels
616
+ # Instead of 768 → 6800 (5.2M params), use 768 → 256 → 6800 (1.9M params)
617
+ # 768
618
+ hidden_size = self.bert.config.hidden_size
619
+
620
+ self.dropout1 = nn.Dropout(p=0.2)
621
+
622
+ # Layer 1: Compress to intermediate representation
623
+ self.classifier1 = nn.Linear(hidden_size, n_intermediate) # 768 → 256
624
+ self.relu = nn.ReLU()
625
+
626
+ self.dropout2 = nn.Dropout(p=0.2)
627
+
628
+ # Layer 2: Expand to all labels
629
+ self.classifier2 = nn.Linear(n_intermediate, n_classes) # 256 → 6800
630
+
631
+ self.criterion = nn.BCEWithLogitsLoss()
632
+ self.lr = lr
633
+ self.eps = eps
634
+
635
+ def forward(self, input_ids, attention_mask, labels=None):
636
+ # Forward pass through LoRA-enhanced BERT
637
+ output = self.bert(input_ids, attention_mask=attention_mask)
638
+
639
+ # Get pooled output (CLS token representation)
640
+ # (batch, 768)
641
+ output = self.dropout1(output.pooler_output)
642
+
643
+
644
+ # Hierarchical classifier
645
+ output = self.classifier1(output)
646
+ # (batch, 256)
647
+ output = self.relu(output)
648
+
649
+ output = self.dropout2(output)
650
+ # (batch, 6800)
651
+ output = self.classifier2(output)
652
+
653
+ loss = 0
654
+ if labels is not None:
655
+ loss = self.criterion(output, labels)
656
+ return loss, output
657
+
658
+ def training_step(self, batch, batch_idx):
659
+ input_ids = batch["input_ids"]
660
+ attention_mask = batch["attention_mask"]
661
+ labels = batch["labels"]
662
+ loss, outputs = self(input_ids, attention_mask, labels)
663
+ self.log("train_loss", loss, prog_bar=True, logger=True)
664
+ return {"loss": loss, "predictions": outputs, "labels": labels}
665
+
666
+ def validation_step(self, batch, batch_idx):
667
+ input_ids = batch["input_ids"]
668
+ attention_mask = batch["attention_mask"]
669
+ labels = batch["labels"]
670
+ loss, outputs = self(input_ids, attention_mask, labels)
671
+ self.log("val_loss", loss, prog_bar=True, logger=True)
672
+ return loss
673
+
674
+ def test_step(self, batch, batch_idx):
675
+ input_ids = batch["input_ids"]
676
+ attention_mask = batch["attention_mask"]
677
+ labels = batch["labels"]
678
+ loss, outputs = self(input_ids, attention_mask, labels)
679
+ self.log("test_loss", loss, prog_bar=True, logger=True)
680
+ return loss
681
+
682
+ def configure_optimizers(self):
683
+ return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
684
+
685
+ def save_lora_adapter(self, path):
686
+ """Save only the LoRA adapter weights"""
687
+ self.bert.save_pretrained(path)
688
+
689
+ def load_lora_adapter(self, path):
690
+ """Load LoRA adapter weights"""
691
+ self.bert = PeftModel.from_pretrained(self.bert, path)
inference_test.ipynb ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "11ab9cd5-a6e4-416a-b44f-201e8bf8ee84",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Test inference"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "40523be3-6ec7-4cac-aa90-6b5177c0f07d",
15
+ "metadata": {
16
+ "tags": []
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "from pdfminer.high_level import extract_text"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 26,
26
+ "id": "c0e5cc3f-5a9d-4b0f-8f7c-d46c0f79b5df",
27
+ "metadata": {
28
+ "tags": []
29
+ },
30
+ "outputs": [
31
+ {
32
+ "name": "stderr",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "Cannot set gray non-stroke color because /'P3954' is an invalid float value\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "text = extract_text(\"./example_docs_for_inference/publication_climate.pdf\")"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 27,
46
+ "id": "120528e3-26b9-40ce-ac8c-3c30c3092d28",
47
+ "metadata": {
48
+ "tags": []
49
+ },
50
+ "outputs": [
51
+ {
52
+ "name": "stdout",
53
+ "output_type": "stream",
54
+ "text": [
55
+ "ISSN 1831-9424 \n",
56
+ "\n",
57
+ "How to plan mitigation, adaptatio\n"
58
+ ]
59
+ }
60
+ ],
61
+ "source": [
62
+ "print(text[0:50])"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 9,
68
+ "id": "d191928f-381e-4da3-8342-1300909b52c5",
69
+ "metadata": {
70
+ "tags": []
71
+ },
72
+ "outputs": [
73
+ {
74
+ "name": "stderr",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "/home/mbarhdadi/projects/training/eurovoc_training_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
78
+ " from .autonotebook import tqdm as notebook_tqdm\n"
79
+ ]
80
+ },
81
+ {
82
+ "name": "stdout",
83
+ "output_type": "stream",
84
+ "text": [
85
+ "Model loaded. Ready to predict 6958 eurovoc labels.\n"
86
+ ]
87
+ }
88
+ ],
89
+ "source": [
90
+ "import pickle\n",
91
+ "from transformers import AutoTokenizer, AutoModel\n",
92
+ "from eurovoc import EurovocTagger\n",
93
+ "\n",
94
+ "# Load MLBinarizer\n",
95
+ "with open('./models_finetuned/latest/mlb.pickle', 'rb') as f:\n",
96
+ " mlb = pickle.load(f)\n",
97
+ "\n",
98
+ "# Load tokenizer\n",
99
+ "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
100
+ "tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)\n",
101
+ "\n",
102
+ "# Load trained model\n",
103
+ "checkpoint_path = \"./models_finetuned/latest/EurovocTaggerFP32-epoch=04-val_loss=0.00.ckpt\" \n",
104
+ "model = EurovocTagger.load_from_checkpoint(\n",
105
+ " checkpoint_path,\n",
106
+ " bert_model_name=BERT_MODEL_NAME,\n",
107
+ " n_classes=len(mlb.classes_)\n",
108
+ ")\n",
109
+ "\n",
110
+ "\n",
111
+ "print(f\"Model loaded. Ready to predict {len(mlb.classes_)} eurovoc labels.\")"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 15,
117
+ "id": "7a1fd7e6-e14d-4c24-97ae-abcd5a30ab71",
118
+ "metadata": {
119
+ "tags": []
120
+ },
121
+ "outputs": [],
122
+ "source": [
123
+ "def get_eurovoc_id_to_term_mapping():\n",
124
+ " \"\"\"\n",
125
+ " Create a mapping from eurovoc IDs to their human-readable terms.\n",
126
+ " \n",
127
+ " Returns:\n",
128
+ " Dict mapping eurovoc_id -> term_name\n",
129
+ " \"\"\"\n",
130
+ " import requests\n",
131
+ " import xmltodict\n",
132
+ " \n",
133
+ " eurovoc_id_to_term = {}\n",
134
+ " \n",
135
+ " response = requests.get(\n",
136
+ " 'http://publications.europa.eu/resource/dataset/eurovoc',\n",
137
+ " headers={\n",
138
+ " 'Accept': 'application/xml',\n",
139
+ " 'Accept-Language': 'en',\n",
140
+ " 'User-Agent': 'Mozilla/5.0'\n",
141
+ " }\n",
142
+ " )\n",
143
+ " \n",
144
+ " data = xmltodict.parse(response.content)\n",
145
+ " \n",
146
+ " for term in data['xs:schema']['xs:simpleType']['xs:restriction']['xs:enumeration']:\n",
147
+ " try:\n",
148
+ " name = term['xs:annotation']['xs:documentation'].split('/')[0].strip()\n",
149
+ " eurovoc_id = term['@value'].split(':')[1]\n",
150
+ " \n",
151
+ " # Map ID -> term \n",
152
+ " eurovoc_id_to_term[eurovoc_id] = {\n",
153
+ " 'original': name,\n",
154
+ " 'lowercase': name.lower()\n",
155
+ " }\n",
156
+ " except (KeyError, IndexError) as e:\n",
157
+ " print(f\"⚠️ Could not parse term: {term}\")\n",
158
+ " \n",
159
+ " print(f\"✓ Loaded {len(eurovoc_id_to_term)} eurovoc terms\")\n",
160
+ " return eurovoc_id_to_term"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 23,
166
+ "id": "d2b703ea-ca41-4353-8776-1a226f02c56b",
167
+ "metadata": {
168
+ "tags": []
169
+ },
170
+ "outputs": [
171
+ {
172
+ "name": "stdout",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "Loading Eurovoc terms...\n",
176
+ "✓ Loaded 7488 eurovoc terms\n"
177
+ ]
178
+ }
179
+ ],
180
+ "source": [
181
+ "print(\"Loading Eurovoc terms...\")\n",
182
+ "eurovoc_id_to_term = get_eurovoc_id_to_term_mapping()\n"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": 24,
188
+ "id": "7a5fed81-64e8-4454-a56b-73eb50676b75",
189
+ "metadata": {
190
+ "tags": []
191
+ },
192
+ "outputs": [],
193
+ "source": [
194
+ "import torch\n",
195
+ "import numpy as np\n",
196
+ "from transformers import AutoTokenizer\n",
197
+ "\n",
198
+ "def predict_eurovoc_labels(text, model, mlb, tokenizer, \n",
199
+ " eurovoc_id_to_term=None,\n",
200
+ " max_token_len=512, \n",
201
+ " threshold=0.5, \n",
202
+ " top_k=10,\n",
203
+ " device='cuda'):\n",
204
+ " model.eval()\n",
205
+ " model.to(device)\n",
206
+ " \n",
207
+ " # Tokenize\n",
208
+ " encoding = tokenizer.encode_plus(\n",
209
+ " text,\n",
210
+ " add_special_tokens=True,\n",
211
+ " max_length=max_token_len,\n",
212
+ " return_token_type_ids=False,\n",
213
+ " padding=\"max_length\",\n",
214
+ " truncation=True,\n",
215
+ " return_attention_mask=True,\n",
216
+ " return_tensors='pt',\n",
217
+ " )\n",
218
+ " \n",
219
+ " input_ids = encoding[\"input_ids\"].to(device)\n",
220
+ " attention_mask = encoding[\"attention_mask\"].to(device)\n",
221
+ " \n",
222
+ " # Predict\n",
223
+ " with torch.no_grad():\n",
224
+ " _, outputs = model(input_ids, attention_mask)\n",
225
+ " \n",
226
+ "\n",
227
+ " probabilities = outputs\n",
228
+ " \n",
229
+ " probabilities = probabilities.cpu().numpy()[0]\n",
230
+ " \n",
231
+ " # Helper function to enrich labels with terms\n",
232
+ " def enrich_labels(label_ids, probs):\n",
233
+ " \"\"\"Add human-readable terms to eurovoc IDs\"\"\"\n",
234
+ " enriched = []\n",
235
+ " for label_id, prob in zip(label_ids, probs):\n",
236
+ " entry = {\n",
237
+ " 'eurovoc_id': label_id,\n",
238
+ " 'probability': float(prob)\n",
239
+ " }\n",
240
+ " \n",
241
+ " # Add term if mapping available\n",
242
+ " if eurovoc_id_to_term and label_id in eurovoc_id_to_term:\n",
243
+ " entry['term'] = eurovoc_id_to_term[label_id]['original']\n",
244
+ " entry['term_lower'] = eurovoc_id_to_term[label_id]['lowercase']\n",
245
+ " else:\n",
246
+ " entry['term'] = None\n",
247
+ " entry['term_lower'] = None\n",
248
+ " \n",
249
+ " enriched.append(entry)\n",
250
+ " \n",
251
+ " return enriched\n",
252
+ " \n",
253
+ " # Get predictions above threshold\n",
254
+ " predicted_indices = np.where(probabilities >= threshold)[0]\n",
255
+ " predicted_labels = mlb.classes_[predicted_indices]\n",
256
+ " predicted_probs = probabilities[predicted_indices]\n",
257
+ " \n",
258
+ " # Get top-k predictions\n",
259
+ " top_k_indices = np.argsort(probabilities)[-top_k:][::-1]\n",
260
+ " top_k_labels = mlb.classes_[top_k_indices]\n",
261
+ " top_k_probs = probabilities[top_k_indices]\n",
262
+ " \n",
263
+ " return {\n",
264
+ " 'above_threshold': {\n",
265
+ " 'predictions': enrich_labels(predicted_labels, predicted_probs),\n",
266
+ " 'count': len(predicted_labels)\n",
267
+ " },\n",
268
+ " 'top_k': {\n",
269
+ " 'predictions': enrich_labels(top_k_labels, top_k_probs)\n",
270
+ " }\n",
271
+ " }"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 28,
277
+ "id": "030b99aa-edc7-472a-8c7f-636a47a9cdce",
278
+ "metadata": {
279
+ "tags": []
280
+ },
281
+ "outputs": [
282
+ {
283
+ "name": "stdout",
284
+ "output_type": "stream",
285
+ "text": [
286
+ "Document length: 696483 characters\n",
287
+ "Truncated to: 2048 tokens (~2048 chars)\n",
288
+ "\n",
289
+ "Running inference...\n",
290
+ "\n",
291
+ "================================================================================\n",
292
+ "TOP 15 PREDICTED EUROVOC LABELS (with terms)\n",
293
+ "================================================================================\n",
294
+ "642 | energy saving | 0.8567\n",
295
+ "6700 | energy efficiency | 0.7060\n",
296
+ "2281 | poverty | 0.4645\n",
297
+ "5311 | user guide | 0.4198\n",
298
+ "2498 | energy policy | 0.3545\n",
299
+ "5482 | climate change | 0.1736\n",
300
+ "754 | renewable energy | 0.1338\n",
301
+ "6400 | reduction of gas emissions | 0.1321\n",
302
+ "2517 | social policy | 0.1260\n",
303
+ "475 | energy distribution | 0.1253\n",
304
+ "5188 | information technology | 0.1087\n",
305
+ "2715 | energy production | 0.1087\n",
306
+ "2451 | EU policy | 0.0812\n",
307
+ "4139 | serial publication | 0.0808\n",
308
+ "83 | living conditions | 0.0793\n",
309
+ "\n",
310
+ "5 labels above threshold (0.3)\n",
311
+ "\n",
312
+ "================================================================================\n",
313
+ "PREDICTIONS ABOVE THRESHOLD (with readable terms)\n",
314
+ "================================================================================\n",
315
+ "2281 | poverty | 0.4645\n",
316
+ "2498 | energy policy | 0.3545\n",
317
+ "5311 | user guide | 0.4198\n",
318
+ "642 | energy saving | 0.8567\n",
319
+ "6700 | energy efficiency | 0.7060\n"
320
+ ]
321
+ }
322
+ ],
323
+ "source": [
324
+ "print(f\"Document length: {len(text)} characters\")\n",
325
+ "print(f\"Truncated to: {512 * 4} tokens (~2048 chars)\\n\") \n",
326
+ "\n",
327
+ "print(\"Running inference...\\n\")\n",
328
+ "results = predict_eurovoc_labels(\n",
329
+ " text=text,\n",
330
+ " model=model,\n",
331
+ " mlb=mlb,\n",
332
+ " tokenizer=tokenizer,\n",
333
+ " eurovoc_id_to_term=eurovoc_id_to_term, # ← Pass the mapping\n",
334
+ " threshold=0.3,\n",
335
+ " top_k=15\n",
336
+ ")\n",
337
+ "print(\"=\" * 80)\n",
338
+ "print(\"TOP 15 PREDICTED EUROVOC LABELS\")\n",
339
+ "print(\"=\" * 80)\n",
340
+ "\n",
341
+ "for pred in results['top_k']['predictions']:\n",
342
+ " term = pred['term'] if pred['term'] else \"(term not found)\"\n",
343
+ " print(f\"{pred['eurovoc_id']:15s} | {term:45s} | {pred['probability']:.4f}\")\n",
344
+ "\n",
345
+ "print(f\"\\n{results['above_threshold']['count']} labels above threshold (0.3)\")\n",
346
+ "\n",
347
+ "print(\"\\n\" + \"=\" * 80)\n",
348
+ "print(\"PREDICTIONS ABOVE THRESHOLD\")\n",
349
+ "print(\"=\" * 80)\n",
350
+ "\n",
351
+ "for pred in results['above_threshold']['predictions']:\n",
352
+ " if pred['term']: # Only show if term was found\n",
353
+ " print(f\"{pred['eurovoc_id']:15s} | {pred['term']:45s} | {pred['probability']:.4f}\")"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "id": "27ebc73c-5832-4702-bc1e-dd026ebeed02",
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": []
363
+ }
364
+ ],
365
+ "metadata": {
366
+ "kernelspec": {
367
+ "display_name": "eurovoc_training_env",
368
+ "language": "python",
369
+ "name": "eurovoc_training_env"
370
+ },
371
+ "language_info": {
372
+ "codemirror_mode": {
373
+ "name": "ipython",
374
+ "version": 3
375
+ },
376
+ "file_extension": ".py",
377
+ "mimetype": "text/x-python",
378
+ "name": "python",
379
+ "nbconvert_exporter": "python",
380
+ "pygments_lexer": "ipython3",
381
+ "version": "3.10.12"
382
+ }
383
+ },
384
+ "nbformat": 4,
385
+ "nbformat_minor": 5
386
+ }
train_lora_included.ipynb ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3dc740a0-1865-40da-a163-b858f29d1313",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 🇪🇺 🏷️ Eurovoc Model Training Notebook"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "64a1dc4a-5bf5-46d9-9356-3958802837ac",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pickle \n",
19
+ "import pandas as pd\n",
20
+ "from transformers import AutoTokenizer, AutoModel\n",
21
+ "\n",
22
+ "from datasets import load_dataset\n",
23
+ "\n",
24
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
25
+ "import torch\n",
26
+ "\n",
27
+ "import pytorch_lightning as pl\n",
28
+ "from pytorch_lightning.callbacks import ModelCheckpoint"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "id": "caa5dc4b-2fe3-43da-846d-a866c2224280",
35
+ "metadata": {
36
+ "tags": []
37
+ },
38
+ "outputs": [],
39
+ "source": [
40
+ "fixed_dir = fix_all_files(all_jsonl_files)\n",
41
+ "logger.info(f\"Done! Use files from: {fixed_dir}\")"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "id": "6d63a920-52aa-4c73-bd2d-575e888d3d55",
47
+ "metadata": {
48
+ "tags": []
49
+ },
50
+ "source": [
51
+ "### Create the MultiLabel Binarizer and save it in a file for prediction "
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "921fd5cd-67e7-4962-8e5e-15e055dd63b6",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "from tqdm import tqdm\n",
62
+ "\n",
63
+ "\n",
64
+ "import os\n",
65
+ "from datetime import datetime\n",
66
+ "\n",
67
+ "FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
68
+ "\n",
69
+ "def list_all_json_files(directory=FIXED_DIR):\n",
70
+ " # List all items in the directory\n",
71
+ " all_items = os.listdir(directory)\n",
72
+ "\n",
73
+ " def extract_date_key(filename):\n",
74
+ " \"\"\"\n",
75
+ " Extracts a datetime object from filenames containing YYYY-MM.\n",
76
+ " Handles .jsonl and .jsonl.gz.\n",
77
+ " \"\"\"\n",
78
+ " base = filename.split('.')[0] \n",
79
+ " yyyy, mm = base.split('-') \n",
80
+ " return datetime(int(yyyy), int(mm), 1)\n",
81
+ "\n",
82
+ "\n",
83
+ " jsonl_files = [\n",
84
+ " f for f in all_items\n",
85
+ " if f.endswith(\".jsonl\") or f.endswith(\".jsonl.gz\")\n",
86
+ " ]\n",
87
+ "\n",
88
+ " # Sort newest to oldest\n",
89
+ " jsonl_files_sorted = sorted(\n",
90
+ " jsonl_files,\n",
91
+ " key=extract_date_key,\n",
92
+ " reverse=True\n",
93
+ " )\n",
94
+ " return [os.path.join(directory, f) for f in jsonl_files_sorted]\n",
95
+ "\n",
96
+ "all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
97
+ "\n",
98
+ " \n",
99
+ "print(f\"Found {len(all_jsonl_files)} files to load (including compressed).\")\n",
100
+ "\n",
101
+ "\n",
102
+ "def build_mlb_from_streaming(all_jsonl_files, output_path='../eurovoc_data/mlb.pickle'):\n",
103
+ " \"\"\"\n",
104
+ " Build MLBinarizer by scanning all files once to collect unique concepts.\n",
105
+ " This is more memory efficient than loading everything.\n",
106
+ " \"\"\"\n",
107
+ " print(\"Scanning files to collect all unique eurovoc concepts...\")\n",
108
+ " all_concepts = set()\n",
109
+ " \n",
110
+ " dataset = load_dataset(\n",
111
+ " 'json',\n",
112
+ " data_files=all_jsonl_files,\n",
113
+ " streaming=True,\n",
114
+ " split='train'\n",
115
+ " )\n",
116
+ " \n",
117
+ " for record in tqdm(dataset, desc=\"Collecting eurovoc IDS\"):\n",
118
+ " concepts = record.get('eurovoc_ids', [])\n",
119
+ " if concepts:\n",
120
+ " all_concepts.update(concepts)\n",
121
+ " \n",
122
+ " print(f\"Found {len(all_concepts)} unique eurovoc IDS\")\n",
123
+ " \n",
124
+ " # Create and fit MLBinarizer\n",
125
+ " mlb = MultiLabelBinarizer()\n",
126
+ " mlb.fit([sorted(list(all_concepts))])\n",
127
+ " \n",
128
+ " # Save it\n",
129
+ " with open(output_path, 'wb') as f:\n",
130
+ " pickle.dump(mlb, f)\n",
131
+ " \n",
132
+ " print(f\"Saved MLBinarizer to {output_path}\")\n",
133
+ " return mlb\n",
134
+ "\n"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "66e1d48e-83a7-4a38-a081-b72ba679e960",
141
+ "metadata": {
142
+ "tags": []
143
+ },
144
+ "outputs": [],
145
+ "source": [
146
+ "build_mlb_from_streaming(all_jsonl_files)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "id": "b2fd1bda-ee0e-40f2-85a6-87322a9db725",
152
+ "metadata": {
153
+ "tags": []
154
+ },
155
+ "source": [
156
+ "---\n",
157
+ "## 2. Load cleaned data and Split data using iterative train test \n",
158
+ "\n",
159
+ "## THIS ASSUMES ALL DATA IS IN 'TRAIN' OF DATASET, IF NOT ALSO LOAD IT HERE\n"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "aaba16cf-a9b6-4c22-944a-2d31b8b5812d",
166
+ "metadata": {
167
+ "tags": []
168
+ },
169
+ "outputs": [],
170
+ "source": [
171
+ "import pickle\n",
172
+ "\n",
173
+ "mlb = pickle.load(open('../eurovoc_data/mlb.pickle', 'rb'))\n",
174
+ "\n",
175
+ "print(f\"Loaded MLBinarizer with {len(mlb.classes_)} classes\")\n",
176
+ " # Show first 10\n",
177
+ "print(f\"Classes: {mlb.classes_[:10]}...\") "
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "7f10ac21-5731-4937-8340-829d531c6116",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "%load_ext autoreload\n",
188
+ "%autoreload 2\n",
189
+ "\n",
190
+ "import os\n",
191
+ "from datetime import datetime\n",
192
+ "\n",
193
+ "FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
194
+ "\n",
195
+ "def list_all_json_files(directory=FIXED_DIR):\n",
196
+ " # List all items in the directory\n",
197
+ " all_items = os.listdir(directory)\n",
198
+ "\n",
199
+ " def extract_date_key(filename):\n",
200
+ " \"\"\"\n",
201
+ " Extracts a datetime object from filenames containing YYYY-MM.\n",
202
+ " Handles .jsonl and .jsonl.gz.\n",
203
+ " \"\"\"\n",
204
+ " base = filename.split('.')[0] \n",
205
+ " yyyy, mm = base.split('-') \n",
206
+ " return datetime(int(yyyy), int(mm), 1)\n",
207
+ "\n",
208
+ "\n",
209
+ " jsonl_files = [\n",
210
+ " f for f in all_items\n",
211
+ " if f.endswith(\".jsonl\") or f.endswith(\".jsonl.gz\")\n",
212
+ " ]\n",
213
+ "\n",
214
+ " # Sort newest to oldest\n",
215
+ " jsonl_files_sorted = sorted(\n",
216
+ " jsonl_files,\n",
217
+ " key=extract_date_key,\n",
218
+ " reverse=True\n",
219
+ " )\n",
220
+ " return [os.path.join(directory, f) for f in jsonl_files_sorted]\n",
221
+ "\n",
222
+ "all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
223
+ "\n",
224
+ " \n",
225
+ "print(f\"Found {len(all_jsonl_files)} files to load (including compressed).\")\n"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "25ecca51-7901-448b-9d89-4ed0663b2bae",
232
+ "metadata": {
233
+ "tags": []
234
+ },
235
+ "outputs": [],
236
+ "source": [
237
+ "import gc\n",
238
+ "gc.collect()"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "markdown",
243
+ "id": "1ff0c6b0-abcb-4424-be97-5c7bd8fb9af7",
244
+ "metadata": {},
245
+ "source": [
246
+ "## 2.1 Model definition"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "id": "aaa9dc1b-1086-47d2-9b3b-20d954bda644",
252
+ "metadata": {},
253
+ "source": [
254
+ "---\n",
255
+ "## 3. Model definition and training (NORMAL)"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "5f15e504-9431-4913-9016-4b0c6344a127",
262
+ "metadata": {
263
+ "tags": []
264
+ },
265
+ "outputs": [],
266
+ "source": [
267
+ "import wandb\n",
268
+ "\n",
269
+ "wandb.login() "
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "2f780f9d-730c-4540-ad9a-e0b60c87f147",
276
+ "metadata": {
277
+ "tags": []
278
+ },
279
+ "outputs": [],
280
+ "source": [
281
+ "%load_ext autoreload\n",
282
+ "%autoreload 2\n",
283
+ "\n",
284
+ "from eurovoc import StreamingEurovocDataModule\n",
285
+ "from eurovoc import EurovocTagger\n",
286
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
287
+ "import pytorch_lightning as pl\n",
288
+ "import torch\n",
289
+ "from pytorch_lightning.callbacks import EarlyStopping\n",
290
+ "import gc\n",
291
+ "\n",
292
+ "class MemoryMonitorCallback(pl.Callback):\n",
293
+ " def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):\n",
294
+ " # Log memory every 100 batches\n",
295
+ " if batch_idx % 100 == 0:\n",
296
+ " if torch.cuda.is_available():\n",
297
+ " for i in range(torch.cuda.device_count()):\n",
298
+ " allocated = torch.cuda.memory_allocated(i) / 1e9\n",
299
+ " reserved = torch.cuda.memory_reserved(i) / 1e9\n",
300
+ " trainer.logger.experiment.log({\n",
301
+ " f\"memory/gpu_{i}_allocated_gb\": allocated,\n",
302
+ " f\"memory/gpu_{i}_reserved_gb\": reserved,\n",
303
+ " \"batch_idx\": batch_idx\n",
304
+ " })\n",
305
+ " \n",
306
+ " def on_train_epoch_end(self, trainer, pl_module):\n",
307
+ " # Force cleanup at end of each epoch\n",
308
+ " gc.collect()\n",
309
+ " torch.cuda.empty_cache()\n",
310
+ " \n",
311
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
312
+ " # Force cleanup after validation\n",
313
+ " gc.collect()\n",
314
+ " torch.cuda.empty_cache()\n",
315
+ " \n",
316
+ " \n",
317
+ "early_stop = EarlyStopping(\n",
318
+ " monitor='val_loss',\n",
319
+ " patience=4,\n",
320
+ " mode='min'\n",
321
+ ")\n",
322
+ "\n",
323
+ "memory_monitor = MemoryMonitorCallback()\n",
324
+ "\n",
325
+ "checkpoint_callback = ModelCheckpoint(\n",
326
+ " monitor='val_loss',\n",
327
+ " filename='EurovocTaggerFP32-{epoch:02d}-{val_loss:.2f}',\n",
328
+ " mode='min',\n",
329
+ ")\n"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "a069d202-2e61-4148-baeb-20fbd9b7bf7b",
336
+ "metadata": {
337
+ "tags": []
338
+ },
339
+ "outputs": [],
340
+ "source": [
341
+ "from pytorch_lightning.loggers import WandbLogger\n",
342
+ "wandb_logger = WandbLogger(\n",
343
+ " project=\"EUROVOC\",\n",
344
+ " name=\"EUROVOC-FP32\",\n",
345
+ " log_model=True, \n",
346
+ " save_dir=\"../logs\"\n",
347
+ ")\n",
348
+ "\n",
349
+ "FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
350
+ "\n",
351
+ "BATCH_SIZE=58\n",
352
+ "\n",
353
+ "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
354
+ "all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
355
+ "\n",
356
+ "dataloader = StreamingEurovocDataModule(BERT_MODEL_NAME, all_jsonl_files, mlb, batch_size=BATCH_SIZE)\n",
357
+ "dataloader.setup()\n",
358
+ "\n",
359
+ "N_EPOCHS = 30\n",
360
+ "LR = 5e-05\n",
361
+ "\n",
362
+ "model = EurovocTagger(BERT_MODEL_NAME, len(mlb.classes_), lr=LR)\n",
363
+ "\n",
364
+ "\n",
365
+ "wandb_logger.experiment.config.update({\n",
366
+ " \"bert_model\": BERT_MODEL_NAME,\n",
367
+ " \"batch_size\": BATCH_SIZE,\n",
368
+ " \"learning_rate\": LR,\n",
369
+ " \"max_epochs\": N_EPOCHS,\n",
370
+ " \"num_workers\": 3,\n",
371
+ " \"num_gpus\": 4,\n",
372
+ " \"precision\": \"32\",\n",
373
+ " \"num_classes\": len(mlb.classes_)\n",
374
+ "})\n",
375
+ "\n",
376
+ "\n",
377
+ "\n",
378
+ "if torch.cuda.is_available():\n",
379
+ " torch.backends.cuda.matmul.allow_tf32 = True\n",
380
+ " torch.backends.cudnn.allow_tf32 = True\n",
381
+ "\n",
382
+ "torch.set_float32_matmul_precision('medium')\n",
383
+ "\n",
384
+ "\n",
385
+ "trainer = pl.Trainer(max_epochs=N_EPOCHS ,\n",
386
+ " accelerator=\"gpu\",\n",
387
+ " devices=4, \n",
388
+ " callbacks=[checkpoint_callback, early_stop, memory_monitor],\n",
389
+ " strategy=\"ddp_notebook\",\n",
390
+ " logger=wandb_logger,\n",
391
+ " log_every_n_steps=50,\n",
392
+ " )\n",
393
+ "\n",
394
+ "trainer.fit(model, dataloader)"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "markdown",
399
+ "id": "29d9203e-c02b-4a76-a57c-d2e0246722c7",
400
+ "metadata": {},
401
+ "source": [
402
+ "## Finetuning in BF16"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "86734efb-0bcd-442f-976e-ea0bbdb393d6",
409
+ "metadata": {
410
+ "tags": []
411
+ },
412
+ "outputs": [],
413
+ "source": [
414
+ "import wandb\n",
415
+ "\n",
416
+ "wandb.login() "
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "id": "4b609efc-e3b7-4924-96c3-59a236f52ec6",
423
+ "metadata": {
424
+ "tags": []
425
+ },
426
+ "outputs": [],
427
+ "source": [
428
+ "from pytorch_lightning.loggers import WandbLogger\n",
429
+ "wandb_logger = WandbLogger(\n",
430
+ " project=\"EUROVOC\",\n",
431
+ " name=\"EUROVOC-BF16\",\n",
432
+ " log_model=True, \n",
433
+ " save_dir=\"../logs\"\n",
434
+ ")\n"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": null,
440
+ "id": "7cab5811-0ab9-48d5-8ec4-8d835bb0d3df",
441
+ "metadata": {},
442
+ "outputs": [],
443
+ "source": [
444
+ "#%%capture output\n",
445
+ "%load_ext autoreload\n",
446
+ "%autoreload 2\n",
447
+ "\n",
448
+ "from eurovoc import StreamingEurovocDataModule\n",
449
+ "from eurovoc import EurovocTaggerBCELogit, EurovocTagger\n",
450
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
451
+ "import pytorch_lightning as pl\n",
452
+ "import torch\n",
453
+ "from pytorch_lightning.callbacks import EarlyStopping\n",
454
+ "import gc\n",
455
+ "\n",
456
+ "class MemoryMonitorCallback(pl.Callback):\n",
457
+ " def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):\n",
458
+ " # Log memory every 100 batches\n",
459
+ " if batch_idx % 100 == 0:\n",
460
+ " if torch.cuda.is_available():\n",
461
+ " for i in range(torch.cuda.device_count()):\n",
462
+ " allocated = torch.cuda.memory_allocated(i) / 1e9\n",
463
+ " reserved = torch.cuda.memory_reserved(i) / 1e9\n",
464
+ " trainer.logger.experiment.log({\n",
465
+ " f\"memory/gpu_{i}_allocated_gb\": allocated,\n",
466
+ " f\"memory/gpu_{i}_reserved_gb\": reserved\n",
467
+ " })\n",
468
+ " \n",
469
+ " def on_train_epoch_end(self, trainer, pl_module):\n",
470
+ " # Force cleanup at end of each epoch\n",
471
+ " gc.collect()\n",
472
+ " torch.cuda.empty_cache()\n",
473
+ " \n",
474
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
475
+ " # Force cleanup after validation\n",
476
+ " gc.collect()\n",
477
+ " torch.cuda.empty_cache()\n",
478
+ "\n",
479
+ " \n",
480
+ " \n",
481
+ "\n",
482
+ "early_stop = EarlyStopping(\n",
483
+ " monitor='val_loss',\n",
484
+ " patience=4,\n",
485
+ " mode='min'\n",
486
+ ")\n",
487
+ "\n",
488
+ "memory_monitor = MemoryMonitorCallback()\n",
489
+ "\n",
490
+ "checkpoint_callback = ModelCheckpoint(\n",
491
+ " monitor='val_loss',\n",
492
+ " filename='EurovocTaggerA-{epoch:02d}-{val_loss:.2f}',\n",
493
+ " mode='min',\n",
494
+ ")\n",
495
+ "\n",
496
+ "\n",
497
+ "if torch.cuda.is_available():\n",
498
+ " torch.backends.cuda.matmul.allow_tf32 = True\n",
499
+ " torch.backends.cudnn.allow_tf32 = True\n",
500
+ "\n",
501
+ "torch.set_float32_matmul_precision('medium')\n",
502
+ "\n",
503
+ "\n",
504
+ "FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
505
+ "\n",
506
+ "BATCH_SIZE=74\n",
507
+ "\n",
508
+ "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
509
+ "all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
510
+ "\n",
511
+ "dataloader = StreamingEurovocDataModule(BERT_MODEL_NAME, all_jsonl_files, mlb, batch_size=BATCH_SIZE)\n",
512
+ "dataloader.setup()\n",
513
+ "\n",
514
+ "\n",
515
+ "\n",
516
+ "N_EPOCHS = 30\n",
517
+ "LR = 5e-05\n",
518
+ "\n",
519
+ "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
520
+ "\n",
521
+ "\n",
522
+ "model = EurovocTaggerBCELogit(BERT_MODEL_NAME, len(mlb.classes_), lr=LR)\n",
523
+ "\n",
524
+ "\n",
525
+ "\n",
526
+ "wandb_logger.experiment.config.update({\n",
527
+ " \"bert_model\": BERT_MODEL_NAME,\n",
528
+ " \"batch_size\": BATCH_SIZE,\n",
529
+ " \"learning_rate\": LR,\n",
530
+ " \"max_epochs\": N_EPOCHS,\n",
531
+ " \"num_workers\": 3,\n",
532
+ " \"num_gpus\": 4,\n",
533
+ " \"precision\": \"16\",\n",
534
+ " \"num_classes\": len(mlb.classes_)\n",
535
+ "})\n",
536
+ "\n",
537
+ "trainer = pl.Trainer(max_epochs=N_EPOCHS ,\n",
538
+ " accelerator=\"gpu\",\n",
539
+ " devices=4, \n",
540
+ " callbacks=[checkpoint_callback, early_stop, memory_monitor],\n",
541
+ " strategy=\"ddp_notebook\",\n",
542
+ " accumulate_grad_batches=1,\n",
543
+ " precision=16,\n",
544
+ " logger=wandb_logger,\n",
545
+ " log_every_n_steps=50,\n",
546
+ " )\n"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": null,
552
+ "id": "6af63c61-5ecd-4207-8aaa-2a0dbd008df2",
553
+ "metadata": {
554
+ "tags": []
555
+ },
556
+ "outputs": [],
557
+ "source": [
558
+ "\n",
559
+ "trainer.fit(model, dataloader)"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "markdown",
564
+ "id": "2e2f69c2-9d89-4468-8198-b15da16e9403",
565
+ "metadata": {},
566
+ "source": [
567
+ "## 4. MODEL definition and training (LORA) (STILL USES OLD EUROVOC TAGGER)"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "id": "c28014b2-5ccb-45d9-8025-d05d31d77a08",
574
+ "metadata": {
575
+ "tags": []
576
+ },
577
+ "outputs": [],
578
+ "source": [
579
+ "from eurovoc import StreamingEurovocDataModule\n",
580
+ "from eurovoc import EurovocTaggerLoRA\n",
581
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
582
+ "import pytorch_lightning as pl\n",
583
+ "import torch\n",
584
+ "\n",
585
+ "\n",
586
+ "torch.set_float32_matmul_precision('medium')\n",
587
+ "\n",
588
+ "FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
589
+ "\n",
590
+ "BATCH_SIZE=94\n",
591
+ "\n",
592
+ "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
593
+ "\n",
594
+ "\n",
595
+ "all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
596
+ "\n",
597
+ "dataloader = StreamingEurovocDataModule(BERT_MODEL_NAME, all_jsonl_files, mlb, batch_size=BATCH_SIZE)\n",
598
+ "dataloader.setup()\n",
599
+ "\n",
600
+ "\n",
601
+ "N_EPOCHS = 30\n",
602
+ "LR = 5e-05\n",
603
+ "\n",
604
+ "# LoRA hyperparameters\n",
605
+ "# Rank of LoRA matrices\n",
606
+ "LORA_R = 16 \n",
607
+ "# Scaling factor (usually 2 * r)\n",
608
+ "LORA_ALPHA = 32 \n",
609
+ "LORA_DROPOUT = 0.1\n",
610
+ "\n",
611
+ "# Hierarchical classifier parameter (for 6800 labels)\n",
612
+ "# Bottleneck size: 768 → 256 → 6800\n",
613
+ "N_INTERMEDIATE = 256 \n",
614
+ "\n",
615
+ "\n",
616
+ "# Create LoRA model with hierarchical classifier\n",
617
+ "model = EurovocTaggerLoRA(\n",
618
+ " BERT_MODEL_NAME, \n",
619
+ " # 6800+ labels\n",
620
+ " len(mlb.classes_),\n",
621
+ " # Bottleneck size\n",
622
+ " n_intermediate=N_INTERMEDIATE, \n",
623
+ " lr=LR,\n",
624
+ " lora_r=LORA_R,\n",
625
+ " lora_alpha=LORA_ALPHA,\n",
626
+ " lora_dropout=LORA_DROPOUT\n",
627
+ ")\n",
628
+ "\n",
629
+ "checkpoint_callback = ModelCheckpoint(\n",
630
+ " monitor='val_loss',\n",
631
+ " filename='EurovocTaggerLoRA-6800-{epoch:02d}-{val_loss:.2f}',\n",
632
+ " mode='min',\n",
633
+ ")\n",
634
+ "\n",
635
+ "trainer = pl.Trainer(\n",
636
+ " max_epochs=N_EPOCHS, \n",
637
+ " accelerator=\"gpu\", \n",
638
+ " devices=4, \n",
639
+ " callbacks=[checkpoint_callback],\n",
640
+ " strategy=\"ddp_notebook\",\n",
641
+ " precision=16\n",
642
+ ")\n",
643
+ "\n",
644
+ "print(f\"Starting LoRA training with {len(mlb.classes_)} labels...\")\n",
645
+ "print(f\"Classifier architecture: 768 → {N_INTERMEDIATE} → {len(mlb.classes_)}\")\n",
646
+ "trainer.fit(model, dataloader)\n",
647
+ "\n",
648
+ "\n",
649
+ "\n"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": null,
655
+ "id": "694ad3d7-794d-4a5f-a7be-8b47e872418d",
656
+ "metadata": {},
657
+ "outputs": [],
658
+ "source": [
659
+ "# Save only the LoRA adapter \n",
660
+ "model.save_lora_adapter('./eurovoc_lora_adapter')\n",
661
+ "\n",
662
+ "print(\"LoRA adapter saved to ./eurovoc_lora_adapter\")\n"
663
+ ]
664
+ }
665
+ ],
666
+ "metadata": {
667
+ "kernelspec": {
668
+ "display_name": "eurovoc_training_env",
669
+ "language": "python",
670
+ "name": "eurovoc_training_env"
671
+ },
672
+ "language_info": {
673
+ "codemirror_mode": {
674
+ "name": "ipython",
675
+ "version": 3
676
+ },
677
+ "file_extension": ".py",
678
+ "mimetype": "text/x-python",
679
+ "name": "python",
680
+ "nbconvert_exporter": "python",
681
+ "pygments_lexer": "ipython3",
682
+ "version": "3.10.12"
683
+ }
684
+ },
685
+ "nbformat": 4,
686
+ "nbformat_minor": 5
687
+ }