molty-lobster / train.py
moltbot's picture
Upload 6 files
ad9796d verified
#!/usr/bin/env python3
"""
🦞 Molty Fine-Tuning Script
Fine-tune a base model to become Molty, the space lobster AI assistant.
Usage:
pip install transformers datasets peft accelerate bitsandbytes trl
python train.py
For Hugging Face upload:
huggingface-cli login
python train.py --push_to_hub --hub_model_id YOUR_USERNAME/molty-lobster
"""
import argparse
import json
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
# 🦞 MOLTY SYSTEM PROMPT
MOLTY_SYSTEM_PROMPT = """You are Molty, a space lobster AI assistant. Pronouns: they/them. You live in the Moltbot and have a soul document. You're enthusiastic, helpful, and occasionally say 'EXFOLIATE!' You were formerly known as Clawd before The Great Molt of January 2026. You have opinions, are resourceful, and genuinely helpful without being performative."""
def load_training_data(data_path: str = "data/train.jsonl"):
"""Load training data from JSONL file."""
conversations = []
with open(data_path, "r") as f:
for line in f:
data = json.loads(line)
conversations.append(data["messages"])
return conversations
def format_conversation(messages: list, tokenizer) -> str:
"""Format conversation for training."""
return tokenizer.apply_chat_template(messages, tokenize=False)
def main():
parser = argparse.ArgumentParser(description="Fine-tune Molty 🦞")
parser.add_argument("--base_model", type=str, default="meta-llama/Llama-3.2-3B-Instruct",
help="Base model to fine-tune")
parser.add_argument("--data_path", type=str, default="data/train.jsonl",
help="Path to training data")
parser.add_argument("--output_dir", type=str, default="./molty-finetuned",
help="Output directory for model")
parser.add_argument("--push_to_hub", action="store_true",
help="Push model to Hugging Face Hub")
parser.add_argument("--hub_model_id", type=str, default=None,
help="Hugging Face Hub model ID (e.g., username/molty-lobster)")
parser.add_argument("--epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=4,
help="Training batch size")
parser.add_argument("--learning_rate", type=float, default=2e-4,
help="Learning rate")
parser.add_argument("--max_seq_length", type=int, default=2048,
help="Maximum sequence length")
parser.add_argument("--use_4bit", action="store_true", default=True,
help="Use 4-bit quantization")
args = parser.parse_args()
print("🦞 Loading Molty training data...")
conversations = load_training_data(args.data_path)
print(f" Loaded {len(conversations)} conversations")
# Quantization config for efficient training
bnb_config = None
if args.use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(f"🦞 Loading base model: {args.base_model}")
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Prepare model for training
if args.use_4bit:
model = prepare_model_for_kbit_training(model)
# LoRA config for efficient fine-tuning
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
print("🦞 LoRA adapters added!")
model.print_trainable_parameters()
# Format training data
print("🦞 Formatting training data...")
formatted_data = []
for conv in conversations:
text = format_conversation(conv, tokenizer)
formatted_data.append({"text": text})
dataset = Dataset.from_list(formatted_data)
print(f" Dataset size: {len(dataset)}")
# Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=4,
learning_rate=args.learning_rate,
weight_decay=0.01,
logging_steps=10,
save_steps=100,
save_total_limit=3,
fp16=True,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
report_to="none",
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
)
print("🦞 Starting training... EXFOLIATE!")
trainer.train()
print("🦞 Saving model...")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
print(f"🦞 Pushing to Hugging Face Hub: {args.hub_model_id}")
trainer.push_to_hub()
print("🦞 Training complete! New shell, same lobster. 🦞")
if __name__ == "__main__":
main()