"""Kaggle/Unsloth training entrypoint for Manthan-T1 (TinyLLaVA-style). This script is intended to be copied into a Kaggle notebook and run on 2×T4. It supports two stages: - stage1: projector alignment pretraining (e.g., LLaVA-CC3M-Pretrain-595K) - stage2: instruction tuning (e.g., LLaVA-Instruct-150K) Notes: - We follow MicroLLaVA/TinyLLaVA convention: IMAGE_TOKEN_INDEX = -200 is inserted into input_ids for placeholders. - Labels are IGNORE_INDEX for everything except assistant tokens. - This script trains: - the multimodal projector (always) - LoRA adapters on the text model (optional, recommended) - vision tower is frozen by default You still need a *real* base model + vision tower weights. Stub exports will run but won't learn useful vision-language alignment. """ from __future__ import annotations import argparse import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import torch from torch import nn from torch.utils.data import Dataset try: # Fallback for non-Unsloth environments from peft import LoraConfig, get_peft_model except Exception: # pragma: no cover LoraConfig = None get_peft_model = None try: # Kaggle + Unsloth import unsloth # noqa: F401 from unsloth import FastLanguageModel except Exception: # pragma: no cover FastLanguageModel = None from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup try: from datasets import load_dataset except Exception as e: # pragma: no cover raise RuntimeError( "Missing dependency `datasets`. Install with `pip install datasets` (Kaggle: add to notebook)." ) from e IMAGE_TOKEN_INDEX = -200 IGNORE_INDEX = -100 def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX) -> List[int]: """MicroLLaVA/TinyLLaVA tokenizer: split on '' and insert a negative id.""" def _insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] input_ids: List[int] = [] offset = 0 if ( len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and tokenizer.bos_token_id is not None and prompt_chunks[0][0] == tokenizer.bos_token_id ): offset = 1 input_ids.append(prompt_chunks[0][0]) for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) return input_ids def build_prompt_from_conversations(conversations: List[Dict[str, str]]) -> Tuple[str, str]: """Return (full_prompt, assistant_answer_text). LLaVA datasets are 2-turn: human then gpt. We map to the string template used in `ManthanForCausalLM.format_chat_prompt`. """ # Expect 2 turns human = conversations[0]["value"] assistant = conversations[1]["value"] system = ( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions. " ) # IMPORTANT: no trailing space after ASSISTANT: full = system + f"USER: {human.strip()} ASSISTANT:" + assistant return full, assistant @dataclass class TrainExample: input_ids: torch.LongTensor labels: torch.LongTensor image_path: str class LlavaLikeDataset(Dataset): def __init__( self, ds_name: str, split: str, tokenizer, max_length: int, limit: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.max_length = max_length # Streaming keeps Kaggle disk usage low. self.ds = load_dataset(ds_name, split=split, streaming=True) self.limit = limit # Materialize a small index for non-streaming dataloader behavior. self._cache: List[Dict[str, Any]] = [] for i, ex in enumerate(self.ds): self._cache.append(ex) if limit is not None and i + 1 >= limit: break def __len__(self) -> int: return len(self._cache) def __getitem__(self, idx: int) -> TrainExample: ex = self._cache[idx] image_path = ex["image"] conversations = ex["conversations"] full_prompt, _assistant = build_prompt_from_conversations(conversations) ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX) # Truncate ids = ids[: self.max_length] # Labels: only learn on assistant answer tokens. # Simple heuristic: find the last occurrence of " ASSISTANT:" marker. marker = " ASSISTANT:" marker_ids = self.tokenizer(marker).input_ids # Find marker in tokenized ids (best-effort). start = 0 for j in range(0, len(ids) - len(marker_ids) + 1): if ids[j : j + len(marker_ids)] == marker_ids: start = j + len(marker_ids) labels = [IGNORE_INDEX] * len(ids) for j in range(start, len(ids)): if ids[j] == IMAGE_TOKEN_INDEX: labels[j] = IGNORE_INDEX else: labels[j] = ids[j] return TrainExample( input_ids=torch.tensor(ids, dtype=torch.long), labels=torch.tensor(labels, dtype=torch.long), image_path=image_path, ) def load_image_tensor(image_path: str, image_size: int) -> torch.FloatTensor: """Load image from local path in dataset. In Kaggle, LLaVA datasets provide image paths relative to the dataset repo. Hugging Face datasets streaming yields paths that resolve via HF cache. """ from PIL import Image import torchvision.transforms as T img = Image.open(image_path).convert("RGB") tfm = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()]) return tfm(img) def collate_fn(batch: List[TrainExample], image_size: int) -> Dict[str, torch.Tensor]: # Pad to max length max_len = max(x.input_ids.numel() for x in batch) input_ids = torch.full((len(batch), max_len), 0, dtype=torch.long) labels = torch.full((len(batch), max_len), IGNORE_INDEX, dtype=torch.long) attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long) for i, ex in enumerate(batch): L = ex.input_ids.numel() input_ids[i, :L] = ex.input_ids labels[i, :L] = ex.labels attention_mask[i, :L] = 1 # Images pixel_values = torch.stack([load_image_tensor(ex.image_path, image_size) for ex in batch], dim=0) return { "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, "pixel_values": pixel_values, } def set_requires_grad(module: nn.Module, requires_grad: bool) -> None: for p in module.parameters(): p.requires_grad = requires_grad def save_projector(model, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) if not hasattr(model, "projector"): return torch.save(model.projector.state_dict(), os.path.join(output_dir, "projector.pt")) def maybe_add_lora_to_model(model, args) -> None: """Attach LoRA adapters (Unsloth preferred; PEFT fallback).""" if not args.use_lora: return # If the model already has adapters (e.g., loaded via Unsloth), skip. if hasattr(model, "peft_config"): return if get_peft_model is None or LoraConfig is None: raise RuntimeError("PEFT not installed, and Unsloth not available. Install `peft` or enable Unsloth.") target_modules = [ # Qwen-like "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", # GPT-like fallback "c_attn", "c_proj", ] cfg = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=target_modules, ) # Wrap the language model inside Manthan model.language_model = get_peft_model(model.language_model, cfg) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--stage", choices=["stage1", "stage2"], required=True) ap.add_argument("--text_model", type=str, default="Qwen/Qwen3-0.6B-Base") ap.add_argument("--vision_model", type=str, default="google/siglip-so400m-patch14-384") ap.add_argument("--dataset", type=str, required=True) ap.add_argument("--output_dir", type=str, default="./outputs") ap.add_argument("--max_length", type=int, default=2048) ap.add_argument("--image_size", type=int, default=384) ap.add_argument("--limit", type=int, default=2048, help="For debugging: number of samples to materialize") # Training ap.add_argument("--epochs", type=int, default=1) ap.add_argument("--batch_size", type=int, default=1) ap.add_argument("--grad_accum", type=int, default=16) ap.add_argument("--lr", type=float, default=1e-4) ap.add_argument("--warmup_ratio", type=float, default=0.03) ap.add_argument("--use_lora", action="store_true") ap.add_argument("--lora_r", type=int, default=16) ap.add_argument("--lora_alpha", type=int, default=32) ap.add_argument("--lora_dropout", type=float, default=0.05) ap.add_argument( "--manthan_model", type=str, required=True, help="HF repo id or local path that contains Manthan remote-code (the thing you push to HF).", ) ap.add_argument("--save_every", type=int, default=500) ap.add_argument("--dry_run", action="store_true", help="Run a single synthetic step (no datasets).") args = ap.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" if device != "cuda": print("WARNING: This script is designed for CUDA (Kaggle). Running on CPU will be extremely slow.") # Tokenizer (use the LLM tokenizer) tok = AutoTokenizer.from_pretrained(args.text_model, trust_remote_code=True, use_fast=False) if tok.pad_token_id is None: tok.pad_token = tok.eos_token # Load Manthan remote-code model # (This should contain config that points to your desired text_model_id & vision_model_id.) model = AutoModelForCausalLM.from_pretrained( args.manthan_model, trust_remote_code=True, torch_dtype=torch.float16 if device == "cuda" else None, ) model.train() model.to(device) # Make sure we don't train the vision tower (T4-friendly) if hasattr(model, "vision_model") and model.vision_model is not None: set_requires_grad(model.vision_model, False) if hasattr(model, "vision_tower") and model.vision_tower is not None: set_requires_grad(model.vision_tower, False) # Train projector always if hasattr(model, "projector"): set_requires_grad(model.projector, True) # Add LoRA to the language model (recommended) maybe_add_lora_to_model(model, args) # Optimizer params = trainable only trainable_params = [p for p in model.parameters() if p.requires_grad] if len(trainable_params) == 0: raise RuntimeError("No trainable parameters. Did you freeze everything?") optim = torch.optim.AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01) # Data if args.dry_run: # Minimal synthetic batch (no images on disk). This just validates loss pathway. B, T = 1, min(64, args.max_length) # IMPORTANT: some tokenizers report an imprecise `vocab_size`; `len(tok)` is the safe upper bound. tok_vocab = int(len(tok)) input_ids = torch.randint(low=0, high=max(tok_vocab - 1, 1), size=(B, T), dtype=torch.long) labels = input_ids.clone() attn = torch.ones_like(input_ids) pixel_values = torch.randn(B, 3, args.image_size, args.image_size) # Insert one image placeholder input_ids[0, 5] = IMAGE_TOKEN_INDEX labels[0, :10] = IGNORE_INDEX # If tokenizer vocab > model vocab (common in dry_run), clamp to avoid CE index errors. lm_vocab = None try: if hasattr(model, "language_model") and hasattr(model.language_model, "config"): lm_vocab = int(getattr(model.language_model.config, "vocab_size", 0) or 0) except Exception: lm_vocab = None if lm_vocab and lm_vocab > 0: safe_ids = input_ids.clone() mask = safe_ids >= 0 safe_ids[mask] = safe_ids[mask].clamp(min=0, max=lm_vocab - 1) input_ids = safe_ids safe_labels = labels.clone() mask = safe_labels >= 0 safe_labels[mask] = safe_labels[mask].clamp(min=0, max=lm_vocab - 1) labels = safe_labels batch = { "input_ids": input_ids.to(device), "labels": labels.to(device), "attention_mask": attn.to(device), "pixel_values": pixel_values.to(device), } out = model(**batch) print("dry_run loss:", float(out.loss)) out.loss.backward() optim.step() optim.zero_grad(set_to_none=True) save_projector(model, args.output_dir) if hasattr(model, "language_model") and hasattr(model.language_model, "save_pretrained"): # Save adapters if present try: model.language_model.save_pretrained(args.output_dir) except Exception: pass return 0 ds = LlavaLikeDataset(args.dataset, split="train", tokenizer=tok, max_length=args.max_length, limit=args.limit) from torch.utils.data import DataLoader dl = DataLoader( ds, batch_size=args.batch_size, shuffle=True, num_workers=2, collate_fn=lambda b: collate_fn(b, args.image_size), ) total_steps = (len(dl) * args.epochs) // max(1, args.grad_accum) warmup_steps = max(1, int(total_steps * args.warmup_ratio)) sched = get_cosine_schedule_with_warmup(optim, warmup_steps, total_steps) step = 0 optim.zero_grad(set_to_none=True) for epoch in range(args.epochs): for micro_idx, batch in enumerate(dl): batch = {k: v.to(device) for k, v in batch.items()} # Mixed precision on Kaggle with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device == "cuda")): out = model(**batch) loss = out.loss / max(1, args.grad_accum) loss.backward() if (micro_idx + 1) % args.grad_accum == 0: torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optim.step() sched.step() optim.zero_grad(set_to_none=True) step += 1 if step % 10 == 0: print(f"epoch={epoch} step={step}/{total_steps} loss={float(out.loss):.4f}") if step % args.save_every == 0: save_projector(model, args.output_dir) # Save adapters if any try: model.save_pretrained(args.output_dir) except Exception: pass if step >= total_steps: break save_projector(model, args.output_dir) try: model.save_pretrained(args.output_dir) except Exception: pass print("DONE") return 0 if __name__ == "__main__": raise SystemExit(main())