Cook_with_a_LLM / scripts /train_planner.py
Fred1e4's picture
Complete Cook App (#5)
75c5414
"""Fine-tune MiniCPM4.1-8B on the recipe SFT dataset via Modal (A10G GPU).
Usage:
modal run scripts/train_planner.py
After training, the adapter is merged and the full model is pushed to HF Hub
as <HF_USERNAME>/cook-with-me-planner-8b
Set HF_USERNAME below (or export HF_TOKEN env var before running).
"""
from __future__ import annotations
import modal
# ---------------------------------------------------------------------------
# Config — change these two values
# ---------------------------------------------------------------------------
HF_USERNAME = "eldinosaur"
SFT_DATASET_REPO = f"{HF_USERNAME}/cook-with-me-recipes-sft"
OUTPUT_REPO = f"{HF_USERNAME}/cook-with-me-planner-8b"
BASE_MODEL = "openbmb/MiniCPM4.1-8B"
# ---------------------------------------------------------------------------
app = modal.App("cook-with-me-train")
volume = modal.Volume.from_name("cook-with-me-train-vol", create_if_missing=True)
train_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"torch==2.4.0",
"transformers>=5.0",
"peft>=0.12",
"trl>=0.10",
"accelerate",
"datasets",
"huggingface_hub>=1.17",
"bitsandbytes",
"sentencepiece",
"safetensors",
)
)
hf_secret = modal.Secret.from_name("huggingface-secret")
@app.function(
image=train_image,
gpu="A10G",
timeout=60 * 60 * 3, # 3-hour hard cap
secrets=[hf_secret],
volumes={"/vol": volume},
)
def train():
import os
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, SFTConfig
os.environ.setdefault("HF_HOME", "/vol/hf_cache")
# MiniCPM4.1-8B custom code references is_torch_fx_available which was
# removed in transformers 5.x. Patch it back before loading the model.
import transformers.utils.import_utils as _iutils
if not hasattr(_iutils, "is_torch_fx_available"):
def _is_torch_fx_available():
try:
import torch.fx # noqa: F401
return True
except ImportError:
return False
_iutils.is_torch_fx_available = _is_torch_fx_available
# ---- Load tokenizer & model ----
print(f"Loading {BASE_MODEL}…")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cuda",
)
# ---- LoRA config ----
lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules="all-linear",
bias="none",
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
# ---- Dataset ----
print(f"Loading dataset {SFT_DATASET_REPO}…")
ds = load_dataset(SFT_DATASET_REPO, split="train")
def _format(example):
return {"text": tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)}
ds = ds.map(_format, remove_columns=ds.column_names)
# ---- Training ----
output_dir = "/vol/planner_out"
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=ds,
args=SFTConfig(
output_dir=output_dir,
num_train_epochs=3, # 2046 examples — 3 epochs converges without overfitting
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
bf16=True,
logging_steps=20,
save_steps=200,
max_length=2048,
dataset_text_field="text",
),
)
trainer.train()
trainer.save_model(output_dir)
# ---- Merge LoRA + push ----
print("Merging LoRA adapter…")
from peft import PeftModel
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cpu"
)
merged = PeftModel.from_pretrained(base, output_dir)
merged = merged.merge_and_unload()
# MiniCPM custom code declares `_tied_weights_keys` as a list, but
# transformers 5.x's save path calls `.keys()` on it. Patch the walker
# to tolerate both list and dict formats before saving/pushing.
import transformers.modeling_utils as _mu
def _safe_get_tied_weight_keys(model, *args, **kwargs):
keys = []
for module_name, module in model.named_modules():
tied = getattr(module, "_tied_weights_keys", None)
if not tied:
continue
names = tied.keys() if isinstance(tied, dict) else tied
for k in names:
keys.append(f"{module_name}.{k}" if module_name else k)
return keys
_mu._get_tied_weight_keys = _safe_get_tied_weight_keys
print(f"Pushing merged model to {OUTPUT_REPO}…")
merged.push_to_hub(OUTPUT_REPO, private=False)
tokenizer.push_to_hub(OUTPUT_REPO, private=False)
print("Done.")
@app.local_entrypoint()
def main():
train.remote()