File size: 3,745 Bytes
c55ab5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Fine-tune MiniCPM into the SAT transaction classifier (LoRA SFT).

🎯 Well-Tuned. Trains a small LoRA adapter on the chat-format dataset produced by
``build_classifier_dataset.py`` and (optionally) pushes it to the Hub.

This requires a GPU and is meant to run on Modal (see modal_app/finetune_modal.py)
or any CUDA box — NOT on the laptop that serves the Space. Usage:

    python -m scripts.train_classifier \
        --base openbmb/MiniCPM-... \
        --data data/finetune \
        --out  artifacts/sat-classifier-lora \
        --push your-org/cuentas-claras-sat-classifier

The design mirrors our previous hackathon planner fine-tune (LoRA SFT, ~1-2k pairs),
which is enough for this narrow, well-specified classification task.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path


def load_chat_jsonl(path: Path):
    from datasets import Dataset
    rows = [json.loads(l) for l in path.read_text(encoding="utf-8").splitlines() if l.strip()]
    # keep only the chat messages for SFT
    return Dataset.from_list([{"messages": r["messages"]} for r in rows])


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base", default="openbmb/MiniCPM3-4B",
                    help="base model repo (≤32B; ≤4B keeps the Tiny Titan option open)")
    ap.add_argument("--data", default="data/finetune")
    ap.add_argument("--out", default="artifacts/sat-classifier-lora")
    ap.add_argument("--push", default=None, help="Hub repo id to push the adapter to")
    ap.add_argument("--epochs", type=float, default=3.0)
    ap.add_argument("--lr", type=float, default=2e-4)
    ap.add_argument("--batch", type=int, default=8)
    ap.add_argument("--max_seq", type=int, default=1024)
    args = ap.parse_args()

    import torch
    from peft import LoraConfig
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from trl import SFTConfig, SFTTrainer

    data_dir = Path(args.data)
    train_ds = load_chat_jsonl(data_dir / "train.jsonl")
    eval_ds = load_chat_jsonl(data_dir / "val.jsonl")

    tokenizer = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.base, trust_remote_code=True,
        torch_dtype=torch.bfloat16, device_map="auto",
    )

    peft_config = LoraConfig(
        r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    )

    sft_config = SFTConfig(
        output_dir=args.out,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch,
        gradient_accumulation_steps=2,
        learning_rate=args.lr,
        max_seq_length=args.max_seq,
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        bf16=True,
        report_to="none",
        # SFTTrainer applies the chat template to the "messages" column.
        packing=False,
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        peft_config=peft_config,
        processing_class=tokenizer,
    )
    trainer.train()
    metrics = trainer.evaluate()
    print("eval:", metrics)

    trainer.save_model(args.out)
    tokenizer.save_pretrained(args.out)
    Path(args.out, "eval_metrics.json").write_text(json.dumps(metrics, indent=2))

    if args.push:
        trainer.model.push_to_hub(args.push)
        tokenizer.push_to_hub(args.push)
        print(f"Pushed adapter → https://huggingface.co/{args.push}")


if __name__ == "__main__":
    main()