kosmylo1992 commited on
Commit
892dd4b
·
verified ·
1 Parent(s): de965f3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +168 -3
README.md CHANGED
@@ -1,3 +1,168 @@
1
- ---
2
- license: llama2
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ["en"]
3
+ license: llama2
4
+ tags:
5
+ - image-text-to-text
6
+ - visual-question-answering
7
+ - vision-language
8
+ - llava
9
+ - multimodal
10
+ - causal-lm
11
+ - continual-pretraining
12
+ - lora
13
+ - axolotl
14
+ - deepspeed
15
+ - transformers
16
+ - eu-hpc
17
+ datasets:
18
+ - mm_captions_chat
19
+ - text_cpt_corpus
20
+ metrics: ["loss"]
21
+ library_name: transformers
22
+ framework: pytorch
23
+ base_model: llava-hf/llava-1.5-7b-hf
24
+ model_name: llava-7b-cpt
25
+ pipeline_tag: image-text-to-text
26
+ task_categories: ["image-text-to-text","visual-question-answering"]
27
+ model_type: llava
28
+ inference:
29
+ parameters:
30
+ max_new_tokens: 128
31
+ temperature: 0.2
32
+ top_p: 0.9
33
+ trained_on: ["Leonardo EuroHPC"]
34
+ description: "Two-stage continual pretraining (CPT) of LLaVA 1.5 7B: first on **text-only** data, then on **image–text** chat-style captions. LoRA adapters merged into base."
35
+ ---
36
+
37
+ # LLaVA 7B — Multimodal Continual Pretraining (CPT) with LoRA Adapters
38
+
39
+ **Model type:** Vision-Language Causal Model
40
+ **Base model:** [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
41
+ **License:** Llama 2 Community License (inherits from base)
42
+ **Framework:** Axolotl + DeepSpeed ZeRO-1
43
+
44
+ ---
45
+
46
+ ## Overview
47
+
48
+ `llava-7b-cpt` is a **continual-pretrained** multimodal version of **LLaVA 1.5 7B**, extending its visual and textual reasoning capabilities through domain-specific continual pretraining (CPT).
49
+ The process follows a **two-stage adaptation flow**:
50
+
51
+ 1. **Textual CPT (Stage 1):**
52
+ - Base: `llava-hf/llava-1.5-7b-hf`
53
+ - Objective: text-only continual pretraining on scientific, governmental, news, and encyclopedic corpora.
54
+
55
+ 2. **Multimodal CPT (Stage 2, this release):**
56
+ - Base: the Stage 1 text-CPT model
57
+ - Objective: multimodal (image–text) continual pretraining using image-caption dialogue data.
58
+
59
+ This pipeline enhances LLaVA’s factual grounding and image-conditioned understanding of technical and energy-domain visual content.
60
+ Training was performed on the **Leonardo EuroHPC** supercomputer using **Axolotl 0.6** with **DeepSpeed ZeRO-1** and **bfloat16** precision.
61
+
62
+ ---
63
+
64
+ ## Training Setup
65
+
66
+ | Component | Specification |
67
+ |:-----------|:--------------|
68
+ | **Objective** | Multimodal continual pretraining (image–text dialogue) |
69
+ | **Adapter type** | LoRA |
70
+ | **Precision** | bfloat16 |
71
+ | **Hardware** | 8 nodes × 2 × NVIDIA A100 64 GB GPUs |
72
+ | **Framework** | Axolotl + DeepSpeed ZeRO-1 (PyTorch 2.5.1 + CUDA 12.1) |
73
+ | **Runtime** | ≈ 24 hours |
74
+ | **Checkpoints** | Saved every epoch |
75
+ | **Vision tower** | Frozen |
76
+ | **Text backbone** | LoRA-updated only |
77
+ | **Loss watchdog** | Disabled for multimodal phase |
78
+
79
+ ---
80
+
81
+ ## Dataset
82
+
83
+ The multimodal CPT stage was trained on **image–caption chat-style pairs**, using an Axolotl-compatible JSONL format (`mm_captions_chat.jsonl`) of LLaVA-style message lists.
84
+
85
+ | File | Description |
86
+ |:------|:-------------|
87
+ | **mm_captions_chat.jsonl** | Image–text dialogues for visual captioning and VQA adaptation |
88
+ | **images/** | Folder of referenced image files used by the dataset entries |
89
+
90
+ Each entry contains alternating `user` (image + text prompt) and `assistant` (caption/answer) messages in a chat structure compatible with the `llava` chat template.
91
+
92
+ ---
93
+
94
+ ## Hyperparameters
95
+
96
+ | Parameter | Value |
97
+ |:-----------|:------|
98
+ | Sequence length | 2048 |
99
+ | Micro batch size | 1 |
100
+ | Gradient accumulation | 4 |
101
+ | Epochs | 1 |
102
+ | Max steps | 6000 |
103
+ | Learning rate | 0.00015 |
104
+ | LR scheduler | cosine |
105
+ | Optimizer | AdamW (8-bit) |
106
+ | Warmup ratio | 0.1 |
107
+ | Weight decay | 0.0 |
108
+ | LoRA rank (r) | 16 |
109
+ | LoRA alpha | 32 |
110
+ | LoRA dropout | 0.05 |
111
+ | LoRA target modules | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
112
+ | Gradient checkpointing | ✅ |
113
+ | Flash attention | ❌ (disabled for stability) |
114
+ | Image size | 512 |
115
+ | Resize algorithm | bilinear |
116
+
117
+ ---
118
+
119
+ ## Model Flow
120
+
121
+
122
+ Base: llava-hf/llava-1.5-7b-hf
123
+
124
+ Stage 1 — Textual Continual Pretraining (CPT) → llava-7b-text-cpt
125
+
126
+ Stage 2 — Multimodal Continual Pretraining (CPT) → ubitech-edg/llava-7b-cpt
127
+
128
+ ---
129
+
130
+ ## Tokenizer & Processor
131
+
132
+ | Component | Value |
133
+ |:-----------|:------|
134
+ | **Tokenizer type** | `AutoTokenizer` |
135
+ | **Processor type** | `AutoProcessor` |
136
+ | **Special tokens** | `<pad>` = ID 32001 |
137
+ | **Chat template** | `llava` |
138
+
139
+ ---
140
+
141
+ ## Usage
142
+
143
+ To load and run `llava-7b-cpt` locally for image–text generation:
144
+
145
+ ```python
146
+ from transformers import AutoModelForCausalLM, AutoProcessor
147
+ from PIL import Image
148
+ import torch
149
+
150
+ model_id = "ubitech-edg/llava-7b-cpt"
151
+
152
+ processor = AutoProcessor.from_pretrained(model_id)
153
+ model = AutoModelForCausalLM.from_pretrained(
154
+ model_id,
155
+ torch_dtype=torch.bfloat16,
156
+ device_map="auto"
157
+ )
158
+
159
+ image = Image.open("example.jpg").convert("RGB")
160
+ prompt = "USER: <image>\nDescribe this image in two sentences.\nASSISTANT:"
161
+
162
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
163
+
164
+ with torch.inference_mode():
165
+ output = model.generate(**inputs, max_new_tokens=128)
166
+
167
+ print(processor.decode(output[0], skip_special_tokens=True))
168
+ ```