Data_augmentation / backend /train_classifier.py
Jacek Dusza
Initial commit: NLP Pipeline backend and React frontend
69a2c97
Raw
History Blame Contribute Delete
4.27 kB
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
import os
# Disable Weights & Biases logging and tokenizer parallelism warnings
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
MODEL_NAME = "allegro/herbert-base-cased"
DATA_FILE = "augmented_dataset.csv"
print("Starting the classifier evaluation on HerBERT model.")
# 1. DATA LOADING AND PREPARATION
df = pd.read_csv(DATA_FILE)
# Separate original and synthetic datasets
df_orig = df[df['is_synthetic'] == False].copy()
df_aug = df[df['is_synthetic'] == True].copy()
# Map text labels to numerical indices
label_mapping = {label: idx for idx, label in enumerate(df_orig['label'].unique())}
df_orig['label_idx'] = df_orig['label'].map(label_mapping)
df_aug['label_idx'] = df_aug['label'].map(label_mapping)
# Split original data into training (80%) and testing (20%) sets
train_orig, test_data = train_test_split(df_orig, test_size=0.2, random_state=42, stratify=df_orig['label_idx'])
# DATA LEAKAGE PROTECTION
# Ensure synthetic samples are only used if their original source is in the training set
train_orig_ids = train_orig['id'].astype(str).tolist()
df_aug_filtered = df_aug[df_aug['id'].apply(lambda x: str(x).split('_')[0] in train_orig_ids)]
# Combine original training data with valid augmented data
train_augmented = pd.concat([train_orig, df_aug_filtered], ignore_index=True)
print("\nExperiment structure:")
print(f" - Test dataset (immutable): {len(test_data)} samples")
print(f" - Training BASELINE: {len(train_orig)} samples")
print(f" - Training AUGMENTED: {len(train_augmented)} samples")
# 2. TOKENIZATION
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize_data(data_df):
dataset = Dataset.from_pandas(data_df[['text', 'label_idx']])
return dataset.map(
lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=128),
batched=True
).rename_column("label_idx", "labels")
print("\nTokenizing datasets...")
test_dataset = tokenize_data(test_data)
train_orig_dataset = tokenize_data(train_orig)
train_aug_dataset = tokenize_data(train_augmented)
# 3. METRICS EVALUATION
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
acc = accuracy_score(labels, preds)
f1 = f1_score(labels, preds, average='macro')
return {'accuracy': acc, 'f1_macro': f1}
# 4. TRAINING ENGINE
def train_and_evaluate(train_ds, test_ds, output_dir):
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label_mapping))
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
eval_strategy="epoch",
save_strategy="no",
logging_dir='./logs',
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
compute_metrics=compute_metrics
)
trainer.train()
return trainer.evaluate()
# 5. EXPERIMENT EXECUTION
print("\n" + "="*50)
print("STEP 1: Training the BASELINE model")
print("="*50)
base_metrics = train_and_evaluate(train_orig_dataset, test_dataset, "./results_base")
print("\n" + "="*50)
print("STEP 2: Training the AUGMENTED model")
print("="*50)
aug_metrics = train_and_evaluate(train_aug_dataset, test_dataset, "./results_aug")
# 6. RESULTS OUTPUT
print("\n\n" + "FINAL EXPERIMENT RESULTS".center(50))
print("-" * 52)
print(f"Metric | Baseline | Augmented | Change")
print("-" * 52)
base_f1 = base_metrics['eval_f1_macro'] * 100
aug_f1 = aug_metrics['eval_f1_macro'] * 100
diff_f1 = aug_f1 - base_f1
base_acc = base_metrics['eval_accuracy'] * 100
aug_acc = aug_metrics['eval_accuracy'] * 100
diff_acc = aug_acc - base_acc
print(f"Macro-F1 | {base_f1:10.2f}% | {aug_f1:10.2f}% | {diff_f1:+5.2f} pp.")
print(f"Accuracy | {base_acc:10.2f}% | {aug_acc:10.2f}% | {diff_acc:+5.2f} pp.")
print("-" * 52)