fmnxl's picture
Revert to 7-label: F1 0.72 (binary was 0.28)
0774ad1 verified
"""
Train mmBERT for email zone classification on HuggingFace Spaces (A100).
This script runs inside the Docker container on HF Spaces.
"""
import json
import os
from collections import Counter
from pathlib import Path
import evaluate
import numpy as np
import torch
from datasets import Dataset
from huggingface_hub import login
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
)
# Zone labels (7-class classification)
ZONE_LABELS = ['O', 'GREETING', 'AUTHOR', 'QUOTE_HEADER', 'QUOTE', 'SIGNATURE', 'FORWARD']
LABEL2ID = {label: i for i, label in enumerate(ZONE_LABELS)}
ID2LABEL = {i: label for i, label in enumerate(ZONE_LABELS)}
class WeightedTrainer(Trainer):
"""Custom Trainer with weighted loss for class imbalance."""
def __init__(self, *args, class_weights=None, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
if self.class_weights is not None:
weights = self.class_weights.to(logits.device)
loss_fct = torch.nn.CrossEntropyLoss(weight=weights, ignore_index=-100)
else:
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, len(ZONE_LABELS)), labels.view(-1))
return (loss, outputs) if return_outputs else loss
# Use relative paths that work inside container
OUTPUT_DIR = Path('./output')
MODEL_NAME = 'jhu-clsp/mmBERT-base'
HUB_MODEL_ID = 'fixthemusic/ftm-zone-classifier' # Use org namespace (token has access)
DATASET_ID = 'fixthemusic/ftm-zone-training-data'
def load_training_data():
"""Load the training data from HuggingFace Dataset."""
from datasets import load_dataset
# Load from HuggingFace Hub dataset
dataset = load_dataset(DATASET_ID, data_files={
'train': 'train.jsonl',
'validation': 'validation.jsonl'
})
return dataset['train'], dataset['validation']
def calculate_class_weights(train_dataset):
"""Calculate class weights from actual data distribution (inverse frequency)."""
print("Calculating class weights from data...")
label_counts = Counter()
# Count labels across all training examples
for example in train_dataset:
label_counts.update(example['char_labels'])
total = sum(label_counts.values())
print(f"Label distribution: {dict(label_counts)}")
# Calculate inverse frequency weights
# More rare classes get higher weights
weights = []
for label in ZONE_LABELS:
count = label_counts.get(label, 1) # Avoid division by zero
freq = count / total
# Inverse frequency with smoothing
weight = min(50.0, 1.0 / (freq + 0.001)) # Cap at 50x to avoid instability
weights.append(weight)
# Normalize so mean weight is 1.0
mean_weight = sum(weights) / len(weights)
weights = [w / mean_weight for w in weights]
print(f"Calculated class weights: {dict(zip(ZONE_LABELS, weights))}")
return torch.tensor(weights, dtype=torch.float32)
def tokenize_and_align_labels(examples, tokenizer, max_length=768):
tokenized = tokenizer(
examples['text'],
truncation=True,
max_length=max_length,
padding='max_length',
return_offsets_mapping=True,
)
all_labels = []
for i, offsets in enumerate(tokenized['offset_mapping']):
char_labels = examples['char_labels'][i]
token_labels = []
for start, end in offsets:
if start == end:
token_labels.append(-100)
else:
label = char_labels[start] if start < len(char_labels) else 'O'
token_labels.append(LABEL2ID.get(label, 0))
all_labels.append(token_labels)
tokenized['labels'] = all_labels
del tokenized['offset_mapping']
return tokenized
def compute_metrics(eval_pred):
seqeval = evaluate.load('seqeval')
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=2)
true_labels = []
pred_labels = []
for pred_seq, label_seq in zip(predictions, labels):
true_seq = []
pred_seq_labels = []
for pred, label in zip(pred_seq, label_seq):
if label != -100:
true_seq.append(ID2LABEL[label])
pred_seq_labels.append(ID2LABEL[pred])
true_labels.append(true_seq)
pred_labels.append(pred_seq_labels)
results = seqeval.compute(predictions=pred_labels, references=true_labels)
return {
'precision': results['overall_precision'],
'recall': results['overall_recall'],
'f1': results['overall_f1'],
'accuracy': results['overall_accuracy'],
}
def main():
# Login to HuggingFace
hf_token = os.environ.get('HF_TOKEN')
if hf_token:
login(token=hf_token)
print("Logged into HuggingFace Hub")
else:
print("WARNING: No HF_TOKEN found - model upload will fail!")
print(f"Device: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}")
print('Loading data...')
train_dataset, val_dataset = load_training_data()
print(f'Train: {len(train_dataset)}, Validation: {len(val_dataset)}')
# Calculate class weights from actual data
class_weights = calculate_class_weights(train_dataset)
print(f'Loading model: {MODEL_NAME}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(
MODEL_NAME,
num_labels=len(ZONE_LABELS),
id2label=ID2LABEL,
label2id=LABEL2ID,
)
print('Tokenizing...')
train_tokenized = train_dataset.map(
lambda x: tokenize_and_align_labels(x, tokenizer),
batched=True,
remove_columns=['id', 'text', 'char_labels']
)
val_tokenized = val_dataset.map(
lambda x: tokenize_and_align_labels(x, tokenizer),
batched=True,
remove_columns=['id', 'text', 'char_labels']
)
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
training_args = TrainingArguments(
output_dir=str(OUTPUT_DIR),
# Hub configuration
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_token=hf_token,
# Training configuration - Phase 2b: Phase 1 + gentle label smoothing
num_train_epochs=3, # Back to 3 (was 2)
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
gradient_accumulation_steps=8, # Effective batch size = 64
warmup_ratio=0.1, # Back to 0.1 (was 0.15)
weight_decay=0.01,
logging_steps=50,
eval_strategy='epoch',
save_strategy='epoch',
load_best_model_at_end=True,
metric_for_best_model='f1',
greater_is_better=True,
# Use bf16 on A100
bf16=True,
learning_rate=2e-5, # Back to 2e-5 (was 1e-5)
label_smoothing_factor=0.05, # Gentle smoothing (was 0.1)
report_to='none',
# Save disk space
save_total_limit=2,
)
# Use standard Trainer (WeightedTrainer caused F1 regression)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized,
eval_dataset=val_tokenized,
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
print('Starting training with standard loss and early stopping...')
trainer.train()
print('Evaluating...')
metrics = trainer.evaluate()
print(f'Final metrics: {metrics}')
print('Pushing to Hub...')
trainer.push_to_hub()
# Save label mapping locally too
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_DIR / 'labels.json', 'w') as f:
json.dump({'labels': ZONE_LABELS, 'label2id': LABEL2ID, 'id2label': ID2LABEL}, f)
print('Done!')
if __name__ == '__main__':
main()