#!/usr/bin/env python3 """LoRA Fine-Tuning KcELECTRA for 7-Class Korean Text Emotion Recognition. Uses PEFT LoRA on beomi/KcELECTRA-base-v2022 for text-based emotion classification. Filters out samples without text (e.g., RAVDESS English data). Usage: python scripts/train_lora_kcelectra.py \ --train-manifest data/lora_dataset/train_manifest.json \ --val-manifest data/lora_dataset/val_manifest.json \ --output-dir data/models/lora_kcelectra_7class \ --epochs 10 --batch-size 16 --device cuda """ from __future__ import annotations import argparse import json import logging import random import time from collections import Counter from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) LABELS_7CLASS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] LABEL2IDX = {l: i for i, l in enumerate(LABELS_7CLASS)} NUM_CLASSES = len(LABELS_7CLASS) # ────────────────────────────────────────────── # Dataset # ────────────────────────────────────────────── class TextEmotionDataset(Dataset): """Load text + label from manifest, skip samples without text.""" def __init__(self, manifest_path: str, tokenizer, max_length: int = 128): with open(manifest_path, encoding="utf-8") as f: raw = json.load(f) # Filter: only samples with non-empty Korean text self.samples = [ s for s in raw if s.get("text", "").strip() and s["label"] in LABEL2IDX ] self.tokenizer = tokenizer self.max_length = max_length logger.info("TextEmotionDataset: %d samples (filtered from %d, skipped %d without text)", len(self.samples), len(raw), len(raw) - len(self.samples)) def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] text = sample["text"] label = LABEL2IDX[sample["label"]] encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "labels": torch.tensor(label, dtype=torch.long), } # ────────────────────────────────────────────── # Validation # ────────────────────────────────────────────── @torch.no_grad() def validate(model, val_loader, device, criterion): model.eval() total_loss = 0 y_true, y_pred = [], [] for batch in val_loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits loss = criterion(logits, labels) total_loss += loss.item() * labels.size(0) y_true.extend(labels.cpu().tolist()) y_pred.extend(logits.argmax(dim=-1).cpu().tolist()) from sklearn.metrics import accuracy_score, f1_score, confusion_matrix acc = accuracy_score(y_true, y_pred) f1_per_class = f1_score(y_true, y_pred, labels=list(range(NUM_CLASSES)), average=None, zero_division=0) macro_f1 = float(np.mean(f1_per_class)) per_class = {LABELS_7CLASS[i]: round(float(f1_per_class[i]), 4) for i in range(NUM_CLASSES)} cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES))) return { "loss": round(total_loss / max(len(y_true), 1), 4), "accuracy": round(acc, 4), "macro_f1": round(macro_f1, 4), "per_class_f1": per_class, "confusion_matrix": cm.tolist(), } def plot_confusion_matrix(cm, output_path: Path, epoch: int): try: import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns fig, ax = plt.subplots(figsize=(9, 7)) cm_norm = cm / cm.sum(axis=1, keepdims=True) sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues", xticklabels=LABELS_7CLASS, yticklabels=LABELS_7CLASS, ax=ax) for i in range(NUM_CLASSES): for j in range(NUM_CLASSES): ax.text(j + 0.5, i + 0.7, f"({cm[i][j]})", ha="center", va="center", fontsize=6, color="gray") ax.set_xlabel("Predicted"); ax.set_ylabel("True") ax.set_title(f"KcELECTRA LoRA — Epoch {epoch}") plt.tight_layout(); plt.savefig(str(output_path), dpi=150); plt.close() except ImportError: pass # ────────────────────────────────────────────── # Training # ────────────────────────────────────────────── def train(args): device = args.device use_amp = (device == "cuda") # Load tokenizer + base model from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import LoraConfig, get_peft_model, TaskType logger.info("Loading KcELECTRA: %s", args.model_id) tokenizer = AutoTokenizer.from_pretrained(args.model_id) model = AutoModelForSequenceClassification.from_pretrained( args.model_id, num_labels=NUM_CLASSES, id2label={i: l for i, l in enumerate(LABELS_7CLASS)}, label2id=LABEL2IDX, ) # Apply LoRA lora_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=["query", "value"], task_type=TaskType.SEQ_CLS, bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() model = model.to(device) # Datasets train_ds = TextEmotionDataset(args.train_manifest, tokenizer, max_length=args.max_length) val_ds = TextEmotionDataset(args.val_manifest, tokenizer, max_length=args.max_length) logger.info("Train: %d, Val: %d", len(train_ds), len(val_ds)) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device == "cuda"), drop_last=True, ) val_loader = DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=(device == "cuda"), ) # Class weights: inverse frequency only (no extra boost) class_counts = np.zeros(NUM_CLASSES) for s in train_ds.samples: class_counts[LABEL2IDX[s["label"]]] += 1 class_weights = 1.0 / np.maximum(class_counts, 1) class_weights = class_weights / class_weights.sum() * NUM_CLASSES logger.info("Class weights: %s", {LABELS_7CLASS[i]: round(float(class_weights[i]), 3) for i in range(NUM_CLASSES)}) criterion = nn.CrossEntropyLoss( weight=torch.tensor(class_weights, dtype=torch.float32).to(device), ) # Optimizer optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=args.weight_decay, ) # Scheduler steps_per_epoch = max(len(train_loader) // args.accumulate_steps, 1) total_steps = steps_per_epoch * args.epochs scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, total_steps=total_steps, pct_start=0.1, anneal_strategy="cos", ) scaler = torch.amp.GradScaler("cuda", enabled=use_amp) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Training state training_log = [] best_f1 = 0.0 patience_counter = 0 start_epoch = 1 # Resume if args.resume: ckpt_path = output_dir / "last_checkpoint.json" if ckpt_path.exists(): with open(ckpt_path) as f: ckpt_info = json.load(f) start_epoch = ckpt_info["epoch"] + 1 best_f1 = ckpt_info["best_f1"] patience_counter = ckpt_info["patience_counter"] training_log = ckpt_info.get("training_log", []) # Load model weights model_path = output_dir / "last_model" if model_path.exists(): from peft import PeftModel model = AutoModelForSequenceClassification.from_pretrained( args.model_id, num_labels=NUM_CLASSES, id2label={i: l for i, l in enumerate(LABELS_7CLASS)}, label2id=LABEL2IDX, ) model = PeftModel.from_pretrained(model, str(model_path)) model = model.to(device) # Rebuild optimizer optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=args.weight_decay, ) remaining_steps = steps_per_epoch * (args.epochs - start_epoch + 1) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, total_steps=max(remaining_steps, 1), pct_start=0.1, anneal_strategy="cos", ) logger.info("Resumed from epoch %d (best_f1=%.4f)", start_epoch, best_f1) else: logger.warning("--resume but no checkpoint found, starting fresh") for epoch in range(start_epoch, args.epochs + 1): epoch_start = time.time() model.train() total_loss = 0; correct = 0; total_samples = 0 optimizer.zero_grad() for batch_idx, batch in enumerate(train_loader): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) with torch.amp.autocast("cuda", enabled=use_amp): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits loss = criterion(logits, labels) / args.accumulate_steps scaler.scale(loss).backward() if (batch_idx + 1) % args.accumulate_steps == 0 or (batch_idx + 1) == len(train_loader): scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( [p for p in model.parameters() if p.requires_grad], 1.0, ) scaler.step(optimizer); scaler.update() optimizer.zero_grad(); scheduler.step() total_loss += loss.item() * args.accumulate_steps * labels.size(0) preds = logits.argmax(dim=-1) correct += (preds == labels).sum().item() total_samples += labels.size(0) if (batch_idx + 1) % 50 == 0: logger.info(" Epoch %d [%d/%d] loss=%.4f lr=%.2e", epoch, batch_idx + 1, len(train_loader), loss.item() * args.accumulate_steps, optimizer.param_groups[0]["lr"]) train_loss = total_loss / max(total_samples, 1) train_acc = correct / max(total_samples, 1) # Validate val_metrics = validate(model, val_loader, device, criterion) epoch_time = time.time() - epoch_start logger.info( "Epoch %d/%d (%.0fs): train_loss=%.4f train_acc=%.3f | " "val_loss=%.4f val_f1=%.3f val_acc=%.3f", epoch, args.epochs, epoch_time, train_loss, train_acc, val_metrics["loss"], val_metrics["macro_f1"], val_metrics["accuracy"], ) logger.info(" Per-class F1: %s", val_metrics["per_class_f1"]) training_log.append({ "epoch": epoch, "train_loss": round(train_loss, 4), "train_acc": round(train_acc, 4), **{f"val_{k}": v for k, v in val_metrics.items() if k != "confusion_matrix"}, "epoch_time_sec": round(epoch_time, 1), }) # Best model if val_metrics["macro_f1"] > best_f1: best_f1 = val_metrics["macro_f1"] patience_counter = 0 model.save_pretrained(str(output_dir / "best_model")) tokenizer.save_pretrained(str(output_dir / "best_model")) logger.info(" New best! macro_f1=%.4f saved to best_model/", best_f1) if "confusion_matrix" in val_metrics: cm = np.array(val_metrics["confusion_matrix"]) plot_confusion_matrix(cm, output_dir / f"cm_epoch{epoch}.png", epoch) else: patience_counter += 1 # Save last (for resume) model.save_pretrained(str(output_dir / "last_model")) tokenizer.save_pretrained(str(output_dir / "last_model")) with open(output_dir / "last_checkpoint.json", "w") as f: json.dump({ "epoch": epoch, "best_f1": best_f1, "patience_counter": patience_counter, "training_log": training_log, }, f, indent=2) with open(output_dir / "training_log.json", "w") as f: json.dump(training_log, f, indent=2) if patience_counter >= args.patience: logger.info("Early stopping at epoch %d (patience=%d)", epoch, args.patience) break if device == "cuda": torch.cuda.empty_cache() # Save config config = { "base_model": args.model_id, "method": "PEFT LoRA", "lora_r": args.lora_r, "lora_alpha": args.lora_alpha, "target_modules": ["query", "value"], "num_classes": NUM_CLASSES, "labels": LABELS_7CLASS, "best_val_f1": best_f1, "training_args": vars(args), } with open(output_dir / "config.json", "w") as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info("Training complete. Best F1=%.4f", best_f1) def main(): parser = argparse.ArgumentParser(description="LoRA fine-tune KcELECTRA 7-class") parser.add_argument("--train-manifest", required=True) parser.add_argument("--val-manifest", required=True) parser.add_argument("--output-dir", default="data/models/lora_kcelectra_7class") parser.add_argument("--model-id", default="beomi/KcELECTRA-base-v2022") parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--accumulate-steps", type=int, default=2) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--patience", type=int, default=3) parser.add_argument("--max-length", type=int, default=128) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--lora-alpha", type=int, default=32) parser.add_argument("--lora-dropout", type=float, default=0.1) parser.add_argument("--resume", action="store_true") args = parser.parse_args() torch.manual_seed(42) random.seed(42) np.random.seed(42) train(args) if __name__ == "__main__": main()