Atah Alam commited on
Commit
7f7a72e
·
0 Parent(s):

Manthan-T1 clean code-only

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ # Keep this repo code-only. Don"t use Git LFS here.
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ venv/
3
+ __pycache__/
4
+ *.pyc
5
+ .DS_Store
6
+
7
+ # python tooling
8
+ .pytest_cache/
9
+ .mypy_cache/
10
+ .ruff_cache/
11
+
12
+ # local caches
13
+ .cache/
14
+ hf/
15
+
16
+ # training artifacts
17
+ wandb/
18
+ runs/
19
+ checkpoints/
20
+ outputs/
21
+ tmp/
22
+
23
+ # local exports (keep them out of git; publish separately if needed)
24
+ hf_export_ready/
.hfignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ venv/
3
+ __pycache__/
4
+ *.pyc
5
+ .DS_Store
6
+
7
+ # local/hf caches
8
+ .cache/
9
+ hf/
10
+
11
+ # training artifacts
12
+ wandb/
13
+ runs/
14
+ checkpoints/
15
+ outputs/
16
+ tmp/
17
+ /tmp/
18
+
19
+ # large local exports (only push intentionally)
20
+ hf_export_ready/
21
+
22
+ .venv/
23
+ __pycache__/
24
+ *.pyc
25
+ .DS_Store
MODEL_CARD.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - pytorch
8
+ - safetensors
9
+ - vision-language
10
+ - visual-question-answering
11
+ pipeline_tag: visual-question-answering
12
+ base_model:
13
+ - Qwen/Qwen3-0.6B
14
+ - google/siglip-so400m-patch14-384
15
+ model-index:
16
+ - name: Manthan-T1
17
+ results:
18
+ - task:
19
+ type: visual-question-answering
20
+ name: VQAv2
21
+ dataset:
22
+ name: VQAv2
23
+ type: vqav2
24
+ metrics:
25
+ - name: Overall Accuracy
26
+ type: accuracy
27
+ value: 0.0
28
+ - name: Yes/No Accuracy
29
+ type: accuracy
30
+ value: 0.0
31
+ - name: Number Accuracy
32
+ type: accuracy
33
+ value: 0.0
34
+ - name: Other Accuracy
35
+ type: accuracy
36
+ value: 0.0
37
+ source:
38
+ name: Pending
39
+ url: https://visualqa.org/download.html
40
+ ---
41
+
42
+ # Manthan-T1
43
+
44
+ A custom **Transformers** architecture for a compact vision-language model.
45
+
46
+ ## Status
47
+
48
+ This repo currently contains:
49
+
50
+ - `ManthanConfig` (`manthan_t1/configuration_manthan.py`)
51
+ - `ManthanForCausalLM` (`manthan_t1/modeling_manthan.py`)
52
+ - vision encoder (minimal ViT-like)
53
+ - projector to text hidden size
54
+ - decoder LM (placeholder GPT-2 by default for smoke tests)
55
+
56
+ Planned next steps:
57
+
58
+ - Swap the text backbone to **Qwen3-0.6B** via `text_config` + weight loading
59
+ - Swap the vision tower to **SigLIP2-so400m (patch14-384)** and align image token handling
60
+ - Add proper processor + chat template to enforce **reply in user’s input language** (Tamil/Hindi/etc.)
61
+
62
+ ## Loading
63
+
64
+ This is intended to be loaded with:
65
+
66
+ - `AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)`
67
+
68
+ See `scripts/infer_hf.py`.
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Manthan-T1
2
+
3
+ A from-scratch scaffold for a custom **Transformers** vision-language architecture named **Manthan-T1**.
4
+
5
+ ## What you get (today)
6
+ - A clean project layout under `manthan_t1/`
7
+ - A full HF custom architecture:
8
+ - `ManthanConfig` in `manthan_t1/configuration_manthan.py`
9
+ - `ManthanForCausalLM` in `manthan_t1/modeling_manthan.py`
10
+ - A no-download HF forward smoke test:
11
+ - `python -m manthan_t1.hf_smoke`
12
+ - An MLX smoke test (kept for Apple Silicon readiness):
13
+ - `python -m manthan_t1.smoke_test`
14
+
15
+ ## What we’ll add next
16
+ - Qwen3-0.6B backbone wiring + weight loading (keeping the model type = `manthan_t1`)
17
+ - SigLIP2 vision tower wiring + projector alignment
18
+ - LoRA fine-tuning recipes for M4 16GB (MLX +/or PyTorch)
19
+ - Multilingual “reply in user language” policy (Indian languages)
20
+
21
+ ## Quick smoke test
22
+ After installing dependencies, you should be able to run:
23
+
24
+ ```bash
25
+ python -m manthan_t1.smoke_test
26
+ python -m manthan_t1.hf_smoke
27
+ ```
28
+
29
+ This does **not** download any external models yet.
docs/KAGGLE_TRAINING.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kaggle training (2×T4) – Manthan‑T1
2
+
3
+ This repo includes a Kaggle-oriented training entrypoint:
4
+
5
+ - `scripts/train_unsloth_kaggle.py`
6
+
7
+ It uses the same LLaVA-style dataset format as TinyLLaVA/MicroLLaVA:
8
+ - dataset sample keys: `image`, `conversations`, `id`
9
+ - `conversations`: `[{'from':'human','value':'...<image>...'}, {'from':'gpt','value':'...'}]`
10
+ - uses `IMAGE_TOKEN_INDEX = -200`
11
+ - uses `IGNORE_INDEX = -100` for masked labels
12
+
13
+ ## What this script trains
14
+
15
+ Default (recommended for 2×T4):
16
+ - vision tower: **frozen**
17
+ - multimodal projector: **trainable** (always)
18
+ - LLM: **LoRA adapters** (optional, enable `--use_lora`)
19
+
20
+ This matches the standard LLaVA/TinyLLaVA recipe: align projector first, then instruction tune.
21
+
22
+ ## Kaggle setup checklist
23
+
24
+ 1) Enable GPU (2×T4) in Kaggle.
25
+ 2) Ensure `pip` deps exist in the notebook:
26
+
27
+ ```bash
28
+ pip install -U transformers accelerate datasets peft
29
+ # Optional (recommended): Unsloth if available in your notebook image
30
+ pip install -U unsloth
31
+ ```
32
+
33
+ 3) Clone your HF repo or `git clone` the repo.
34
+ 4) Set HF cache to persist in `/kaggle/working` so it survives “Save Version”:
35
+
36
+ ```bash
37
+ export HF_HOME=/kaggle/working/hf
38
+ export TRANSFORMERS_CACHE=/kaggle/working/hf/transformers
39
+ export HF_DATASETS_CACHE=/kaggle/working/hf/datasets
40
+ ```
41
+
42
+ ## Stage 1 (projector alignment)
43
+
44
+ Use a smaller pretrain set first:
45
+ - `liuhaotian/LLaVA-CC3M-Pretrain-595K`
46
+
47
+ Example run:
48
+
49
+ ```bash
50
+ python scripts/train_unsloth_kaggle.py \
51
+ --stage stage1 \
52
+ --manthan_model <YOUR_HF_REPO_OR_LOCAL_PATH> \
53
+ --text_model Qwen/Qwen3-0.6B-Base \
54
+ --dataset liuhaotian/LLaVA-CC3M-Pretrain-595K \
55
+ --output_dir /kaggle/working/manthan_stage1 \
56
+ --use_lora \
57
+ --max_length 2048 \
58
+ --image_size 384 \
59
+ --batch_size 1 \
60
+ --grad_accum 32 \
61
+ --lr 1e-4 \
62
+ --epochs 1 \
63
+ --limit 20000
64
+ ```
65
+
66
+ Notes:
67
+ - Increase `--limit` as you gain confidence.
68
+ - If you run out of VRAM, reduce `--max_length` or increase `--grad_accum`.
69
+
70
+ ## Stage 2 (instruction tuning)
71
+
72
+ Dataset:
73
+ - `liuhaotian/LLaVA-Instruct-150K`
74
+
75
+ ```bash
76
+ python scripts/train_unsloth_kaggle.py \
77
+ --stage stage2 \
78
+ --manthan_model <YOUR_HF_REPO_OR_LOCAL_PATH> \
79
+ --text_model Qwen/Qwen3-0.6B-Base \
80
+ --dataset liuhaotian/LLaVA-Instruct-150K \
81
+ --output_dir /kaggle/working/manthan_stage2 \
82
+ --use_lora \
83
+ --max_length 2048 \
84
+ --image_size 384 \
85
+ --batch_size 1 \
86
+ --grad_accum 32 \
87
+ --lr 1e-4 \
88
+ --epochs 1 \
89
+ --limit 150000
90
+ ```
91
+
92
+ ## Outputs
93
+
94
+ The script saves into `--output_dir`:
95
+ - `projector.pt` (multimodal projector weights)
96
+ - `save_pretrained()` output for the model (includes remote-code config; adapters if supported)
97
+
98
+ In practice, you’ll likely upload these artifacts back to HF.
99
+
100
+ ## Dry run (local)
101
+
102
+ To validate the training loop without datasets:
103
+
104
+ ```bash
105
+ python scripts/train_unsloth_kaggle.py \
106
+ --stage stage1 \
107
+ --manthan_model hf_export_ready \
108
+ --text_model gpt2 \
109
+ --dataset dummy \
110
+ --output_dir ./tmp_out \
111
+ --dry_run
112
+ ```
113
+
114
+ (For real Kaggle training, don’t use stub weights.)
hf_export_stub/added_tokens.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "<image>": 0,
3
+ "<im_start>": 0,
4
+ "<im_end>": 0
5
+ }
hf_export_stub/chat_template.jinja ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% set system = system_message | default('You are Manthan-T1, a helpful multimodal assistant.') %}
2
+ {% set user_lang_rule = 'Reply in the same language as the user. If the user writes in Tamil, reply in Tamil; if Hindi then Hindi; if English then English.' %}
3
+
4
+ {% if messages[0]['role'] != 'system' %}
5
+ <|system|>
6
+ {{ system }}\n{{ user_lang_rule }}
7
+ <|end|>
8
+ {% endif %}
9
+
10
+ {% for m in messages %}
11
+ {% if m['role'] == 'system' %}
12
+ <|system|>
13
+ {{ m['content'] }}\n{{ user_lang_rule }}
14
+ <|end|>
15
+ {% elif m['role'] == 'user' %}
16
+ <|user|>
17
+ {{ m['content'] }}
18
+ <|end|>
19
+ {% elif m['role'] == 'assistant' %}
20
+ <|assistant|>
21
+ {{ m['content'] }}
22
+ <|end|>
23
+ {% endif %}
24
+ {% endfor %}
25
+
26
+ <|assistant|>
hf_export_stub/config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "manthan_t1",
3
+ "architectures": ["ManthanForCausalLM"],
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_manthan.ManthanConfig",
6
+ "AutoModelForCausalLM": "modeling_manthan.ManthanForCausalLM"
7
+ },
8
+ "text_model_id": "Qwen/Qwen3-0.6B",
9
+ "vision_model_id": "google/siglip-so400m-patch14-384",
10
+ "vision_image_size": 384,
11
+ "vision_patch_size": 14,
12
+ "vision_feature_select": "patch",
13
+ "num_image_tokens": 256,
14
+ "image_token_id": 0,
15
+ "torch_dtype": "float16"
16
+ }
hf_export_stub/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<image>",
4
+ "<im_start>",
5
+ "<im_end>"
6
+ ]
7
+ }
hf_export_stub/tokenizer_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "model_max_length": 4096,
3
+ "padding_side": "right",
4
+ "truncation_side": "right",
5
+ "use_fast": false
6
+ }
manthan_t1/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __all__ = ["__version__"]
2
+ __version__ = "0.0.1"
manthan_t1/configuration_manthan.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class ManthanConfig(PretrainedConfig):
10
+ """Configuration for Manthan-T1.
11
+
12
+ Matches key MicroLLaVA/TinyLLaVA conventions:
13
+ - `image_token_index` is a negative placeholder id (default -200)
14
+ - keep `image_token_id` as an alias (defaults to image_token_index)
15
+
16
+ `text_config` is kept as a dict for JSON serialization.
17
+ """
18
+
19
+ model_type = "manthan_t1"
20
+
21
+ def __init__(
22
+ self,
23
+ text_config: Optional[dict] = None,
24
+ text_model_id: Optional[str] = None,
25
+ vision_model_id: Optional[str] = None,
26
+ vision_hidden_size: int = 1024,
27
+ vision_num_hidden_layers: int = 24,
28
+ vision_num_attention_heads: int = 16,
29
+ vision_image_size: int = 384,
30
+ vision_patch_size: int = 14,
31
+ projector_hidden_size: Optional[int] = None,
32
+ image_token_index: int = -200,
33
+ image_token_id: Optional[int] = None,
34
+ num_image_tokens: int = 256,
35
+ vision_feature_select: str = "patch",
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+
40
+ self.text_config_dict = text_config or {}
41
+
42
+ # Optional resolved config for HF generation helpers
43
+ self.text_config_obj: Optional[PretrainedConfig] = None
44
+ if self.text_config_dict.get("model_type"):
45
+ from transformers import AutoConfig
46
+
47
+ try:
48
+ self.text_config_obj = AutoConfig.for_model(**self.text_config_dict)
49
+ except Exception:
50
+ self.text_config_obj = None
51
+
52
+ self.text_model_id = text_model_id
53
+ self.vision_model_id = vision_model_id
54
+
55
+ self.vision_hidden_size = int(vision_hidden_size)
56
+ self.vision_num_hidden_layers = int(vision_num_hidden_layers)
57
+ self.vision_num_attention_heads = int(vision_num_attention_heads)
58
+ self.vision_image_size = int(vision_image_size)
59
+ self.vision_patch_size = int(vision_patch_size)
60
+
61
+ self.projector_hidden_size = projector_hidden_size
62
+
63
+ self.image_token_index = int(image_token_index)
64
+ self.image_token_id = int(image_token_id) if image_token_id is not None else int(image_token_index)
65
+ self.num_image_tokens = int(num_image_tokens)
66
+ self.vision_feature_select = vision_feature_select
67
+
68
+ # -------- Generation-related compatibility --------
69
+ # Transformers' generation utilities (DynamicCache, etc.) expect certain
70
+ # attributes on the *decoder/text* config. Since ManthanConfig is a
71
+ # wrapper that may not always carry a resolved `text_config_obj`, we set
72
+ # conservative defaults here to keep `model.generate()` functional in
73
+ # stub/export scenarios.
74
+ self.num_hidden_layers = int(
75
+ getattr(self.text_config_obj, "num_hidden_layers", kwargs.get("num_hidden_layers", 1))
76
+ if self.text_config_obj is not None
77
+ else kwargs.get("num_hidden_layers", 1)
78
+ )
79
+ self.num_attention_heads = int(
80
+ getattr(self.text_config_obj, "num_attention_heads", kwargs.get("num_attention_heads", 1))
81
+ if self.text_config_obj is not None
82
+ else kwargs.get("num_attention_heads", 1)
83
+ )
84
+ self.hidden_size = int(
85
+ getattr(self.text_config_obj, "hidden_size", kwargs.get("hidden_size", 256))
86
+ if self.text_config_obj is not None
87
+ else kwargs.get("hidden_size", 256)
88
+ )
89
+ self.max_position_embeddings = int(
90
+ getattr(self.text_config_obj, "max_position_embeddings", kwargs.get("max_position_embeddings", 2048))
91
+ if self.text_config_obj is not None
92
+ else kwargs.get("max_position_embeddings", 2048)
93
+ )
94
+ self.vocab_size = int(
95
+ getattr(self.text_config_obj, "vocab_size", kwargs.get("vocab_size", 32000))
96
+ if self.text_config_obj is not None
97
+ else kwargs.get("vocab_size", 32000)
98
+ )
99
+
100
+ def get_text_config(self, decoder: bool = False):
101
+ # Transformers' GenerationConfig helpers call get_text_config() during
102
+ # PreTrainedModel initialization. For stub/export-time configs we may not
103
+ # have a resolved text backbone yet; in that case, fall back to self.
104
+ if self.text_config_obj is None:
105
+ return self
106
+ return self.text_config_obj
107
+
108
+
109
+ @dataclass
110
+ class ManthanBatch:
111
+ input_ids: "torch.LongTensor"
112
+ attention_mask: Optional["torch.LongTensor"]
113
+ pixel_values: Optional["torch.FloatTensor"]
114
+ labels: Optional["torch.LongTensor"]
manthan_t1/hf_integration_smoke.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optional integration smoke test.
2
+
3
+ This WILL download models (big) if run.
4
+
5
+ It checks that:
6
+ - Qwen/Qwen3-0.6B loads as the text backbone
7
+ - google/siglip-so400m-patch14-384 loads as the vision backbone
8
+ - a single forward pass works with <image> token injection
9
+
10
+ Run (optional):
11
+ python -m manthan_t1.hf_integration_smoke
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ from transformers import AutoTokenizer
18
+
19
+ from manthan_t1.configuration_manthan import ManthanConfig
20
+ from manthan_t1.modeling_manthan import ManthanForCausalLM
21
+
22
+
23
+ def main() -> None:
24
+ text_id = "Qwen/Qwen3-0.6B"
25
+ vision_id = "google/siglip-so400m-patch14-384"
26
+
27
+ cfg = ManthanConfig(
28
+ text_model_id=text_id,
29
+ vision_model_id=vision_id,
30
+ # SigLIP so400m patch14 384
31
+ vision_image_size=384,
32
+ vision_patch_size=14,
33
+ # 384/14 is non-integer; many siglip variants still use patch14, but token count comes from model.
34
+ # We'll keep num_image_tokens as 256 to match common LLaVA-style settings.
35
+ num_image_tokens=256,
36
+ image_token_id=151665,
37
+ vision_feature_select="patch",
38
+ )
39
+
40
+ model = ManthanForCausalLM(cfg)
41
+ model.eval()
42
+
43
+ tok = AutoTokenizer.from_pretrained(text_id, use_fast=False, trust_remote_code=True)
44
+
45
+ # Create a prompt with enough <image> placeholders
46
+ image_tok = tok.decode([cfg.image_token_id])
47
+ prompt = (image_tok + " ") * cfg.num_image_tokens + "\nDescribe the image."
48
+ inputs = tok(prompt, return_tensors="pt")
49
+
50
+ # Dummy image tensor with expected size; processor correctness is handled in `chat()`.
51
+ pixel_values = torch.randn(1, 3, cfg.vision_image_size, cfg.vision_image_size)
52
+
53
+ with torch.no_grad():
54
+ out = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), pixel_values=pixel_values)
55
+
56
+ print("OK integration forward", tuple(out.logits.shape))
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
manthan_t1/hf_smoke.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HF Transformers smoke test for Manthan-T1.
2
+
3
+ This must run without downloading any external checkpoints.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+
10
+ from manthan_t1.configuration_manthan import ManthanConfig
11
+ from manthan_t1.modeling_manthan import ManthanForCausalLM
12
+
13
+
14
+ def main() -> None:
15
+ cfg = ManthanConfig(
16
+ text_config={"model_type": "gpt2"},
17
+ vision_image_size=224,
18
+ vision_patch_size=16,
19
+ vision_hidden_size=128,
20
+ # 224/16 = 14 patches per side => 196 patch tokens (CLS is dropped by default)
21
+ num_image_tokens=196,
22
+ image_token_id=42,
23
+ )
24
+ model = ManthanForCausalLM(cfg)
25
+ model.eval()
26
+
27
+ # Fake text with num_image_tokens image tokens
28
+ B, T = 2, 32
29
+ input_ids = torch.randint(0, 100, (B, T))
30
+ # Ensure sequence is long enough to host the image tokens
31
+ if T < cfg.num_image_tokens:
32
+ input_ids = torch.randint(0, 100, (B, cfg.num_image_tokens + 8))
33
+ T = input_ids.shape[1]
34
+ input_ids[:, : cfg.num_image_tokens] = cfg.image_token_id
35
+ pixel_values = torch.randn(B, 3, cfg.vision_image_size, cfg.vision_image_size)
36
+
37
+ out = model(input_ids=input_ids, pixel_values=pixel_values)
38
+ assert out.logits.shape[:2] == (B, T)
39
+ print("OK manthan hf forward", tuple(out.logits.shape))
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
manthan_t1/modeling_manthan.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModel,
12
+ AutoModelForCausalLM,
13
+ GenerationMixin,
14
+ PreTrainedModel,
15
+ )
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+
18
+ from .configuration_manthan import ManthanConfig
19
+
20
+
21
+ IGNORE_INDEX = -100
22
+
23
+
24
+ class ManthanVisionEncoder(nn.Module):
25
+ """Minimal ViT-like vision encoder.
26
+
27
+ This is intentionally simple so the architecture is fully defined in this repo.
28
+ You can later swap it with SigLIP2 weights by mapping parameters.
29
+ """
30
+
31
+ def __init__(self, image_size: int, patch_size: int, hidden_size: int):
32
+ super().__init__()
33
+ self.image_size = image_size
34
+ self.patch_size = patch_size
35
+ self.hidden_size = hidden_size
36
+
37
+ self.proj = nn.Conv2d(3, hidden_size, kernel_size=patch_size, stride=patch_size)
38
+ num_patches = (image_size // patch_size) * (image_size // patch_size)
39
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
40
+ self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, hidden_size))
41
+
42
+ encoder_layer = nn.TransformerEncoderLayer(
43
+ d_model=hidden_size,
44
+ nhead=max(1, hidden_size // 64),
45
+ dim_feedforward=hidden_size * 4,
46
+ batch_first=True,
47
+ activation="gelu",
48
+ )
49
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
50
+ self.ln = nn.LayerNorm(hidden_size)
51
+
52
+ nn.init.normal_(self.pos_embed, std=0.02)
53
+ nn.init.normal_(self.cls_token, std=0.02)
54
+
55
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
56
+ # pixel_values: (B, 3, H, W)
57
+ x = self.proj(pixel_values) # (B, C, H', W')
58
+ x = x.flatten(2).transpose(1, 2) # (B, N, C)
59
+ cls = self.cls_token.expand(x.size(0), -1, -1)
60
+ x = torch.cat([cls, x], dim=1)
61
+ x = x + self.pos_embed[:, : x.size(1), :]
62
+ x = self.encoder(x)
63
+ return self.ln(x)
64
+
65
+
66
+ class ManthanProjector(nn.Module):
67
+ def __init__(self, vision_hidden: int, text_hidden: int, mid: Optional[int] = None):
68
+ super().__init__()
69
+ mid = mid or max(text_hidden, vision_hidden)
70
+ self.net = nn.Sequential(
71
+ nn.Linear(vision_hidden, mid),
72
+ nn.GELU(),
73
+ nn.Linear(mid, text_hidden),
74
+ )
75
+
76
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
77
+ return self.net(x)
78
+
79
+
80
+ class ManthanForCausalLM(PreTrainedModel, GenerationMixin):
81
+ config_class = ManthanConfig
82
+ base_model_prefix = "manthan"
83
+
84
+ def __init__(self, config: ManthanConfig):
85
+ super().__init__(config)
86
+
87
+ # Text backbone
88
+ # Priority:
89
+ # 1) `text_model_id` (e.g., Qwen/Qwen3-0.6B)
90
+ # 2) `text_config` dict (model_type + overrides)
91
+ # 3) fallback tiny GPT-2 for smoke tests
92
+ if config.text_model_id:
93
+ self.language_model = AutoModelForCausalLM.from_pretrained(
94
+ config.text_model_id,
95
+ trust_remote_code=True,
96
+ )
97
+ else:
98
+ text_cfg = getattr(config, "text_config_obj", None)
99
+ if text_cfg is None:
100
+ text_cfg = (
101
+ AutoConfig.for_model(**config.text_config_dict)
102
+ if getattr(config, "text_config_dict", {}).get("model_type")
103
+ else None
104
+ )
105
+ if text_cfg is None:
106
+ from transformers import GPT2Config
107
+
108
+ text_cfg = GPT2Config(
109
+ n_embd=256,
110
+ n_layer=4,
111
+ n_head=4,
112
+ vocab_size=32000,
113
+ )
114
+ self.language_model = AutoModelForCausalLM.from_config(text_cfg)
115
+
116
+
117
+ text_hidden = self.language_model.config.hidden_size
118
+
119
+ # Vision backbone
120
+ self.vision_model = None
121
+ if config.vision_model_id:
122
+ self.vision_model = AutoModel.from_pretrained(config.vision_model_id, trust_remote_code=True)
123
+ vision_hidden = getattr(getattr(self.vision_model, "config", None), "hidden_size", None)
124
+ if vision_hidden is not None:
125
+ config.vision_hidden_size = int(vision_hidden)
126
+
127
+ # Fallback toy tower remains available (used when vision_model_id is not set)
128
+ self.vision_tower = ManthanVisionEncoder(
129
+ image_size=config.vision_image_size,
130
+ patch_size=config.vision_patch_size,
131
+ hidden_size=config.vision_hidden_size,
132
+ )
133
+
134
+ self.projector = ManthanProjector(
135
+ vision_hidden=config.vision_hidden_size,
136
+ text_hidden=text_hidden,
137
+ mid=config.projector_hidden_size,
138
+ )
139
+
140
+ # Use TinyLLaVA-style negative placeholder for <image>
141
+ self.image_token_id = int(getattr(config, "image_token_id", -200))
142
+ self.num_image_tokens = int(getattr(config, "num_image_tokens", 256))
143
+
144
+ # Generation helpers
145
+ self._gen_pixel_values: Optional[torch.FloatTensor] = None
146
+
147
+ self.post_init()
148
+
149
+ @staticmethod
150
+ def format_chat_prompt(prompt: str, has_image: bool = False) -> str:
151
+ """Format a single user prompt similar to TinyLLaVA's Qwen3 template.
152
+
153
+ We intentionally keep it simple and *string based* so it does not depend
154
+ on the tokenizer's chat_template.
155
+ """
156
+
157
+ system = (
158
+ "A chat between a curious user and an artificial intelligence assistant. "
159
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. "
160
+ )
161
+
162
+ if has_image:
163
+ # Ensure the user doesn't redundantly include <image> in their prompt.
164
+ clean = prompt.replace("<image>", "").strip()
165
+ formatted = f"<image>\n{clean}"
166
+ else:
167
+ formatted = prompt.strip()
168
+
169
+ # Critical: no trailing space after ASSISTANT:
170
+ return system + f"USER: {formatted} ASSISTANT:"
171
+
172
+ def _inject_vision_embeds(
173
+ self,
174
+ input_ids: torch.LongTensor,
175
+ pixel_values: torch.FloatTensor,
176
+ attention_mask: Optional[torch.LongTensor] = None,
177
+ ) -> Tuple[torch.FloatTensor, Optional[torch.LongTensor]]:
178
+ """Create input_embeds where <image> tokens are replaced by projected vision tokens.
179
+
180
+ Contract:
181
+ - input_ids contains exactly `num_image_tokens` occurrences of `image_token_id`.
182
+ - We will replace them in sequence order with vision tokens.
183
+ """
184
+
185
+ # Vision features
186
+ if self.vision_model is not None:
187
+ vout = self.vision_model(pixel_values=pixel_values)
188
+ # Try common fields: last_hidden_state
189
+ vision = getattr(vout, "last_hidden_state", None)
190
+ if vision is None:
191
+ raise ValueError("Vision model output does not contain last_hidden_state")
192
+ else:
193
+ vision = self.vision_tower(pixel_values) # (B, 1+N, vision_hidden)
194
+
195
+ # Token selection
196
+ if self.config.vision_feature_select == "patch":
197
+ # Drop CLS and take first num_image_tokens patches
198
+ vision = vision[:, 1 : 1 + self.num_image_tokens, :]
199
+ elif self.config.vision_feature_select == "cls_patch":
200
+ vision = vision[:, : self.num_image_tokens, :]
201
+ else:
202
+ raise ValueError(f"Unknown vision_feature_select={self.config.vision_feature_select}")
203
+
204
+ # Be strict: ensure we have exactly num_image_tokens.
205
+ if vision.shape[1] != self.num_image_tokens:
206
+ # Common case for the toy vision tower: includes CLS + N patches.
207
+ if vision.shape[1] > self.num_image_tokens:
208
+ vision = vision[:, : self.num_image_tokens, :]
209
+ else:
210
+ raise ValueError(
211
+ f"vision tokens ({vision.shape[1]}) < num_image_tokens ({self.num_image_tokens}); increase image size/patching or reduce num_image_tokens"
212
+ )
213
+ vision = self.projector(vision) # (B, num_img, text_hidden)
214
+
215
+ # Text embeds
216
+ lm = self.language_model
217
+ if hasattr(lm, "get_input_embeddings"):
218
+ tok_emb = lm.get_input_embeddings()
219
+ else:
220
+ tok_emb = lm.base_model.get_input_embeddings()
221
+ inputs_embeds = tok_emb(input_ids)
222
+
223
+ # Replace image token positions
224
+ mask = input_ids.eq(self.image_token_id) # (B, T)
225
+ if mask.sum(dim=1).min().item() != self.num_image_tokens:
226
+ raise ValueError(
227
+ f"Expected exactly {self.num_image_tokens} <image> tokens per sample, got min={mask.sum(dim=1).min().item()}"
228
+ )
229
+
230
+ for b in range(input_ids.size(0)):
231
+ idx = torch.nonzero(mask[b], as_tuple=False).squeeze(-1)
232
+ inputs_embeds[b, idx, :] = vision[b, : idx.numel(), :]
233
+
234
+ return inputs_embeds, attention_mask
235
+
236
+ @staticmethod
237
+ def tokenizer_image_token(
238
+ prompt: str,
239
+ tokenizer,
240
+ image_token_index: int = -200,
241
+ return_tensors: Optional[str] = None,
242
+ ):
243
+ """MicroLLaVA/TinyLLaVA-style tokenization inserting a negative image placeholder id.
244
+
245
+ This avoids requiring `<image>` to be a real token in the tokenizer vocab.
246
+ """
247
+
248
+ def _insert_separator(X, sep):
249
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
250
+
251
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
252
+
253
+ input_ids: List[int] = []
254
+ offset = 0
255
+ if (
256
+ len(prompt_chunks) > 0
257
+ and len(prompt_chunks[0]) > 0
258
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
259
+ ):
260
+ offset = 1
261
+ input_ids.append(prompt_chunks[0][0])
262
+
263
+ for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
264
+ input_ids.extend(x[offset:])
265
+
266
+ if return_tensors is not None:
267
+ if return_tensors == "pt":
268
+ return torch.tensor(input_ids, dtype=torch.long)
269
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
270
+ return input_ids
271
+
272
+ def _encode_images(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
273
+ """Return projected vision features (B, N, text_hidden)."""
274
+
275
+ if self.vision_model is not None:
276
+ vout = self.vision_model(pixel_values=pixel_values, output_hidden_states=True)
277
+ # Prefer siglip/clip style selection when hidden_states exist.
278
+ if hasattr(vout, "hidden_states") and vout.hidden_states is not None:
279
+ layer = getattr(self.config, "vision_feature_layer", -2)
280
+ vision = vout.hidden_states[layer]
281
+ else:
282
+ vision = getattr(vout, "last_hidden_state", None)
283
+ if vision is None:
284
+ raise ValueError("Vision model output has no usable hidden states")
285
+ else:
286
+ vision = self.vision_tower(pixel_values)
287
+
288
+ # Match TinyLLaVA selection strategy
289
+ strat = getattr(self.config, "vision_feature_select", "patch")
290
+ if strat == "patch":
291
+ vision = vision[:, 1:]
292
+ elif strat == "cls_patch":
293
+ vision = vision
294
+ else:
295
+ raise ValueError(f"Unknown vision_feature_select={strat}")
296
+
297
+ # Optionally truncate to configured max image tokens
298
+ if vision.shape[1] > self.num_image_tokens:
299
+ vision = vision[:, : self.num_image_tokens]
300
+
301
+ return self.projector(vision)
302
+
303
+ def prepare_inputs_labels_for_multimodal(
304
+ self,
305
+ input_ids: torch.LongTensor,
306
+ attention_mask: Optional[torch.Tensor],
307
+ past_key_values: Optional[Any],
308
+ labels: Optional[torch.LongTensor],
309
+ pixel_values: Optional[torch.FloatTensor],
310
+ ) -> Tuple[
311
+ Optional[torch.LongTensor],
312
+ Optional[torch.LongTensor],
313
+ Optional[torch.Tensor],
314
+ Optional[Any],
315
+ Optional[torch.FloatTensor],
316
+ Optional[torch.LongTensor],
317
+ ]:
318
+ """MicroLLaVA-style splice: build inputs_embeds by inserting vision features at IMAGE_TOKEN_INDEX.
319
+
320
+ Returns:
321
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels)
322
+ where input_ids is None when inputs_embeds is provided.
323
+ """
324
+
325
+ if pixel_values is None or input_ids.shape[1] == 1 or self.vision_tower is None:
326
+ return input_ids, None, attention_mask, past_key_values, None, labels
327
+
328
+ image_features = self._encode_images(pixel_values) # (B, N, hidden)
329
+
330
+ orig_labels = labels
331
+ orig_attention_mask = attention_mask
332
+
333
+ if attention_mask is None:
334
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
335
+ else:
336
+ attention_mask = attention_mask.bool()
337
+
338
+ if labels is None:
339
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
340
+
341
+ # Remove padding
342
+ input_ids_list = [cur_ids[cur_mask] for cur_ids, cur_mask in zip(input_ids, attention_mask)]
343
+ labels_list = [cur_lbl[cur_mask] for cur_lbl, cur_mask in zip(labels, attention_mask)]
344
+
345
+ tok_emb = self.language_model.get_input_embeddings()
346
+ vocab_size = int(getattr(tok_emb, "num_embeddings", 0) or 0)
347
+
348
+ new_input_embeds: List[torch.Tensor] = []
349
+ new_labels: List[torch.Tensor] = []
350
+ cur_image_idx = 0
351
+
352
+ for batch_idx, cur_ids in enumerate(input_ids_list):
353
+ num_images = int((cur_ids == self.image_token_id).sum().item())
354
+
355
+ # No image tokens: plain text path.
356
+ if num_images == 0:
357
+ if vocab_size > 0:
358
+ cur_ids = cur_ids.clamp(min=0, max=vocab_size - 1)
359
+ new_input_embeds.append(tok_emb(cur_ids))
360
+ new_labels.append(labels_list[batch_idx])
361
+ continue
362
+
363
+ # Split around image placeholder positions
364
+ image_token_indices = [-1] + torch.where(cur_ids == self.image_token_id)[0].tolist() + [cur_ids.shape[0]]
365
+ cur_labels = labels_list[batch_idx]
366
+
367
+ seg_ids: List[torch.Tensor] = []
368
+ seg_lbls: List[torch.Tensor] = []
369
+ for i in range(len(image_token_indices) - 1):
370
+ s = image_token_indices[i] + 1
371
+ e = image_token_indices[i + 1]
372
+ seg_ids.append(cur_ids[s:e])
373
+ seg_lbls.append(cur_labels[s:e])
374
+
375
+ split_sizes = [x.shape[0] for x in seg_ids]
376
+ total = int(sum(split_sizes))
377
+
378
+ if total > 0:
379
+ flat_ids = torch.cat(seg_ids, dim=0)
380
+ # Never feed negative ids into embeddings.
381
+ flat_ids = flat_ids[flat_ids >= 0]
382
+ if vocab_size > 0 and flat_ids.numel() > 0:
383
+ flat_ids = flat_ids.clamp(min=0, max=vocab_size - 1)
384
+ flat_emb = tok_emb(flat_ids)
385
+ emb_chunks = list(torch.split(flat_emb, split_sizes, dim=0))
386
+ else:
387
+ emb_chunks = [tok_emb(cur_ids[:0]) for _ in split_sizes]
388
+
389
+ cur_new_embeds: List[torch.Tensor] = []
390
+ cur_new_labels: List[torch.Tensor] = []
391
+
392
+ for i in range(num_images + 1):
393
+ cur_new_embeds.append(emb_chunks[i])
394
+ cur_new_labels.append(seg_lbls[i])
395
+ if i < num_images:
396
+ cur_img_feat = image_features[cur_image_idx]
397
+ cur_image_idx += 1
398
+ cur_new_embeds.append(cur_img_feat)
399
+ cur_new_labels.append(
400
+ torch.full(
401
+ (cur_img_feat.shape[0],),
402
+ IGNORE_INDEX,
403
+ device=cur_labels.device,
404
+ dtype=cur_labels.dtype,
405
+ )
406
+ )
407
+
408
+ new_input_embeds.append(torch.cat([x.to(self.device) for x in cur_new_embeds], dim=0))
409
+ new_labels.append(torch.cat(cur_new_labels, dim=0))
410
+
411
+ # Truncate if needed
412
+ max_len_cfg = getattr(self.config, "tokenizer_model_max_length", None)
413
+ if max_len_cfg is not None:
414
+ new_input_embeds = [x[:max_len_cfg] for x in new_input_embeds]
415
+ new_labels = [x[:max_len_cfg] for x in new_labels]
416
+
417
+ max_len = max(x.shape[0] for x in new_input_embeds)
418
+ batch_size = len(new_input_embeds)
419
+
420
+ padded_embeds: List[torch.Tensor] = []
421
+ padded_labels = torch.full(
422
+ (batch_size, max_len),
423
+ IGNORE_INDEX,
424
+ dtype=new_labels[0].dtype,
425
+ device=new_labels[0].device,
426
+ )
427
+ padded_mask = torch.zeros((batch_size, max_len), dtype=torch.long, device=padded_labels.device)
428
+
429
+ for i, (emb, lbl) in enumerate(zip(new_input_embeds, new_labels)):
430
+ cur_len = emb.shape[0]
431
+ pad = torch.zeros((max_len - cur_len, emb.shape[1]), dtype=emb.dtype, device=emb.device)
432
+ padded_embeds.append(torch.cat([emb, pad], dim=0))
433
+ padded_labels[i, :cur_len] = lbl
434
+ padded_mask[i, :cur_len] = 1
435
+
436
+ inputs_embeds = torch.stack(padded_embeds, dim=0)
437
+
438
+ out_labels = None if orig_labels is None else padded_labels
439
+ out_mask = None if orig_attention_mask is None else padded_mask
440
+
441
+ return None, None, out_mask, past_key_values, inputs_embeds, out_labels
442
+
443
+ def forward(
444
+ self,
445
+ input_ids: Optional[torch.LongTensor] = None,
446
+ attention_mask: Optional[torch.LongTensor] = None,
447
+ pixel_values: Optional[torch.FloatTensor] = None,
448
+ labels: Optional[torch.LongTensor] = None,
449
+ **kwargs,
450
+ ) -> CausalLMOutputWithPast:
451
+ if input_ids is None:
452
+ raise ValueError("input_ids is required")
453
+
454
+ if pixel_values is None and self._gen_pixel_values is not None:
455
+ pixel_values = self._gen_pixel_values
456
+
457
+ past_key_values = kwargs.get("past_key_values", None)
458
+ (
459
+ input_ids,
460
+ position_ids,
461
+ attention_mask,
462
+ past_key_values,
463
+ inputs_embeds,
464
+ labels,
465
+ ) = self.prepare_inputs_labels_for_multimodal(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ past_key_values=past_key_values,
469
+ labels=labels,
470
+ pixel_values=pixel_values,
471
+ )
472
+
473
+ return self.language_model(
474
+ input_ids=input_ids,
475
+ inputs_embeds=inputs_embeds,
476
+ attention_mask=attention_mask,
477
+ position_ids=position_ids,
478
+ past_key_values=past_key_values,
479
+ labels=labels,
480
+ **kwargs,
481
+ )
482
+
483
+ @torch.no_grad()
484
+ def generate(
485
+ self,
486
+ input_ids: Optional[torch.Tensor] = None,
487
+ pixel_values: Optional[torch.FloatTensor] = None,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ **kwargs,
490
+ ):
491
+ if input_ids is None:
492
+ raise ValueError("input_ids is required")
493
+
494
+ if pixel_values is not None:
495
+ (
496
+ _,
497
+ position_ids,
498
+ attention_mask,
499
+ _,
500
+ inputs_embeds,
501
+ _,
502
+ ) = self.prepare_inputs_labels_for_multimodal(
503
+ input_ids=input_ids,
504
+ attention_mask=attention_mask,
505
+ past_key_values=None,
506
+ labels=None,
507
+ pixel_values=pixel_values,
508
+ )
509
+ return self.language_model.generate(
510
+ inputs_embeds=inputs_embeds,
511
+ attention_mask=attention_mask,
512
+ position_ids=position_ids,
513
+ **kwargs,
514
+ )
515
+
516
+ return self.language_model.generate(
517
+ input_ids=input_ids,
518
+ attention_mask=attention_mask,
519
+ **kwargs,
520
+ )
521
+
522
+ def prepare_inputs_for_generation(
523
+ self,
524
+ input_ids: torch.LongTensor,
525
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, ...], ...]] = None,
526
+ attention_mask: Optional[torch.LongTensor] = None,
527
+ pixel_values: Optional[torch.FloatTensor] = None,
528
+ **kwargs,
529
+ ) -> Dict[str, Any]:
530
+ """HF generation hook.
531
+
532
+ Contract:
533
+ - On the first step, we keep full `input_ids` and provide `pixel_values`.
534
+ - For subsequent steps (past_key_values != None), HF passes only the last token.
535
+ We keep cached `pixel_values` and forward normally.
536
+ """
537
+
538
+ if pixel_values is not None:
539
+ self._gen_pixel_values = pixel_values
540
+
541
+ # If we have past, only feed last token (standard causal LM behavior)
542
+ if past_key_values is not None:
543
+ input_ids = input_ids[:, -1:]
544
+
545
+ model_inputs: Dict[str, Any] = {
546
+ "input_ids": input_ids,
547
+ "attention_mask": attention_mask,
548
+ "past_key_values": past_key_values,
549
+ "use_cache": kwargs.get("use_cache", True),
550
+ }
551
+
552
+ if self._gen_pixel_values is not None:
553
+ model_inputs["pixel_values"] = self._gen_pixel_values
554
+
555
+ return model_inputs
556
+
557
+ def _reorder_cache(self, past_key_values, beam_idx):
558
+ # Delegate to the underlying LM implementation
559
+ if hasattr(self.language_model, "_reorder_cache"):
560
+ return self.language_model._reorder_cache(past_key_values, beam_idx)
561
+ return past_key_values
562
+
563
+ @torch.no_grad()
564
+ def chat(
565
+ self,
566
+ prompt: str,
567
+ tokenizer,
568
+ image: Optional[Union[str, "PIL.Image.Image"]] = None,
569
+ max_new_tokens: int = 128,
570
+ **gen_kwargs,
571
+ ) -> str:
572
+ """Simple chat helper (mirrors the style in your MicroLLaVA README).
573
+
574
+ This is intentionally minimal. A proper processor will be added later.
575
+ """
576
+
577
+ # Lazy import to keep base install small
578
+ pixel_values = None
579
+ if image is not None:
580
+ from PIL import Image
581
+ import requests
582
+ from io import BytesIO
583
+
584
+ if isinstance(image, str) and image.startswith("http"):
585
+ r = requests.get(image, timeout=30)
586
+ r.raise_for_status()
587
+ image = Image.open(BytesIO(r.content)).convert("RGB")
588
+ elif isinstance(image, str):
589
+ image = Image.open(image).convert("RGB")
590
+
591
+ # Prefer model-specific preprocessing when we have a real vision backbone.
592
+ if self.vision_model is not None and self.config.vision_model_id:
593
+ from transformers import AutoProcessor
594
+
595
+ proc = AutoProcessor.from_pretrained(self.config.vision_model_id, trust_remote_code=True)
596
+ pv = proc(images=image, return_tensors="pt")
597
+ pixel_values = pv.get("pixel_values", None)
598
+ if pixel_values is None:
599
+ raise ValueError("AutoProcessor did not return pixel_values")
600
+ else:
601
+ import torchvision.transforms as T
602
+
603
+ tfm = T.Compose(
604
+ [
605
+ T.Resize((self.config.vision_image_size, self.config.vision_image_size)),
606
+ T.ToTensor(),
607
+ ]
608
+ )
609
+ pixel_values = tfm(image).unsqueeze(0)
610
+
611
+ # Insert language mirroring instruction (simple + robust)
612
+ # We keep it short to not fight Qwen's reasoning.
613
+ user_prompt = (
614
+ "Reply in the same language as the user's prompt (e.g., Tamil in Tamil, Hindi in Hindi, English in English). "
615
+ "Be helpful and concise.\n\n" + prompt
616
+ )
617
+
618
+ formatted = self.format_chat_prompt(user_prompt, has_image=pixel_values is not None)
619
+
620
+ # MicroLLaVA-style: tokenize by inserting image placeholder ids directly.
621
+ if pixel_values is not None:
622
+ input_ids = self.tokenizer_image_token(
623
+ formatted,
624
+ tokenizer,
625
+ image_token_index=self.image_token_id,
626
+ return_tensors="pt",
627
+ ).unsqueeze(0)
628
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
629
+ else:
630
+ ids = tokenizer(formatted, return_tensors="pt")
631
+ input_ids = ids["input_ids"]
632
+ attention_mask = ids.get("attention_mask")
633
+
634
+ # Keep everything on the model device
635
+ input_ids = input_ids.to(self.device)
636
+ if attention_mask is not None:
637
+ attention_mask = attention_mask.to(self.device)
638
+ if pixel_values is not None:
639
+ pixel_values = pixel_values.to(self.device)
640
+
641
+ # Use this wrapper's generate support so pixel_values can be carried through.
642
+ out = self.generate(
643
+ input_ids=input_ids,
644
+ attention_mask=attention_mask,
645
+ pixel_values=pixel_values,
646
+ max_new_tokens=max_new_tokens,
647
+ **gen_kwargs,
648
+ )
649
+ return tokenizer.decode(out[0], skip_special_tokens=True)
650
+
651
+
652
+ # Registration so AutoModel/AutoConfig can find it when you package/export.
653
+ AutoConfig.register(ManthanConfig.model_type, ManthanConfig)
654
+ AutoModelForCausalLM.register(ManthanConfig, ManthanForCausalLM)
manthan_t1/smoke_test.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal smoke test.
2
+
3
+ Goal: ensure the repo is runnable on macOS + MLX before we wire real Qwen/SigLIP weights.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import time
9
+
10
+ import mlx.core as mx
11
+ import mlx.nn as nn
12
+ import mlx.optimizers as optim
13
+
14
+
15
+ class TinyToyModel(nn.Module):
16
+ def __init__(self, vocab_size: int = 256, d_model: int = 128):
17
+ super().__init__()
18
+ self.emb = nn.Embedding(vocab_size, d_model)
19
+ self.l1 = nn.Linear(d_model, d_model)
20
+ self.l2 = nn.Linear(d_model, vocab_size)
21
+
22
+ def __call__(self, token_ids: mx.array) -> mx.array:
23
+ # token_ids: (B, T)
24
+ x = self.emb(token_ids)
25
+ x = nn.relu(self.l1(x))
26
+ return self.l2(x)
27
+
28
+
29
+ def main() -> None:
30
+ mx.random.seed(0)
31
+
32
+ model = TinyToyModel()
33
+
34
+ # fake batch
35
+ B, T = 4, 32
36
+ token_ids = mx.random.randint(0, 256, shape=(B, T))
37
+ targets = mx.random.randint(0, 256, shape=(B, T))
38
+
39
+ def loss_fn(m: TinyToyModel, x: mx.array, y: mx.array) -> mx.array:
40
+ logits = m(x)
41
+ logits2 = logits.reshape((-1, logits.shape[-1]))
42
+ y2 = y.reshape((-1,))
43
+ # `cross_entropy` may return per-example/per-token; reduce to scalar.
44
+ return mx.mean(nn.losses.cross_entropy(logits2, y2))
45
+
46
+ opt = optim.Adam(learning_rate=1e-3)
47
+
48
+ start = time.time()
49
+ def f(m: TinyToyModel) -> mx.array:
50
+ return loss_fn(m, token_ids, targets)
51
+
52
+ for step in range(5):
53
+ loss, grads = mx.value_and_grad(f)(model)
54
+ opt.update(model, grads)
55
+ mx.eval(loss)
56
+ print(f"step={step} loss={float(loss):.4f}")
57
+
58
+ mx.eval(model.parameters())
59
+ print(f"OK (elapsed {time.time() - start:.2f}s)")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
manthan_t1/text_generate_smoke.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """No-download generate smoke test.
2
+
3
+ Ensures that `ManthanForCausalLM.generate()` works (text-only path), which is
4
+ required before enabling image+generate.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ from manthan_t1.configuration_manthan import ManthanConfig
13
+ from manthan_t1.modeling_manthan import ManthanForCausalLM
14
+
15
+
16
+ def main() -> None:
17
+ cfg = ManthanConfig(text_config={"model_type": "gpt2"})
18
+ model = ManthanForCausalLM(cfg)
19
+ model.eval()
20
+
21
+ tok = AutoTokenizer.from_pretrained("gpt2")
22
+ prompt = "Hello, my name is"
23
+ inputs = tok(prompt, return_tensors="pt")
24
+
25
+ out = model.generate(**inputs, max_new_tokens=10)
26
+ print(tok.decode(out[0], skip_special_tokens=True))
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
manthan_t1/tokenizer_smoke.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenizer + template smoke test.
2
+
3
+ - Ensures `<image>` can be added to a tokenizer
4
+ - Ensures we can build a prompt containing `<image>` token placeholders
5
+
6
+ Downloads a small tokenizer (gpt2) only.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from transformers import AutoTokenizer
12
+
13
+ from manthan_t1.tokenizer_utils import ensure_vision_special_tokens
14
+
15
+
16
+ def main() -> None:
17
+ tok = AutoTokenizer.from_pretrained("gpt2")
18
+ res = ensure_vision_special_tokens(tok, add_im_start_end=True)
19
+
20
+ prompt = ("<image> " * 4) + "\nExplain what you see."
21
+ ids = tok(prompt).input_ids
22
+
23
+ print("image_token_id", res.image_token_id)
24
+ print("count", sum(1 for i in ids if i == res.image_token_id))
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
manthan_t1/tokenizer_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional, Tuple
5
+
6
+
7
+ DEFAULT_SPECIAL_TOKENS: Dict[str, str] = {
8
+ "image_token": "<image>",
9
+ "im_start": "<im_start>",
10
+ "im_end": "<im_end>",
11
+ }
12
+
13
+
14
+ @dataclass
15
+ class TokenSetupResult:
16
+ image_token_id: int
17
+ added: Dict[str, int]
18
+
19
+
20
+ def ensure_vision_special_tokens(tokenizer, add_im_start_end: bool = False) -> TokenSetupResult:
21
+ """Ensure the tokenizer has `<image>` (and optionally `<im_start>/<im_end>`).
22
+
23
+ Returns the chosen `image_token_id` and any newly-added token ids.
24
+
25
+ Works with both fast and slow tokenizers.
26
+ """
27
+
28
+ specials = {
29
+ "additional_special_tokens": [DEFAULT_SPECIAL_TOKENS["image_token"]],
30
+ }
31
+ if add_im_start_end:
32
+ specials["additional_special_tokens"].extend(
33
+ [DEFAULT_SPECIAL_TOKENS["im_start"], DEFAULT_SPECIAL_TOKENS["im_end"]]
34
+ )
35
+
36
+ # Add only missing tokens
37
+ existing = set(getattr(tokenizer, "additional_special_tokens", []) or [])
38
+ to_add = [t for t in specials["additional_special_tokens"] if t not in existing]
39
+
40
+ added_map: Dict[str, int] = {}
41
+ if to_add:
42
+ tokenizer.add_special_tokens({"additional_special_tokens": to_add})
43
+ # after add, resolve ids
44
+ for t in to_add:
45
+ added_map[t] = tokenizer.convert_tokens_to_ids(t)
46
+
47
+ image_token = DEFAULT_SPECIAL_TOKENS["image_token"]
48
+ image_token_id = tokenizer.convert_tokens_to_ids(image_token)
49
+ if image_token_id is None or image_token_id < 0:
50
+ raise ValueError("Failed to register <image> token")
51
+
52
+ return TokenSetupResult(image_token_id=image_token_id, added=added_map)
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "manthan-t1"
3
+ version = "0.0.1"
4
+ description = "Manthan-T1: MLX-first vision-language model (Qwen3 + vision encoder) fine-tuning scaffold"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ license = {text = "Apache-2.0"}
8
+ authors = [{name = "Manthan"}]
9
+
10
+ dependencies = [
11
+ "mlx>=0.17.0",
12
+ "numpy>=1.26",
13
+ "pillow>=10.0",
14
+ "tqdm>=4.66",
15
+ "pyyaml>=6.0",
16
+ "transformers>=4.55.0",
17
+ "torch>=2.2.0",
18
+ "torchvision>=0.17.0",
19
+ "requests>=2.31",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ dev = [
24
+ "pytest>=8.0",
25
+ ]
26
+
27
+ [tool.pytest.ini_options]
28
+ addopts = "-q"
29
+ testpaths = ["tests"]
requirements.txt ADDED
File without changes
scripts/export_hf.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Export a Manthan-T1 folder that can be uploaded to Hugging Face.
4
+
5
+ What this does:
6
+ - Copies `hf_export_stub/*` into an output directory
7
+ - Builds a tokenizer from `tokenizer_name_or_path` (defaults to Qwen3)
8
+ - Ensures `<image>` is a real special token in the tokenizer
9
+ - Writes `tokenizer_config.json`, `special_tokens_map.json`, `added_tokens.json`, and `chat_template.jinja`
10
+ - Updates `config.json` with a correct `image_token_id` (kept equal to -200 placeholder)
11
+
12
+ Note:
13
+ - This does NOT include model weights. It's intended for placeholder-weight repo layout
14
+ (like your MicroLLaVA example). For training, you'll later save actual weights.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import os
22
+ import shutil
23
+ import sys
24
+ from pathlib import Path
25
+
26
+ from transformers import AutoTokenizer
27
+
28
+
29
+ # Allow running this script without installing the package.
30
+ REPO_ROOT = Path(__file__).resolve().parents[1]
31
+ if str(REPO_ROOT) not in sys.path:
32
+ sys.path.insert(0, str(REPO_ROOT))
33
+
34
+
35
+ def _copytree(src: Path, dst: Path) -> None:
36
+ dst.mkdir(parents=True, exist_ok=True)
37
+ for item in src.iterdir():
38
+ s = item
39
+ d = dst / item.name
40
+ if item.is_dir():
41
+ shutil.copytree(s, d, dirs_exist_ok=True)
42
+ else:
43
+ shutil.copy2(s, d)
44
+
45
+
46
+ def main() -> None:
47
+ ap = argparse.ArgumentParser()
48
+ ap.add_argument("--out", required=True, help="Output folder")
49
+ ap.add_argument(
50
+ "--stub",
51
+ default=str(Path(__file__).resolve().parents[1] / "hf_export_stub"),
52
+ help="Path to hf_export_stub folder",
53
+ )
54
+ ap.add_argument(
55
+ "--tokenizer",
56
+ default=None,
57
+ help="Tokenizer name/path. Defaults to config.json tokenizer_name_or_path.",
58
+ )
59
+ ap.add_argument(
60
+ "--tokenizer_local_dir",
61
+ default=None,
62
+ help="Local tokenizer directory to copy (e.g. MicroLlava-* folder). If set, no network fetch is performed.",
63
+ )
64
+ ap.add_argument(
65
+ "--write_stub_weights",
66
+ action="store_true",
67
+ help="Write randomly-initialized weights (model.safetensors) into the export dir so from_pretrained() succeeds.",
68
+ )
69
+ args = ap.parse_args()
70
+
71
+ out_dir = Path(args.out).expanduser().resolve()
72
+ stub_dir = Path(args.stub).expanduser().resolve()
73
+
74
+ if not stub_dir.exists():
75
+ raise SystemExit(f"Stub dir not found: {stub_dir}")
76
+
77
+ out_dir.mkdir(parents=True, exist_ok=True)
78
+ _copytree(stub_dir, out_dir)
79
+
80
+ # Ensure we don't keep stale remote-code python files from a previous export.
81
+ for stale in ["configuration_manthan.py", "modeling_manthan.py", "__init__.py"]:
82
+ p = out_dir / stale
83
+ if p.exists():
84
+ p.unlink()
85
+
86
+ # Copy remote-code python files to export root (HF dynamic module loader expects them)
87
+ repo_root = Path(__file__).resolve().parents[1]
88
+ pkg_dir = repo_root / "manthan_t1"
89
+ for fname in ["configuration_manthan.py", "modeling_manthan.py", "__init__.py"]:
90
+ src = pkg_dir / fname
91
+ if not src.exists():
92
+ raise SystemExit(f"Missing required source file for export: {src}")
93
+ shutil.copy2(src, out_dir / fname)
94
+
95
+ cfg_path = out_dir / "config.json"
96
+ if not cfg_path.exists():
97
+ raise SystemExit(f"config.json not found in: {out_dir}")
98
+
99
+ cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
100
+ tokenizer_name = (
101
+ args.tokenizer
102
+ or cfg.get("tokenizer_name_or_path")
103
+ or cfg.get("llm_model_name_or_path")
104
+ or cfg.get("text_model_id")
105
+ or cfg.get("vision_model_id")
106
+ )
107
+ if not tokenizer_name:
108
+ raise SystemExit("Could not infer tokenizer_name_or_path")
109
+
110
+ # Prefer an on-disk tokenizer (e.g. the attached MicroLLaVA folder) to avoid any
111
+ # network dependency during export.
112
+ repo_root = Path(__file__).resolve().parents[1]
113
+ local_tokenizer_candidates = [
114
+ repo_root / "MicroLlava-Qwen3-0.6B-base-siglip2-so400m",
115
+ ]
116
+ for cand in local_tokenizer_candidates:
117
+ if cand.exists() and (cand / "tokenizer_config.json").exists():
118
+ tokenizer_name = str(cand)
119
+ break
120
+
121
+ tok = AutoTokenizer.from_pretrained(
122
+ tokenizer_name,
123
+ trust_remote_code=True,
124
+ use_fast=bool(cfg.get("tokenizer_use_fast", False)),
125
+ local_files_only=True,
126
+ )
127
+
128
+ # Ensure special tokens exist
129
+ added = tok.add_special_tokens({"additional_special_tokens": ["<image>"]})
130
+ # Some tokenizers need a pad token for batching.
131
+ if tok.pad_token_id is None and cfg.get("pad_token"):
132
+ tok.add_special_tokens({"pad_token": cfg["pad_token"]})
133
+
134
+ # Save tokenizer files into export dir
135
+ tok.save_pretrained(out_dir)
136
+
137
+ # Copy chat template if present in stub
138
+ tmpl_src = out_dir / "chat_template.jinja"
139
+ if tmpl_src.exists():
140
+ # Ensure tokenizer_config.json references it (HF uses string field)
141
+ tok_cfg_path = out_dir / "tokenizer_config.json"
142
+ if tok_cfg_path.exists():
143
+ tok_cfg = json.loads(tok_cfg_path.read_text(encoding="utf-8"))
144
+ else:
145
+ tok_cfg = {}
146
+ tok_cfg["chat_template"] = tmpl_src.read_text(encoding="utf-8")
147
+ tok_cfg_path.write_text(json.dumps(tok_cfg, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
148
+
149
+ # Align config fields with MicroLLaVA convention
150
+ cfg.setdefault("image_token_index", -200)
151
+ cfg["image_token_index"] = -200
152
+ cfg["image_token_id"] = -200
153
+
154
+ # For user convenience record actual tokenizer vocab id of '<image>'
155
+ img_vocab_id = tok.convert_tokens_to_ids("<image>")
156
+ cfg["tokenizer_image_token_id"] = int(img_vocab_id) if img_vocab_id is not None else None
157
+ cfg["tokenizer_added_tokens"] = int(added)
158
+
159
+ cfg_path.write_text(json.dumps(cfg, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
160
+
161
+ # Minimal README hint
162
+ readme = out_dir / "README_EXPORT.md"
163
+ readme.write_text(
164
+ "Manthan-T1 export folder (stub).\n\n"
165
+ "- `config.json` uses `image_token_index=-200` placeholder like TinyLLaVA.\n"
166
+ "- Tokenizer contains a real `<image>` special token.\n"
167
+ "- This folder does not include model weights; training should save weights here later.\n",
168
+ encoding="utf-8",
169
+ )
170
+
171
+ print(f"Exported to: {out_dir}")
172
+
173
+ if args.write_stub_weights:
174
+ # Import only when requested to avoid heavier imports for plain export.
175
+ from manthan_t1.configuration_manthan import ManthanConfig
176
+ from manthan_t1.modeling_manthan import ManthanForCausalLM
177
+
178
+ # Tiny randomly-initialized model that is loadable.
179
+ # This does not download any base weights.
180
+ stub_cfg = ManthanConfig(
181
+ text_model_id=None,
182
+ vision_model_id=None,
183
+ image_token_index=-200,
184
+ num_image_tokens=32,
185
+ )
186
+ model = ManthanForCausalLM(stub_cfg)
187
+ model.save_pretrained(out_dir, safe_serialization=True)
188
+
189
+ # Ensure auto_map is present so AutoConfig/AutoModel can resolve our
190
+ # custom classes via trust_remote_code.
191
+ saved_cfg = json.loads((out_dir / "config.json").read_text(encoding="utf-8"))
192
+ saved_cfg["auto_map"] = cfg.get(
193
+ "auto_map",
194
+ {
195
+ "AutoConfig": "configuration_manthan.ManthanConfig",
196
+ "AutoModelForCausalLM": "modeling_manthan.ManthanForCausalLM",
197
+ },
198
+ )
199
+ (out_dir / "config.json").write_text(
200
+ json.dumps(saved_cfg, indent=2, ensure_ascii=False) + "\n",
201
+ encoding="utf-8",
202
+ )
203
+
204
+ print("Wrote stub weights: model.safetensors")
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
scripts/infer_hf.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+
9
+ def main() -> None:
10
+ ap = argparse.ArgumentParser()
11
+ ap.add_argument("--model", type=str, default="./")
12
+ ap.add_argument("--prompt", type=str, required=True)
13
+ ap.add_argument("--image", type=str, default=None, help="URL or local path")
14
+ args = ap.parse_args()
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True)
17
+ tok = AutoTokenizer.from_pretrained(args.model, use_fast=False)
18
+
19
+ if hasattr(model, "chat"):
20
+ text = model.chat(prompt=args.prompt, image=args.image, tokenizer=tok)
21
+ print(text)
22
+ return
23
+
24
+ inputs = tok(args.prompt, return_tensors="pt")
25
+ out = model.generate(**inputs, max_new_tokens=128)
26
+ print(tok.decode(out[0], skip_special_tokens=True))
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
scripts/infer_qwen3_siglip2.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+
8
+ from manthan_t1.configuration_manthan import ManthanConfig
9
+ from manthan_t1.modeling_manthan import ManthanForCausalLM
10
+
11
+
12
+ def main() -> None:
13
+ ap = argparse.ArgumentParser()
14
+ ap.add_argument("--text-model", type=str, default="Qwen/Qwen3-0.6B")
15
+ ap.add_argument(
16
+ "--vision-model",
17
+ type=str,
18
+ default="google/siglip-so400m-patch14-384",
19
+ )
20
+ ap.add_argument("--prompt", type=str, required=True)
21
+ ap.add_argument("--image", type=str, default=None, help="URL or local path")
22
+ ap.add_argument("--image-token-id", type=int, default=151665)
23
+ ap.add_argument("--num-image-tokens", type=int, default=256)
24
+ args = ap.parse_args()
25
+
26
+ cfg = ManthanConfig(
27
+ text_model_id=args.text_model,
28
+ vision_model_id=args.vision_model,
29
+ vision_image_size=384,
30
+ vision_patch_size=14,
31
+ image_token_id=args.image_token_id,
32
+ num_image_tokens=args.num_image_tokens,
33
+ vision_feature_select="patch",
34
+ )
35
+
36
+ model = ManthanForCausalLM(cfg)
37
+ tok = AutoTokenizer.from_pretrained(args.text_model, use_fast=False, trust_remote_code=True)
38
+
39
+ if args.image:
40
+ out = model.chat(prompt=args.prompt, image=args.image, tokenizer=tok)
41
+ else:
42
+ inputs = tok(args.prompt, return_tensors="pt")
43
+ out_ids = model.language_model.generate(**inputs, max_new_tokens=128)
44
+ out = tok.decode(out_ids[0], skip_special_tokens=True)
45
+
46
+ print(out)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
scripts/kaggle_train_all.sh ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # One-shot Kaggle runner: Stage 1 (pretrain/alignment) -> Stage 2 (instruct finetune)
5
+ # Designed for Kaggle 2xT4, but also works on other CUDA machines.
6
+
7
+ ############################################
8
+ # User config (edit these if you want)
9
+ ############################################
10
+ : "${MANTHAN_MODEL:=zyxcisss/Manthan-T1}" # HF repo or local path containing Manthan remote-code
11
+ : "${TEXT_MODEL:=Qwen/Qwen3-0.6B-Base}" # base LLM checkpoint
12
+ : "${STAGE1_DS:=liuhaotian/LLaVA-CC3M-Pretrain-595K}" # pretrain/alignment
13
+ : "${STAGE2_DS:=liuhaotian/LLaVA-Instruct-150K}" # instruction finetune
14
+
15
+ : "${OUT_BASE:=/kaggle/working/manthan_runs}" # all outputs saved here
16
+ : "${STAGE1_OUT:=${OUT_BASE}/stage1}" # stage1 output dir
17
+ : "${STAGE2_OUT:=${OUT_BASE}/stage2}" # stage2 output dir
18
+
19
+ # Training knobs (safe defaults for 2xT4)
20
+ : "${MAX_LENGTH:=2048}"
21
+ : "${IMAGE_SIZE:=384}"
22
+ : "${BATCH_SIZE:=1}"
23
+ : "${GRAD_ACCUM:=32}"
24
+ : "${LR:=1e-4}"
25
+ : "${EPOCHS_STAGE1:=1}"
26
+ : "${EPOCHS_STAGE2:=1}"
27
+
28
+ # Optional dataset limits (set empty for full)
29
+ : "${LIMIT_STAGE1:=20000}"
30
+ : "${LIMIT_STAGE2:=150000}"
31
+
32
+ # If you want to disable LoRA for projector-only training, set USE_LORA=0
33
+ : "${USE_LORA:=1}"
34
+
35
+ # If you want this script to upload artifacts via huggingface-cli, set UPLOAD=1
36
+ : "${UPLOAD:=0}"
37
+
38
+ ############################################
39
+ # Environment setup
40
+ ############################################
41
+ if command -v nvidia-smi >/dev/null 2>&1; then
42
+ echo "GPU found:"; nvidia-smi || true
43
+ else
44
+ echo "WARNING: nvidia-smi not found. This script expects a CUDA runtime (Kaggle)."
45
+ fi
46
+
47
+ # Persist caches on Kaggle
48
+ export HF_HOME="${HF_HOME:-/kaggle/working/hf}"
49
+ export TRANSFORMERS_CACHE="${TRANSFORMERS_CACHE:-/kaggle/working/hf/transformers}"
50
+ export HF_DATASETS_CACHE="${HF_DATASETS_CACHE:-/kaggle/working/hf/datasets}"
51
+
52
+ mkdir -p "${HF_HOME}" "${TRANSFORMERS_CACHE}" "${HF_DATASETS_CACHE}" "${OUT_BASE}"
53
+
54
+ ############################################
55
+ # Dependencies
56
+ ############################################
57
+ python - <<'PY'
58
+ import sys
59
+ print("python:", sys.version)
60
+ PY
61
+
62
+ # Keep installs minimal and reproducible enough for Kaggle.
63
+ python -m pip install -U pip
64
+ python -m pip install -U "transformers>=4.45" accelerate datasets peft
65
+
66
+ # Unsloth is optional; script falls back to PEFT if it isn't installed.
67
+ python -m pip install -U unsloth || true
68
+
69
+ ############################################
70
+ # Helper to add optional args
71
+ ############################################
72
+ maybe_limit_args() {
73
+ local limit_val="$1"
74
+ if [[ -n "${limit_val}" ]]; then
75
+ echo "--limit" "${limit_val}"
76
+ fi
77
+ }
78
+
79
+ maybe_lora_args() {
80
+ if [[ "${USE_LORA}" == "1" ]]; then
81
+ echo "--use_lora"
82
+ else
83
+ echo ""
84
+ fi
85
+ }
86
+
87
+ ############################################
88
+ # Stage 1
89
+ ############################################
90
+ echo "==== Stage 1: projector alignment/pretrain ===="
91
+ python scripts/train_unsloth_kaggle.py \
92
+ --stage stage1 \
93
+ --manthan_model "${MANTHAN_MODEL}" \
94
+ --text_model "${TEXT_MODEL}" \
95
+ --dataset "${STAGE1_DS}" \
96
+ --output_dir "${STAGE1_OUT}" \
97
+ $(maybe_lora_args) \
98
+ --max_length "${MAX_LENGTH}" \
99
+ --image_size "${IMAGE_SIZE}" \
100
+ --batch_size "${BATCH_SIZE}" \
101
+ --grad_accum "${GRAD_ACCUM}" \
102
+ --lr "${LR}" \
103
+ --epochs "${EPOCHS_STAGE1}" \
104
+ $(maybe_limit_args "${LIMIT_STAGE1}")
105
+
106
+ ############################################
107
+ # Stage 2
108
+ ############################################
109
+ echo "==== Stage 2: instruction finetune ===="
110
+ python scripts/train_unsloth_kaggle.py \
111
+ --stage stage2 \
112
+ --manthan_model "${MANTHAN_MODEL}" \
113
+ --text_model "${TEXT_MODEL}" \
114
+ --dataset "${STAGE2_DS}" \
115
+ --output_dir "${STAGE2_OUT}" \
116
+ $(maybe_lora_args) \
117
+ --max_length "${MAX_LENGTH}" \
118
+ --image_size "${IMAGE_SIZE}" \
119
+ --batch_size "${BATCH_SIZE}" \
120
+ --grad_accum "${GRAD_ACCUM}" \
121
+ --lr "${LR}" \
122
+ --epochs "${EPOCHS_STAGE2}" \
123
+ $(maybe_limit_args "${LIMIT_STAGE2}")
124
+
125
+ echo "==== Done ===="
126
+ echo "Stage1 outputs: ${STAGE1_OUT}"
127
+ echo "Stage2 outputs: ${STAGE2_OUT}"
128
+
129
+ ############################################
130
+ # Optional upload (manual control)
131
+ ############################################
132
+ if [[ "${UPLOAD}" == "1" ]]; then
133
+ echo "UPLOAD=1: attempting to upload artifacts (requires HF auth)."
134
+ python -m pip install -U huggingface_hub
135
+ echo "You can now upload ${OUT_BASE} with your preferred workflow."
136
+ fi
scripts/smoke_export_load.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Smoke test: export a HF folder, then load it with trust_remote_code.
4
+
5
+ This is meant to catch:
6
+ - remote-code syntax/indentation errors
7
+ - missing auto_map
8
+ - missing model weights (optional stub)
9
+ - basic forward/generate wiring regressions
10
+
11
+ It intentionally uses the small stub-weights mode so it does not download big models.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import subprocess
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+
24
+ def main() -> int:
25
+ repo_root = Path(__file__).resolve().parents[1]
26
+ out_dir = repo_root / "hf_export_ready"
27
+
28
+ if out_dir.exists():
29
+ # keep it simple
30
+ subprocess.run(["rm", "-rf", str(out_dir)], check=True)
31
+
32
+ subprocess.run(
33
+ [
34
+ sys.executable,
35
+ str(repo_root / "scripts" / "export_hf.py"),
36
+ "--out",
37
+ str(out_dir),
38
+ "--write_stub_weights",
39
+ ],
40
+ check=True,
41
+ )
42
+
43
+ print("Loading tokenizer...")
44
+ tok = AutoTokenizer.from_pretrained(out_dir, trust_remote_code=True)
45
+
46
+ print("Loading model...")
47
+ model = AutoModelForCausalLM.from_pretrained(out_dir, trust_remote_code=True)
48
+ model.eval()
49
+
50
+ # Basic forward pass (text-only)
51
+ ids = tok("Hello", return_tensors="pt").input_ids
52
+ with torch.inference_mode():
53
+ out = model(input_ids=ids)
54
+ assert out.logits.shape[:2] == ids.shape
55
+
56
+ # Tiny generate smoke
57
+ with torch.inference_mode():
58
+ gen = model.generate(ids, max_new_tokens=4, use_cache=False)
59
+ assert gen.shape[0] == 1
60
+
61
+ print("SMOKE OK")
62
+ return 0
63
+
64
+
65
+ if __name__ == "__main__":
66
+ raise SystemExit(main())
scripts/train_unsloth_kaggle.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kaggle/Unsloth training entrypoint for Manthan-T1 (TinyLLaVA-style).
2
+
3
+ This script is intended to be copied into a Kaggle notebook and run on 2×T4.
4
+ It supports two stages:
5
+ - stage1: projector alignment pretraining (e.g., LLaVA-CC3M-Pretrain-595K)
6
+ - stage2: instruction tuning (e.g., LLaVA-Instruct-150K)
7
+
8
+ Notes:
9
+ - We follow MicroLLaVA/TinyLLaVA convention: IMAGE_TOKEN_INDEX = -200 is inserted
10
+ into input_ids for <image> placeholders.
11
+ - Labels are IGNORE_INDEX for everything except assistant tokens.
12
+ - This script trains:
13
+ - the multimodal projector (always)
14
+ - LoRA adapters on the text model (optional, recommended)
15
+ - vision tower is frozen by default
16
+
17
+ You still need a *real* base model + vision tower weights. Stub exports will run
18
+ but won't learn useful vision-language alignment.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import os
25
+ from dataclasses import dataclass
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import torch
29
+ from torch import nn
30
+ from torch.utils.data import Dataset
31
+
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
33
+
34
+ try:
35
+ # Fallback for non-Unsloth environments
36
+ from peft import LoraConfig, get_peft_model
37
+ except Exception: # pragma: no cover
38
+ LoraConfig = None
39
+ get_peft_model = None
40
+
41
+ try:
42
+ # Kaggle + Unsloth
43
+ from unsloth import FastLanguageModel
44
+ except Exception: # pragma: no cover
45
+ FastLanguageModel = None
46
+
47
+ try:
48
+ from datasets import load_dataset
49
+ except Exception as e: # pragma: no cover
50
+ raise RuntimeError(
51
+ "Missing dependency `datasets`. Install with `pip install datasets` (Kaggle: add to notebook)."
52
+ ) from e
53
+
54
+
55
+ IMAGE_TOKEN_INDEX = -200
56
+ IGNORE_INDEX = -100
57
+
58
+
59
+ def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX) -> List[int]:
60
+ """MicroLLaVA/TinyLLaVA tokenizer: split on '<image>' and insert a negative id."""
61
+
62
+ def _insert_separator(X, sep):
63
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
64
+
65
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
66
+
67
+ input_ids: List[int] = []
68
+ offset = 0
69
+ if (
70
+ len(prompt_chunks) > 0
71
+ and len(prompt_chunks[0]) > 0
72
+ and tokenizer.bos_token_id is not None
73
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
74
+ ):
75
+ offset = 1
76
+ input_ids.append(prompt_chunks[0][0])
77
+
78
+ for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
79
+ input_ids.extend(x[offset:])
80
+
81
+ return input_ids
82
+
83
+
84
+ def build_prompt_from_conversations(conversations: List[Dict[str, str]]) -> Tuple[str, str]:
85
+ """Return (full_prompt, assistant_answer_text).
86
+
87
+ LLaVA datasets are 2-turn: human then gpt.
88
+ We map to the string template used in `ManthanForCausalLM.format_chat_prompt`.
89
+ """
90
+
91
+ # Expect 2 turns
92
+ human = conversations[0]["value"]
93
+ assistant = conversations[1]["value"]
94
+
95
+ system = (
96
+ "A chat between a curious user and an artificial intelligence assistant. "
97
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. "
98
+ )
99
+ # IMPORTANT: no trailing space after ASSISTANT:
100
+ full = system + f"USER: {human.strip()} ASSISTANT:" + assistant
101
+ return full, assistant
102
+
103
+
104
+ @dataclass
105
+ class TrainExample:
106
+ input_ids: torch.LongTensor
107
+ labels: torch.LongTensor
108
+ image_path: str
109
+
110
+
111
+ class LlavaLikeDataset(Dataset):
112
+ def __init__(
113
+ self,
114
+ ds_name: str,
115
+ split: str,
116
+ tokenizer,
117
+ max_length: int,
118
+ limit: Optional[int] = None,
119
+ ) -> None:
120
+ self.tokenizer = tokenizer
121
+ self.max_length = max_length
122
+
123
+ # Streaming keeps Kaggle disk usage low.
124
+ self.ds = load_dataset(ds_name, split=split, streaming=True)
125
+ self.limit = limit
126
+
127
+ # Materialize a small index for non-streaming dataloader behavior.
128
+ self._cache: List[Dict[str, Any]] = []
129
+ for i, ex in enumerate(self.ds):
130
+ self._cache.append(ex)
131
+ if limit is not None and i + 1 >= limit:
132
+ break
133
+
134
+ def __len__(self) -> int:
135
+ return len(self._cache)
136
+
137
+ def __getitem__(self, idx: int) -> TrainExample:
138
+ ex = self._cache[idx]
139
+ image_path = ex["image"]
140
+ conversations = ex["conversations"]
141
+
142
+ full_prompt, _assistant = build_prompt_from_conversations(conversations)
143
+ ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX)
144
+
145
+ # Truncate
146
+ ids = ids[: self.max_length]
147
+
148
+ # Labels: only learn on assistant answer tokens.
149
+ # Simple heuristic: find the last occurrence of " ASSISTANT:" marker.
150
+ marker = " ASSISTANT:"
151
+ marker_ids = self.tokenizer(marker).input_ids
152
+
153
+ # Find marker in tokenized ids (best-effort).
154
+ start = 0
155
+ for j in range(0, len(ids) - len(marker_ids) + 1):
156
+ if ids[j : j + len(marker_ids)] == marker_ids:
157
+ start = j + len(marker_ids)
158
+
159
+ labels = [IGNORE_INDEX] * len(ids)
160
+ for j in range(start, len(ids)):
161
+ if ids[j] == IMAGE_TOKEN_INDEX:
162
+ labels[j] = IGNORE_INDEX
163
+ else:
164
+ labels[j] = ids[j]
165
+
166
+ return TrainExample(
167
+ input_ids=torch.tensor(ids, dtype=torch.long),
168
+ labels=torch.tensor(labels, dtype=torch.long),
169
+ image_path=image_path,
170
+ )
171
+
172
+
173
+ def load_image_tensor(image_path: str, image_size: int) -> torch.FloatTensor:
174
+ """Load image from local path in dataset.
175
+
176
+ In Kaggle, LLaVA datasets provide image paths relative to the dataset repo.
177
+ Hugging Face datasets streaming yields paths that resolve via HF cache.
178
+ """
179
+
180
+ from PIL import Image
181
+ import torchvision.transforms as T
182
+
183
+ img = Image.open(image_path).convert("RGB")
184
+ tfm = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()])
185
+ return tfm(img)
186
+
187
+
188
+ def collate_fn(batch: List[TrainExample], image_size: int) -> Dict[str, torch.Tensor]:
189
+ # Pad to max length
190
+ max_len = max(x.input_ids.numel() for x in batch)
191
+ input_ids = torch.full((len(batch), max_len), 0, dtype=torch.long)
192
+ labels = torch.full((len(batch), max_len), IGNORE_INDEX, dtype=torch.long)
193
+ attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long)
194
+
195
+ for i, ex in enumerate(batch):
196
+ L = ex.input_ids.numel()
197
+ input_ids[i, :L] = ex.input_ids
198
+ labels[i, :L] = ex.labels
199
+ attention_mask[i, :L] = 1
200
+
201
+ # Images
202
+ pixel_values = torch.stack([load_image_tensor(ex.image_path, image_size) for ex in batch], dim=0)
203
+
204
+ return {
205
+ "input_ids": input_ids,
206
+ "labels": labels,
207
+ "attention_mask": attention_mask,
208
+ "pixel_values": pixel_values,
209
+ }
210
+
211
+
212
+ def set_requires_grad(module: nn.Module, requires_grad: bool) -> None:
213
+ for p in module.parameters():
214
+ p.requires_grad = requires_grad
215
+
216
+
217
+ def save_projector(model, output_dir: str) -> None:
218
+ os.makedirs(output_dir, exist_ok=True)
219
+ if not hasattr(model, "projector"):
220
+ return
221
+ torch.save(model.projector.state_dict(), os.path.join(output_dir, "projector.pt"))
222
+
223
+
224
+ def maybe_add_lora_to_model(model, args) -> None:
225
+ """Attach LoRA adapters (Unsloth preferred; PEFT fallback)."""
226
+
227
+ if not args.use_lora:
228
+ return
229
+
230
+ # If the model already has adapters (e.g., loaded via Unsloth), skip.
231
+ if hasattr(model, "peft_config"):
232
+ return
233
+
234
+ if get_peft_model is None or LoraConfig is None:
235
+ raise RuntimeError("PEFT not installed, and Unsloth not available. Install `peft` or enable Unsloth.")
236
+
237
+ target_modules = [
238
+ # Qwen-like
239
+ "q_proj",
240
+ "k_proj",
241
+ "v_proj",
242
+ "o_proj",
243
+ "gate_proj",
244
+ "up_proj",
245
+ "down_proj",
246
+ # GPT-like fallback
247
+ "c_attn",
248
+ "c_proj",
249
+ ]
250
+ cfg = LoraConfig(
251
+ r=args.lora_r,
252
+ lora_alpha=args.lora_alpha,
253
+ lora_dropout=args.lora_dropout,
254
+ bias="none",
255
+ task_type="CAUSAL_LM",
256
+ target_modules=target_modules,
257
+ )
258
+
259
+ # Wrap the language model inside Manthan
260
+ model.language_model = get_peft_model(model.language_model, cfg)
261
+
262
+
263
+
264
+ def main() -> int:
265
+ ap = argparse.ArgumentParser()
266
+ ap.add_argument("--stage", choices=["stage1", "stage2"], required=True)
267
+ ap.add_argument("--text_model", type=str, default="Qwen/Qwen3-0.6B-Base")
268
+ ap.add_argument("--vision_model", type=str, default="google/siglip-so400m-patch14-384")
269
+ ap.add_argument("--dataset", type=str, required=True)
270
+ ap.add_argument("--output_dir", type=str, default="./outputs")
271
+ ap.add_argument("--max_length", type=int, default=2048)
272
+ ap.add_argument("--image_size", type=int, default=384)
273
+ ap.add_argument("--limit", type=int, default=2048, help="For debugging: number of samples to materialize")
274
+
275
+ # Training
276
+ ap.add_argument("--epochs", type=int, default=1)
277
+ ap.add_argument("--batch_size", type=int, default=1)
278
+ ap.add_argument("--grad_accum", type=int, default=16)
279
+ ap.add_argument("--lr", type=float, default=1e-4)
280
+ ap.add_argument("--warmup_ratio", type=float, default=0.03)
281
+ ap.add_argument("--use_lora", action="store_true")
282
+ ap.add_argument("--lora_r", type=int, default=16)
283
+ ap.add_argument("--lora_alpha", type=int, default=32)
284
+ ap.add_argument("--lora_dropout", type=float, default=0.05)
285
+
286
+ ap.add_argument(
287
+ "--manthan_model",
288
+ type=str,
289
+ required=True,
290
+ help="HF repo id or local path that contains Manthan remote-code (the thing you push to HF).",
291
+ )
292
+ ap.add_argument("--save_every", type=int, default=500)
293
+ ap.add_argument("--dry_run", action="store_true", help="Run a single synthetic step (no datasets).")
294
+
295
+ args = ap.parse_args()
296
+
297
+ os.makedirs(args.output_dir, exist_ok=True)
298
+
299
+ device = "cuda" if torch.cuda.is_available() else "cpu"
300
+ if device != "cuda":
301
+ print("WARNING: This script is designed for CUDA (Kaggle). Running on CPU will be extremely slow.")
302
+
303
+ # Tokenizer (use the LLM tokenizer)
304
+ tok = AutoTokenizer.from_pretrained(args.text_model, trust_remote_code=True, use_fast=False)
305
+ if tok.pad_token_id is None:
306
+ tok.pad_token = tok.eos_token
307
+
308
+ # Load Manthan remote-code model
309
+ # (This should contain config that points to your desired text_model_id & vision_model_id.)
310
+ model = AutoModelForCausalLM.from_pretrained(
311
+ args.manthan_model,
312
+ trust_remote_code=True,
313
+ torch_dtype=torch.float16 if device == "cuda" else None,
314
+ )
315
+
316
+ model.train()
317
+ model.to(device)
318
+
319
+ # Make sure we don't train the vision tower (T4-friendly)
320
+ if hasattr(model, "vision_model") and model.vision_model is not None:
321
+ set_requires_grad(model.vision_model, False)
322
+ if hasattr(model, "vision_tower") and model.vision_tower is not None:
323
+ set_requires_grad(model.vision_tower, False)
324
+
325
+ # Train projector always
326
+ if hasattr(model, "projector"):
327
+ set_requires_grad(model.projector, True)
328
+
329
+ # Add LoRA to the language model (recommended)
330
+ maybe_add_lora_to_model(model, args)
331
+
332
+ # Optimizer params = trainable only
333
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
334
+ if len(trainable_params) == 0:
335
+ raise RuntimeError("No trainable parameters. Did you freeze everything?")
336
+
337
+ optim = torch.optim.AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
338
+
339
+ # Data
340
+ if args.dry_run:
341
+ # Minimal synthetic batch (no images on disk). This just validates loss pathway.
342
+ B, T = 1, min(64, args.max_length)
343
+
344
+ # IMPORTANT: some tokenizers report an imprecise `vocab_size`; `len(tok)` is the safe upper bound.
345
+ tok_vocab = int(len(tok))
346
+ input_ids = torch.randint(low=0, high=max(tok_vocab - 1, 1), size=(B, T), dtype=torch.long)
347
+ labels = input_ids.clone()
348
+ attn = torch.ones_like(input_ids)
349
+ pixel_values = torch.randn(B, 3, args.image_size, args.image_size)
350
+
351
+ # Insert one image placeholder
352
+ input_ids[0, 5] = IMAGE_TOKEN_INDEX
353
+ labels[0, :10] = IGNORE_INDEX
354
+
355
+ # If tokenizer vocab > model vocab (common in dry_run), clamp to avoid CE index errors.
356
+ lm_vocab = None
357
+ try:
358
+ if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
359
+ lm_vocab = int(getattr(model.language_model.config, "vocab_size", 0) or 0)
360
+ except Exception:
361
+ lm_vocab = None
362
+
363
+ if lm_vocab and lm_vocab > 0:
364
+ safe_ids = input_ids.clone()
365
+ mask = safe_ids >= 0
366
+ safe_ids[mask] = safe_ids[mask].clamp(min=0, max=lm_vocab - 1)
367
+ input_ids = safe_ids
368
+
369
+ safe_labels = labels.clone()
370
+ mask = safe_labels >= 0
371
+ safe_labels[mask] = safe_labels[mask].clamp(min=0, max=lm_vocab - 1)
372
+ labels = safe_labels
373
+
374
+ batch = {
375
+ "input_ids": input_ids.to(device),
376
+ "labels": labels.to(device),
377
+ "attention_mask": attn.to(device),
378
+ "pixel_values": pixel_values.to(device),
379
+ }
380
+ out = model(**batch)
381
+ print("dry_run loss:", float(out.loss))
382
+ out.loss.backward()
383
+ optim.step()
384
+ optim.zero_grad(set_to_none=True)
385
+ save_projector(model, args.output_dir)
386
+ if hasattr(model, "language_model") and hasattr(model.language_model, "save_pretrained"):
387
+ # Save adapters if present
388
+ try:
389
+ model.language_model.save_pretrained(args.output_dir)
390
+ except Exception:
391
+ pass
392
+ return 0
393
+
394
+ ds = LlavaLikeDataset(args.dataset, split="train", tokenizer=tok, max_length=args.max_length, limit=args.limit)
395
+ from torch.utils.data import DataLoader
396
+
397
+ dl = DataLoader(
398
+ ds,
399
+ batch_size=args.batch_size,
400
+ shuffle=True,
401
+ num_workers=2,
402
+ collate_fn=lambda b: collate_fn(b, args.image_size),
403
+ )
404
+
405
+ total_steps = (len(dl) * args.epochs) // max(1, args.grad_accum)
406
+ warmup_steps = max(1, int(total_steps * args.warmup_ratio))
407
+ sched = get_cosine_schedule_with_warmup(optim, warmup_steps, total_steps)
408
+
409
+ step = 0
410
+ optim.zero_grad(set_to_none=True)
411
+ for epoch in range(args.epochs):
412
+ for micro_idx, batch in enumerate(dl):
413
+ batch = {k: v.to(device) for k, v in batch.items()}
414
+
415
+ # Mixed precision on Kaggle
416
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device == "cuda")):
417
+ out = model(**batch)
418
+ loss = out.loss / max(1, args.grad_accum)
419
+
420
+ loss.backward()
421
+
422
+ if (micro_idx + 1) % args.grad_accum == 0:
423
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
424
+ optim.step()
425
+ sched.step()
426
+ optim.zero_grad(set_to_none=True)
427
+ step += 1
428
+
429
+ if step % 10 == 0:
430
+ print(f"epoch={epoch} step={step}/{total_steps} loss={float(out.loss):.4f}")
431
+
432
+ if step % args.save_every == 0:
433
+ save_projector(model, args.output_dir)
434
+ # Save adapters if any
435
+ try:
436
+ model.save_pretrained(args.output_dir)
437
+ except Exception:
438
+ pass
439
+
440
+ if step >= total_steps:
441
+ break
442
+
443
+ save_projector(model, args.output_dir)
444
+ try:
445
+ model.save_pretrained(args.output_dir)
446
+ except Exception:
447
+ pass
448
+
449
+ print("DONE")
450
+ return 0
451
+
452
+
453
+ if __name__ == "__main__":
454
+ raise SystemExit(main())
tests/test_smoke.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from manthan_t1.smoke_test import TinyToyModel
2
+
3
+
4
+ def test_tiny_model_forward():
5
+ import mlx.core as mx
6
+
7
+ m = TinyToyModel(vocab_size=64, d_model=32)
8
+ x = mx.random.randint(0, 64, shape=(2, 8))
9
+ y = m(x)
10
+ assert y.shape == (2, 8, 64)