File size: 1,322 Bytes
80b7188 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | from datasets import load_dataset
from rllm.data.dataset import DatasetRegistry
def prepare_geo3k_data():
# Load dataset
dataset = load_dataset("linxy/LaTeX_OCR")["train"]
dataset = dataset.train_test_split(test_size=500, seed=42)
train_dataset = dataset["train"]
test_dataset = dataset["test"]
def process_fn(example, idx):
prompt = "<image>Convert the image to LaTeX code."
answer = example.pop("text")
image = example.pop("image")
data = {
"data_source": "latex_ocr",
"image": image,
"question": prompt,
"ground_truth": answer,
}
return data
# Preprocess datasets
train_dataset = train_dataset.map(function=process_fn, with_indices=True, num_proc=8)
test_dataset = test_dataset.map(function=process_fn, with_indices=True, num_proc=8)
# Register datasets
train_dataset = DatasetRegistry.register_dataset("latex_ocr", train_dataset, "train")
test_dataset = DatasetRegistry.register_dataset("latex_ocr", test_dataset, "test")
return train_dataset, test_dataset
if __name__ == "__main__":
train_dataset, test_dataset = prepare_geo3k_data()
print(train_dataset.get_data_path())
print(test_dataset.get_data_path())
|