File size: 7,808 Bytes
cc9aa53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from datasets import load_dataset
dataset_name = "ayoubkirouane/llava-instruct-small"
# Load Dataset
dataset = load_dataset(dataset_name)
# import os
# import zipfile
# import io
# # from datasets import DatasetDict
# from huggingface_hub import hf_hub_download, list_repo_files
# from PIL import Image
# dataset_train_split = "test"
# def format_data(samples: dict[str, any]) -> dict[str, list]:
# formatted_samples = {"messages": []}
# for cont in range(len(samples["question"])):
# images = []
# for img_path in samples["input_image_path"][cont]:
# try:
# with open(img_path, "rb") as f:
# img_bytes = f.read()
# image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# images.append({"type": "image", "image": image})
# except Exception as e:
# print(f"Error processing image {img_path}: {e}")
# continue
# formatted_samples["messages"].append(
# [
# {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
# {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
# {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
# ]
# )
# return formatted_samples
# For multi-image example
# def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
# all_files = list_repo_files(dataset_name, repo_type="dataset")
# zip_files = [f for f in all_files if f.endswith(".zip")]
# for zip_filename in zip_files:
# zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
# extract_folder = zip_filename.replace(".zip", "")
# os.makedirs(extract_folder, exist_ok=True)
# with zipfile.ZipFile(zip_path, "r") as zip_ref:
# zip_ref.extractall(extract_folder)
# dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
# return dataset
# dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
model_id = "HuggingFaceTB/SmolVLM-256M-Instruct"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.bfloat16,
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id,use_fast=True)
processor.tokenizer.padding_side = "right"
from peft import LoraConfig, get_peft_model
# Configure QLoRA
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
from trl import SFTConfig
training_args = SFTConfig(
output_dir="smolvlm-trl-sft-test", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
num_train_epochs=1, # Set the number of epochs to train the model.
per_device_train_batch_size=2, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
gradient_accumulation_steps=32, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
save_strategy="epoch", # Save checkpoints at the end of each epoch.
learning_rate=2e-05, # Learning rate for training.
bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
push_to_hub=False, # Automatically push the fine-tuned model to Hugging Face Hub after training.
report_to="tensorboard", # Automatically report metrics to tensorboard.
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
)
from PIL import Image
# For multi-image cases
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
if "image" in element:
image = element["image"]
else:
image = element
if image is not None:
image = Image.open(io.BytesIO(image["bytes"]))
image_inputs.append(image.convert("RGB"))
return image_inputs
def collate_fn(examples):
texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
if "images" in examples[0]: # single-image
images = [
[img.convert("RGB") for img in example["images"]]
for example in examples
]
else: # multi-image
images = [process_vision_info(example["messages"]) for example in examples]
# Tokenize the texts and process the images
batch = processor(
images=images, text=texts, return_tensors="pt", padding=True
) # Encode texts and images into tensors
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask image tokens
image_token_id = getattr(model.config, "image_token_id", None)
if image_token_id is None:
image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
# labels[labels == 262144] = -100
batch["labels"] = labels
return batch # Return the prepared batch
# Training
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
processing_class=processor,
peft_config=peft_config,
)
trainer.train()
# Save the final model
trainer.save_model() |