Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Train LoRA only (supervised classification) using CSV labeled data. | |
| Required CSV columns: text/claim, evidence, label | |
| """ | |
| import argparse | |
| import os | |
| from dotenv import load_dotenv | |
| from loguru import logger | |
| from src.data.csv_loader import CSVLabeledLoader | |
| from src.training.lora_trainer import LoRATrainingConfig, train_lora_classification | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Train LoRA for crypto claim classification" | |
| ) | |
| parser.add_argument( | |
| "--train-csv", type=str, default=None, help="Path to TRAIN CSV file" | |
| ) | |
| parser.add_argument( | |
| "--dev-csv", type=str, default=None, help="Path to DEV/EVAL CSV file" | |
| ) | |
| parser.add_argument( | |
| "--csv", type=str, default=None, help="(Deprecated) Alias for --train-csv" | |
| ) | |
| parser.add_argument( | |
| "--output", type=str, default=None, help="Output directory for LoRA model" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", type=int, default=None, help="Training batch size" | |
| ) | |
| parser.add_argument( | |
| "--epochs", type=int, default=None, help="Number of training epochs" | |
| ) | |
| parser.add_argument("--lr", type=float, default=None, help="Learning rate") | |
| parser.add_argument( | |
| "--precision", | |
| type=str, | |
| choices=["auto", "bf16", "fp16", "fp32"], | |
| default=None, | |
| help="Training precision (default: auto)", | |
| ) | |
| parser.add_argument( | |
| "--max-length", type=int, default=None, help="Max sequence length" | |
| ) | |
| parser.add_argument( | |
| "--grad-accum", type=int, default=None, help="Gradient accumulation steps" | |
| ) | |
| parser.add_argument( | |
| "--early-stopping", | |
| type=int, | |
| default=None, | |
| help="Early stopping patience (default: 3)", | |
| ) | |
| parser.add_argument( | |
| "--load-model", | |
| type=str, | |
| default=None, | |
| help="Path to LoRA checkpoint to resume training (e.g., models/lora_llm/checkpoint-190)", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Priority: args > env > legacy env fallback | |
| train_csv = ( | |
| args.train_csv | |
| or args.csv | |
| or os.getenv("TRAIN_CSV_PATH") | |
| or os.getenv("LABELED_CSV_PATH") | |
| ) | |
| dev_csv = args.dev_csv or os.getenv("DEV_CSV_PATH") | |
| if not train_csv or not dev_csv: | |
| raise ValueError( | |
| "Provide --train-csv and --dev-csv (or TRAIN_CSV_PATH and DEV_CSV_PATH). " | |
| "Format: text,evidence,label" | |
| ) | |
| logger.info(f"Loading TRAIN data from {train_csv}...") | |
| train_df = CSVLabeledLoader(train_csv).load() | |
| logger.info(f"TRAIN samples: {len(train_df)}") | |
| logger.info(f"Loading DEV data from {dev_csv}...") | |
| dev_df = CSVLabeledLoader(dev_csv).load() | |
| logger.info(f"DEV samples: {len(dev_df)}") | |
| claims = train_df["text"].tolist() | |
| labels = train_df["label"].tolist() | |
| evidences = train_df["evidence"].tolist() | |
| eval_claims = dev_df["text"].tolist() | |
| eval_labels = dev_df["label"].tolist() | |
| eval_evidences = dev_df["evidence"].tolist() | |
| output_dir = args.output or os.getenv("LORA_OUTPUT_DIR", "models/lora_llm") | |
| default_model_name = LoRATrainingConfig().model_name | |
| lora_config = LoRATrainingConfig( | |
| model_name=os.getenv("LLM_MODEL_NAME", default_model_name), | |
| output_dir=output_dir, | |
| epochs=args.epochs or int(os.getenv("LORA_EPOCHS", "3")), | |
| batch_size=args.batch_size or int(os.getenv("LORA_BATCH_SIZE", "1")), | |
| learning_rate=args.lr or float(os.getenv("LORA_LR", "1e-4")), | |
| precision=args.precision or os.getenv("LORA_PRECISION", "auto"), | |
| max_length=args.max_length or int(os.getenv("LORA_MAX_LENGTH", "256")), | |
| early_stopping_patience=args.early_stopping | |
| or int(os.getenv("LORA_EARLY_STOPPING", "3")), | |
| ) | |
| lora_path = train_lora_classification( | |
| claims=claims, | |
| evidences=evidences, | |
| labels=labels, | |
| eval_claims=eval_claims, | |
| eval_evidences=eval_evidences, | |
| eval_labels=eval_labels, | |
| config=lora_config, | |
| gradient_accumulation_steps=args.grad_accum | |
| or int(os.getenv("LORA_GRAD_ACCUM", "4")), | |
| checkpoint_path=args.load_model, # Load checkpoint to resume training | |
| ) | |
| logger.info(f"LoRA training complete. Model saved to: {lora_path}") | |
| if __name__ == "__main__": | |
| main() | |