from datasets import load_dataset dataset_name = "ayoubkirouane/llava-instruct-small" # Load Dataset dataset = load_dataset(dataset_name) # import os # import zipfile # import io # # from datasets import DatasetDict # from huggingface_hub import hf_hub_download, list_repo_files # from PIL import Image # dataset_train_split = "test" # def format_data(samples: dict[str, any]) -> dict[str, list]: # formatted_samples = {"messages": []} # for cont in range(len(samples["question"])): # images = [] # for img_path in samples["input_image_path"][cont]: # try: # with open(img_path, "rb") as f: # img_bytes = f.read() # image = Image.open(io.BytesIO(img_bytes)).convert("RGB") # images.append({"type": "image", "image": image}) # except Exception as e: # print(f"Error processing image {img_path}: {e}") # continue # formatted_samples["messages"].append( # [ # {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, # {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, # {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, # ] # ) # return formatted_samples # For multi-image example # def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: # all_files = list_repo_files(dataset_name, repo_type="dataset") # zip_files = [f for f in all_files if f.endswith(".zip")] # for zip_filename in zip_files: # zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") # extract_folder = zip_filename.replace(".zip", "") # os.makedirs(extract_folder, exist_ok=True) # with zipfile.ZipFile(zip_path, "r") as zip_ref: # zip_ref.extractall(extract_folder) # dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) # return dataset # dataset = prepare_dataset(dataset, dataset_name, dataset_train_split) import torch from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig model_id = "HuggingFaceTB/SmolVLM-256M-Instruct" # BitsAndBytesConfig int-4 config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_storage=torch.bfloat16, ) # Load model and tokenizer model = AutoModelForImageTextToText.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934) quantization_config=bnb_config ) processor = AutoProcessor.from_pretrained(model_id,use_fast=True) processor.tokenizer.padding_side = "right" from peft import LoraConfig, get_peft_model # Configure QLoRA peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.05, r=16, bias="none", target_modules="all-linear", task_type="CAUSAL_LM", modules_to_save=[ "lm_head", "embed_tokens", ], ) from trl import SFTConfig training_args = SFTConfig( output_dir="smolvlm-trl-sft-test", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets). num_train_epochs=1, # Set the number of epochs to train the model. per_device_train_batch_size=2, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1 gradient_accumulation_steps=32, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1 gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training. optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance. save_strategy="epoch", # Save checkpoints at the end of each epoch. learning_rate=2e-05, # Learning rate for training. bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations. push_to_hub=False, # Automatically push the fine-tuned model to Hugging Face Hub after training. report_to="tensorboard", # Automatically report metrics to tensorboard. gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues. dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually. remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing). ) from PIL import Image # For multi-image cases def process_vision_info(messages: list[dict]) -> list[Image.Image]: image_inputs = [] for msg in messages: content = msg.get("content", []) if not isinstance(content, list): content = [content] for element in content: if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): if "image" in element: image = element["image"] else: image = element if image is not None: image = Image.open(io.BytesIO(image["bytes"])) image_inputs.append(image.convert("RGB")) return image_inputs def collate_fn(examples): texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples] if "images" in examples[0]: # single-image images = [ [img.convert("RGB") for img in example["images"]] for example in examples ] else: # multi-image images = [process_vision_info(example["messages"]) for example in examples] # Tokenize the texts and process the images batch = processor( images=images, text=texts, return_tensors="pt", padding=True ) # Encode texts and images into tensors # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() # Clone input IDs for labels # Mask image tokens image_token_id = getattr(model.config, "image_token_id", None) if image_token_id is None: image_token_id = processor.tokenizer.convert_tokens_to_ids("") # Mask tokens for not being used in the loss computation labels[labels == processor.tokenizer.pad_token_id] = -100 labels[labels == image_token_id] = -100 # labels[labels == 262144] = -100 batch["labels"] = labels return batch # Return the prepared batch # Training from trl import SFTTrainer trainer = SFTTrainer( model=model, args=training_args, data_collator=collate_fn, train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"], processing_class=processor, peft_config=peft_config, ) trainer.train() # Save the final model trainer.save_model()