Spaces:
Sleeping
Sleeping
| """The Media and Tables page for the Trackio UI.""" | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| import pandas as pd | |
| import trackio.utils as utils | |
| from trackio.media import TrackioAudio, TrackioImage, TrackioVideo | |
| from trackio.sqlite_storage import SQLiteStorage | |
| from trackio.table import Table | |
| from trackio.ui import fns | |
| from trackio.ui.components.colored_dropdown import ColoredDropdown | |
| class MediaData: | |
| caption: str | None | |
| file_path: str | |
| type: str | |
| def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]: | |
| media_by_key: dict[str, list[MediaData]] = {} | |
| logs = sorted(logs, key=lambda x: x.get("step", 0)) | |
| for log in logs: | |
| for key, value in log.items(): | |
| if isinstance(value, dict): | |
| type = value.get("_type") | |
| if ( | |
| type == TrackioImage.TYPE | |
| or type == TrackioVideo.TYPE | |
| or type == TrackioAudio.TYPE | |
| ): | |
| if key not in media_by_key: | |
| media_by_key[key] = [] | |
| try: | |
| media_data = MediaData( | |
| file_path=utils.MEDIA_DIR / value.get("file_path"), | |
| type=type, | |
| caption=value.get("caption"), | |
| ) | |
| media_by_key[key].append(media_data) | |
| except Exception as e: | |
| print(f"Media currently unavailable: {key}: {e}") | |
| return media_by_key | |
| def is_table(value) -> bool: | |
| return isinstance(value, dict) and value.get("_type") == Table.TYPE | |
| def refresh_runs_dropdown(project: str | None): | |
| if project is None: | |
| runs: list[str] = [] | |
| else: | |
| runs = fns.get_runs(project) | |
| color_palette = utils.get_color_palette() | |
| colors = [color_palette[i % len(color_palette)] for i in range(len(runs))] | |
| return ColoredDropdown( | |
| choices=runs, | |
| colors=colors, | |
| value=runs[0] if runs else None, | |
| placeholder=f"Select a run ({len(runs)})", | |
| ) | |
| with gr.Blocks() as media_page: | |
| with gr.Sidebar() as sidebar: | |
| logo = fns.create_logo() | |
| project_dd = fns.create_project_dropdown() | |
| runs_dropdown = ColoredDropdown(choices=[], colors=[], label="Run") | |
| navbar = fns.create_navbar() | |
| timer = gr.Timer(value=1) | |
| fns.setup_alert_notifications(timer, project_dd) | |
| def display_media_and_tables(project: str | None, selected_run: str | None): | |
| if not project or not selected_run: | |
| gr.Markdown("*Select a project and run to view media and tables*") | |
| return | |
| logs = SQLiteStorage.get_logs(project, selected_run) | |
| if not logs: | |
| gr.Markdown("*No data found for this run*") | |
| return | |
| df = pd.DataFrame(logs) | |
| media_by_key = extract_media(logs) | |
| has_media = media_by_key and any(media_by_key.values()) | |
| has_tables = False | |
| table_cols = df.select_dtypes(include="object").columns | |
| table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS] | |
| table_cols = [c for c in table_cols if any(is_table(x) for x in df[c])] | |
| has_tables = len(table_cols) > 0 | |
| if not has_media and not has_tables: | |
| gr.Markdown( | |
| """ | |
| ## No Media or Tables Available | |
| Media and tables will appear here once logged. To log media or tables: | |
| ```python | |
| import trackio | |
| import numpy as np | |
| run = trackio.init(project="my-project") | |
| # Log an image (from numpy array, PIL Image, or file path) | |
| image = trackio.Image(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8), caption="Sample") | |
| trackio.log({"my_image": image}) | |
| # Log a table (from a pandas DataFrame) | |
| import pandas as pd | |
| df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) | |
| trackio.log({"my_table": trackio.Table(dataframe=df)}) | |
| ``` | |
| """ | |
| ) | |
| return | |
| if has_media: | |
| for key, media_items in media_by_key.items(): | |
| image_and_video = [ | |
| item | |
| for item in media_items | |
| if item.type in [TrackioImage.TYPE, TrackioVideo.TYPE] | |
| ] | |
| audio = [item for item in media_items if item.type == TrackioAudio.TYPE] | |
| if image_and_video: | |
| gr.Gallery( | |
| [(item.file_path, item.caption) for item in image_and_video], | |
| label=key, | |
| columns=6, | |
| elem_classes=("media-gallery"), | |
| ) | |
| if audio: | |
| with gr.Accordion( | |
| label=key, elem_classes=("media-audio-accordion") | |
| ): | |
| for i in range(0, len(audio), 3): | |
| with gr.Row(elem_classes=("media-audio-row")): | |
| for item in audio[i : i + 3]: | |
| gr.Audio( | |
| value=item.file_path, | |
| label=item.caption, | |
| elem_classes=("media-audio-item"), | |
| ) | |
| if has_tables: | |
| with gr.Accordion(f"Tables ({len(table_cols)})", open=True): | |
| with gr.Row(key="row"): | |
| for metric_idx, metric_name in enumerate(table_cols): | |
| metric_df = df[df[metric_name].apply(is_table)] | |
| if not metric_df.empty: | |
| value = metric_df[metric_name] | |
| first_value = value.iloc[0] | |
| if ( | |
| isinstance(first_value, dict) | |
| and "_type" in first_value | |
| and first_value["_type"] == Table.TYPE | |
| ): | |
| try: | |
| with gr.Column(): | |
| s = gr.Slider( | |
| value=len(value), | |
| minimum=1, | |
| maximum=len(value), | |
| step=1, | |
| container=False, | |
| visible=len(value) > 1, | |
| interactive=True, | |
| ) | |
| processed_data = Table.to_display_format( | |
| value.iloc[-1]["_value"] | |
| ) | |
| df_table = pd.DataFrame(processed_data) | |
| table = gr.DataFrame( | |
| df_table, | |
| label=f"{metric_name} (index {len(value)})", | |
| key=f"table-{metric_idx}", | |
| wrap=True, | |
| datatype="markdown", | |
| preserved_by_key=None, | |
| ) | |
| def make_table_renderer( | |
| capture_df, capture_name | |
| ): | |
| def get_table_at_index(index: int): | |
| value = capture_df[capture_name] | |
| processed_data = ( | |
| Table.to_display_format( | |
| value.iloc[index - 1]["_value"] | |
| ) | |
| ) | |
| df_ = pd.DataFrame(processed_data) | |
| return gr.DataFrame( | |
| df_, | |
| label=f"{capture_name} (index {index})", | |
| ) | |
| return get_table_at_index | |
| get_table_at_index = make_table_renderer( | |
| metric_df, metric_name | |
| ) | |
| s.input( | |
| get_table_at_index, | |
| inputs=s, | |
| outputs=table, | |
| show_progress="hidden", | |
| ) | |
| except Exception as e: | |
| gr.Warning( | |
| f"Column {metric_name} failed to render as a table: {e}" | |
| ) | |
| gr.on( | |
| [timer.tick], | |
| fn=lambda: gr.Dropdown(info=fns.get_project_info()), | |
| outputs=[project_dd], | |
| show_progress="hidden", | |
| api_visibility="private", | |
| ) | |
| gr.on( | |
| [media_page.load], | |
| fn=fns.get_projects, | |
| outputs=project_dd, | |
| show_progress="hidden", | |
| queue=False, | |
| api_visibility="private", | |
| ).then( | |
| fns.update_navbar_value, | |
| inputs=[project_dd], | |
| outputs=[navbar], | |
| show_progress="hidden", | |
| api_visibility="private", | |
| queue=False, | |
| ) | |
| gr.on( | |
| [project_dd.change], | |
| fn=refresh_runs_dropdown, | |
| inputs=[project_dd], | |
| outputs=[runs_dropdown], | |
| show_progress="hidden", | |
| queue=False, | |
| api_visibility="private", | |
| ).then( | |
| fns.update_navbar_value, | |
| inputs=[project_dd], | |
| outputs=[navbar], | |
| show_progress="hidden", | |
| api_visibility="private", | |
| queue=False, | |
| ) | |