Spaces:
Running on Zero
Running on Zero
| """Fine-tune MiniCPM4.1-8B on the recipe SFT dataset via Modal (A10G GPU). | |
| Usage: | |
| modal run scripts/train_planner.py | |
| After training, the adapter is merged and the full model is pushed to HF Hub | |
| as <HF_USERNAME>/cook-with-me-planner-8b | |
| Set HF_USERNAME below (or export HF_TOKEN env var before running). | |
| """ | |
| from __future__ import annotations | |
| import modal | |
| # --------------------------------------------------------------------------- | |
| # Config — change these two values | |
| # --------------------------------------------------------------------------- | |
| HF_USERNAME = "eldinosaur" | |
| SFT_DATASET_REPO = f"{HF_USERNAME}/cook-with-me-recipes-sft" | |
| OUTPUT_REPO = f"{HF_USERNAME}/cook-with-me-planner-8b" | |
| BASE_MODEL = "openbmb/MiniCPM4.1-8B" | |
| # --------------------------------------------------------------------------- | |
| app = modal.App("cook-with-me-train") | |
| volume = modal.Volume.from_name("cook-with-me-train-vol", create_if_missing=True) | |
| train_image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .pip_install( | |
| "torch==2.4.0", | |
| "transformers>=5.0", | |
| "peft>=0.12", | |
| "trl>=0.10", | |
| "accelerate", | |
| "datasets", | |
| "huggingface_hub>=1.17", | |
| "bitsandbytes", | |
| "sentencepiece", | |
| "safetensors", | |
| ) | |
| ) | |
| hf_secret = modal.Secret.from_name("huggingface-secret") | |
| def train(): | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments | |
| from trl import SFTTrainer, SFTConfig | |
| os.environ.setdefault("HF_HOME", "/vol/hf_cache") | |
| # MiniCPM4.1-8B custom code references is_torch_fx_available which was | |
| # removed in transformers 5.x. Patch it back before loading the model. | |
| import transformers.utils.import_utils as _iutils | |
| if not hasattr(_iutils, "is_torch_fx_available"): | |
| def _is_torch_fx_available(): | |
| try: | |
| import torch.fx # noqa: F401 | |
| return True | |
| except ImportError: | |
| return False | |
| _iutils.is_torch_fx_available = _is_torch_fx_available | |
| # ---- Load tokenizer & model ---- | |
| print(f"Loading {BASE_MODEL}…") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="cuda", | |
| ) | |
| # ---- LoRA config ---- | |
| lora_cfg = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules="all-linear", | |
| bias="none", | |
| ) | |
| model = get_peft_model(model, lora_cfg) | |
| model.print_trainable_parameters() | |
| # ---- Dataset ---- | |
| print(f"Loading dataset {SFT_DATASET_REPO}…") | |
| ds = load_dataset(SFT_DATASET_REPO, split="train") | |
| def _format(example): | |
| return {"text": tokenizer.apply_chat_template( | |
| example["messages"], tokenize=False, add_generation_prompt=False | |
| )} | |
| ds = ds.map(_format, remove_columns=ds.column_names) | |
| # ---- Training ---- | |
| output_dir = "/vol/planner_out" | |
| trainer = SFTTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| train_dataset=ds, | |
| args=SFTConfig( | |
| output_dir=output_dir, | |
| num_train_epochs=3, # 2046 examples — 3 epochs converges without overfitting | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| learning_rate=2e-4, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.05, | |
| bf16=True, | |
| logging_steps=20, | |
| save_steps=200, | |
| max_length=2048, | |
| dataset_text_field="text", | |
| ), | |
| ) | |
| trainer.train() | |
| trainer.save_model(output_dir) | |
| # ---- Merge LoRA + push ---- | |
| print("Merging LoRA adapter…") | |
| from peft import PeftModel | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cpu" | |
| ) | |
| merged = PeftModel.from_pretrained(base, output_dir) | |
| merged = merged.merge_and_unload() | |
| # MiniCPM custom code declares `_tied_weights_keys` as a list, but | |
| # transformers 5.x's save path calls `.keys()` on it. Patch the walker | |
| # to tolerate both list and dict formats before saving/pushing. | |
| import transformers.modeling_utils as _mu | |
| def _safe_get_tied_weight_keys(model, *args, **kwargs): | |
| keys = [] | |
| for module_name, module in model.named_modules(): | |
| tied = getattr(module, "_tied_weights_keys", None) | |
| if not tied: | |
| continue | |
| names = tied.keys() if isinstance(tied, dict) else tied | |
| for k in names: | |
| keys.append(f"{module_name}.{k}" if module_name else k) | |
| return keys | |
| _mu._get_tied_weight_keys = _safe_get_tied_weight_keys | |
| print(f"Pushing merged model to {OUTPUT_REPO}…") | |
| merged.push_to_hub(OUTPUT_REPO, private=False) | |
| tokenizer.push_to_hub(OUTPUT_REPO, private=False) | |
| print("Done.") | |
| def main(): | |
| train.remote() | |