Atah Alam
commited on
Commit
·
7f7a72e
0
Parent(s):
Manthan-T1 clean code-only
Browse files- .gitattributes +1 -0
- .gitignore +24 -0
- .hfignore +25 -0
- MODEL_CARD.md +68 -0
- README.md +29 -0
- docs/KAGGLE_TRAINING.md +114 -0
- hf_export_stub/added_tokens.json +5 -0
- hf_export_stub/chat_template.jinja +26 -0
- hf_export_stub/config.json +16 -0
- hf_export_stub/special_tokens_map.json +7 -0
- hf_export_stub/tokenizer_config.json +6 -0
- manthan_t1/__init__.py +2 -0
- manthan_t1/configuration_manthan.py +114 -0
- manthan_t1/hf_integration_smoke.py +60 -0
- manthan_t1/hf_smoke.py +43 -0
- manthan_t1/modeling_manthan.py +654 -0
- manthan_t1/smoke_test.py +63 -0
- manthan_t1/text_generate_smoke.py +30 -0
- manthan_t1/tokenizer_smoke.py +28 -0
- manthan_t1/tokenizer_utils.py +52 -0
- pyproject.toml +29 -0
- requirements.txt +0 -0
- scripts/export_hf.py +208 -0
- scripts/infer_hf.py +30 -0
- scripts/infer_qwen3_siglip2.py +50 -0
- scripts/kaggle_train_all.sh +136 -0
- scripts/smoke_export_load.py +66 -0
- scripts/train_unsloth_kaggle.py +454 -0
- tests/test_smoke.py +10 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Keep this repo code-only. Don"t use Git LFS here.
|
.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
.DS_Store
|
| 6 |
+
|
| 7 |
+
# python tooling
|
| 8 |
+
.pytest_cache/
|
| 9 |
+
.mypy_cache/
|
| 10 |
+
.ruff_cache/
|
| 11 |
+
|
| 12 |
+
# local caches
|
| 13 |
+
.cache/
|
| 14 |
+
hf/
|
| 15 |
+
|
| 16 |
+
# training artifacts
|
| 17 |
+
wandb/
|
| 18 |
+
runs/
|
| 19 |
+
checkpoints/
|
| 20 |
+
outputs/
|
| 21 |
+
tmp/
|
| 22 |
+
|
| 23 |
+
# local exports (keep them out of git; publish separately if needed)
|
| 24 |
+
hf_export_ready/
|
.hfignore
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
.DS_Store
|
| 6 |
+
|
| 7 |
+
# local/hf caches
|
| 8 |
+
.cache/
|
| 9 |
+
hf/
|
| 10 |
+
|
| 11 |
+
# training artifacts
|
| 12 |
+
wandb/
|
| 13 |
+
runs/
|
| 14 |
+
checkpoints/
|
| 15 |
+
outputs/
|
| 16 |
+
tmp/
|
| 17 |
+
/tmp/
|
| 18 |
+
|
| 19 |
+
# large local exports (only push intentionally)
|
| 20 |
+
hf_export_ready/
|
| 21 |
+
|
| 22 |
+
.venv/
|
| 23 |
+
__pycache__/
|
| 24 |
+
*.pyc
|
| 25 |
+
.DS_Store
|
MODEL_CARD.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
tags:
|
| 7 |
+
- pytorch
|
| 8 |
+
- safetensors
|
| 9 |
+
- vision-language
|
| 10 |
+
- visual-question-answering
|
| 11 |
+
pipeline_tag: visual-question-answering
|
| 12 |
+
base_model:
|
| 13 |
+
- Qwen/Qwen3-0.6B
|
| 14 |
+
- google/siglip-so400m-patch14-384
|
| 15 |
+
model-index:
|
| 16 |
+
- name: Manthan-T1
|
| 17 |
+
results:
|
| 18 |
+
- task:
|
| 19 |
+
type: visual-question-answering
|
| 20 |
+
name: VQAv2
|
| 21 |
+
dataset:
|
| 22 |
+
name: VQAv2
|
| 23 |
+
type: vqav2
|
| 24 |
+
metrics:
|
| 25 |
+
- name: Overall Accuracy
|
| 26 |
+
type: accuracy
|
| 27 |
+
value: 0.0
|
| 28 |
+
- name: Yes/No Accuracy
|
| 29 |
+
type: accuracy
|
| 30 |
+
value: 0.0
|
| 31 |
+
- name: Number Accuracy
|
| 32 |
+
type: accuracy
|
| 33 |
+
value: 0.0
|
| 34 |
+
- name: Other Accuracy
|
| 35 |
+
type: accuracy
|
| 36 |
+
value: 0.0
|
| 37 |
+
source:
|
| 38 |
+
name: Pending
|
| 39 |
+
url: https://visualqa.org/download.html
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
# Manthan-T1
|
| 43 |
+
|
| 44 |
+
A custom **Transformers** architecture for a compact vision-language model.
|
| 45 |
+
|
| 46 |
+
## Status
|
| 47 |
+
|
| 48 |
+
This repo currently contains:
|
| 49 |
+
|
| 50 |
+
- `ManthanConfig` (`manthan_t1/configuration_manthan.py`)
|
| 51 |
+
- `ManthanForCausalLM` (`manthan_t1/modeling_manthan.py`)
|
| 52 |
+
- vision encoder (minimal ViT-like)
|
| 53 |
+
- projector to text hidden size
|
| 54 |
+
- decoder LM (placeholder GPT-2 by default for smoke tests)
|
| 55 |
+
|
| 56 |
+
Planned next steps:
|
| 57 |
+
|
| 58 |
+
- Swap the text backbone to **Qwen3-0.6B** via `text_config` + weight loading
|
| 59 |
+
- Swap the vision tower to **SigLIP2-so400m (patch14-384)** and align image token handling
|
| 60 |
+
- Add proper processor + chat template to enforce **reply in user’s input language** (Tamil/Hindi/etc.)
|
| 61 |
+
|
| 62 |
+
## Loading
|
| 63 |
+
|
| 64 |
+
This is intended to be loaded with:
|
| 65 |
+
|
| 66 |
+
- `AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)`
|
| 67 |
+
|
| 68 |
+
See `scripts/infer_hf.py`.
|
README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Manthan-T1
|
| 2 |
+
|
| 3 |
+
A from-scratch scaffold for a custom **Transformers** vision-language architecture named **Manthan-T1**.
|
| 4 |
+
|
| 5 |
+
## What you get (today)
|
| 6 |
+
- A clean project layout under `manthan_t1/`
|
| 7 |
+
- A full HF custom architecture:
|
| 8 |
+
- `ManthanConfig` in `manthan_t1/configuration_manthan.py`
|
| 9 |
+
- `ManthanForCausalLM` in `manthan_t1/modeling_manthan.py`
|
| 10 |
+
- A no-download HF forward smoke test:
|
| 11 |
+
- `python -m manthan_t1.hf_smoke`
|
| 12 |
+
- An MLX smoke test (kept for Apple Silicon readiness):
|
| 13 |
+
- `python -m manthan_t1.smoke_test`
|
| 14 |
+
|
| 15 |
+
## What we’ll add next
|
| 16 |
+
- Qwen3-0.6B backbone wiring + weight loading (keeping the model type = `manthan_t1`)
|
| 17 |
+
- SigLIP2 vision tower wiring + projector alignment
|
| 18 |
+
- LoRA fine-tuning recipes for M4 16GB (MLX +/or PyTorch)
|
| 19 |
+
- Multilingual “reply in user language” policy (Indian languages)
|
| 20 |
+
|
| 21 |
+
## Quick smoke test
|
| 22 |
+
After installing dependencies, you should be able to run:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
python -m manthan_t1.smoke_test
|
| 26 |
+
python -m manthan_t1.hf_smoke
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
This does **not** download any external models yet.
|
docs/KAGGLE_TRAINING.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kaggle training (2×T4) – Manthan‑T1
|
| 2 |
+
|
| 3 |
+
This repo includes a Kaggle-oriented training entrypoint:
|
| 4 |
+
|
| 5 |
+
- `scripts/train_unsloth_kaggle.py`
|
| 6 |
+
|
| 7 |
+
It uses the same LLaVA-style dataset format as TinyLLaVA/MicroLLaVA:
|
| 8 |
+
- dataset sample keys: `image`, `conversations`, `id`
|
| 9 |
+
- `conversations`: `[{'from':'human','value':'...<image>...'}, {'from':'gpt','value':'...'}]`
|
| 10 |
+
- uses `IMAGE_TOKEN_INDEX = -200`
|
| 11 |
+
- uses `IGNORE_INDEX = -100` for masked labels
|
| 12 |
+
|
| 13 |
+
## What this script trains
|
| 14 |
+
|
| 15 |
+
Default (recommended for 2×T4):
|
| 16 |
+
- vision tower: **frozen**
|
| 17 |
+
- multimodal projector: **trainable** (always)
|
| 18 |
+
- LLM: **LoRA adapters** (optional, enable `--use_lora`)
|
| 19 |
+
|
| 20 |
+
This matches the standard LLaVA/TinyLLaVA recipe: align projector first, then instruction tune.
|
| 21 |
+
|
| 22 |
+
## Kaggle setup checklist
|
| 23 |
+
|
| 24 |
+
1) Enable GPU (2×T4) in Kaggle.
|
| 25 |
+
2) Ensure `pip` deps exist in the notebook:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
pip install -U transformers accelerate datasets peft
|
| 29 |
+
# Optional (recommended): Unsloth if available in your notebook image
|
| 30 |
+
pip install -U unsloth
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
3) Clone your HF repo or `git clone` the repo.
|
| 34 |
+
4) Set HF cache to persist in `/kaggle/working` so it survives “Save Version”:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
export HF_HOME=/kaggle/working/hf
|
| 38 |
+
export TRANSFORMERS_CACHE=/kaggle/working/hf/transformers
|
| 39 |
+
export HF_DATASETS_CACHE=/kaggle/working/hf/datasets
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Stage 1 (projector alignment)
|
| 43 |
+
|
| 44 |
+
Use a smaller pretrain set first:
|
| 45 |
+
- `liuhaotian/LLaVA-CC3M-Pretrain-595K`
|
| 46 |
+
|
| 47 |
+
Example run:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python scripts/train_unsloth_kaggle.py \
|
| 51 |
+
--stage stage1 \
|
| 52 |
+
--manthan_model <YOUR_HF_REPO_OR_LOCAL_PATH> \
|
| 53 |
+
--text_model Qwen/Qwen3-0.6B-Base \
|
| 54 |
+
--dataset liuhaotian/LLaVA-CC3M-Pretrain-595K \
|
| 55 |
+
--output_dir /kaggle/working/manthan_stage1 \
|
| 56 |
+
--use_lora \
|
| 57 |
+
--max_length 2048 \
|
| 58 |
+
--image_size 384 \
|
| 59 |
+
--batch_size 1 \
|
| 60 |
+
--grad_accum 32 \
|
| 61 |
+
--lr 1e-4 \
|
| 62 |
+
--epochs 1 \
|
| 63 |
+
--limit 20000
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
Notes:
|
| 67 |
+
- Increase `--limit` as you gain confidence.
|
| 68 |
+
- If you run out of VRAM, reduce `--max_length` or increase `--grad_accum`.
|
| 69 |
+
|
| 70 |
+
## Stage 2 (instruction tuning)
|
| 71 |
+
|
| 72 |
+
Dataset:
|
| 73 |
+
- `liuhaotian/LLaVA-Instruct-150K`
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python scripts/train_unsloth_kaggle.py \
|
| 77 |
+
--stage stage2 \
|
| 78 |
+
--manthan_model <YOUR_HF_REPO_OR_LOCAL_PATH> \
|
| 79 |
+
--text_model Qwen/Qwen3-0.6B-Base \
|
| 80 |
+
--dataset liuhaotian/LLaVA-Instruct-150K \
|
| 81 |
+
--output_dir /kaggle/working/manthan_stage2 \
|
| 82 |
+
--use_lora \
|
| 83 |
+
--max_length 2048 \
|
| 84 |
+
--image_size 384 \
|
| 85 |
+
--batch_size 1 \
|
| 86 |
+
--grad_accum 32 \
|
| 87 |
+
--lr 1e-4 \
|
| 88 |
+
--epochs 1 \
|
| 89 |
+
--limit 150000
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Outputs
|
| 93 |
+
|
| 94 |
+
The script saves into `--output_dir`:
|
| 95 |
+
- `projector.pt` (multimodal projector weights)
|
| 96 |
+
- `save_pretrained()` output for the model (includes remote-code config; adapters if supported)
|
| 97 |
+
|
| 98 |
+
In practice, you’ll likely upload these artifacts back to HF.
|
| 99 |
+
|
| 100 |
+
## Dry run (local)
|
| 101 |
+
|
| 102 |
+
To validate the training loop without datasets:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
python scripts/train_unsloth_kaggle.py \
|
| 106 |
+
--stage stage1 \
|
| 107 |
+
--manthan_model hf_export_ready \
|
| 108 |
+
--text_model gpt2 \
|
| 109 |
+
--dataset dummy \
|
| 110 |
+
--output_dir ./tmp_out \
|
| 111 |
+
--dry_run
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
(For real Kaggle training, don’t use stub weights.)
|
hf_export_stub/added_tokens.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image>": 0,
|
| 3 |
+
"<im_start>": 0,
|
| 4 |
+
"<im_end>": 0
|
| 5 |
+
}
|
hf_export_stub/chat_template.jinja
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% set system = system_message | default('You are Manthan-T1, a helpful multimodal assistant.') %}
|
| 2 |
+
{% set user_lang_rule = 'Reply in the same language as the user. If the user writes in Tamil, reply in Tamil; if Hindi then Hindi; if English then English.' %}
|
| 3 |
+
|
| 4 |
+
{% if messages[0]['role'] != 'system' %}
|
| 5 |
+
<|system|>
|
| 6 |
+
{{ system }}\n{{ user_lang_rule }}
|
| 7 |
+
<|end|>
|
| 8 |
+
{% endif %}
|
| 9 |
+
|
| 10 |
+
{% for m in messages %}
|
| 11 |
+
{% if m['role'] == 'system' %}
|
| 12 |
+
<|system|>
|
| 13 |
+
{{ m['content'] }}\n{{ user_lang_rule }}
|
| 14 |
+
<|end|>
|
| 15 |
+
{% elif m['role'] == 'user' %}
|
| 16 |
+
<|user|>
|
| 17 |
+
{{ m['content'] }}
|
| 18 |
+
<|end|>
|
| 19 |
+
{% elif m['role'] == 'assistant' %}
|
| 20 |
+
<|assistant|>
|
| 21 |
+
{{ m['content'] }}
|
| 22 |
+
<|end|>
|
| 23 |
+
{% endif %}
|
| 24 |
+
{% endfor %}
|
| 25 |
+
|
| 26 |
+
<|assistant|>
|
hf_export_stub/config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "manthan_t1",
|
| 3 |
+
"architectures": ["ManthanForCausalLM"],
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configuration_manthan.ManthanConfig",
|
| 6 |
+
"AutoModelForCausalLM": "modeling_manthan.ManthanForCausalLM"
|
| 7 |
+
},
|
| 8 |
+
"text_model_id": "Qwen/Qwen3-0.6B",
|
| 9 |
+
"vision_model_id": "google/siglip-so400m-patch14-384",
|
| 10 |
+
"vision_image_size": 384,
|
| 11 |
+
"vision_patch_size": 14,
|
| 12 |
+
"vision_feature_select": "patch",
|
| 13 |
+
"num_image_tokens": 256,
|
| 14 |
+
"image_token_id": 0,
|
| 15 |
+
"torch_dtype": "float16"
|
| 16 |
+
}
|
hf_export_stub/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<image>",
|
| 4 |
+
"<im_start>",
|
| 5 |
+
"<im_end>"
|
| 6 |
+
]
|
| 7 |
+
}
|
hf_export_stub/tokenizer_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_max_length": 4096,
|
| 3 |
+
"padding_side": "right",
|
| 4 |
+
"truncation_side": "right",
|
| 5 |
+
"use_fast": false
|
| 6 |
+
}
|
manthan_t1/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ["__version__"]
|
| 2 |
+
__version__ = "0.0.1"
|
manthan_t1/configuration_manthan.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ManthanConfig(PretrainedConfig):
|
| 10 |
+
"""Configuration for Manthan-T1.
|
| 11 |
+
|
| 12 |
+
Matches key MicroLLaVA/TinyLLaVA conventions:
|
| 13 |
+
- `image_token_index` is a negative placeholder id (default -200)
|
| 14 |
+
- keep `image_token_id` as an alias (defaults to image_token_index)
|
| 15 |
+
|
| 16 |
+
`text_config` is kept as a dict for JSON serialization.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
model_type = "manthan_t1"
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
text_config: Optional[dict] = None,
|
| 24 |
+
text_model_id: Optional[str] = None,
|
| 25 |
+
vision_model_id: Optional[str] = None,
|
| 26 |
+
vision_hidden_size: int = 1024,
|
| 27 |
+
vision_num_hidden_layers: int = 24,
|
| 28 |
+
vision_num_attention_heads: int = 16,
|
| 29 |
+
vision_image_size: int = 384,
|
| 30 |
+
vision_patch_size: int = 14,
|
| 31 |
+
projector_hidden_size: Optional[int] = None,
|
| 32 |
+
image_token_index: int = -200,
|
| 33 |
+
image_token_id: Optional[int] = None,
|
| 34 |
+
num_image_tokens: int = 256,
|
| 35 |
+
vision_feature_select: str = "patch",
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(**kwargs)
|
| 39 |
+
|
| 40 |
+
self.text_config_dict = text_config or {}
|
| 41 |
+
|
| 42 |
+
# Optional resolved config for HF generation helpers
|
| 43 |
+
self.text_config_obj: Optional[PretrainedConfig] = None
|
| 44 |
+
if self.text_config_dict.get("model_type"):
|
| 45 |
+
from transformers import AutoConfig
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
self.text_config_obj = AutoConfig.for_model(**self.text_config_dict)
|
| 49 |
+
except Exception:
|
| 50 |
+
self.text_config_obj = None
|
| 51 |
+
|
| 52 |
+
self.text_model_id = text_model_id
|
| 53 |
+
self.vision_model_id = vision_model_id
|
| 54 |
+
|
| 55 |
+
self.vision_hidden_size = int(vision_hidden_size)
|
| 56 |
+
self.vision_num_hidden_layers = int(vision_num_hidden_layers)
|
| 57 |
+
self.vision_num_attention_heads = int(vision_num_attention_heads)
|
| 58 |
+
self.vision_image_size = int(vision_image_size)
|
| 59 |
+
self.vision_patch_size = int(vision_patch_size)
|
| 60 |
+
|
| 61 |
+
self.projector_hidden_size = projector_hidden_size
|
| 62 |
+
|
| 63 |
+
self.image_token_index = int(image_token_index)
|
| 64 |
+
self.image_token_id = int(image_token_id) if image_token_id is not None else int(image_token_index)
|
| 65 |
+
self.num_image_tokens = int(num_image_tokens)
|
| 66 |
+
self.vision_feature_select = vision_feature_select
|
| 67 |
+
|
| 68 |
+
# -------- Generation-related compatibility --------
|
| 69 |
+
# Transformers' generation utilities (DynamicCache, etc.) expect certain
|
| 70 |
+
# attributes on the *decoder/text* config. Since ManthanConfig is a
|
| 71 |
+
# wrapper that may not always carry a resolved `text_config_obj`, we set
|
| 72 |
+
# conservative defaults here to keep `model.generate()` functional in
|
| 73 |
+
# stub/export scenarios.
|
| 74 |
+
self.num_hidden_layers = int(
|
| 75 |
+
getattr(self.text_config_obj, "num_hidden_layers", kwargs.get("num_hidden_layers", 1))
|
| 76 |
+
if self.text_config_obj is not None
|
| 77 |
+
else kwargs.get("num_hidden_layers", 1)
|
| 78 |
+
)
|
| 79 |
+
self.num_attention_heads = int(
|
| 80 |
+
getattr(self.text_config_obj, "num_attention_heads", kwargs.get("num_attention_heads", 1))
|
| 81 |
+
if self.text_config_obj is not None
|
| 82 |
+
else kwargs.get("num_attention_heads", 1)
|
| 83 |
+
)
|
| 84 |
+
self.hidden_size = int(
|
| 85 |
+
getattr(self.text_config_obj, "hidden_size", kwargs.get("hidden_size", 256))
|
| 86 |
+
if self.text_config_obj is not None
|
| 87 |
+
else kwargs.get("hidden_size", 256)
|
| 88 |
+
)
|
| 89 |
+
self.max_position_embeddings = int(
|
| 90 |
+
getattr(self.text_config_obj, "max_position_embeddings", kwargs.get("max_position_embeddings", 2048))
|
| 91 |
+
if self.text_config_obj is not None
|
| 92 |
+
else kwargs.get("max_position_embeddings", 2048)
|
| 93 |
+
)
|
| 94 |
+
self.vocab_size = int(
|
| 95 |
+
getattr(self.text_config_obj, "vocab_size", kwargs.get("vocab_size", 32000))
|
| 96 |
+
if self.text_config_obj is not None
|
| 97 |
+
else kwargs.get("vocab_size", 32000)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def get_text_config(self, decoder: bool = False):
|
| 101 |
+
# Transformers' GenerationConfig helpers call get_text_config() during
|
| 102 |
+
# PreTrainedModel initialization. For stub/export-time configs we may not
|
| 103 |
+
# have a resolved text backbone yet; in that case, fall back to self.
|
| 104 |
+
if self.text_config_obj is None:
|
| 105 |
+
return self
|
| 106 |
+
return self.text_config_obj
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class ManthanBatch:
|
| 111 |
+
input_ids: "torch.LongTensor"
|
| 112 |
+
attention_mask: Optional["torch.LongTensor"]
|
| 113 |
+
pixel_values: Optional["torch.FloatTensor"]
|
| 114 |
+
labels: Optional["torch.LongTensor"]
|
manthan_t1/hf_integration_smoke.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optional integration smoke test.
|
| 2 |
+
|
| 3 |
+
This WILL download models (big) if run.
|
| 4 |
+
|
| 5 |
+
It checks that:
|
| 6 |
+
- Qwen/Qwen3-0.6B loads as the text backbone
|
| 7 |
+
- google/siglip-so400m-patch14-384 loads as the vision backbone
|
| 8 |
+
- a single forward pass works with <image> token injection
|
| 9 |
+
|
| 10 |
+
Run (optional):
|
| 11 |
+
python -m manthan_t1.hf_integration_smoke
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import AutoTokenizer
|
| 18 |
+
|
| 19 |
+
from manthan_t1.configuration_manthan import ManthanConfig
|
| 20 |
+
from manthan_t1.modeling_manthan import ManthanForCausalLM
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main() -> None:
|
| 24 |
+
text_id = "Qwen/Qwen3-0.6B"
|
| 25 |
+
vision_id = "google/siglip-so400m-patch14-384"
|
| 26 |
+
|
| 27 |
+
cfg = ManthanConfig(
|
| 28 |
+
text_model_id=text_id,
|
| 29 |
+
vision_model_id=vision_id,
|
| 30 |
+
# SigLIP so400m patch14 384
|
| 31 |
+
vision_image_size=384,
|
| 32 |
+
vision_patch_size=14,
|
| 33 |
+
# 384/14 is non-integer; many siglip variants still use patch14, but token count comes from model.
|
| 34 |
+
# We'll keep num_image_tokens as 256 to match common LLaVA-style settings.
|
| 35 |
+
num_image_tokens=256,
|
| 36 |
+
image_token_id=151665,
|
| 37 |
+
vision_feature_select="patch",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
model = ManthanForCausalLM(cfg)
|
| 41 |
+
model.eval()
|
| 42 |
+
|
| 43 |
+
tok = AutoTokenizer.from_pretrained(text_id, use_fast=False, trust_remote_code=True)
|
| 44 |
+
|
| 45 |
+
# Create a prompt with enough <image> placeholders
|
| 46 |
+
image_tok = tok.decode([cfg.image_token_id])
|
| 47 |
+
prompt = (image_tok + " ") * cfg.num_image_tokens + "\nDescribe the image."
|
| 48 |
+
inputs = tok(prompt, return_tensors="pt")
|
| 49 |
+
|
| 50 |
+
# Dummy image tensor with expected size; processor correctness is handled in `chat()`.
|
| 51 |
+
pixel_values = torch.randn(1, 3, cfg.vision_image_size, cfg.vision_image_size)
|
| 52 |
+
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
out = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), pixel_values=pixel_values)
|
| 55 |
+
|
| 56 |
+
print("OK integration forward", tuple(out.logits.shape))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
manthan_t1/hf_smoke.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF Transformers smoke test for Manthan-T1.
|
| 2 |
+
|
| 3 |
+
This must run without downloading any external checkpoints.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from manthan_t1.configuration_manthan import ManthanConfig
|
| 11 |
+
from manthan_t1.modeling_manthan import ManthanForCausalLM
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main() -> None:
|
| 15 |
+
cfg = ManthanConfig(
|
| 16 |
+
text_config={"model_type": "gpt2"},
|
| 17 |
+
vision_image_size=224,
|
| 18 |
+
vision_patch_size=16,
|
| 19 |
+
vision_hidden_size=128,
|
| 20 |
+
# 224/16 = 14 patches per side => 196 patch tokens (CLS is dropped by default)
|
| 21 |
+
num_image_tokens=196,
|
| 22 |
+
image_token_id=42,
|
| 23 |
+
)
|
| 24 |
+
model = ManthanForCausalLM(cfg)
|
| 25 |
+
model.eval()
|
| 26 |
+
|
| 27 |
+
# Fake text with num_image_tokens image tokens
|
| 28 |
+
B, T = 2, 32
|
| 29 |
+
input_ids = torch.randint(0, 100, (B, T))
|
| 30 |
+
# Ensure sequence is long enough to host the image tokens
|
| 31 |
+
if T < cfg.num_image_tokens:
|
| 32 |
+
input_ids = torch.randint(0, 100, (B, cfg.num_image_tokens + 8))
|
| 33 |
+
T = input_ids.shape[1]
|
| 34 |
+
input_ids[:, : cfg.num_image_tokens] = cfg.image_token_id
|
| 35 |
+
pixel_values = torch.randn(B, 3, cfg.vision_image_size, cfg.vision_image_size)
|
| 36 |
+
|
| 37 |
+
out = model(input_ids=input_ids, pixel_values=pixel_values)
|
| 38 |
+
assert out.logits.shape[:2] == (B, T)
|
| 39 |
+
print("OK manthan hf forward", tuple(out.logits.shape))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
manthan_t1/modeling_manthan.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoConfig,
|
| 11 |
+
AutoModel,
|
| 12 |
+
AutoModelForCausalLM,
|
| 13 |
+
GenerationMixin,
|
| 14 |
+
PreTrainedModel,
|
| 15 |
+
)
|
| 16 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 17 |
+
|
| 18 |
+
from .configuration_manthan import ManthanConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
IGNORE_INDEX = -100
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ManthanVisionEncoder(nn.Module):
|
| 25 |
+
"""Minimal ViT-like vision encoder.
|
| 26 |
+
|
| 27 |
+
This is intentionally simple so the architecture is fully defined in this repo.
|
| 28 |
+
You can later swap it with SigLIP2 weights by mapping parameters.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, image_size: int, patch_size: int, hidden_size: int):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.image_size = image_size
|
| 34 |
+
self.patch_size = patch_size
|
| 35 |
+
self.hidden_size = hidden_size
|
| 36 |
+
|
| 37 |
+
self.proj = nn.Conv2d(3, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 38 |
+
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
| 39 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
| 40 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, hidden_size))
|
| 41 |
+
|
| 42 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 43 |
+
d_model=hidden_size,
|
| 44 |
+
nhead=max(1, hidden_size // 64),
|
| 45 |
+
dim_feedforward=hidden_size * 4,
|
| 46 |
+
batch_first=True,
|
| 47 |
+
activation="gelu",
|
| 48 |
+
)
|
| 49 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
| 50 |
+
self.ln = nn.LayerNorm(hidden_size)
|
| 51 |
+
|
| 52 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
| 53 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
| 54 |
+
|
| 55 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
| 56 |
+
# pixel_values: (B, 3, H, W)
|
| 57 |
+
x = self.proj(pixel_values) # (B, C, H', W')
|
| 58 |
+
x = x.flatten(2).transpose(1, 2) # (B, N, C)
|
| 59 |
+
cls = self.cls_token.expand(x.size(0), -1, -1)
|
| 60 |
+
x = torch.cat([cls, x], dim=1)
|
| 61 |
+
x = x + self.pos_embed[:, : x.size(1), :]
|
| 62 |
+
x = self.encoder(x)
|
| 63 |
+
return self.ln(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ManthanProjector(nn.Module):
|
| 67 |
+
def __init__(self, vision_hidden: int, text_hidden: int, mid: Optional[int] = None):
|
| 68 |
+
super().__init__()
|
| 69 |
+
mid = mid or max(text_hidden, vision_hidden)
|
| 70 |
+
self.net = nn.Sequential(
|
| 71 |
+
nn.Linear(vision_hidden, mid),
|
| 72 |
+
nn.GELU(),
|
| 73 |
+
nn.Linear(mid, text_hidden),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
| 77 |
+
return self.net(x)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ManthanForCausalLM(PreTrainedModel, GenerationMixin):
|
| 81 |
+
config_class = ManthanConfig
|
| 82 |
+
base_model_prefix = "manthan"
|
| 83 |
+
|
| 84 |
+
def __init__(self, config: ManthanConfig):
|
| 85 |
+
super().__init__(config)
|
| 86 |
+
|
| 87 |
+
# Text backbone
|
| 88 |
+
# Priority:
|
| 89 |
+
# 1) `text_model_id` (e.g., Qwen/Qwen3-0.6B)
|
| 90 |
+
# 2) `text_config` dict (model_type + overrides)
|
| 91 |
+
# 3) fallback tiny GPT-2 for smoke tests
|
| 92 |
+
if config.text_model_id:
|
| 93 |
+
self.language_model = AutoModelForCausalLM.from_pretrained(
|
| 94 |
+
config.text_model_id,
|
| 95 |
+
trust_remote_code=True,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
text_cfg = getattr(config, "text_config_obj", None)
|
| 99 |
+
if text_cfg is None:
|
| 100 |
+
text_cfg = (
|
| 101 |
+
AutoConfig.for_model(**config.text_config_dict)
|
| 102 |
+
if getattr(config, "text_config_dict", {}).get("model_type")
|
| 103 |
+
else None
|
| 104 |
+
)
|
| 105 |
+
if text_cfg is None:
|
| 106 |
+
from transformers import GPT2Config
|
| 107 |
+
|
| 108 |
+
text_cfg = GPT2Config(
|
| 109 |
+
n_embd=256,
|
| 110 |
+
n_layer=4,
|
| 111 |
+
n_head=4,
|
| 112 |
+
vocab_size=32000,
|
| 113 |
+
)
|
| 114 |
+
self.language_model = AutoModelForCausalLM.from_config(text_cfg)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
text_hidden = self.language_model.config.hidden_size
|
| 118 |
+
|
| 119 |
+
# Vision backbone
|
| 120 |
+
self.vision_model = None
|
| 121 |
+
if config.vision_model_id:
|
| 122 |
+
self.vision_model = AutoModel.from_pretrained(config.vision_model_id, trust_remote_code=True)
|
| 123 |
+
vision_hidden = getattr(getattr(self.vision_model, "config", None), "hidden_size", None)
|
| 124 |
+
if vision_hidden is not None:
|
| 125 |
+
config.vision_hidden_size = int(vision_hidden)
|
| 126 |
+
|
| 127 |
+
# Fallback toy tower remains available (used when vision_model_id is not set)
|
| 128 |
+
self.vision_tower = ManthanVisionEncoder(
|
| 129 |
+
image_size=config.vision_image_size,
|
| 130 |
+
patch_size=config.vision_patch_size,
|
| 131 |
+
hidden_size=config.vision_hidden_size,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.projector = ManthanProjector(
|
| 135 |
+
vision_hidden=config.vision_hidden_size,
|
| 136 |
+
text_hidden=text_hidden,
|
| 137 |
+
mid=config.projector_hidden_size,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Use TinyLLaVA-style negative placeholder for <image>
|
| 141 |
+
self.image_token_id = int(getattr(config, "image_token_id", -200))
|
| 142 |
+
self.num_image_tokens = int(getattr(config, "num_image_tokens", 256))
|
| 143 |
+
|
| 144 |
+
# Generation helpers
|
| 145 |
+
self._gen_pixel_values: Optional[torch.FloatTensor] = None
|
| 146 |
+
|
| 147 |
+
self.post_init()
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def format_chat_prompt(prompt: str, has_image: bool = False) -> str:
|
| 151 |
+
"""Format a single user prompt similar to TinyLLaVA's Qwen3 template.
|
| 152 |
+
|
| 153 |
+
We intentionally keep it simple and *string based* so it does not depend
|
| 154 |
+
on the tokenizer's chat_template.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
system = (
|
| 158 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 159 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if has_image:
|
| 163 |
+
# Ensure the user doesn't redundantly include <image> in their prompt.
|
| 164 |
+
clean = prompt.replace("<image>", "").strip()
|
| 165 |
+
formatted = f"<image>\n{clean}"
|
| 166 |
+
else:
|
| 167 |
+
formatted = prompt.strip()
|
| 168 |
+
|
| 169 |
+
# Critical: no trailing space after ASSISTANT:
|
| 170 |
+
return system + f"USER: {formatted} ASSISTANT:"
|
| 171 |
+
|
| 172 |
+
def _inject_vision_embeds(
|
| 173 |
+
self,
|
| 174 |
+
input_ids: torch.LongTensor,
|
| 175 |
+
pixel_values: torch.FloatTensor,
|
| 176 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 177 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.LongTensor]]:
|
| 178 |
+
"""Create input_embeds where <image> tokens are replaced by projected vision tokens.
|
| 179 |
+
|
| 180 |
+
Contract:
|
| 181 |
+
- input_ids contains exactly `num_image_tokens` occurrences of `image_token_id`.
|
| 182 |
+
- We will replace them in sequence order with vision tokens.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# Vision features
|
| 186 |
+
if self.vision_model is not None:
|
| 187 |
+
vout = self.vision_model(pixel_values=pixel_values)
|
| 188 |
+
# Try common fields: last_hidden_state
|
| 189 |
+
vision = getattr(vout, "last_hidden_state", None)
|
| 190 |
+
if vision is None:
|
| 191 |
+
raise ValueError("Vision model output does not contain last_hidden_state")
|
| 192 |
+
else:
|
| 193 |
+
vision = self.vision_tower(pixel_values) # (B, 1+N, vision_hidden)
|
| 194 |
+
|
| 195 |
+
# Token selection
|
| 196 |
+
if self.config.vision_feature_select == "patch":
|
| 197 |
+
# Drop CLS and take first num_image_tokens patches
|
| 198 |
+
vision = vision[:, 1 : 1 + self.num_image_tokens, :]
|
| 199 |
+
elif self.config.vision_feature_select == "cls_patch":
|
| 200 |
+
vision = vision[:, : self.num_image_tokens, :]
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Unknown vision_feature_select={self.config.vision_feature_select}")
|
| 203 |
+
|
| 204 |
+
# Be strict: ensure we have exactly num_image_tokens.
|
| 205 |
+
if vision.shape[1] != self.num_image_tokens:
|
| 206 |
+
# Common case for the toy vision tower: includes CLS + N patches.
|
| 207 |
+
if vision.shape[1] > self.num_image_tokens:
|
| 208 |
+
vision = vision[:, : self.num_image_tokens, :]
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
f"vision tokens ({vision.shape[1]}) < num_image_tokens ({self.num_image_tokens}); increase image size/patching or reduce num_image_tokens"
|
| 212 |
+
)
|
| 213 |
+
vision = self.projector(vision) # (B, num_img, text_hidden)
|
| 214 |
+
|
| 215 |
+
# Text embeds
|
| 216 |
+
lm = self.language_model
|
| 217 |
+
if hasattr(lm, "get_input_embeddings"):
|
| 218 |
+
tok_emb = lm.get_input_embeddings()
|
| 219 |
+
else:
|
| 220 |
+
tok_emb = lm.base_model.get_input_embeddings()
|
| 221 |
+
inputs_embeds = tok_emb(input_ids)
|
| 222 |
+
|
| 223 |
+
# Replace image token positions
|
| 224 |
+
mask = input_ids.eq(self.image_token_id) # (B, T)
|
| 225 |
+
if mask.sum(dim=1).min().item() != self.num_image_tokens:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
f"Expected exactly {self.num_image_tokens} <image> tokens per sample, got min={mask.sum(dim=1).min().item()}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
for b in range(input_ids.size(0)):
|
| 231 |
+
idx = torch.nonzero(mask[b], as_tuple=False).squeeze(-1)
|
| 232 |
+
inputs_embeds[b, idx, :] = vision[b, : idx.numel(), :]
|
| 233 |
+
|
| 234 |
+
return inputs_embeds, attention_mask
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def tokenizer_image_token(
|
| 238 |
+
prompt: str,
|
| 239 |
+
tokenizer,
|
| 240 |
+
image_token_index: int = -200,
|
| 241 |
+
return_tensors: Optional[str] = None,
|
| 242 |
+
):
|
| 243 |
+
"""MicroLLaVA/TinyLLaVA-style tokenization inserting a negative image placeholder id.
|
| 244 |
+
|
| 245 |
+
This avoids requiring `<image>` to be a real token in the tokenizer vocab.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def _insert_separator(X, sep):
|
| 249 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 250 |
+
|
| 251 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 252 |
+
|
| 253 |
+
input_ids: List[int] = []
|
| 254 |
+
offset = 0
|
| 255 |
+
if (
|
| 256 |
+
len(prompt_chunks) > 0
|
| 257 |
+
and len(prompt_chunks[0]) > 0
|
| 258 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
| 259 |
+
):
|
| 260 |
+
offset = 1
|
| 261 |
+
input_ids.append(prompt_chunks[0][0])
|
| 262 |
+
|
| 263 |
+
for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 264 |
+
input_ids.extend(x[offset:])
|
| 265 |
+
|
| 266 |
+
if return_tensors is not None:
|
| 267 |
+
if return_tensors == "pt":
|
| 268 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 269 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
| 270 |
+
return input_ids
|
| 271 |
+
|
| 272 |
+
def _encode_images(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
| 273 |
+
"""Return projected vision features (B, N, text_hidden)."""
|
| 274 |
+
|
| 275 |
+
if self.vision_model is not None:
|
| 276 |
+
vout = self.vision_model(pixel_values=pixel_values, output_hidden_states=True)
|
| 277 |
+
# Prefer siglip/clip style selection when hidden_states exist.
|
| 278 |
+
if hasattr(vout, "hidden_states") and vout.hidden_states is not None:
|
| 279 |
+
layer = getattr(self.config, "vision_feature_layer", -2)
|
| 280 |
+
vision = vout.hidden_states[layer]
|
| 281 |
+
else:
|
| 282 |
+
vision = getattr(vout, "last_hidden_state", None)
|
| 283 |
+
if vision is None:
|
| 284 |
+
raise ValueError("Vision model output has no usable hidden states")
|
| 285 |
+
else:
|
| 286 |
+
vision = self.vision_tower(pixel_values)
|
| 287 |
+
|
| 288 |
+
# Match TinyLLaVA selection strategy
|
| 289 |
+
strat = getattr(self.config, "vision_feature_select", "patch")
|
| 290 |
+
if strat == "patch":
|
| 291 |
+
vision = vision[:, 1:]
|
| 292 |
+
elif strat == "cls_patch":
|
| 293 |
+
vision = vision
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError(f"Unknown vision_feature_select={strat}")
|
| 296 |
+
|
| 297 |
+
# Optionally truncate to configured max image tokens
|
| 298 |
+
if vision.shape[1] > self.num_image_tokens:
|
| 299 |
+
vision = vision[:, : self.num_image_tokens]
|
| 300 |
+
|
| 301 |
+
return self.projector(vision)
|
| 302 |
+
|
| 303 |
+
def prepare_inputs_labels_for_multimodal(
|
| 304 |
+
self,
|
| 305 |
+
input_ids: torch.LongTensor,
|
| 306 |
+
attention_mask: Optional[torch.Tensor],
|
| 307 |
+
past_key_values: Optional[Any],
|
| 308 |
+
labels: Optional[torch.LongTensor],
|
| 309 |
+
pixel_values: Optional[torch.FloatTensor],
|
| 310 |
+
) -> Tuple[
|
| 311 |
+
Optional[torch.LongTensor],
|
| 312 |
+
Optional[torch.LongTensor],
|
| 313 |
+
Optional[torch.Tensor],
|
| 314 |
+
Optional[Any],
|
| 315 |
+
Optional[torch.FloatTensor],
|
| 316 |
+
Optional[torch.LongTensor],
|
| 317 |
+
]:
|
| 318 |
+
"""MicroLLaVA-style splice: build inputs_embeds by inserting vision features at IMAGE_TOKEN_INDEX.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels)
|
| 322 |
+
where input_ids is None when inputs_embeds is provided.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
if pixel_values is None or input_ids.shape[1] == 1 or self.vision_tower is None:
|
| 326 |
+
return input_ids, None, attention_mask, past_key_values, None, labels
|
| 327 |
+
|
| 328 |
+
image_features = self._encode_images(pixel_values) # (B, N, hidden)
|
| 329 |
+
|
| 330 |
+
orig_labels = labels
|
| 331 |
+
orig_attention_mask = attention_mask
|
| 332 |
+
|
| 333 |
+
if attention_mask is None:
|
| 334 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 335 |
+
else:
|
| 336 |
+
attention_mask = attention_mask.bool()
|
| 337 |
+
|
| 338 |
+
if labels is None:
|
| 339 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 340 |
+
|
| 341 |
+
# Remove padding
|
| 342 |
+
input_ids_list = [cur_ids[cur_mask] for cur_ids, cur_mask in zip(input_ids, attention_mask)]
|
| 343 |
+
labels_list = [cur_lbl[cur_mask] for cur_lbl, cur_mask in zip(labels, attention_mask)]
|
| 344 |
+
|
| 345 |
+
tok_emb = self.language_model.get_input_embeddings()
|
| 346 |
+
vocab_size = int(getattr(tok_emb, "num_embeddings", 0) or 0)
|
| 347 |
+
|
| 348 |
+
new_input_embeds: List[torch.Tensor] = []
|
| 349 |
+
new_labels: List[torch.Tensor] = []
|
| 350 |
+
cur_image_idx = 0
|
| 351 |
+
|
| 352 |
+
for batch_idx, cur_ids in enumerate(input_ids_list):
|
| 353 |
+
num_images = int((cur_ids == self.image_token_id).sum().item())
|
| 354 |
+
|
| 355 |
+
# No image tokens: plain text path.
|
| 356 |
+
if num_images == 0:
|
| 357 |
+
if vocab_size > 0:
|
| 358 |
+
cur_ids = cur_ids.clamp(min=0, max=vocab_size - 1)
|
| 359 |
+
new_input_embeds.append(tok_emb(cur_ids))
|
| 360 |
+
new_labels.append(labels_list[batch_idx])
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
# Split around image placeholder positions
|
| 364 |
+
image_token_indices = [-1] + torch.where(cur_ids == self.image_token_id)[0].tolist() + [cur_ids.shape[0]]
|
| 365 |
+
cur_labels = labels_list[batch_idx]
|
| 366 |
+
|
| 367 |
+
seg_ids: List[torch.Tensor] = []
|
| 368 |
+
seg_lbls: List[torch.Tensor] = []
|
| 369 |
+
for i in range(len(image_token_indices) - 1):
|
| 370 |
+
s = image_token_indices[i] + 1
|
| 371 |
+
e = image_token_indices[i + 1]
|
| 372 |
+
seg_ids.append(cur_ids[s:e])
|
| 373 |
+
seg_lbls.append(cur_labels[s:e])
|
| 374 |
+
|
| 375 |
+
split_sizes = [x.shape[0] for x in seg_ids]
|
| 376 |
+
total = int(sum(split_sizes))
|
| 377 |
+
|
| 378 |
+
if total > 0:
|
| 379 |
+
flat_ids = torch.cat(seg_ids, dim=0)
|
| 380 |
+
# Never feed negative ids into embeddings.
|
| 381 |
+
flat_ids = flat_ids[flat_ids >= 0]
|
| 382 |
+
if vocab_size > 0 and flat_ids.numel() > 0:
|
| 383 |
+
flat_ids = flat_ids.clamp(min=0, max=vocab_size - 1)
|
| 384 |
+
flat_emb = tok_emb(flat_ids)
|
| 385 |
+
emb_chunks = list(torch.split(flat_emb, split_sizes, dim=0))
|
| 386 |
+
else:
|
| 387 |
+
emb_chunks = [tok_emb(cur_ids[:0]) for _ in split_sizes]
|
| 388 |
+
|
| 389 |
+
cur_new_embeds: List[torch.Tensor] = []
|
| 390 |
+
cur_new_labels: List[torch.Tensor] = []
|
| 391 |
+
|
| 392 |
+
for i in range(num_images + 1):
|
| 393 |
+
cur_new_embeds.append(emb_chunks[i])
|
| 394 |
+
cur_new_labels.append(seg_lbls[i])
|
| 395 |
+
if i < num_images:
|
| 396 |
+
cur_img_feat = image_features[cur_image_idx]
|
| 397 |
+
cur_image_idx += 1
|
| 398 |
+
cur_new_embeds.append(cur_img_feat)
|
| 399 |
+
cur_new_labels.append(
|
| 400 |
+
torch.full(
|
| 401 |
+
(cur_img_feat.shape[0],),
|
| 402 |
+
IGNORE_INDEX,
|
| 403 |
+
device=cur_labels.device,
|
| 404 |
+
dtype=cur_labels.dtype,
|
| 405 |
+
)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
new_input_embeds.append(torch.cat([x.to(self.device) for x in cur_new_embeds], dim=0))
|
| 409 |
+
new_labels.append(torch.cat(cur_new_labels, dim=0))
|
| 410 |
+
|
| 411 |
+
# Truncate if needed
|
| 412 |
+
max_len_cfg = getattr(self.config, "tokenizer_model_max_length", None)
|
| 413 |
+
if max_len_cfg is not None:
|
| 414 |
+
new_input_embeds = [x[:max_len_cfg] for x in new_input_embeds]
|
| 415 |
+
new_labels = [x[:max_len_cfg] for x in new_labels]
|
| 416 |
+
|
| 417 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
| 418 |
+
batch_size = len(new_input_embeds)
|
| 419 |
+
|
| 420 |
+
padded_embeds: List[torch.Tensor] = []
|
| 421 |
+
padded_labels = torch.full(
|
| 422 |
+
(batch_size, max_len),
|
| 423 |
+
IGNORE_INDEX,
|
| 424 |
+
dtype=new_labels[0].dtype,
|
| 425 |
+
device=new_labels[0].device,
|
| 426 |
+
)
|
| 427 |
+
padded_mask = torch.zeros((batch_size, max_len), dtype=torch.long, device=padded_labels.device)
|
| 428 |
+
|
| 429 |
+
for i, (emb, lbl) in enumerate(zip(new_input_embeds, new_labels)):
|
| 430 |
+
cur_len = emb.shape[0]
|
| 431 |
+
pad = torch.zeros((max_len - cur_len, emb.shape[1]), dtype=emb.dtype, device=emb.device)
|
| 432 |
+
padded_embeds.append(torch.cat([emb, pad], dim=0))
|
| 433 |
+
padded_labels[i, :cur_len] = lbl
|
| 434 |
+
padded_mask[i, :cur_len] = 1
|
| 435 |
+
|
| 436 |
+
inputs_embeds = torch.stack(padded_embeds, dim=0)
|
| 437 |
+
|
| 438 |
+
out_labels = None if orig_labels is None else padded_labels
|
| 439 |
+
out_mask = None if orig_attention_mask is None else padded_mask
|
| 440 |
+
|
| 441 |
+
return None, None, out_mask, past_key_values, inputs_embeds, out_labels
|
| 442 |
+
|
| 443 |
+
def forward(
|
| 444 |
+
self,
|
| 445 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 446 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 447 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 448 |
+
labels: Optional[torch.LongTensor] = None,
|
| 449 |
+
**kwargs,
|
| 450 |
+
) -> CausalLMOutputWithPast:
|
| 451 |
+
if input_ids is None:
|
| 452 |
+
raise ValueError("input_ids is required")
|
| 453 |
+
|
| 454 |
+
if pixel_values is None and self._gen_pixel_values is not None:
|
| 455 |
+
pixel_values = self._gen_pixel_values
|
| 456 |
+
|
| 457 |
+
past_key_values = kwargs.get("past_key_values", None)
|
| 458 |
+
(
|
| 459 |
+
input_ids,
|
| 460 |
+
position_ids,
|
| 461 |
+
attention_mask,
|
| 462 |
+
past_key_values,
|
| 463 |
+
inputs_embeds,
|
| 464 |
+
labels,
|
| 465 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 466 |
+
input_ids=input_ids,
|
| 467 |
+
attention_mask=attention_mask,
|
| 468 |
+
past_key_values=past_key_values,
|
| 469 |
+
labels=labels,
|
| 470 |
+
pixel_values=pixel_values,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return self.language_model(
|
| 474 |
+
input_ids=input_ids,
|
| 475 |
+
inputs_embeds=inputs_embeds,
|
| 476 |
+
attention_mask=attention_mask,
|
| 477 |
+
position_ids=position_ids,
|
| 478 |
+
past_key_values=past_key_values,
|
| 479 |
+
labels=labels,
|
| 480 |
+
**kwargs,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
@torch.no_grad()
|
| 484 |
+
def generate(
|
| 485 |
+
self,
|
| 486 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 487 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 488 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 489 |
+
**kwargs,
|
| 490 |
+
):
|
| 491 |
+
if input_ids is None:
|
| 492 |
+
raise ValueError("input_ids is required")
|
| 493 |
+
|
| 494 |
+
if pixel_values is not None:
|
| 495 |
+
(
|
| 496 |
+
_,
|
| 497 |
+
position_ids,
|
| 498 |
+
attention_mask,
|
| 499 |
+
_,
|
| 500 |
+
inputs_embeds,
|
| 501 |
+
_,
|
| 502 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 503 |
+
input_ids=input_ids,
|
| 504 |
+
attention_mask=attention_mask,
|
| 505 |
+
past_key_values=None,
|
| 506 |
+
labels=None,
|
| 507 |
+
pixel_values=pixel_values,
|
| 508 |
+
)
|
| 509 |
+
return self.language_model.generate(
|
| 510 |
+
inputs_embeds=inputs_embeds,
|
| 511 |
+
attention_mask=attention_mask,
|
| 512 |
+
position_ids=position_ids,
|
| 513 |
+
**kwargs,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return self.language_model.generate(
|
| 517 |
+
input_ids=input_ids,
|
| 518 |
+
attention_mask=attention_mask,
|
| 519 |
+
**kwargs,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
def prepare_inputs_for_generation(
|
| 523 |
+
self,
|
| 524 |
+
input_ids: torch.LongTensor,
|
| 525 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, ...], ...]] = None,
|
| 526 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 527 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 528 |
+
**kwargs,
|
| 529 |
+
) -> Dict[str, Any]:
|
| 530 |
+
"""HF generation hook.
|
| 531 |
+
|
| 532 |
+
Contract:
|
| 533 |
+
- On the first step, we keep full `input_ids` and provide `pixel_values`.
|
| 534 |
+
- For subsequent steps (past_key_values != None), HF passes only the last token.
|
| 535 |
+
We keep cached `pixel_values` and forward normally.
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
if pixel_values is not None:
|
| 539 |
+
self._gen_pixel_values = pixel_values
|
| 540 |
+
|
| 541 |
+
# If we have past, only feed last token (standard causal LM behavior)
|
| 542 |
+
if past_key_values is not None:
|
| 543 |
+
input_ids = input_ids[:, -1:]
|
| 544 |
+
|
| 545 |
+
model_inputs: Dict[str, Any] = {
|
| 546 |
+
"input_ids": input_ids,
|
| 547 |
+
"attention_mask": attention_mask,
|
| 548 |
+
"past_key_values": past_key_values,
|
| 549 |
+
"use_cache": kwargs.get("use_cache", True),
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
if self._gen_pixel_values is not None:
|
| 553 |
+
model_inputs["pixel_values"] = self._gen_pixel_values
|
| 554 |
+
|
| 555 |
+
return model_inputs
|
| 556 |
+
|
| 557 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 558 |
+
# Delegate to the underlying LM implementation
|
| 559 |
+
if hasattr(self.language_model, "_reorder_cache"):
|
| 560 |
+
return self.language_model._reorder_cache(past_key_values, beam_idx)
|
| 561 |
+
return past_key_values
|
| 562 |
+
|
| 563 |
+
@torch.no_grad()
|
| 564 |
+
def chat(
|
| 565 |
+
self,
|
| 566 |
+
prompt: str,
|
| 567 |
+
tokenizer,
|
| 568 |
+
image: Optional[Union[str, "PIL.Image.Image"]] = None,
|
| 569 |
+
max_new_tokens: int = 128,
|
| 570 |
+
**gen_kwargs,
|
| 571 |
+
) -> str:
|
| 572 |
+
"""Simple chat helper (mirrors the style in your MicroLLaVA README).
|
| 573 |
+
|
| 574 |
+
This is intentionally minimal. A proper processor will be added later.
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
# Lazy import to keep base install small
|
| 578 |
+
pixel_values = None
|
| 579 |
+
if image is not None:
|
| 580 |
+
from PIL import Image
|
| 581 |
+
import requests
|
| 582 |
+
from io import BytesIO
|
| 583 |
+
|
| 584 |
+
if isinstance(image, str) and image.startswith("http"):
|
| 585 |
+
r = requests.get(image, timeout=30)
|
| 586 |
+
r.raise_for_status()
|
| 587 |
+
image = Image.open(BytesIO(r.content)).convert("RGB")
|
| 588 |
+
elif isinstance(image, str):
|
| 589 |
+
image = Image.open(image).convert("RGB")
|
| 590 |
+
|
| 591 |
+
# Prefer model-specific preprocessing when we have a real vision backbone.
|
| 592 |
+
if self.vision_model is not None and self.config.vision_model_id:
|
| 593 |
+
from transformers import AutoProcessor
|
| 594 |
+
|
| 595 |
+
proc = AutoProcessor.from_pretrained(self.config.vision_model_id, trust_remote_code=True)
|
| 596 |
+
pv = proc(images=image, return_tensors="pt")
|
| 597 |
+
pixel_values = pv.get("pixel_values", None)
|
| 598 |
+
if pixel_values is None:
|
| 599 |
+
raise ValueError("AutoProcessor did not return pixel_values")
|
| 600 |
+
else:
|
| 601 |
+
import torchvision.transforms as T
|
| 602 |
+
|
| 603 |
+
tfm = T.Compose(
|
| 604 |
+
[
|
| 605 |
+
T.Resize((self.config.vision_image_size, self.config.vision_image_size)),
|
| 606 |
+
T.ToTensor(),
|
| 607 |
+
]
|
| 608 |
+
)
|
| 609 |
+
pixel_values = tfm(image).unsqueeze(0)
|
| 610 |
+
|
| 611 |
+
# Insert language mirroring instruction (simple + robust)
|
| 612 |
+
# We keep it short to not fight Qwen's reasoning.
|
| 613 |
+
user_prompt = (
|
| 614 |
+
"Reply in the same language as the user's prompt (e.g., Tamil in Tamil, Hindi in Hindi, English in English). "
|
| 615 |
+
"Be helpful and concise.\n\n" + prompt
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
formatted = self.format_chat_prompt(user_prompt, has_image=pixel_values is not None)
|
| 619 |
+
|
| 620 |
+
# MicroLLaVA-style: tokenize by inserting image placeholder ids directly.
|
| 621 |
+
if pixel_values is not None:
|
| 622 |
+
input_ids = self.tokenizer_image_token(
|
| 623 |
+
formatted,
|
| 624 |
+
tokenizer,
|
| 625 |
+
image_token_index=self.image_token_id,
|
| 626 |
+
return_tensors="pt",
|
| 627 |
+
).unsqueeze(0)
|
| 628 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
| 629 |
+
else:
|
| 630 |
+
ids = tokenizer(formatted, return_tensors="pt")
|
| 631 |
+
input_ids = ids["input_ids"]
|
| 632 |
+
attention_mask = ids.get("attention_mask")
|
| 633 |
+
|
| 634 |
+
# Keep everything on the model device
|
| 635 |
+
input_ids = input_ids.to(self.device)
|
| 636 |
+
if attention_mask is not None:
|
| 637 |
+
attention_mask = attention_mask.to(self.device)
|
| 638 |
+
if pixel_values is not None:
|
| 639 |
+
pixel_values = pixel_values.to(self.device)
|
| 640 |
+
|
| 641 |
+
# Use this wrapper's generate support so pixel_values can be carried through.
|
| 642 |
+
out = self.generate(
|
| 643 |
+
input_ids=input_ids,
|
| 644 |
+
attention_mask=attention_mask,
|
| 645 |
+
pixel_values=pixel_values,
|
| 646 |
+
max_new_tokens=max_new_tokens,
|
| 647 |
+
**gen_kwargs,
|
| 648 |
+
)
|
| 649 |
+
return tokenizer.decode(out[0], skip_special_tokens=True)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
# Registration so AutoModel/AutoConfig can find it when you package/export.
|
| 653 |
+
AutoConfig.register(ManthanConfig.model_type, ManthanConfig)
|
| 654 |
+
AutoModelForCausalLM.register(ManthanConfig, ManthanForCausalLM)
|
manthan_t1/smoke_test.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal smoke test.
|
| 2 |
+
|
| 3 |
+
Goal: ensure the repo is runnable on macOS + MLX before we wire real Qwen/SigLIP weights.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import mlx.core as mx
|
| 11 |
+
import mlx.nn as nn
|
| 12 |
+
import mlx.optimizers as optim
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TinyToyModel(nn.Module):
|
| 16 |
+
def __init__(self, vocab_size: int = 256, d_model: int = 128):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.emb = nn.Embedding(vocab_size, d_model)
|
| 19 |
+
self.l1 = nn.Linear(d_model, d_model)
|
| 20 |
+
self.l2 = nn.Linear(d_model, vocab_size)
|
| 21 |
+
|
| 22 |
+
def __call__(self, token_ids: mx.array) -> mx.array:
|
| 23 |
+
# token_ids: (B, T)
|
| 24 |
+
x = self.emb(token_ids)
|
| 25 |
+
x = nn.relu(self.l1(x))
|
| 26 |
+
return self.l2(x)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main() -> None:
|
| 30 |
+
mx.random.seed(0)
|
| 31 |
+
|
| 32 |
+
model = TinyToyModel()
|
| 33 |
+
|
| 34 |
+
# fake batch
|
| 35 |
+
B, T = 4, 32
|
| 36 |
+
token_ids = mx.random.randint(0, 256, shape=(B, T))
|
| 37 |
+
targets = mx.random.randint(0, 256, shape=(B, T))
|
| 38 |
+
|
| 39 |
+
def loss_fn(m: TinyToyModel, x: mx.array, y: mx.array) -> mx.array:
|
| 40 |
+
logits = m(x)
|
| 41 |
+
logits2 = logits.reshape((-1, logits.shape[-1]))
|
| 42 |
+
y2 = y.reshape((-1,))
|
| 43 |
+
# `cross_entropy` may return per-example/per-token; reduce to scalar.
|
| 44 |
+
return mx.mean(nn.losses.cross_entropy(logits2, y2))
|
| 45 |
+
|
| 46 |
+
opt = optim.Adam(learning_rate=1e-3)
|
| 47 |
+
|
| 48 |
+
start = time.time()
|
| 49 |
+
def f(m: TinyToyModel) -> mx.array:
|
| 50 |
+
return loss_fn(m, token_ids, targets)
|
| 51 |
+
|
| 52 |
+
for step in range(5):
|
| 53 |
+
loss, grads = mx.value_and_grad(f)(model)
|
| 54 |
+
opt.update(model, grads)
|
| 55 |
+
mx.eval(loss)
|
| 56 |
+
print(f"step={step} loss={float(loss):.4f}")
|
| 57 |
+
|
| 58 |
+
mx.eval(model.parameters())
|
| 59 |
+
print(f"OK (elapsed {time.time() - start:.2f}s)")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
manthan_t1/text_generate_smoke.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""No-download generate smoke test.
|
| 2 |
+
|
| 3 |
+
Ensures that `ManthanForCausalLM.generate()` works (text-only path), which is
|
| 4 |
+
required before enabling image+generate.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from manthan_t1.configuration_manthan import ManthanConfig
|
| 13 |
+
from manthan_t1.modeling_manthan import ManthanForCausalLM
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main() -> None:
|
| 17 |
+
cfg = ManthanConfig(text_config={"model_type": "gpt2"})
|
| 18 |
+
model = ManthanForCausalLM(cfg)
|
| 19 |
+
model.eval()
|
| 20 |
+
|
| 21 |
+
tok = AutoTokenizer.from_pretrained("gpt2")
|
| 22 |
+
prompt = "Hello, my name is"
|
| 23 |
+
inputs = tok(prompt, return_tensors="pt")
|
| 24 |
+
|
| 25 |
+
out = model.generate(**inputs, max_new_tokens=10)
|
| 26 |
+
print(tok.decode(out[0], skip_special_tokens=True))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
main()
|
manthan_t1/tokenizer_smoke.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenizer + template smoke test.
|
| 2 |
+
|
| 3 |
+
- Ensures `<image>` can be added to a tokenizer
|
| 4 |
+
- Ensures we can build a prompt containing `<image>` token placeholders
|
| 5 |
+
|
| 6 |
+
Downloads a small tokenizer (gpt2) only.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
from manthan_t1.tokenizer_utils import ensure_vision_special_tokens
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main() -> None:
|
| 17 |
+
tok = AutoTokenizer.from_pretrained("gpt2")
|
| 18 |
+
res = ensure_vision_special_tokens(tok, add_im_start_end=True)
|
| 19 |
+
|
| 20 |
+
prompt = ("<image> " * 4) + "\nExplain what you see."
|
| 21 |
+
ids = tok(prompt).input_ids
|
| 22 |
+
|
| 23 |
+
print("image_token_id", res.image_token_id)
|
| 24 |
+
print("count", sum(1 for i in ids if i == res.image_token_id))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
manthan_t1/tokenizer_utils.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
DEFAULT_SPECIAL_TOKENS: Dict[str, str] = {
|
| 8 |
+
"image_token": "<image>",
|
| 9 |
+
"im_start": "<im_start>",
|
| 10 |
+
"im_end": "<im_end>",
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TokenSetupResult:
|
| 16 |
+
image_token_id: int
|
| 17 |
+
added: Dict[str, int]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def ensure_vision_special_tokens(tokenizer, add_im_start_end: bool = False) -> TokenSetupResult:
|
| 21 |
+
"""Ensure the tokenizer has `<image>` (and optionally `<im_start>/<im_end>`).
|
| 22 |
+
|
| 23 |
+
Returns the chosen `image_token_id` and any newly-added token ids.
|
| 24 |
+
|
| 25 |
+
Works with both fast and slow tokenizers.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
specials = {
|
| 29 |
+
"additional_special_tokens": [DEFAULT_SPECIAL_TOKENS["image_token"]],
|
| 30 |
+
}
|
| 31 |
+
if add_im_start_end:
|
| 32 |
+
specials["additional_special_tokens"].extend(
|
| 33 |
+
[DEFAULT_SPECIAL_TOKENS["im_start"], DEFAULT_SPECIAL_TOKENS["im_end"]]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Add only missing tokens
|
| 37 |
+
existing = set(getattr(tokenizer, "additional_special_tokens", []) or [])
|
| 38 |
+
to_add = [t for t in specials["additional_special_tokens"] if t not in existing]
|
| 39 |
+
|
| 40 |
+
added_map: Dict[str, int] = {}
|
| 41 |
+
if to_add:
|
| 42 |
+
tokenizer.add_special_tokens({"additional_special_tokens": to_add})
|
| 43 |
+
# after add, resolve ids
|
| 44 |
+
for t in to_add:
|
| 45 |
+
added_map[t] = tokenizer.convert_tokens_to_ids(t)
|
| 46 |
+
|
| 47 |
+
image_token = DEFAULT_SPECIAL_TOKENS["image_token"]
|
| 48 |
+
image_token_id = tokenizer.convert_tokens_to_ids(image_token)
|
| 49 |
+
if image_token_id is None or image_token_id < 0:
|
| 50 |
+
raise ValueError("Failed to register <image> token")
|
| 51 |
+
|
| 52 |
+
return TokenSetupResult(image_token_id=image_token_id, added=added_map)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "manthan-t1"
|
| 3 |
+
version = "0.0.1"
|
| 4 |
+
description = "Manthan-T1: MLX-first vision-language model (Qwen3 + vision encoder) fine-tuning scaffold"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
license = {text = "Apache-2.0"}
|
| 8 |
+
authors = [{name = "Manthan"}]
|
| 9 |
+
|
| 10 |
+
dependencies = [
|
| 11 |
+
"mlx>=0.17.0",
|
| 12 |
+
"numpy>=1.26",
|
| 13 |
+
"pillow>=10.0",
|
| 14 |
+
"tqdm>=4.66",
|
| 15 |
+
"pyyaml>=6.0",
|
| 16 |
+
"transformers>=4.55.0",
|
| 17 |
+
"torch>=2.2.0",
|
| 18 |
+
"torchvision>=0.17.0",
|
| 19 |
+
"requests>=2.31",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.optional-dependencies]
|
| 23 |
+
dev = [
|
| 24 |
+
"pytest>=8.0",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[tool.pytest.ini_options]
|
| 28 |
+
addopts = "-q"
|
| 29 |
+
testpaths = ["tests"]
|
requirements.txt
ADDED
|
File without changes
|
scripts/export_hf.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""Export a Manthan-T1 folder that can be uploaded to Hugging Face.
|
| 4 |
+
|
| 5 |
+
What this does:
|
| 6 |
+
- Copies `hf_export_stub/*` into an output directory
|
| 7 |
+
- Builds a tokenizer from `tokenizer_name_or_path` (defaults to Qwen3)
|
| 8 |
+
- Ensures `<image>` is a real special token in the tokenizer
|
| 9 |
+
- Writes `tokenizer_config.json`, `special_tokens_map.json`, `added_tokens.json`, and `chat_template.jinja`
|
| 10 |
+
- Updates `config.json` with a correct `image_token_id` (kept equal to -200 placeholder)
|
| 11 |
+
|
| 12 |
+
Note:
|
| 13 |
+
- This does NOT include model weights. It's intended for placeholder-weight repo layout
|
| 14 |
+
(like your MicroLLaVA example). For training, you'll later save actual weights.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
import sys
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
from transformers import AutoTokenizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Allow running this script without installing the package.
|
| 30 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 31 |
+
if str(REPO_ROOT) not in sys.path:
|
| 32 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _copytree(src: Path, dst: Path) -> None:
|
| 36 |
+
dst.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
for item in src.iterdir():
|
| 38 |
+
s = item
|
| 39 |
+
d = dst / item.name
|
| 40 |
+
if item.is_dir():
|
| 41 |
+
shutil.copytree(s, d, dirs_exist_ok=True)
|
| 42 |
+
else:
|
| 43 |
+
shutil.copy2(s, d)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main() -> None:
|
| 47 |
+
ap = argparse.ArgumentParser()
|
| 48 |
+
ap.add_argument("--out", required=True, help="Output folder")
|
| 49 |
+
ap.add_argument(
|
| 50 |
+
"--stub",
|
| 51 |
+
default=str(Path(__file__).resolve().parents[1] / "hf_export_stub"),
|
| 52 |
+
help="Path to hf_export_stub folder",
|
| 53 |
+
)
|
| 54 |
+
ap.add_argument(
|
| 55 |
+
"--tokenizer",
|
| 56 |
+
default=None,
|
| 57 |
+
help="Tokenizer name/path. Defaults to config.json tokenizer_name_or_path.",
|
| 58 |
+
)
|
| 59 |
+
ap.add_argument(
|
| 60 |
+
"--tokenizer_local_dir",
|
| 61 |
+
default=None,
|
| 62 |
+
help="Local tokenizer directory to copy (e.g. MicroLlava-* folder). If set, no network fetch is performed.",
|
| 63 |
+
)
|
| 64 |
+
ap.add_argument(
|
| 65 |
+
"--write_stub_weights",
|
| 66 |
+
action="store_true",
|
| 67 |
+
help="Write randomly-initialized weights (model.safetensors) into the export dir so from_pretrained() succeeds.",
|
| 68 |
+
)
|
| 69 |
+
args = ap.parse_args()
|
| 70 |
+
|
| 71 |
+
out_dir = Path(args.out).expanduser().resolve()
|
| 72 |
+
stub_dir = Path(args.stub).expanduser().resolve()
|
| 73 |
+
|
| 74 |
+
if not stub_dir.exists():
|
| 75 |
+
raise SystemExit(f"Stub dir not found: {stub_dir}")
|
| 76 |
+
|
| 77 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
_copytree(stub_dir, out_dir)
|
| 79 |
+
|
| 80 |
+
# Ensure we don't keep stale remote-code python files from a previous export.
|
| 81 |
+
for stale in ["configuration_manthan.py", "modeling_manthan.py", "__init__.py"]:
|
| 82 |
+
p = out_dir / stale
|
| 83 |
+
if p.exists():
|
| 84 |
+
p.unlink()
|
| 85 |
+
|
| 86 |
+
# Copy remote-code python files to export root (HF dynamic module loader expects them)
|
| 87 |
+
repo_root = Path(__file__).resolve().parents[1]
|
| 88 |
+
pkg_dir = repo_root / "manthan_t1"
|
| 89 |
+
for fname in ["configuration_manthan.py", "modeling_manthan.py", "__init__.py"]:
|
| 90 |
+
src = pkg_dir / fname
|
| 91 |
+
if not src.exists():
|
| 92 |
+
raise SystemExit(f"Missing required source file for export: {src}")
|
| 93 |
+
shutil.copy2(src, out_dir / fname)
|
| 94 |
+
|
| 95 |
+
cfg_path = out_dir / "config.json"
|
| 96 |
+
if not cfg_path.exists():
|
| 97 |
+
raise SystemExit(f"config.json not found in: {out_dir}")
|
| 98 |
+
|
| 99 |
+
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
|
| 100 |
+
tokenizer_name = (
|
| 101 |
+
args.tokenizer
|
| 102 |
+
or cfg.get("tokenizer_name_or_path")
|
| 103 |
+
or cfg.get("llm_model_name_or_path")
|
| 104 |
+
or cfg.get("text_model_id")
|
| 105 |
+
or cfg.get("vision_model_id")
|
| 106 |
+
)
|
| 107 |
+
if not tokenizer_name:
|
| 108 |
+
raise SystemExit("Could not infer tokenizer_name_or_path")
|
| 109 |
+
|
| 110 |
+
# Prefer an on-disk tokenizer (e.g. the attached MicroLLaVA folder) to avoid any
|
| 111 |
+
# network dependency during export.
|
| 112 |
+
repo_root = Path(__file__).resolve().parents[1]
|
| 113 |
+
local_tokenizer_candidates = [
|
| 114 |
+
repo_root / "MicroLlava-Qwen3-0.6B-base-siglip2-so400m",
|
| 115 |
+
]
|
| 116 |
+
for cand in local_tokenizer_candidates:
|
| 117 |
+
if cand.exists() and (cand / "tokenizer_config.json").exists():
|
| 118 |
+
tokenizer_name = str(cand)
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
tok = AutoTokenizer.from_pretrained(
|
| 122 |
+
tokenizer_name,
|
| 123 |
+
trust_remote_code=True,
|
| 124 |
+
use_fast=bool(cfg.get("tokenizer_use_fast", False)),
|
| 125 |
+
local_files_only=True,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Ensure special tokens exist
|
| 129 |
+
added = tok.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
| 130 |
+
# Some tokenizers need a pad token for batching.
|
| 131 |
+
if tok.pad_token_id is None and cfg.get("pad_token"):
|
| 132 |
+
tok.add_special_tokens({"pad_token": cfg["pad_token"]})
|
| 133 |
+
|
| 134 |
+
# Save tokenizer files into export dir
|
| 135 |
+
tok.save_pretrained(out_dir)
|
| 136 |
+
|
| 137 |
+
# Copy chat template if present in stub
|
| 138 |
+
tmpl_src = out_dir / "chat_template.jinja"
|
| 139 |
+
if tmpl_src.exists():
|
| 140 |
+
# Ensure tokenizer_config.json references it (HF uses string field)
|
| 141 |
+
tok_cfg_path = out_dir / "tokenizer_config.json"
|
| 142 |
+
if tok_cfg_path.exists():
|
| 143 |
+
tok_cfg = json.loads(tok_cfg_path.read_text(encoding="utf-8"))
|
| 144 |
+
else:
|
| 145 |
+
tok_cfg = {}
|
| 146 |
+
tok_cfg["chat_template"] = tmpl_src.read_text(encoding="utf-8")
|
| 147 |
+
tok_cfg_path.write_text(json.dumps(tok_cfg, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
| 148 |
+
|
| 149 |
+
# Align config fields with MicroLLaVA convention
|
| 150 |
+
cfg.setdefault("image_token_index", -200)
|
| 151 |
+
cfg["image_token_index"] = -200
|
| 152 |
+
cfg["image_token_id"] = -200
|
| 153 |
+
|
| 154 |
+
# For user convenience record actual tokenizer vocab id of '<image>'
|
| 155 |
+
img_vocab_id = tok.convert_tokens_to_ids("<image>")
|
| 156 |
+
cfg["tokenizer_image_token_id"] = int(img_vocab_id) if img_vocab_id is not None else None
|
| 157 |
+
cfg["tokenizer_added_tokens"] = int(added)
|
| 158 |
+
|
| 159 |
+
cfg_path.write_text(json.dumps(cfg, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
| 160 |
+
|
| 161 |
+
# Minimal README hint
|
| 162 |
+
readme = out_dir / "README_EXPORT.md"
|
| 163 |
+
readme.write_text(
|
| 164 |
+
"Manthan-T1 export folder (stub).\n\n"
|
| 165 |
+
"- `config.json` uses `image_token_index=-200` placeholder like TinyLLaVA.\n"
|
| 166 |
+
"- Tokenizer contains a real `<image>` special token.\n"
|
| 167 |
+
"- This folder does not include model weights; training should save weights here later.\n",
|
| 168 |
+
encoding="utf-8",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
print(f"Exported to: {out_dir}")
|
| 172 |
+
|
| 173 |
+
if args.write_stub_weights:
|
| 174 |
+
# Import only when requested to avoid heavier imports for plain export.
|
| 175 |
+
from manthan_t1.configuration_manthan import ManthanConfig
|
| 176 |
+
from manthan_t1.modeling_manthan import ManthanForCausalLM
|
| 177 |
+
|
| 178 |
+
# Tiny randomly-initialized model that is loadable.
|
| 179 |
+
# This does not download any base weights.
|
| 180 |
+
stub_cfg = ManthanConfig(
|
| 181 |
+
text_model_id=None,
|
| 182 |
+
vision_model_id=None,
|
| 183 |
+
image_token_index=-200,
|
| 184 |
+
num_image_tokens=32,
|
| 185 |
+
)
|
| 186 |
+
model = ManthanForCausalLM(stub_cfg)
|
| 187 |
+
model.save_pretrained(out_dir, safe_serialization=True)
|
| 188 |
+
|
| 189 |
+
# Ensure auto_map is present so AutoConfig/AutoModel can resolve our
|
| 190 |
+
# custom classes via trust_remote_code.
|
| 191 |
+
saved_cfg = json.loads((out_dir / "config.json").read_text(encoding="utf-8"))
|
| 192 |
+
saved_cfg["auto_map"] = cfg.get(
|
| 193 |
+
"auto_map",
|
| 194 |
+
{
|
| 195 |
+
"AutoConfig": "configuration_manthan.ManthanConfig",
|
| 196 |
+
"AutoModelForCausalLM": "modeling_manthan.ManthanForCausalLM",
|
| 197 |
+
},
|
| 198 |
+
)
|
| 199 |
+
(out_dir / "config.json").write_text(
|
| 200 |
+
json.dumps(saved_cfg, indent=2, ensure_ascii=False) + "\n",
|
| 201 |
+
encoding="utf-8",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
print("Wrote stub weights: model.safetensors")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
main()
|
scripts/infer_hf.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main() -> None:
|
| 10 |
+
ap = argparse.ArgumentParser()
|
| 11 |
+
ap.add_argument("--model", type=str, default="./")
|
| 12 |
+
ap.add_argument("--prompt", type=str, required=True)
|
| 13 |
+
ap.add_argument("--image", type=str, default=None, help="URL or local path")
|
| 14 |
+
args = ap.parse_args()
|
| 15 |
+
|
| 16 |
+
model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True)
|
| 17 |
+
tok = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
| 18 |
+
|
| 19 |
+
if hasattr(model, "chat"):
|
| 20 |
+
text = model.chat(prompt=args.prompt, image=args.image, tokenizer=tok)
|
| 21 |
+
print(text)
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
inputs = tok(args.prompt, return_tensors="pt")
|
| 25 |
+
out = model.generate(**inputs, max_new_tokens=128)
|
| 26 |
+
print(tok.decode(out[0], skip_special_tokens=True))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
main()
|
scripts/infer_qwen3_siglip2.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from manthan_t1.configuration_manthan import ManthanConfig
|
| 9 |
+
from manthan_t1.modeling_manthan import ManthanForCausalLM
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main() -> None:
|
| 13 |
+
ap = argparse.ArgumentParser()
|
| 14 |
+
ap.add_argument("--text-model", type=str, default="Qwen/Qwen3-0.6B")
|
| 15 |
+
ap.add_argument(
|
| 16 |
+
"--vision-model",
|
| 17 |
+
type=str,
|
| 18 |
+
default="google/siglip-so400m-patch14-384",
|
| 19 |
+
)
|
| 20 |
+
ap.add_argument("--prompt", type=str, required=True)
|
| 21 |
+
ap.add_argument("--image", type=str, default=None, help="URL or local path")
|
| 22 |
+
ap.add_argument("--image-token-id", type=int, default=151665)
|
| 23 |
+
ap.add_argument("--num-image-tokens", type=int, default=256)
|
| 24 |
+
args = ap.parse_args()
|
| 25 |
+
|
| 26 |
+
cfg = ManthanConfig(
|
| 27 |
+
text_model_id=args.text_model,
|
| 28 |
+
vision_model_id=args.vision_model,
|
| 29 |
+
vision_image_size=384,
|
| 30 |
+
vision_patch_size=14,
|
| 31 |
+
image_token_id=args.image_token_id,
|
| 32 |
+
num_image_tokens=args.num_image_tokens,
|
| 33 |
+
vision_feature_select="patch",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
model = ManthanForCausalLM(cfg)
|
| 37 |
+
tok = AutoTokenizer.from_pretrained(args.text_model, use_fast=False, trust_remote_code=True)
|
| 38 |
+
|
| 39 |
+
if args.image:
|
| 40 |
+
out = model.chat(prompt=args.prompt, image=args.image, tokenizer=tok)
|
| 41 |
+
else:
|
| 42 |
+
inputs = tok(args.prompt, return_tensors="pt")
|
| 43 |
+
out_ids = model.language_model.generate(**inputs, max_new_tokens=128)
|
| 44 |
+
out = tok.decode(out_ids[0], skip_special_tokens=True)
|
| 45 |
+
|
| 46 |
+
print(out)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
scripts/kaggle_train_all.sh
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# One-shot Kaggle runner: Stage 1 (pretrain/alignment) -> Stage 2 (instruct finetune)
|
| 5 |
+
# Designed for Kaggle 2xT4, but also works on other CUDA machines.
|
| 6 |
+
|
| 7 |
+
############################################
|
| 8 |
+
# User config (edit these if you want)
|
| 9 |
+
############################################
|
| 10 |
+
: "${MANTHAN_MODEL:=zyxcisss/Manthan-T1}" # HF repo or local path containing Manthan remote-code
|
| 11 |
+
: "${TEXT_MODEL:=Qwen/Qwen3-0.6B-Base}" # base LLM checkpoint
|
| 12 |
+
: "${STAGE1_DS:=liuhaotian/LLaVA-CC3M-Pretrain-595K}" # pretrain/alignment
|
| 13 |
+
: "${STAGE2_DS:=liuhaotian/LLaVA-Instruct-150K}" # instruction finetune
|
| 14 |
+
|
| 15 |
+
: "${OUT_BASE:=/kaggle/working/manthan_runs}" # all outputs saved here
|
| 16 |
+
: "${STAGE1_OUT:=${OUT_BASE}/stage1}" # stage1 output dir
|
| 17 |
+
: "${STAGE2_OUT:=${OUT_BASE}/stage2}" # stage2 output dir
|
| 18 |
+
|
| 19 |
+
# Training knobs (safe defaults for 2xT4)
|
| 20 |
+
: "${MAX_LENGTH:=2048}"
|
| 21 |
+
: "${IMAGE_SIZE:=384}"
|
| 22 |
+
: "${BATCH_SIZE:=1}"
|
| 23 |
+
: "${GRAD_ACCUM:=32}"
|
| 24 |
+
: "${LR:=1e-4}"
|
| 25 |
+
: "${EPOCHS_STAGE1:=1}"
|
| 26 |
+
: "${EPOCHS_STAGE2:=1}"
|
| 27 |
+
|
| 28 |
+
# Optional dataset limits (set empty for full)
|
| 29 |
+
: "${LIMIT_STAGE1:=20000}"
|
| 30 |
+
: "${LIMIT_STAGE2:=150000}"
|
| 31 |
+
|
| 32 |
+
# If you want to disable LoRA for projector-only training, set USE_LORA=0
|
| 33 |
+
: "${USE_LORA:=1}"
|
| 34 |
+
|
| 35 |
+
# If you want this script to upload artifacts via huggingface-cli, set UPLOAD=1
|
| 36 |
+
: "${UPLOAD:=0}"
|
| 37 |
+
|
| 38 |
+
############################################
|
| 39 |
+
# Environment setup
|
| 40 |
+
############################################
|
| 41 |
+
if command -v nvidia-smi >/dev/null 2>&1; then
|
| 42 |
+
echo "GPU found:"; nvidia-smi || true
|
| 43 |
+
else
|
| 44 |
+
echo "WARNING: nvidia-smi not found. This script expects a CUDA runtime (Kaggle)."
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
# Persist caches on Kaggle
|
| 48 |
+
export HF_HOME="${HF_HOME:-/kaggle/working/hf}"
|
| 49 |
+
export TRANSFORMERS_CACHE="${TRANSFORMERS_CACHE:-/kaggle/working/hf/transformers}"
|
| 50 |
+
export HF_DATASETS_CACHE="${HF_DATASETS_CACHE:-/kaggle/working/hf/datasets}"
|
| 51 |
+
|
| 52 |
+
mkdir -p "${HF_HOME}" "${TRANSFORMERS_CACHE}" "${HF_DATASETS_CACHE}" "${OUT_BASE}"
|
| 53 |
+
|
| 54 |
+
############################################
|
| 55 |
+
# Dependencies
|
| 56 |
+
############################################
|
| 57 |
+
python - <<'PY'
|
| 58 |
+
import sys
|
| 59 |
+
print("python:", sys.version)
|
| 60 |
+
PY
|
| 61 |
+
|
| 62 |
+
# Keep installs minimal and reproducible enough for Kaggle.
|
| 63 |
+
python -m pip install -U pip
|
| 64 |
+
python -m pip install -U "transformers>=4.45" accelerate datasets peft
|
| 65 |
+
|
| 66 |
+
# Unsloth is optional; script falls back to PEFT if it isn't installed.
|
| 67 |
+
python -m pip install -U unsloth || true
|
| 68 |
+
|
| 69 |
+
############################################
|
| 70 |
+
# Helper to add optional args
|
| 71 |
+
############################################
|
| 72 |
+
maybe_limit_args() {
|
| 73 |
+
local limit_val="$1"
|
| 74 |
+
if [[ -n "${limit_val}" ]]; then
|
| 75 |
+
echo "--limit" "${limit_val}"
|
| 76 |
+
fi
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
maybe_lora_args() {
|
| 80 |
+
if [[ "${USE_LORA}" == "1" ]]; then
|
| 81 |
+
echo "--use_lora"
|
| 82 |
+
else
|
| 83 |
+
echo ""
|
| 84 |
+
fi
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
############################################
|
| 88 |
+
# Stage 1
|
| 89 |
+
############################################
|
| 90 |
+
echo "==== Stage 1: projector alignment/pretrain ===="
|
| 91 |
+
python scripts/train_unsloth_kaggle.py \
|
| 92 |
+
--stage stage1 \
|
| 93 |
+
--manthan_model "${MANTHAN_MODEL}" \
|
| 94 |
+
--text_model "${TEXT_MODEL}" \
|
| 95 |
+
--dataset "${STAGE1_DS}" \
|
| 96 |
+
--output_dir "${STAGE1_OUT}" \
|
| 97 |
+
$(maybe_lora_args) \
|
| 98 |
+
--max_length "${MAX_LENGTH}" \
|
| 99 |
+
--image_size "${IMAGE_SIZE}" \
|
| 100 |
+
--batch_size "${BATCH_SIZE}" \
|
| 101 |
+
--grad_accum "${GRAD_ACCUM}" \
|
| 102 |
+
--lr "${LR}" \
|
| 103 |
+
--epochs "${EPOCHS_STAGE1}" \
|
| 104 |
+
$(maybe_limit_args "${LIMIT_STAGE1}")
|
| 105 |
+
|
| 106 |
+
############################################
|
| 107 |
+
# Stage 2
|
| 108 |
+
############################################
|
| 109 |
+
echo "==== Stage 2: instruction finetune ===="
|
| 110 |
+
python scripts/train_unsloth_kaggle.py \
|
| 111 |
+
--stage stage2 \
|
| 112 |
+
--manthan_model "${MANTHAN_MODEL}" \
|
| 113 |
+
--text_model "${TEXT_MODEL}" \
|
| 114 |
+
--dataset "${STAGE2_DS}" \
|
| 115 |
+
--output_dir "${STAGE2_OUT}" \
|
| 116 |
+
$(maybe_lora_args) \
|
| 117 |
+
--max_length "${MAX_LENGTH}" \
|
| 118 |
+
--image_size "${IMAGE_SIZE}" \
|
| 119 |
+
--batch_size "${BATCH_SIZE}" \
|
| 120 |
+
--grad_accum "${GRAD_ACCUM}" \
|
| 121 |
+
--lr "${LR}" \
|
| 122 |
+
--epochs "${EPOCHS_STAGE2}" \
|
| 123 |
+
$(maybe_limit_args "${LIMIT_STAGE2}")
|
| 124 |
+
|
| 125 |
+
echo "==== Done ===="
|
| 126 |
+
echo "Stage1 outputs: ${STAGE1_OUT}"
|
| 127 |
+
echo "Stage2 outputs: ${STAGE2_OUT}"
|
| 128 |
+
|
| 129 |
+
############################################
|
| 130 |
+
# Optional upload (manual control)
|
| 131 |
+
############################################
|
| 132 |
+
if [[ "${UPLOAD}" == "1" ]]; then
|
| 133 |
+
echo "UPLOAD=1: attempting to upload artifacts (requires HF auth)."
|
| 134 |
+
python -m pip install -U huggingface_hub
|
| 135 |
+
echo "You can now upload ${OUT_BASE} with your preferred workflow."
|
| 136 |
+
fi
|
scripts/smoke_export_load.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""Smoke test: export a HF folder, then load it with trust_remote_code.
|
| 4 |
+
|
| 5 |
+
This is meant to catch:
|
| 6 |
+
- remote-code syntax/indentation errors
|
| 7 |
+
- missing auto_map
|
| 8 |
+
- missing model weights (optional stub)
|
| 9 |
+
- basic forward/generate wiring regressions
|
| 10 |
+
|
| 11 |
+
It intentionally uses the small stub-weights mode so it does not download big models.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> int:
|
| 25 |
+
repo_root = Path(__file__).resolve().parents[1]
|
| 26 |
+
out_dir = repo_root / "hf_export_ready"
|
| 27 |
+
|
| 28 |
+
if out_dir.exists():
|
| 29 |
+
# keep it simple
|
| 30 |
+
subprocess.run(["rm", "-rf", str(out_dir)], check=True)
|
| 31 |
+
|
| 32 |
+
subprocess.run(
|
| 33 |
+
[
|
| 34 |
+
sys.executable,
|
| 35 |
+
str(repo_root / "scripts" / "export_hf.py"),
|
| 36 |
+
"--out",
|
| 37 |
+
str(out_dir),
|
| 38 |
+
"--write_stub_weights",
|
| 39 |
+
],
|
| 40 |
+
check=True,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print("Loading tokenizer...")
|
| 44 |
+
tok = AutoTokenizer.from_pretrained(out_dir, trust_remote_code=True)
|
| 45 |
+
|
| 46 |
+
print("Loading model...")
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(out_dir, trust_remote_code=True)
|
| 48 |
+
model.eval()
|
| 49 |
+
|
| 50 |
+
# Basic forward pass (text-only)
|
| 51 |
+
ids = tok("Hello", return_tensors="pt").input_ids
|
| 52 |
+
with torch.inference_mode():
|
| 53 |
+
out = model(input_ids=ids)
|
| 54 |
+
assert out.logits.shape[:2] == ids.shape
|
| 55 |
+
|
| 56 |
+
# Tiny generate smoke
|
| 57 |
+
with torch.inference_mode():
|
| 58 |
+
gen = model.generate(ids, max_new_tokens=4, use_cache=False)
|
| 59 |
+
assert gen.shape[0] == 1
|
| 60 |
+
|
| 61 |
+
print("SMOKE OK")
|
| 62 |
+
return 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
raise SystemExit(main())
|
scripts/train_unsloth_kaggle.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kaggle/Unsloth training entrypoint for Manthan-T1 (TinyLLaVA-style).
|
| 2 |
+
|
| 3 |
+
This script is intended to be copied into a Kaggle notebook and run on 2×T4.
|
| 4 |
+
It supports two stages:
|
| 5 |
+
- stage1: projector alignment pretraining (e.g., LLaVA-CC3M-Pretrain-595K)
|
| 6 |
+
- stage2: instruction tuning (e.g., LLaVA-Instruct-150K)
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- We follow MicroLLaVA/TinyLLaVA convention: IMAGE_TOKEN_INDEX = -200 is inserted
|
| 10 |
+
into input_ids for <image> placeholders.
|
| 11 |
+
- Labels are IGNORE_INDEX for everything except assistant tokens.
|
| 12 |
+
- This script trains:
|
| 13 |
+
- the multimodal projector (always)
|
| 14 |
+
- LoRA adapters on the text model (optional, recommended)
|
| 15 |
+
- vision tower is frozen by default
|
| 16 |
+
|
| 17 |
+
You still need a *real* base model + vision tower weights. Stub exports will run
|
| 18 |
+
but won't learn useful vision-language alignment.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import os
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from torch import nn
|
| 30 |
+
from torch.utils.data import Dataset
|
| 31 |
+
|
| 32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
# Fallback for non-Unsloth environments
|
| 36 |
+
from peft import LoraConfig, get_peft_model
|
| 37 |
+
except Exception: # pragma: no cover
|
| 38 |
+
LoraConfig = None
|
| 39 |
+
get_peft_model = None
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# Kaggle + Unsloth
|
| 43 |
+
from unsloth import FastLanguageModel
|
| 44 |
+
except Exception: # pragma: no cover
|
| 45 |
+
FastLanguageModel = None
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from datasets import load_dataset
|
| 49 |
+
except Exception as e: # pragma: no cover
|
| 50 |
+
raise RuntimeError(
|
| 51 |
+
"Missing dependency `datasets`. Install with `pip install datasets` (Kaggle: add to notebook)."
|
| 52 |
+
) from e
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
IMAGE_TOKEN_INDEX = -200
|
| 56 |
+
IGNORE_INDEX = -100
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX) -> List[int]:
|
| 60 |
+
"""MicroLLaVA/TinyLLaVA tokenizer: split on '<image>' and insert a negative id."""
|
| 61 |
+
|
| 62 |
+
def _insert_separator(X, sep):
|
| 63 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 64 |
+
|
| 65 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 66 |
+
|
| 67 |
+
input_ids: List[int] = []
|
| 68 |
+
offset = 0
|
| 69 |
+
if (
|
| 70 |
+
len(prompt_chunks) > 0
|
| 71 |
+
and len(prompt_chunks[0]) > 0
|
| 72 |
+
and tokenizer.bos_token_id is not None
|
| 73 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
| 74 |
+
):
|
| 75 |
+
offset = 1
|
| 76 |
+
input_ids.append(prompt_chunks[0][0])
|
| 77 |
+
|
| 78 |
+
for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 79 |
+
input_ids.extend(x[offset:])
|
| 80 |
+
|
| 81 |
+
return input_ids
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_prompt_from_conversations(conversations: List[Dict[str, str]]) -> Tuple[str, str]:
|
| 85 |
+
"""Return (full_prompt, assistant_answer_text).
|
| 86 |
+
|
| 87 |
+
LLaVA datasets are 2-turn: human then gpt.
|
| 88 |
+
We map to the string template used in `ManthanForCausalLM.format_chat_prompt`.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
# Expect 2 turns
|
| 92 |
+
human = conversations[0]["value"]
|
| 93 |
+
assistant = conversations[1]["value"]
|
| 94 |
+
|
| 95 |
+
system = (
|
| 96 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 97 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
| 98 |
+
)
|
| 99 |
+
# IMPORTANT: no trailing space after ASSISTANT:
|
| 100 |
+
full = system + f"USER: {human.strip()} ASSISTANT:" + assistant
|
| 101 |
+
return full, assistant
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class TrainExample:
|
| 106 |
+
input_ids: torch.LongTensor
|
| 107 |
+
labels: torch.LongTensor
|
| 108 |
+
image_path: str
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class LlavaLikeDataset(Dataset):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
ds_name: str,
|
| 115 |
+
split: str,
|
| 116 |
+
tokenizer,
|
| 117 |
+
max_length: int,
|
| 118 |
+
limit: Optional[int] = None,
|
| 119 |
+
) -> None:
|
| 120 |
+
self.tokenizer = tokenizer
|
| 121 |
+
self.max_length = max_length
|
| 122 |
+
|
| 123 |
+
# Streaming keeps Kaggle disk usage low.
|
| 124 |
+
self.ds = load_dataset(ds_name, split=split, streaming=True)
|
| 125 |
+
self.limit = limit
|
| 126 |
+
|
| 127 |
+
# Materialize a small index for non-streaming dataloader behavior.
|
| 128 |
+
self._cache: List[Dict[str, Any]] = []
|
| 129 |
+
for i, ex in enumerate(self.ds):
|
| 130 |
+
self._cache.append(ex)
|
| 131 |
+
if limit is not None and i + 1 >= limit:
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
def __len__(self) -> int:
|
| 135 |
+
return len(self._cache)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, idx: int) -> TrainExample:
|
| 138 |
+
ex = self._cache[idx]
|
| 139 |
+
image_path = ex["image"]
|
| 140 |
+
conversations = ex["conversations"]
|
| 141 |
+
|
| 142 |
+
full_prompt, _assistant = build_prompt_from_conversations(conversations)
|
| 143 |
+
ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX)
|
| 144 |
+
|
| 145 |
+
# Truncate
|
| 146 |
+
ids = ids[: self.max_length]
|
| 147 |
+
|
| 148 |
+
# Labels: only learn on assistant answer tokens.
|
| 149 |
+
# Simple heuristic: find the last occurrence of " ASSISTANT:" marker.
|
| 150 |
+
marker = " ASSISTANT:"
|
| 151 |
+
marker_ids = self.tokenizer(marker).input_ids
|
| 152 |
+
|
| 153 |
+
# Find marker in tokenized ids (best-effort).
|
| 154 |
+
start = 0
|
| 155 |
+
for j in range(0, len(ids) - len(marker_ids) + 1):
|
| 156 |
+
if ids[j : j + len(marker_ids)] == marker_ids:
|
| 157 |
+
start = j + len(marker_ids)
|
| 158 |
+
|
| 159 |
+
labels = [IGNORE_INDEX] * len(ids)
|
| 160 |
+
for j in range(start, len(ids)):
|
| 161 |
+
if ids[j] == IMAGE_TOKEN_INDEX:
|
| 162 |
+
labels[j] = IGNORE_INDEX
|
| 163 |
+
else:
|
| 164 |
+
labels[j] = ids[j]
|
| 165 |
+
|
| 166 |
+
return TrainExample(
|
| 167 |
+
input_ids=torch.tensor(ids, dtype=torch.long),
|
| 168 |
+
labels=torch.tensor(labels, dtype=torch.long),
|
| 169 |
+
image_path=image_path,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def load_image_tensor(image_path: str, image_size: int) -> torch.FloatTensor:
|
| 174 |
+
"""Load image from local path in dataset.
|
| 175 |
+
|
| 176 |
+
In Kaggle, LLaVA datasets provide image paths relative to the dataset repo.
|
| 177 |
+
Hugging Face datasets streaming yields paths that resolve via HF cache.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
from PIL import Image
|
| 181 |
+
import torchvision.transforms as T
|
| 182 |
+
|
| 183 |
+
img = Image.open(image_path).convert("RGB")
|
| 184 |
+
tfm = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()])
|
| 185 |
+
return tfm(img)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def collate_fn(batch: List[TrainExample], image_size: int) -> Dict[str, torch.Tensor]:
|
| 189 |
+
# Pad to max length
|
| 190 |
+
max_len = max(x.input_ids.numel() for x in batch)
|
| 191 |
+
input_ids = torch.full((len(batch), max_len), 0, dtype=torch.long)
|
| 192 |
+
labels = torch.full((len(batch), max_len), IGNORE_INDEX, dtype=torch.long)
|
| 193 |
+
attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long)
|
| 194 |
+
|
| 195 |
+
for i, ex in enumerate(batch):
|
| 196 |
+
L = ex.input_ids.numel()
|
| 197 |
+
input_ids[i, :L] = ex.input_ids
|
| 198 |
+
labels[i, :L] = ex.labels
|
| 199 |
+
attention_mask[i, :L] = 1
|
| 200 |
+
|
| 201 |
+
# Images
|
| 202 |
+
pixel_values = torch.stack([load_image_tensor(ex.image_path, image_size) for ex in batch], dim=0)
|
| 203 |
+
|
| 204 |
+
return {
|
| 205 |
+
"input_ids": input_ids,
|
| 206 |
+
"labels": labels,
|
| 207 |
+
"attention_mask": attention_mask,
|
| 208 |
+
"pixel_values": pixel_values,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def set_requires_grad(module: nn.Module, requires_grad: bool) -> None:
|
| 213 |
+
for p in module.parameters():
|
| 214 |
+
p.requires_grad = requires_grad
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def save_projector(model, output_dir: str) -> None:
|
| 218 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 219 |
+
if not hasattr(model, "projector"):
|
| 220 |
+
return
|
| 221 |
+
torch.save(model.projector.state_dict(), os.path.join(output_dir, "projector.pt"))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def maybe_add_lora_to_model(model, args) -> None:
|
| 225 |
+
"""Attach LoRA adapters (Unsloth preferred; PEFT fallback)."""
|
| 226 |
+
|
| 227 |
+
if not args.use_lora:
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
# If the model already has adapters (e.g., loaded via Unsloth), skip.
|
| 231 |
+
if hasattr(model, "peft_config"):
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
if get_peft_model is None or LoraConfig is None:
|
| 235 |
+
raise RuntimeError("PEFT not installed, and Unsloth not available. Install `peft` or enable Unsloth.")
|
| 236 |
+
|
| 237 |
+
target_modules = [
|
| 238 |
+
# Qwen-like
|
| 239 |
+
"q_proj",
|
| 240 |
+
"k_proj",
|
| 241 |
+
"v_proj",
|
| 242 |
+
"o_proj",
|
| 243 |
+
"gate_proj",
|
| 244 |
+
"up_proj",
|
| 245 |
+
"down_proj",
|
| 246 |
+
# GPT-like fallback
|
| 247 |
+
"c_attn",
|
| 248 |
+
"c_proj",
|
| 249 |
+
]
|
| 250 |
+
cfg = LoraConfig(
|
| 251 |
+
r=args.lora_r,
|
| 252 |
+
lora_alpha=args.lora_alpha,
|
| 253 |
+
lora_dropout=args.lora_dropout,
|
| 254 |
+
bias="none",
|
| 255 |
+
task_type="CAUSAL_LM",
|
| 256 |
+
target_modules=target_modules,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Wrap the language model inside Manthan
|
| 260 |
+
model.language_model = get_peft_model(model.language_model, cfg)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def main() -> int:
|
| 265 |
+
ap = argparse.ArgumentParser()
|
| 266 |
+
ap.add_argument("--stage", choices=["stage1", "stage2"], required=True)
|
| 267 |
+
ap.add_argument("--text_model", type=str, default="Qwen/Qwen3-0.6B-Base")
|
| 268 |
+
ap.add_argument("--vision_model", type=str, default="google/siglip-so400m-patch14-384")
|
| 269 |
+
ap.add_argument("--dataset", type=str, required=True)
|
| 270 |
+
ap.add_argument("--output_dir", type=str, default="./outputs")
|
| 271 |
+
ap.add_argument("--max_length", type=int, default=2048)
|
| 272 |
+
ap.add_argument("--image_size", type=int, default=384)
|
| 273 |
+
ap.add_argument("--limit", type=int, default=2048, help="For debugging: number of samples to materialize")
|
| 274 |
+
|
| 275 |
+
# Training
|
| 276 |
+
ap.add_argument("--epochs", type=int, default=1)
|
| 277 |
+
ap.add_argument("--batch_size", type=int, default=1)
|
| 278 |
+
ap.add_argument("--grad_accum", type=int, default=16)
|
| 279 |
+
ap.add_argument("--lr", type=float, default=1e-4)
|
| 280 |
+
ap.add_argument("--warmup_ratio", type=float, default=0.03)
|
| 281 |
+
ap.add_argument("--use_lora", action="store_true")
|
| 282 |
+
ap.add_argument("--lora_r", type=int, default=16)
|
| 283 |
+
ap.add_argument("--lora_alpha", type=int, default=32)
|
| 284 |
+
ap.add_argument("--lora_dropout", type=float, default=0.05)
|
| 285 |
+
|
| 286 |
+
ap.add_argument(
|
| 287 |
+
"--manthan_model",
|
| 288 |
+
type=str,
|
| 289 |
+
required=True,
|
| 290 |
+
help="HF repo id or local path that contains Manthan remote-code (the thing you push to HF).",
|
| 291 |
+
)
|
| 292 |
+
ap.add_argument("--save_every", type=int, default=500)
|
| 293 |
+
ap.add_argument("--dry_run", action="store_true", help="Run a single synthetic step (no datasets).")
|
| 294 |
+
|
| 295 |
+
args = ap.parse_args()
|
| 296 |
+
|
| 297 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 300 |
+
if device != "cuda":
|
| 301 |
+
print("WARNING: This script is designed for CUDA (Kaggle). Running on CPU will be extremely slow.")
|
| 302 |
+
|
| 303 |
+
# Tokenizer (use the LLM tokenizer)
|
| 304 |
+
tok = AutoTokenizer.from_pretrained(args.text_model, trust_remote_code=True, use_fast=False)
|
| 305 |
+
if tok.pad_token_id is None:
|
| 306 |
+
tok.pad_token = tok.eos_token
|
| 307 |
+
|
| 308 |
+
# Load Manthan remote-code model
|
| 309 |
+
# (This should contain config that points to your desired text_model_id & vision_model_id.)
|
| 310 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 311 |
+
args.manthan_model,
|
| 312 |
+
trust_remote_code=True,
|
| 313 |
+
torch_dtype=torch.float16 if device == "cuda" else None,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
model.train()
|
| 317 |
+
model.to(device)
|
| 318 |
+
|
| 319 |
+
# Make sure we don't train the vision tower (T4-friendly)
|
| 320 |
+
if hasattr(model, "vision_model") and model.vision_model is not None:
|
| 321 |
+
set_requires_grad(model.vision_model, False)
|
| 322 |
+
if hasattr(model, "vision_tower") and model.vision_tower is not None:
|
| 323 |
+
set_requires_grad(model.vision_tower, False)
|
| 324 |
+
|
| 325 |
+
# Train projector always
|
| 326 |
+
if hasattr(model, "projector"):
|
| 327 |
+
set_requires_grad(model.projector, True)
|
| 328 |
+
|
| 329 |
+
# Add LoRA to the language model (recommended)
|
| 330 |
+
maybe_add_lora_to_model(model, args)
|
| 331 |
+
|
| 332 |
+
# Optimizer params = trainable only
|
| 333 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 334 |
+
if len(trainable_params) == 0:
|
| 335 |
+
raise RuntimeError("No trainable parameters. Did you freeze everything?")
|
| 336 |
+
|
| 337 |
+
optim = torch.optim.AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
|
| 338 |
+
|
| 339 |
+
# Data
|
| 340 |
+
if args.dry_run:
|
| 341 |
+
# Minimal synthetic batch (no images on disk). This just validates loss pathway.
|
| 342 |
+
B, T = 1, min(64, args.max_length)
|
| 343 |
+
|
| 344 |
+
# IMPORTANT: some tokenizers report an imprecise `vocab_size`; `len(tok)` is the safe upper bound.
|
| 345 |
+
tok_vocab = int(len(tok))
|
| 346 |
+
input_ids = torch.randint(low=0, high=max(tok_vocab - 1, 1), size=(B, T), dtype=torch.long)
|
| 347 |
+
labels = input_ids.clone()
|
| 348 |
+
attn = torch.ones_like(input_ids)
|
| 349 |
+
pixel_values = torch.randn(B, 3, args.image_size, args.image_size)
|
| 350 |
+
|
| 351 |
+
# Insert one image placeholder
|
| 352 |
+
input_ids[0, 5] = IMAGE_TOKEN_INDEX
|
| 353 |
+
labels[0, :10] = IGNORE_INDEX
|
| 354 |
+
|
| 355 |
+
# If tokenizer vocab > model vocab (common in dry_run), clamp to avoid CE index errors.
|
| 356 |
+
lm_vocab = None
|
| 357 |
+
try:
|
| 358 |
+
if hasattr(model, "language_model") and hasattr(model.language_model, "config"):
|
| 359 |
+
lm_vocab = int(getattr(model.language_model.config, "vocab_size", 0) or 0)
|
| 360 |
+
except Exception:
|
| 361 |
+
lm_vocab = None
|
| 362 |
+
|
| 363 |
+
if lm_vocab and lm_vocab > 0:
|
| 364 |
+
safe_ids = input_ids.clone()
|
| 365 |
+
mask = safe_ids >= 0
|
| 366 |
+
safe_ids[mask] = safe_ids[mask].clamp(min=0, max=lm_vocab - 1)
|
| 367 |
+
input_ids = safe_ids
|
| 368 |
+
|
| 369 |
+
safe_labels = labels.clone()
|
| 370 |
+
mask = safe_labels >= 0
|
| 371 |
+
safe_labels[mask] = safe_labels[mask].clamp(min=0, max=lm_vocab - 1)
|
| 372 |
+
labels = safe_labels
|
| 373 |
+
|
| 374 |
+
batch = {
|
| 375 |
+
"input_ids": input_ids.to(device),
|
| 376 |
+
"labels": labels.to(device),
|
| 377 |
+
"attention_mask": attn.to(device),
|
| 378 |
+
"pixel_values": pixel_values.to(device),
|
| 379 |
+
}
|
| 380 |
+
out = model(**batch)
|
| 381 |
+
print("dry_run loss:", float(out.loss))
|
| 382 |
+
out.loss.backward()
|
| 383 |
+
optim.step()
|
| 384 |
+
optim.zero_grad(set_to_none=True)
|
| 385 |
+
save_projector(model, args.output_dir)
|
| 386 |
+
if hasattr(model, "language_model") and hasattr(model.language_model, "save_pretrained"):
|
| 387 |
+
# Save adapters if present
|
| 388 |
+
try:
|
| 389 |
+
model.language_model.save_pretrained(args.output_dir)
|
| 390 |
+
except Exception:
|
| 391 |
+
pass
|
| 392 |
+
return 0
|
| 393 |
+
|
| 394 |
+
ds = LlavaLikeDataset(args.dataset, split="train", tokenizer=tok, max_length=args.max_length, limit=args.limit)
|
| 395 |
+
from torch.utils.data import DataLoader
|
| 396 |
+
|
| 397 |
+
dl = DataLoader(
|
| 398 |
+
ds,
|
| 399 |
+
batch_size=args.batch_size,
|
| 400 |
+
shuffle=True,
|
| 401 |
+
num_workers=2,
|
| 402 |
+
collate_fn=lambda b: collate_fn(b, args.image_size),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
total_steps = (len(dl) * args.epochs) // max(1, args.grad_accum)
|
| 406 |
+
warmup_steps = max(1, int(total_steps * args.warmup_ratio))
|
| 407 |
+
sched = get_cosine_schedule_with_warmup(optim, warmup_steps, total_steps)
|
| 408 |
+
|
| 409 |
+
step = 0
|
| 410 |
+
optim.zero_grad(set_to_none=True)
|
| 411 |
+
for epoch in range(args.epochs):
|
| 412 |
+
for micro_idx, batch in enumerate(dl):
|
| 413 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 414 |
+
|
| 415 |
+
# Mixed precision on Kaggle
|
| 416 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device == "cuda")):
|
| 417 |
+
out = model(**batch)
|
| 418 |
+
loss = out.loss / max(1, args.grad_accum)
|
| 419 |
+
|
| 420 |
+
loss.backward()
|
| 421 |
+
|
| 422 |
+
if (micro_idx + 1) % args.grad_accum == 0:
|
| 423 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
| 424 |
+
optim.step()
|
| 425 |
+
sched.step()
|
| 426 |
+
optim.zero_grad(set_to_none=True)
|
| 427 |
+
step += 1
|
| 428 |
+
|
| 429 |
+
if step % 10 == 0:
|
| 430 |
+
print(f"epoch={epoch} step={step}/{total_steps} loss={float(out.loss):.4f}")
|
| 431 |
+
|
| 432 |
+
if step % args.save_every == 0:
|
| 433 |
+
save_projector(model, args.output_dir)
|
| 434 |
+
# Save adapters if any
|
| 435 |
+
try:
|
| 436 |
+
model.save_pretrained(args.output_dir)
|
| 437 |
+
except Exception:
|
| 438 |
+
pass
|
| 439 |
+
|
| 440 |
+
if step >= total_steps:
|
| 441 |
+
break
|
| 442 |
+
|
| 443 |
+
save_projector(model, args.output_dir)
|
| 444 |
+
try:
|
| 445 |
+
model.save_pretrained(args.output_dir)
|
| 446 |
+
except Exception:
|
| 447 |
+
pass
|
| 448 |
+
|
| 449 |
+
print("DONE")
|
| 450 |
+
return 0
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
raise SystemExit(main())
|
tests/test_smoke.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from manthan_t1.smoke_test import TinyToyModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_tiny_model_forward():
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
|
| 7 |
+
m = TinyToyModel(vocab_size=64, d_model=32)
|
| 8 |
+
x = mx.random.randint(0, 64, shape=(2, 8))
|
| 9 |
+
y = m(x)
|
| 10 |
+
assert y.shape == (2, 8, 64)
|