File size: 978 Bytes
ddba888
 
 
 
ee142e9
ddba888
 
ee142e9
 
 
 
ddba888
ee142e9
ddba888
 
 
 
 
 
 
 
 
 
ee142e9
 
ddba888
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from datasets import load_dataset
from config import DATASET_NAME, TRAIN_SPLIT, TEST_SPLIT_RATIO
from core.model import processor
from PIL import Image

def preprocess_batch(batch):
    images = [img.convert('RGB') for img in batch["image"]]

    labels = processor.tokenizer(batch['text'], padding=True, max_length=128, truncation=True).input_ids
    pixel_values = processor.image_processor(images, return_tensors="pt").pixel_values

    batch["pixel_values"] = pixel_values
    batch["labels"] = labels

    return batch

def load():
    dataset = load_dataset(DATASET_NAME, split = TRAIN_SPLIT)
    train_test = dataset.train_test_split(test_size = TEST_SPLIT_RATIO)
    train_ds = train_test['train']
    eval_ds = train_test['test']

    train_ds = train_ds.map(preprocess_batch, batched=True, remove_columns=train_ds.column_names)
    eval_ds = eval_ds.map(preprocess_batch, batched=True, remove_columns=eval_ds.column_names)

    return train_ds, eval_ds