Spaces:
Running
Running
| import math | |
| import os | |
| import re | |
| import secrets | |
| import time | |
| import warnings | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any | |
| from urllib.parse import quote, urlencode | |
| import huggingface_hub | |
| import numpy as np | |
| from huggingface_hub.constants import HF_HOME | |
| if TYPE_CHECKING: | |
| from trackio.commit_scheduler import CommitScheduler | |
| from trackio.dummy_commit_scheduler import DummyCommitScheduler | |
| RESERVED_KEYS = ["project", "run", "timestamp", "step", "time", "metrics"] | |
| TRACKIO_LOGO_DIR = Path(__file__).parent / "assets" | |
| def get_logo_urls() -> dict[str, str]: | |
| """Get logo URLs from environment variables or use defaults.""" | |
| light_url = os.environ.get( | |
| "TRACKIO_LOGO_LIGHT_URL", | |
| f"/file?path={quote(str(TRACKIO_LOGO_DIR / 'trackio_logo_type_light_transparent.png'))}", | |
| ) | |
| dark_url = os.environ.get( | |
| "TRACKIO_LOGO_DARK_URL", | |
| f"/file?path={quote(str(TRACKIO_LOGO_DIR / 'trackio_logo_type_dark_transparent.png'))}", | |
| ) | |
| return {"light": light_url, "dark": dark_url} | |
| def order_metrics_by_plot_preference(metrics: list[str]) -> tuple[list[str], dict]: | |
| """ | |
| Order metrics based on TRACKIO_PLOT_ORDER environment variable and group them. | |
| Args: | |
| metrics: List of metric names to order and group | |
| Returns: | |
| Tuple of (ordered_group_names, grouped_metrics_dict) | |
| """ | |
| plot_order_env = os.environ.get("TRACKIO_PLOT_ORDER", "") | |
| if not plot_order_env.strip(): | |
| plot_order = [] | |
| else: | |
| plot_order = [ | |
| item.strip() for item in plot_order_env.split(",") if item.strip() | |
| ] | |
| def get_metric_priority(metric: str) -> tuple[int, int, str]: | |
| if not plot_order: | |
| return (float("inf"), float("inf"), metric) | |
| group_prefix = metric.split("/")[0] if "/" in metric else "charts" | |
| no_match_priority = len(plot_order) | |
| group_priority = no_match_priority | |
| for i, pattern in enumerate(plot_order): | |
| pattern_group = pattern.split("/")[0] if "/" in pattern else "charts" | |
| if pattern_group == group_prefix: | |
| group_priority = i | |
| break | |
| within_group_priority = no_match_priority | |
| for i, pattern in enumerate(plot_order): | |
| if pattern == metric: | |
| within_group_priority = i | |
| break | |
| elif pattern.endswith("/*") and within_group_priority == no_match_priority: | |
| pattern_prefix = pattern[:-2] | |
| if metric.startswith(pattern_prefix + "/"): | |
| within_group_priority = i + len(plot_order) | |
| return (group_priority, within_group_priority, metric) | |
| result = {} | |
| for metric in metrics: | |
| if "/" not in metric: | |
| if "charts" not in result: | |
| result["charts"] = {"direct_metrics": [], "subgroups": {}} | |
| result["charts"]["direct_metrics"].append(metric) | |
| else: | |
| parts = metric.split("/") | |
| main_prefix = parts[0] | |
| if main_prefix not in result: | |
| result[main_prefix] = {"direct_metrics": [], "subgroups": {}} | |
| if len(parts) == 2: | |
| result[main_prefix]["direct_metrics"].append(metric) | |
| else: | |
| subprefix = parts[1] | |
| if subprefix not in result[main_prefix]["subgroups"]: | |
| result[main_prefix]["subgroups"][subprefix] = [] | |
| result[main_prefix]["subgroups"][subprefix].append(metric) | |
| for group_data in result.values(): | |
| group_data["direct_metrics"].sort(key=get_metric_priority) | |
| for subgroup_name in group_data["subgroups"]: | |
| group_data["subgroups"][subgroup_name].sort(key=get_metric_priority) | |
| if "charts" in result and not result["charts"]["direct_metrics"]: | |
| del result["charts"] | |
| def get_group_priority(group_name: str) -> tuple[int, str]: | |
| if not plot_order: | |
| return (float("inf"), group_name) | |
| min_priority = len(plot_order) | |
| for i, pattern in enumerate(plot_order): | |
| pattern_group = pattern.split("/")[0] if "/" in pattern else "charts" | |
| if pattern_group == group_name: | |
| min_priority = min(min_priority, i) | |
| return (min_priority, group_name) | |
| ordered_groups = sorted(result.keys(), key=get_group_priority) | |
| return ordered_groups, result | |
| def persistent_storage_enabled() -> bool: | |
| return ( | |
| os.environ.get("PERSISTANT_STORAGE_ENABLED") == "true" | |
| ) # typo in the name of the environment variable | |
| def _get_trackio_dir() -> Path: | |
| if persistent_storage_enabled(): | |
| return Path("/data/trackio") | |
| if os.environ.get("TRACKIO_DIR"): | |
| return Path(os.environ.get("TRACKIO_DIR")) | |
| return Path(HF_HOME) / "trackio" | |
| TRACKIO_DIR = _get_trackio_dir() | |
| MEDIA_DIR = TRACKIO_DIR / "media" | |
| def get_or_create_project_hash(project: str) -> str: | |
| hash_path = TRACKIO_DIR / f"{project}.hash" | |
| if hash_path.exists(): | |
| return hash_path.read_text().strip() | |
| hash_value = secrets.token_urlsafe(8) | |
| TRACKIO_DIR.mkdir(parents=True, exist_ok=True) | |
| hash_path.write_text(hash_value) | |
| return hash_value | |
| def generate_readable_name(used_names: list[str], space_id: str | None = None) -> str: | |
| """ | |
| Generates a random, readable name like "dainty-sunset-0". | |
| If space_id is provided, generates username-timestamp format instead. | |
| """ | |
| if space_id is not None: | |
| username = _get_default_namespace() | |
| timestamp = int(time.time()) | |
| return f"{username}-{timestamp}" | |
| adjectives = [ | |
| "dainty", | |
| "brave", | |
| "calm", | |
| "eager", | |
| "fancy", | |
| "gentle", | |
| "happy", | |
| "jolly", | |
| "kind", | |
| "lively", | |
| "merry", | |
| "nice", | |
| "proud", | |
| "quick", | |
| "hugging", | |
| "silly", | |
| "tidy", | |
| "witty", | |
| "zealous", | |
| "bright", | |
| "shy", | |
| "bold", | |
| "clever", | |
| "daring", | |
| "elegant", | |
| "faithful", | |
| "graceful", | |
| "honest", | |
| "inventive", | |
| "jovial", | |
| "keen", | |
| "lucky", | |
| "modest", | |
| "noble", | |
| "optimistic", | |
| "patient", | |
| "quirky", | |
| "resourceful", | |
| "sincere", | |
| "thoughtful", | |
| "upbeat", | |
| "valiant", | |
| "warm", | |
| "youthful", | |
| "zesty", | |
| "adventurous", | |
| "breezy", | |
| "cheerful", | |
| "delightful", | |
| "energetic", | |
| "fearless", | |
| "glad", | |
| "hopeful", | |
| "imaginative", | |
| "joyful", | |
| "kindly", | |
| "luminous", | |
| "mysterious", | |
| "neat", | |
| "outgoing", | |
| "playful", | |
| "radiant", | |
| "spirited", | |
| "tranquil", | |
| "unique", | |
| "vivid", | |
| "wise", | |
| "zany", | |
| "artful", | |
| "bubbly", | |
| "charming", | |
| "dazzling", | |
| "earnest", | |
| "festive", | |
| "gentlemanly", | |
| "hearty", | |
| "intrepid", | |
| "jubilant", | |
| "knightly", | |
| "lively", | |
| "magnetic", | |
| "nimble", | |
| "orderly", | |
| "peaceful", | |
| "quick-witted", | |
| "robust", | |
| "sturdy", | |
| "trusty", | |
| "upstanding", | |
| "vibrant", | |
| "whimsical", | |
| ] | |
| nouns = [ | |
| "sunset", | |
| "forest", | |
| "river", | |
| "mountain", | |
| "breeze", | |
| "meadow", | |
| "ocean", | |
| "valley", | |
| "sky", | |
| "field", | |
| "cloud", | |
| "star", | |
| "rain", | |
| "leaf", | |
| "stone", | |
| "flower", | |
| "bird", | |
| "tree", | |
| "wave", | |
| "trail", | |
| "island", | |
| "desert", | |
| "hill", | |
| "lake", | |
| "pond", | |
| "grove", | |
| "canyon", | |
| "reef", | |
| "bay", | |
| "peak", | |
| "glade", | |
| "marsh", | |
| "cliff", | |
| "dune", | |
| "spring", | |
| "brook", | |
| "cave", | |
| "plain", | |
| "ridge", | |
| "wood", | |
| "blossom", | |
| "petal", | |
| "root", | |
| "branch", | |
| "seed", | |
| "acorn", | |
| "pine", | |
| "willow", | |
| "cedar", | |
| "elm", | |
| "falcon", | |
| "eagle", | |
| "sparrow", | |
| "robin", | |
| "owl", | |
| "finch", | |
| "heron", | |
| "crane", | |
| "duck", | |
| "swan", | |
| "fox", | |
| "wolf", | |
| "bear", | |
| "deer", | |
| "moose", | |
| "otter", | |
| "beaver", | |
| "lynx", | |
| "hare", | |
| "badger", | |
| "butterfly", | |
| "bee", | |
| "ant", | |
| "beetle", | |
| "dragonfly", | |
| "firefly", | |
| "ladybug", | |
| "moth", | |
| "spider", | |
| "worm", | |
| "coral", | |
| "kelp", | |
| "shell", | |
| "pebble", | |
| "face", | |
| "boulder", | |
| "cobble", | |
| "sand", | |
| "wavelet", | |
| "tide", | |
| "current", | |
| "mist", | |
| ] | |
| number = 0 | |
| name = f"{adjectives[0]}-{nouns[0]}-{number}" | |
| while name in used_names: | |
| number += 1 | |
| adjective = adjectives[number % len(adjectives)] | |
| noun = nouns[number % len(nouns)] | |
| name = f"{adjective}-{noun}-{number}" | |
| return name | |
| def is_in_notebook(): | |
| """ | |
| Detect if code is running in a notebook environment (Jupyter, Colab, etc.). | |
| """ | |
| try: | |
| from IPython import get_ipython | |
| if get_ipython() is not None: | |
| return get_ipython().__class__.__name__ in [ | |
| "ZMQInteractiveShell", # Jupyter notebook/lab | |
| "Shell", # IPython terminal | |
| ] or "google.colab" in str(get_ipython()) | |
| except ImportError: | |
| pass | |
| return False | |
| def block_main_thread_until_keyboard_interrupt(): | |
| try: | |
| while True: | |
| time.sleep(0.1) | |
| except (KeyboardInterrupt, OSError): | |
| print("Keyboard interruption in main thread... closing dashboard.") | |
| def simplify_column_names(columns: list[str]) -> dict[str, str]: | |
| """ | |
| Simplifies column names to first 10 alphanumeric or "/" characters with unique suffixes. | |
| Args: | |
| columns: List of original column names | |
| Returns: | |
| Dictionary mapping original column names to simplified names | |
| """ | |
| simplified_names = {} | |
| used_names = set() | |
| for col in columns: | |
| alphanumeric = re.sub(r"[^a-zA-Z0-9/]", "", col) | |
| base_name = alphanumeric[:10] if alphanumeric else f"col_{len(used_names)}" | |
| final_name = base_name | |
| suffix = 1 | |
| while final_name in used_names: | |
| final_name = f"{base_name}_{suffix}" | |
| suffix += 1 | |
| simplified_names[col] = final_name | |
| used_names.add(final_name) | |
| return simplified_names | |
| def print_dashboard_instructions(project: str) -> None: | |
| """ | |
| Prints instructions for viewing the Trackio dashboard. | |
| Args: | |
| project: The name of the project to show dashboard for. | |
| """ | |
| ORANGE = "\033[38;5;208m" | |
| BOLD = "\033[1m" | |
| RESET = "\033[0m" | |
| print("* View dashboard by running in your terminal:") | |
| print(f'{BOLD}{ORANGE}trackio show --project "{project}"{RESET}') | |
| print(f'* or by running in Python: trackio.show(project="{project}")') | |
| def preprocess_space_and_dataset_ids( | |
| space_id: str | None, | |
| dataset_id: str | None, | |
| bucket_id: str | None = None, | |
| ) -> tuple[str | None, str | None, str | None]: | |
| """ | |
| Preprocesses the Space and Bucket names to ensure they are valid | |
| "username/name" format. When space_id is provided and bucket_id is not | |
| explicitly set, auto-generates a bucket_id. | |
| """ | |
| if space_id is not None and "/" not in space_id: | |
| username = _get_default_namespace() | |
| space_id = f"{username}/{space_id}" | |
| if dataset_id is not None: | |
| warnings.warn( | |
| "`dataset_id` is deprecated. Use `bucket_id` instead.", | |
| DeprecationWarning, | |
| stacklevel=3, | |
| ) | |
| if dataset_id is not None and "/" not in dataset_id: | |
| username = _get_default_namespace() | |
| dataset_id = f"{username}/{dataset_id}" | |
| if bucket_id is not None and "/" not in bucket_id: | |
| username = _get_default_namespace() | |
| bucket_id = f"{username}/{bucket_id}" | |
| if space_id is not None and dataset_id is None and bucket_id is None: | |
| bucket_id = f"{space_id}-bucket" | |
| return space_id, dataset_id, bucket_id | |
| def fibo(): | |
| """Generator for Fibonacci backoff: 1, 1, 2, 3, 5, 8, ...""" | |
| a, b = 1, 1 | |
| while True: | |
| yield a | |
| a, b = b, a + b | |
| def format_timestamp(timestamp_str): | |
| """Convert ISO timestamp to human-readable format like '3 minutes ago'.""" | |
| if not timestamp_str or is_missing_value(timestamp_str): | |
| return "Unknown" | |
| try: | |
| created_time = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) | |
| if created_time.tzinfo is None: | |
| created_time = created_time.replace(tzinfo=timezone.utc) | |
| now = datetime.now(timezone.utc) | |
| diff = now - created_time | |
| seconds = int(diff.total_seconds()) | |
| if seconds < 60: | |
| return "Just now" | |
| elif seconds < 3600: | |
| minutes = seconds // 60 | |
| return f"{minutes} minute{'s' if minutes != 1 else ''} ago" | |
| elif seconds < 86400: | |
| hours = seconds // 3600 | |
| return f"{hours} hour{'s' if hours != 1 else ''} ago" | |
| else: | |
| days = seconds // 86400 | |
| return f"{days} day{'s' if days != 1 else ''} ago" | |
| except Exception: | |
| return "Unknown" | |
| DEFAULT_COLOR_PALETTE = [ | |
| "#A8769B", | |
| "#E89957", | |
| "#3B82F6", | |
| "#10B981", | |
| "#EF4444", | |
| "#8B5CF6", | |
| "#14B8A6", | |
| "#F59E0B", | |
| "#EC4899", | |
| "#06B6D4", | |
| ] | |
| def get_color_palette() -> list[str]: | |
| """Get the color palette from environment variable or use default.""" | |
| env_palette = os.environ.get("TRACKIO_COLOR_PALETTE") | |
| if env_palette: | |
| return [color.strip() for color in env_palette.split(",")] | |
| return DEFAULT_COLOR_PALETTE | |
| def get_color_mapping( | |
| runs: list[str], smoothing: bool, color_palette: list[str] | None = None | |
| ) -> dict[str, str]: | |
| """Generate color mapping for runs, with transparency for original data when smoothing is enabled.""" | |
| if color_palette is None: | |
| color_palette = get_color_palette() | |
| color_map = {} | |
| for i, run in enumerate(runs): | |
| base_color = color_palette[i % len(color_palette)] | |
| if smoothing: | |
| color_map[run] = base_color + "4D" | |
| color_map[f"{run}_smoothed"] = base_color | |
| else: | |
| color_map[run] = base_color | |
| return color_map | |
| def is_missing_value(value: object) -> bool: | |
| if value is None: | |
| return True | |
| if isinstance(value, str): | |
| return False | |
| try: | |
| return bool(math.isnan(value)) | |
| except (TypeError, ValueError): | |
| return False | |
| def _to_records_with_columns( | |
| data: object, | |
| ) -> tuple[list[dict[str, Any]], list[str], Any]: | |
| if hasattr(data, "to_dict"): | |
| try: | |
| records = data.to_dict(orient="records") | |
| except Exception: | |
| pass | |
| else: | |
| columns = [str(column) for column in getattr(data, "columns", [])] | |
| return [dict(row) for row in records], columns, data.__class__ | |
| if isinstance(data, list): | |
| records = [dict(row) for row in data] | |
| columns = list(records[0].keys()) if records else [] | |
| return records, columns, None | |
| raise TypeError( | |
| "downsample() expects a list of row dictionaries or a dataframe-like object." | |
| ) | |
| def _restore_records_shape( | |
| records: list[dict[str, Any]], | |
| columns: list[str], | |
| dataframe_class: Any, | |
| ) -> Any: | |
| if dataframe_class is None or not hasattr(dataframe_class, "from_records"): | |
| return records | |
| return dataframe_class.from_records(records, columns=columns) | |
| def downsample( | |
| data: object, | |
| x: str, | |
| y: str, | |
| color: str | None, | |
| x_lim: tuple[float | None, float | None] | None = None, | |
| ) -> tuple[Any, tuple[float, float] | None]: | |
| """ | |
| Downsample the dataframe to reduce the number of points plotted. | |
| Also updates the x-axis limits to the data min/max if either of the x-axis limits are None. | |
| Args: | |
| df: The dataframe to downsample. | |
| x: The column name to use for the x-axis. | |
| y: The column name to use for the y-axis. | |
| color: The column name to use for the color. | |
| x_lim: The x-axis limits to use. | |
| Returns: | |
| A tuple containing the downsampled dataframe and the updated x-axis limits. | |
| """ | |
| rows, columns, dataframe_class = _to_records_with_columns(data) | |
| if not rows: | |
| if x_lim is not None: | |
| x_lim = (x_lim[0] or 0, x_lim[1] or 0) | |
| return _restore_records_shape([], columns, dataframe_class), x_lim | |
| columns_to_keep = [x, y] | |
| if color is not None and any(color in row for row in rows): | |
| columns_to_keep.append(color) | |
| filtered_rows = [ | |
| {column: row.get(column) for column in columns_to_keep} for row in rows | |
| ] | |
| data_x_values = [row[x] for row in filtered_rows] | |
| data_x_min = min(data_x_values) | |
| data_x_max = max(data_x_values) | |
| if x_lim is not None: | |
| x_min, x_max = x_lim | |
| if x_min is None: | |
| x_min = data_x_min | |
| if x_max is None: | |
| x_max = data_x_max | |
| updated_x_lim = (x_min, x_max) | |
| else: | |
| updated_x_lim = None | |
| n_bins = 100 | |
| groups: dict[Any, list[tuple[int, dict[str, Any]]]] = {} | |
| if color is not None and color in columns_to_keep: | |
| for idx, row in enumerate(filtered_rows): | |
| groups.setdefault(row.get(color), []).append((idx, row)) | |
| else: | |
| groups[None] = list(enumerate(filtered_rows)) | |
| downsampled_indices: list[int] = [] | |
| for group_rows in groups.values(): | |
| if not group_rows: | |
| continue | |
| group_rows = sorted(group_rows, key=lambda item: item[1][x]) | |
| if updated_x_lim is not None: | |
| x_min, x_max = updated_x_lim | |
| before_point = [item for item in group_rows if item[1][x] < x_min] | |
| after_point = [item for item in group_rows if item[1][x] > x_max] | |
| group_rows = [item for item in group_rows if x_min <= item[1][x] <= x_max] | |
| else: | |
| before_point = after_point = None | |
| x_min = group_rows[0][1][x] | |
| x_max = group_rows[-1][1][x] | |
| if before_point: | |
| downsampled_indices.append(before_point[-1][0]) | |
| if after_point: | |
| downsampled_indices.append(after_point[0][0]) | |
| if not group_rows: | |
| continue | |
| if x_min == x_max: | |
| min_y_idx = min(group_rows, key=lambda item: item[1][y])[0] | |
| max_y_idx = max(group_rows, key=lambda item: item[1][y])[0] | |
| if min_y_idx != max_y_idx: | |
| downsampled_indices.extend([min_y_idx, max_y_idx]) | |
| else: | |
| downsampled_indices.append(min_y_idx) | |
| continue | |
| if len(group_rows) < 500: | |
| downsampled_indices.extend(idx for idx, _ in group_rows) | |
| continue | |
| bins = np.linspace(x_min, x_max, n_bins + 1) | |
| binned_rows: dict[int, list[tuple[int, dict[str, Any]]]] = {} | |
| for idx, row in group_rows: | |
| bin_idx = int( | |
| np.clip(np.digitize(row[x], bins, right=False) - 1, 0, n_bins - 1) | |
| ) | |
| binned_rows.setdefault(bin_idx, []).append((idx, row)) | |
| for bin_rows in binned_rows.values(): | |
| if not bin_rows: | |
| continue | |
| min_y_idx = min(bin_rows, key=lambda item: item[1][y])[0] | |
| max_y_idx = max(bin_rows, key=lambda item: item[1][y])[0] | |
| downsampled_indices.append(min_y_idx) | |
| if min_y_idx != max_y_idx: | |
| downsampled_indices.append(max_y_idx) | |
| unique_indices = sorted(set(downsampled_indices)) | |
| selected_rows = [filtered_rows[idx] for idx in unique_indices] | |
| if color is not None and color in columns_to_keep: | |
| grouped_rows: dict[Any, list[dict[str, Any]]] = {} | |
| group_order: list[Any] = [] | |
| for row in selected_rows: | |
| group_key = row.get(color) | |
| if group_key not in grouped_rows: | |
| grouped_rows[group_key] = [] | |
| group_order.append(group_key) | |
| grouped_rows[group_key].append(row) | |
| downsampled_rows = [] | |
| for group_key in group_order: | |
| downsampled_rows.extend( | |
| sorted(grouped_rows[group_key], key=lambda row: row[x]) | |
| ) | |
| else: | |
| downsampled_rows = sorted(selected_rows, key=lambda row: row[x]) | |
| return ( | |
| _restore_records_shape(downsampled_rows, columns_to_keep, dataframe_class), | |
| updated_x_lim, | |
| ) | |
| def sort_metrics_by_prefix(metrics: list[str]) -> list[str]: | |
| """ | |
| Sort metrics by grouping prefixes together for dropdown/list display. | |
| Metrics without prefixes come first, then grouped by prefix. | |
| Args: | |
| metrics: List of metric names | |
| Returns: | |
| List of metric names sorted by prefix | |
| Example: | |
| Input: ["train/loss", "loss", "train/acc", "val/loss"] | |
| Output: ["loss", "train/acc", "train/loss", "val/loss"] | |
| """ | |
| groups = group_metrics_by_prefix(metrics) | |
| result = [] | |
| if "charts" in groups: | |
| result.extend(groups["charts"]) | |
| for group_name in sorted(groups.keys()): | |
| if group_name != "charts": | |
| result.extend(groups[group_name]) | |
| return result | |
| def group_metrics_by_prefix(metrics: list[str]) -> dict[str, list[str]]: | |
| """ | |
| Group metrics by their prefix. Metrics without prefix go to 'charts' group. | |
| Args: | |
| metrics: List of metric names | |
| Returns: | |
| Dictionary with prefix names as keys and lists of metrics as values | |
| Example: | |
| Input: ["loss", "accuracy", "train/loss", "train/acc", "val/loss"] | |
| Output: { | |
| "charts": ["loss", "accuracy"], | |
| "train": ["train/loss", "train/acc"], | |
| "val": ["val/loss"] | |
| } | |
| """ | |
| no_prefix = [] | |
| with_prefix = [] | |
| for metric in metrics: | |
| if "/" in metric: | |
| with_prefix.append(metric) | |
| else: | |
| no_prefix.append(metric) | |
| no_prefix.sort() | |
| prefix_groups = {} | |
| for metric in with_prefix: | |
| prefix = metric.split("/")[0] | |
| if prefix not in prefix_groups: | |
| prefix_groups[prefix] = [] | |
| prefix_groups[prefix].append(metric) | |
| for prefix in prefix_groups: | |
| prefix_groups[prefix].sort() | |
| groups = {} | |
| if no_prefix: | |
| groups["charts"] = no_prefix | |
| for prefix in sorted(prefix_groups.keys()): | |
| groups[prefix] = prefix_groups[prefix] | |
| return groups | |
| def get_sync_status(scheduler: "CommitScheduler | DummyCommitScheduler") -> int | None: | |
| """Get the sync status from the CommitScheduler in an integer number of minutes, or None if not synced yet.""" | |
| if getattr( | |
| scheduler, "last_push_time", None | |
| ): # DummyCommitScheduler doesn't have last_push_time | |
| time_diff = time.time() - scheduler.last_push_time | |
| return int(time_diff / 60) | |
| else: | |
| return None | |
| def generate_share_url( | |
| project: str, | |
| metrics: str, | |
| selected_runs: list = None, | |
| hide_headers: bool = False, | |
| ) -> str: | |
| """Generate the shareable Space URL based on current settings.""" | |
| space_host = os.environ.get("SPACE_HOST", "") | |
| if not space_host: | |
| return "" | |
| params: dict[str, str] = {} | |
| if project: | |
| params["project"] = project | |
| if metrics and metrics.strip(): | |
| params["metrics"] = metrics | |
| if selected_runs: | |
| params["runs"] = ",".join(selected_runs) | |
| if hide_headers: | |
| params["accordion"] = "hidden" | |
| params["sidebar"] = "hidden" | |
| params["navbar"] = "hidden" | |
| query_string = urlencode(params) | |
| return f"https://{space_host}?{query_string}" | |
| def generate_embed_code( | |
| project: str, | |
| metrics: str, | |
| selected_runs: list = None, | |
| hide_headers: bool = False, | |
| ) -> str: | |
| """Generate the embed iframe code based on current settings.""" | |
| embed_url = generate_share_url(project, metrics, selected_runs, hide_headers) | |
| if not embed_url: | |
| return "" | |
| return f'<iframe src="{embed_url}" style="width:1600px; height:500px; border:0;"></iframe>' | |
| def serialize_values(metrics): | |
| """ | |
| Serialize values to make them JSON-compliant. | |
| Converts: | |
| - float('inf') -> "Infinity" | |
| - float('-inf') -> "-Infinity" | |
| - float('nan') -> "NaN" | |
| Example: | |
| {"loss": float('inf'), "accuracy": 0.95} -> {"loss": "Infinity", "accuracy": 0.95} | |
| """ | |
| def _serialize(value): | |
| if isinstance(value, dict): | |
| return {str(key): _serialize(item) for key, item in value.items()} | |
| if isinstance(value, (list, tuple, set)): | |
| return [_serialize(item) for item in value] | |
| if isinstance(value, np.generic): | |
| value = value.item() | |
| if isinstance(value, bool | int): | |
| return value | |
| if isinstance(value, float): | |
| if math.isinf(value): | |
| return "Infinity" if value > 0 else "-Infinity" | |
| if math.isnan(value): | |
| return "NaN" | |
| return float(value) | |
| return value | |
| return _serialize(metrics) | |
| def deserialize_values(metrics): | |
| """ | |
| Deserialize infinity and NaN string values back to their numeric forms. | |
| Only handles top-level string values. | |
| Converts: | |
| - "Infinity" -> float('inf') | |
| - "-Infinity" -> float('-inf') | |
| - "NaN" -> float('nan') | |
| Example: | |
| {"loss": "Infinity", "accuracy": 0.95} -> {"loss": float('inf'), "accuracy": 0.95} | |
| """ | |
| if not isinstance(metrics, dict): | |
| return metrics | |
| result = {} | |
| for key, value in metrics.items(): | |
| if value == "Infinity": | |
| result[key] = float("inf") | |
| elif value == "-Infinity": | |
| result[key] = float("-inf") | |
| elif value == "NaN": | |
| result[key] = float("nan") | |
| else: | |
| result[key] = value | |
| return result | |
| def get_full_url( | |
| base_url: str, project: str | None, write_token: str, footer: bool = True | |
| ) -> str: | |
| params = [] | |
| if project: | |
| params.append(f"project={project}") | |
| params.append(f"write_token={write_token}") | |
| if not footer: | |
| params.append("footer=false") | |
| return base_url + "?" + "&".join(params) | |
| def embed_url_in_notebook(url: str) -> None: | |
| try: | |
| from IPython.display import HTML, display | |
| embed_code = HTML( | |
| f'<div><iframe src="{url}" width="100%" height="1000px" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>' | |
| ) | |
| display(embed_code) | |
| except ImportError: | |
| pass | |
| def to_json_safe(obj): | |
| if isinstance(obj, (str, int, float, bool, type(None))): | |
| return obj | |
| if isinstance(obj, np.generic): | |
| return obj.item() | |
| if isinstance(obj, dict): | |
| return {str(k): to_json_safe(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple, set)): | |
| return [to_json_safe(v) for v in obj] | |
| if hasattr(obj, "to_dict") and callable(obj.to_dict): | |
| return to_json_safe(obj.to_dict()) | |
| if hasattr(obj, "__dict__"): | |
| return { | |
| str(k): to_json_safe(v) | |
| for k, v in vars(obj).items() | |
| if not k.startswith("_") | |
| } | |
| return str(obj) | |
| def get_space() -> str | None: | |
| """ | |
| Get the space ID ("user/space") if Trackio is running in a Space, or None if not. | |
| """ | |
| return os.environ.get("SPACE_ID") | |
| def ordered_subset(items: list[str], subset: list[str] | None) -> list[str]: | |
| subset_set = set(subset or []) | |
| return [item for item in items if item in subset_set] | |
| def _get_default_namespace() -> str: | |
| """Get the default namespace (username). | |
| This function uses caching to avoid repeated API calls to /whoami-v2. | |
| """ | |
| token = huggingface_hub.get_token() | |
| return huggingface_hub.whoami(token=token, cache=True)["name"] | |