File size: 8,288 Bytes
7a943a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
from typing import Any, Literal
from urllib.parse import quote

from trackio.media.media import TrackioMedia
from trackio.utils import MEDIA_DIR


class Table:
    """
    Initializes a Table object.

    Tables can be used to log tabular data including images, numbers, and text.

    Args:
        columns (`list[str]`, *optional*):
            Names of the columns in the table. Optional if `data` is provided. Not
            expected if `dataframe` is provided. Currently ignored.
        data (`list[list[Any]]`, *optional*):
            2D row-oriented array of values. Each value can be a number, a string
            (treated as Markdown and truncated if too long), or a `Trackio.Image` or
            list of `Trackio.Image` objects.
        dataframe (`pandas.DataFrame`, *optional*):
            DataFrame used to create the table. When set, `data` and `columns`
            arguments are ignored.
        rows (`list[list[Any]]`, *optional*):
            Currently ignored.
        optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
            Currently ignored.
        allow_mixed_types (`bool`, *optional*, defaults to `False`):
            Currently ignored.
        log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
            Currently ignored.
    """

    TYPE = "trackio.table"

    def __init__(
        self,
        columns: list[str] | None = None,
        data: list[list[Any]] | None = None,
        dataframe: Any | None = None,
        rows: list[list[Any]] | None = None,
        optional: bool | list[bool] = True,
        allow_mixed_types: bool = False,
        log_mode: Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] | None = "IMMUTABLE",
    ):
        # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
        # for now (like `rows`) they are included for API compat but don't do anything.
        self.data = self._normalize_rows(
            columns=columns, data=data, dataframe=dataframe
        )

    @staticmethod
    def _normalize_rows(
        columns: list[str] | None,
        data: list[list[Any]] | None,
        dataframe: Any | None,
    ) -> list[dict[str, Any]]:
        if dataframe is not None:
            try:
                records = dataframe.to_dict(orient="records")
            except Exception as e:
                raise TypeError(
                    "The `dataframe` argument must support `to_dict(orient='records')`."
                ) from e
            return [dict(row) for row in records]

        if data is None:
            return []

        if data and isinstance(data[0], dict):
            return [dict(row) for row in data]

        normalized_rows: list[dict[str, Any]] = []
        for row in data:
            row_dict: dict[str, Any] = {}
            if columns is None:
                for idx, value in enumerate(row):
                    row_dict[idx] = value
            else:
                for idx, column in enumerate(columns):
                    row_dict[column] = row[idx] if idx < len(row) else None
                for idx in range(len(columns), len(row)):
                    row_dict[idx] = row[idx]
            normalized_rows.append(row_dict)
        return normalized_rows

    def _has_media_objects(self, rows: list[dict[str, Any]]) -> bool:
        """Check if rows contain any TrackioMedia objects or lists of TrackioMedia objects."""
        for row in rows:
            for value in row.values():
                if isinstance(value, TrackioMedia):
                    return True
                if (
                    isinstance(value, list)
                    and len(value) > 0
                    and isinstance(value[0], TrackioMedia)
                ):
                    return True
        return False

    def _process_data(self, project: str, run: str, step: int = 0):
        """Convert rows to dict format, processing any TrackioMedia objects if present."""
        if not self._has_media_objects(self.data):
            return [dict(row) for row in self.data]

        processed_rows = [dict(row) for row in self.data]
        for row in processed_rows:
            for key, value in list(row.items()):
                if isinstance(value, TrackioMedia):
                    value._save(project, run, step)
                    row[key] = value._to_dict()
                if (
                    isinstance(value, list)
                    and len(value) > 0
                    and isinstance(value[0], TrackioMedia)
                ):
                    [v._save(project, run, step) for v in value]
                    row[key] = [v._to_dict() for v in value]

        return processed_rows

    @staticmethod
    def to_display_format(table_data: list[dict]) -> list[dict]:
        """
        Converts stored table data to display format for UI rendering.

        Note:
            This does not use the `self.data` attribute, but instead uses the
            `table_data` parameter, which is what the UI receives.

        Args:
            table_data (`list[dict]`):
                List of dictionaries representing table rows (from stored `_value`).

        Returns:
            `list[dict]`: Table data with images converted to markdown syntax and long
            text truncated.
        """
        truncate_length = int(os.getenv("TRACKIO_TABLE_TRUNCATE_LENGTH", "250"))

        def convert_image_to_markdown(image_data: dict) -> str:
            relative_path = image_data.get("file_path", "")
            caption = image_data.get("caption", "")
            absolute_path = MEDIA_DIR / relative_path
            return (
                f'<img src="/file?path={quote(str(absolute_path))}" alt="{caption}" />'
            )

        processed_data = []
        for row in table_data:
            processed_row = {}
            for key, value in row.items():
                if isinstance(value, dict) and value.get("_type") == "trackio.image":
                    processed_row[key] = convert_image_to_markdown(value)
                elif (
                    isinstance(value, list)
                    and len(value) > 0
                    and isinstance(value[0], dict)
                    and value[0].get("_type") == "trackio.image"
                ):
                    # This assumes that if the first item is an image, all items are images. Ok for now since we don't support mixed types in a single cell.
                    processed_row[key] = (
                        '<div style="display: flex; gap: 10px;">'
                        + "".join([convert_image_to_markdown(item) for item in value])
                        + "</div>"
                    )
                elif isinstance(value, str) and len(value) > truncate_length:
                    truncated = value[:truncate_length]
                    full_text = value.replace("<", "&lt;").replace(">", "&gt;")
                    processed_row[key] = (
                        f'<details style="display: inline;">'
                        f'<summary style="display: inline; cursor: pointer;">{truncated}…<span><em>(truncated, click to expand)</em></span></summary>'
                        f'<div style="margin-top: 10px; padding: 10px; background: #f5f5f5; border-radius: 4px; max-height: 400px; overflow: auto;">'
                        f'<pre style="white-space: pre-wrap; word-wrap: break-word; margin: 0;">{full_text}</pre>'
                        f"</div>"
                        f"</details>"
                    )
                else:
                    processed_row[key] = value
            processed_data.append(processed_row)
        return processed_data

    def _to_dict(self, project: str, run: str, step: int = 0):
        """
        Converts the table to a dictionary representation.

        Args:
            project (`str`):
                Project name for saving media files.
            run (`str`):
                Run name for saving media files.
            step (`int`, *optional*, defaults to `0`):
                Step number for saving media files.
        """
        data = self._process_data(project, run, step)
        return {
            "_type": self.TYPE,
            "_value": data,
        }