|
|
|
|
|
import os |
|
|
from PIL import Image |
|
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
|
|
|
|
def process_dataset(zip_path, output_dir, generate_captions=True): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
import zipfile |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(output_dir) |
|
|
|
|
|
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
|
|
|
|
|
for img_name in os.listdir(output_dir): |
|
|
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
|
img_path = os.path.join(output_dir, img_name) |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
|
|
|
|
|
|
image.thumbnail((512, 512), Image.LANCZOS) |
|
|
image.save(img_path) |
|
|
|
|
|
if generate_captions: |
|
|
inputs = processor(image, return_tensors="pt") |
|
|
outputs = model.generate(**inputs, max_new_tokens=50) |
|
|
caption = processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
txt_path = os.path.splitext(img_path)[0] + ".txt" |
|
|
with open(txt_path, "w", encoding="utf-8") as f: |
|
|
f.write(caption) |
|
|
|
|
|
return output_dir |