Manthan-T1 / scripts /train_unsloth_kaggle.py
Atah Alam
Add Kaggle root trainer + fix Unsloth import order
d3df1cb
"""Kaggle/Unsloth training entrypoint for Manthan-T1 (TinyLLaVA-style).
This script is intended to be copied into a Kaggle notebook and run on 2×T4.
It supports two stages:
- stage1: projector alignment pretraining (e.g., LLaVA-CC3M-Pretrain-595K)
- stage2: instruction tuning (e.g., LLaVA-Instruct-150K)
Notes:
- We follow MicroLLaVA/TinyLLaVA convention: IMAGE_TOKEN_INDEX = -200 is inserted
into input_ids for <image> placeholders.
- Labels are IGNORE_INDEX for everything except assistant tokens.
- This script trains:
- the multimodal projector (always)
- LoRA adapters on the text model (optional, recommended)
- vision tower is frozen by default
You still need a *real* base model + vision tower weights. Stub exports will run
but won't learn useful vision-language alignment.
"""
from __future__ import annotations
import argparse
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from torch.utils.data import Dataset
try:
# Fallback for non-Unsloth environments
from peft import LoraConfig, get_peft_model
except Exception: # pragma: no cover
LoraConfig = None
get_peft_model = None
try:
# Kaggle + Unsloth
import unsloth # noqa: F401
from unsloth import FastLanguageModel
except Exception: # pragma: no cover
FastLanguageModel = None
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
try:
from datasets import load_dataset
except Exception as e: # pragma: no cover
raise RuntimeError(
"Missing dependency `datasets`. Install with `pip install datasets` (Kaggle: add to notebook)."
) from e
IMAGE_TOKEN_INDEX = -200
IGNORE_INDEX = -100
def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX) -> List[int]:
"""MicroLLaVA/TinyLLaVA tokenizer: split on '<image>' and insert a negative id."""
def _insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
input_ids: List[int] = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and tokenizer.bos_token_id is not None
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
return input_ids
def build_prompt_from_conversations(conversations: List[Dict[str, str]]) -> Tuple[str, str]:
"""Return (full_prompt, assistant_answer_text).
LLaVA datasets are 2-turn: human then gpt.
We map to the string template used in `ManthanForCausalLM.format_chat_prompt`.
"""
# Expect 2 turns
human = conversations[0]["value"]
assistant = conversations[1]["value"]
system = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
)
# IMPORTANT: no trailing space after ASSISTANT:
full = system + f"USER: {human.strip()} ASSISTANT:" + assistant
return full, assistant
@dataclass
class TrainExample:
input_ids: torch.LongTensor
labels: torch.LongTensor
image_path: str
class LlavaLikeDataset(Dataset):
def __init__(
self,
ds_name: str,
split: str,
tokenizer,
max_length: int,
limit: Optional[int] = None,
) -> None:
self.tokenizer = tokenizer
self.max_length = max_length
# Streaming keeps Kaggle disk usage low.
self.ds = load_dataset(ds_name, split=split, streaming=True)
self.limit = limit
# Materialize a small index for non-streaming dataloader behavior.
self._cache: List[Dict[str, Any]] = []
for i, ex in enumerate(self.ds):
self._cache.append(ex)
if limit is not None and i + 1 >= limit:
break
def __len__(self) -> int:
return len(self._cache)
def __getitem__(self, idx: int) -> TrainExample:
ex = self._cache[idx]
image_path = ex["image"]
conversations = ex["conversations"]
full_prompt, _assistant = build_prompt_from_conversations(conversations)
ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX)
# Truncate
ids = ids[: self.max_length]
# Labels: only learn on assistant answer tokens.
# Simple heuristic: find the last occurrence of " ASSISTANT:" marker.
marker = " ASSISTANT:"
marker_ids = self.tokenizer(marker).input_ids
# Find marker in tokenized ids (best-effort).
start = 0
for j in range(0, len(ids) - len(marker_ids) + 1):
if ids[j : j + len(marker_ids)] == marker_ids:
start = j + len(marker_ids)
labels = [IGNORE_INDEX] * len(ids)
for j in range(start, len(ids)):
if ids[j] == IMAGE_TOKEN_INDEX:
labels[j] = IGNORE_INDEX
else:
labels[j] = ids[j]
return TrainExample(
input_ids=torch.tensor(ids, dtype=torch.long),
labels=torch.tensor(labels, dtype=torch.long),
image_path=image_path,
)
def load_image_tensor(image_path: str, image_size: int) -> torch.FloatTensor:
"""Load image from local path in dataset.
In Kaggle, LLaVA datasets provide image paths relative to the dataset repo.
Hugging Face datasets streaming yields paths that resolve via HF cache.
"""
from PIL import Image
import torchvision.transforms as T
img = Image.open(image_path).convert("RGB")
tfm = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()])
return tfm(img)
def collate_fn(batch: List[TrainExample], image_size: int) -> Dict[str, torch.Tensor]:
# Pad to max length
max_len = max(x.input_ids.numel() for x in batch)
input_ids = torch.full((len(batch), max_len), 0, dtype=torch.long)
labels = torch.full((len(batch), max_len), IGNORE_INDEX, dtype=torch.long)
attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long)
for i, ex in enumerate(batch):
L = ex.input_ids.numel()
input_ids[i, :L] = ex.input_ids
labels[i, :L] = ex.labels
attention_mask[i, :L] = 1
# Images
pixel_values = torch.stack([load_image_tensor(ex.image_path, image_size) for ex in batch], dim=0)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
def set_requires_grad(module: nn.Module, requires_grad: bool) -> None:
for p in module.parameters():
p.requires_grad = requires_grad
def save_projector(model, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
if not hasattr(model, "projector"):
return
torch.save(model.projector.state_dict(), os.path.join(output_dir, "projector.pt"))
def maybe_add_lora_to_model(model, args) -> None:
"""Attach LoRA adapters (Unsloth preferred; PEFT fallback)."""
if not args.use_lora:
return
# If the model already has adapters (e.g., loaded via Unsloth), skip.
if hasattr(model, "peft_config"):
return
if get_peft_model is None or LoraConfig is None:
raise RuntimeError("PEFT not installed, and Unsloth not available. Install `peft` or enable Unsloth.")
target_modules = [
# Qwen-like
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
# GPT-like fallback
"c_attn",
"c_proj",
]
cfg = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
)
# Wrap the language model inside Manthan
model.language_model = get_peft_model(model.language_model, cfg)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--stage", choices=["stage1", "stage2"], required=True)
ap.add_argument("--text_model", type=str, default="Qwen/Qwen3-0.6B-Base")
ap.add_argument("--vision_model", type=str, default="google/siglip-so400m-patch14-384")
ap.add_argument("--dataset", type=str, required=True)
ap.add_argument("--output_dir", type=str, default="./outputs")
ap.add_argument("--max_length", type=int, default=2048)
ap.add_argument("--image_size", type=int, default=384)
ap.add_argument("--limit", type=int, default=2048, help="For debugging: number of samples to materialize")
# Training
ap.add_argument("--epochs", type=int, default=1)
ap.add_argument("--batch_size", type=int, default=1)
ap.add_argument("--grad_accum", type=int, default=16)
ap.add_argument("--lr", type=float, default=1e-4)
ap.add_argument("--warmup_ratio", type=float, default=0.03)
ap.add_argument("--use_lora", action="store_true")
ap.add_argument("--lora_r", type=int, default=16)
ap.add_argument("--lora_alpha", type=int, default=32)
ap.add_argument("--lora_dropout", type=float, default=0.05)
ap.add_argument(
"--manthan_model",
type=str,
required=True,
help="HF repo id or local path that contains Manthan remote-code (the thing you push to HF).",
)
ap.add_argument("--save_every", type=int, default=500)
ap.add_argument("--dry_run", action="store_true", help="Run a single synthetic step (no datasets).")
args = ap.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("WARNING: This script is designed for CUDA (Kaggle). Running on CPU will be extremely slow.")
# Tokenizer (use the LLM tokenizer)
tok = AutoTokenizer.from_pretrained(args.text_model, trust_remote_code=True, use_fast=False)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
# Load Manthan remote-code model
# (This should contain config that points to your desired text_model_id & vision_model_id.)
model = AutoModelForCausalLM.from_pretrained(
args.manthan_model,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else None,
)
model.train()
model.to(device)
# Make sure we don't train the vision tower (T4-friendly)
if hasattr(model, "vision_model") and model.vision_model is not None:
set_requires_grad(model.vision_model, False)
if hasattr(model, "vision_tower") and model.vision_tower is not None:
set_requires_grad(model.vision_tower, False)
# Train projector always
if hasattr(model, "projector"):
set_requires_grad(model.projector, True)
# Add LoRA to the language model (recommended)
maybe_add_lora_to_model(model, args)
# Optimizer params = trainable only
trainable_params = [p for p in model.parameters() if p.requires_grad]
if len(trainable_params) == 0:
raise RuntimeError("No trainable parameters. Did you freeze everything?")
optim = torch.optim.AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
# Data
if args.dry_run:
# Minimal synthetic batch (no images on disk). This just validates loss pathway.
B, T = 1, min(64, args.max_length)
# IMPORTANT: some tokenizers report an imprecise `vocab_size`; `len(tok)` is the safe upper bound.
tok_vocab = int(len(tok))
input_ids = torch.randint(low=0, high=max(tok_vocab - 1, 1), size=(B, T), dtype=torch.long)
labels = input_ids.clone()
attn = torch.ones_like(input_ids)
pixel_values = torch.randn(B, 3, args.image_size, args.image_size)
# Insert one image placeholder
input_ids[0, 5] = IMAGE_TOKEN_INDEX
labels[0, :10] = IGNORE_INDEX
# If tokenizer vocab > model vocab (common in dry_run), clamp to avoid CE index errors.
lm_vocab = None
try:
if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
lm_vocab = int(getattr(model.language_model.config, "vocab_size", 0) or 0)
except Exception:
lm_vocab = None
if lm_vocab and lm_vocab > 0:
safe_ids = input_ids.clone()
mask = safe_ids >= 0
safe_ids[mask] = safe_ids[mask].clamp(min=0, max=lm_vocab - 1)
input_ids = safe_ids
safe_labels = labels.clone()
mask = safe_labels >= 0
safe_labels[mask] = safe_labels[mask].clamp(min=0, max=lm_vocab - 1)
labels = safe_labels
batch = {
"input_ids": input_ids.to(device),
"labels": labels.to(device),
"attention_mask": attn.to(device),
"pixel_values": pixel_values.to(device),
}
out = model(**batch)
print("dry_run loss:", float(out.loss))
out.loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
save_projector(model, args.output_dir)
if hasattr(model, "language_model") and hasattr(model.language_model, "save_pretrained"):
# Save adapters if present
try:
model.language_model.save_pretrained(args.output_dir)
except Exception:
pass
return 0
ds = LlavaLikeDataset(args.dataset, split="train", tokenizer=tok, max_length=args.max_length, limit=args.limit)
from torch.utils.data import DataLoader
dl = DataLoader(
ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
collate_fn=lambda b: collate_fn(b, args.image_size),
)
total_steps = (len(dl) * args.epochs) // max(1, args.grad_accum)
warmup_steps = max(1, int(total_steps * args.warmup_ratio))
sched = get_cosine_schedule_with_warmup(optim, warmup_steps, total_steps)
step = 0
optim.zero_grad(set_to_none=True)
for epoch in range(args.epochs):
for micro_idx, batch in enumerate(dl):
batch = {k: v.to(device) for k, v in batch.items()}
# Mixed precision on Kaggle
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device == "cuda")):
out = model(**batch)
loss = out.loss / max(1, args.grad_accum)
loss.backward()
if (micro_idx + 1) % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optim.step()
sched.step()
optim.zero_grad(set_to_none=True)
step += 1
if step % 10 == 0:
print(f"epoch={epoch} step={step}/{total_steps} loss={float(out.loss):.4f}")
if step % args.save_every == 0:
save_projector(model, args.output_dir)
# Save adapters if any
try:
model.save_pretrained(args.output_dir)
except Exception:
pass
if step >= total_steps:
break
save_projector(model, args.output_dir)
try:
model.save_pretrained(args.output_dir)
except Exception:
pass
print("DONE")
return 0
if __name__ == "__main__":
raise SystemExit(main())