| import functools |
| import hashlib |
| import logging |
| import os |
| import re |
| import shutil |
| import time |
| import traceback |
| from typing import List, Optional |
|
|
| import pandas as pd |
| from h2o_wave import Q, ui |
| from h2o_wave.types import FormCard, ImageCard, MarkupCard, StatListItem, Tab |
|
|
| from llm_studio.app_utils.config import default_cfg |
| from llm_studio.app_utils.db import Dataset |
| from llm_studio.app_utils.sections.common import clean_dashboard |
| from llm_studio.app_utils.sections.experiment import experiment_start |
| from llm_studio.app_utils.sections.histogram_card import histogram_card |
| from llm_studio.app_utils.utils import ( |
| add_model_type, |
| azure_download, |
| azure_file_options, |
| check_valid_upload_content, |
| clean_error, |
| dir_file_table, |
| get_data_dir, |
| get_dataset_elements, |
| get_datasets, |
| get_experiments_status, |
| get_frame_stats, |
| get_model_types, |
| get_problem_types, |
| get_unique_dataset_name, |
| kaggle_download, |
| local_download, |
| make_label, |
| parse_ui_elements, |
| remove_temp_files, |
| s3_download, |
| s3_file_options, |
| ) |
| from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df |
| from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains |
| from llm_studio.src.utils.config_utils import ( |
| load_config_py, |
| load_config_yaml, |
| save_config_yaml, |
| ) |
| from llm_studio.src.utils.data_utils import ( |
| get_fill_columns, |
| read_dataframe, |
| read_dataframe_drop_missing_labels, |
| sanity_check, |
| ) |
| from llm_studio.src.utils.plot_utils import PlotData |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def file_extension_is_compatible(q): |
| cfg = q.client["dataset/import/cfg"] |
| allowed_extensions = cfg.dataset._allowed_file_extensions |
|
|
| is_correct_extension = [] |
| for mode in ["train", "validation"]: |
| dataset_name = q.client[f"dataset/import/cfg/{mode}_dataframe"] |
|
|
| if dataset_name is None or dataset_name == "None": |
| continue |
| is_correct_extension.append(dataset_name.endswith(allowed_extensions)) |
| return all(is_correct_extension) |
|
|
|
|
| async def dataset_import( |
| q: Q, |
| step: int, |
| edit: Optional[bool] = False, |
| error: Optional[str] = "", |
| warning: Optional[str] = "", |
| info: Optional[str] = "", |
| allow_merge: bool = True, |
| ) -> None: |
| """Display dataset import cards. |
| |
| Args: |
| q: Q |
| step: current step of wizard |
| edit: whether in edit mode |
| error: optional error message |
| warning: optional warning message |
| info: optional info message |
| allow_merge: whether to allow merging dataset when importing |
| """ |
|
|
| await clean_dashboard(q, mode="full") |
| q.client["nav/active"] = "dataset/import" |
| if step == 1: |
| q.page["dataset/import"] = ui.form_card(box="content", items=[]) |
| q.client.delete_cards.add("dataset/import") |
|
|
| if q.client["dataset/import/source"] is None: |
| q.client["dataset/import/source"] = "Upload" |
|
|
| import_choices = [ |
| ui.choice("Upload", "Upload"), |
| ui.choice("Local", "Local"), |
| ui.choice("S3", "AWS S3"), |
| ui.choice("Azure", "Azure Datalake"), |
| ui.choice("Kaggle", "Kaggle"), |
| ] |
|
|
| items = [ |
| ui.text_l("Import dataset"), |
| ui.dropdown( |
| name="dataset/import/source", |
| label="Source", |
| value=( |
| "Upload" |
| if q.client["dataset/import/source"] is None |
| else q.client["dataset/import/source"] |
| ), |
| choices=import_choices, |
| trigger=True, |
| tooltip="Source of dataset import", |
| ), |
| ] |
|
|
| if ( |
| q.client["dataset/import/source"] is None |
| or q.client["dataset/import/source"] == "S3" |
| ): |
| if q.client["dataset/import/s3_bucket"] is None: |
| q.client["dataset/import/s3_bucket"] = q.client[ |
| "default_aws_bucket_name" |
| ] |
| if q.client["dataset/import/s3_access_key"] is None: |
| q.client["dataset/import/s3_access_key"] = q.client[ |
| "default_aws_access_key" |
| ] |
| if q.client["dataset/import/s3_secret_key"] is None: |
| q.client["dataset/import/s3_secret_key"] = q.client[ |
| "default_aws_secret_key" |
| ] |
|
|
| files = s3_file_options( |
| q.client["dataset/import/s3_bucket"], |
| q.client["dataset/import/s3_access_key"], |
| q.client["dataset/import/s3_secret_key"], |
| ) |
|
|
| if not files: |
| ui_filename = ui.textbox( |
| name="dataset/import/s3_filename", |
| label="File name", |
| value="", |
| required=True, |
| tooltip="File name to be imported", |
| ) |
| else: |
| if default_cfg.s3_filename in files: |
| default_file = default_cfg.s3_filename |
| else: |
| default_file = files[0] |
| ui_filename = ui.dropdown( |
| name="dataset/import/s3_filename", |
| label="File name", |
| value=default_file, |
| choices=[ui.choice(x, x.split("/")[-1]) for x in files], |
| required=True, |
| tooltip="File name to be imported", |
| ) |
|
|
| items += [ |
| ui.textbox( |
| name="dataset/import/s3_bucket", |
| label="S3 bucket name", |
| value=q.client["dataset/import/s3_bucket"], |
| trigger=True, |
| required=True, |
| tooltip="S3 bucket name including relative paths", |
| ), |
| ui.textbox( |
| name="dataset/import/s3_access_key", |
| label="AWS access key", |
| value=q.client["dataset/import/s3_access_key"], |
| trigger=True, |
| required=True, |
| password=True, |
| tooltip="Optional AWS access key; empty for anonymous access.", |
| ), |
| ui.textbox( |
| name="dataset/import/s3_secret_key", |
| label="AWS secret key", |
| value=q.client["dataset/import/s3_secret_key"], |
| trigger=True, |
| required=True, |
| password=True, |
| tooltip="Optional AWS secret key; empty for anonymous access.", |
| ), |
| ui_filename, |
| ] |
|
|
| elif ( |
| q.client["dataset/import/source"] is None |
| or q.client["dataset/import/source"] == "Azure" |
| ): |
| if q.client["dataset/import/azure_conn_string"] is None: |
| q.client["dataset/import/azure_conn_string"] = q.client[ |
| "default_azure_conn_string" |
| ] |
| if q.client["dataset/import/azure_container"] is None: |
| q.client["dataset/import/azure_container"] = q.client[ |
| "default_azure_container" |
| ] |
|
|
| files = azure_file_options( |
| q.client["dataset/import/azure_conn_string"], |
| q.client["dataset/import/azure_container"], |
| ) |
| print(files) |
|
|
| if not files: |
| ui_filename = ui.textbox( |
| name="dataset/import/azure_filename", |
| label="File name", |
| value="", |
| required=True, |
| tooltip="File name to be imported", |
| ) |
| else: |
| default_file = files[0] |
| ui_filename = ui.dropdown( |
| name="dataset/import/azure_filename", |
| label="File name", |
| value=default_file, |
| choices=[ui.choice(x, x.split("/")[-1]) for x in files], |
| required=True, |
| tooltip="File name to be imported", |
| ) |
|
|
| items += [ |
| ui.textbox( |
| name="dataset/import/azure_conn_string", |
| label="Datalake connection string", |
| value=q.client["dataset/import/azure_conn_string"], |
| trigger=True, |
| required=True, |
| password=True, |
| tooltip="Azure connection string to connect to Datalake storage", |
| ), |
| ui.textbox( |
| name="dataset/import/azure_container", |
| label="Datalake container name", |
| value=q.client["dataset/import/azure_container"], |
| trigger=True, |
| required=True, |
| tooltip="Azure Datalake container name including relative paths", |
| ), |
| ui_filename, |
| ] |
|
|
| elif q.client["dataset/import/source"] == "Upload": |
| items += [ |
| ui.file_upload( |
| name="dataset/import/local_upload", |
| label="Upload!", |
| multiple=False, |
| file_extensions=default_cfg.allowed_file_extensions, |
| ) |
| ] |
|
|
| elif q.client["dataset/import/source"] == "Local": |
| current_path = ( |
| q.client["dataset/import/local_path_current"] |
| if q.client["dataset/import/local_path_current"] is not None |
| else os.path.expanduser("~") |
| ) |
|
|
| if q.args.__wave_submission_name__ == "dataset/import/local_path_list": |
| idx = int(q.args["dataset/import/local_path_list"][0]) |
| options = q.client["dataset/import/local_path_list_last"] |
| new_path = os.path.abspath(os.path.join(current_path, options[idx])) |
| if os.path.exists(new_path): |
| current_path = new_path |
|
|
| results_df = dir_file_table(current_path) |
| files_list = results_df[current_path].tolist() |
| q.client["dataset/import/local_path_list_last"] = files_list |
| q.client["dataset/import/local_path_current"] = current_path |
|
|
| items += [ |
| ui.textbox( |
| name="dataset/import/local_path", |
| label="File location", |
| value=current_path, |
| required=True, |
| tooltip="Location of file to be imported", |
| ), |
| ui_table_from_df( |
| q=q, |
| df=results_df, |
| name="dataset/import/local_path_list", |
| sortables=[], |
| searchables=[], |
| min_widths={current_path: "400"}, |
| link_col=current_path, |
| height="calc(65vh)", |
| ), |
| ] |
|
|
| elif q.client["dataset/import/source"] == "Kaggle": |
| if q.client["dataset/import/kaggle_access_key"] is None: |
| q.client["dataset/import/kaggle_access_key"] = q.client[ |
| "default_kaggle_username" |
| ] |
| if q.client["dataset/import/kaggle_secret_key"] is None: |
| q.client["dataset/import/kaggle_secret_key"] = q.client[ |
| "default_kaggle_secret_key" |
| ] |
|
|
| items += [ |
| ui.textbox( |
| name="dataset/import/kaggle_command", |
| label="Kaggle API command", |
| value=default_cfg.kaggle_command, |
| required=True, |
| tooltip="Kaggle API command to be executed", |
| ), |
| ui.textbox( |
| name="dataset/import/kaggle_access_key", |
| label="Kaggle username", |
| value=q.client["dataset/import/kaggle_access_key"], |
| required=True, |
| password=False, |
| tooltip="Kaggle username for API authentication", |
| ), |
| ui.textbox( |
| name="dataset/import/kaggle_secret_key", |
| label="Kaggle secret key", |
| value=q.client["dataset/import/kaggle_secret_key"], |
| required=True, |
| password=True, |
| tooltip="Kaggle secret key for API authentication", |
| ), |
| ] |
|
|
| allowed_types = ", ".join(default_cfg.allowed_file_extensions) |
| allowed_types = " or".join(allowed_types.rsplit(",", 1)) |
| items += [ |
| ui.message_bar(type="info", text=info + f"Must be a {allowed_types} file."), |
| ui.message_bar(type="error", text=error), |
| ui.message_bar(type="warning", text=warning), |
| ] |
|
|
| q.page["dataset/import"].items = items |
|
|
| buttons = [ui.button(name="dataset/list", label="Abort")] |
| if q.client["dataset/import/source"] != "Upload": |
| buttons.insert( |
| 0, ui.button(name="dataset/import/2", label="Continue", primary=True) |
| ) |
|
|
| q.page["dataset/import/footer"] = ui.form_card( |
| box="footer", items=[ui.inline(items=buttons, justify="start")] |
| ) |
| q.client.delete_cards.add("dataset/import/footer") |
|
|
| q.client["dataset/import/id"] = None |
| q.client["dataset/import/cfg_file"] = None |
|
|
| elif step == 2: |
| q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[]) |
| try: |
| if not q.args["dataset/import/cfg_file"] and not edit: |
| if q.client["dataset/import/source"] == "S3": |
| ( |
| q.client["dataset/import/path"], |
| q.client["dataset/import/name"], |
| ) = await s3_download( |
| q, |
| q.client["dataset/import/s3_bucket"], |
| q.client["dataset/import/s3_filename"], |
| q.client["dataset/import/s3_access_key"], |
| q.client["dataset/import/s3_secret_key"], |
| ) |
| elif q.client["dataset/import/source"] == "Azure": |
| ( |
| q.client["dataset/import/path"], |
| q.client["dataset/import/name"], |
| ) = await azure_download( |
| q, |
| q.client["dataset/import/azure_conn_string"], |
| q.client["dataset/import/azure_container"], |
| q.client["dataset/import/azure_filename"], |
| ) |
| elif q.client["dataset/import/source"] in ("Upload", "Local"): |
| ( |
| q.client["dataset/import/path"], |
| q.client["dataset/import/name"], |
| ) = await local_download(q, q.client["dataset/import/local_path"]) |
| elif q.client["dataset/import/source"] == "Kaggle": |
| ( |
| q.client["dataset/import/path"], |
| q.client["dataset/import/name"], |
| ) = await kaggle_download( |
| q, |
| q.client["dataset/import/kaggle_command"], |
| q.client["dataset/import/kaggle_access_key"], |
| q.client["dataset/import/kaggle_secret_key"], |
| ) |
|
|
| |
| q.client["dataset/import/edit"] = edit |
|
|
| |
| for trigger_key in default_cfg.dataset_trigger_keys: |
| if q.client[f"dataset/import/cfg/{trigger_key}"]: |
| del q.client[f"dataset/import/cfg/{trigger_key}"] |
|
|
| await dataset_import( |
| q, |
| step=3, |
| edit=edit, |
| error=error, |
| warning=warning, |
| allow_merge=allow_merge, |
| ) |
| except Exception: |
| logger.error("Dataset error:", exc_info=True) |
| error = ( |
| "Dataset import failed. Please make sure all required " |
| "fields are filled correctly." |
| ) |
| await clean_dashboard(q, mode="full") |
| await dataset_import(q, step=1, error=str(error)) |
|
|
| elif step == 3: |
| q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[]) |
| try: |
| if not q.args["dataset/import/cfg_file"] and not edit: |
| q.client["dataset/import/name"] = get_unique_dataset_name( |
| q, q.client["dataset/import/name"] |
| ) |
| q.page["dataset/import"] = ui.form_card(box="content", items=[]) |
| q.client.delete_cards.add("dataset/import") |
|
|
| wizard = q.page["dataset/import"] |
|
|
| title = "Configure dataset" |
|
|
| items = [ |
| ui.text_l(title), |
| ui.textbox( |
| name="dataset/import/name", |
| label="Dataset name", |
| value=q.client["dataset/import/name"], |
| required=True, |
| ), |
| ] |
|
|
| choices_problem_types = [ |
| ui.choice(name, label) for name, label in get_problem_types() |
| ] |
|
|
| if q.client["dataset/import/cfg_file"] is None: |
| max_substring_len = 0 |
| for c in choices_problem_types: |
| problem_type_name = c.name.replace("_config", "") |
| if problem_type_name in q.client["dataset/import/name"]: |
| if len(problem_type_name) > max_substring_len: |
| q.client["dataset/import/cfg_file"] = c.name |
| q.client["dataset/import/cfg_category"] = c.name.split("_")[ |
| 0 |
| ] |
| max_substring_len = len(problem_type_name) |
| if q.client["dataset/import/cfg_file"] is None: |
| q.client["dataset/import/cfg_file"] = default_cfg.cfg_file |
| q.client["dataset/import/cfg_category"] = q.client[ |
| "dataset/import/cfg_file" |
| ].split("_")[0] |
|
|
| |
| if ( |
| q.client["dataset/import/cfg_category"] |
| not in q.client["dataset/import/cfg_file"] |
| ): |
| q.client["dataset/import/cfg_file"] = get_problem_types( |
| category=q.client["dataset/import/cfg_category"] |
| )[0][0] |
|
|
| model_types = get_model_types(q.client["dataset/import/cfg_file"]) |
| if len(model_types) > 0: |
| |
| q.client["dataset/import/cfg_file"] = add_model_type( |
| q.client["dataset/import/cfg_file"], model_types[0][0] |
| ) |
| if not edit: |
| q.client["dataset/import/cfg"] = load_config_py( |
| config_path=( |
| f"llm_studio/python_configs/" |
| f"{q.client['dataset/import/cfg_file']}" |
| ), |
| config_name="ConfigProblemBase", |
| ) |
|
|
| option_items = get_dataset_elements(cfg=q.client["dataset/import/cfg"], q=q) |
| items.extend(option_items) |
| items.append(ui.message_bar(type="error", text=error)) |
| items.append(ui.message_bar(type="warning", text=warning)) |
| if file_extension_is_compatible(q): |
| ui_nav_name = "dataset/import/4/edit" if edit else "dataset/import/4" |
| buttons = [ |
| ui.button(name=ui_nav_name, label="Continue", primary=True), |
| ui.button(name="dataset/list", label="Abort"), |
| ] |
| if allow_merge: |
| datasets_df = q.client.app_db.get_datasets_df() |
| if datasets_df.shape[0]: |
| label = "Merge With Existing Dataset" |
| buttons.insert(1, ui.button(name="dataset/merge", label=label)) |
| else: |
| problem_type = make_label( |
| re.sub("_config.*", "", q.client["dataset/import/cfg_file"]) |
| ) |
| items += [ |
| ui.text( |
| "<b> The chosen file extensions is not " |
| f"compatible with {problem_type}.</b> " |
| ) |
| ] |
| buttons = [ |
| ui.button(name="dataset/list", label="Abort"), |
| ] |
| q.page["dataset/import/footer"] = ui.form_card( |
| box="footer", items=[ui.inline(items=buttons, justify="start")] |
| ) |
|
|
| wizard.items = items |
|
|
| q.client.delete_cards.add("dataset/import/footer") |
|
|
| except Exception as exception: |
| logger.error("Dataset error:", exc_info=True) |
| error = clean_error(str(exception)) |
| await clean_dashboard(q, mode="full") |
| await dataset_import(q, step=1, error=str(error)) |
|
|
| elif step == 4: |
| dataset_name = q.client["dataset/import/name"] |
| original_name = q.client["dataset/import/original_name"] |
| valid_dataset_name = get_unique_dataset_name(q, dataset_name) |
| if valid_dataset_name != dataset_name and not ( |
| q.client["dataset/import/edit"] and dataset_name == original_name |
| ): |
| err = f"Dataset <strong>{dataset_name}</strong> already exists" |
| q.client["dataset/import/name"] = valid_dataset_name |
| await dataset_import(q, 3, edit=edit, error=err) |
| else: |
| await dataset_import(q, 5, edit=edit) |
|
|
| elif step == 5: |
| header = "<h2>Sample Data Visualization</h2>" |
| valid_visualization = False |
| try: |
| cfg = q.client["dataset/import/cfg"] |
| cfg = parse_ui_elements( |
| cfg=cfg, q=q, limit=default_cfg.dataset_keys, pre="dataset/import/cfg/" |
| ) |
|
|
| q.client["dataset/import/cfg"] = cfg |
| plot = cfg.logging.plots_class.plot_data(cfg) |
| text = ( |
| "Data Validity Check. Click <strong>Continue</strong> if the input " |
| "data and labels appear correctly." |
| ) |
| if plot.encoding == "image": |
| plot_item = ui.image(title="", type="png", image=plot.data) |
| elif plot.encoding == "html": |
| plot_item = ui.markup(content=plot.data) |
| elif plot.encoding == "df": |
| df = pd.read_parquet(plot.data) |
| df = df.iloc[:2000] |
| min_widths = {"Content": "800"} |
| plot_item = ui_table_from_df( |
| q=q, |
| df=df, |
| name="experiment/display/table", |
| markdown_cells=list(df.columns), |
| searchables=list(df.columns), |
| downloadable=False, |
| resettable=False, |
| min_widths=min_widths, |
| height="calc(100vh - 245px)", |
| max_char_length=5_000, |
| cell_overflow="tooltip", |
| ) |
| else: |
| raise ValueError(f"Unknown plot encoding `{plot.encoding}`") |
|
|
| items = [ui.markup(content=header), ui.message_bar(text=text), plot_item] |
| valid_visualization = True |
|
|
| await busy_dialog( |
| q=q, |
| title="Performing sanity checks on the data", |
| text="Please be patient...", |
| ) |
| |
| |
| time.sleep(1) |
| sanity_check(cfg) |
|
|
| except Exception as exception: |
| logger.error( |
| f"Error while plotting data preview: {exception}", exc_info=True |
| ) |
| text = ( |
| "Error occurred while visualizing the data. Please go back and verify " |
| "whether the problem type and other settings were set properly." |
| ) |
| items = [ |
| ui.markup(content=header), |
| ui.message_bar(text=text, type="error"), |
| ui.expander( |
| name="expander", |
| label="Expand Error Traceback", |
| items=[ui.markup(f"<pre>{traceback.format_exc()}</pre>")], |
| ), |
| ] |
|
|
| buttons = [ |
| ui.button( |
| name="dataset/import/6", label="Continue", primary=valid_visualization |
| ), |
| ui.button( |
| name="dataset/import/3/edit", |
| label="Back", |
| primary=not valid_visualization, |
| ), |
| ui.button(name="dataset/list", label="Abort"), |
| ] |
|
|
| q.page["dataset/import"] = ui.form_card(box="content", items=items) |
| q.client.delete_cards.add("dataset/import") |
|
|
| q.page["dataset/import/footer"] = ui.form_card( |
| box="footer", items=[ui.inline(items=buttons, justify="start")] |
| ) |
| q.client.delete_cards.add("dataset/import/footer") |
|
|
| elif step == 6: |
| if q.client["dataset/import/name"] == "": |
| await clean_dashboard(q, mode="full") |
| await dataset_import(q, step=2, error="Please enter all required fields!") |
|
|
| else: |
| folder_name = q.client["dataset/import/path"].split("/")[-1] |
| new_folder = q.client["dataset/import/name"] |
| act_path = q.client["dataset/import/path"] |
| new_path = new_folder.join(act_path.rsplit(folder_name, 1)) |
|
|
| try: |
| shutil.move(q.client["dataset/import/path"], new_path) |
|
|
| cfg = q.client["dataset/import/cfg"] |
|
|
| |
| for k in default_cfg.dataset_folder_keys: |
| old_path = getattr(cfg.dataset, k, None) |
| if old_path is not None: |
| setattr( |
| cfg.dataset, |
| k, |
| old_path.replace(q.client["dataset/import/path"], new_path), |
| ) |
|
|
| |
| if cfg.dataset.validation_dataframe != "None": |
| cfg.dataset.validation_strategy = "custom" |
| cfg_path = f"{new_path}/{q.client['dataset/import/cfg_file']}.yaml" |
| save_config_yaml(cfg_path, cfg) |
|
|
| train_rows = None |
| if os.path.exists(cfg.dataset.train_dataframe): |
| train_rows = read_dataframe_drop_missing_labels( |
| cfg.dataset.train_dataframe, cfg |
| ).shape[0] |
| validation_rows = None |
| if os.path.exists(cfg.dataset.validation_dataframe): |
| validation_rows = read_dataframe_drop_missing_labels( |
| cfg.dataset.validation_dataframe, cfg |
| ).shape[0] |
|
|
| dataset = Dataset( |
| id=q.client["dataset/import/id"], |
| name=q.client["dataset/import/name"], |
| path=new_path, |
| config_file=cfg_path, |
| train_rows=train_rows, |
| validation_rows=validation_rows, |
| ) |
| if q.client["dataset/import/id"] is not None: |
| q.client.app_db.delete_dataset(dataset.id) |
| q.client.app_db.add_dataset(dataset) |
| await dataset_list(q) |
|
|
| except Exception as exception: |
| logger.error("Dataset error:", exc_info=True) |
| q.client.app_db._session.rollback() |
| error = clean_error(str(exception)) |
| await clean_dashboard(q, mode="full") |
| await dataset_import(q, step=2, error=str(error)) |
|
|
|
|
| async def dataset_merge(q: Q, step, error=""): |
| if step == 1: |
| await clean_dashboard(q, mode="full") |
| q.client["nav/active"] = "dataset/merge" |
|
|
| q.page["dataset/merge"] = ui.form_card(box="content", items=[]) |
| q.client.delete_cards.add("dataset/merge") |
|
|
| datasets_df = q.client.app_db.get_datasets_df() |
| import_choices = [ |
| ui.choice(x["path"], x["name"]) for idx, x in datasets_df.iterrows() |
| ] |
|
|
| items = [ |
| ui.text_l("Merge current dataset with an existing dataset"), |
| ui.dropdown( |
| name="dataset/merge/target", |
| label="Dataset", |
| value=datasets_df.iloc[0]["path"], |
| choices=import_choices, |
| trigger=False, |
| tooltip="Source of dataset import", |
| ), |
| ] |
|
|
| if error: |
| items.append(ui.message_bar(type="error", text=error)) |
|
|
| q.page["dataset/merge"].items = items |
|
|
| buttons = [ |
| ui.button(name="dataset/merge/action", label="Merge", primary=True), |
| ui.button(name="dataset/import/3", label="Back", primary=False), |
| ui.button(name="dataset/list", label="Abort"), |
| ] |
|
|
| q.page["dataset/import/footer"] = ui.form_card( |
| box="footer", items=[ui.inline(items=buttons, justify="start")] |
| ) |
| q.client.delete_cards.add("dataset/import/footer") |
|
|
| elif step == 2: |
| current_dir = q.client["dataset/import/path"] |
| target_dir = q.args["dataset/merge/target"] |
|
|
| if current_dir == target_dir: |
| await dataset_merge(q, step=1, error="Cannot merge dataset with itself") |
| return |
|
|
| datasets_df = q.client.app_db.get_datasets_df().set_index("path") |
| has_dataset_entry = current_dir in datasets_df.index |
|
|
| if has_dataset_entry: |
| experiment_df = q.client.app_db.get_experiments_df() |
| source_id = int(datasets_df.loc[current_dir, "id"]) |
| has_experiment = any(experiment_df["dataset"].astype(int) == source_id) |
| else: |
| source_id = None |
| has_experiment = False |
|
|
| current_files = os.listdir(current_dir) |
| current_files = [x for x in current_files if not x.endswith(".yaml")] |
| target_files = os.listdir(target_dir) |
| overlapping_files = list(set(current_files).intersection(set(target_files))) |
| rename_map = {} |
|
|
| for file in overlapping_files: |
| tmp_str = file.split(".") |
| if len(tmp_str) == 1: |
| file_name, extension = file, "" |
| else: |
| file_name, extension = ".".join(tmp_str[:-1]), f".{tmp_str[-1]}" |
|
|
| cnt = 1 |
| while f"{file_name}_{cnt}{extension}" in target_files: |
| cnt += 1 |
|
|
| rename_map[file] = f"{file_name}_{cnt}{extension}" |
| target_files.append(rename_map[file]) |
|
|
| if len(overlapping_files): |
| warning = ( |
| f"Renamed {', '.join(rename_map.keys())} to " |
| f"{', '.join(rename_map.values())} due to duplicated entries." |
| ) |
| else: |
| warning = "" |
|
|
| for file in current_files: |
| new_file = rename_map.get(file, file) |
| src = os.path.join(current_dir, file) |
| dst = os.path.join(target_dir, new_file) |
|
|
| if has_experiment: |
| if os.path.isdir(src): |
| shutil.copytree(src, dst) |
| else: |
| shutil.copy(src, dst) |
| else: |
| shutil.move(src, dst) |
|
|
| if not has_experiment: |
| shutil.rmtree(current_dir) |
| if has_dataset_entry: |
| q.client.app_db.delete_dataset(source_id) |
|
|
| dataset_id = int(datasets_df.loc[target_dir, "id"]) |
| await dataset_edit(q, dataset_id, warning=warning, allow_merge=False) |
|
|
|
|
| async def dataset_list_table( |
| q: Q, |
| show_experiment_datasets: bool = True, |
| ) -> None: |
| """Pepare dataset list form card |
| |
| Args: |
| q: Q |
| show_experiment_datasets: whether to also show datasets linked to experiments |
| """ |
|
|
| q.client["dataset/list/df_datasets"] = get_datasets( |
| q=q, |
| show_experiment_datasets=show_experiment_datasets, |
| ) |
|
|
| df_viz = q.client["dataset/list/df_datasets"].copy() |
|
|
| columns_to_drop = [ |
| "id", |
| "path", |
| "config_file", |
| "validation dataframe", |
| ] |
|
|
| df_viz = df_viz.drop(columns=columns_to_drop, errors="ignore") |
| if "problem type" in df_viz.columns: |
| df_viz["problem type"] = df_viz["problem type"].str.replace("Text ", "") |
|
|
| widths = { |
| "name": "200", |
| "problem type": "210", |
| "train dataframe": "190", |
| "train rows": "120", |
| "validation rows": "130", |
| "labels": "120", |
| "actions": "5", |
| } |
|
|
| actions_dict = { |
| "dataset/newexperiment": "New experiment", |
| "dataset/edit": "Edit dataset", |
| "dataset/delete/dialog/single": "Delete dataset", |
| } |
|
|
| q.page["dataset/list"] = ui.form_card( |
| box="content", |
| items=[ |
| ui_table_from_df( |
| q=q, |
| df=df_viz, |
| name="dataset/list/table", |
| sortables=["train rows", "validation rows"], |
| filterables=["name", "problem type"], |
| searchables=[], |
| min_widths=widths, |
| link_col="name", |
| height="calc(100vh - 245px)", |
| actions=actions_dict, |
| ), |
| ui.message_bar(type="info", text=""), |
| ], |
| ) |
| q.client.delete_cards.add("dataset/list") |
|
|
|
|
| async def dataset_list(q: Q, reset: bool = True) -> None: |
| """Display all datasets.""" |
| q.client["nav/active"] = "dataset/list" |
|
|
| if reset: |
| await clean_dashboard(q, mode="full") |
| await dataset_list_table(q) |
|
|
| q.page["dataset/display/footer"] = ui.form_card( |
| box="footer", |
| items=[ |
| ui.inline( |
| items=[ |
| ui.button( |
| name="dataset/import", label="Import dataset", primary=True |
| ), |
| ui.button( |
| name="dataset/list/delete", |
| label="Delete datasets", |
| primary=False, |
| ), |
| ], |
| justify="start", |
| ) |
| ], |
| ) |
| q.client.delete_cards.add("dataset/display/footer") |
| remove_temp_files(q) |
|
|
| await q.page.save() |
|
|
|
|
| async def dataset_newexperiment(q: Q, dataset_id: int): |
| """Start a new experiment from given dataset.""" |
|
|
| dataset = q.client.app_db.get_dataset(dataset_id) |
|
|
| q.client["experiment/start/cfg_file"] = dataset.config_file.split("/")[-1].replace( |
| ".yaml", "" |
| ) |
| q.client["experiment/start/cfg_category"] = q.client[ |
| "experiment/start/cfg_file" |
| ].split("_")[0] |
| q.client["experiment/start/dataset"] = str(dataset_id) |
|
|
| await experiment_start(q) |
|
|
|
|
| async def dataset_edit( |
| q: Q, dataset_id: int, error: str = "", warning: str = "", allow_merge: bool = True |
| ): |
| """Edit selected dataset. |
| |
| Args: |
| q: Q |
| dataset_id: dataset id to edit |
| error: optional error message |
| warning: optional warning message |
| allow_merge: whether to allow merging dataset when editing |
| """ |
|
|
| dataset = q.client.app_db.get_dataset(dataset_id) |
|
|
| experiments_df = q.client.app_db.get_experiments_df() |
| experiments_df = experiments_df[experiments_df["dataset"] == str(dataset_id)] |
| statuses, _ = get_experiments_status(experiments_df) |
| num_invalid = len([stat for stat in statuses if stat in ["running", "queued"]]) |
|
|
| if num_invalid: |
| info = "s" if num_invalid > 1 else "" |
| info_str = ( |
| f"Dataset <strong>{dataset.name}</strong> is linked to {num_invalid} " |
| f"running or queued experiment{info}. Wait for them to finish or stop them " |
| "first before editing the dataset." |
| ) |
| q.page["dataset/list"].items[1].message_bar.text = info_str |
| return |
|
|
| q.client["dataset/import/id"] = dataset_id |
|
|
| q.client["dataset/import/cfg_file"] = dataset.config_file.split("/")[-1].replace( |
| ".yaml", "" |
| ) |
| q.client["dataset/import/cfg_category"] = q.client["dataset/import/cfg_file"].split( |
| "_" |
| )[0] |
| q.client["dataset/import/path"] = dataset.path |
| q.client["dataset/import/name"] = dataset.name |
| q.client["dataset/import/original_name"] = dataset.name |
| q.client["dataset/import/cfg"] = load_config_yaml(dataset.config_file) |
|
|
| if allow_merge and experiments_df.shape[0]: |
| allow_merge = False |
|
|
| await dataset_import( |
| q=q, step=2, edit=True, error=error, warning=warning, allow_merge=allow_merge |
| ) |
|
|
|
|
| async def dataset_list_delete(q: Q): |
| """Allow to select multiple datasets for deletion.""" |
|
|
| await dataset_list_table(q, show_experiment_datasets=False) |
|
|
| q.page["dataset/list"].items[0].table.multiple = True |
|
|
| info_str = "Only datasets not linked to experiments can be deleted." |
|
|
| q.page["dataset/list"].items[1].message_bar.text = info_str |
|
|
| q.page["dataset/display/footer"].items = [ |
| ui.inline( |
| items=[ |
| ui.button( |
| name="dataset/delete/dialog", label="Delete datasets", primary=True |
| ), |
| ui.button(name="dataset/list/delete/abort", label="Abort"), |
| ] |
| ) |
| ] |
|
|
|
|
| async def dataset_delete(q: Q, dataset_ids: List[int]): |
| """Delete selected datasets. |
| |
| Args: |
| q: Q |
| dataset_ids: list of dataset ids to delete |
| """ |
|
|
| for dataset_id in dataset_ids: |
| dataset = q.client.app_db.get_dataset(dataset_id) |
| q.client.app_db.delete_dataset(dataset.id) |
|
|
| try: |
| shutil.rmtree(dataset.path) |
| except OSError: |
| pass |
|
|
|
|
| async def dataset_delete_single(q: Q, dataset_id: int): |
| dataset = q.client.app_db.get_dataset(dataset_id) |
|
|
| experiments_df = q.client.app_db.get_experiments_df() |
| num_experiments = sum(experiments_df["dataset"] == str(dataset_id)) |
| if num_experiments: |
| info = "s" if num_experiments > 1 else "" |
| info_str = ( |
| f"Dataset <strong>{dataset.name}</strong> is linked to {num_experiments} " |
| f"experiment{info}. Only datasets not linked to experiments can be deleted." |
| ) |
| await dataset_list(q) |
| q.page["dataset/list"].items[1].message_bar.text = info_str |
| else: |
| await dataset_delete(q, [dataset_id]) |
| await dataset_list(q) |
|
|
|
|
| async def dataset_display(q: Q) -> None: |
| """Display a selected dataset.""" |
|
|
| dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[ |
| q.client["dataset/display/id"] |
| ] |
| dataset: Dataset = q.client.app_db.get_dataset(dataset_id) |
| config_filename = dataset.config_file |
| cfg = load_config_yaml(config_filename) |
| dataset_filename = cfg.dataset.train_dataframe |
|
|
| if ( |
| q.client["dataset/display/tab"] is None |
| or q.args["dataset/display/data"] is not None |
| ): |
| q.client["dataset/display/tab"] = "dataset/display/data" |
|
|
| if q.args["dataset/display/visualization"] is not None: |
| q.client["dataset/display/tab"] = "dataset/display/visualization" |
|
|
| if q.args["dataset/display/statistics"] is not None: |
| q.client["dataset/display/tab"] = "dataset/display/statistics" |
|
|
| if q.args["dataset/display/summary"] is not None: |
| q.client["dataset/display/tab"] = "dataset/display/summary" |
|
|
| await clean_dashboard(q, mode=q.client["dataset/display/tab"]) |
|
|
| items: List[Tab] = [ |
| ui.tab(name="dataset/display/data", label="Sample Train Data"), |
| ui.tab( |
| name="dataset/display/visualization", label="Sample Train Visualization" |
| ), |
| ui.tab(name="dataset/display/statistics", label="Train Data Statistics"), |
| ui.tab(name="dataset/display/summary", label="Summary"), |
| ] |
|
|
| q.page["dataset/display/tab"] = ui.tab_card( |
| box="nav2", |
| link=True, |
| items=items, |
| value=q.client["dataset/display/tab"], |
| ) |
| q.client.delete_cards.add("dataset/display/tab") |
|
|
| if q.client["dataset/display/tab"] == "dataset/display/data": |
| await show_data_tab(q=q, cfg=cfg, filename=dataset_filename) |
|
|
| elif q.client["dataset/display/tab"] == "dataset/display/visualization": |
| await show_visualization_tab(q, cfg) |
|
|
| elif q.client["dataset/display/tab"] == "dataset/display/statistics": |
| await show_statistics_tab( |
| q, dataset_filename=dataset_filename, config_filename=config_filename |
| ) |
|
|
| elif q.client["dataset/display/tab"] == "dataset/display/summary": |
| await show_summary_tab(q, dataset_id) |
|
|
| q.page["dataset/display/footer"] = ui.form_card( |
| box="footer", |
| items=[ |
| ui.inline( |
| items=[ |
| ui.button( |
| name="dataset/newexperiment/from_current", |
| label="Create experiment", |
| primary=False, |
| disabled=False, |
| tooltip=None, |
| ), |
| ui.button(name="dataset/list", label="Back", primary=False), |
| ], |
| justify="start", |
| ) |
| ], |
| ) |
| q.client.delete_cards.add("dataset/display/footer") |
|
|
|
|
| async def show_data_tab(q, cfg, filename: str): |
| fill_columns = get_fill_columns(cfg) |
| df = read_dataframe(filename, n_rows=200, fill_columns=fill_columns) |
| q.page["dataset/display/data"] = ui.form_card( |
| box="first", |
| items=[ |
| ui_table_from_df( |
| q=q, |
| df=df, |
| name="dataset/display/data/table", |
| sortables=list(df.columns), |
| height="calc(100vh - 265px)", |
| cell_overflow="wrap", |
| ) |
| ], |
| ) |
| q.client.delete_cards.add("dataset/display/data") |
|
|
|
|
| async def show_visualization_tab(q, cfg): |
| try: |
| plot = cfg.logging.plots_class.plot_data(cfg) |
| except Exception as error: |
| logger.error(f"Error while plotting data preview: {error}", exc_info=True) |
| plot = PlotData("<h2>Error while plotting data.</h2>", encoding="html") |
| card: ImageCard | MarkupCard | FormCard |
| if plot.encoding == "image": |
| card = ui.image_card(box="first", title="", type="png", image=plot.data) |
| elif plot.encoding == "html": |
| card = ui.markup_card(box="first", title="", content=plot.data) |
| elif plot.encoding == "df": |
| df = pd.read_parquet(plot.data) |
| df = df.iloc[:2000] |
| min_widths = {"Content": "800"} |
| card = ui.form_card( |
| box="first", |
| items=[ |
| ui_table_from_df( |
| q=q, |
| df=df, |
| name="dataset/display/visualization/table", |
| markdown_cells=list(df.columns), |
| searchables=list(df.columns), |
| downloadable=True, |
| resettable=True, |
| min_widths=min_widths, |
| height="calc(100vh - 245px)", |
| max_char_length=50_000, |
| cell_overflow="tooltip", |
| ) |
| ], |
| ) |
|
|
| else: |
| raise ValueError(f"Unknown plot encoding `{plot.encoding}`") |
| q.page["dataset/display/visualization"] = card |
| q.client.delete_cards.add("dataset/display/visualization") |
|
|
|
|
| async def show_summary_tab(q, dataset_id): |
| dataset_df = get_datasets(q) |
| dataset_df = dataset_df[dataset_df.id == dataset_id] |
| stat_list_items: List[StatListItem] = [] |
| for col in dataset_df.columns: |
| if col in ["id", "config_file", "path", "process_id", "status"]: |
| continue |
| v = dataset_df[col].values[0] |
| t: StatListItem = ui.stat_list_item(label=make_label(col), value=str(v)) |
|
|
| stat_list_items.append(t) |
| q.page["dataset/display/summary"] = ui.stat_list_card( |
| box="first", items=stat_list_items, title="" |
| ) |
| q.client.delete_cards.add("dataset/display/summary") |
|
|
|
|
| async def show_statistics_tab(q, dataset_filename, config_filename): |
| cfg_hash = hashlib.md5(open(config_filename, "rb").read()).hexdigest() |
| stats_dict = compute_dataset_statistics(dataset_filename, config_filename, cfg_hash) |
|
|
| for chat_type in ["prompts", "answers"]: |
| q.page[f"dataset/display/statistics/{chat_type}_histogram"] = histogram_card( |
| x=stats_dict[chat_type], |
| x_axis_description=f"text_length_{chat_type.capitalize()}", |
| title=f"Text Length Distribution for {chat_type.capitalize()}" |
| f" (split by whitespace)", |
| histogram_box="first", |
| ) |
| q.client.delete_cards.add(f"dataset/display/statistics/{chat_type}_histogram") |
|
|
| q.page["dataset/display/statistics/full_conversation_histogram"] = histogram_card( |
| x=stats_dict["complete_conversations"], |
| x_axis_description="text_length_complete_conversations", |
| title="Text Length Distribution for complete " |
| "conversations (split by whitespace)", |
| histogram_box="second", |
| ) |
| q.client.delete_cards.add("dataset/display/statistics/full_conversation_histogram") |
|
|
| if len(set(stats_dict["number_of_prompts"])) > 1: |
| q.page["dataset/display/statistics/parent_id_length_histogram"] = ( |
| histogram_card( |
| x=stats_dict["number_of_prompts"], |
| x_axis_description="number_of_prompts", |
| title="Distribution of number of prompt-answer turns per conversation.", |
| histogram_box="second", |
| ) |
| ) |
| q.client.delete_cards.add( |
| "dataset/display/statistics/parent_id_length_histogram" |
| ) |
|
|
| df_stats = stats_dict["df_stats"] |
| if df_stats is None: |
| component_items = [ |
| ui.text( |
| "Dataset does not contain numerical or text features. " |
| "No statistics available." |
| ) |
| ] |
| else: |
| if df_stats.shape[1] > 5: |
| widths = {col: "77" for col in df_stats} |
| else: |
| widths = None |
| component_items = [ |
| ui_table_from_df( |
| q=q, |
| df=df_stats, |
| name="dataset/display/statistics/table", |
| sortables=list(df_stats.columns), |
| min_widths=widths, |
| height="265px", |
| ) |
| ] |
| q.page["dataset/display/statistics"] = ui.form_card( |
| box="third", |
| items=component_items, |
| ) |
| q.client.delete_cards.add("dataset/display/statistics") |
|
|
|
|
| @functools.lru_cache() |
| def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str): |
| """ |
| Compute various statistics for a dataset. |
| - text length distribution for prompts and answers |
| - text length distribution for complete conversations |
| - distribution of number of prompt-answer turns per conversation |
| - statistics for non text features |
| |
| We use LRU caching to avoid recomputing the statistics for the same dataset. |
| Thus, cfg_hash is used as a function argument to identify the dataset. |
| """ |
| df_train = read_dataframe(dataset_path) |
| cfg = load_config_yaml(cfg_path) |
| conversations = get_conversation_chains( |
| df=df_train, cfg=cfg, limit_chained_samples=True |
| ) |
| stats_dict = {} |
| for chat_type in ["prompts", "answers"]: |
| text_lengths = [ |
| [len(text.split(" ")) for text in conversation[chat_type]] |
| for conversation in conversations |
| ] |
| text_lengths = [item for sublist in text_lengths for item in sublist] |
| stats_dict[chat_type] = text_lengths |
| input_texts = [] |
| for conversation in conversations: |
| input_text = conversation["systems"][0] |
| prompts = conversation["prompts"] |
| answers = conversation["answers"] |
| for prompt, answer in zip(prompts, answers): |
| input_text += prompt + answer |
| input_texts += [input_text] |
| stats_dict["complete_conversations"] = [ |
| len(text.split(" ")) for text in input_texts |
| ] |
| stats_dict["number_of_prompts"] = [ |
| len(conversation["prompts"]) for conversation in conversations |
| ] |
| stats_dict["df_stats"] = get_frame_stats(df_train) |
| return stats_dict |
|
|
|
|
| async def dataset_import_uploaded_file(q: Q): |
| local_path = await q.site.download( |
| q.args["dataset/import/local_upload"][0], |
| f"{get_data_dir(q)}/" |
| f'{q.args["dataset/import/local_upload"][0].split("/")[-1]}', |
| ) |
| await q.site.unload(q.args["dataset/import/local_upload"][0]) |
| valid, error = check_valid_upload_content(local_path) |
| if valid: |
| q.args["dataset/import/local_path"] = local_path |
| q.client["dataset/import/local_path"] = q.args["dataset/import/local_path"] |
| await dataset_import(q, step=2) |
| else: |
| await dataset_import(q, step=1, error=error) |
|
|
|
|
| async def dataset_delete_current_datasets(q: Q): |
| dataset_ids = list( |
| q.client["dataset/list/df_datasets"]["id"].iloc[ |
| list(map(int, q.client["dataset/list/table"])) |
| ] |
| ) |
| await dataset_delete(q, dataset_ids) |
| await dataset_list(q) |
|
|