|
|
"""Kaggle/Unsloth training entrypoint for Manthan-T1 (TinyLLaVA-style). |
|
|
|
|
|
This script is intended to be copied into a Kaggle notebook and run on 2×T4. |
|
|
It supports two stages: |
|
|
- stage1: projector alignment pretraining (e.g., LLaVA-CC3M-Pretrain-595K) |
|
|
- stage2: instruction tuning (e.g., LLaVA-Instruct-150K) |
|
|
|
|
|
Notes: |
|
|
- We follow MicroLLaVA/TinyLLaVA convention: IMAGE_TOKEN_INDEX = -200 is inserted |
|
|
into input_ids for <image> placeholders. |
|
|
- Labels are IGNORE_INDEX for everything except assistant tokens. |
|
|
- This script trains: |
|
|
- the multimodal projector (always) |
|
|
- LoRA adapters on the text model (optional, recommended) |
|
|
- vision tower is frozen by default |
|
|
|
|
|
You still need a *real* base model + vision tower weights. Stub exports will run |
|
|
but won't learn useful vision-language alignment. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
try: |
|
|
|
|
|
from peft import LoraConfig, get_peft_model |
|
|
except Exception: |
|
|
LoraConfig = None |
|
|
get_peft_model = None |
|
|
|
|
|
try: |
|
|
|
|
|
import unsloth |
|
|
from unsloth import FastLanguageModel |
|
|
except Exception: |
|
|
FastLanguageModel = None |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup |
|
|
|
|
|
try: |
|
|
from datasets import load_dataset |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Missing dependency `datasets`. Install with `pip install datasets` (Kaggle: add to notebook)." |
|
|
) from e |
|
|
|
|
|
|
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
IGNORE_INDEX = -100 |
|
|
|
|
|
|
|
|
def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX) -> List[int]: |
|
|
"""MicroLLaVA/TinyLLaVA tokenizer: split on '<image>' and insert a negative id.""" |
|
|
|
|
|
def _insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
|
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")] |
|
|
|
|
|
input_ids: List[int] = [] |
|
|
offset = 0 |
|
|
if ( |
|
|
len(prompt_chunks) > 0 |
|
|
and len(prompt_chunks[0]) > 0 |
|
|
and tokenizer.bos_token_id is not None |
|
|
and prompt_chunks[0][0] == tokenizer.bos_token_id |
|
|
): |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
return input_ids |
|
|
|
|
|
|
|
|
def build_prompt_from_conversations(conversations: List[Dict[str, str]]) -> Tuple[str, str]: |
|
|
"""Return (full_prompt, assistant_answer_text). |
|
|
|
|
|
LLaVA datasets are 2-turn: human then gpt. |
|
|
We map to the string template used in `ManthanForCausalLM.format_chat_prompt`. |
|
|
""" |
|
|
|
|
|
|
|
|
human = conversations[0]["value"] |
|
|
assistant = conversations[1]["value"] |
|
|
|
|
|
system = ( |
|
|
"A chat between a curious user and an artificial intelligence assistant. " |
|
|
"The assistant gives helpful, detailed, and polite answers to the user's questions. " |
|
|
) |
|
|
|
|
|
full = system + f"USER: {human.strip()} ASSISTANT:" + assistant |
|
|
return full, assistant |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainExample: |
|
|
input_ids: torch.LongTensor |
|
|
labels: torch.LongTensor |
|
|
image_path: str |
|
|
|
|
|
|
|
|
class LlavaLikeDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ds_name: str, |
|
|
split: str, |
|
|
tokenizer, |
|
|
max_length: int, |
|
|
limit: Optional[int] = None, |
|
|
) -> None: |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
|
|
|
self.ds = load_dataset(ds_name, split=split, streaming=True) |
|
|
self.limit = limit |
|
|
|
|
|
|
|
|
self._cache: List[Dict[str, Any]] = [] |
|
|
for i, ex in enumerate(self.ds): |
|
|
self._cache.append(ex) |
|
|
if limit is not None and i + 1 >= limit: |
|
|
break |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self._cache) |
|
|
|
|
|
def __getitem__(self, idx: int) -> TrainExample: |
|
|
ex = self._cache[idx] |
|
|
image_path = ex["image"] |
|
|
conversations = ex["conversations"] |
|
|
|
|
|
full_prompt, _assistant = build_prompt_from_conversations(conversations) |
|
|
ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX) |
|
|
|
|
|
|
|
|
ids = ids[: self.max_length] |
|
|
|
|
|
|
|
|
|
|
|
marker = " ASSISTANT:" |
|
|
marker_ids = self.tokenizer(marker).input_ids |
|
|
|
|
|
|
|
|
start = 0 |
|
|
for j in range(0, len(ids) - len(marker_ids) + 1): |
|
|
if ids[j : j + len(marker_ids)] == marker_ids: |
|
|
start = j + len(marker_ids) |
|
|
|
|
|
labels = [IGNORE_INDEX] * len(ids) |
|
|
for j in range(start, len(ids)): |
|
|
if ids[j] == IMAGE_TOKEN_INDEX: |
|
|
labels[j] = IGNORE_INDEX |
|
|
else: |
|
|
labels[j] = ids[j] |
|
|
|
|
|
return TrainExample( |
|
|
input_ids=torch.tensor(ids, dtype=torch.long), |
|
|
labels=torch.tensor(labels, dtype=torch.long), |
|
|
image_path=image_path, |
|
|
) |
|
|
|
|
|
|
|
|
def load_image_tensor(image_path: str, image_size: int) -> torch.FloatTensor: |
|
|
"""Load image from local path in dataset. |
|
|
|
|
|
In Kaggle, LLaVA datasets provide image paths relative to the dataset repo. |
|
|
Hugging Face datasets streaming yields paths that resolve via HF cache. |
|
|
""" |
|
|
|
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
|
|
|
img = Image.open(image_path).convert("RGB") |
|
|
tfm = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()]) |
|
|
return tfm(img) |
|
|
|
|
|
|
|
|
def collate_fn(batch: List[TrainExample], image_size: int) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
max_len = max(x.input_ids.numel() for x in batch) |
|
|
input_ids = torch.full((len(batch), max_len), 0, dtype=torch.long) |
|
|
labels = torch.full((len(batch), max_len), IGNORE_INDEX, dtype=torch.long) |
|
|
attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long) |
|
|
|
|
|
for i, ex in enumerate(batch): |
|
|
L = ex.input_ids.numel() |
|
|
input_ids[i, :L] = ex.input_ids |
|
|
labels[i, :L] = ex.labels |
|
|
attention_mask[i, :L] = 1 |
|
|
|
|
|
|
|
|
pixel_values = torch.stack([load_image_tensor(ex.image_path, image_size) for ex in batch], dim=0) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"labels": labels, |
|
|
"attention_mask": attention_mask, |
|
|
"pixel_values": pixel_values, |
|
|
} |
|
|
|
|
|
|
|
|
def set_requires_grad(module: nn.Module, requires_grad: bool) -> None: |
|
|
for p in module.parameters(): |
|
|
p.requires_grad = requires_grad |
|
|
|
|
|
|
|
|
def save_projector(model, output_dir: str) -> None: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
if not hasattr(model, "projector"): |
|
|
return |
|
|
torch.save(model.projector.state_dict(), os.path.join(output_dir, "projector.pt")) |
|
|
|
|
|
|
|
|
def maybe_add_lora_to_model(model, args) -> None: |
|
|
"""Attach LoRA adapters (Unsloth preferred; PEFT fallback).""" |
|
|
|
|
|
if not args.use_lora: |
|
|
return |
|
|
|
|
|
|
|
|
if hasattr(model, "peft_config"): |
|
|
return |
|
|
|
|
|
if get_peft_model is None or LoraConfig is None: |
|
|
raise RuntimeError("PEFT not installed, and Unsloth not available. Install `peft` or enable Unsloth.") |
|
|
|
|
|
target_modules = [ |
|
|
|
|
|
"q_proj", |
|
|
"k_proj", |
|
|
"v_proj", |
|
|
"o_proj", |
|
|
"gate_proj", |
|
|
"up_proj", |
|
|
"down_proj", |
|
|
|
|
|
"c_attn", |
|
|
"c_proj", |
|
|
] |
|
|
cfg = LoraConfig( |
|
|
r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
lora_dropout=args.lora_dropout, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
target_modules=target_modules, |
|
|
) |
|
|
|
|
|
|
|
|
model.language_model = get_peft_model(model.language_model, cfg) |
|
|
|
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--stage", choices=["stage1", "stage2"], required=True) |
|
|
ap.add_argument("--text_model", type=str, default="Qwen/Qwen3-0.6B-Base") |
|
|
ap.add_argument("--vision_model", type=str, default="google/siglip-so400m-patch14-384") |
|
|
ap.add_argument("--dataset", type=str, required=True) |
|
|
ap.add_argument("--output_dir", type=str, default="./outputs") |
|
|
ap.add_argument("--max_length", type=int, default=2048) |
|
|
ap.add_argument("--image_size", type=int, default=384) |
|
|
ap.add_argument("--limit", type=int, default=2048, help="For debugging: number of samples to materialize") |
|
|
|
|
|
|
|
|
ap.add_argument("--epochs", type=int, default=1) |
|
|
ap.add_argument("--batch_size", type=int, default=1) |
|
|
ap.add_argument("--grad_accum", type=int, default=16) |
|
|
ap.add_argument("--lr", type=float, default=1e-4) |
|
|
ap.add_argument("--warmup_ratio", type=float, default=0.03) |
|
|
ap.add_argument("--use_lora", action="store_true") |
|
|
ap.add_argument("--lora_r", type=int, default=16) |
|
|
ap.add_argument("--lora_alpha", type=int, default=32) |
|
|
ap.add_argument("--lora_dropout", type=float, default=0.05) |
|
|
|
|
|
ap.add_argument( |
|
|
"--manthan_model", |
|
|
type=str, |
|
|
required=True, |
|
|
help="HF repo id or local path that contains Manthan remote-code (the thing you push to HF).", |
|
|
) |
|
|
ap.add_argument("--save_every", type=int, default=500) |
|
|
ap.add_argument("--dry_run", action="store_true", help="Run a single synthetic step (no datasets).") |
|
|
|
|
|
args = ap.parse_args() |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if device != "cuda": |
|
|
print("WARNING: This script is designed for CUDA (Kaggle). Running on CPU will be extremely slow.") |
|
|
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(args.text_model, trust_remote_code=True, use_fast=False) |
|
|
if tok.pad_token_id is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.manthan_model, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if device == "cuda" else None, |
|
|
) |
|
|
|
|
|
model.train() |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
if hasattr(model, "vision_model") and model.vision_model is not None: |
|
|
set_requires_grad(model.vision_model, False) |
|
|
if hasattr(model, "vision_tower") and model.vision_tower is not None: |
|
|
set_requires_grad(model.vision_tower, False) |
|
|
|
|
|
|
|
|
if hasattr(model, "projector"): |
|
|
set_requires_grad(model.projector, True) |
|
|
|
|
|
|
|
|
maybe_add_lora_to_model(model, args) |
|
|
|
|
|
|
|
|
trainable_params = [p for p in model.parameters() if p.requires_grad] |
|
|
if len(trainable_params) == 0: |
|
|
raise RuntimeError("No trainable parameters. Did you freeze everything?") |
|
|
|
|
|
optim = torch.optim.AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01) |
|
|
|
|
|
|
|
|
if args.dry_run: |
|
|
|
|
|
B, T = 1, min(64, args.max_length) |
|
|
|
|
|
|
|
|
tok_vocab = int(len(tok)) |
|
|
input_ids = torch.randint(low=0, high=max(tok_vocab - 1, 1), size=(B, T), dtype=torch.long) |
|
|
labels = input_ids.clone() |
|
|
attn = torch.ones_like(input_ids) |
|
|
pixel_values = torch.randn(B, 3, args.image_size, args.image_size) |
|
|
|
|
|
|
|
|
input_ids[0, 5] = IMAGE_TOKEN_INDEX |
|
|
labels[0, :10] = IGNORE_INDEX |
|
|
|
|
|
|
|
|
lm_vocab = None |
|
|
try: |
|
|
if hasattr(model, "language_model") and hasattr(model.language_model, "config"): |
|
|
lm_vocab = int(getattr(model.language_model.config, "vocab_size", 0) or 0) |
|
|
except Exception: |
|
|
lm_vocab = None |
|
|
|
|
|
if lm_vocab and lm_vocab > 0: |
|
|
safe_ids = input_ids.clone() |
|
|
mask = safe_ids >= 0 |
|
|
safe_ids[mask] = safe_ids[mask].clamp(min=0, max=lm_vocab - 1) |
|
|
input_ids = safe_ids |
|
|
|
|
|
safe_labels = labels.clone() |
|
|
mask = safe_labels >= 0 |
|
|
safe_labels[mask] = safe_labels[mask].clamp(min=0, max=lm_vocab - 1) |
|
|
labels = safe_labels |
|
|
|
|
|
batch = { |
|
|
"input_ids": input_ids.to(device), |
|
|
"labels": labels.to(device), |
|
|
"attention_mask": attn.to(device), |
|
|
"pixel_values": pixel_values.to(device), |
|
|
} |
|
|
out = model(**batch) |
|
|
print("dry_run loss:", float(out.loss)) |
|
|
out.loss.backward() |
|
|
optim.step() |
|
|
optim.zero_grad(set_to_none=True) |
|
|
save_projector(model, args.output_dir) |
|
|
if hasattr(model, "language_model") and hasattr(model.language_model, "save_pretrained"): |
|
|
|
|
|
try: |
|
|
model.language_model.save_pretrained(args.output_dir) |
|
|
except Exception: |
|
|
pass |
|
|
return 0 |
|
|
|
|
|
ds = LlavaLikeDataset(args.dataset, split="train", tokenizer=tok, max_length=args.max_length, limit=args.limit) |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
dl = DataLoader( |
|
|
ds, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
collate_fn=lambda b: collate_fn(b, args.image_size), |
|
|
) |
|
|
|
|
|
total_steps = (len(dl) * args.epochs) // max(1, args.grad_accum) |
|
|
warmup_steps = max(1, int(total_steps * args.warmup_ratio)) |
|
|
sched = get_cosine_schedule_with_warmup(optim, warmup_steps, total_steps) |
|
|
|
|
|
step = 0 |
|
|
optim.zero_grad(set_to_none=True) |
|
|
for epoch in range(args.epochs): |
|
|
for micro_idx, batch in enumerate(dl): |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device == "cuda")): |
|
|
out = model(**batch) |
|
|
loss = out.loss / max(1, args.grad_accum) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
if (micro_idx + 1) % args.grad_accum == 0: |
|
|
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) |
|
|
optim.step() |
|
|
sched.step() |
|
|
optim.zero_grad(set_to_none=True) |
|
|
step += 1 |
|
|
|
|
|
if step % 10 == 0: |
|
|
print(f"epoch={epoch} step={step}/{total_steps} loss={float(out.loss):.4f}") |
|
|
|
|
|
if step % args.save_every == 0: |
|
|
save_projector(model, args.output_dir) |
|
|
|
|
|
try: |
|
|
model.save_pretrained(args.output_dir) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if step >= total_steps: |
|
|
break |
|
|
|
|
|
save_projector(model, args.output_dir) |
|
|
try: |
|
|
model.save_pretrained(args.output_dir) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
print("DONE") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
raise SystemExit(main()) |
|
|
|