| import torch |
| from datasets import load_dataset, load_metric |
| from transformers import ( |
| AutoProcessor, |
| BlipForConditionalGeneration, |
| TrainingArguments, |
| Trainer, |
| ) |
| from PIL import Image |
|
|
| |
| MODEL_NAME = "Salesforce/blip-image-captioning-base" |
| DATASET_ID = "lambdalabs/pokemon-blip-captions" |
| OUTPUT_DIR = "./blip-image-captioning-finetuned" |
| NUM_TRAIN_EPOCHS = 3 |
| BATCH_SIZE = 16 |
|
|
| |
| |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) |
| model = BlipForConditionalGeneration.from_pretrained(MODEL_NAME) |
|
|
| |
| print(f"Loading dataset: {DATASET_ID}") |
| ds = load_dataset(DATASET_ID) |
| |
| ds = ds['train'].train_test_split(test_size=0.1) |
| train_ds = ds['train'] |
| eval_ds = ds['test'] |
|
|
| |
| max_caption_length = 50 |
|
|
| def preprocess_data(examples): |
| """Tokenizes captions and processes images.""" |
| |
| |
| inputs = processor( |
| images=[image.convert("RGB") for image in examples["image"]], |
| text=examples["text"], |
| padding="max_length", |
| max_length=max_caption_length, |
| truncation=True, |
| return_tensors="pt" |
| ) |
|
|
| |
| |
| inputs["labels"] = inputs["input_ids"] |
|
|
| |
| del inputs["input_ids"] |
| del inputs["attention_mask"] |
|
|
| return inputs |
|
|
| |
| print("Applying preprocessing to the dataset...") |
| |
| train_ds.set_transform(preprocess_data) |
| eval_ds.set_transform(preprocess_data) |
|
|
|
|
| |
| |
| |
|
|
| training_args = TrainingArguments( |
| output_dir=OUTPUT_DIR, |
| num_train_epochs=NUM_TRAIN_EPOCHS, |
| per_device_train_batch_size=BATCH_SIZE, |
| per_device_eval_batch_size=BATCH_SIZE, |
| learning_rate=5e-5, |
| evaluation_strategy="epoch", |
| logging_dir=f"{OUTPUT_DIR}/logs", |
| logging_steps=100, |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| fp16=torch.cuda.is_available(), |
| push_to_hub=True, |
| hub_model_id=f"YOUR_HUGGINGFACE_USERNAME/blip-finetuned-{DATASET_ID.split('/')[-1]}", |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| tokenizer=processor.tokenizer, |
| ) |
|
|
| |
| print("Starting training...") |
| trainer.train() |
| print("Training complete! Pushing model to Hub...") |
|
|
| |
| trainer.save_model(OUTPUT_DIR) |
| |
|
|
| |
| |