Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from datasets.dataset_dict import DatasetDict | |
| from typing import Literal | |
| from pathlib import Path | |
| import pandas as pd | |
| ''' | |
| Tipologie di file in GAIA: | |
| ['xml', 'jpg', 'pdf', 'xlsx', 'png', 'm4a', 'docx', 'pptx', 'txt', 'csv', 'mp3', 'MOV', 'json'] | |
| Conto fatto su tutti i file di GAIA: | |
| lvl_1 | |
| {'jpg': 3, 'xlsx': 5, 'png': 2, 'pdf': 3, 'docx': 1, 'txt': 5, 'csv': 3, 'mp3': 2, 'json': 1} | |
| lvl_2 | |
| {'xml': 1, 'pdf': 7, 'xlsx': 11, 'png': 4, 'm4a': 1, 'pptx': 1, 'txt': 5, 'mp3': 2} | |
| lvl_3 | |
| {'xml': 1, 'jpg': 2, 'pdf': 2, 'png': 4, 'txt': 2, 'csv': 2, 'MOV': 1} | |
| ''' | |
| AVAILABLE_SETS = ['2023_all', '2023_level1', '2023_level2', '2023_level3'] | |
| SPLITS = ["test", "validation"] | |
| class GAIA: | |
| def __init__(self, dset_string, split): | |
| self._dset_string = dset_string | |
| self._split = split | |
| self._dataset : pd.DataFrame = self.get_loader(self._dset_string, self._split).to_pandas() | |
| self._add_filetype() | |
| def _add_filetype(self): | |
| self._dataset["FileType"] = self._dataset["file_name"].apply(lambda x: None if not x else Path(x).suffix[1:]) | |
| def dataset(self): | |
| return self._dataset | |
| def dset_string(self): | |
| return self._dset_string | |
| def dset_string(self, value: str): | |
| self._dset_string = value | |
| self._dataset = self.get_loader(self._dset_string, self._split).to_pandas() | |
| self._add_filetype() | |
| def split(self): | |
| return self._split | |
| def split(self, value: str): | |
| self._split = value | |
| self._dataset = self.get_loader(self._dset_string, self._split).to_pandas() | |
| self._add_filetype() | |
| def __iter__(self) -> tuple[int, pd.Series]: | |
| for row_idx, row in self._dataset.iterrows(): | |
| yield (row_idx, row) | |
| def num_rows(self) -> int: | |
| return self._dataset.shape[0] | |
| def num_cols(self) -> int: | |
| return self._dataset.shape[-1] | |
| def cols(self) -> list[str]: | |
| return self._dataset.columns.tolist() | |
| def get_loader( | |
| self, | |
| dset : Literal['2023_all', '2023_level1', '2023_level2', '2023_level3'] = '2023_level1', | |
| split : Literal['test', 'validation'] = 'test' | |
| ): | |
| if dset not in AVAILABLE_SETS: | |
| raise ValueError(f"Dataset {dset} not available. Available datasets are: {AVAILABLE_SETS}") | |
| if split not in SPLITS: | |
| raise ValueError(f"Split {split} not available. Available splits are: {SPLITS}") | |
| dataset : DatasetDict = load_dataset("gaia-benchmark/GAIA", dset, trust_remote_code=True) | |
| return dataset[split] | |
| def visualize_with_streamlit(): | |
| import streamlit as st | |
| st.set_page_config( | |
| page_title="GAIA Dataset Viewer", | |
| page_icon=":bar_chart:", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| st.title("GAIA Dataset Viewer") | |
| st.write("This is a simple Streamlit app to visualize the GAIA dataset.") | |
| dset_string = st.selectbox("Select the dataset string", AVAILABLE_SETS) | |
| split = st.selectbox("Select the dataset split", SPLITS) | |
| # Load the dataset | |
| gaia = GAIA(dset_string, split) | |
| # Select the columns to display | |
| columns = st.multiselect("Select the columns to display", gaia_columns := gaia.cols(), default = gaia_columns) | |
| # Display the dataset | |
| st.dataframe(gaia.dataset[columns]) | |
| st.write(f"Number of rows: {gaia.num_rows()}") | |
| st.write(f"Number of columns: {gaia.num_cols()}") | |
| def get_random_row(columns : list[str]): | |
| row_idx = st.number_input("Select the row index", min_value=0, max_value=gaia.num_rows()-1, value=0) | |
| row = gaia.dataset.iloc[row_idx][columns] | |
| st.write(f"Row {row_idx}:") | |
| st.write(row) | |
| def row_iterator_trial(columns : list[str], limit = -1): | |
| for row_idx, row in gaia: | |
| if row_idx >= limit and limit > 0: | |
| break | |
| st.write(f"Row {row_idx}:") | |
| st.write(row[columns]) | |
| if st.toggle('Try random row getter'): | |
| get_random_row(columns) | |
| show_iterator = st.toggle('Try row iterator') | |
| if show_iterator: | |
| limit_rows_for_iterator = st.number_input("Select the limit for the iterator", min_value=-1, max_value=gaia.num_rows(), value=10) | |
| if show_iterator:row_iterator_trial(columns, limit_rows_for_iterator) | |
| if __name__ == "__main__": | |
| visualize_with_streamlit() | |