Spaces:
Sleeping
Sleeping
File size: 978 Bytes
ddba888 ee142e9 ddba888 ee142e9 ddba888 ee142e9 ddba888 ee142e9 ddba888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from datasets import load_dataset
from config import DATASET_NAME, TRAIN_SPLIT, TEST_SPLIT_RATIO
from core.model import processor
from PIL import Image
def preprocess_batch(batch):
images = [img.convert('RGB') for img in batch["image"]]
labels = processor.tokenizer(batch['text'], padding=True, max_length=128, truncation=True).input_ids
pixel_values = processor.image_processor(images, return_tensors="pt").pixel_values
batch["pixel_values"] = pixel_values
batch["labels"] = labels
return batch
def load():
dataset = load_dataset(DATASET_NAME, split = TRAIN_SPLIT)
train_test = dataset.train_test_split(test_size = TEST_SPLIT_RATIO)
train_ds = train_test['train']
eval_ds = train_test['test']
train_ds = train_ds.map(preprocess_batch, batched=True, remove_columns=train_ds.column_names)
eval_ds = eval_ds.map(preprocess_batch, batched=True, remove_columns=eval_ds.column_names)
return train_ds, eval_ds
|