Trabre / preprocess.py
Allex21's picture
Update preprocess.py
5b0bbfc verified
# preprocess.py
import os
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
def process_dataset(zip_path, output_dir, generate_captions=True):
os.makedirs(output_dir, exist_ok=True)
# Descompacta dataset
import zipfile
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(output_dir)
# Carrega BLIP (em inglês — modelo oficial da Salesforce)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Processa imagens
for img_name in os.listdir(output_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(output_dir, img_name)
image = Image.open(img_path).convert('RGB')
# Redimensiona para evitar erros de memória
image.thumbnail((512, 512), Image.LANCZOS)
image.save(img_path) # Salva imagem redimensionada
if generate_captions:
inputs = processor(image, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
caption = processor.decode(outputs[0], skip_special_tokens=True)
txt_path = os.path.splitext(img_path)[0] + ".txt"
with open(txt_path, "w", encoding="utf-8") as f:
f.write(caption)
return output_dir