File size: 5,487 Bytes
75c5414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Fine-tune MiniCPM4.1-8B on the recipe SFT dataset via Modal (A10G GPU).

Usage:
    modal run scripts/train_planner.py

After training, the adapter is merged and the full model is pushed to HF Hub
as   <HF_USERNAME>/cook-with-me-planner-8b

Set HF_USERNAME below (or export HF_TOKEN env var before running).
"""
from __future__ import annotations

import modal

# ---------------------------------------------------------------------------
# Config — change these two values
# ---------------------------------------------------------------------------
HF_USERNAME = "eldinosaur"
SFT_DATASET_REPO = f"{HF_USERNAME}/cook-with-me-recipes-sft"
OUTPUT_REPO = f"{HF_USERNAME}/cook-with-me-planner-8b"
BASE_MODEL = "openbmb/MiniCPM4.1-8B"
# ---------------------------------------------------------------------------

app = modal.App("cook-with-me-train")

volume = modal.Volume.from_name("cook-with-me-train-vol", create_if_missing=True)

train_image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "torch==2.4.0",
        "transformers>=5.0",
        "peft>=0.12",
        "trl>=0.10",
        "accelerate",
        "datasets",
        "huggingface_hub>=1.17",
        "bitsandbytes",
        "sentencepiece",
        "safetensors",
    )
)

hf_secret = modal.Secret.from_name("huggingface-secret")


@app.function(
    image=train_image,
    gpu="A10G",
    timeout=60 * 60 * 3,          # 3-hour hard cap
    secrets=[hf_secret],
    volumes={"/vol": volume},
)
def train():
    import os
    import torch
    from datasets import load_dataset
    from peft import LoraConfig, get_peft_model, TaskType
    from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
    from trl import SFTTrainer, SFTConfig

    os.environ.setdefault("HF_HOME", "/vol/hf_cache")

    # MiniCPM4.1-8B custom code references is_torch_fx_available which was
    # removed in transformers 5.x. Patch it back before loading the model.
    import transformers.utils.import_utils as _iutils
    if not hasattr(_iutils, "is_torch_fx_available"):
        def _is_torch_fx_available():
            try:
                import torch.fx  # noqa: F401
                return True
            except ImportError:
                return False
        _iutils.is_torch_fx_available = _is_torch_fx_available

    # ---- Load tokenizer & model ----
    print(f"Loading {BASE_MODEL}…")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="cuda",
    )

    # ---- LoRA config ----
    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules="all-linear",
        bias="none",
    )
    model = get_peft_model(model, lora_cfg)
    model.print_trainable_parameters()

    # ---- Dataset ----
    print(f"Loading dataset {SFT_DATASET_REPO}…")
    ds = load_dataset(SFT_DATASET_REPO, split="train")

    def _format(example):
        return {"text": tokenizer.apply_chat_template(
            example["messages"], tokenize=False, add_generation_prompt=False
        )}

    ds = ds.map(_format, remove_columns=ds.column_names)

    # ---- Training ----
    output_dir = "/vol/planner_out"
    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=ds,
        args=SFTConfig(
            output_dir=output_dir,
            num_train_epochs=3,   # 2046 examples — 3 epochs converges without overfitting
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            learning_rate=2e-4,
            lr_scheduler_type="cosine",
            warmup_ratio=0.05,
            bf16=True,
            logging_steps=20,
            save_steps=200,
            max_length=2048,
            dataset_text_field="text",
        ),
    )
    trainer.train()
    trainer.save_model(output_dir)

    # ---- Merge LoRA + push ----
    print("Merging LoRA adapter…")
    from peft import PeftModel

    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cpu"
    )
    merged = PeftModel.from_pretrained(base, output_dir)
    merged = merged.merge_and_unload()

    # MiniCPM custom code declares `_tied_weights_keys` as a list, but
    # transformers 5.x's save path calls `.keys()` on it. Patch the walker
    # to tolerate both list and dict formats before saving/pushing.
    import transformers.modeling_utils as _mu

    def _safe_get_tied_weight_keys(model, *args, **kwargs):
        keys = []
        for module_name, module in model.named_modules():
            tied = getattr(module, "_tied_weights_keys", None)
            if not tied:
                continue
            names = tied.keys() if isinstance(tied, dict) else tied
            for k in names:
                keys.append(f"{module_name}.{k}" if module_name else k)
        return keys

    _mu._get_tied_weight_keys = _safe_get_tied_weight_keys

    print(f"Pushing merged model to {OUTPUT_REPO}…")
    merged.push_to_hub(OUTPUT_REPO, private=False)
    tokenizer.push_to_hub(OUTPUT_REPO, private=False)
    print("Done.")


@app.local_entrypoint()
def main():
    train.remote()