#!/usr/bin/env python3 """ Train LoRA only (supervised classification) using CSV labeled data. Required CSV columns: text/claim, evidence, label """ import argparse import os from dotenv import load_dotenv from loguru import logger from src.data.csv_loader import CSVLabeledLoader from src.training.lora_trainer import LoRATrainingConfig, train_lora_classification # Load environment variables from .env file load_dotenv() def parse_args(): parser = argparse.ArgumentParser( description="Train LoRA for crypto claim classification" ) parser.add_argument( "--train-csv", type=str, default=None, help="Path to TRAIN CSV file" ) parser.add_argument( "--dev-csv", type=str, default=None, help="Path to DEV/EVAL CSV file" ) parser.add_argument( "--csv", type=str, default=None, help="(Deprecated) Alias for --train-csv" ) parser.add_argument( "--output", type=str, default=None, help="Output directory for LoRA model" ) parser.add_argument( "--batch-size", type=int, default=None, help="Training batch size" ) parser.add_argument( "--epochs", type=int, default=None, help="Number of training epochs" ) parser.add_argument("--lr", type=float, default=None, help="Learning rate") parser.add_argument( "--precision", type=str, choices=["auto", "bf16", "fp16", "fp32"], default=None, help="Training precision (default: auto)", ) parser.add_argument( "--max-length", type=int, default=None, help="Max sequence length" ) parser.add_argument( "--grad-accum", type=int, default=None, help="Gradient accumulation steps" ) parser.add_argument( "--early-stopping", type=int, default=None, help="Early stopping patience (default: 3)", ) parser.add_argument( "--load-model", type=str, default=None, help="Path to LoRA checkpoint to resume training (e.g., models/lora_llm/checkpoint-190)", ) return parser.parse_args() def main(): args = parse_args() # Priority: args > env > legacy env fallback train_csv = ( args.train_csv or args.csv or os.getenv("TRAIN_CSV_PATH") or os.getenv("LABELED_CSV_PATH") ) dev_csv = args.dev_csv or os.getenv("DEV_CSV_PATH") if not train_csv or not dev_csv: raise ValueError( "Provide --train-csv and --dev-csv (or TRAIN_CSV_PATH and DEV_CSV_PATH). " "Format: text,evidence,label" ) logger.info(f"Loading TRAIN data from {train_csv}...") train_df = CSVLabeledLoader(train_csv).load() logger.info(f"TRAIN samples: {len(train_df)}") logger.info(f"Loading DEV data from {dev_csv}...") dev_df = CSVLabeledLoader(dev_csv).load() logger.info(f"DEV samples: {len(dev_df)}") claims = train_df["text"].tolist() labels = train_df["label"].tolist() evidences = train_df["evidence"].tolist() eval_claims = dev_df["text"].tolist() eval_labels = dev_df["label"].tolist() eval_evidences = dev_df["evidence"].tolist() output_dir = args.output or os.getenv("LORA_OUTPUT_DIR", "models/lora_llm") default_model_name = LoRATrainingConfig().model_name lora_config = LoRATrainingConfig( model_name=os.getenv("LLM_MODEL_NAME", default_model_name), output_dir=output_dir, epochs=args.epochs or int(os.getenv("LORA_EPOCHS", "3")), batch_size=args.batch_size or int(os.getenv("LORA_BATCH_SIZE", "1")), learning_rate=args.lr or float(os.getenv("LORA_LR", "1e-4")), precision=args.precision or os.getenv("LORA_PRECISION", "auto"), max_length=args.max_length or int(os.getenv("LORA_MAX_LENGTH", "256")), early_stopping_patience=args.early_stopping or int(os.getenv("LORA_EARLY_STOPPING", "3")), ) lora_path = train_lora_classification( claims=claims, evidences=evidences, labels=labels, eval_claims=eval_claims, eval_evidences=eval_evidences, eval_labels=eval_labels, config=lora_config, gradient_accumulation_steps=args.grad_accum or int(os.getenv("LORA_GRAD_ACCUM", "4")), checkpoint_path=args.load_model, # Load checkpoint to resume training ) logger.info(f"LoRA training complete. Model saved to: {lora_path}") if __name__ == "__main__": main()