Spaces:
Paused
Paused
| import os | |
| import re | |
| import time | |
| import pandas as pd | |
| import streamlit as st | |
| import opendashboards.utils.utils as utils | |
| from pandas.api.types import ( | |
| is_categorical_dtype, | |
| is_datetime64_any_dtype, | |
| is_numeric_dtype, | |
| is_object_dtype, | |
| ) | |
| # @st.cache_data | |
| def load_runs(project, filters, min_steps=10, max_recent=100, local_path='wandb_runs.csv', local_stale_time=3600): | |
| # TODO: clean up the caching logic (e.g. take into account the args) | |
| dtypes = {'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category', 'start_time': 'datetime64[s]', 'end_time': 'datetime64[s]', 'duration': 'timedelta64[s]'} | |
| if local_path and os.path.exists(local_path) and (time.time() - float(os.path.getmtime(local_path))) < local_stale_time: | |
| frame = pd.read_csv(local_path) | |
| return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns}) | |
| runs = [] | |
| n_events = 0 | |
| successful = 0 | |
| progress = st.progress(0, 'Fetching runs from wandb') | |
| msg = st.empty() | |
| all_runs = utils.get_runs(project, filters) | |
| for i, run in enumerate(all_runs): | |
| if i > max_recent: | |
| break | |
| summary = run.summary | |
| step = summary.get('_step',-1) + 1 | |
| if step < min_steps: | |
| msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})') | |
| continue | |
| prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)' | |
| progress.progress(min(i/len(all_runs),1),f'{prog_msg}... **fetching** `{run.name}`') | |
| duration = summary.get('_runtime') | |
| end_time = summary.get('_timestamp') | |
| # extract values for selected tags | |
| rules = { | |
| 'version': re.compile('^\\d\.\\d+\.\\d+$'), | |
| 'spec_version': re.compile('\\d{4}$'), | |
| 'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE) | |
| } | |
| tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)} | |
| # include bool flag for remaining tags | |
| tags.update({k: k in run.tags for k in ('mock','disable_set_weights')}) | |
| runs.append({ | |
| 'state': run.state, | |
| 'num_steps': step, | |
| 'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)), | |
| 'duration': pd.to_timedelta(duration, unit="s").round('T'), # round to nearest minute | |
| 'start_time': pd.to_datetime(end_time-duration, unit="s").round('T'), | |
| 'end_time': pd.to_datetime(end_time, unit="s").round('T'), | |
| 'netuid': run.config.get('netuid'), | |
| **tags, | |
| 'username': run.user.username, | |
| 'run_id': run.id, | |
| 'run_name': run.name, | |
| 'url': run.url, | |
| # 'entity': run.entity, | |
| # 'project': run.project, | |
| 'run_path': os.path.join(run.entity, run.project, run.id), | |
| }) | |
| n_events += step | |
| successful += 1 | |
| progress.empty() | |
| msg.empty() | |
| frame = pd.DataFrame(runs) | |
| frame.to_csv(local_path, index=False) | |
| return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns}) | |
| def load_data(selected_runs, load=True, save=False): | |
| frames = [] | |
| n_events = 0 | |
| successful = 0 | |
| progress = st.progress(0, 'Loading data') | |
| info = st.empty() | |
| if not os.path.exists('data/'): | |
| os.makedirs('data/') | |
| for i, idx in enumerate(selected_runs.index): | |
| run = selected_runs.loc[idx] | |
| prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({successful}/{len(selected_runs)} runs, {n_events} events)' | |
| file_path = os.path.join('data',f'history-{run.run_id}.csv') | |
| if load and os.path.exists(file_path): | |
| progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`') | |
| try: | |
| df = utils.read_data(file_path) | |
| except Exception as e: | |
| info.warning(f'Failed to load history from `{file_path}`') | |
| st.exception(e) | |
| continue | |
| else: | |
| progress.progress(i/len(selected_runs),f'{prog_msg}... **downloading** `{run.run_path}`') | |
| try: | |
| # Download the history from wandb and add metadata | |
| df = utils.download_data(run.run_path).assign(**run.to_dict()) | |
| print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}') | |
| df.info() | |
| if save and run.state != 'running': | |
| df.to_csv(file_path, index=False) | |
| # st.info(f'Saved history to {file_path}') | |
| except Exception as e: | |
| info.warning(f'Failed to download history for `{run.run_path}`') | |
| st.exception(e) | |
| continue | |
| frames.append(df) | |
| n_events += df.shape[0] | |
| successful += 1 | |
| progress.empty() | |
| if not frames: | |
| info.error('No data loaded') | |
| st.stop() | |
| # Remove rows which contain chain weights as it messes up schema | |
| return pd.concat(frames) | |
| def filter_dataframe(df: pd.DataFrame, demo_selection=None) -> pd.DataFrame: | |
| """ | |
| Adds a UI on top of a dataframe to let viewers filter columns | |
| Args: | |
| df (pd.DataFrame): Original dataframe | |
| demo_selection (pd.Index): Index of runs to select (if demo) | |
| Returns: | |
| pd.DataFrame: Filtered dataframe | |
| """ | |
| filter_mode = st.sidebar.radio("Filter mode", ("Use demo", "Add filters"), index=0) | |
| run_msg = st.info("Select a single wandb run or compare multiple runs") | |
| if filter_mode == "Use demo": | |
| df = df.loc[demo_selection] | |
| run_msg.info(f"Selected {len(df)} runs") | |
| return df | |
| df = df.copy() | |
| # Try to convert datetimes into a standarrd format (datetime, no timezone) | |
| for col in df.columns: | |
| if is_object_dtype(df[col]): | |
| try: | |
| df[col] = pd.to_datetime(df[col]) | |
| except Exception: | |
| pass | |
| if is_datetime64_any_dtype(df[col]): | |
| df[col] = df[col].dt.tz_localize(None) | |
| modification_container = st.container() | |
| with modification_container: | |
| to_filter_columns = st.multiselect("Filter dataframe on", df.columns) | |
| for column in to_filter_columns: | |
| left, right = st.columns((1, 20)) | |
| # Treat columns with < 10 unique values as categorical | |
| if is_categorical_dtype(df[column]) or df[column].nunique() < 10: | |
| user_cat_input = right.multiselect( | |
| f"Values for {column}", | |
| df[column].unique(), | |
| default=list(df[column].unique()), | |
| ) | |
| df = df[df[column].isin(user_cat_input)] | |
| elif is_numeric_dtype(df[column]): | |
| _min = float(df[column].min()) | |
| _max = float(df[column].max()) | |
| step = (_max - _min) / 100 | |
| user_num_input = right.slider( | |
| f"Values for {column}", | |
| min_value=_min, | |
| max_value=_max, | |
| value=(_min, _max), | |
| step=step, | |
| ) | |
| df = df[df[column].between(*user_num_input)] | |
| elif is_datetime64_any_dtype(df[column]): | |
| user_date_input = right.date_input( | |
| f"Values for {column}", | |
| value=( | |
| df[column].min(), | |
| df[column].max(), | |
| ), | |
| ) | |
| if len(user_date_input) == 2: | |
| user_date_input = tuple(map(pd.to_datetime, user_date_input)) | |
| start_date, end_date = user_date_input | |
| df = df.loc[df[column].between(start_date, end_date)] | |
| else: | |
| user_text_input = right.text_input( | |
| f"Substring or regex in {column}", | |
| ) | |
| if user_text_input: | |
| df = df[df[column].astype(str).str.contains(user_text_input)] | |
| # Load data if new runs selected | |
| if len(df): | |
| run_msg.info(f"Selected {len(df)} runs") | |
| else: | |
| # open a dialog to select runs | |
| run_msg.error("Please select at least one run") | |
| # st.snow() | |
| # st.stop() | |
| return df |