crying_cv / train.py
cho-sr's picture
Upload train.py
cc9aa53 verified
from datasets import load_dataset
dataset_name = "ayoubkirouane/llava-instruct-small"
# Load Dataset
dataset = load_dataset(dataset_name)
# import os
# import zipfile
# import io
# # from datasets import DatasetDict
# from huggingface_hub import hf_hub_download, list_repo_files
# from PIL import Image
# dataset_train_split = "test"
# def format_data(samples: dict[str, any]) -> dict[str, list]:
# formatted_samples = {"messages": []}
# for cont in range(len(samples["question"])):
# images = []
# for img_path in samples["input_image_path"][cont]:
# try:
# with open(img_path, "rb") as f:
# img_bytes = f.read()
# image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# images.append({"type": "image", "image": image})
# except Exception as e:
# print(f"Error processing image {img_path}: {e}")
# continue
# formatted_samples["messages"].append(
# [
# {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
# {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
# {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
# ]
# )
# return formatted_samples
# For multi-image example
# def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
# all_files = list_repo_files(dataset_name, repo_type="dataset")
# zip_files = [f for f in all_files if f.endswith(".zip")]
# for zip_filename in zip_files:
# zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
# extract_folder = zip_filename.replace(".zip", "")
# os.makedirs(extract_folder, exist_ok=True)
# with zipfile.ZipFile(zip_path, "r") as zip_ref:
# zip_ref.extractall(extract_folder)
# dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
# return dataset
# dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
model_id = "HuggingFaceTB/SmolVLM-256M-Instruct"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.bfloat16,
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id,use_fast=True)
processor.tokenizer.padding_side = "right"
from peft import LoraConfig, get_peft_model
# Configure QLoRA
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
from trl import SFTConfig
training_args = SFTConfig(
output_dir="smolvlm-trl-sft-test", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
num_train_epochs=1, # Set the number of epochs to train the model.
per_device_train_batch_size=2, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
gradient_accumulation_steps=32, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
save_strategy="epoch", # Save checkpoints at the end of each epoch.
learning_rate=2e-05, # Learning rate for training.
bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
push_to_hub=False, # Automatically push the fine-tuned model to Hugging Face Hub after training.
report_to="tensorboard", # Automatically report metrics to tensorboard.
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
)
from PIL import Image
# For multi-image cases
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
if "image" in element:
image = element["image"]
else:
image = element
if image is not None:
image = Image.open(io.BytesIO(image["bytes"]))
image_inputs.append(image.convert("RGB"))
return image_inputs
def collate_fn(examples):
texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
if "images" in examples[0]: # single-image
images = [
[img.convert("RGB") for img in example["images"]]
for example in examples
]
else: # multi-image
images = [process_vision_info(example["messages"]) for example in examples]
# Tokenize the texts and process the images
batch = processor(
images=images, text=texts, return_tensors="pt", padding=True
) # Encode texts and images into tensors
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask image tokens
image_token_id = getattr(model.config, "image_token_id", None)
if image_token_id is None:
image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
# labels[labels == 262144] = -100
batch["labels"] = labels
return batch # Return the prepared batch
# Training
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
processing_class=processor,
peft_config=peft_config,
)
trainer.train()
# Save the final model
trainer.save_model()