fangedvampire24's picture
Create train.py
fcd12e4 verified
Raw
History Blame Contribute Delete
3.63 kB
import torch
from datasets import load_dataset, load_metric
from transformers import (
AutoProcessor,
BlipForConditionalGeneration,
TrainingArguments,
Trainer,
)
from PIL import Image
# --- 1. CONFIGURATION ---
MODEL_NAME = "Salesforce/blip-image-captioning-base"
DATASET_ID = "lambdalabs/pokemon-blip-captions" # Replace with COCO or your specialized dataset
OUTPUT_DIR = "./blip-image-captioning-finetuned"
NUM_TRAIN_EPOCHS = 3
BATCH_SIZE = 16
# --- 2. LOAD PROCESSOR AND MODEL ---
# The processor handles both image feature extraction and text tokenization
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = BlipForConditionalGeneration.from_pretrained(MODEL_NAME)
# --- 3. LOAD & PREPARE DATASET ---
print(f"Loading dataset: {DATASET_ID}")
ds = load_dataset(DATASET_ID)
# We'll use the 'train' split and split it further for a validation set
ds = ds['train'].train_test_split(test_size=0.1)
train_ds = ds['train']
eval_ds = ds['test']
# Set the maximum sequence length for the captions
max_caption_length = 50
def preprocess_data(examples):
"""Tokenizes captions and processes images."""
# Process images and captions together
# BLIP processor handles image resizing, normalization, and text tokenization
inputs = processor(
images=[image.convert("RGB") for image in examples["image"]],
text=examples["text"],
padding="max_length",
max_length=max_caption_length,
truncation=True,
return_tensors="pt"
)
# The labels for Causal Language Modeling are the input tokens shifted right
# The tokenizer includes BOS/EOS tokens which are essential here
inputs["labels"] = inputs["input_ids"]
# Delete the original image data since the processor has converted it to pixel_values
del inputs["input_ids"]
del inputs["attention_mask"]
return inputs
# Apply the preprocessing function to the dataset
print("Applying preprocessing to the dataset...")
# set_transform is highly efficient as it applies the function on-the-fly
train_ds.set_transform(preprocess_data)
eval_ds.set_transform(preprocess_data)
# --- 4. TRAINING SETUP (Trainer API) ---
# Define evaluation metric (often BLEU or ROUGE, but WER is common for generation)
# Note: For simplicity, we skip complex metric computation in this basic script.
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_TRAIN_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=5e-5,
evaluation_strategy="epoch",
logging_dir=f"{OUTPUT_DIR}/logs",
logging_steps=100,
save_strategy="epoch",
load_best_model_at_end=True,
fp16=torch.cuda.is_available(), # Use mixed precision if a GPU is available
push_to_hub=True, # Set this to True to push the model to the Hugging Face Hub!
hub_model_id=f"YOUR_HUGGINGFACE_USERNAME/blip-finetuned-{DATASET_ID.split('/')[-1]}", # Customize this
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=processor.tokenizer, # Pass the tokenizer for the Trainer to use
)
# --- 5. START TRAINING ---
print("Starting training...")
trainer.train()
print("Training complete! Pushing model to Hub...")
# --- 6. SAVE & PUSH TO HUB ---
trainer.save_model(OUTPUT_DIR)
# The push_to_hub=True in TrainingArguments automatically handles the final push.
# You will need to log in to your Hugging Face account via the command line
# (huggingface-cli login) or in a notebook (from huggingface_hub import notebook_login; notebook_login()).