| """The Media and Tables page for the Trackio UI.""" |
|
|
| import re |
| from dataclasses import dataclass |
|
|
| import gradio as gr |
| import pandas as pd |
|
|
| try: |
| import trackio.utils as utils |
| from trackio.media import TrackioAudio, TrackioImage, TrackioVideo |
| from trackio.sqlite_storage import SQLiteStorage |
| from trackio.table import Table |
| from trackio.ui import fns |
| from trackio.ui.components.colored_dropdown import ColoredDropdown |
| except ImportError: |
| import utils |
| from media import TrackioAudio, TrackioImage, TrackioVideo |
| from sqlite_storage import SQLiteStorage |
| from table import Table |
| from ui import fns |
| from ui.components.colored_dropdown import ColoredDropdown |
|
|
|
|
| def get_runs(project) -> list[str]: |
| if not project: |
| return [] |
| return SQLiteStorage.get_runs(project) |
|
|
|
|
| @dataclass |
| class MediaData: |
| caption: str | None |
| file_path: str |
| type: str |
|
|
|
|
| def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]: |
| media_by_key: dict[str, list[MediaData]] = {} |
| logs = sorted(logs, key=lambda x: x.get("step", 0)) |
| for log in logs: |
| for key, value in log.items(): |
| if isinstance(value, dict): |
| type = value.get("_type") |
| if ( |
| type == TrackioImage.TYPE |
| or type == TrackioVideo.TYPE |
| or type == TrackioAudio.TYPE |
| ): |
| if key not in media_by_key: |
| media_by_key[key] = [] |
| try: |
| media_data = MediaData( |
| file_path=utils.MEDIA_DIR / value.get("file_path"), |
| type=type, |
| caption=value.get("caption"), |
| ) |
| media_by_key[key].append(media_data) |
| except Exception as e: |
| print(f"Media currently unavailable: {key}: {e}") |
| return media_by_key |
|
|
|
|
| def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]: |
| """ |
| Filter metrics using regex pattern. |
| |
| Args: |
| metrics: List of metric names to filter |
| filter_pattern: Regex pattern to match against metric names |
| |
| Returns: |
| List of metric names that match the pattern |
| """ |
| if not filter_pattern.strip(): |
| return metrics |
|
|
| try: |
| pattern = re.compile(filter_pattern, re.IGNORECASE) |
| return [metric for metric in metrics if pattern.search(metric)] |
| except re.error: |
| return [ |
| metric for metric in metrics if filter_pattern.lower() in metric.lower() |
| ] |
|
|
|
|
| def refresh_runs_dropdown(project: str | None): |
| if project is None: |
| runs: list[str] = [] |
| else: |
| runs = get_runs(project) |
|
|
| color_palette = utils.get_color_palette() |
| colors = [color_palette[i % len(color_palette)] for i in range(len(runs))] |
|
|
| return ColoredDropdown( |
| choices=runs, |
| colors=colors, |
| value=runs[0] if runs else None, |
| placeholder=f"Select a run ({len(runs)})", |
| ) |
|
|
|
|
| with gr.Blocks() as media_page: |
| with gr.Sidebar() as sidebar: |
| logo_urls = utils.get_logo_urls() |
| logo = gr.Markdown( |
| f""" |
| <img src='{logo_urls["light"]}' width='80%' class='logo-light'> |
| <img src='{logo_urls["dark"]}' width='80%' class='logo-dark'> |
| """ |
| ) |
| project_dd = gr.Dropdown(label="Project", allow_custom_value=True) |
| runs_dropdown = ColoredDropdown(choices=[], colors=[], label="Run") |
|
|
| navbar = gr.Navbar( |
| value=[ |
| ("Metrics", ""), |
| ("Media & Tables", "/media"), |
| ("Runs", "/runs"), |
| ("Files", "/files"), |
| ], |
| main_page_name=False, |
| ) |
| timer = gr.Timer(value=1) |
|
|
| @gr.render( |
| triggers=[ |
| media_page.load, |
| runs_dropdown.change, |
| project_dd.change, |
| ], |
| inputs=[project_dd, runs_dropdown], |
| show_progress="hidden", |
| queue=False, |
| ) |
| def display_media_and_tables(project: str | None, selected_run: str | None): |
| if not project or not selected_run: |
| gr.Markdown("*Select a project and run to view media and tables*") |
| return |
|
|
| logs = SQLiteStorage.get_logs(project, selected_run) |
| if not logs: |
| gr.Markdown("*No data found for this run*") |
| return |
|
|
| df = pd.DataFrame(logs) |
|
|
| media_by_key = extract_media(logs) |
|
|
| has_media = media_by_key and any(media_by_key.values()) |
| has_tables = False |
|
|
| table_cols = df.select_dtypes(include="object").columns |
| table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS] |
| table_cols = [ |
| c |
| for c in table_cols |
| if not (metric_df := df.dropna(subset=[c])).empty |
| and isinstance(first_value := metric_df[c].iloc[0], dict) |
| and first_value.get("_type") == Table.TYPE |
| ] |
| has_tables = len(table_cols) > 0 |
|
|
| if not has_media and not has_tables: |
| gr.Markdown("*No media or tables found for this run*") |
| return |
|
|
| if has_media: |
| for key, media_items in media_by_key.items(): |
| image_and_video = [ |
| item |
| for item in media_items |
| if item.type in [TrackioImage.TYPE, TrackioVideo.TYPE] |
| ] |
| audio = [item for item in media_items if item.type == TrackioAudio.TYPE] |
| if image_and_video: |
| gr.Gallery( |
| [(item.file_path, item.caption) for item in image_and_video], |
| label=key, |
| columns=6, |
| elem_classes=("media-gallery"), |
| ) |
| if audio: |
| with gr.Accordion( |
| label=key, elem_classes=("media-audio-accordion") |
| ): |
| for i in range(0, len(audio), 3): |
| with gr.Row(elem_classes=("media-audio-row")): |
| for item in audio[i : i + 3]: |
| gr.Audio( |
| value=item.file_path, |
| label=item.caption, |
| elem_classes=("media-audio-item"), |
| ) |
|
|
| if has_tables: |
| with gr.Accordion(f"Tables ({len(table_cols)})", open=True): |
| with gr.Row(key="row"): |
| for metric_idx, metric_name in enumerate(table_cols): |
| metric_df = df.dropna(subset=[metric_name]) |
| if not metric_df.empty: |
| value = metric_df[metric_name] |
| first_value = value.iloc[0] |
| if ( |
| isinstance(first_value, dict) |
| and "_type" in first_value |
| and first_value["_type"] == Table.TYPE |
| ): |
| try: |
| with gr.Column(): |
| s = gr.Slider( |
| value=len(value), |
| minimum=1, |
| maximum=len(value), |
| step=1, |
| container=False, |
| visible=len(value) > 1, |
| interactive=True, |
| ) |
| processed_data = Table.to_display_format( |
| value.iloc[-1]["_value"] |
| ) |
| df_table = pd.DataFrame(processed_data) |
| table = gr.DataFrame( |
| df_table, |
| label=f"{metric_name} (index {len(value)})", |
| key=f"table-{metric_idx}", |
| wrap=True, |
| datatype="markdown", |
| preserved_by_key=None, |
| ) |
|
|
| def get_table_at_index(index: int): |
| value = metric_df[metric_name] |
| processed_data = Table.to_display_format( |
| value.iloc[index - 1]["_value"] |
| ) |
| df_ = pd.DataFrame(processed_data) |
| return gr.DataFrame( |
| df_, |
| label=f"{metric_name} (index {index})", |
| ) |
|
|
| s.input( |
| get_table_at_index, |
| inputs=s, |
| outputs=table, |
| show_progress="hidden", |
| ) |
| except Exception as e: |
| gr.Warning( |
| f"Column {metric_name} failed to render as a table: {e}" |
| ) |
|
|
| gr.on( |
| [timer.tick], |
| fn=lambda: gr.Dropdown(info=fns.get_project_info()), |
| outputs=[project_dd], |
| show_progress="hidden", |
| api_visibility="private", |
| ) |
|
|
| gr.on( |
| [media_page.load], |
| fn=fns.get_projects, |
| outputs=project_dd, |
| show_progress="hidden", |
| queue=False, |
| api_visibility="private", |
| ).then( |
| fns.update_navbar_value, |
| inputs=[project_dd], |
| outputs=[navbar], |
| show_progress="hidden", |
| api_visibility="private", |
| queue=False, |
| ) |
| gr.on( |
| [project_dd.change], |
| fn=refresh_runs_dropdown, |
| inputs=[project_dd], |
| outputs=[runs_dropdown], |
| show_progress="hidden", |
| queue=False, |
| api_visibility="private", |
| ).then( |
| fns.update_navbar_value, |
| inputs=[project_dd], |
| outputs=[navbar], |
| show_progress="hidden", |
| api_visibility="private", |
| queue=False, |
| ) |
|
|