Upload 3 files
Browse files- eurovoc.py +691 -0
- inference_test.ipynb +386 -0
- train_lora_included.ipynb +687 -0
eurovoc.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import BertTokenizerFast as BertTokenizer, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from tqdm.auto import tqdm
|
| 12 |
+
import gzip
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import os
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_split_config(train_files, val_files, config_path, metadata=None):
|
| 20 |
+
"""
|
| 21 |
+
Save train/val split configuration to a JSON file.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
train_files: List of training file paths
|
| 25 |
+
val_files: List of validation file paths
|
| 26 |
+
config_path: Path to save the configuration JSON
|
| 27 |
+
metadata: Optional dict with additional info (train_ratio, seed, etc.)
|
| 28 |
+
"""
|
| 29 |
+
config = {
|
| 30 |
+
'train_files': train_files,
|
| 31 |
+
'val_files': val_files,
|
| 32 |
+
'num_train_files': len(train_files),
|
| 33 |
+
'num_val_files': len(val_files),
|
| 34 |
+
'metadata': metadata or {}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Create directory if it doesn't exist
|
| 38 |
+
os.makedirs(os.path.dirname(config_path) if os.path.dirname(config_path) else '.', exist_ok=True)
|
| 39 |
+
|
| 40 |
+
with open(config_path, 'w') as f:
|
| 41 |
+
json.dump(config, f, indent=2)
|
| 42 |
+
|
| 43 |
+
print(f"✓ Split configuration saved to {config_path}")
|
| 44 |
+
|
| 45 |
+
def load_split_config(config_path):
|
| 46 |
+
"""
|
| 47 |
+
Load train/val split configuration from a JSON file.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config_path: Path to the configuration JSON
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Tuple of (train_files, val_files, metadata)
|
| 54 |
+
"""
|
| 55 |
+
with open(config_path, 'r') as f:
|
| 56 |
+
config = json.load(f)
|
| 57 |
+
|
| 58 |
+
print(f"✓ Loaded split configuration from {config_path}")
|
| 59 |
+
print(f" Train files: {config['num_train_files']}")
|
| 60 |
+
print(f" Val files: {config['num_val_files']}")
|
| 61 |
+
|
| 62 |
+
return config['train_files'], config['val_files'], config.get('metadata', {})
|
| 63 |
+
|
| 64 |
+
def get_file_label_stats(jsonl_files):
|
| 65 |
+
"""
|
| 66 |
+
Get label distribution from all files.
|
| 67 |
+
Since we need accurate stats for rare labels, we count everything.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
jsonl_files: List of paths to JSONL files
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Dict mapping file paths to their label statistics
|
| 74 |
+
"""
|
| 75 |
+
file_labels = {}
|
| 76 |
+
|
| 77 |
+
print(f"Analyzing {len(jsonl_files)} files...")
|
| 78 |
+
for file_path in tqdm(jsonl_files):
|
| 79 |
+
label_counts = Counter()
|
| 80 |
+
total_records = 0
|
| 81 |
+
|
| 82 |
+
if file_path.endswith('.gz'):
|
| 83 |
+
open_func = lambda f: gzip.open(f, 'rt', encoding='utf-8')
|
| 84 |
+
else:
|
| 85 |
+
open_func = lambda f: open(f, 'r', encoding='utf-8')
|
| 86 |
+
|
| 87 |
+
with open_func(file_path) as f:
|
| 88 |
+
for line in f:
|
| 89 |
+
try:
|
| 90 |
+
record = json.loads(line)
|
| 91 |
+
eurovoc_ids = record.get('eurovoc_ids', [])
|
| 92 |
+
label_counts.update(eurovoc_ids)
|
| 93 |
+
total_records += 1
|
| 94 |
+
except Exception as e:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
file_labels[file_path] = {
|
| 98 |
+
'label_counts': label_counts,
|
| 99 |
+
'total_records': total_records
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
return file_labels
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def smart_split_files(all_jsonl_files, train_ratio=0.92,
|
| 106 |
+
rare_threshold=0.005, seed=42, verbose=True,
|
| 107 |
+
save_config_path=None):
|
| 108 |
+
"""
|
| 109 |
+
Split files ensuring rare labels appear in training set.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
all_jsonl_files: List of all JSONL file paths
|
| 113 |
+
train_ratio: Fraction of files for training (default 0.92)
|
| 114 |
+
rare_threshold: Labels appearing in < this fraction are considered rare
|
| 115 |
+
seed: Random seed for reproducibility
|
| 116 |
+
verbose: Print statistics
|
| 117 |
+
save_config_path: If provided, save the split configuration to this path
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Tuple of (train_files, val_files)
|
| 121 |
+
"""
|
| 122 |
+
random.seed(seed)
|
| 123 |
+
|
| 124 |
+
if verbose:
|
| 125 |
+
print("Analyzing label distribution across files...")
|
| 126 |
+
|
| 127 |
+
file_stats = get_file_label_stats(all_jsonl_files)
|
| 128 |
+
|
| 129 |
+
# Calculate which labels are rare globally
|
| 130 |
+
global_label_counts = Counter()
|
| 131 |
+
for stats in file_stats.values():
|
| 132 |
+
global_label_counts.update(stats['label_counts'])
|
| 133 |
+
|
| 134 |
+
# Identify rare labels
|
| 135 |
+
total_labels = sum(global_label_counts.values())
|
| 136 |
+
rare_count_threshold = total_labels * rare_threshold
|
| 137 |
+
rare_labels = {label for label, count in global_label_counts.items()
|
| 138 |
+
if count < rare_count_threshold}
|
| 139 |
+
|
| 140 |
+
if verbose:
|
| 141 |
+
print(f"Found {len(rare_labels)} rare labels out of {len(global_label_counts)} total")
|
| 142 |
+
|
| 143 |
+
# Score files by number of rare labels they contain
|
| 144 |
+
file_rare_counts = {}
|
| 145 |
+
for file_path, stats in file_stats.items():
|
| 146 |
+
file_labels_set = set(stats['label_counts'].keys())
|
| 147 |
+
rare_in_file = file_labels_set & rare_labels
|
| 148 |
+
file_rare_counts[file_path] = len(rare_in_file)
|
| 149 |
+
|
| 150 |
+
# Sort files by rare label count (descending)
|
| 151 |
+
sorted_files = sorted(file_rare_counts.items(), key=lambda x: x[1], reverse=True)
|
| 152 |
+
|
| 153 |
+
# Calculate split point
|
| 154 |
+
split_idx = int(len(all_jsonl_files) * train_ratio)
|
| 155 |
+
|
| 156 |
+
# Assign files
|
| 157 |
+
train_files = [f for f, _ in sorted_files[:split_idx]]
|
| 158 |
+
val_files = [f for f, _ in sorted_files[split_idx:]]
|
| 159 |
+
|
| 160 |
+
# Calculate stats
|
| 161 |
+
train_rare_count = sum(1 for f in train_files if file_rare_counts[f] > 0)
|
| 162 |
+
val_rare_count = sum(1 for f in val_files if file_rare_counts[f] > 0)
|
| 163 |
+
|
| 164 |
+
if verbose:
|
| 165 |
+
print(f"Train files: {len(train_files)} ({train_rare_count} with rare labels)")
|
| 166 |
+
print(f"Val files: {len(val_files)} ({val_rare_count} with rare labels)")
|
| 167 |
+
|
| 168 |
+
# Check label coverage
|
| 169 |
+
train_labels = set()
|
| 170 |
+
val_labels = set()
|
| 171 |
+
for f in train_files:
|
| 172 |
+
train_labels.update(file_stats[f]['label_counts'].keys())
|
| 173 |
+
for f in val_files:
|
| 174 |
+
val_labels.update(file_stats[f]['label_counts'].keys())
|
| 175 |
+
|
| 176 |
+
labels_only_in_train = train_labels - val_labels
|
| 177 |
+
labels_only_in_val = val_labels - train_labels
|
| 178 |
+
|
| 179 |
+
print(f"Labels only in train: {len(labels_only_in_train)}")
|
| 180 |
+
print(f"Labels only in val: {len(labels_only_in_val)}")
|
| 181 |
+
if len(labels_only_in_val) > 0:
|
| 182 |
+
print(f"⚠️ WARNING: {len(labels_only_in_val)} labels appear only in validation!")
|
| 183 |
+
|
| 184 |
+
# Save configuration if path provided
|
| 185 |
+
if save_config_path:
|
| 186 |
+
metadata = {
|
| 187 |
+
'train_ratio': train_ratio,
|
| 188 |
+
'rare_threshold': rare_threshold,
|
| 189 |
+
'seed': seed,
|
| 190 |
+
'total_files': len(all_jsonl_files),
|
| 191 |
+
'num_rare_labels': len(rare_labels),
|
| 192 |
+
'num_total_labels': len(global_label_counts),
|
| 193 |
+
'train_rare_count': train_rare_count,
|
| 194 |
+
'val_rare_count': val_rare_count
|
| 195 |
+
}
|
| 196 |
+
save_split_config(train_files, val_files, save_config_path, metadata)
|
| 197 |
+
|
| 198 |
+
return train_files, val_files
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class EurovocDataset(Dataset):
|
| 203 |
+
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
text: np.array,
|
| 207 |
+
labels: np.array,
|
| 208 |
+
tokenizer: BertTokenizer,
|
| 209 |
+
max_token_len: int = 128
|
| 210 |
+
):
|
| 211 |
+
self.tokenizer = tokenizer
|
| 212 |
+
self.text = text
|
| 213 |
+
self.labels = labels
|
| 214 |
+
self.max_token_len = max_token_len
|
| 215 |
+
|
| 216 |
+
def __len__(self):
|
| 217 |
+
return len(self.labels)
|
| 218 |
+
|
| 219 |
+
def __getitem__(self, index: int):
|
| 220 |
+
text = self.text[index][0]
|
| 221 |
+
labels = self.labels[index]
|
| 222 |
+
|
| 223 |
+
encoding = self.tokenizer.encode_plus(
|
| 224 |
+
text,
|
| 225 |
+
add_special_tokens=True,
|
| 226 |
+
max_length=self.max_token_len,
|
| 227 |
+
return_token_type_ids=False,
|
| 228 |
+
padding="max_length",
|
| 229 |
+
truncation=True,
|
| 230 |
+
return_attention_mask=True,
|
| 231 |
+
return_tensors='pt',
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return dict(
|
| 235 |
+
text=text,
|
| 236 |
+
input_ids=encoding["input_ids"].flatten(),
|
| 237 |
+
attention_mask=encoding["attention_mask"].flatten(),
|
| 238 |
+
labels=torch.FloatTensor(labels)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class StreamingEurovocDataset(IterableDataset):
|
| 243 |
+
"""
|
| 244 |
+
Streaming dataset that doesn't load everything into memory.
|
| 245 |
+
Processes one record at a time from disk.
|
| 246 |
+
"""
|
| 247 |
+
def __init__(self, jsonl_files, mlb, tokenizer, max_token_len=512, split='train'):
|
| 248 |
+
self.jsonl_files = jsonl_files
|
| 249 |
+
self.mlb = mlb
|
| 250 |
+
self.tokenizer = tokenizer
|
| 251 |
+
self.max_token_len = max_token_len
|
| 252 |
+
self.split = split
|
| 253 |
+
|
| 254 |
+
def __iter__(self):
|
| 255 |
+
dataset = load_dataset(
|
| 256 |
+
'json',
|
| 257 |
+
data_files=self.jsonl_files,
|
| 258 |
+
streaming=True,
|
| 259 |
+
split='train'
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
for record in dataset:
|
| 263 |
+
text = record.get('text')
|
| 264 |
+
eurovoc_ids = record.get('eurovoc_ids', [])
|
| 265 |
+
|
| 266 |
+
# Skip invalid records
|
| 267 |
+
if not text or not eurovoc_ids:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
# Convert concepts to binary labels
|
| 271 |
+
labels = self.mlb.transform([eurovoc_ids])[0]
|
| 272 |
+
|
| 273 |
+
# Tokenize
|
| 274 |
+
encoding = self.tokenizer.encode_plus(
|
| 275 |
+
text,
|
| 276 |
+
add_special_tokens=True,
|
| 277 |
+
max_length=self.max_token_len,
|
| 278 |
+
return_token_type_ids=False,
|
| 279 |
+
padding="max_length",
|
| 280 |
+
truncation=True,
|
| 281 |
+
return_attention_mask=True,
|
| 282 |
+
return_tensors='pt',
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
yield {
|
| 286 |
+
'input_ids': encoding["input_ids"].flatten(),
|
| 287 |
+
'attention_mask': encoding["attention_mask"].flatten(),
|
| 288 |
+
'labels': torch.FloatTensor(labels)
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class EuroVocLongTextDataset(Dataset):
|
| 294 |
+
|
| 295 |
+
def __splitter__(text, max_lenght):
|
| 296 |
+
l = text.split()
|
| 297 |
+
for i in range(0, len(l), max_lenght):
|
| 298 |
+
yield l[i:i + max_lenght]
|
| 299 |
+
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
text: np.array,
|
| 303 |
+
labels: np.array,
|
| 304 |
+
tokenizer: BertTokenizer,
|
| 305 |
+
max_token_len: int = 128
|
| 306 |
+
):
|
| 307 |
+
self.tokenizer = tokenizer
|
| 308 |
+
self.text = text
|
| 309 |
+
self.labels = labels
|
| 310 |
+
self.max_token_len = max_token_len
|
| 311 |
+
|
| 312 |
+
self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)]
|
| 313 |
+
|
| 314 |
+
self.encoding = self.tokenizer.batch_encode_plus(
|
| 315 |
+
[c for c, _ in self.chunks_and_labels],
|
| 316 |
+
add_special_tokens=True,
|
| 317 |
+
max_length=self.max_token_len,
|
| 318 |
+
return_token_type_ids=False,
|
| 319 |
+
padding="max_length",
|
| 320 |
+
truncation=True,
|
| 321 |
+
return_attention_mask=True,
|
| 322 |
+
return_tensors='pt',
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
def __len__(self):
|
| 326 |
+
return len(self.chunks_and_labels)
|
| 327 |
+
|
| 328 |
+
def __getitem__(self, index: int):
|
| 329 |
+
text, labels = self.chunks_and_labels[index]
|
| 330 |
+
|
| 331 |
+
return dict(
|
| 332 |
+
text=text,
|
| 333 |
+
input_ids=self.encoding[index]["input_ids"].flatten(),
|
| 334 |
+
attention_mask=self.encoding[index]["attention_mask"].flatten(),
|
| 335 |
+
labels=torch.FloatTensor(labels)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class EurovocDataModule(pl.LightningDataModule):
|
| 340 |
+
|
| 341 |
+
def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512):
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
self.batch_size = batch_size
|
| 345 |
+
self.x_tr = x_tr
|
| 346 |
+
self.y_tr = y_tr
|
| 347 |
+
self.x_test = x_test
|
| 348 |
+
self.y_test = y_test
|
| 349 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
| 350 |
+
self.max_token_len = max_token_len
|
| 351 |
+
|
| 352 |
+
def setup(self, stage=None):
|
| 353 |
+
self.train_dataset = EurovocDataset(
|
| 354 |
+
self.x_tr,
|
| 355 |
+
self.y_tr,
|
| 356 |
+
self.tokenizer,
|
| 357 |
+
self.max_token_len
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
self.test_dataset = EurovocDataset(
|
| 361 |
+
self.x_test,
|
| 362 |
+
self.y_test,
|
| 363 |
+
self.tokenizer,
|
| 364 |
+
self.max_token_len
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def train_dataloader(self):
|
| 368 |
+
return DataLoader(
|
| 369 |
+
self.train_dataset,
|
| 370 |
+
batch_size=self.batch_size,
|
| 371 |
+
shuffle=True,
|
| 372 |
+
num_workers=2
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
def val_dataloader(self):
|
| 376 |
+
return DataLoader(
|
| 377 |
+
self.test_dataset,
|
| 378 |
+
batch_size=self.batch_size,
|
| 379 |
+
num_workers=2
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def test_dataloader(self):
|
| 383 |
+
return DataLoader(
|
| 384 |
+
self.test_dataset,
|
| 385 |
+
batch_size=self.batch_size,
|
| 386 |
+
num_workers=2
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
class StreamingEurovocDataModule(pl.LightningDataModule):
|
| 390 |
+
"""
|
| 391 |
+
DataModule that uses streaming datasets.
|
| 392 |
+
Supports both random and smart (stratified) file splitting.
|
| 393 |
+
Can load pre-computed splits from config file.
|
| 394 |
+
"""
|
| 395 |
+
def __init__(self, bert_model_name, all_jsonl_files, mlb,
|
| 396 |
+
batch_size=64, max_token_len=512,
|
| 397 |
+
train_ratio=0.92, rare_threshold=0.005,
|
| 398 |
+
split_strategy='smart',
|
| 399 |
+
split_config_path="../eurovoc_data/train_val_split_config.json",
|
| 400 |
+
save_split_config_path="../eurovoc_data/train_val_split_config.json"):
|
| 401 |
+
"""
|
| 402 |
+
Args:
|
| 403 |
+
bert_model_name: Name of the BERT model to use
|
| 404 |
+
all_jsonl_files: List of all JSONL file paths (ignored if split_config_path provided)
|
| 405 |
+
mlb: Fitted MultiLabelBinarizer
|
| 406 |
+
batch_size: Batch size for dataloaders
|
| 407 |
+
max_token_len: Maximum token length for tokenization
|
| 408 |
+
train_ratio: Fraction of files for training
|
| 409 |
+
rare_threshold: Threshold for rare label identification
|
| 410 |
+
split_strategy: 'random' or 'smart'
|
| 411 |
+
split_config_path: Path to existing split config JSON (if provided, loads from this)
|
| 412 |
+
save_split_config_path: Path to save new split config JSON
|
| 413 |
+
"""
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.batch_size = batch_size
|
| 416 |
+
self.mlb = mlb
|
| 417 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
| 418 |
+
self.max_token_len = max_token_len
|
| 419 |
+
|
| 420 |
+
# Option 1: Load from existing config
|
| 421 |
+
if split_config_path and os.path.exists(split_config_path):
|
| 422 |
+
print(f"Loading split from existing config: {split_config_path}")
|
| 423 |
+
self.train_files, self.val_files, metadata = load_split_config(split_config_path)
|
| 424 |
+
if metadata:
|
| 425 |
+
print(f"Split metadata: {metadata}")
|
| 426 |
+
|
| 427 |
+
# Option 2: Create new split
|
| 428 |
+
else:
|
| 429 |
+
if split_strategy == 'smart':
|
| 430 |
+
print("Using smart split strategy (ensuring rare label coverage)...")
|
| 431 |
+
self.train_files, self.val_files = smart_split_files(
|
| 432 |
+
all_jsonl_files,
|
| 433 |
+
train_ratio=train_ratio,
|
| 434 |
+
rare_threshold=rare_threshold,
|
| 435 |
+
save_config_path=save_split_config_path
|
| 436 |
+
)
|
| 437 |
+
elif split_strategy == 'random':
|
| 438 |
+
print("Using random split strategy...")
|
| 439 |
+
random.shuffle(all_jsonl_files)
|
| 440 |
+
|
| 441 |
+
split_idx = int(len(all_jsonl_files) * train_ratio)
|
| 442 |
+
self.train_files = all_jsonl_files[:split_idx]
|
| 443 |
+
self.val_files = all_jsonl_files[split_idx:]
|
| 444 |
+
|
| 445 |
+
print(f"Train files: {len(self.train_files)}")
|
| 446 |
+
print(f"Val files: {len(self.val_files)}")
|
| 447 |
+
|
| 448 |
+
# Save config if requested
|
| 449 |
+
if save_split_config_path:
|
| 450 |
+
metadata = {
|
| 451 |
+
'train_ratio': train_ratio,
|
| 452 |
+
'split_strategy': 'random',
|
| 453 |
+
'total_files': len(all_jsonl_files)
|
| 454 |
+
}
|
| 455 |
+
save_split_config(self.train_files, self.val_files,
|
| 456 |
+
save_split_config_path, metadata)
|
| 457 |
+
else:
|
| 458 |
+
raise ValueError(f"Unknown split_strategy: {split_strategy}. Use 'random' or 'smart'")
|
| 459 |
+
|
| 460 |
+
def setup(self, stage=None):
|
| 461 |
+
self.train_dataset = StreamingEurovocDataset(
|
| 462 |
+
self.train_files,
|
| 463 |
+
self.mlb,
|
| 464 |
+
self.tokenizer,
|
| 465 |
+
self.max_token_len
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
self.val_dataset = StreamingEurovocDataset(
|
| 469 |
+
self.val_files,
|
| 470 |
+
self.mlb,
|
| 471 |
+
self.tokenizer,
|
| 472 |
+
self.max_token_len
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
def train_dataloader(self):
|
| 476 |
+
return DataLoader(
|
| 477 |
+
self.train_dataset,
|
| 478 |
+
batch_size=self.batch_size,
|
| 479 |
+
num_workers=4,
|
| 480 |
+
pin_memory=True
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
def val_dataloader(self):
|
| 484 |
+
return DataLoader(
|
| 485 |
+
self.val_dataset,
|
| 486 |
+
batch_size=self.batch_size,
|
| 487 |
+
num_workers=4,
|
| 488 |
+
pin_memory=True
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class EurovocTagger(pl.LightningModule):
|
| 493 |
+
|
| 494 |
+
def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.bert = AutoModel.from_pretrained(bert_model_name)
|
| 497 |
+
self.dropout = nn.Dropout(p=0.2)
|
| 498 |
+
self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
|
| 499 |
+
self.criterion = nn.BCELoss()
|
| 500 |
+
self.lr = lr
|
| 501 |
+
self.eps = eps
|
| 502 |
+
|
| 503 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
| 504 |
+
output = self.bert(input_ids, attention_mask=attention_mask)
|
| 505 |
+
output = self.dropout(output.pooler_output)
|
| 506 |
+
output = self.classifier1(output)
|
| 507 |
+
output = torch.sigmoid(output)
|
| 508 |
+
loss = 0
|
| 509 |
+
if labels is not None:
|
| 510 |
+
loss = self.criterion(output, labels)
|
| 511 |
+
return loss, output
|
| 512 |
+
|
| 513 |
+
def training_step(self, batch, batch_idx):
|
| 514 |
+
input_ids = batch["input_ids"]
|
| 515 |
+
attention_mask = batch["attention_mask"]
|
| 516 |
+
labels = batch["labels"]
|
| 517 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 518 |
+
self.log("train_loss", loss, prog_bar=True, logger=True)
|
| 519 |
+
return {"loss": loss, "predictions": outputs, "labels": labels}
|
| 520 |
+
|
| 521 |
+
def validation_step(self, batch, batch_idx):
|
| 522 |
+
input_ids = batch["input_ids"]
|
| 523 |
+
attention_mask = batch["attention_mask"]
|
| 524 |
+
labels = batch["labels"]
|
| 525 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 526 |
+
self.log("val_loss", loss, prog_bar=True, logger=True)
|
| 527 |
+
return loss
|
| 528 |
+
|
| 529 |
+
def test_step(self, batch, batch_idx):
|
| 530 |
+
input_ids = batch["input_ids"]
|
| 531 |
+
attention_mask = batch["attention_mask"]
|
| 532 |
+
labels = batch["labels"]
|
| 533 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 534 |
+
self.log("test_loss", loss, prog_bar=True, logger=True)
|
| 535 |
+
return loss
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def configure_optimizers(self):
|
| 539 |
+
return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class EurovocTaggerBCELogit(pl.LightningModule):
|
| 543 |
+
|
| 544 |
+
def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.bert = AutoModel.from_pretrained(bert_model_name)
|
| 547 |
+
self.dropout = nn.Dropout(p=0.2)
|
| 548 |
+
self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
|
| 549 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 550 |
+
self.lr = lr
|
| 551 |
+
self.eps = eps
|
| 552 |
+
|
| 553 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
| 554 |
+
output = self.bert(input_ids, attention_mask=attention_mask)
|
| 555 |
+
output = self.dropout(output.pooler_output)
|
| 556 |
+
output = self.classifier1(output)
|
| 557 |
+
loss = 0
|
| 558 |
+
if labels is not None:
|
| 559 |
+
loss = self.criterion(output, labels)
|
| 560 |
+
return loss, output
|
| 561 |
+
|
| 562 |
+
def training_step(self, batch, batch_idx):
|
| 563 |
+
input_ids = batch["input_ids"]
|
| 564 |
+
attention_mask = batch["attention_mask"]
|
| 565 |
+
labels = batch["labels"]
|
| 566 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 567 |
+
self.log("train_loss", loss, prog_bar=True, logger=True)
|
| 568 |
+
return {"loss": loss, "predictions": outputs, "labels": labels}
|
| 569 |
+
|
| 570 |
+
def validation_step(self, batch, batch_idx):
|
| 571 |
+
input_ids = batch["input_ids"]
|
| 572 |
+
attention_mask = batch["attention_mask"]
|
| 573 |
+
labels = batch["labels"]
|
| 574 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 575 |
+
self.log("val_loss", loss, prog_bar=True, logger=True)
|
| 576 |
+
return loss
|
| 577 |
+
|
| 578 |
+
def test_step(self, batch, batch_idx):
|
| 579 |
+
input_ids = batch["input_ids"]
|
| 580 |
+
attention_mask = batch["attention_mask"]
|
| 581 |
+
labels = batch["labels"]
|
| 582 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 583 |
+
self.log("test_loss", loss, prog_bar=True, logger=True)
|
| 584 |
+
return loss
|
| 585 |
+
|
| 586 |
+
def configure_optimizers(self):
|
| 587 |
+
return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class EurovocTaggerLoRA(pl.LightningModule):
|
| 591 |
+
|
| 592 |
+
def __init__(self, bert_model_name, n_classes, n_intermediate=256, lr=2e-5, eps=1e-8, lora_r=8, lora_alpha=16, lora_dropout=0.1):
|
| 593 |
+
super().__init__()
|
| 594 |
+
|
| 595 |
+
# Load base BERT model
|
| 596 |
+
self.bert = AutoModel.from_pretrained(bert_model_name)
|
| 597 |
+
|
| 598 |
+
# Configure LoRA
|
| 599 |
+
# Target modules: query and value projection layers in attention
|
| 600 |
+
lora_config = LoraConfig(
|
| 601 |
+
r=lora_r, # Rank of the low-rank matrices (smaller = fewer params)
|
| 602 |
+
lora_alpha=lora_alpha, # Scaling factor
|
| 603 |
+
target_modules=["query", "value"], # Which layers to apply LoRA to
|
| 604 |
+
lora_dropout=lora_dropout,
|
| 605 |
+
bias="none",
|
| 606 |
+
task_type=TaskType.FEATURE_EXTRACTION # For getting embeddings
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Apply LoRA to BERT
|
| 610 |
+
self.bert = get_peft_model(self.bert, lora_config)
|
| 611 |
+
|
| 612 |
+
# Print trainable parameters info
|
| 613 |
+
self.bert.print_trainable_parameters()
|
| 614 |
+
|
| 615 |
+
# Hierarchical classification head for 6800 labels
|
| 616 |
+
# Instead of 768 → 6800 (5.2M params), use 768 → 256 → 6800 (1.9M params)
|
| 617 |
+
# 768
|
| 618 |
+
hidden_size = self.bert.config.hidden_size
|
| 619 |
+
|
| 620 |
+
self.dropout1 = nn.Dropout(p=0.2)
|
| 621 |
+
|
| 622 |
+
# Layer 1: Compress to intermediate representation
|
| 623 |
+
self.classifier1 = nn.Linear(hidden_size, n_intermediate) # 768 → 256
|
| 624 |
+
self.relu = nn.ReLU()
|
| 625 |
+
|
| 626 |
+
self.dropout2 = nn.Dropout(p=0.2)
|
| 627 |
+
|
| 628 |
+
# Layer 2: Expand to all labels
|
| 629 |
+
self.classifier2 = nn.Linear(n_intermediate, n_classes) # 256 → 6800
|
| 630 |
+
|
| 631 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 632 |
+
self.lr = lr
|
| 633 |
+
self.eps = eps
|
| 634 |
+
|
| 635 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
| 636 |
+
# Forward pass through LoRA-enhanced BERT
|
| 637 |
+
output = self.bert(input_ids, attention_mask=attention_mask)
|
| 638 |
+
|
| 639 |
+
# Get pooled output (CLS token representation)
|
| 640 |
+
# (batch, 768)
|
| 641 |
+
output = self.dropout1(output.pooler_output)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
# Hierarchical classifier
|
| 645 |
+
output = self.classifier1(output)
|
| 646 |
+
# (batch, 256)
|
| 647 |
+
output = self.relu(output)
|
| 648 |
+
|
| 649 |
+
output = self.dropout2(output)
|
| 650 |
+
# (batch, 6800)
|
| 651 |
+
output = self.classifier2(output)
|
| 652 |
+
|
| 653 |
+
loss = 0
|
| 654 |
+
if labels is not None:
|
| 655 |
+
loss = self.criterion(output, labels)
|
| 656 |
+
return loss, output
|
| 657 |
+
|
| 658 |
+
def training_step(self, batch, batch_idx):
|
| 659 |
+
input_ids = batch["input_ids"]
|
| 660 |
+
attention_mask = batch["attention_mask"]
|
| 661 |
+
labels = batch["labels"]
|
| 662 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 663 |
+
self.log("train_loss", loss, prog_bar=True, logger=True)
|
| 664 |
+
return {"loss": loss, "predictions": outputs, "labels": labels}
|
| 665 |
+
|
| 666 |
+
def validation_step(self, batch, batch_idx):
|
| 667 |
+
input_ids = batch["input_ids"]
|
| 668 |
+
attention_mask = batch["attention_mask"]
|
| 669 |
+
labels = batch["labels"]
|
| 670 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 671 |
+
self.log("val_loss", loss, prog_bar=True, logger=True)
|
| 672 |
+
return loss
|
| 673 |
+
|
| 674 |
+
def test_step(self, batch, batch_idx):
|
| 675 |
+
input_ids = batch["input_ids"]
|
| 676 |
+
attention_mask = batch["attention_mask"]
|
| 677 |
+
labels = batch["labels"]
|
| 678 |
+
loss, outputs = self(input_ids, attention_mask, labels)
|
| 679 |
+
self.log("test_loss", loss, prog_bar=True, logger=True)
|
| 680 |
+
return loss
|
| 681 |
+
|
| 682 |
+
def configure_optimizers(self):
|
| 683 |
+
return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
|
| 684 |
+
|
| 685 |
+
def save_lora_adapter(self, path):
|
| 686 |
+
"""Save only the LoRA adapter weights"""
|
| 687 |
+
self.bert.save_pretrained(path)
|
| 688 |
+
|
| 689 |
+
def load_lora_adapter(self, path):
|
| 690 |
+
"""Load LoRA adapter weights"""
|
| 691 |
+
self.bert = PeftModel.from_pretrained(self.bert, path)
|
inference_test.ipynb
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "11ab9cd5-a6e4-416a-b44f-201e8bf8ee84",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Test inference"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 5,
|
| 14 |
+
"id": "40523be3-6ec7-4cac-aa90-6b5177c0f07d",
|
| 15 |
+
"metadata": {
|
| 16 |
+
"tags": []
|
| 17 |
+
},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"from pdfminer.high_level import extract_text"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": 26,
|
| 26 |
+
"id": "c0e5cc3f-5a9d-4b0f-8f7c-d46c0f79b5df",
|
| 27 |
+
"metadata": {
|
| 28 |
+
"tags": []
|
| 29 |
+
},
|
| 30 |
+
"outputs": [
|
| 31 |
+
{
|
| 32 |
+
"name": "stderr",
|
| 33 |
+
"output_type": "stream",
|
| 34 |
+
"text": [
|
| 35 |
+
"Cannot set gray non-stroke color because /'P3954' is an invalid float value\n"
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"source": [
|
| 40 |
+
"text = extract_text(\"./example_docs_for_inference/publication_climate.pdf\")"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": 27,
|
| 46 |
+
"id": "120528e3-26b9-40ce-ac8c-3c30c3092d28",
|
| 47 |
+
"metadata": {
|
| 48 |
+
"tags": []
|
| 49 |
+
},
|
| 50 |
+
"outputs": [
|
| 51 |
+
{
|
| 52 |
+
"name": "stdout",
|
| 53 |
+
"output_type": "stream",
|
| 54 |
+
"text": [
|
| 55 |
+
"ISSN 1831-9424 \n",
|
| 56 |
+
"\n",
|
| 57 |
+
"How to plan mitigation, adaptatio\n"
|
| 58 |
+
]
|
| 59 |
+
}
|
| 60 |
+
],
|
| 61 |
+
"source": [
|
| 62 |
+
"print(text[0:50])"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": 9,
|
| 68 |
+
"id": "d191928f-381e-4da3-8342-1300909b52c5",
|
| 69 |
+
"metadata": {
|
| 70 |
+
"tags": []
|
| 71 |
+
},
|
| 72 |
+
"outputs": [
|
| 73 |
+
{
|
| 74 |
+
"name": "stderr",
|
| 75 |
+
"output_type": "stream",
|
| 76 |
+
"text": [
|
| 77 |
+
"/home/mbarhdadi/projects/training/eurovoc_training_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 78 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"name": "stdout",
|
| 83 |
+
"output_type": "stream",
|
| 84 |
+
"text": [
|
| 85 |
+
"Model loaded. Ready to predict 6958 eurovoc labels.\n"
|
| 86 |
+
]
|
| 87 |
+
}
|
| 88 |
+
],
|
| 89 |
+
"source": [
|
| 90 |
+
"import pickle\n",
|
| 91 |
+
"from transformers import AutoTokenizer, AutoModel\n",
|
| 92 |
+
"from eurovoc import EurovocTagger\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# Load MLBinarizer\n",
|
| 95 |
+
"with open('./models_finetuned/latest/mlb.pickle', 'rb') as f:\n",
|
| 96 |
+
" mlb = pickle.load(f)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"# Load tokenizer\n",
|
| 99 |
+
"BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
|
| 100 |
+
"tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# Load trained model\n",
|
| 103 |
+
"checkpoint_path = \"./models_finetuned/latest/EurovocTaggerFP32-epoch=04-val_loss=0.00.ckpt\" \n",
|
| 104 |
+
"model = EurovocTagger.load_from_checkpoint(\n",
|
| 105 |
+
" checkpoint_path,\n",
|
| 106 |
+
" bert_model_name=BERT_MODEL_NAME,\n",
|
| 107 |
+
" n_classes=len(mlb.classes_)\n",
|
| 108 |
+
")\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"print(f\"Model loaded. Ready to predict {len(mlb.classes_)} eurovoc labels.\")"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": 15,
|
| 117 |
+
"id": "7a1fd7e6-e14d-4c24-97ae-abcd5a30ab71",
|
| 118 |
+
"metadata": {
|
| 119 |
+
"tags": []
|
| 120 |
+
},
|
| 121 |
+
"outputs": [],
|
| 122 |
+
"source": [
|
| 123 |
+
"def get_eurovoc_id_to_term_mapping():\n",
|
| 124 |
+
" \"\"\"\n",
|
| 125 |
+
" Create a mapping from eurovoc IDs to their human-readable terms.\n",
|
| 126 |
+
" \n",
|
| 127 |
+
" Returns:\n",
|
| 128 |
+
" Dict mapping eurovoc_id -> term_name\n",
|
| 129 |
+
" \"\"\"\n",
|
| 130 |
+
" import requests\n",
|
| 131 |
+
" import xmltodict\n",
|
| 132 |
+
" \n",
|
| 133 |
+
" eurovoc_id_to_term = {}\n",
|
| 134 |
+
" \n",
|
| 135 |
+
" response = requests.get(\n",
|
| 136 |
+
" 'http://publications.europa.eu/resource/dataset/eurovoc',\n",
|
| 137 |
+
" headers={\n",
|
| 138 |
+
" 'Accept': 'application/xml',\n",
|
| 139 |
+
" 'Accept-Language': 'en',\n",
|
| 140 |
+
" 'User-Agent': 'Mozilla/5.0'\n",
|
| 141 |
+
" }\n",
|
| 142 |
+
" )\n",
|
| 143 |
+
" \n",
|
| 144 |
+
" data = xmltodict.parse(response.content)\n",
|
| 145 |
+
" \n",
|
| 146 |
+
" for term in data['xs:schema']['xs:simpleType']['xs:restriction']['xs:enumeration']:\n",
|
| 147 |
+
" try:\n",
|
| 148 |
+
" name = term['xs:annotation']['xs:documentation'].split('/')[0].strip()\n",
|
| 149 |
+
" eurovoc_id = term['@value'].split(':')[1]\n",
|
| 150 |
+
" \n",
|
| 151 |
+
" # Map ID -> term \n",
|
| 152 |
+
" eurovoc_id_to_term[eurovoc_id] = {\n",
|
| 153 |
+
" 'original': name,\n",
|
| 154 |
+
" 'lowercase': name.lower()\n",
|
| 155 |
+
" }\n",
|
| 156 |
+
" except (KeyError, IndexError) as e:\n",
|
| 157 |
+
" print(f\"⚠️ Could not parse term: {term}\")\n",
|
| 158 |
+
" \n",
|
| 159 |
+
" print(f\"✓ Loaded {len(eurovoc_id_to_term)} eurovoc terms\")\n",
|
| 160 |
+
" return eurovoc_id_to_term"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": 23,
|
| 166 |
+
"id": "d2b703ea-ca41-4353-8776-1a226f02c56b",
|
| 167 |
+
"metadata": {
|
| 168 |
+
"tags": []
|
| 169 |
+
},
|
| 170 |
+
"outputs": [
|
| 171 |
+
{
|
| 172 |
+
"name": "stdout",
|
| 173 |
+
"output_type": "stream",
|
| 174 |
+
"text": [
|
| 175 |
+
"Loading Eurovoc terms...\n",
|
| 176 |
+
"✓ Loaded 7488 eurovoc terms\n"
|
| 177 |
+
]
|
| 178 |
+
}
|
| 179 |
+
],
|
| 180 |
+
"source": [
|
| 181 |
+
"print(\"Loading Eurovoc terms...\")\n",
|
| 182 |
+
"eurovoc_id_to_term = get_eurovoc_id_to_term_mapping()\n"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": 24,
|
| 188 |
+
"id": "7a5fed81-64e8-4454-a56b-73eb50676b75",
|
| 189 |
+
"metadata": {
|
| 190 |
+
"tags": []
|
| 191 |
+
},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": [
|
| 194 |
+
"import torch\n",
|
| 195 |
+
"import numpy as np\n",
|
| 196 |
+
"from transformers import AutoTokenizer\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"def predict_eurovoc_labels(text, model, mlb, tokenizer, \n",
|
| 199 |
+
" eurovoc_id_to_term=None,\n",
|
| 200 |
+
" max_token_len=512, \n",
|
| 201 |
+
" threshold=0.5, \n",
|
| 202 |
+
" top_k=10,\n",
|
| 203 |
+
" device='cuda'):\n",
|
| 204 |
+
" model.eval()\n",
|
| 205 |
+
" model.to(device)\n",
|
| 206 |
+
" \n",
|
| 207 |
+
" # Tokenize\n",
|
| 208 |
+
" encoding = tokenizer.encode_plus(\n",
|
| 209 |
+
" text,\n",
|
| 210 |
+
" add_special_tokens=True,\n",
|
| 211 |
+
" max_length=max_token_len,\n",
|
| 212 |
+
" return_token_type_ids=False,\n",
|
| 213 |
+
" padding=\"max_length\",\n",
|
| 214 |
+
" truncation=True,\n",
|
| 215 |
+
" return_attention_mask=True,\n",
|
| 216 |
+
" return_tensors='pt',\n",
|
| 217 |
+
" )\n",
|
| 218 |
+
" \n",
|
| 219 |
+
" input_ids = encoding[\"input_ids\"].to(device)\n",
|
| 220 |
+
" attention_mask = encoding[\"attention_mask\"].to(device)\n",
|
| 221 |
+
" \n",
|
| 222 |
+
" # Predict\n",
|
| 223 |
+
" with torch.no_grad():\n",
|
| 224 |
+
" _, outputs = model(input_ids, attention_mask)\n",
|
| 225 |
+
" \n",
|
| 226 |
+
"\n",
|
| 227 |
+
" probabilities = outputs\n",
|
| 228 |
+
" \n",
|
| 229 |
+
" probabilities = probabilities.cpu().numpy()[0]\n",
|
| 230 |
+
" \n",
|
| 231 |
+
" # Helper function to enrich labels with terms\n",
|
| 232 |
+
" def enrich_labels(label_ids, probs):\n",
|
| 233 |
+
" \"\"\"Add human-readable terms to eurovoc IDs\"\"\"\n",
|
| 234 |
+
" enriched = []\n",
|
| 235 |
+
" for label_id, prob in zip(label_ids, probs):\n",
|
| 236 |
+
" entry = {\n",
|
| 237 |
+
" 'eurovoc_id': label_id,\n",
|
| 238 |
+
" 'probability': float(prob)\n",
|
| 239 |
+
" }\n",
|
| 240 |
+
" \n",
|
| 241 |
+
" # Add term if mapping available\n",
|
| 242 |
+
" if eurovoc_id_to_term and label_id in eurovoc_id_to_term:\n",
|
| 243 |
+
" entry['term'] = eurovoc_id_to_term[label_id]['original']\n",
|
| 244 |
+
" entry['term_lower'] = eurovoc_id_to_term[label_id]['lowercase']\n",
|
| 245 |
+
" else:\n",
|
| 246 |
+
" entry['term'] = None\n",
|
| 247 |
+
" entry['term_lower'] = None\n",
|
| 248 |
+
" \n",
|
| 249 |
+
" enriched.append(entry)\n",
|
| 250 |
+
" \n",
|
| 251 |
+
" return enriched\n",
|
| 252 |
+
" \n",
|
| 253 |
+
" # Get predictions above threshold\n",
|
| 254 |
+
" predicted_indices = np.where(probabilities >= threshold)[0]\n",
|
| 255 |
+
" predicted_labels = mlb.classes_[predicted_indices]\n",
|
| 256 |
+
" predicted_probs = probabilities[predicted_indices]\n",
|
| 257 |
+
" \n",
|
| 258 |
+
" # Get top-k predictions\n",
|
| 259 |
+
" top_k_indices = np.argsort(probabilities)[-top_k:][::-1]\n",
|
| 260 |
+
" top_k_labels = mlb.classes_[top_k_indices]\n",
|
| 261 |
+
" top_k_probs = probabilities[top_k_indices]\n",
|
| 262 |
+
" \n",
|
| 263 |
+
" return {\n",
|
| 264 |
+
" 'above_threshold': {\n",
|
| 265 |
+
" 'predictions': enrich_labels(predicted_labels, predicted_probs),\n",
|
| 266 |
+
" 'count': len(predicted_labels)\n",
|
| 267 |
+
" },\n",
|
| 268 |
+
" 'top_k': {\n",
|
| 269 |
+
" 'predictions': enrich_labels(top_k_labels, top_k_probs)\n",
|
| 270 |
+
" }\n",
|
| 271 |
+
" }"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": 28,
|
| 277 |
+
"id": "030b99aa-edc7-472a-8c7f-636a47a9cdce",
|
| 278 |
+
"metadata": {
|
| 279 |
+
"tags": []
|
| 280 |
+
},
|
| 281 |
+
"outputs": [
|
| 282 |
+
{
|
| 283 |
+
"name": "stdout",
|
| 284 |
+
"output_type": "stream",
|
| 285 |
+
"text": [
|
| 286 |
+
"Document length: 696483 characters\n",
|
| 287 |
+
"Truncated to: 2048 tokens (~2048 chars)\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"Running inference...\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"================================================================================\n",
|
| 292 |
+
"TOP 15 PREDICTED EUROVOC LABELS (with terms)\n",
|
| 293 |
+
"================================================================================\n",
|
| 294 |
+
"642 | energy saving | 0.8567\n",
|
| 295 |
+
"6700 | energy efficiency | 0.7060\n",
|
| 296 |
+
"2281 | poverty | 0.4645\n",
|
| 297 |
+
"5311 | user guide | 0.4198\n",
|
| 298 |
+
"2498 | energy policy | 0.3545\n",
|
| 299 |
+
"5482 | climate change | 0.1736\n",
|
| 300 |
+
"754 | renewable energy | 0.1338\n",
|
| 301 |
+
"6400 | reduction of gas emissions | 0.1321\n",
|
| 302 |
+
"2517 | social policy | 0.1260\n",
|
| 303 |
+
"475 | energy distribution | 0.1253\n",
|
| 304 |
+
"5188 | information technology | 0.1087\n",
|
| 305 |
+
"2715 | energy production | 0.1087\n",
|
| 306 |
+
"2451 | EU policy | 0.0812\n",
|
| 307 |
+
"4139 | serial publication | 0.0808\n",
|
| 308 |
+
"83 | living conditions | 0.0793\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"5 labels above threshold (0.3)\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"================================================================================\n",
|
| 313 |
+
"PREDICTIONS ABOVE THRESHOLD (with readable terms)\n",
|
| 314 |
+
"================================================================================\n",
|
| 315 |
+
"2281 | poverty | 0.4645\n",
|
| 316 |
+
"2498 | energy policy | 0.3545\n",
|
| 317 |
+
"5311 | user guide | 0.4198\n",
|
| 318 |
+
"642 | energy saving | 0.8567\n",
|
| 319 |
+
"6700 | energy efficiency | 0.7060\n"
|
| 320 |
+
]
|
| 321 |
+
}
|
| 322 |
+
],
|
| 323 |
+
"source": [
|
| 324 |
+
"print(f\"Document length: {len(text)} characters\")\n",
|
| 325 |
+
"print(f\"Truncated to: {512 * 4} tokens (~2048 chars)\\n\") \n",
|
| 326 |
+
"\n",
|
| 327 |
+
"print(\"Running inference...\\n\")\n",
|
| 328 |
+
"results = predict_eurovoc_labels(\n",
|
| 329 |
+
" text=text,\n",
|
| 330 |
+
" model=model,\n",
|
| 331 |
+
" mlb=mlb,\n",
|
| 332 |
+
" tokenizer=tokenizer,\n",
|
| 333 |
+
" eurovoc_id_to_term=eurovoc_id_to_term, # ← Pass the mapping\n",
|
| 334 |
+
" threshold=0.3,\n",
|
| 335 |
+
" top_k=15\n",
|
| 336 |
+
")\n",
|
| 337 |
+
"print(\"=\" * 80)\n",
|
| 338 |
+
"print(\"TOP 15 PREDICTED EUROVOC LABELS\")\n",
|
| 339 |
+
"print(\"=\" * 80)\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"for pred in results['top_k']['predictions']:\n",
|
| 342 |
+
" term = pred['term'] if pred['term'] else \"(term not found)\"\n",
|
| 343 |
+
" print(f\"{pred['eurovoc_id']:15s} | {term:45s} | {pred['probability']:.4f}\")\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"print(f\"\\n{results['above_threshold']['count']} labels above threshold (0.3)\")\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 348 |
+
"print(\"PREDICTIONS ABOVE THRESHOLD\")\n",
|
| 349 |
+
"print(\"=\" * 80)\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"for pred in results['above_threshold']['predictions']:\n",
|
| 352 |
+
" if pred['term']: # Only show if term was found\n",
|
| 353 |
+
" print(f\"{pred['eurovoc_id']:15s} | {pred['term']:45s} | {pred['probability']:.4f}\")"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"cell_type": "code",
|
| 358 |
+
"execution_count": null,
|
| 359 |
+
"id": "27ebc73c-5832-4702-bc1e-dd026ebeed02",
|
| 360 |
+
"metadata": {},
|
| 361 |
+
"outputs": [],
|
| 362 |
+
"source": []
|
| 363 |
+
}
|
| 364 |
+
],
|
| 365 |
+
"metadata": {
|
| 366 |
+
"kernelspec": {
|
| 367 |
+
"display_name": "eurovoc_training_env",
|
| 368 |
+
"language": "python",
|
| 369 |
+
"name": "eurovoc_training_env"
|
| 370 |
+
},
|
| 371 |
+
"language_info": {
|
| 372 |
+
"codemirror_mode": {
|
| 373 |
+
"name": "ipython",
|
| 374 |
+
"version": 3
|
| 375 |
+
},
|
| 376 |
+
"file_extension": ".py",
|
| 377 |
+
"mimetype": "text/x-python",
|
| 378 |
+
"name": "python",
|
| 379 |
+
"nbconvert_exporter": "python",
|
| 380 |
+
"pygments_lexer": "ipython3",
|
| 381 |
+
"version": "3.10.12"
|
| 382 |
+
}
|
| 383 |
+
},
|
| 384 |
+
"nbformat": 4,
|
| 385 |
+
"nbformat_minor": 5
|
| 386 |
+
}
|
train_lora_included.ipynb
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "3dc740a0-1865-40da-a163-b858f29d1313",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# 🇪🇺 🏷️ Eurovoc Model Training Notebook"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "64a1dc4a-5bf5-46d9-9356-3958802837ac",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import pickle \n",
|
| 19 |
+
"import pandas as pd\n",
|
| 20 |
+
"from transformers import AutoTokenizer, AutoModel\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"from datasets import load_dataset\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 25 |
+
"import torch\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"import pytorch_lightning as pl\n",
|
| 28 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": null,
|
| 34 |
+
"id": "caa5dc4b-2fe3-43da-846d-a866c2224280",
|
| 35 |
+
"metadata": {
|
| 36 |
+
"tags": []
|
| 37 |
+
},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"fixed_dir = fix_all_files(all_jsonl_files)\n",
|
| 41 |
+
"logger.info(f\"Done! Use files from: {fixed_dir}\")"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"id": "6d63a920-52aa-4c73-bd2d-575e888d3d55",
|
| 47 |
+
"metadata": {
|
| 48 |
+
"tags": []
|
| 49 |
+
},
|
| 50 |
+
"source": [
|
| 51 |
+
"### Create the MultiLabel Binarizer and save it in a file for prediction "
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"id": "921fd5cd-67e7-4962-8e5e-15e055dd63b6",
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"from tqdm import tqdm\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"import os\n",
|
| 65 |
+
"from datetime import datetime\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"def list_all_json_files(directory=FIXED_DIR):\n",
|
| 70 |
+
" # List all items in the directory\n",
|
| 71 |
+
" all_items = os.listdir(directory)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
" def extract_date_key(filename):\n",
|
| 74 |
+
" \"\"\"\n",
|
| 75 |
+
" Extracts a datetime object from filenames containing YYYY-MM.\n",
|
| 76 |
+
" Handles .jsonl and .jsonl.gz.\n",
|
| 77 |
+
" \"\"\"\n",
|
| 78 |
+
" base = filename.split('.')[0] \n",
|
| 79 |
+
" yyyy, mm = base.split('-') \n",
|
| 80 |
+
" return datetime(int(yyyy), int(mm), 1)\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" jsonl_files = [\n",
|
| 84 |
+
" f for f in all_items\n",
|
| 85 |
+
" if f.endswith(\".jsonl\") or f.endswith(\".jsonl.gz\")\n",
|
| 86 |
+
" ]\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" # Sort newest to oldest\n",
|
| 89 |
+
" jsonl_files_sorted = sorted(\n",
|
| 90 |
+
" jsonl_files,\n",
|
| 91 |
+
" key=extract_date_key,\n",
|
| 92 |
+
" reverse=True\n",
|
| 93 |
+
" )\n",
|
| 94 |
+
" return [os.path.join(directory, f) for f in jsonl_files_sorted]\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
" \n",
|
| 99 |
+
"print(f\"Found {len(all_jsonl_files)} files to load (including compressed).\")\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"def build_mlb_from_streaming(all_jsonl_files, output_path='../eurovoc_data/mlb.pickle'):\n",
|
| 103 |
+
" \"\"\"\n",
|
| 104 |
+
" Build MLBinarizer by scanning all files once to collect unique concepts.\n",
|
| 105 |
+
" This is more memory efficient than loading everything.\n",
|
| 106 |
+
" \"\"\"\n",
|
| 107 |
+
" print(\"Scanning files to collect all unique eurovoc concepts...\")\n",
|
| 108 |
+
" all_concepts = set()\n",
|
| 109 |
+
" \n",
|
| 110 |
+
" dataset = load_dataset(\n",
|
| 111 |
+
" 'json',\n",
|
| 112 |
+
" data_files=all_jsonl_files,\n",
|
| 113 |
+
" streaming=True,\n",
|
| 114 |
+
" split='train'\n",
|
| 115 |
+
" )\n",
|
| 116 |
+
" \n",
|
| 117 |
+
" for record in tqdm(dataset, desc=\"Collecting eurovoc IDS\"):\n",
|
| 118 |
+
" concepts = record.get('eurovoc_ids', [])\n",
|
| 119 |
+
" if concepts:\n",
|
| 120 |
+
" all_concepts.update(concepts)\n",
|
| 121 |
+
" \n",
|
| 122 |
+
" print(f\"Found {len(all_concepts)} unique eurovoc IDS\")\n",
|
| 123 |
+
" \n",
|
| 124 |
+
" # Create and fit MLBinarizer\n",
|
| 125 |
+
" mlb = MultiLabelBinarizer()\n",
|
| 126 |
+
" mlb.fit([sorted(list(all_concepts))])\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" # Save it\n",
|
| 129 |
+
" with open(output_path, 'wb') as f:\n",
|
| 130 |
+
" pickle.dump(mlb, f)\n",
|
| 131 |
+
" \n",
|
| 132 |
+
" print(f\"Saved MLBinarizer to {output_path}\")\n",
|
| 133 |
+
" return mlb\n",
|
| 134 |
+
"\n"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": null,
|
| 140 |
+
"id": "66e1d48e-83a7-4a38-a081-b72ba679e960",
|
| 141 |
+
"metadata": {
|
| 142 |
+
"tags": []
|
| 143 |
+
},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"build_mlb_from_streaming(all_jsonl_files)"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "markdown",
|
| 151 |
+
"id": "b2fd1bda-ee0e-40f2-85a6-87322a9db725",
|
| 152 |
+
"metadata": {
|
| 153 |
+
"tags": []
|
| 154 |
+
},
|
| 155 |
+
"source": [
|
| 156 |
+
"---\n",
|
| 157 |
+
"## 2. Load cleaned data and Split data using iterative train test \n",
|
| 158 |
+
"\n",
|
| 159 |
+
"## THIS ASSUMES ALL DATA IS IN 'TRAIN' OF DATASET, IF NOT ALSO LOAD IT HERE\n"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"cell_type": "code",
|
| 164 |
+
"execution_count": null,
|
| 165 |
+
"id": "aaba16cf-a9b6-4c22-944a-2d31b8b5812d",
|
| 166 |
+
"metadata": {
|
| 167 |
+
"tags": []
|
| 168 |
+
},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"import pickle\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"mlb = pickle.load(open('../eurovoc_data/mlb.pickle', 'rb'))\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"print(f\"Loaded MLBinarizer with {len(mlb.classes_)} classes\")\n",
|
| 176 |
+
" # Show first 10\n",
|
| 177 |
+
"print(f\"Classes: {mlb.classes_[:10]}...\") "
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"execution_count": null,
|
| 183 |
+
"id": "7f10ac21-5731-4937-8340-829d531c6116",
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"outputs": [],
|
| 186 |
+
"source": [
|
| 187 |
+
"%load_ext autoreload\n",
|
| 188 |
+
"%autoreload 2\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"import os\n",
|
| 191 |
+
"from datetime import datetime\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"def list_all_json_files(directory=FIXED_DIR):\n",
|
| 196 |
+
" # List all items in the directory\n",
|
| 197 |
+
" all_items = os.listdir(directory)\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" def extract_date_key(filename):\n",
|
| 200 |
+
" \"\"\"\n",
|
| 201 |
+
" Extracts a datetime object from filenames containing YYYY-MM.\n",
|
| 202 |
+
" Handles .jsonl and .jsonl.gz.\n",
|
| 203 |
+
" \"\"\"\n",
|
| 204 |
+
" base = filename.split('.')[0] \n",
|
| 205 |
+
" yyyy, mm = base.split('-') \n",
|
| 206 |
+
" return datetime(int(yyyy), int(mm), 1)\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" jsonl_files = [\n",
|
| 210 |
+
" f for f in all_items\n",
|
| 211 |
+
" if f.endswith(\".jsonl\") or f.endswith(\".jsonl.gz\")\n",
|
| 212 |
+
" ]\n",
|
| 213 |
+
"\n",
|
| 214 |
+
" # Sort newest to oldest\n",
|
| 215 |
+
" jsonl_files_sorted = sorted(\n",
|
| 216 |
+
" jsonl_files,\n",
|
| 217 |
+
" key=extract_date_key,\n",
|
| 218 |
+
" reverse=True\n",
|
| 219 |
+
" )\n",
|
| 220 |
+
" return [os.path.join(directory, f) for f in jsonl_files_sorted]\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
|
| 223 |
+
"\n",
|
| 224 |
+
" \n",
|
| 225 |
+
"print(f\"Found {len(all_jsonl_files)} files to load (including compressed).\")\n"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": null,
|
| 231 |
+
"id": "25ecca51-7901-448b-9d89-4ed0663b2bae",
|
| 232 |
+
"metadata": {
|
| 233 |
+
"tags": []
|
| 234 |
+
},
|
| 235 |
+
"outputs": [],
|
| 236 |
+
"source": [
|
| 237 |
+
"import gc\n",
|
| 238 |
+
"gc.collect()"
|
| 239 |
+
]
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "markdown",
|
| 243 |
+
"id": "1ff0c6b0-abcb-4424-be97-5c7bd8fb9af7",
|
| 244 |
+
"metadata": {},
|
| 245 |
+
"source": [
|
| 246 |
+
"## 2.1 Model definition"
|
| 247 |
+
]
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"cell_type": "markdown",
|
| 251 |
+
"id": "aaa9dc1b-1086-47d2-9b3b-20d954bda644",
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"source": [
|
| 254 |
+
"---\n",
|
| 255 |
+
"## 3. Model definition and training (NORMAL)"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "code",
|
| 260 |
+
"execution_count": null,
|
| 261 |
+
"id": "5f15e504-9431-4913-9016-4b0c6344a127",
|
| 262 |
+
"metadata": {
|
| 263 |
+
"tags": []
|
| 264 |
+
},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"import wandb\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"wandb.login() "
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"cell_type": "code",
|
| 274 |
+
"execution_count": null,
|
| 275 |
+
"id": "2f780f9d-730c-4540-ad9a-e0b60c87f147",
|
| 276 |
+
"metadata": {
|
| 277 |
+
"tags": []
|
| 278 |
+
},
|
| 279 |
+
"outputs": [],
|
| 280 |
+
"source": [
|
| 281 |
+
"%load_ext autoreload\n",
|
| 282 |
+
"%autoreload 2\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"from eurovoc import StreamingEurovocDataModule\n",
|
| 285 |
+
"from eurovoc import EurovocTagger\n",
|
| 286 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
| 287 |
+
"import pytorch_lightning as pl\n",
|
| 288 |
+
"import torch\n",
|
| 289 |
+
"from pytorch_lightning.callbacks import EarlyStopping\n",
|
| 290 |
+
"import gc\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"class MemoryMonitorCallback(pl.Callback):\n",
|
| 293 |
+
" def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):\n",
|
| 294 |
+
" # Log memory every 100 batches\n",
|
| 295 |
+
" if batch_idx % 100 == 0:\n",
|
| 296 |
+
" if torch.cuda.is_available():\n",
|
| 297 |
+
" for i in range(torch.cuda.device_count()):\n",
|
| 298 |
+
" allocated = torch.cuda.memory_allocated(i) / 1e9\n",
|
| 299 |
+
" reserved = torch.cuda.memory_reserved(i) / 1e9\n",
|
| 300 |
+
" trainer.logger.experiment.log({\n",
|
| 301 |
+
" f\"memory/gpu_{i}_allocated_gb\": allocated,\n",
|
| 302 |
+
" f\"memory/gpu_{i}_reserved_gb\": reserved,\n",
|
| 303 |
+
" \"batch_idx\": batch_idx\n",
|
| 304 |
+
" })\n",
|
| 305 |
+
" \n",
|
| 306 |
+
" def on_train_epoch_end(self, trainer, pl_module):\n",
|
| 307 |
+
" # Force cleanup at end of each epoch\n",
|
| 308 |
+
" gc.collect()\n",
|
| 309 |
+
" torch.cuda.empty_cache()\n",
|
| 310 |
+
" \n",
|
| 311 |
+
" def on_validation_epoch_end(self, trainer, pl_module):\n",
|
| 312 |
+
" # Force cleanup after validation\n",
|
| 313 |
+
" gc.collect()\n",
|
| 314 |
+
" torch.cuda.empty_cache()\n",
|
| 315 |
+
" \n",
|
| 316 |
+
" \n",
|
| 317 |
+
"early_stop = EarlyStopping(\n",
|
| 318 |
+
" monitor='val_loss',\n",
|
| 319 |
+
" patience=4,\n",
|
| 320 |
+
" mode='min'\n",
|
| 321 |
+
")\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"memory_monitor = MemoryMonitorCallback()\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"checkpoint_callback = ModelCheckpoint(\n",
|
| 326 |
+
" monitor='val_loss',\n",
|
| 327 |
+
" filename='EurovocTaggerFP32-{epoch:02d}-{val_loss:.2f}',\n",
|
| 328 |
+
" mode='min',\n",
|
| 329 |
+
")\n"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "code",
|
| 334 |
+
"execution_count": null,
|
| 335 |
+
"id": "a069d202-2e61-4148-baeb-20fbd9b7bf7b",
|
| 336 |
+
"metadata": {
|
| 337 |
+
"tags": []
|
| 338 |
+
},
|
| 339 |
+
"outputs": [],
|
| 340 |
+
"source": [
|
| 341 |
+
"from pytorch_lightning.loggers import WandbLogger\n",
|
| 342 |
+
"wandb_logger = WandbLogger(\n",
|
| 343 |
+
" project=\"EUROVOC\",\n",
|
| 344 |
+
" name=\"EUROVOC-FP32\",\n",
|
| 345 |
+
" log_model=True, \n",
|
| 346 |
+
" save_dir=\"../logs\"\n",
|
| 347 |
+
")\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"BATCH_SIZE=58\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
|
| 354 |
+
"all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"dataloader = StreamingEurovocDataModule(BERT_MODEL_NAME, all_jsonl_files, mlb, batch_size=BATCH_SIZE)\n",
|
| 357 |
+
"dataloader.setup()\n",
|
| 358 |
+
"\n",
|
| 359 |
+
"N_EPOCHS = 30\n",
|
| 360 |
+
"LR = 5e-05\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"model = EurovocTagger(BERT_MODEL_NAME, len(mlb.classes_), lr=LR)\n",
|
| 363 |
+
"\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"wandb_logger.experiment.config.update({\n",
|
| 366 |
+
" \"bert_model\": BERT_MODEL_NAME,\n",
|
| 367 |
+
" \"batch_size\": BATCH_SIZE,\n",
|
| 368 |
+
" \"learning_rate\": LR,\n",
|
| 369 |
+
" \"max_epochs\": N_EPOCHS,\n",
|
| 370 |
+
" \"num_workers\": 3,\n",
|
| 371 |
+
" \"num_gpus\": 4,\n",
|
| 372 |
+
" \"precision\": \"32\",\n",
|
| 373 |
+
" \"num_classes\": len(mlb.classes_)\n",
|
| 374 |
+
"})\n",
|
| 375 |
+
"\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"if torch.cuda.is_available():\n",
|
| 379 |
+
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 380 |
+
" torch.backends.cudnn.allow_tf32 = True\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"torch.set_float32_matmul_precision('medium')\n",
|
| 383 |
+
"\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"trainer = pl.Trainer(max_epochs=N_EPOCHS ,\n",
|
| 386 |
+
" accelerator=\"gpu\",\n",
|
| 387 |
+
" devices=4, \n",
|
| 388 |
+
" callbacks=[checkpoint_callback, early_stop, memory_monitor],\n",
|
| 389 |
+
" strategy=\"ddp_notebook\",\n",
|
| 390 |
+
" logger=wandb_logger,\n",
|
| 391 |
+
" log_every_n_steps=50,\n",
|
| 392 |
+
" )\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"trainer.fit(model, dataloader)"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "markdown",
|
| 399 |
+
"id": "29d9203e-c02b-4a76-a57c-d2e0246722c7",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"source": [
|
| 402 |
+
"## Finetuning in BF16"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "code",
|
| 407 |
+
"execution_count": null,
|
| 408 |
+
"id": "86734efb-0bcd-442f-976e-ea0bbdb393d6",
|
| 409 |
+
"metadata": {
|
| 410 |
+
"tags": []
|
| 411 |
+
},
|
| 412 |
+
"outputs": [],
|
| 413 |
+
"source": [
|
| 414 |
+
"import wandb\n",
|
| 415 |
+
"\n",
|
| 416 |
+
"wandb.login() "
|
| 417 |
+
]
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"cell_type": "code",
|
| 421 |
+
"execution_count": null,
|
| 422 |
+
"id": "4b609efc-e3b7-4924-96c3-59a236f52ec6",
|
| 423 |
+
"metadata": {
|
| 424 |
+
"tags": []
|
| 425 |
+
},
|
| 426 |
+
"outputs": [],
|
| 427 |
+
"source": [
|
| 428 |
+
"from pytorch_lightning.loggers import WandbLogger\n",
|
| 429 |
+
"wandb_logger = WandbLogger(\n",
|
| 430 |
+
" project=\"EUROVOC\",\n",
|
| 431 |
+
" name=\"EUROVOC-BF16\",\n",
|
| 432 |
+
" log_model=True, \n",
|
| 433 |
+
" save_dir=\"../logs\"\n",
|
| 434 |
+
")\n"
|
| 435 |
+
]
|
| 436 |
+
},
|
| 437 |
+
{
|
| 438 |
+
"cell_type": "code",
|
| 439 |
+
"execution_count": null,
|
| 440 |
+
"id": "7cab5811-0ab9-48d5-8ec4-8d835bb0d3df",
|
| 441 |
+
"metadata": {},
|
| 442 |
+
"outputs": [],
|
| 443 |
+
"source": [
|
| 444 |
+
"#%%capture output\n",
|
| 445 |
+
"%load_ext autoreload\n",
|
| 446 |
+
"%autoreload 2\n",
|
| 447 |
+
"\n",
|
| 448 |
+
"from eurovoc import StreamingEurovocDataModule\n",
|
| 449 |
+
"from eurovoc import EurovocTaggerBCELogit, EurovocTagger\n",
|
| 450 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
| 451 |
+
"import pytorch_lightning as pl\n",
|
| 452 |
+
"import torch\n",
|
| 453 |
+
"from pytorch_lightning.callbacks import EarlyStopping\n",
|
| 454 |
+
"import gc\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"class MemoryMonitorCallback(pl.Callback):\n",
|
| 457 |
+
" def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):\n",
|
| 458 |
+
" # Log memory every 100 batches\n",
|
| 459 |
+
" if batch_idx % 100 == 0:\n",
|
| 460 |
+
" if torch.cuda.is_available():\n",
|
| 461 |
+
" for i in range(torch.cuda.device_count()):\n",
|
| 462 |
+
" allocated = torch.cuda.memory_allocated(i) / 1e9\n",
|
| 463 |
+
" reserved = torch.cuda.memory_reserved(i) / 1e9\n",
|
| 464 |
+
" trainer.logger.experiment.log({\n",
|
| 465 |
+
" f\"memory/gpu_{i}_allocated_gb\": allocated,\n",
|
| 466 |
+
" f\"memory/gpu_{i}_reserved_gb\": reserved\n",
|
| 467 |
+
" })\n",
|
| 468 |
+
" \n",
|
| 469 |
+
" def on_train_epoch_end(self, trainer, pl_module):\n",
|
| 470 |
+
" # Force cleanup at end of each epoch\n",
|
| 471 |
+
" gc.collect()\n",
|
| 472 |
+
" torch.cuda.empty_cache()\n",
|
| 473 |
+
" \n",
|
| 474 |
+
" def on_validation_epoch_end(self, trainer, pl_module):\n",
|
| 475 |
+
" # Force cleanup after validation\n",
|
| 476 |
+
" gc.collect()\n",
|
| 477 |
+
" torch.cuda.empty_cache()\n",
|
| 478 |
+
"\n",
|
| 479 |
+
" \n",
|
| 480 |
+
" \n",
|
| 481 |
+
"\n",
|
| 482 |
+
"early_stop = EarlyStopping(\n",
|
| 483 |
+
" monitor='val_loss',\n",
|
| 484 |
+
" patience=4,\n",
|
| 485 |
+
" mode='min'\n",
|
| 486 |
+
")\n",
|
| 487 |
+
"\n",
|
| 488 |
+
"memory_monitor = MemoryMonitorCallback()\n",
|
| 489 |
+
"\n",
|
| 490 |
+
"checkpoint_callback = ModelCheckpoint(\n",
|
| 491 |
+
" monitor='val_loss',\n",
|
| 492 |
+
" filename='EurovocTaggerA-{epoch:02d}-{val_loss:.2f}',\n",
|
| 493 |
+
" mode='min',\n",
|
| 494 |
+
")\n",
|
| 495 |
+
"\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"if torch.cuda.is_available():\n",
|
| 498 |
+
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 499 |
+
" torch.backends.cudnn.allow_tf32 = True\n",
|
| 500 |
+
"\n",
|
| 501 |
+
"torch.set_float32_matmul_precision('medium')\n",
|
| 502 |
+
"\n",
|
| 503 |
+
"\n",
|
| 504 |
+
"FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"BATCH_SIZE=74\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
|
| 509 |
+
"all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"dataloader = StreamingEurovocDataModule(BERT_MODEL_NAME, all_jsonl_files, mlb, batch_size=BATCH_SIZE)\n",
|
| 512 |
+
"dataloader.setup()\n",
|
| 513 |
+
"\n",
|
| 514 |
+
"\n",
|
| 515 |
+
"\n",
|
| 516 |
+
"N_EPOCHS = 30\n",
|
| 517 |
+
"LR = 5e-05\n",
|
| 518 |
+
"\n",
|
| 519 |
+
"BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"\n",
|
| 522 |
+
"model = EurovocTaggerBCELogit(BERT_MODEL_NAME, len(mlb.classes_), lr=LR)\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"wandb_logger.experiment.config.update({\n",
|
| 527 |
+
" \"bert_model\": BERT_MODEL_NAME,\n",
|
| 528 |
+
" \"batch_size\": BATCH_SIZE,\n",
|
| 529 |
+
" \"learning_rate\": LR,\n",
|
| 530 |
+
" \"max_epochs\": N_EPOCHS,\n",
|
| 531 |
+
" \"num_workers\": 3,\n",
|
| 532 |
+
" \"num_gpus\": 4,\n",
|
| 533 |
+
" \"precision\": \"16\",\n",
|
| 534 |
+
" \"num_classes\": len(mlb.classes_)\n",
|
| 535 |
+
"})\n",
|
| 536 |
+
"\n",
|
| 537 |
+
"trainer = pl.Trainer(max_epochs=N_EPOCHS ,\n",
|
| 538 |
+
" accelerator=\"gpu\",\n",
|
| 539 |
+
" devices=4, \n",
|
| 540 |
+
" callbacks=[checkpoint_callback, early_stop, memory_monitor],\n",
|
| 541 |
+
" strategy=\"ddp_notebook\",\n",
|
| 542 |
+
" accumulate_grad_batches=1,\n",
|
| 543 |
+
" precision=16,\n",
|
| 544 |
+
" logger=wandb_logger,\n",
|
| 545 |
+
" log_every_n_steps=50,\n",
|
| 546 |
+
" )\n"
|
| 547 |
+
]
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"cell_type": "code",
|
| 551 |
+
"execution_count": null,
|
| 552 |
+
"id": "6af63c61-5ecd-4207-8aaa-2a0dbd008df2",
|
| 553 |
+
"metadata": {
|
| 554 |
+
"tags": []
|
| 555 |
+
},
|
| 556 |
+
"outputs": [],
|
| 557 |
+
"source": [
|
| 558 |
+
"\n",
|
| 559 |
+
"trainer.fit(model, dataloader)"
|
| 560 |
+
]
|
| 561 |
+
},
|
| 562 |
+
{
|
| 563 |
+
"cell_type": "markdown",
|
| 564 |
+
"id": "2e2f69c2-9d89-4468-8198-b15da16e9403",
|
| 565 |
+
"metadata": {},
|
| 566 |
+
"source": [
|
| 567 |
+
"## 4. MODEL definition and training (LORA) (STILL USES OLD EUROVOC TAGGER)"
|
| 568 |
+
]
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
"cell_type": "code",
|
| 572 |
+
"execution_count": null,
|
| 573 |
+
"id": "c28014b2-5ccb-45d9-8025-d05d31d77a08",
|
| 574 |
+
"metadata": {
|
| 575 |
+
"tags": []
|
| 576 |
+
},
|
| 577 |
+
"outputs": [],
|
| 578 |
+
"source": [
|
| 579 |
+
"from eurovoc import StreamingEurovocDataModule\n",
|
| 580 |
+
"from eurovoc import EurovocTaggerLoRA\n",
|
| 581 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
| 582 |
+
"import pytorch_lightning as pl\n",
|
| 583 |
+
"import torch\n",
|
| 584 |
+
"\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"torch.set_float32_matmul_precision('medium')\n",
|
| 587 |
+
"\n",
|
| 588 |
+
"FIXED_DIR = \"../eurovoc_data/files_fixed\"\n",
|
| 589 |
+
"\n",
|
| 590 |
+
"BATCH_SIZE=94\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
|
| 593 |
+
"\n",
|
| 594 |
+
"\n",
|
| 595 |
+
"all_jsonl_files = list_all_json_files(FIXED_DIR)\n",
|
| 596 |
+
"\n",
|
| 597 |
+
"dataloader = StreamingEurovocDataModule(BERT_MODEL_NAME, all_jsonl_files, mlb, batch_size=BATCH_SIZE)\n",
|
| 598 |
+
"dataloader.setup()\n",
|
| 599 |
+
"\n",
|
| 600 |
+
"\n",
|
| 601 |
+
"N_EPOCHS = 30\n",
|
| 602 |
+
"LR = 5e-05\n",
|
| 603 |
+
"\n",
|
| 604 |
+
"# LoRA hyperparameters\n",
|
| 605 |
+
"# Rank of LoRA matrices\n",
|
| 606 |
+
"LORA_R = 16 \n",
|
| 607 |
+
"# Scaling factor (usually 2 * r)\n",
|
| 608 |
+
"LORA_ALPHA = 32 \n",
|
| 609 |
+
"LORA_DROPOUT = 0.1\n",
|
| 610 |
+
"\n",
|
| 611 |
+
"# Hierarchical classifier parameter (for 6800 labels)\n",
|
| 612 |
+
"# Bottleneck size: 768 → 256 → 6800\n",
|
| 613 |
+
"N_INTERMEDIATE = 256 \n",
|
| 614 |
+
"\n",
|
| 615 |
+
"\n",
|
| 616 |
+
"# Create LoRA model with hierarchical classifier\n",
|
| 617 |
+
"model = EurovocTaggerLoRA(\n",
|
| 618 |
+
" BERT_MODEL_NAME, \n",
|
| 619 |
+
" # 6800+ labels\n",
|
| 620 |
+
" len(mlb.classes_),\n",
|
| 621 |
+
" # Bottleneck size\n",
|
| 622 |
+
" n_intermediate=N_INTERMEDIATE, \n",
|
| 623 |
+
" lr=LR,\n",
|
| 624 |
+
" lora_r=LORA_R,\n",
|
| 625 |
+
" lora_alpha=LORA_ALPHA,\n",
|
| 626 |
+
" lora_dropout=LORA_DROPOUT\n",
|
| 627 |
+
")\n",
|
| 628 |
+
"\n",
|
| 629 |
+
"checkpoint_callback = ModelCheckpoint(\n",
|
| 630 |
+
" monitor='val_loss',\n",
|
| 631 |
+
" filename='EurovocTaggerLoRA-6800-{epoch:02d}-{val_loss:.2f}',\n",
|
| 632 |
+
" mode='min',\n",
|
| 633 |
+
")\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"trainer = pl.Trainer(\n",
|
| 636 |
+
" max_epochs=N_EPOCHS, \n",
|
| 637 |
+
" accelerator=\"gpu\", \n",
|
| 638 |
+
" devices=4, \n",
|
| 639 |
+
" callbacks=[checkpoint_callback],\n",
|
| 640 |
+
" strategy=\"ddp_notebook\",\n",
|
| 641 |
+
" precision=16\n",
|
| 642 |
+
")\n",
|
| 643 |
+
"\n",
|
| 644 |
+
"print(f\"Starting LoRA training with {len(mlb.classes_)} labels...\")\n",
|
| 645 |
+
"print(f\"Classifier architecture: 768 → {N_INTERMEDIATE} → {len(mlb.classes_)}\")\n",
|
| 646 |
+
"trainer.fit(model, dataloader)\n",
|
| 647 |
+
"\n",
|
| 648 |
+
"\n",
|
| 649 |
+
"\n"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"cell_type": "code",
|
| 654 |
+
"execution_count": null,
|
| 655 |
+
"id": "694ad3d7-794d-4a5f-a7be-8b47e872418d",
|
| 656 |
+
"metadata": {},
|
| 657 |
+
"outputs": [],
|
| 658 |
+
"source": [
|
| 659 |
+
"# Save only the LoRA adapter \n",
|
| 660 |
+
"model.save_lora_adapter('./eurovoc_lora_adapter')\n",
|
| 661 |
+
"\n",
|
| 662 |
+
"print(\"LoRA adapter saved to ./eurovoc_lora_adapter\")\n"
|
| 663 |
+
]
|
| 664 |
+
}
|
| 665 |
+
],
|
| 666 |
+
"metadata": {
|
| 667 |
+
"kernelspec": {
|
| 668 |
+
"display_name": "eurovoc_training_env",
|
| 669 |
+
"language": "python",
|
| 670 |
+
"name": "eurovoc_training_env"
|
| 671 |
+
},
|
| 672 |
+
"language_info": {
|
| 673 |
+
"codemirror_mode": {
|
| 674 |
+
"name": "ipython",
|
| 675 |
+
"version": 3
|
| 676 |
+
},
|
| 677 |
+
"file_extension": ".py",
|
| 678 |
+
"mimetype": "text/x-python",
|
| 679 |
+
"name": "python",
|
| 680 |
+
"nbconvert_exporter": "python",
|
| 681 |
+
"pygments_lexer": "ipython3",
|
| 682 |
+
"version": "3.10.12"
|
| 683 |
+
}
|
| 684 |
+
},
|
| 685 |
+
"nbformat": 4,
|
| 686 |
+
"nbformat_minor": 5
|
| 687 |
+
}
|