Spaces:
Sleeping
Sleeping
File size: 4,457 Bytes
398a289 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | #!/usr/bin/env python3
"""
Train LoRA only (supervised classification) using CSV labeled data.
Required CSV columns: text/claim, evidence, label
"""
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
from src.data.csv_loader import CSVLabeledLoader
from src.training.lora_trainer import LoRATrainingConfig, train_lora_classification
# Load environment variables from .env file
load_dotenv()
def parse_args():
parser = argparse.ArgumentParser(
description="Train LoRA for crypto claim classification"
)
parser.add_argument(
"--train-csv", type=str, default=None, help="Path to TRAIN CSV file"
)
parser.add_argument(
"--dev-csv", type=str, default=None, help="Path to DEV/EVAL CSV file"
)
parser.add_argument(
"--csv", type=str, default=None, help="(Deprecated) Alias for --train-csv"
)
parser.add_argument(
"--output", type=str, default=None, help="Output directory for LoRA model"
)
parser.add_argument(
"--batch-size", type=int, default=None, help="Training batch size"
)
parser.add_argument(
"--epochs", type=int, default=None, help="Number of training epochs"
)
parser.add_argument("--lr", type=float, default=None, help="Learning rate")
parser.add_argument(
"--precision",
type=str,
choices=["auto", "bf16", "fp16", "fp32"],
default=None,
help="Training precision (default: auto)",
)
parser.add_argument(
"--max-length", type=int, default=None, help="Max sequence length"
)
parser.add_argument(
"--grad-accum", type=int, default=None, help="Gradient accumulation steps"
)
parser.add_argument(
"--early-stopping",
type=int,
default=None,
help="Early stopping patience (default: 3)",
)
parser.add_argument(
"--load-model",
type=str,
default=None,
help="Path to LoRA checkpoint to resume training (e.g., models/lora_llm/checkpoint-190)",
)
return parser.parse_args()
def main():
args = parse_args()
# Priority: args > env > legacy env fallback
train_csv = (
args.train_csv
or args.csv
or os.getenv("TRAIN_CSV_PATH")
or os.getenv("LABELED_CSV_PATH")
)
dev_csv = args.dev_csv or os.getenv("DEV_CSV_PATH")
if not train_csv or not dev_csv:
raise ValueError(
"Provide --train-csv and --dev-csv (or TRAIN_CSV_PATH and DEV_CSV_PATH). "
"Format: text,evidence,label"
)
logger.info(f"Loading TRAIN data from {train_csv}...")
train_df = CSVLabeledLoader(train_csv).load()
logger.info(f"TRAIN samples: {len(train_df)}")
logger.info(f"Loading DEV data from {dev_csv}...")
dev_df = CSVLabeledLoader(dev_csv).load()
logger.info(f"DEV samples: {len(dev_df)}")
claims = train_df["text"].tolist()
labels = train_df["label"].tolist()
evidences = train_df["evidence"].tolist()
eval_claims = dev_df["text"].tolist()
eval_labels = dev_df["label"].tolist()
eval_evidences = dev_df["evidence"].tolist()
output_dir = args.output or os.getenv("LORA_OUTPUT_DIR", "models/lora_llm")
default_model_name = LoRATrainingConfig().model_name
lora_config = LoRATrainingConfig(
model_name=os.getenv("LLM_MODEL_NAME", default_model_name),
output_dir=output_dir,
epochs=args.epochs or int(os.getenv("LORA_EPOCHS", "3")),
batch_size=args.batch_size or int(os.getenv("LORA_BATCH_SIZE", "1")),
learning_rate=args.lr or float(os.getenv("LORA_LR", "1e-4")),
precision=args.precision or os.getenv("LORA_PRECISION", "auto"),
max_length=args.max_length or int(os.getenv("LORA_MAX_LENGTH", "256")),
early_stopping_patience=args.early_stopping
or int(os.getenv("LORA_EARLY_STOPPING", "3")),
)
lora_path = train_lora_classification(
claims=claims,
evidences=evidences,
labels=labels,
eval_claims=eval_claims,
eval_evidences=eval_evidences,
eval_labels=eval_labels,
config=lora_config,
gradient_accumulation_steps=args.grad_accum
or int(os.getenv("LORA_GRAD_ACCUM", "4")),
checkpoint_path=args.load_model, # Load checkpoint to resume training
)
logger.info(f"LoRA training complete. Model saved to: {lora_path}")
if __name__ == "__main__":
main()
|