kosmylo1992 commited on
Commit
2cbae32
·
verified ·
1 Parent(s): 99064f8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +170 -3
README.md CHANGED
@@ -1,3 +1,170 @@
1
- ---
2
- license: llama2
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ {
3
+ "language": ["en"],
4
+ "license": "llama2",
5
+ "tags": [
6
+ "text-generation",
7
+ "causal-lm",
8
+ "supervised-fine-tuning",
9
+ "instruction-tuning",
10
+ "synthetic-qa",
11
+ "lora",
12
+ "axolotl",
13
+ "deepspeed",
14
+ "transformers",
15
+ "llava",
16
+ "eu-hpc"
17
+ ],
18
+ "datasets": [
19
+ "axolotl_deduplicated_synthetic_qa"
20
+ ],
21
+ "metrics": [
22
+ "loss"
23
+ ],
24
+ "library_name": "transformers",
25
+ "framework": "pytorch",
26
+ "base_model": "llava-hf/llava-1.5-7b-hf",
27
+ "model_name": "llava-7b-sft",
28
+ "pipeline_tag": "text-generation",
29
+ "task_categories": ["text-generation", "question-answering"],
30
+ "model_type": "llava",
31
+ "inference": {
32
+ "parameters": {
33
+ "max_new_tokens": 512,
34
+ "temperature": 0.7,
35
+ "top_p": 0.9
36
+ }
37
+ },
38
+ "trained_on": [
39
+ "Leonardo EuroHPC"
40
+ ],
41
+ "description": "Supervised fine-tuning (SFT) of LLaVA 1.5 7B on synthetic QA pairs using Axolotl and DeepSpeed ZeRO-1. The model improves text-based question answering and instruction following while preserving its multimodal capabilities."
42
+ }
43
+ ---
44
+
45
+ # LLaVA 7B — Supervised Fine-Tuning (SFT) on Synthetic QA
46
+
47
+ **Model type:** Vision-Language Causal Model (text-finetuned LLaVA-1.5)
48
+ **Base model:** [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
49
+ **License:** Llama 2 Community License
50
+ **Framework:** Axolotl + DeepSpeed ZeRO-1 (PyTorch 2.5.1 + CUDA 12.1)
51
+
52
+ ---
53
+
54
+ ## Overview
55
+
56
+ `llava-7b-sft` is a **supervised fine-tuned** version of **LLaVA 1.5 7B**, trained on a synthetic instruction-following dataset of **question–answer pairs** to enhance text understanding and reasoning.
57
+ Although derived from a multimodal base, this SFT run fine-tunes the **language model component** using LoRA adapters which were later **merged into the full model weights**.
58
+
59
+ This model therefore supports **text-only generation** natively (without PEFT) and retains compatibility with the **multimodal processor and vision configuration** from LLaVA.
60
+
61
+ Training was conducted on the **Leonardo EuroHPC** system using **Axolotl** and **DeepSpeed ZeRO-1**.
62
+
63
+ ---
64
+
65
+ ## Training Setup
66
+
67
+ | Component | Specification |
68
+ |:-----------|:--------------|
69
+ | **Objective** | Supervised fine-tuning (instruction-following QA) |
70
+ | **Adapter type** | LoRA (merged into full model) |
71
+ | **Precision** | bfloat16 |
72
+ | **Hardware** | 8 nodes × 2 × NVIDIA A100 64 GB GPUs |
73
+ | **Framework** | Axolotl 0.6 + DeepSpeed ZeRO-1 (PyTorch 2.5.1 + CUDA 12.1) |
74
+ | **Runtime** | ~24 hours |
75
+ | **Checkpoints** | 2 per epoch |
76
+ | **Vision tower** | Frozen during SFT |
77
+ | **Dataset split** | 70% train / 30% validation |
78
+
79
+ ---
80
+
81
+ ## Dataset
82
+
83
+ **Name:** `axolotl_deduplicated_synthetic_qa.jsonl`
84
+ **Type:** Instruction-following synthetic QA dataset (Alpaca-style)
85
+
86
+ Each record contains a single-turn question and a high-quality generated answer.
87
+ This SFT data improves the model’s **reasoning**, **language coherence**, and **conversational QA** quality.
88
+
89
+ ---
90
+
91
+ ## Hyperparameters
92
+
93
+ | Parameter | Value |
94
+ |:-----------|:------|
95
+ | Sequence length | 2048 |
96
+ | Micro batch size | 1 |
97
+ | Gradient accumulation | 4 |
98
+ | Epochs | 1 |
99
+ | Learning rate | 0.0002 |
100
+ | LR scheduler | cosine |
101
+ | Optimizer | AdamW (8-bit) |
102
+ | Warmup steps | 10 |
103
+ | Weight decay | 0.0 |
104
+ | LoRA rank (r) | 16 |
105
+ | LoRA alpha | 32 |
106
+ | LoRA dropout | 0.05 |
107
+ | LoRA target modules | `q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj` |
108
+ | Gradient checkpointing | ✅ |
109
+ | Flash attention | ✅ |
110
+ | Validation set size | 0.3 |
111
+ | Evals per epoch | 2 |
112
+
113
+ ---
114
+
115
+ ## Tokenizer & Processor
116
+
117
+ | Component | Description |
118
+ |:-----------|:-------------|
119
+ | **Tokenizer type** | `AutoTokenizer` |
120
+ | **Processor type** | `AutoProcessor` (compatible with LLaVA image+text inputs) |
121
+ | **Pad token** | `<pad>` (ID 32001) |
122
+ | **Chat template** | `llava` |
123
+
124
+ The processor configuration allows image or text inputs; however, this release focuses on text-based supervised tuning.
125
+
126
+ ---
127
+
128
+ ## Files Included
129
+
130
+ This repository contains the **fully merged model weights** and all required configs for direct use with `transformers`:
131
+
132
+ - `config.json`
133
+ - `model-*.safetensors`
134
+ - `tokenizer.json`
135
+ - `tokenizer_config.json`
136
+ - `tokenizer.model`
137
+ - `special_tokens_map.json`
138
+ - `processor_config.json`
139
+ - `preprocessor_config.json`
140
+ - `vision_config.json`
141
+ - `image_processor_config.json`
142
+ - `README.md`
143
+
144
+ ---
145
+
146
+ ## Usage Example
147
+
148
+ To run text-based generation with this model:
149
+
150
+ ```python
151
+ import torch
152
+ from transformers import AutoProcessor, AutoModelForCausalLM
153
+
154
+ model_id = "ubitech-edg/llava-7b-sft"
155
+
156
+ processor = AutoProcessor.from_pretrained(model_id)
157
+ model = AutoModelForCausalLM.from_pretrained(
158
+ model_id,
159
+ torch_dtype=torch.bfloat16,
160
+ device_map="auto"
161
+ )
162
+
163
+ prompt = "USER: Explain the principle of energy conservation.\nASSISTANT:"
164
+ inputs = processor(text=prompt, return_tensors="pt").to("cuda")
165
+
166
+ with torch.inference_mode():
167
+ outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
168
+
169
+ print(processor.decode(outputs[0], skip_special_tokens=True))
170
+ ```