Spaces:
Sleeping
Sleeping
updated
Browse files- .gitignore +2 -0
- .python-version +2 -0
- train.py +75 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv
|
| 2 |
+
my_blip_computer_thoughts/
|
.python-version
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
3.9.21
|
train.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py
|
| 3 |
+
|
| 4 |
+
A complete example of fine-tuning BLIP on 'agentsea/computer-thoughts' for captioning.
|
| 5 |
+
All processing is done in the collate function. This is simpler and avoids shape mismatches.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from datasets import load_dataset, Image as HFImage
|
| 10 |
+
from transformers import (
|
| 11 |
+
BlipProcessor,
|
| 12 |
+
BlipForConditionalGeneration,
|
| 13 |
+
TrainingArguments,
|
| 14 |
+
Trainer
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# 1. Load dataset
|
| 18 |
+
dataset = load_dataset("agentsea/computer-thoughts")
|
| 19 |
+
|
| 20 |
+
# 2. Rename "image_before" -> "image" and cast to HFImage so it becomes a PIL Image
|
| 21 |
+
dataset = dataset.rename_column("image_before", "image")
|
| 22 |
+
dataset = dataset.cast_column("image", HFImage())
|
| 23 |
+
|
| 24 |
+
# 3. Create a small subset for demo (just 5 examples). Remove this if you want the full data.
|
| 25 |
+
train_subset = dataset["train"].select(range(5))
|
| 26 |
+
|
| 27 |
+
# 4. Load the BLIP base model and processor
|
| 28 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 29 |
+
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 30 |
+
|
| 31 |
+
# 5. Define a collate_fn that transforms images+text on-the-fly
|
| 32 |
+
def collate_fn(examples):
|
| 33 |
+
# examples is a list of dicts, each dict with keys:
|
| 34 |
+
# 'task', 'image', 'image_after', 'action', 'thought', 'bad_thought', 'subtask', 'bad_subtask', etc.
|
| 35 |
+
# We'll use 'image' (PIL) and 'subtask' (string) as the caption.
|
| 36 |
+
images = [ex["image"] for ex in examples] # PIL images
|
| 37 |
+
texts = [ex["subtask"] for ex in examples] # or whichever text column you want
|
| 38 |
+
|
| 39 |
+
inputs = processor(images=images, text=texts, return_tensors="pt", padding=True)
|
| 40 |
+
|
| 41 |
+
# Add labels so the model can compute cross-entropy loss
|
| 42 |
+
# For a basic approach: labels = input_ids
|
| 43 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
| 44 |
+
|
| 45 |
+
return inputs
|
| 46 |
+
|
| 47 |
+
# 6. Define training arguments
|
| 48 |
+
training_args = TrainingArguments(
|
| 49 |
+
output_dir="./my_blip_computer_thoughts",
|
| 50 |
+
num_train_epochs=1,
|
| 51 |
+
per_device_train_batch_size=1,
|
| 52 |
+
gradient_accumulation_steps=4, # effectively batch size 4 per device
|
| 53 |
+
logging_steps=5,
|
| 54 |
+
save_steps=20,
|
| 55 |
+
save_total_limit=2,
|
| 56 |
+
remove_unused_columns=False # important when custom columns are in the dataset
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# 6. Create Trainer
|
| 60 |
+
trainer = Trainer(
|
| 61 |
+
model=model,
|
| 62 |
+
args=training_args,
|
| 63 |
+
train_dataset=train_subset, # or dataset["train"] for the full set
|
| 64 |
+
data_collator=collate_fn,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 7. Train
|
| 68 |
+
trainer.train()
|
| 69 |
+
|
| 70 |
+
# 9. Push the final model + processor to Hugging Face Hub
|
| 71 |
+
# (Make sure you're logged in: huggingface-cli login)
|
| 72 |
+
model.push_to_hub("zeddotes/blip-computer-thoughts")
|
| 73 |
+
processor.push_to_hub("zeddotes/blip-computer-thoughts")
|
| 74 |
+
|
| 75 |
+
print("Done training and pushed model to zeddotes/blip-computer-thoughts!")
|