PocketAccountant / scripts /train_classifier.py
eldinosaur's picture
PocketAccountant: custom ledger UI + deterministic agent (engine, ledger, retrieval, classifier)
c55ab5e verified
Raw
History Blame Contribute Delete
3.75 kB
"""Fine-tune MiniCPM into the SAT transaction classifier (LoRA SFT).
🎯 Well-Tuned. Trains a small LoRA adapter on the chat-format dataset produced by
``build_classifier_dataset.py`` and (optionally) pushes it to the Hub.
This requires a GPU and is meant to run on Modal (see modal_app/finetune_modal.py)
or any CUDA box — NOT on the laptop that serves the Space. Usage:
python -m scripts.train_classifier \
--base openbmb/MiniCPM-... \
--data data/finetune \
--out artifacts/sat-classifier-lora \
--push your-org/cuentas-claras-sat-classifier
The design mirrors our previous hackathon planner fine-tune (LoRA SFT, ~1-2k pairs),
which is enough for this narrow, well-specified classification task.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
def load_chat_jsonl(path: Path):
from datasets import Dataset
rows = [json.loads(l) for l in path.read_text(encoding="utf-8").splitlines() if l.strip()]
# keep only the chat messages for SFT
return Dataset.from_list([{"messages": r["messages"]} for r in rows])
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base", default="openbmb/MiniCPM3-4B",
help="base model repo (≤32B; ≤4B keeps the Tiny Titan option open)")
ap.add_argument("--data", default="data/finetune")
ap.add_argument("--out", default="artifacts/sat-classifier-lora")
ap.add_argument("--push", default=None, help="Hub repo id to push the adapter to")
ap.add_argument("--epochs", type=float, default=3.0)
ap.add_argument("--lr", type=float, default=2e-4)
ap.add_argument("--batch", type=int, default=8)
ap.add_argument("--max_seq", type=int, default=1024)
args = ap.parse_args()
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
data_dir = Path(args.data)
train_ds = load_chat_jsonl(data_dir / "train.jsonl")
eval_ds = load_chat_jsonl(data_dir / "val.jsonl")
tokenizer = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.base, trust_remote_code=True,
torch_dtype=torch.bfloat16, device_map="auto",
)
peft_config = LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
sft_config = SFTConfig(
output_dir=args.out,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch,
gradient_accumulation_steps=2,
learning_rate=args.lr,
max_seq_length=args.max_seq,
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
bf16=True,
report_to="none",
# SFTTrainer applies the chat template to the "messages" column.
packing=False,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
peft_config=peft_config,
processing_class=tokenizer,
)
trainer.train()
metrics = trainer.evaluate()
print("eval:", metrics)
trainer.save_model(args.out)
tokenizer.save_pretrained(args.out)
Path(args.out, "eval_metrics.json").write_text(json.dumps(metrics, indent=2))
if args.push:
trainer.model.push_to_hub(args.push)
tokenizer.push_to_hub(args.push)
print(f"Pushed adapter → https://huggingface.co/{args.push}")
if __name__ == "__main__":
main()