| |
| """LoRA Fine-Tuning KcELECTRA for 7-Class Korean Text Emotion Recognition. |
| |
| Uses PEFT LoRA on beomi/KcELECTRA-base-v2022 for text-based emotion classification. |
| Filters out samples without text (e.g., RAVDESS English data). |
| |
| Usage: |
| python scripts/train_lora_kcelectra.py \ |
| --train-manifest data/lora_dataset/train_manifest.json \ |
| --val-manifest data/lora_dataset/val_manifest.json \ |
| --output-dir data/models/lora_kcelectra_7class \ |
| --epochs 10 --batch-size 16 --device cuda |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import random |
| import time |
| from collections import Counter |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, Dataset |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| LABELS_7CLASS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] |
| LABEL2IDX = {l: i for i, l in enumerate(LABELS_7CLASS)} |
| NUM_CLASSES = len(LABELS_7CLASS) |
|
|
|
|
| |
| |
| |
|
|
| class TextEmotionDataset(Dataset): |
| """Load text + label from manifest, skip samples without text.""" |
|
|
| def __init__(self, manifest_path: str, tokenizer, max_length: int = 128): |
| with open(manifest_path, encoding="utf-8") as f: |
| raw = json.load(f) |
|
|
| |
| self.samples = [ |
| s for s in raw |
| if s.get("text", "").strip() and s["label"] in LABEL2IDX |
| ] |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| logger.info("TextEmotionDataset: %d samples (filtered from %d, skipped %d without text)", |
| len(self.samples), len(raw), len(raw) - len(self.samples)) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
| text = sample["text"] |
| label = LABEL2IDX[sample["label"]] |
|
|
| encoding = self.tokenizer( |
| text, |
| truncation=True, |
| max_length=self.max_length, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| return { |
| "input_ids": encoding["input_ids"].squeeze(0), |
| "attention_mask": encoding["attention_mask"].squeeze(0), |
| "labels": torch.tensor(label, dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def validate(model, val_loader, device, criterion): |
| model.eval() |
| total_loss = 0 |
| y_true, y_pred = [], [] |
|
|
| for batch in val_loader: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| logits = outputs.logits |
| loss = criterion(logits, labels) |
|
|
| total_loss += loss.item() * labels.size(0) |
| y_true.extend(labels.cpu().tolist()) |
| y_pred.extend(logits.argmax(dim=-1).cpu().tolist()) |
|
|
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix |
| acc = accuracy_score(y_true, y_pred) |
| f1_per_class = f1_score(y_true, y_pred, labels=list(range(NUM_CLASSES)), |
| average=None, zero_division=0) |
| macro_f1 = float(np.mean(f1_per_class)) |
| per_class = {LABELS_7CLASS[i]: round(float(f1_per_class[i]), 4) for i in range(NUM_CLASSES)} |
| cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES))) |
|
|
| return { |
| "loss": round(total_loss / max(len(y_true), 1), 4), |
| "accuracy": round(acc, 4), |
| "macro_f1": round(macro_f1, 4), |
| "per_class_f1": per_class, |
| "confusion_matrix": cm.tolist(), |
| } |
|
|
|
|
| def plot_confusion_matrix(cm, output_path: Path, epoch: int): |
| try: |
| import matplotlib; matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| fig, ax = plt.subplots(figsize=(9, 7)) |
| cm_norm = cm / cm.sum(axis=1, keepdims=True) |
| sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues", |
| xticklabels=LABELS_7CLASS, yticklabels=LABELS_7CLASS, ax=ax) |
| for i in range(NUM_CLASSES): |
| for j in range(NUM_CLASSES): |
| ax.text(j + 0.5, i + 0.7, f"({cm[i][j]})", |
| ha="center", va="center", fontsize=6, color="gray") |
| ax.set_xlabel("Predicted"); ax.set_ylabel("True") |
| ax.set_title(f"KcELECTRA LoRA — Epoch {epoch}") |
| plt.tight_layout(); plt.savefig(str(output_path), dpi=150); plt.close() |
| except ImportError: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| def train(args): |
| device = args.device |
| use_amp = (device == "cuda") |
|
|
| |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from peft import LoraConfig, get_peft_model, TaskType |
|
|
| logger.info("Loading KcELECTRA: %s", args.model_id) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_id) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| args.model_id, |
| num_labels=NUM_CLASSES, |
| id2label={i: l for i, l in enumerate(LABELS_7CLASS)}, |
| label2id=LABEL2IDX, |
| ) |
|
|
| |
| lora_config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| target_modules=["query", "value"], |
| task_type=TaskType.SEQ_CLS, |
| bias="none", |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| model = model.to(device) |
|
|
| |
| train_ds = TextEmotionDataset(args.train_manifest, tokenizer, max_length=args.max_length) |
| val_ds = TextEmotionDataset(args.val_manifest, tokenizer, max_length=args.max_length) |
| logger.info("Train: %d, Val: %d", len(train_ds), len(val_ds)) |
|
|
| train_loader = DataLoader( |
| train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=2, pin_memory=(device == "cuda"), drop_last=True, |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=2, pin_memory=(device == "cuda"), |
| ) |
|
|
| |
| class_counts = np.zeros(NUM_CLASSES) |
| for s in train_ds.samples: |
| class_counts[LABEL2IDX[s["label"]]] += 1 |
| class_weights = 1.0 / np.maximum(class_counts, 1) |
| class_weights = class_weights / class_weights.sum() * NUM_CLASSES |
| logger.info("Class weights: %s", |
| {LABELS_7CLASS[i]: round(float(class_weights[i]), 3) for i in range(NUM_CLASSES)}) |
|
|
| criterion = nn.CrossEntropyLoss( |
| weight=torch.tensor(class_weights, dtype=torch.float32).to(device), |
| ) |
|
|
| |
| optimizer = torch.optim.AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| ) |
|
|
| |
| steps_per_epoch = max(len(train_loader) // args.accumulate_steps, 1) |
| total_steps = steps_per_epoch * args.epochs |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, max_lr=args.lr, total_steps=total_steps, |
| pct_start=0.1, anneal_strategy="cos", |
| ) |
|
|
| scaler = torch.amp.GradScaler("cuda", enabled=use_amp) |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| training_log = [] |
| best_f1 = 0.0 |
| patience_counter = 0 |
| start_epoch = 1 |
|
|
| |
| if args.resume: |
| ckpt_path = output_dir / "last_checkpoint.json" |
| if ckpt_path.exists(): |
| with open(ckpt_path) as f: |
| ckpt_info = json.load(f) |
| start_epoch = ckpt_info["epoch"] + 1 |
| best_f1 = ckpt_info["best_f1"] |
| patience_counter = ckpt_info["patience_counter"] |
| training_log = ckpt_info.get("training_log", []) |
| |
| model_path = output_dir / "last_model" |
| if model_path.exists(): |
| from peft import PeftModel |
| model = AutoModelForSequenceClassification.from_pretrained( |
| args.model_id, num_labels=NUM_CLASSES, |
| id2label={i: l for i, l in enumerate(LABELS_7CLASS)}, |
| label2id=LABEL2IDX, |
| ) |
| model = PeftModel.from_pretrained(model, str(model_path)) |
| model = model.to(device) |
| |
| optimizer = torch.optim.AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=args.lr, weight_decay=args.weight_decay, |
| ) |
| remaining_steps = steps_per_epoch * (args.epochs - start_epoch + 1) |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, max_lr=args.lr, total_steps=max(remaining_steps, 1), |
| pct_start=0.1, anneal_strategy="cos", |
| ) |
| logger.info("Resumed from epoch %d (best_f1=%.4f)", start_epoch, best_f1) |
| else: |
| logger.warning("--resume but no checkpoint found, starting fresh") |
|
|
| for epoch in range(start_epoch, args.epochs + 1): |
| epoch_start = time.time() |
| model.train() |
| total_loss = 0; correct = 0; total_samples = 0 |
| optimizer.zero_grad() |
|
|
| for batch_idx, batch in enumerate(train_loader): |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| with torch.amp.autocast("cuda", enabled=use_amp): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| logits = outputs.logits |
| loss = criterion(logits, labels) / args.accumulate_steps |
|
|
| scaler.scale(loss).backward() |
|
|
| if (batch_idx + 1) % args.accumulate_steps == 0 or (batch_idx + 1) == len(train_loader): |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad], 1.0, |
| ) |
| scaler.step(optimizer); scaler.update() |
| optimizer.zero_grad(); scheduler.step() |
|
|
| total_loss += loss.item() * args.accumulate_steps * labels.size(0) |
| preds = logits.argmax(dim=-1) |
| correct += (preds == labels).sum().item() |
| total_samples += labels.size(0) |
|
|
| if (batch_idx + 1) % 50 == 0: |
| logger.info(" Epoch %d [%d/%d] loss=%.4f lr=%.2e", |
| epoch, batch_idx + 1, len(train_loader), |
| loss.item() * args.accumulate_steps, |
| optimizer.param_groups[0]["lr"]) |
|
|
| train_loss = total_loss / max(total_samples, 1) |
| train_acc = correct / max(total_samples, 1) |
|
|
| |
| val_metrics = validate(model, val_loader, device, criterion) |
| epoch_time = time.time() - epoch_start |
|
|
| logger.info( |
| "Epoch %d/%d (%.0fs): train_loss=%.4f train_acc=%.3f | " |
| "val_loss=%.4f val_f1=%.3f val_acc=%.3f", |
| epoch, args.epochs, epoch_time, train_loss, train_acc, |
| val_metrics["loss"], val_metrics["macro_f1"], val_metrics["accuracy"], |
| ) |
| logger.info(" Per-class F1: %s", val_metrics["per_class_f1"]) |
|
|
| training_log.append({ |
| "epoch": epoch, |
| "train_loss": round(train_loss, 4), |
| "train_acc": round(train_acc, 4), |
| **{f"val_{k}": v for k, v in val_metrics.items() if k != "confusion_matrix"}, |
| "epoch_time_sec": round(epoch_time, 1), |
| }) |
|
|
| |
| if val_metrics["macro_f1"] > best_f1: |
| best_f1 = val_metrics["macro_f1"] |
| patience_counter = 0 |
| model.save_pretrained(str(output_dir / "best_model")) |
| tokenizer.save_pretrained(str(output_dir / "best_model")) |
| logger.info(" New best! macro_f1=%.4f saved to best_model/", best_f1) |
| if "confusion_matrix" in val_metrics: |
| cm = np.array(val_metrics["confusion_matrix"]) |
| plot_confusion_matrix(cm, output_dir / f"cm_epoch{epoch}.png", epoch) |
| else: |
| patience_counter += 1 |
|
|
| |
| model.save_pretrained(str(output_dir / "last_model")) |
| tokenizer.save_pretrained(str(output_dir / "last_model")) |
| with open(output_dir / "last_checkpoint.json", "w") as f: |
| json.dump({ |
| "epoch": epoch, "best_f1": best_f1, |
| "patience_counter": patience_counter, |
| "training_log": training_log, |
| }, f, indent=2) |
|
|
| with open(output_dir / "training_log.json", "w") as f: |
| json.dump(training_log, f, indent=2) |
|
|
| if patience_counter >= args.patience: |
| logger.info("Early stopping at epoch %d (patience=%d)", epoch, args.patience) |
| break |
|
|
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| |
| config = { |
| "base_model": args.model_id, |
| "method": "PEFT LoRA", |
| "lora_r": args.lora_r, |
| "lora_alpha": args.lora_alpha, |
| "target_modules": ["query", "value"], |
| "num_classes": NUM_CLASSES, |
| "labels": LABELS_7CLASS, |
| "best_val_f1": best_f1, |
| "training_args": vars(args), |
| } |
| with open(output_dir / "config.json", "w") as f: |
| json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
| logger.info("Training complete. Best F1=%.4f", best_f1) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="LoRA fine-tune KcELECTRA 7-class") |
| parser.add_argument("--train-manifest", required=True) |
| parser.add_argument("--val-manifest", required=True) |
| parser.add_argument("--output-dir", default="data/models/lora_kcelectra_7class") |
| parser.add_argument("--model-id", default="beomi/KcELECTRA-base-v2022") |
| parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) |
| parser.add_argument("--epochs", type=int, default=10) |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--accumulate-steps", type=int, default=2) |
| parser.add_argument("--lr", type=float, default=2e-4) |
| parser.add_argument("--weight-decay", type=float, default=0.01) |
| parser.add_argument("--patience", type=int, default=3) |
| parser.add_argument("--max-length", type=int, default=128) |
| parser.add_argument("--lora-r", type=int, default=16) |
| parser.add_argument("--lora-alpha", type=int, default=32) |
| parser.add_argument("--lora-dropout", type=float, default=0.1) |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| torch.manual_seed(42) |
| random.seed(42) |
| np.random.seed(42) |
|
|
| train(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|