"""The main page for the Trackio UI.""" import os import re import secrets import shutil from dataclasses import dataclass from typing import Any import gradio as gr import numpy as np import pandas as pd import trackio.utils as utils from trackio.media import ( TrackioAudio, TrackioImage, TrackioVideo, get_project_media_path, ) from trackio.sqlite_storage import SQLiteStorage from trackio.typehints import LogEntry, SystemLogEntry, UploadEntry from trackio.ui import fns from trackio.ui.components.colored_checkbox import ColoredCheckboxGroup from trackio.ui.files import files_page from trackio.ui.helpers.run_selection import RunSelection from trackio.ui.media_page import media_page from trackio.ui.run_detail import run_detail_page from trackio.ui.runs import run_page from trackio.ui.system_page import system_page INSTRUCTIONS_SPACES = """ ## Start logging with Trackio 🤗 To start logging to this Trackio dashboard, first make sure you have the Trackio library installed. You can do this by running: ```bash pip install trackio ``` Then, start logging to this Trackio dashboard by passing in the `space_id` to `trackio.init()`: ```python import trackio trackio.init(project="my-project", space_id="{}") ``` Then call `trackio.log()` to log metrics. ```python for i in range(10): trackio.log({{"loss": 1/(i+1)}}) ``` Finally, call `trackio.finish()` to finish the run. ```python trackio.finish() ``` """ INSTRUCTIONS_LOCAL = """ ## Start logging with Trackio 🤗 You can create a new project by calling `trackio.init()`: ```python import trackio trackio.init(project="my-project") ``` Then call `trackio.log()` to log metrics. ```python for i in range(10): trackio.log({"loss": 1/(i+1)}) ``` Finally, call `trackio.finish()` to finish the run. ```python trackio.finish() ``` Read the [Trackio documentation](https://huggingface.co/docs/trackio/en/index) for more examples. """ def get_runs(project) -> list[str]: if not project: return [] return SQLiteStorage.get_runs(project) def upload_db_to_space( project: str, uploaded_db: gr.FileData, hf_token: str | None ) -> None: """ Uploads the database of a local Trackio project to a Hugging Face Space. """ fns.check_hf_token_has_write_access(hf_token) db_project_path = SQLiteStorage.get_project_db_path(project) os.makedirs(os.path.dirname(db_project_path), exist_ok=True) shutil.copy(uploaded_db["path"], db_project_path) def get_available_metrics(project: str, runs: list[str]) -> list[str]: """Get all available metrics across all runs for x-axis selection.""" if not project or not runs: return ["step", "time"] all_metrics = set() for run in runs: metrics = SQLiteStorage.get_logs(project, run) if metrics: df = pd.DataFrame(metrics) numeric_cols = df.select_dtypes(include="number").columns numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] all_metrics.update(numeric_cols) all_metrics.add("step") all_metrics.add("time") sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics)) result = ["step", "time"] for metric in sorted_metrics: if metric not in result: result.append(metric) return result @dataclass class MediaData: caption: str | None file_path: str type: str def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]: media_by_key: dict[str, list[MediaData]] = {} logs = sorted(logs, key=lambda x: x.get("step", 0)) for log in logs: for key, value in log.items(): if isinstance(value, dict): type = value.get("_type") if ( type == TrackioImage.TYPE or type == TrackioVideo.TYPE or type == TrackioAudio.TYPE ): if key not in media_by_key: media_by_key[key] = [] try: media_data = MediaData( file_path=utils.MEDIA_DIR / value.get("file_path"), type=type, caption=value.get("caption"), ) media_by_key[key].append(media_data) except Exception as e: print(f"Media currently unavailable: {key}: {e}") return media_by_key def load_run_data( project: str | None, run: str | None, smoothing_granularity: int = 0, x_axis: str = "step", log_scale_x: bool = False, log_scale_y: bool = False, ) -> tuple[pd.DataFrame, dict]: if not project or not run: return None, None logs = SQLiteStorage.get_logs(project, run) if not logs: return None, None media = extract_media(logs) df = pd.DataFrame(logs) if "step" not in df.columns: df["step"] = range(len(df)) if x_axis == "time" and "timestamp" in df.columns: df["timestamp"] = pd.to_datetime(df["timestamp"]) first_timestamp = df["timestamp"].min() df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds() x_column = "time" elif x_axis == "step": x_column = "step" else: x_column = x_axis if log_scale_x and x_column in df.columns: x_vals = df[x_column] if (x_vals <= 0).any(): df[x_column] = np.log10(np.maximum(x_vals, 0) + 1) else: df[x_column] = np.log10(x_vals) if log_scale_y: numeric_cols = df.select_dtypes(include="number").columns y_cols = [ c for c in numeric_cols if c not in utils.RESERVED_KEYS and c != x_column ] for y_col in y_cols: if y_col in df.columns: y_vals = df[y_col] if (y_vals <= 0).any(): df[y_col] = np.log10(np.maximum(y_vals, 0) + 1) else: df[y_col] = np.log10(y_vals) if smoothing_granularity > 0: numeric_cols = df.select_dtypes(include="number").columns numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] df_original = df.copy() df_original["run"] = run df_original["data_type"] = "original" df_smoothed = df.copy() window_size = max(3, min(smoothing_granularity, len(df))) df_smoothed[numeric_cols] = ( df_smoothed[numeric_cols] .rolling(window=window_size, center=True, min_periods=1) .mean() ) df_smoothed["run"] = f"{run}_smoothed" df_smoothed["data_type"] = "smoothed" combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) combined_df["x_axis"] = x_column return combined_df, media else: df["run"] = run df["data_type"] = "original" df["x_axis"] = x_column return df, media def refresh_runs( project: str | None, filter_text: str | None, selection: RunSelection, selected_runs_from_url: list[str] | None = None, ): if project is None: runs: list[str] = [] else: runs = get_runs(project) if filter_text: runs = [r for r in runs if filter_text in r] preferred = None if selected_runs_from_url: preferred = [r for r in runs if r in selected_runs_from_url] did_change = selection.update_choices(runs, preferred) return ( fns.run_checkbox_update(selection) if did_change else gr.skip(), gr.Textbox(label=f"Runs ({len(runs)})"), selection, ) def generate_embed(project: str, metrics: str, selection: RunSelection) -> str: return utils.generate_embed_code(project, metrics, selection.selected) def update_x_axis_choices(project, selection): """Update x-axis dropdown choices based on available metrics.""" runs = selection.selected available_metrics = get_available_metrics(project, runs) return gr.Dropdown( label="X-axis", choices=available_metrics, value="step", ) def toggle_timer(cb_value): if cb_value: return gr.Timer(active=True) else: return gr.Timer(active=False) def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None: """ Uploads media files to a Trackio dashboard. Each entry in the list is a tuple of the project, run, and media file to be uploaded. Also handles uplaoding project-level files to the project's files directory (if the run and step are not provided). """ fns.check_hf_token_has_write_access(hf_token) for upload in uploads: media_path = get_project_media_path( project=upload["project"], run=upload["run"], step=upload["step"], relative_path=upload["relative_path"], ) shutil.copy(upload["uploaded_file"]["path"], media_path) def log( project: str, run: str, metrics: dict[str, Any], step: int | None, hf_token: str | None, ) -> None: """ Note: this method is not used in the latest versions of Trackio (replaced by bulk_log) but is kept for backwards compatibility for users who are connecting to a newer version of a Trackio Spaces dashboard with an older version of Trackio installed locally. """ fns.check_hf_token_has_write_access(hf_token) SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step) def bulk_log( logs: list[LogEntry], hf_token: str | None, ) -> None: """ Logs a list of metrics to a Trackio dashboard. Each entry in the list is a dictionary of the project, run, a dictionary of metrics, and optionally, a step and config. """ fns.check_hf_token_has_write_access(hf_token) logs_by_run = {} for log_entry in logs: key = (log_entry["project"], log_entry["run"]) if key not in logs_by_run: logs_by_run[key] = {"metrics": [], "steps": [], "config": None} logs_by_run[key]["metrics"].append(log_entry["metrics"]) logs_by_run[key]["steps"].append(log_entry.get("step")) if log_entry.get("config") and logs_by_run[key]["config"] is None: logs_by_run[key]["config"] = log_entry["config"] for (project, run), data in logs_by_run.items(): SQLiteStorage.bulk_log( project=project, run=run, metrics_list=data["metrics"], steps=data["steps"], config=data["config"], ) def bulk_log_system( logs: list[SystemLogEntry], hf_token: str | None, ) -> None: """ Logs system metrics (GPU, etc.) to a Trackio dashboard. These metrics use timestamps instead of steps. """ fns.check_hf_token_has_write_access(hf_token) logs_by_run = {} for log_entry in logs: key = (log_entry["project"], log_entry["run"]) if key not in logs_by_run: logs_by_run[key] = {"metrics": [], "timestamps": []} logs_by_run[key]["metrics"].append(log_entry["metrics"]) logs_by_run[key]["timestamps"].append(log_entry.get("timestamp")) for (project, run), data in logs_by_run.items(): SQLiteStorage.bulk_log_system( project=project, run=run, metrics_list=data["metrics"], timestamps=data["timestamps"], ) def get_metric_values( project: str, run: str, metric_name: str, ) -> list[dict]: """ Get all values for a specific metric in a project/run. Returns a list of dictionaries with timestamp, step, and value. """ return SQLiteStorage.get_metric_values(project, run, metric_name) def get_runs_for_project( project: str, ) -> list[str]: """ Get all runs for a given project. Returns a list of run names. """ return SQLiteStorage.get_runs(project) def get_metrics_for_run( project: str, run: str, ) -> list[str]: """ Get all metrics for a given project and run. Returns a list of metric names. """ return SQLiteStorage.get_all_metrics_for_run(project, run) def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]: """ Filter metrics using regex pattern. Args: metrics: List of metric names to filter filter_pattern: Regex pattern to match against metric names Returns: List of metric names that match the pattern """ if not filter_pattern.strip(): return metrics try: pattern = re.compile(filter_pattern, re.IGNORECASE) return [metric for metric in metrics if pattern.search(metric)] except re.error: return [ metric for metric in metrics if filter_pattern.lower() in metric.lower() ] def get_all_projects() -> list[str]: """ Get all project names. Returns a list of project names. """ return SQLiteStorage.get_projects() def get_project_summary(project: str) -> dict: """ Get a summary of a project including number of runs and recent activity. Args: project: Project name Returns: Dictionary with project summary information """ runs = SQLiteStorage.get_runs(project) if not runs: return {"project": project, "num_runs": 0, "runs": [], "last_activity": None} last_steps = SQLiteStorage.get_max_steps_for_runs(project) return { "project": project, "num_runs": len(runs), "runs": runs, "last_activity": max(last_steps.values()) if last_steps else None, } def get_run_summary(project: str, run: str) -> dict: """ Get a summary of a specific run including metrics and configuration. Args: project: Project name run: Run name Returns: Dictionary with run summary information """ logs = SQLiteStorage.get_logs(project, run) metrics = SQLiteStorage.get_all_metrics_for_run(project, run) if not logs: return { "project": project, "run": run, "num_logs": 0, "metrics": [], "config": None, "last_step": None, } df = pd.DataFrame(logs) config = logs[0].get("config") if logs else None last_step = df["step"].max() if "step" in df.columns else len(logs) - 1 return { "project": project, "run": run, "num_logs": len(logs), "metrics": metrics, "config": config, "last_step": last_step, } def configure(request: gr.Request): sidebar_param = request.query_params.get("sidebar") match sidebar_param: case "collapsed": sidebar = gr.Sidebar(open=False, visible=True) case "hidden": sidebar = gr.Sidebar(open=False, visible=False) case _: sidebar = gr.Sidebar(open=True, visible=True) metrics_param = request.query_params.get("metrics", "") runs_param = request.query_params.get("runs", "") selected_runs = runs_param.split(",") if runs_param else [] navbar_param = request.query_params.get("navbar") x_min_param = request.query_params.get("xmin") x_max_param = request.query_params.get("xmax") x_min = float(x_min_param) if x_min_param is not None else None x_max = float(x_max_param) if x_max_param is not None else None smoothing_param = request.query_params.get("smoothing") smoothing_value = int(smoothing_param) if smoothing_param is not None else 10 match navbar_param: case "hidden": navbar = gr.Navbar(visible=False) case _: navbar = gr.Navbar(visible=True) return ( [], sidebar, metrics_param, selected_runs, navbar, [x_min, x_max], smoothing_value, ) CSS = """ .dark .caption-label { color: white; } .accent-link { font-weight: bold; } .media-gallery .fixed-height { min-height: 275px; } .media-group, .media-group > div { background: none; } .media-group .tabs { padding: 0.5em; } .media-tab { max-height: 500px; overflow-y: scroll; } .media-audio-accordion > button { border-bottom-width: 1px; padding-bottom: 3px; } .media-audio-item { border-width: 1px !important; border-radius: 0.5em; } .media-audio-row { gap: 0.25em; margin-bottom: 0.25em; } .tab-like-container { visibility: hidden; } /* Runs table styling - prevent text overflow and distortion */ .dataframe table { table-layout: fixed; width: 100%; } .dataframe td, .dataframe th { overflow: hidden; text-overflow: ellipsis; white-space: nowrap; padding: 8px 12px; } .dataframe td:hover, .dataframe th:hover { overflow: visible; white-space: normal; word-wrap: break-word; position: relative; z-index: 10; } """ HEAD = """ """ gr.set_static_paths(paths=[utils.MEDIA_DIR]) with gr.Blocks(title="Trackio Dashboard") as demo: with gr.Sidebar(open=False) as sidebar: logo = fns.create_logo() project_dd = fns.create_project_dropdown() embed_code = gr.Code( label="Embed this view", max_lines=2, lines=2, language="html", visible=bool(os.environ.get("SPACE_HOST")), ) with gr.Group(): run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...") run_group_by_dd = gr.Dropdown(label="Group by...", choices=[], value=None) grouped_runs_panel = gr.Group(visible=False) run_cb = ColoredCheckboxGroup(choices=[], colors=[], label="Runs") gr.HTML("