| """ |
| Train mmBERT for email zone classification on HuggingFace Spaces (A100). |
| |
| This script runs inside the Docker container on HF Spaces. |
| """ |
|
|
| import json |
| import os |
| from collections import Counter |
| from pathlib import Path |
|
|
| import evaluate |
| import numpy as np |
| import torch |
| from datasets import Dataset |
| from huggingface_hub import login |
| from transformers import ( |
| AutoModelForTokenClassification, |
| AutoTokenizer, |
| DataCollatorForTokenClassification, |
| EarlyStoppingCallback, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| |
| ZONE_LABELS = ['O', 'GREETING', 'AUTHOR', 'QUOTE_HEADER', 'QUOTE', 'SIGNATURE', 'FORWARD'] |
| LABEL2ID = {label: i for i, label in enumerate(ZONE_LABELS)} |
| ID2LABEL = {i: label for i, label in enumerate(ZONE_LABELS)} |
|
|
|
|
| class WeightedTrainer(Trainer): |
| """Custom Trainer with weighted loss for class imbalance.""" |
| |
| def __init__(self, *args, class_weights=None, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.class_weights = class_weights |
| |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| labels = inputs.pop("labels") |
| outputs = model(**inputs) |
| logits = outputs.logits |
| |
| if self.class_weights is not None: |
| weights = self.class_weights.to(logits.device) |
| loss_fct = torch.nn.CrossEntropyLoss(weight=weights, ignore_index=-100) |
| else: |
| loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) |
| |
| loss = loss_fct(logits.view(-1, len(ZONE_LABELS)), labels.view(-1)) |
| |
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| |
| OUTPUT_DIR = Path('./output') |
| MODEL_NAME = 'jhu-clsp/mmBERT-base' |
| HUB_MODEL_ID = 'fixthemusic/ftm-zone-classifier' |
| DATASET_ID = 'fixthemusic/ftm-zone-training-data' |
|
|
|
|
| def load_training_data(): |
| """Load the training data from HuggingFace Dataset.""" |
| from datasets import load_dataset |
| |
| |
| dataset = load_dataset(DATASET_ID, data_files={ |
| 'train': 'train.jsonl', |
| 'validation': 'validation.jsonl' |
| }) |
| |
| return dataset['train'], dataset['validation'] |
|
|
|
|
| def calculate_class_weights(train_dataset): |
| """Calculate class weights from actual data distribution (inverse frequency).""" |
| print("Calculating class weights from data...") |
| label_counts = Counter() |
| |
| |
| for example in train_dataset: |
| label_counts.update(example['char_labels']) |
| |
| total = sum(label_counts.values()) |
| print(f"Label distribution: {dict(label_counts)}") |
| |
| |
| |
| weights = [] |
| for label in ZONE_LABELS: |
| count = label_counts.get(label, 1) |
| freq = count / total |
| |
| weight = min(50.0, 1.0 / (freq + 0.001)) |
| weights.append(weight) |
| |
| |
| mean_weight = sum(weights) / len(weights) |
| weights = [w / mean_weight for w in weights] |
| |
| print(f"Calculated class weights: {dict(zip(ZONE_LABELS, weights))}") |
| return torch.tensor(weights, dtype=torch.float32) |
|
|
|
|
| def tokenize_and_align_labels(examples, tokenizer, max_length=768): |
| tokenized = tokenizer( |
| examples['text'], |
| truncation=True, |
| max_length=max_length, |
| padding='max_length', |
| return_offsets_mapping=True, |
| ) |
| |
| all_labels = [] |
| for i, offsets in enumerate(tokenized['offset_mapping']): |
| char_labels = examples['char_labels'][i] |
| token_labels = [] |
| |
| for start, end in offsets: |
| if start == end: |
| token_labels.append(-100) |
| else: |
| label = char_labels[start] if start < len(char_labels) else 'O' |
| token_labels.append(LABEL2ID.get(label, 0)) |
| |
| all_labels.append(token_labels) |
| |
| tokenized['labels'] = all_labels |
| del tokenized['offset_mapping'] |
| |
| return tokenized |
|
|
|
|
| def compute_metrics(eval_pred): |
| seqeval = evaluate.load('seqeval') |
| |
| predictions, labels = eval_pred |
| predictions = np.argmax(predictions, axis=2) |
| |
| true_labels = [] |
| pred_labels = [] |
| |
| for pred_seq, label_seq in zip(predictions, labels): |
| true_seq = [] |
| pred_seq_labels = [] |
| for pred, label in zip(pred_seq, label_seq): |
| if label != -100: |
| true_seq.append(ID2LABEL[label]) |
| pred_seq_labels.append(ID2LABEL[pred]) |
| true_labels.append(true_seq) |
| pred_labels.append(pred_seq_labels) |
| |
| results = seqeval.compute(predictions=pred_labels, references=true_labels) |
| |
| return { |
| 'precision': results['overall_precision'], |
| 'recall': results['overall_recall'], |
| 'f1': results['overall_f1'], |
| 'accuracy': results['overall_accuracy'], |
| } |
|
|
|
|
| def main(): |
| |
| hf_token = os.environ.get('HF_TOKEN') |
| if hf_token: |
| login(token=hf_token) |
| print("Logged into HuggingFace Hub") |
| else: |
| print("WARNING: No HF_TOKEN found - model upload will fail!") |
| |
| print(f"Device: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}") |
| |
| print('Loading data...') |
| train_dataset, val_dataset = load_training_data() |
| print(f'Train: {len(train_dataset)}, Validation: {len(val_dataset)}') |
| |
| |
| class_weights = calculate_class_weights(train_dataset) |
| |
| print(f'Loading model: {MODEL_NAME}...') |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForTokenClassification.from_pretrained( |
| MODEL_NAME, |
| num_labels=len(ZONE_LABELS), |
| id2label=ID2LABEL, |
| label2id=LABEL2ID, |
| ) |
| |
| print('Tokenizing...') |
| train_tokenized = train_dataset.map( |
| lambda x: tokenize_and_align_labels(x, tokenizer), |
| batched=True, |
| remove_columns=['id', 'text', 'char_labels'] |
| ) |
| val_tokenized = val_dataset.map( |
| lambda x: tokenize_and_align_labels(x, tokenizer), |
| batched=True, |
| remove_columns=['id', 'text', 'char_labels'] |
| ) |
| |
| data_collator = DataCollatorForTokenClassification(tokenizer, padding=True) |
| |
| training_args = TrainingArguments( |
| output_dir=str(OUTPUT_DIR), |
| |
| push_to_hub=True, |
| hub_model_id=HUB_MODEL_ID, |
| hub_token=hf_token, |
| |
| num_train_epochs=3, |
| per_device_train_batch_size=8, |
| per_device_eval_batch_size=16, |
| gradient_accumulation_steps=8, |
| warmup_ratio=0.1, |
| weight_decay=0.01, |
| logging_steps=50, |
| eval_strategy='epoch', |
| save_strategy='epoch', |
| load_best_model_at_end=True, |
| metric_for_best_model='f1', |
| greater_is_better=True, |
| |
| bf16=True, |
| learning_rate=2e-5, |
| label_smoothing_factor=0.05, |
| report_to='none', |
| |
| save_total_limit=2, |
| ) |
| |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_tokenized, |
| eval_dataset=val_tokenized, |
| processing_class=tokenizer, |
| data_collator=data_collator, |
| compute_metrics=compute_metrics, |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], |
| ) |
| |
| print('Starting training with standard loss and early stopping...') |
| trainer.train() |
| |
| print('Evaluating...') |
| metrics = trainer.evaluate() |
| print(f'Final metrics: {metrics}') |
| |
| print('Pushing to Hub...') |
| trainer.push_to_hub() |
| |
| |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| with open(OUTPUT_DIR / 'labels.json', 'w') as f: |
| json.dump({'labels': ZONE_LABELS, 'label2id': LABEL2ID, 'id2label': ID2LABEL}, f) |
| |
| print('Done!') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|