HFAgentsCourse / loader.py
nicolacaione's picture
removed langfuse
7e01be5
from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from typing import Literal
from pathlib import Path
import pandas as pd
'''
Tipologie di file in GAIA:
['xml', 'jpg', 'pdf', 'xlsx', 'png', 'm4a', 'docx', 'pptx', 'txt', 'csv', 'mp3', 'MOV', 'json']
Conto fatto su tutti i file di GAIA:
>>> lvl_1
{'jpg': 3, 'xlsx': 5, 'png': 2, 'pdf': 3, 'docx': 1, 'txt': 5, 'csv': 3, 'mp3': 2, 'json': 1}
>>> lvl_2
{'xml': 1, 'pdf': 7, 'xlsx': 11, 'png': 4, 'm4a': 1, 'pptx': 1, 'txt': 5, 'mp3': 2}
>>> lvl_3
{'xml': 1, 'jpg': 2, 'pdf': 2, 'png': 4, 'txt': 2, 'csv': 2, 'MOV': 1}
'''
AVAILABLE_SETS = ['2023_all', '2023_level1', '2023_level2', '2023_level3']
SPLITS = ["test", "validation"]
class GAIA:
def __init__(self, dset_string, split):
self._dset_string = dset_string
self._split = split
self._dataset : pd.DataFrame = self.get_loader(self._dset_string, self._split).to_pandas()
self._add_filetype()
def _add_filetype(self):
self._dataset["FileType"] = self._dataset["file_name"].apply(lambda x: None if not x else Path(x).suffix[1:])
@property
def dataset(self):
return self._dataset
@property
def dset_string(self):
return self._dset_string
@dset_string.setter
def dset_string(self, value: str):
self._dset_string = value
self._dataset = self.get_loader(self._dset_string, self._split).to_pandas()
self._add_filetype()
@property
def split(self):
return self._split
@split.setter
def split(self, value: str):
self._split = value
self._dataset = self.get_loader(self._dset_string, self._split).to_pandas()
self._add_filetype()
def __iter__(self) -> tuple[int, pd.Series]:
for row_idx, row in self._dataset.iterrows():
yield (row_idx, row)
def num_rows(self) -> int:
return self._dataset.shape[0]
def num_cols(self) -> int:
return self._dataset.shape[-1]
def cols(self) -> list[str]:
return self._dataset.columns.tolist()
def get_loader(
self,
dset : Literal['2023_all', '2023_level1', '2023_level2', '2023_level3'] = '2023_level1',
split : Literal['test', 'validation'] = 'test'
):
if dset not in AVAILABLE_SETS:
raise ValueError(f"Dataset {dset} not available. Available datasets are: {AVAILABLE_SETS}")
if split not in SPLITS:
raise ValueError(f"Split {split} not available. Available splits are: {SPLITS}")
dataset : DatasetDict = load_dataset("gaia-benchmark/GAIA", dset, trust_remote_code=True)
return dataset[split]
def visualize_with_streamlit():
import streamlit as st
st.set_page_config(
page_title="GAIA Dataset Viewer",
page_icon=":bar_chart:",
layout="wide",
initial_sidebar_state="expanded",
)
st.title("GAIA Dataset Viewer")
st.write("This is a simple Streamlit app to visualize the GAIA dataset.")
dset_string = st.selectbox("Select the dataset string", AVAILABLE_SETS)
split = st.selectbox("Select the dataset split", SPLITS)
# Load the dataset
gaia = GAIA(dset_string, split)
# Select the columns to display
columns = st.multiselect("Select the columns to display", gaia_columns := gaia.cols(), default = gaia_columns)
# Display the dataset
st.dataframe(gaia.dataset[columns])
st.write(f"Number of rows: {gaia.num_rows()}")
st.write(f"Number of columns: {gaia.num_cols()}")
def get_random_row(columns : list[str]):
row_idx = st.number_input("Select the row index", min_value=0, max_value=gaia.num_rows()-1, value=0)
row = gaia.dataset.iloc[row_idx][columns]
st.write(f"Row {row_idx}:")
st.write(row)
def row_iterator_trial(columns : list[str], limit = -1):
for row_idx, row in gaia:
if row_idx >= limit and limit > 0:
break
st.write(f"Row {row_idx}:")
st.write(row[columns])
if st.toggle('Try random row getter'):
get_random_row(columns)
show_iterator = st.toggle('Try row iterator')
if show_iterator:
limit_rows_for_iterator = st.number_input("Select the limit for the iterator", min_value=-1, max_value=gaia.num_rows(), value=10)
if show_iterator:row_iterator_trial(columns, limit_rows_for_iterator)
if __name__ == "__main__":
visualize_with_streamlit()