File size: 3,625 Bytes
fcd12e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from datasets import load_dataset, load_metric
from transformers import (
    AutoProcessor,
    BlipForConditionalGeneration,
    TrainingArguments,
    Trainer,
)
from PIL import Image

# --- 1. CONFIGURATION ---
MODEL_NAME = "Salesforce/blip-image-captioning-base"
DATASET_ID = "lambdalabs/pokemon-blip-captions" # Replace with COCO or your specialized dataset
OUTPUT_DIR = "./blip-image-captioning-finetuned"
NUM_TRAIN_EPOCHS = 3
BATCH_SIZE = 16

# --- 2. LOAD PROCESSOR AND MODEL ---
# The processor handles both image feature extraction and text tokenization
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = BlipForConditionalGeneration.from_pretrained(MODEL_NAME)

# --- 3. LOAD & PREPARE DATASET ---
print(f"Loading dataset: {DATASET_ID}")
ds = load_dataset(DATASET_ID)
# We'll use the 'train' split and split it further for a validation set
ds = ds['train'].train_test_split(test_size=0.1)
train_ds = ds['train']
eval_ds = ds['test']

# Set the maximum sequence length for the captions
max_caption_length = 50

def preprocess_data(examples):
    """Tokenizes captions and processes images."""
    # Process images and captions together
    # BLIP processor handles image resizing, normalization, and text tokenization
    inputs = processor(
        images=[image.convert("RGB") for image in examples["image"]],
        text=examples["text"],
        padding="max_length",
        max_length=max_caption_length,
        truncation=True,
        return_tensors="pt"
    )

    # The labels for Causal Language Modeling are the input tokens shifted right
    # The tokenizer includes BOS/EOS tokens which are essential here
    inputs["labels"] = inputs["input_ids"]

    # Delete the original image data since the processor has converted it to pixel_values
    del inputs["input_ids"]
    del inputs["attention_mask"]

    return inputs

# Apply the preprocessing function to the dataset
print("Applying preprocessing to the dataset...")
# set_transform is highly efficient as it applies the function on-the-fly
train_ds.set_transform(preprocess_data)
eval_ds.set_transform(preprocess_data)


# --- 4. TRAINING SETUP (Trainer API) ---
# Define evaluation metric (often BLEU or ROUGE, but WER is common for generation)
# Note: For simplicity, we skip complex metric computation in this basic script.

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    evaluation_strategy="epoch",
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available(), # Use mixed precision if a GPU is available
    push_to_hub=True, # Set this to True to push the model to the Hugging Face Hub!
    hub_model_id=f"YOUR_HUGGINGFACE_USERNAME/blip-finetuned-{DATASET_ID.split('/')[-1]}", # Customize this
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=processor.tokenizer, # Pass the tokenizer for the Trainer to use
)

# --- 5. START TRAINING ---
print("Starting training...")
trainer.train()
print("Training complete! Pushing model to Hub...")

# --- 6. SAVE & PUSH TO HUB ---
trainer.save_model(OUTPUT_DIR)
# The push_to_hub=True in TrainingArguments automatically handles the final push.

# You will need to log in to your Hugging Face account via the command line
# (huggingface-cli login) or in a notebook (from huggingface_hub import notebook_login; notebook_login()).