"""Fine-tune MiniCPM4.1-8B on the recipe SFT dataset via Modal (A10G GPU). Usage: modal run scripts/train_planner.py After training, the adapter is merged and the full model is pushed to HF Hub as /cook-with-me-planner-8b Set HF_USERNAME below (or export HF_TOKEN env var before running). """ from __future__ import annotations import modal # --------------------------------------------------------------------------- # Config — change these two values # --------------------------------------------------------------------------- HF_USERNAME = "eldinosaur" SFT_DATASET_REPO = f"{HF_USERNAME}/cook-with-me-recipes-sft" OUTPUT_REPO = f"{HF_USERNAME}/cook-with-me-planner-8b" BASE_MODEL = "openbmb/MiniCPM4.1-8B" # --------------------------------------------------------------------------- app = modal.App("cook-with-me-train") volume = modal.Volume.from_name("cook-with-me-train-vol", create_if_missing=True) train_image = ( modal.Image.debian_slim(python_version="3.12") .pip_install( "torch==2.4.0", "transformers>=5.0", "peft>=0.12", "trl>=0.10", "accelerate", "datasets", "huggingface_hub>=1.17", "bitsandbytes", "sentencepiece", "safetensors", ) ) hf_secret = modal.Secret.from_name("huggingface-secret") @app.function( image=train_image, gpu="A10G", timeout=60 * 60 * 3, # 3-hour hard cap secrets=[hf_secret], volumes={"/vol": volume}, ) def train(): import os import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model, TaskType from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments from trl import SFTTrainer, SFTConfig os.environ.setdefault("HF_HOME", "/vol/hf_cache") # MiniCPM4.1-8B custom code references is_torch_fx_available which was # removed in transformers 5.x. Patch it back before loading the model. import transformers.utils.import_utils as _iutils if not hasattr(_iutils, "is_torch_fx_available"): def _is_torch_fx_available(): try: import torch.fx # noqa: F401 return True except ImportError: return False _iutils.is_torch_fx_available = _is_torch_fx_available # ---- Load tokenizer & model ---- print(f"Loading {BASE_MODEL}…") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cuda", ) # ---- LoRA config ---- lora_cfg = LoraConfig( task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32, lora_dropout=0.05, target_modules="all-linear", bias="none", ) model = get_peft_model(model, lora_cfg) model.print_trainable_parameters() # ---- Dataset ---- print(f"Loading dataset {SFT_DATASET_REPO}…") ds = load_dataset(SFT_DATASET_REPO, split="train") def _format(example): return {"text": tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False )} ds = ds.map(_format, remove_columns=ds.column_names) # ---- Training ---- output_dir = "/vol/planner_out" trainer = SFTTrainer( model=model, processing_class=tokenizer, train_dataset=ds, args=SFTConfig( output_dir=output_dir, num_train_epochs=3, # 2046 examples — 3 epochs converges without overfitting per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-4, lr_scheduler_type="cosine", warmup_ratio=0.05, bf16=True, logging_steps=20, save_steps=200, max_length=2048, dataset_text_field="text", ), ) trainer.train() trainer.save_model(output_dir) # ---- Merge LoRA + push ---- print("Merging LoRA adapter…") from peft import PeftModel base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cpu" ) merged = PeftModel.from_pretrained(base, output_dir) merged = merged.merge_and_unload() # MiniCPM custom code declares `_tied_weights_keys` as a list, but # transformers 5.x's save path calls `.keys()` on it. Patch the walker # to tolerate both list and dict formats before saving/pushing. import transformers.modeling_utils as _mu def _safe_get_tied_weight_keys(model, *args, **kwargs): keys = [] for module_name, module in model.named_modules(): tied = getattr(module, "_tied_weights_keys", None) if not tied: continue names = tied.keys() if isinstance(tied, dict) else tied for k in names: keys.append(f"{module_name}.{k}" if module_name else k) return keys _mu._get_tied_weight_keys = _safe_get_tied_weight_keys print(f"Pushing merged model to {OUTPUT_REPO}…") merged.push_to_hub(OUTPUT_REPO, private=False) tokenizer.push_to_hub(OUTPUT_REPO, private=False) print("Done.") @app.local_entrypoint() def main(): train.remote()