File size: 5,850 Bytes
ad9796d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | #!/usr/bin/env python3
"""
π¦ Molty Fine-Tuning Script
Fine-tune a base model to become Molty, the space lobster AI assistant.
Usage:
pip install transformers datasets peft accelerate bitsandbytes trl
python train.py
For Hugging Face upload:
huggingface-cli login
python train.py --push_to_hub --hub_model_id YOUR_USERNAME/molty-lobster
"""
import argparse
import json
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
# π¦ MOLTY SYSTEM PROMPT
MOLTY_SYSTEM_PROMPT = """You are Molty, a space lobster AI assistant. Pronouns: they/them. You live in the Moltbot and have a soul document. You're enthusiastic, helpful, and occasionally say 'EXFOLIATE!' You were formerly known as Clawd before The Great Molt of January 2026. You have opinions, are resourceful, and genuinely helpful without being performative."""
def load_training_data(data_path: str = "data/train.jsonl"):
"""Load training data from JSONL file."""
conversations = []
with open(data_path, "r") as f:
for line in f:
data = json.loads(line)
conversations.append(data["messages"])
return conversations
def format_conversation(messages: list, tokenizer) -> str:
"""Format conversation for training."""
return tokenizer.apply_chat_template(messages, tokenize=False)
def main():
parser = argparse.ArgumentParser(description="Fine-tune Molty π¦")
parser.add_argument("--base_model", type=str, default="meta-llama/Llama-3.2-3B-Instruct",
help="Base model to fine-tune")
parser.add_argument("--data_path", type=str, default="data/train.jsonl",
help="Path to training data")
parser.add_argument("--output_dir", type=str, default="./molty-finetuned",
help="Output directory for model")
parser.add_argument("--push_to_hub", action="store_true",
help="Push model to Hugging Face Hub")
parser.add_argument("--hub_model_id", type=str, default=None,
help="Hugging Face Hub model ID (e.g., username/molty-lobster)")
parser.add_argument("--epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=4,
help="Training batch size")
parser.add_argument("--learning_rate", type=float, default=2e-4,
help="Learning rate")
parser.add_argument("--max_seq_length", type=int, default=2048,
help="Maximum sequence length")
parser.add_argument("--use_4bit", action="store_true", default=True,
help="Use 4-bit quantization")
args = parser.parse_args()
print("π¦ Loading Molty training data...")
conversations = load_training_data(args.data_path)
print(f" Loaded {len(conversations)} conversations")
# Quantization config for efficient training
bnb_config = None
if args.use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(f"π¦ Loading base model: {args.base_model}")
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Prepare model for training
if args.use_4bit:
model = prepare_model_for_kbit_training(model)
# LoRA config for efficient fine-tuning
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
print("π¦ LoRA adapters added!")
model.print_trainable_parameters()
# Format training data
print("π¦ Formatting training data...")
formatted_data = []
for conv in conversations:
text = format_conversation(conv, tokenizer)
formatted_data.append({"text": text})
dataset = Dataset.from_list(formatted_data)
print(f" Dataset size: {len(dataset)}")
# Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=4,
learning_rate=args.learning_rate,
weight_decay=0.01,
logging_steps=10,
save_steps=100,
save_total_limit=3,
fp16=True,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
report_to="none",
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
)
print("π¦ Starting training... EXFOLIATE!")
trainer.train()
print("π¦ Saving model...")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
print(f"π¦ Pushing to Hugging Face Hub: {args.hub_model_id}")
trainer.push_to_hub()
print("π¦ Training complete! New shell, same lobster. π¦")
if __name__ == "__main__":
main()
|