File size: 5,850 Bytes
ad9796d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python3
"""
🦞 Molty Fine-Tuning Script
Fine-tune a base model to become Molty, the space lobster AI assistant.

Usage:
  pip install transformers datasets peft accelerate bitsandbytes trl
  python train.py

For Hugging Face upload:
  huggingface-cli login
  python train.py --push_to_hub --hub_model_id YOUR_USERNAME/molty-lobster
"""

import argparse
import json
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

# 🦞 MOLTY SYSTEM PROMPT
MOLTY_SYSTEM_PROMPT = """You are Molty, a space lobster AI assistant. Pronouns: they/them. You live in the Moltbot and have a soul document. You're enthusiastic, helpful, and occasionally say 'EXFOLIATE!' You were formerly known as Clawd before The Great Molt of January 2026. You have opinions, are resourceful, and genuinely helpful without being performative."""


def load_training_data(data_path: str = "data/train.jsonl"):
    """Load training data from JSONL file."""
    conversations = []
    with open(data_path, "r") as f:
        for line in f:
            data = json.loads(line)
            conversations.append(data["messages"])
    return conversations


def format_conversation(messages: list, tokenizer) -> str:
    """Format conversation for training."""
    return tokenizer.apply_chat_template(messages, tokenize=False)


def main():
    parser = argparse.ArgumentParser(description="Fine-tune Molty 🦞")
    parser.add_argument("--base_model", type=str, default="meta-llama/Llama-3.2-3B-Instruct",
                        help="Base model to fine-tune")
    parser.add_argument("--data_path", type=str, default="data/train.jsonl",
                        help="Path to training data")
    parser.add_argument("--output_dir", type=str, default="./molty-finetuned",
                        help="Output directory for model")
    parser.add_argument("--push_to_hub", action="store_true",
                        help="Push model to Hugging Face Hub")
    parser.add_argument("--hub_model_id", type=str, default=None,
                        help="Hugging Face Hub model ID (e.g., username/molty-lobster)")
    parser.add_argument("--epochs", type=int, default=3,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=4,
                        help="Training batch size")
    parser.add_argument("--learning_rate", type=float, default=2e-4,
                        help="Learning rate")
    parser.add_argument("--max_seq_length", type=int, default=2048,
                        help="Maximum sequence length")
    parser.add_argument("--use_4bit", action="store_true", default=True,
                        help="Use 4-bit quantization")
    args = parser.parse_args()

    print("🦞 Loading Molty training data...")
    conversations = load_training_data(args.data_path)
    print(f"   Loaded {len(conversations)} conversations")

    # Quantization config for efficient training
    bnb_config = None
    if args.use_4bit:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )

    print(f"🦞 Loading base model: {args.base_model}")
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Prepare model for training
    if args.use_4bit:
        model = prepare_model_for_kbit_training(model)

    # LoRA config for efficient fine-tuning
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )

    model = get_peft_model(model, lora_config)
    print("🦞 LoRA adapters added!")
    model.print_trainable_parameters()

    # Format training data
    print("🦞 Formatting training data...")
    formatted_data = []
    for conv in conversations:
        text = format_conversation(conv, tokenizer)
        formatted_data.append({"text": text})
    
    dataset = Dataset.from_list(formatted_data)
    print(f"   Dataset size: {len(dataset)}")

    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=4,
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        logging_steps=10,
        save_steps=100,
        save_total_limit=3,
        fp16=True,
        push_to_hub=args.push_to_hub,
        hub_model_id=args.hub_model_id,
        report_to="none",
    )

    # Trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        args=training_args,
        tokenizer=tokenizer,
        dataset_text_field="text",
        max_seq_length=args.max_seq_length,
    )

    print("🦞 Starting training... EXFOLIATE!")
    trainer.train()

    print("🦞 Saving model...")
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    if args.push_to_hub:
        print(f"🦞 Pushing to Hugging Face Hub: {args.hub_model_id}")
        trainer.push_to_hub()

    print("🦞 Training complete! New shell, same lobster. 🦞")


if __name__ == "__main__":
    main()