ustwo-api / scripts /train_lora_kcelectra.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
15.8 kB
#!/usr/bin/env python3
"""LoRA Fine-Tuning KcELECTRA for 7-Class Korean Text Emotion Recognition.
Uses PEFT LoRA on beomi/KcELECTRA-base-v2022 for text-based emotion classification.
Filters out samples without text (e.g., RAVDESS English data).
Usage:
python scripts/train_lora_kcelectra.py \
--train-manifest data/lora_dataset/train_manifest.json \
--val-manifest data/lora_dataset/val_manifest.json \
--output-dir data/models/lora_kcelectra_7class \
--epochs 10 --batch-size 16 --device cuda
"""
from __future__ import annotations
import argparse
import json
import logging
import random
import time
from collections import Counter
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
LABELS_7CLASS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"]
LABEL2IDX = {l: i for i, l in enumerate(LABELS_7CLASS)}
NUM_CLASSES = len(LABELS_7CLASS)
# ──────────────────────────────────────────────
# Dataset
# ──────────────────────────────────────────────
class TextEmotionDataset(Dataset):
"""Load text + label from manifest, skip samples without text."""
def __init__(self, manifest_path: str, tokenizer, max_length: int = 128):
with open(manifest_path, encoding="utf-8") as f:
raw = json.load(f)
# Filter: only samples with non-empty Korean text
self.samples = [
s for s in raw
if s.get("text", "").strip() and s["label"] in LABEL2IDX
]
self.tokenizer = tokenizer
self.max_length = max_length
logger.info("TextEmotionDataset: %d samples (filtered from %d, skipped %d without text)",
len(self.samples), len(raw), len(raw) - len(self.samples))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
text = sample["text"]
label = LABEL2IDX[sample["label"]]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(label, dtype=torch.long),
}
# ──────────────────────────────────────────────
# Validation
# ──────────────────────────────────────────────
@torch.no_grad()
def validate(model, val_loader, device, criterion):
model.eval()
total_loss = 0
y_true, y_pred = [], []
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
y_true.extend(labels.cpu().tolist())
y_pred.extend(logits.argmax(dim=-1).cpu().tolist())
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
acc = accuracy_score(y_true, y_pred)
f1_per_class = f1_score(y_true, y_pred, labels=list(range(NUM_CLASSES)),
average=None, zero_division=0)
macro_f1 = float(np.mean(f1_per_class))
per_class = {LABELS_7CLASS[i]: round(float(f1_per_class[i]), 4) for i in range(NUM_CLASSES)}
cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES)))
return {
"loss": round(total_loss / max(len(y_true), 1), 4),
"accuracy": round(acc, 4),
"macro_f1": round(macro_f1, 4),
"per_class_f1": per_class,
"confusion_matrix": cm.tolist(),
}
def plot_confusion_matrix(cm, output_path: Path, epoch: int):
try:
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(9, 7))
cm_norm = cm / cm.sum(axis=1, keepdims=True)
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
xticklabels=LABELS_7CLASS, yticklabels=LABELS_7CLASS, ax=ax)
for i in range(NUM_CLASSES):
for j in range(NUM_CLASSES):
ax.text(j + 0.5, i + 0.7, f"({cm[i][j]})",
ha="center", va="center", fontsize=6, color="gray")
ax.set_xlabel("Predicted"); ax.set_ylabel("True")
ax.set_title(f"KcELECTRA LoRA — Epoch {epoch}")
plt.tight_layout(); plt.savefig(str(output_path), dpi=150); plt.close()
except ImportError:
pass
# ──────────────────────────────────────────────
# Training
# ──────────────────────────────────────────────
def train(args):
device = args.device
use_amp = (device == "cuda")
# Load tokenizer + base model
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType
logger.info("Loading KcELECTRA: %s", args.model_id)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_id,
num_labels=NUM_CLASSES,
id2label={i: l for i, l in enumerate(LABELS_7CLASS)},
label2id=LABEL2IDX,
)
# Apply LoRA
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=["query", "value"],
task_type=TaskType.SEQ_CLS,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model = model.to(device)
# Datasets
train_ds = TextEmotionDataset(args.train_manifest, tokenizer, max_length=args.max_length)
val_ds = TextEmotionDataset(args.val_manifest, tokenizer, max_length=args.max_length)
logger.info("Train: %d, Val: %d", len(train_ds), len(val_ds))
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=2, pin_memory=(device == "cuda"), drop_last=True,
)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size, shuffle=False,
num_workers=2, pin_memory=(device == "cuda"),
)
# Class weights: inverse frequency only (no extra boost)
class_counts = np.zeros(NUM_CLASSES)
for s in train_ds.samples:
class_counts[LABEL2IDX[s["label"]]] += 1
class_weights = 1.0 / np.maximum(class_counts, 1)
class_weights = class_weights / class_weights.sum() * NUM_CLASSES
logger.info("Class weights: %s",
{LABELS_7CLASS[i]: round(float(class_weights[i]), 3) for i in range(NUM_CLASSES)})
criterion = nn.CrossEntropyLoss(
weight=torch.tensor(class_weights, dtype=torch.float32).to(device),
)
# Optimizer
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr,
weight_decay=args.weight_decay,
)
# Scheduler
steps_per_epoch = max(len(train_loader) // args.accumulate_steps, 1)
total_steps = steps_per_epoch * args.epochs
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=args.lr, total_steps=total_steps,
pct_start=0.1, anneal_strategy="cos",
)
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Training state
training_log = []
best_f1 = 0.0
patience_counter = 0
start_epoch = 1
# Resume
if args.resume:
ckpt_path = output_dir / "last_checkpoint.json"
if ckpt_path.exists():
with open(ckpt_path) as f:
ckpt_info = json.load(f)
start_epoch = ckpt_info["epoch"] + 1
best_f1 = ckpt_info["best_f1"]
patience_counter = ckpt_info["patience_counter"]
training_log = ckpt_info.get("training_log", [])
# Load model weights
model_path = output_dir / "last_model"
if model_path.exists():
from peft import PeftModel
model = AutoModelForSequenceClassification.from_pretrained(
args.model_id, num_labels=NUM_CLASSES,
id2label={i: l for i, l in enumerate(LABELS_7CLASS)},
label2id=LABEL2IDX,
)
model = PeftModel.from_pretrained(model, str(model_path))
model = model.to(device)
# Rebuild optimizer
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=args.weight_decay,
)
remaining_steps = steps_per_epoch * (args.epochs - start_epoch + 1)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=args.lr, total_steps=max(remaining_steps, 1),
pct_start=0.1, anneal_strategy="cos",
)
logger.info("Resumed from epoch %d (best_f1=%.4f)", start_epoch, best_f1)
else:
logger.warning("--resume but no checkpoint found, starting fresh")
for epoch in range(start_epoch, args.epochs + 1):
epoch_start = time.time()
model.train()
total_loss = 0; correct = 0; total_samples = 0
optimizer.zero_grad()
for batch_idx, batch in enumerate(train_loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
with torch.amp.autocast("cuda", enabled=use_amp):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
loss = criterion(logits, labels) / args.accumulate_steps
scaler.scale(loss).backward()
if (batch_idx + 1) % args.accumulate_steps == 0 or (batch_idx + 1) == len(train_loader):
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0,
)
scaler.step(optimizer); scaler.update()
optimizer.zero_grad(); scheduler.step()
total_loss += loss.item() * args.accumulate_steps * labels.size(0)
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total_samples += labels.size(0)
if (batch_idx + 1) % 50 == 0:
logger.info(" Epoch %d [%d/%d] loss=%.4f lr=%.2e",
epoch, batch_idx + 1, len(train_loader),
loss.item() * args.accumulate_steps,
optimizer.param_groups[0]["lr"])
train_loss = total_loss / max(total_samples, 1)
train_acc = correct / max(total_samples, 1)
# Validate
val_metrics = validate(model, val_loader, device, criterion)
epoch_time = time.time() - epoch_start
logger.info(
"Epoch %d/%d (%.0fs): train_loss=%.4f train_acc=%.3f | "
"val_loss=%.4f val_f1=%.3f val_acc=%.3f",
epoch, args.epochs, epoch_time, train_loss, train_acc,
val_metrics["loss"], val_metrics["macro_f1"], val_metrics["accuracy"],
)
logger.info(" Per-class F1: %s", val_metrics["per_class_f1"])
training_log.append({
"epoch": epoch,
"train_loss": round(train_loss, 4),
"train_acc": round(train_acc, 4),
**{f"val_{k}": v for k, v in val_metrics.items() if k != "confusion_matrix"},
"epoch_time_sec": round(epoch_time, 1),
})
# Best model
if val_metrics["macro_f1"] > best_f1:
best_f1 = val_metrics["macro_f1"]
patience_counter = 0
model.save_pretrained(str(output_dir / "best_model"))
tokenizer.save_pretrained(str(output_dir / "best_model"))
logger.info(" New best! macro_f1=%.4f saved to best_model/", best_f1)
if "confusion_matrix" in val_metrics:
cm = np.array(val_metrics["confusion_matrix"])
plot_confusion_matrix(cm, output_dir / f"cm_epoch{epoch}.png", epoch)
else:
patience_counter += 1
# Save last (for resume)
model.save_pretrained(str(output_dir / "last_model"))
tokenizer.save_pretrained(str(output_dir / "last_model"))
with open(output_dir / "last_checkpoint.json", "w") as f:
json.dump({
"epoch": epoch, "best_f1": best_f1,
"patience_counter": patience_counter,
"training_log": training_log,
}, f, indent=2)
with open(output_dir / "training_log.json", "w") as f:
json.dump(training_log, f, indent=2)
if patience_counter >= args.patience:
logger.info("Early stopping at epoch %d (patience=%d)", epoch, args.patience)
break
if device == "cuda":
torch.cuda.empty_cache()
# Save config
config = {
"base_model": args.model_id,
"method": "PEFT LoRA",
"lora_r": args.lora_r,
"lora_alpha": args.lora_alpha,
"target_modules": ["query", "value"],
"num_classes": NUM_CLASSES,
"labels": LABELS_7CLASS,
"best_val_f1": best_f1,
"training_args": vars(args),
}
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info("Training complete. Best F1=%.4f", best_f1)
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tune KcELECTRA 7-class")
parser.add_argument("--train-manifest", required=True)
parser.add_argument("--val-manifest", required=True)
parser.add_argument("--output-dir", default="data/models/lora_kcelectra_7class")
parser.add_argument("--model-id", default="beomi/KcELECTRA-base-v2022")
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"])
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--accumulate-steps", type=int, default=2)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--patience", type=int, default=3)
parser.add_argument("--max-length", type=int, default=128)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--lora-dropout", type=float, default=0.1)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
train(args)
if __name__ == "__main__":
main()