File size: 4,457 Bytes
398a289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
Train LoRA only (supervised classification) using CSV labeled data.
Required CSV columns: text/claim, evidence, label
"""

import argparse
import os

from dotenv import load_dotenv
from loguru import logger

from src.data.csv_loader import CSVLabeledLoader
from src.training.lora_trainer import LoRATrainingConfig, train_lora_classification

# Load environment variables from .env file
load_dotenv()


def parse_args():
    parser = argparse.ArgumentParser(
        description="Train LoRA for crypto claim classification"
    )
    parser.add_argument(
        "--train-csv", type=str, default=None, help="Path to TRAIN CSV file"
    )
    parser.add_argument(
        "--dev-csv", type=str, default=None, help="Path to DEV/EVAL CSV file"
    )
    parser.add_argument(
        "--csv", type=str, default=None, help="(Deprecated) Alias for --train-csv"
    )
    parser.add_argument(
        "--output", type=str, default=None, help="Output directory for LoRA model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=None, help="Training batch size"
    )
    parser.add_argument(
        "--epochs", type=int, default=None, help="Number of training epochs"
    )
    parser.add_argument("--lr", type=float, default=None, help="Learning rate")
    parser.add_argument(
        "--precision",
        type=str,
        choices=["auto", "bf16", "fp16", "fp32"],
        default=None,
        help="Training precision (default: auto)",
    )
    parser.add_argument(
        "--max-length", type=int, default=None, help="Max sequence length"
    )
    parser.add_argument(
        "--grad-accum", type=int, default=None, help="Gradient accumulation steps"
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=None,
        help="Early stopping patience (default: 3)",
    )
    parser.add_argument(
        "--load-model",
        type=str,
        default=None,
        help="Path to LoRA checkpoint to resume training (e.g., models/lora_llm/checkpoint-190)",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    # Priority: args > env > legacy env fallback
    train_csv = (
        args.train_csv
        or args.csv
        or os.getenv("TRAIN_CSV_PATH")
        or os.getenv("LABELED_CSV_PATH")
    )
    dev_csv = args.dev_csv or os.getenv("DEV_CSV_PATH")
    if not train_csv or not dev_csv:
        raise ValueError(
            "Provide --train-csv and --dev-csv (or TRAIN_CSV_PATH and DEV_CSV_PATH). "
            "Format: text,evidence,label"
        )

    logger.info(f"Loading TRAIN data from {train_csv}...")
    train_df = CSVLabeledLoader(train_csv).load()
    logger.info(f"TRAIN samples: {len(train_df)}")

    logger.info(f"Loading DEV data from {dev_csv}...")
    dev_df = CSVLabeledLoader(dev_csv).load()
    logger.info(f"DEV samples: {len(dev_df)}")

    claims = train_df["text"].tolist()
    labels = train_df["label"].tolist()
    evidences = train_df["evidence"].tolist()
    eval_claims = dev_df["text"].tolist()
    eval_labels = dev_df["label"].tolist()
    eval_evidences = dev_df["evidence"].tolist()

    output_dir = args.output or os.getenv("LORA_OUTPUT_DIR", "models/lora_llm")
    default_model_name = LoRATrainingConfig().model_name
    lora_config = LoRATrainingConfig(
        model_name=os.getenv("LLM_MODEL_NAME", default_model_name),
        output_dir=output_dir,
        epochs=args.epochs or int(os.getenv("LORA_EPOCHS", "3")),
        batch_size=args.batch_size or int(os.getenv("LORA_BATCH_SIZE", "1")),
        learning_rate=args.lr or float(os.getenv("LORA_LR", "1e-4")),
        precision=args.precision or os.getenv("LORA_PRECISION", "auto"),
        max_length=args.max_length or int(os.getenv("LORA_MAX_LENGTH", "256")),
        early_stopping_patience=args.early_stopping
        or int(os.getenv("LORA_EARLY_STOPPING", "3")),
    )

    lora_path = train_lora_classification(
        claims=claims,
        evidences=evidences,
        labels=labels,
        eval_claims=eval_claims,
        eval_evidences=eval_evidences,
        eval_labels=eval_labels,
        config=lora_config,
        gradient_accumulation_steps=args.grad_accum
        or int(os.getenv("LORA_GRAD_ACCUM", "4")),
        checkpoint_path=args.load_model,  # Load checkpoint to resume training
    )

    logger.info(f"LoRA training complete. Model saved to: {lora_path}")


if __name__ == "__main__":
    main()