TrashCollector / train_unsloth.py
Mihir Mithani
Sync Hub-enabled code to Space (no weights)
a8d4cdf
"""
Fine-tuning Llama-3.2-3B-Instruct with Unsloth for the Garbage Collecting Robot.
Training data: fixed_dataset.jsonl (generated by code2.py + fixer.py)
Format: {"user": "### Instruction:\n...\n\n### Input:\nENVIRONMENT STATUS:\n...", "assistant": "UP|DOWN|LEFT|RIGHT|COLLECT"}
Base model: unsloth/llama-3.2-3b-instruct-bnb-4bit (same as Unsloth Studio run)
Export: lora_garbage_robot/ (LoRA adapter)
"""
import os
import json
from datasets import Dataset
max_seq_length = 512 # Prompts are short; 512 is well above the longest sample
dtype = None # Auto-detect (float16 on T4, bfloat16 on Ampere+)
load_in_4bit = True
# ── Alpaca prompt — MUST match fixed_dataset.jsonl / code2.py / app.py ──────
ALPACA_TEMPLATE = (
"### Instruction:\n{instruction}\n\n"
"### Input:\nENVIRONMENT STATUS:\n{input}\n\n"
"### Response:\n{response}"
)
INSTRUCTION = (
"You are an AI brain controlling a garbage collecting robot.\n"
"Reply with EXACTLY ONE of: UP DOWN LEFT RIGHT COLLECT"
)
EOS_TOKEN = None # filled in after tokenizer loads
def load_fixed_dataset(path: str = "fixed_dataset.jsonl") -> Dataset:
"""
Load fixed_dataset.jsonl produced by fixer.py.
Each row: {"user": "<### Instruction:...### Input:...>", "assistant": "<ACTION>"}
We re-format into the full Alpaca text so the model sees input + target in one string.
"""
rows = []
with open(path, "r") as f:
for line in f:
row = json.loads(line)
user_text = row["user"] # already contains ### Instruction + ### Input
assistant = row["assistant"] # e.g. "RIGHT"
# Extract the environment status message from the user field
try:
env_status = user_text.split("ENVIRONMENT STATUS:\n")[1].strip()
except IndexError:
continue # skip malformed rows
text = ALPACA_TEMPLATE.format(
instruction=INSTRUCTION,
input=env_status,
response=assistant,
) + (EOS_TOKEN or "")
rows.append({"text": text})
print(f"[Dataset] Loaded {len(rows):,} samples from {path}")
return Dataset.from_list(rows)
def main():
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
global EOS_TOKEN
print("=" * 60)
print(" Fine-tuning Llama-3.2-3B-Instruct — Garbage Robot")
print("=" * 60)
# ── 1. Load base model (same as Unsloth Studio session) ──────────────────
print("\n[1/4] Loading base model …")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/llama-3.2-3b-instruct-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
EOS_TOKEN = tokenizer.eos_token # fill in for dataset formatting
# ── 2. Add LoRA adapters ─────────────────────────────────────────────────
print("[2/4] Attaching LoRA adapters …")
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
# ── 3. Load dataset ──────────────────────────────────────────────────────
print("[3/4] Loading fixed_dataset.jsonl …")
dataset = load_fixed_dataset("fixed_dataset.jsonl")
# ── 4. Train ─────────────────────────────────────────────────────────────
print("[4/4] Starting fine-tuning …")
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = True, # efficient for short sequences
args = TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
warmup_ratio = 0.03,
num_train_epochs = 1,
learning_rate = 2e-4,
fp16 = not FastLanguageModel.is_bfloat16_supported(),
bf16 = FastLanguageModel.is_bfloat16_supported(),
logging_steps = 10,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
output_dir = "outputs",
save_strategy = "epoch",
),
)
trainer_stats = trainer.train()
print(f"\nTraining complete. Loss: {trainer_stats.training_loss:.4f}")
# ── Save LoRA adapter ────────────────────────────────────────────────────
model.save_pretrained("lora_garbage_robot")
tokenizer.save_pretrained("lora_garbage_robot")
print("\nLoRA adapter saved to: lora_garbage_robot/")
print("To export a merged model, use Unsloth Studio → Export → Merged Model.")
if __name__ == "__main__":
main()