kashif HF Staff commited on
Commit
d3cd6c8
·
verified ·
1 Parent(s): fd03ef2

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. trackio/CHANGELOG.md +104 -0
  3. trackio/__init__.py +543 -0
  4. trackio/__pycache__/__init__.cpython-312.pyc +0 -0
  5. trackio/__pycache__/commit_scheduler.cpython-312.pyc +0 -0
  6. trackio/__pycache__/context_vars.cpython-312.pyc +0 -0
  7. trackio/__pycache__/deploy.cpython-312.pyc +0 -0
  8. trackio/__pycache__/dummy_commit_scheduler.cpython-312.pyc +0 -0
  9. trackio/__pycache__/gpu.cpython-312.pyc +0 -0
  10. trackio/__pycache__/histogram.cpython-312.pyc +0 -0
  11. trackio/__pycache__/imports.cpython-312.pyc +0 -0
  12. trackio/__pycache__/run.cpython-312.pyc +0 -0
  13. trackio/__pycache__/sqlite_storage.cpython-312.pyc +0 -0
  14. trackio/__pycache__/table.cpython-312.pyc +0 -0
  15. trackio/__pycache__/typehints.cpython-312.pyc +0 -0
  16. trackio/__pycache__/utils.cpython-312.pyc +0 -0
  17. trackio/assets/badge.png +0 -0
  18. trackio/assets/trackio_logo_dark.png +0 -0
  19. trackio/assets/trackio_logo_light.png +0 -0
  20. trackio/assets/trackio_logo_old.png +3 -0
  21. trackio/assets/trackio_logo_type_dark.png +0 -0
  22. trackio/assets/trackio_logo_type_dark_transparent.png +0 -0
  23. trackio/assets/trackio_logo_type_light.png +0 -0
  24. trackio/assets/trackio_logo_type_light_transparent.png +0 -0
  25. trackio/cli.py +93 -0
  26. trackio/commit_scheduler.py +391 -0
  27. trackio/context_vars.py +21 -0
  28. trackio/deploy.py +363 -0
  29. trackio/dummy_commit_scheduler.py +12 -0
  30. trackio/gpu.py +368 -0
  31. trackio/histogram.py +71 -0
  32. trackio/imports.py +304 -0
  33. trackio/media/__init__.py +27 -0
  34. trackio/media/__pycache__/__init__.cpython-312.pyc +0 -0
  35. trackio/media/__pycache__/audio.cpython-312.pyc +0 -0
  36. trackio/media/__pycache__/image.cpython-312.pyc +0 -0
  37. trackio/media/__pycache__/media.cpython-312.pyc +0 -0
  38. trackio/media/__pycache__/utils.cpython-312.pyc +0 -0
  39. trackio/media/__pycache__/video.cpython-312.pyc +0 -0
  40. trackio/media/audio.py +167 -0
  41. trackio/media/image.py +84 -0
  42. trackio/media/media.py +79 -0
  43. trackio/media/utils.py +60 -0
  44. trackio/media/video.py +246 -0
  45. trackio/package.json +6 -0
  46. trackio/py.typed +0 -0
  47. trackio/run.py +283 -0
  48. trackio/sqlite_storage.py +874 -0
  49. trackio/table.py +171 -0
  50. trackio/typehints.py +26 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trackio/assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
trackio/CHANGELOG.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # trackio
2
+
3
+ ## 0.13.1
4
+
5
+ ### Features
6
+
7
+ - [#369](https://github.com/gradio-app/trackio/pull/369) [`767e9fe`](https://github.com/gradio-app/trackio/commit/767e9fe095d7c6ed102016caf927c1517fb8618c) - tiny pr removing unnecessary code. Thanks @abidlabs!
8
+
9
+ ## 0.13.0
10
+
11
+ ### Features
12
+
13
+ - [#358](https://github.com/gradio-app/trackio/pull/358) [`073715d`](https://github.com/gradio-app/trackio/commit/073715d1caf8282f68890117f09c3ac301205312) - Improvements to `trackio.sync()`. Thanks @abidlabs!
14
+
15
+ ## 0.12.0
16
+
17
+ ### Features
18
+
19
+ - [#357](https://github.com/gradio-app/trackio/pull/357) [`02ba815`](https://github.com/gradio-app/trackio/commit/02ba815358060f1966052de051a5bdb09702920e) - Redesign media and tables to show up on separate page. Thanks @abidlabs!
20
+ - [#359](https://github.com/gradio-app/trackio/pull/359) [`08fe9c9`](https://github.com/gradio-app/trackio/commit/08fe9c9ddd7fe99ee811555fdfb62df9ab88e939) - docs: Improve docstrings. Thanks @qgallouedec!
21
+
22
+ ## 0.11.0
23
+
24
+ ### Features
25
+
26
+ - [#355](https://github.com/gradio-app/trackio/pull/355) [`ea51f49`](https://github.com/gradio-app/trackio/commit/ea51f4954922f21be76ef828700420fe9a912c4b) - Color code run checkboxes and match with plot lines. Thanks @abidlabs!
27
+ - [#353](https://github.com/gradio-app/trackio/pull/353) [`8abe691`](https://github.com/gradio-app/trackio/commit/8abe6919aeefe21fc7a23af814883efbb037c21f) - Remove show_api from demo.launch. Thanks @sergiopaniego!
28
+ - [#351](https://github.com/gradio-app/trackio/pull/351) [`8a8957e`](https://github.com/gradio-app/trackio/commit/8a8957e530dd7908d1fef7f2df030303f808101f) - Add `trackio.save()`. Thanks @abidlabs!
29
+
30
+ ## 0.10.0
31
+
32
+ ### Features
33
+
34
+ - [#305](https://github.com/gradio-app/trackio/pull/305) [`e64883a`](https://github.com/gradio-app/trackio/commit/e64883a51f7b8b93f7d48b8afe55acdb62238b71) - bump to gradio 6.0, make `trackio` compatible, and fix related issues. Thanks @abidlabs!
35
+
36
+ ## 0.9.1
37
+
38
+ ### Features
39
+
40
+ - [#344](https://github.com/gradio-app/trackio/pull/344) [`7e01024`](https://github.com/gradio-app/trackio/commit/7e010241d9a34794e0ce0dc19c1a6f0cf94ba856) - Avoid redundant calls to /whoami-v2. Thanks @Wauplin!
41
+
42
+ ## 0.9.0
43
+
44
+ ### Features
45
+
46
+ - [#343](https://github.com/gradio-app/trackio/pull/343) [`51bea30`](https://github.com/gradio-app/trackio/commit/51bea30f2877adff8e6497466d3a799400a0a049) - Sync offline projects to Hugging Face spaces. Thanks @candemircan!
47
+ - [#341](https://github.com/gradio-app/trackio/pull/341) [`4fd841f`](https://github.com/gradio-app/trackio/commit/4fd841fa190e15071b02f6fba7683ef4f393a654) - Adds a basic UI test to `trackio`. Thanks @abidlabs!
48
+ - [#339](https://github.com/gradio-app/trackio/pull/339) [`011d91b`](https://github.com/gradio-app/trackio/commit/011d91bb6ae266516fd250a349285670a8049d05) - Allow customzing the trackio color palette. Thanks @abidlabs!
49
+
50
+ ## 0.8.1
51
+
52
+ ### Features
53
+
54
+ - [#336](https://github.com/gradio-app/trackio/pull/336) [`5f9f51d`](https://github.com/gradio-app/trackio/commit/5f9f51dac8677f240d7c42c3e3b2660a22aee138) - Support a list of `Trackio.Image` in a `trackio.Table` cell. Thanks @abidlabs!
55
+
56
+ ## 0.8.0
57
+
58
+ ### Features
59
+
60
+ - [#331](https://github.com/gradio-app/trackio/pull/331) [`2c02d0f`](https://github.com/gradio-app/trackio/commit/2c02d0fd0a5824160528782402bb0dd4083396d5) - Truncate table string values that are greater than 250 characters (configuirable via env variable). Thanks @abidlabs!
61
+ - [#324](https://github.com/gradio-app/trackio/pull/324) [`50b2122`](https://github.com/gradio-app/trackio/commit/50b2122e7965ac82a72e6cb3b7d048bc10a2a6b1) - Add log y-axis functionality to UI. Thanks @abidlabs!
62
+ - [#326](https://github.com/gradio-app/trackio/pull/326) [`61dc1f4`](https://github.com/gradio-app/trackio/commit/61dc1f40af2f545f8e70395ddf0dbb8aee6b60d5) - Fix: improve table rendering for metrics in Trackio Dashboard. Thanks @vigneshwaran!
63
+ - [#328](https://github.com/gradio-app/trackio/pull/328) [`6857cbb`](https://github.com/gradio-app/trackio/commit/6857cbbe557a59a4642f210ec42566d108294e63) - Support trackio.Table with trackio.Image columns. Thanks @abidlabs!
64
+
65
+ ## 0.7.0
66
+
67
+ ### Features
68
+
69
+ - [#277](https://github.com/gradio-app/trackio/pull/277) [`db35601`](https://github.com/gradio-app/trackio/commit/db35601b9c023423c4654c9909b8ab73e58737de) - fix: make grouped runs view reflect live updates. Thanks @Saba9!
70
+ - [#320](https://github.com/gradio-app/trackio/pull/320) [`24ae739`](https://github.com/gradio-app/trackio/commit/24ae73969b09fb3126acd2f91647cdfbf8cf72a1) - Add additional query parms for xmin, xmax, and smoothing. Thanks @abidlabs!
71
+ - [#270](https://github.com/gradio-app/trackio/pull/270) [`cd1dfc3`](https://github.com/gradio-app/trackio/commit/cd1dfc3dc641b4499ac6d4a1b066fa8e2b52c57b) - feature: add support for logging audio. Thanks @Saba9!
72
+
73
+ ## 0.6.0
74
+
75
+ ### Features
76
+
77
+ - [#309](https://github.com/gradio-app/trackio/pull/309) [`1df2353`](https://github.com/gradio-app/trackio/commit/1df23534d6c01938c8db9c0f584ffa23e8d6021d) - Add histogram support with wandb-compatible API. Thanks @abidlabs!
78
+ - [#315](https://github.com/gradio-app/trackio/pull/315) [`76ba060`](https://github.com/gradio-app/trackio/commit/76ba06055dc43ca8f03b79f3e72d761949bd19a8) - Add guards to avoid silent fails. Thanks @Xmaster6y!
79
+ - [#313](https://github.com/gradio-app/trackio/pull/313) [`a606b3e`](https://github.com/gradio-app/trackio/commit/a606b3e1c5edf3d4cf9f31bd50605226a5a1c5d0) - No longer prevent certain keys from being used. Instead, dunderify them to prevent collisions with internal usage. Thanks @abidlabs!
80
+ - [#317](https://github.com/gradio-app/trackio/pull/317) [`27370a5`](https://github.com/gradio-app/trackio/commit/27370a595d0dbdf7eebbe7159d2ba778f039da44) - quick fixes for trackio.histogram. Thanks @abidlabs!
81
+ - [#312](https://github.com/gradio-app/trackio/pull/312) [`aa0f3bf`](https://github.com/gradio-app/trackio/commit/aa0f3bf372e7a0dd592a38af699c998363830eeb) - Fix video logging by adding TRACKIO_DIR to allowed_paths. Thanks @abidlabs!
82
+
83
+ ## 0.5.3
84
+
85
+ ### Features
86
+
87
+ - [#300](https://github.com/gradio-app/trackio/pull/300) [`5e4cacf`](https://github.com/gradio-app/trackio/commit/5e4cacf2e7ce527b4ce60de3a5bc05d2c02c77fb) - Adds more environment variables to allow customization of Trackio dashboard. Thanks @abidlabs!
88
+
89
+ ## 0.5.2
90
+
91
+ ### Features
92
+
93
+ - [#293](https://github.com/gradio-app/trackio/pull/293) [`64afc28`](https://github.com/gradio-app/trackio/commit/64afc28d3ea1dfd821472dc6bf0b8ed35a9b74be) - Ensures that the TRACKIO_DIR environment variable is respected. Thanks @abidlabs!
94
+ - [#287](https://github.com/gradio-app/trackio/pull/287) [`cd3e929`](https://github.com/gradio-app/trackio/commit/cd3e9294320949e6b8b829239069a43d5d7ff4c1) - fix(sqlite): unify .sqlite extension, allow export when DBs exist, clean WAL sidecars on import. Thanks @vaibhav-research!
95
+
96
+ ### Fixes
97
+
98
+ - [#291](https://github.com/gradio-app/trackio/pull/291) [`3b5adc3`](https://github.com/gradio-app/trackio/commit/3b5adc3d1f452dbab7a714d235f4974782f93730) - Fix the wheel build. Thanks @pngwn!
99
+
100
+ ## 0.5.1
101
+
102
+ ### Fixes
103
+
104
+ - [#278](https://github.com/gradio-app/trackio/pull/278) [`314c054`](https://github.com/gradio-app/trackio/commit/314c05438007ddfea3383e06fd19143e27468e2d) - Fix row orientation of metrics plots. Thanks @abidlabs!
trackio/__init__.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import logging
4
+ import os
5
+ import warnings
6
+ import webbrowser
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import huggingface_hub
11
+ from gradio.themes import ThemeClass
12
+ from gradio.utils import TupleNoPrint
13
+ from gradio_client import Client, handle_file
14
+ from huggingface_hub import SpaceStorage
15
+ from huggingface_hub.errors import LocalTokenNotFoundError
16
+
17
+ from trackio import context_vars, deploy, utils
18
+ from trackio.deploy import sync
19
+ from trackio.gpu import gpu_available, log_gpu
20
+ from trackio.histogram import Histogram
21
+ from trackio.imports import import_csv, import_tf_events
22
+ from trackio.media import TrackioAudio, TrackioImage, TrackioVideo
23
+ from trackio.run import Run
24
+ from trackio.sqlite_storage import SQLiteStorage
25
+ from trackio.table import Table
26
+ from trackio.typehints import UploadEntry
27
+ from trackio.ui.main import CSS, HEAD, demo
28
+ from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
29
+
30
+ logging.getLogger("httpx").setLevel(logging.WARNING)
31
+
32
+ warnings.filterwarnings(
33
+ "ignore",
34
+ message="Empty session being created. Install gradio\\[oauth\\]",
35
+ category=UserWarning,
36
+ module="gradio.helpers",
37
+ )
38
+
39
+ __version__ = json.loads(Path(__file__).parent.joinpath("package.json").read_text())[
40
+ "version"
41
+ ]
42
+
43
+ __all__ = [
44
+ "init",
45
+ "log",
46
+ "log_gpu",
47
+ "finish",
48
+ "show",
49
+ "sync",
50
+ "delete_project",
51
+ "import_csv",
52
+ "import_tf_events",
53
+ "save",
54
+ "Image",
55
+ "Video",
56
+ "Audio",
57
+ "Table",
58
+ "Histogram",
59
+ ]
60
+
61
+ Image = TrackioImage
62
+ Video = TrackioVideo
63
+ Audio = TrackioAudio
64
+
65
+
66
+ config = {}
67
+
68
+
69
+ def init(
70
+ project: str,
71
+ name: str | None = None,
72
+ group: str | None = None,
73
+ space_id: str | None = None,
74
+ space_storage: SpaceStorage | None = None,
75
+ dataset_id: str | None = None,
76
+ config: dict | None = None,
77
+ resume: str = "never",
78
+ settings: Any = None,
79
+ private: bool | None = None,
80
+ embed: bool = True,
81
+ auto_log_gpu: bool | None = None,
82
+ gpu_log_interval: float = 10.0,
83
+ ) -> Run:
84
+ """
85
+ Creates a new Trackio project and returns a [`Run`] object.
86
+
87
+ Args:
88
+ project (`str`):
89
+ The name of the project (can be an existing project to continue tracking or
90
+ a new project to start tracking from scratch).
91
+ name (`str`, *optional*):
92
+ The name of the run (if not provided, a default name will be generated).
93
+ group (`str`, *optional*):
94
+ The name of the group which this run belongs to in order to help organize
95
+ related runs together. You can toggle the entire group's visibilitiy in the
96
+ dashboard.
97
+ space_id (`str`, *optional*):
98
+ If provided, the project will be logged to a Hugging Face Space instead of
99
+ a local directory. Should be a complete Space name like
100
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
101
+ case the Space will be created in the currently-logged-in Hugging Face
102
+ user's namespace. If the Space does not exist, it will be created. If the
103
+ Space already exists, the project will be logged to it.
104
+ space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
105
+ Choice of persistent storage tier.
106
+ dataset_id (`str`, *optional*):
107
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
108
+ created and the metrics will be synced to it every 5 minutes. Specify a
109
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
110
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
111
+ or `None` (uses the same name as the Space but with the `"_dataset"`
112
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
113
+ already exists, the project will be appended to it.
114
+ config (`dict`, *optional*):
115
+ A dictionary of configuration options. Provided for compatibility with
116
+ `wandb.init()`.
117
+ resume (`str`, *optional*, defaults to `"never"`):
118
+ Controls how to handle resuming a run. Can be one of:
119
+
120
+ - `"must"`: Must resume the run with the given name, raises error if run
121
+ doesn't exist
122
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
123
+ - `"never"`: Never resume a run, always create a new one
124
+ private (`bool`, *optional*):
125
+ Whether to make the Space private. If None (default), the repo will be
126
+ public unless the organization's default is private. This value is ignored
127
+ if the repo already exists.
128
+ settings (`Any`, *optional*):
129
+ Not used. Provided for compatibility with `wandb.init()`.
130
+ embed (`bool`, *optional*, defaults to `True`):
131
+ If running inside a jupyter/Colab notebook, whether the dashboard should
132
+ automatically be embedded in the cell when trackio.init() is called.
133
+ auto_log_gpu (`bool` or `None`, *optional*, defaults to `None`):
134
+ Controls automatic GPU metrics logging. If `None` (default), GPU logging
135
+ is automatically enabled when `nvidia-ml-py` is installed and an NVIDIA
136
+ GPU is detected. Set to `True` to force enable or `False` to disable.
137
+ gpu_log_interval (`float`, *optional*, defaults to `10.0`):
138
+ The interval in seconds between automatic GPU metric logs.
139
+ Only used when `auto_log_gpu=True`.
140
+
141
+ Returns:
142
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
143
+ """
144
+ if settings is not None:
145
+ warnings.warn(
146
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
147
+ )
148
+
149
+ if space_id is None and dataset_id is not None:
150
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
151
+ try:
152
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(
153
+ space_id, dataset_id
154
+ )
155
+ except LocalTokenNotFoundError as e:
156
+ raise LocalTokenNotFoundError(
157
+ f"You must be logged in to Hugging Face locally when `space_id` is provided to deploy to a Space. {e}"
158
+ ) from e
159
+ url = context_vars.current_server.get()
160
+ share_url = context_vars.current_share_server.get()
161
+
162
+ if url is None:
163
+ if space_id is None:
164
+ _, url, share_url = demo.launch(
165
+ css=CSS,
166
+ head=HEAD,
167
+ footer_links=["gradio", "settings"],
168
+ inline=False,
169
+ quiet=True,
170
+ prevent_thread_lock=True,
171
+ show_error=True,
172
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
173
+ allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
174
+ )
175
+ context_vars.current_space_id.set(None)
176
+ else:
177
+ url = space_id
178
+ share_url = None
179
+ context_vars.current_space_id.set(space_id)
180
+
181
+ context_vars.current_server.set(url)
182
+ context_vars.current_share_server.set(share_url)
183
+ if (
184
+ context_vars.current_project.get() is None
185
+ or context_vars.current_project.get() != project
186
+ ):
187
+ print(f"* Trackio project initialized: {project}")
188
+
189
+ if dataset_id is not None:
190
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
191
+ print(
192
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
193
+ )
194
+ if space_id is None:
195
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
196
+ if utils.is_in_notebook() and embed:
197
+ base_url = share_url + "/" if share_url else url
198
+ full_url = utils.get_full_url(
199
+ base_url, project=project, write_token=demo.write_token, footer=True
200
+ )
201
+ utils.embed_url_in_notebook(full_url)
202
+ else:
203
+ utils.print_dashboard_instructions(project)
204
+ else:
205
+ deploy.create_space_if_not_exists(
206
+ space_id, space_storage, dataset_id, private
207
+ )
208
+ user_name, space_name = space_id.split("/")
209
+ space_url = deploy.SPACE_HOST_URL.format(
210
+ user_name=user_name, space_name=space_name
211
+ )
212
+ print(f"* View dashboard by going to: {space_url}")
213
+ if utils.is_in_notebook() and embed:
214
+ utils.embed_url_in_notebook(space_url)
215
+ context_vars.current_project.set(project)
216
+
217
+ client = None
218
+ if not space_id:
219
+ client = Client(url, verbose=False)
220
+
221
+ if resume == "must":
222
+ if name is None:
223
+ raise ValueError("Must provide a run name when resume='must'")
224
+ if name not in SQLiteStorage.get_runs(project):
225
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
226
+ resumed = True
227
+ elif resume == "allow":
228
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
229
+ elif resume == "never":
230
+ if name is not None and name in SQLiteStorage.get_runs(project):
231
+ warnings.warn(
232
+ f"* Warning: resume='never' but a run '{name}' already exists in "
233
+ f"project '{project}'. Generating a new name and instead. If you want "
234
+ "to resume this run, call init() with resume='must' or resume='allow'."
235
+ )
236
+ name = None
237
+ resumed = False
238
+ else:
239
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
240
+
241
+ if auto_log_gpu is None:
242
+ auto_log_gpu = gpu_available()
243
+ if auto_log_gpu:
244
+ print("* GPU detected, enabling automatic GPU metrics logging")
245
+
246
+ run = Run(
247
+ url=url,
248
+ project=project,
249
+ client=client,
250
+ name=name,
251
+ group=group,
252
+ config=config,
253
+ space_id=space_id,
254
+ auto_log_gpu=auto_log_gpu,
255
+ gpu_log_interval=gpu_log_interval,
256
+ )
257
+
258
+ if resumed:
259
+ print(f"* Resumed existing run: {run.name}")
260
+ else:
261
+ print(f"* Created new run: {run.name}")
262
+
263
+ context_vars.current_run.set(run)
264
+ globals()["config"] = run.config
265
+ return run
266
+
267
+
268
+ def log(metrics: dict, step: int | None = None) -> None:
269
+ """
270
+ Logs metrics to the current run.
271
+
272
+ Args:
273
+ metrics (`dict`):
274
+ A dictionary of metrics to log.
275
+ step (`int`, *optional*):
276
+ The step number. If not provided, the step will be incremented
277
+ automatically.
278
+ """
279
+ run = context_vars.current_run.get()
280
+ if run is None:
281
+ raise RuntimeError("Call trackio.init() before trackio.log().")
282
+ run.log(
283
+ metrics=metrics,
284
+ step=step,
285
+ )
286
+
287
+
288
+ def finish():
289
+ """
290
+ Finishes the current run.
291
+ """
292
+ run = context_vars.current_run.get()
293
+ if run is None:
294
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
295
+ run.finish()
296
+
297
+
298
+ def delete_project(project: str, force: bool = False) -> bool:
299
+ """
300
+ Deletes a project by removing its local SQLite database.
301
+
302
+ Args:
303
+ project (`str`):
304
+ The name of the project to delete.
305
+ force (`bool`, *optional*, defaults to `False`):
306
+ If `True`, deletes the project without prompting for confirmation.
307
+ If `False`, prompts the user to confirm before deleting.
308
+
309
+ Returns:
310
+ `bool`: `True` if the project was deleted, `False` otherwise.
311
+ """
312
+ db_path = SQLiteStorage.get_project_db_path(project)
313
+
314
+ if not db_path.exists():
315
+ print(f"* Project '{project}' does not exist.")
316
+ return False
317
+
318
+ if not force:
319
+ response = input(
320
+ f"Are you sure you want to delete project '{project}'? "
321
+ f"This will permanently delete all runs and metrics. (y/N): "
322
+ )
323
+ if response.lower() not in ["y", "yes"]:
324
+ print("* Deletion cancelled.")
325
+ return False
326
+
327
+ try:
328
+ db_path.unlink()
329
+
330
+ for suffix in ("-wal", "-shm"):
331
+ sidecar = Path(str(db_path) + suffix)
332
+ if sidecar.exists():
333
+ sidecar.unlink()
334
+
335
+ print(f"* Project '{project}' has been deleted.")
336
+ return True
337
+ except Exception as e:
338
+ print(f"* Error deleting project '{project}': {e}")
339
+ return False
340
+
341
+
342
+ def save(
343
+ glob_str: str | Path,
344
+ project: str | None = None,
345
+ ) -> str:
346
+ """
347
+ Saves files to a project (not linked to a specific run). If Trackio is running
348
+ locally, the file(s) will be moved to the project's files directory. If Trackio is
349
+ running in a Space, the file(s) will be uploaded to the Space's files directory.
350
+
351
+ Args:
352
+ glob_str (`str` or `Path`):
353
+ The file path or glob pattern to save. Can be a single file or a pattern
354
+ matching multiple files (e.g., `"*.py"`, `"models/**/*.pth"`).
355
+ project (`str`, *optional*):
356
+ The name of the project to save files to. If not provided, uses the current
357
+ project from `trackio.init()`. If no project is initialized, raises an
358
+ error.
359
+
360
+ Returns:
361
+ `str`: The path where the file(s) were saved (project's files directory).
362
+
363
+ Example:
364
+ ```python
365
+ import trackio
366
+
367
+ trackio.init(project="my-project")
368
+ trackio.save("config.yaml")
369
+ trackio.save("models/*.pth")
370
+ ```
371
+ """
372
+ if project is None:
373
+ project = context_vars.current_project.get()
374
+ if project is None:
375
+ raise RuntimeError(
376
+ "No project specified. Either call trackio.init() first or provide a "
377
+ "project parameter to trackio.save()."
378
+ )
379
+
380
+ glob_str = Path(glob_str)
381
+ base_path = Path.cwd().resolve()
382
+
383
+ matched_files = []
384
+ if glob_str.is_file():
385
+ matched_files = [glob_str.resolve()]
386
+ else:
387
+ pattern = str(glob_str)
388
+ if not glob_str.is_absolute():
389
+ pattern = str((Path.cwd() / glob_str).resolve())
390
+ matched_files = [
391
+ Path(f).resolve()
392
+ for f in glob.glob(pattern, recursive=True)
393
+ if Path(f).is_file()
394
+ ]
395
+
396
+ if not matched_files:
397
+ raise ValueError(f"No files found matching pattern: {glob_str}")
398
+
399
+ url = context_vars.current_server.get()
400
+ current_run = context_vars.current_run.get()
401
+
402
+ upload_entries = []
403
+
404
+ for file_path in matched_files:
405
+ try:
406
+ relative_to_base = file_path.relative_to(base_path)
407
+ except ValueError:
408
+ relative_to_base = Path(file_path.name)
409
+
410
+ if current_run is not None:
411
+ # If a run is active, use its queue to upload the file to the project's files directory
412
+ # as it's more efficent than uploading files one by one. But we should not use the run name
413
+ # as the files should be stored in the project's files directory, not the run's, hence
414
+ # the use_run_name flag is set to False.
415
+ current_run._queue_upload(
416
+ file_path,
417
+ step=None,
418
+ relative_path=str(relative_to_base.parent),
419
+ use_run_name=False,
420
+ )
421
+ else:
422
+ upload_entry: UploadEntry = {
423
+ "project": project,
424
+ "run": None,
425
+ "step": None,
426
+ "relative_path": str(relative_to_base),
427
+ "uploaded_file": handle_file(file_path),
428
+ }
429
+ upload_entries.append(upload_entry)
430
+
431
+ if upload_entries:
432
+ if url is None:
433
+ raise RuntimeError(
434
+ "No server available. Call trackio.init() before trackio.save() to start the server."
435
+ )
436
+
437
+ try:
438
+ client = Client(url, verbose=False, httpx_kwargs={"timeout": 90})
439
+ client.predict(
440
+ api_name="/bulk_upload_media",
441
+ uploads=upload_entries,
442
+ hf_token=huggingface_hub.utils.get_token(),
443
+ )
444
+ except Exception as e:
445
+ warnings.warn(
446
+ f"Failed to upload files: {e}. "
447
+ "Files may not be available in the dashboard."
448
+ )
449
+
450
+ return str(utils.MEDIA_DIR / project / "files")
451
+
452
+
453
+ def show(
454
+ project: str | None = None,
455
+ *,
456
+ theme: str | ThemeClass | None = None,
457
+ mcp_server: bool | None = None,
458
+ footer: bool = True,
459
+ color_palette: list[str] | None = None,
460
+ open_browser: bool = True,
461
+ block_thread: bool | None = None,
462
+ ):
463
+ """
464
+ Launches the Trackio dashboard.
465
+
466
+ Args:
467
+ project (`str`, *optional*):
468
+ The name of the project whose runs to show. If not provided, all projects
469
+ will be shown and the user can select one.
470
+ theme (`str` or `ThemeClass`, *optional*):
471
+ A Gradio Theme to use for the dashboard instead of the default Gradio theme,
472
+ can be a built-in theme (e.g. `'soft'`, `'citrus'`), a theme from the Hub
473
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class. If not provided, the
474
+ `TRACKIO_THEME` environment variable will be used, or if that is not set,
475
+ the default Gradio theme will be used.
476
+ mcp_server (`bool`, *optional*):
477
+ If `True`, the Trackio dashboard will be set up as an MCP server and certain
478
+ functions will be added as MCP tools. If `None` (default behavior), then the
479
+ `GRADIO_MCP_SERVER` environment variable will be used to determine if the
480
+ MCP server should be enabled (which is `"True"` on Hugging Face Spaces).
481
+ footer (`bool`, *optional*, defaults to `True`):
482
+ Whether to show the Gradio footer. When `False`, the footer will be hidden.
483
+ This can also be controlled via the `footer` query parameter in the URL.
484
+ color_palette (`list[str]`, *optional*):
485
+ A list of hex color codes to use for plot lines. If not provided, the
486
+ `TRACKIO_COLOR_PALETTE` environment variable will be used (comma-separated
487
+ hex codes), or if that is not set, the default color palette will be used.
488
+ Example: `['#FF0000', '#00FF00', '#0000FF']`
489
+ open_browser (`bool`, *optional*, defaults to `True`):
490
+ If `True` and not in a notebook, a new browser tab will be opened with the
491
+ dashboard. If `False`, the browser will not be opened.
492
+ block_thread (`bool`, *optional*):
493
+ If `True`, the main thread will be blocked until the dashboard is closed.
494
+ If `None` (default behavior), then the main thread will not be blocked if the
495
+ dashboard is launched in a notebook, otherwise the main thread will be blocked.
496
+
497
+ Returns:
498
+ `app`: The Gradio app object corresponding to the dashboard launched by Trackio.
499
+ `url`: The local URL of the dashboard.
500
+ `share_url`: The public share URL of the dashboard.
501
+ `full_url`: The full URL of the dashboard including the write token (will use the public share URL if launched publicly, otherwise the local URL).
502
+ """
503
+ if color_palette is not None:
504
+ os.environ["TRACKIO_COLOR_PALETTE"] = ",".join(color_palette)
505
+
506
+ theme = theme or os.environ.get("TRACKIO_THEME")
507
+
508
+ _mcp_server = (
509
+ mcp_server
510
+ if mcp_server is not None
511
+ else os.environ.get("GRADIO_MCP_SERVER", "False") == "True"
512
+ )
513
+
514
+ app, url, share_url = demo.launch(
515
+ css=CSS,
516
+ head=HEAD,
517
+ footer_links=["gradio", "settings"] + (["api"] if _mcp_server else []),
518
+ quiet=True,
519
+ inline=False,
520
+ prevent_thread_lock=True,
521
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
522
+ allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
523
+ mcp_server=_mcp_server,
524
+ theme=theme,
525
+ )
526
+
527
+ base_url = share_url + "/" if share_url else url
528
+ full_url = utils.get_full_url(
529
+ base_url, project=project, write_token=demo.write_token, footer=footer
530
+ )
531
+
532
+ if not utils.is_in_notebook():
533
+ print(f"* Trackio UI launched at: {full_url}")
534
+ if open_browser:
535
+ webbrowser.open(full_url)
536
+ block_thread = block_thread if block_thread is not None else True
537
+ else:
538
+ utils.embed_url_in_notebook(full_url)
539
+ block_thread = block_thread if block_thread is not None else False
540
+
541
+ if block_thread:
542
+ utils.block_main_thread_until_keyboard_interrupt()
543
+ return TupleNoPrint((demo, url, share_url, full_url))
trackio/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (22.8 kB). View file
 
trackio/__pycache__/commit_scheduler.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
trackio/__pycache__/context_vars.cpython-312.pyc ADDED
Binary file (1.08 kB). View file
 
trackio/__pycache__/deploy.cpython-312.pyc ADDED
Binary file (15.5 kB). View file
 
trackio/__pycache__/dummy_commit_scheduler.cpython-312.pyc ADDED
Binary file (1.01 kB). View file
 
trackio/__pycache__/gpu.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
trackio/__pycache__/histogram.cpython-312.pyc ADDED
Binary file (3.23 kB). View file
 
trackio/__pycache__/imports.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
trackio/__pycache__/run.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
trackio/__pycache__/sqlite_storage.cpython-312.pyc ADDED
Binary file (40.7 kB). View file
 
trackio/__pycache__/table.cpython-312.pyc ADDED
Binary file (8.6 kB). View file
 
trackio/__pycache__/typehints.cpython-312.pyc ADDED
Binary file (1.25 kB). View file
 
trackio/__pycache__/utils.cpython-312.pyc ADDED
Binary file (29.8 kB). View file
 
trackio/assets/badge.png ADDED
trackio/assets/trackio_logo_dark.png ADDED
trackio/assets/trackio_logo_light.png ADDED
trackio/assets/trackio_logo_old.png ADDED

Git LFS Details

  • SHA256: 3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
trackio/assets/trackio_logo_type_dark.png ADDED
trackio/assets/trackio_logo_type_dark_transparent.png ADDED
trackio/assets/trackio_logo_type_light.png ADDED
trackio/assets/trackio_logo_type_light_transparent.png ADDED
trackio/cli.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from trackio import show, sync
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Trackio CLI")
8
+ subparsers = parser.add_subparsers(dest="command")
9
+
10
+ ui_parser = subparsers.add_parser(
11
+ "show", help="Show the Trackio dashboard UI for a project"
12
+ )
13
+ ui_parser.add_argument(
14
+ "--project", required=False, help="Project name to show in the dashboard"
15
+ )
16
+ ui_parser.add_argument(
17
+ "--theme",
18
+ required=False,
19
+ default="default",
20
+ help="A Gradio Theme to use for the dashboard instead of the default, can be a built-in theme (e.g. 'soft', 'citrus'), or a theme from the Hub (e.g. 'gstaff/xkcd').",
21
+ )
22
+ ui_parser.add_argument(
23
+ "--mcp-server",
24
+ action="store_true",
25
+ help="Enable MCP server functionality. The Trackio dashboard will be set up as an MCP server and certain functions will be exposed as MCP tools.",
26
+ )
27
+ ui_parser.add_argument(
28
+ "--footer",
29
+ action="store_true",
30
+ default=True,
31
+ help="Show the Gradio footer. Use --no-footer to hide it.",
32
+ )
33
+ ui_parser.add_argument(
34
+ "--no-footer",
35
+ dest="footer",
36
+ action="store_false",
37
+ help="Hide the Gradio footer.",
38
+ )
39
+ ui_parser.add_argument(
40
+ "--color-palette",
41
+ required=False,
42
+ help="Comma-separated list of hex color codes for plot lines (e.g. '#FF0000,#00FF00,#0000FF'). If not provided, the TRACKIO_COLOR_PALETTE environment variable will be used, or the default palette if not set.",
43
+ )
44
+
45
+ sync_parser = subparsers.add_parser(
46
+ "sync",
47
+ help="Sync a local project's database to a Hugging Face Space. If the Space does not exist, it will be created.",
48
+ )
49
+ sync_parser.add_argument(
50
+ "--project", required=True, help="The name of the local project."
51
+ )
52
+ sync_parser.add_argument(
53
+ "--space-id",
54
+ required=True,
55
+ help="The Hugging Face Space ID where the project will be synced (e.g. username/space_id).",
56
+ )
57
+ sync_parser.add_argument(
58
+ "--private",
59
+ action="store_true",
60
+ help="Make the Hugging Face Space private if creating a new Space. By default, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.",
61
+ )
62
+ sync_parser.add_argument(
63
+ "--force",
64
+ action="store_true",
65
+ help="Overwrite the existing database without prompting for confirmation.",
66
+ )
67
+
68
+ args = parser.parse_args()
69
+
70
+ if args.command == "show":
71
+ color_palette = None
72
+ if args.color_palette:
73
+ color_palette = [color.strip() for color in args.color_palette.split(",")]
74
+ show(
75
+ project=args.project,
76
+ theme=args.theme,
77
+ mcp_server=args.mcp_server,
78
+ footer=args.footer,
79
+ color_palette=color_palette,
80
+ )
81
+ elif args.command == "sync":
82
+ sync(
83
+ project=args.project,
84
+ space_id=args.space_id,
85
+ private=args.private,
86
+ force=args.force,
87
+ )
88
+ else:
89
+ parser.print_help()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
trackio/commit_scheduler.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
2
+
3
+ import atexit
4
+ import logging
5
+ import os
6
+ import time
7
+ from concurrent.futures import Future
8
+ from dataclasses import dataclass
9
+ from io import SEEK_END, SEEK_SET, BytesIO
10
+ from pathlib import Path
11
+ from threading import Lock, Thread
12
+ from typing import Callable, Dict, List, Union
13
+
14
+ from huggingface_hub.hf_api import (
15
+ DEFAULT_IGNORE_PATTERNS,
16
+ CommitInfo,
17
+ CommitOperationAdd,
18
+ HfApi,
19
+ )
20
+ from huggingface_hub.utils import filter_repo_objects
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class _FileToUpload:
27
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
28
+
29
+ local_path: Path
30
+ path_in_repo: str
31
+ size_limit: int
32
+ last_modified: float
33
+
34
+
35
+ class CommitScheduler:
36
+ """
37
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
38
+
39
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
40
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
41
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
42
+ to learn more about how to use it.
43
+
44
+ Args:
45
+ repo_id (`str`):
46
+ The id of the repo to commit to.
47
+ folder_path (`str` or `Path`):
48
+ Path to the local folder to upload regularly.
49
+ every (`int` or `float`, *optional*):
50
+ The number of minutes between each commit. Defaults to 5 minutes.
51
+ path_in_repo (`str`, *optional*):
52
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
53
+ of the repository.
54
+ repo_type (`str`, *optional*):
55
+ The type of the repo to commit to. Defaults to `model`.
56
+ revision (`str`, *optional*):
57
+ The revision of the repo to commit to. Defaults to `main`.
58
+ private (`bool`, *optional*):
59
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
60
+ token (`str`, *optional*):
61
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
62
+ allow_patterns (`List[str]` or `str`, *optional*):
63
+ If provided, only files matching at least one pattern are uploaded.
64
+ ignore_patterns (`List[str]` or `str`, *optional*):
65
+ If provided, files matching any of the patterns are not uploaded.
66
+ squash_history (`bool`, *optional*):
67
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
68
+ useful to avoid degraded performances on the repo when it grows too large.
69
+ hf_api (`HfApi`, *optional*):
70
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
71
+ on_before_commit (`Callable[[], None]`, *optional*):
72
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
73
+
74
+ Example:
75
+ ```py
76
+ >>> from pathlib import Path
77
+ >>> from huggingface_hub import CommitScheduler
78
+
79
+ # Scheduler uploads every 10 minutes
80
+ >>> csv_path = Path("watched_folder/data.csv")
81
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
82
+
83
+ >>> with csv_path.open("a") as f:
84
+ ... f.write("first line")
85
+
86
+ # Some time later (...)
87
+ >>> with csv_path.open("a") as f:
88
+ ... f.write("second line")
89
+ ```
90
+
91
+ Example using a context manager:
92
+ ```py
93
+ >>> from pathlib import Path
94
+ >>> from huggingface_hub import CommitScheduler
95
+
96
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
97
+ ... csv_path = Path("watched_folder/data.csv")
98
+ ... with csv_path.open("a") as f:
99
+ ... f.write("first line")
100
+ ... (...)
101
+ ... with csv_path.open("a") as f:
102
+ ... f.write("second line")
103
+
104
+ # Scheduler is now stopped and last commit have been triggered
105
+ ```
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ repo_id: str,
112
+ folder_path: Union[str, Path],
113
+ every: Union[int, float] = 5,
114
+ path_in_repo: str | None = None,
115
+ repo_type: str | None = None,
116
+ revision: str | None = None,
117
+ private: bool | None = None,
118
+ token: str | None = None,
119
+ allow_patterns: list[str] | str | None = None,
120
+ ignore_patterns: list[str] | str | None = None,
121
+ squash_history: bool = False,
122
+ hf_api: HfApi | None = None,
123
+ on_before_commit: Callable[[], None] | None = None,
124
+ ) -> None:
125
+ self.api = hf_api or HfApi(token=token)
126
+ self.on_before_commit = on_before_commit
127
+
128
+ # Folder
129
+ self.folder_path = Path(folder_path).expanduser().resolve()
130
+ self.path_in_repo = path_in_repo or ""
131
+ self.allow_patterns = allow_patterns
132
+
133
+ if ignore_patterns is None:
134
+ ignore_patterns = []
135
+ elif isinstance(ignore_patterns, str):
136
+ ignore_patterns = [ignore_patterns]
137
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
138
+
139
+ if self.folder_path.is_file():
140
+ raise ValueError(
141
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
142
+ )
143
+ self.folder_path.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Repository
146
+ repo_url = self.api.create_repo(
147
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
148
+ )
149
+ self.repo_id = repo_url.repo_id
150
+ self.repo_type = repo_type
151
+ self.revision = revision
152
+ self.token = token
153
+
154
+ self.last_uploaded: Dict[Path, float] = {}
155
+ self.last_push_time: float | None = None
156
+
157
+ if not every > 0:
158
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
159
+ self.lock = Lock()
160
+ self.every = every
161
+ self.squash_history = squash_history
162
+
163
+ logger.info(
164
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
165
+ )
166
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
167
+ self._scheduler_thread.start()
168
+ atexit.register(self._push_to_hub)
169
+
170
+ self.__stopped = False
171
+
172
+ def stop(self) -> None:
173
+ """Stop the scheduler.
174
+
175
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
176
+ """
177
+ self.__stopped = True
178
+
179
+ def __enter__(self) -> "CommitScheduler":
180
+ return self
181
+
182
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
183
+ # Upload last changes before exiting
184
+ self.trigger().result()
185
+ self.stop()
186
+ return
187
+
188
+ def _run_scheduler(self) -> None:
189
+ """Dumb thread waiting between each scheduled push to Hub."""
190
+ while True:
191
+ self.last_future = self.trigger()
192
+ time.sleep(self.every * 60)
193
+ if self.__stopped:
194
+ break
195
+
196
+ def trigger(self) -> Future:
197
+ """Trigger a `push_to_hub` and return a future.
198
+
199
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
200
+ immediately, without waiting for the next scheduled commit.
201
+ """
202
+ return self.api.run_as_future(self._push_to_hub)
203
+
204
+ def _push_to_hub(self) -> CommitInfo | None:
205
+ if self.__stopped: # If stopped, already scheduled commits are ignored
206
+ return None
207
+
208
+ logger.info("(Background) scheduled commit triggered.")
209
+ try:
210
+ value = self.push_to_hub()
211
+ if self.squash_history:
212
+ logger.info("(Background) squashing repo history.")
213
+ self.api.super_squash_history(
214
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
215
+ )
216
+ return value
217
+ except Exception as e:
218
+ logger.error(
219
+ f"Error while pushing to Hub: {e}"
220
+ ) # Depending on the setup, error might be silenced
221
+ raise
222
+
223
+ def push_to_hub(self) -> CommitInfo | None:
224
+ """
225
+ Push folder to the Hub and return the commit info.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
230
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
231
+ issues.
232
+
233
+ </Tip>
234
+
235
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
236
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
237
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
238
+ for example to compress data together in a single file before committing. For more details and examples, check
239
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
240
+ """
241
+ # Check files to upload (with lock)
242
+ with self.lock:
243
+ if self.on_before_commit is not None:
244
+ self.on_before_commit()
245
+
246
+ logger.debug("Listing files to upload for scheduled commit.")
247
+
248
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
249
+ relpath_to_abspath = {
250
+ path.relative_to(self.folder_path).as_posix(): path
251
+ for path in sorted(
252
+ self.folder_path.glob("**/*")
253
+ ) # sorted to be deterministic
254
+ if path.is_file()
255
+ }
256
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
257
+
258
+ # Filter with pattern + filter out unchanged files + retrieve current file size
259
+ files_to_upload: List[_FileToUpload] = []
260
+ for relpath in filter_repo_objects(
261
+ relpath_to_abspath.keys(),
262
+ allow_patterns=self.allow_patterns,
263
+ ignore_patterns=self.ignore_patterns,
264
+ ):
265
+ local_path = relpath_to_abspath[relpath]
266
+ stat = local_path.stat()
267
+ if (
268
+ self.last_uploaded.get(local_path) is None
269
+ or self.last_uploaded[local_path] != stat.st_mtime
270
+ ):
271
+ files_to_upload.append(
272
+ _FileToUpload(
273
+ local_path=local_path,
274
+ path_in_repo=prefix + relpath,
275
+ size_limit=stat.st_size,
276
+ last_modified=stat.st_mtime,
277
+ )
278
+ )
279
+
280
+ # Return if nothing to upload
281
+ if len(files_to_upload) == 0:
282
+ logger.debug("Dropping schedule commit: no changed file to upload.")
283
+ return None
284
+
285
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
286
+ logger.debug("Removing unchanged files since previous scheduled commit.")
287
+ add_operations = [
288
+ CommitOperationAdd(
289
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
290
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
291
+ path_or_fileobj=file_to_upload.local_path,
292
+ path_in_repo=file_to_upload.path_in_repo,
293
+ )
294
+ for file_to_upload in files_to_upload
295
+ ]
296
+
297
+ # Upload files (append mode expected - no need for lock)
298
+ logger.debug("Uploading files for scheduled commit.")
299
+ commit_info = self.api.create_commit(
300
+ repo_id=self.repo_id,
301
+ repo_type=self.repo_type,
302
+ operations=add_operations,
303
+ commit_message="Scheduled Commit",
304
+ revision=self.revision,
305
+ )
306
+
307
+ for file in files_to_upload:
308
+ self.last_uploaded[file.local_path] = file.last_modified
309
+
310
+ self.last_push_time = time.time()
311
+
312
+ return commit_info
313
+
314
+
315
+ class PartialFileIO(BytesIO):
316
+ """A file-like object that reads only the first part of a file.
317
+
318
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
319
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
320
+
321
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
322
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
323
+
324
+ Only supports `read`, `tell` and `seek` methods.
325
+
326
+ Args:
327
+ file_path (`str` or `Path`):
328
+ Path to the file to read.
329
+ size_limit (`int`):
330
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
331
+ will be read (and uploaded).
332
+ """
333
+
334
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
335
+ self._file_path = Path(file_path)
336
+ self._file = self._file_path.open("rb")
337
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
338
+
339
+ def __del__(self) -> None:
340
+ self._file.close()
341
+ return super().__del__()
342
+
343
+ def __repr__(self) -> str:
344
+ return (
345
+ f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
346
+ )
347
+
348
+ def __len__(self) -> int:
349
+ return self._size_limit
350
+
351
+ def __getattribute__(self, name: str):
352
+ if name.startswith("_") or name in (
353
+ "read",
354
+ "tell",
355
+ "seek",
356
+ ): # only 3 public methods supported
357
+ return super().__getattribute__(name)
358
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
359
+
360
+ def tell(self) -> int:
361
+ """Return the current file position."""
362
+ return self._file.tell()
363
+
364
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
365
+ """Change the stream position to the given offset.
366
+
367
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
368
+ """
369
+ if __whence == SEEK_END:
370
+ # SEEK_END => set from the truncated end
371
+ __offset = len(self) + __offset
372
+ __whence = SEEK_SET
373
+
374
+ pos = self._file.seek(__offset, __whence)
375
+ if pos > self._size_limit:
376
+ return self._file.seek(self._size_limit)
377
+ return pos
378
+
379
+ def read(self, __size: int | None = -1) -> bytes:
380
+ """Read at most `__size` bytes from the file.
381
+
382
+ Behavior is the same as a regular file, except that it is capped to the size limit.
383
+ """
384
+ current = self._file.tell()
385
+ if __size is None or __size < 0:
386
+ # Read until file limit
387
+ truncated_size = self._size_limit - current
388
+ else:
389
+ # Read until file limit or __size
390
+ truncated_size = min(__size, self._size_limit - current)
391
+ return self._file.read(truncated_size)
trackio/context_vars.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from trackio.run import Run
6
+
7
+ current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
8
+ "current_run", default=None
9
+ )
10
+ current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
11
+ "current_project", default=None
12
+ )
13
+ current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
14
+ "current_server", default=None
15
+ )
16
+ current_space_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
17
+ "current_space_id", default=None
18
+ )
19
+ current_share_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
20
+ "current_share_server", default=None
21
+ )
trackio/deploy.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import io
3
+ import os
4
+ import sys
5
+ import threading
6
+ import time
7
+ from importlib.resources import files
8
+ from pathlib import Path
9
+
10
+ if sys.version_info >= (3, 11):
11
+ import tomllib
12
+ else:
13
+ import tomli as tomllib
14
+
15
+ import gradio
16
+ import huggingface_hub
17
+ from gradio_client import Client, handle_file
18
+ from httpx import ReadTimeout
19
+ from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError
20
+
21
+ import trackio
22
+ from trackio.sqlite_storage import SQLiteStorage
23
+ from trackio.utils import get_or_create_project_hash, preprocess_space_and_dataset_ids
24
+
25
+ SPACE_HOST_URL = "https://{user_name}-{space_name}.hf.space/"
26
+ SPACE_URL = "https://huggingface.co/spaces/{space_id}"
27
+
28
+
29
+ def _get_source_install_dependencies() -> str:
30
+ """Get trackio dependencies from pyproject.toml for source installs."""
31
+ trackio_path = files("trackio")
32
+ pyproject_path = Path(trackio_path).parent / "pyproject.toml"
33
+ with open(pyproject_path, "rb") as f:
34
+ pyproject = tomllib.load(f)
35
+ deps = pyproject["project"]["dependencies"]
36
+ spaces_deps = (
37
+ pyproject["project"].get("optional-dependencies", {}).get("spaces", [])
38
+ )
39
+ return "\n".join(deps + spaces_deps)
40
+
41
+
42
+ def _is_trackio_installed_from_source() -> bool:
43
+ """Check if trackio is installed from source/editable install vs PyPI."""
44
+ try:
45
+ trackio_file = trackio.__file__
46
+ if "site-packages" not in trackio_file:
47
+ return True
48
+
49
+ dist = importlib.metadata.distribution("trackio")
50
+ if dist.files:
51
+ files = list(dist.files)
52
+ has_pth = any(".pth" in str(f) for f in files)
53
+ if has_pth:
54
+ return True
55
+
56
+ return False
57
+ except (
58
+ AttributeError,
59
+ importlib.metadata.PackageNotFoundError,
60
+ importlib.metadata.MetadataError,
61
+ ValueError,
62
+ TypeError,
63
+ ):
64
+ return True
65
+
66
+
67
+ def deploy_as_space(
68
+ space_id: str,
69
+ space_storage: huggingface_hub.SpaceStorage | None = None,
70
+ dataset_id: str | None = None,
71
+ private: bool | None = None,
72
+ ):
73
+ if (
74
+ os.getenv("SYSTEM") == "spaces"
75
+ ): # in case a repo with this function is uploaded to spaces
76
+ return
77
+
78
+ trackio_path = files("trackio")
79
+
80
+ hf_api = huggingface_hub.HfApi()
81
+
82
+ try:
83
+ huggingface_hub.create_repo(
84
+ space_id,
85
+ private=private,
86
+ space_sdk="gradio",
87
+ space_storage=space_storage,
88
+ repo_type="space",
89
+ exist_ok=True,
90
+ )
91
+ except HfHubHTTPError as e:
92
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
93
+ print("Need 'write' access token to create a Spaces repo.")
94
+ huggingface_hub.login(add_to_git_credential=False)
95
+ huggingface_hub.create_repo(
96
+ space_id,
97
+ private=private,
98
+ space_sdk="gradio",
99
+ space_storage=space_storage,
100
+ repo_type="space",
101
+ exist_ok=True,
102
+ )
103
+ else:
104
+ raise ValueError(f"Failed to create Space: {e}")
105
+
106
+ with open(Path(trackio_path, "README.md"), "r") as f:
107
+ readme_content = f.read()
108
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
109
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
110
+ hf_api.upload_file(
111
+ path_or_fileobj=readme_buffer,
112
+ path_in_repo="README.md",
113
+ repo_id=space_id,
114
+ repo_type="space",
115
+ )
116
+
117
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
118
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
119
+ is_source_install = _is_trackio_installed_from_source()
120
+
121
+ if is_source_install:
122
+ requirements_content = _get_source_install_dependencies()
123
+ else:
124
+ requirements_content = f"trackio[spaces]=={trackio.__version__}"
125
+
126
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
127
+ hf_api.upload_file(
128
+ path_or_fileobj=requirements_buffer,
129
+ path_in_repo="requirements.txt",
130
+ repo_id=space_id,
131
+ repo_type="space",
132
+ )
133
+
134
+ huggingface_hub.utils.disable_progress_bars()
135
+
136
+ if is_source_install:
137
+ hf_api.upload_folder(
138
+ repo_id=space_id,
139
+ repo_type="space",
140
+ folder_path=trackio_path,
141
+ path_in_repo="trackio",
142
+ ignore_patterns=["README.md"],
143
+ )
144
+
145
+ app_file_content = """import trackio
146
+ trackio.show()"""
147
+ app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
148
+ hf_api.upload_file(
149
+ path_or_fileobj=app_file_buffer,
150
+ path_in_repo="app.py",
151
+ repo_id=space_id,
152
+ repo_type="space",
153
+ )
154
+
155
+ if hf_token := huggingface_hub.utils.get_token():
156
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
157
+ if dataset_id is not None:
158
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
159
+
160
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
161
+ huggingface_hub.add_space_variable(
162
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
163
+ )
164
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
165
+ huggingface_hub.add_space_variable(
166
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
167
+ )
168
+
169
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
170
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_PLOT_ORDER", plot_order)
171
+
172
+ if theme := os.environ.get("TRACKIO_THEME"):
173
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
174
+
175
+ huggingface_hub.add_space_variable(space_id, "GRADIO_MCP_SERVER", "True")
176
+
177
+
178
+ def create_space_if_not_exists(
179
+ space_id: str,
180
+ space_storage: huggingface_hub.SpaceStorage | None = None,
181
+ dataset_id: str | None = None,
182
+ private: bool | None = None,
183
+ ) -> None:
184
+ """
185
+ Creates a new Hugging Face Space if it does not exist.
186
+
187
+ Args:
188
+ space_id (`str`):
189
+ The ID of the Space to create.
190
+ space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
191
+ Choice of persistent storage tier for the Space.
192
+ dataset_id (`str`, *optional*):
193
+ The ID of the Dataset to add to the Space as a space variable.
194
+ private (`bool`, *optional*):
195
+ Whether to make the Space private. If `None` (default), the repo will be
196
+ public unless the organization's default is private. This value is ignored
197
+ if the repo already exists.
198
+ """
199
+ if "/" not in space_id:
200
+ raise ValueError(
201
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
202
+ )
203
+ if dataset_id is not None and "/" not in dataset_id:
204
+ raise ValueError(
205
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
206
+ )
207
+ try:
208
+ huggingface_hub.repo_info(space_id, repo_type="space")
209
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
210
+ if dataset_id is not None:
211
+ huggingface_hub.add_space_variable(
212
+ space_id, "TRACKIO_DATASET_ID", dataset_id
213
+ )
214
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
215
+ huggingface_hub.add_space_variable(
216
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
217
+ )
218
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
219
+ huggingface_hub.add_space_variable(
220
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
221
+ )
222
+
223
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
224
+ huggingface_hub.add_space_variable(
225
+ space_id, "TRACKIO_PLOT_ORDER", plot_order
226
+ )
227
+
228
+ if theme := os.environ.get("TRACKIO_THEME"):
229
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
230
+ return
231
+ except RepositoryNotFoundError:
232
+ pass
233
+ except HfHubHTTPError as e:
234
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
235
+ print("Need 'write' access token to create a Spaces repo.")
236
+ huggingface_hub.login(add_to_git_credential=False)
237
+ huggingface_hub.add_space_variable(
238
+ space_id, "TRACKIO_DATASET_ID", dataset_id
239
+ )
240
+ else:
241
+ raise ValueError(f"Failed to create Space: {e}")
242
+
243
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
244
+ deploy_as_space(space_id, space_storage, dataset_id, private)
245
+
246
+
247
+ def wait_until_space_exists(
248
+ space_id: str,
249
+ ) -> None:
250
+ """
251
+ Blocks the current thread until the Space exists.
252
+
253
+ Args:
254
+ space_id (`str`):
255
+ The ID of the Space to wait for.
256
+
257
+ Raises:
258
+ `TimeoutError`: If waiting for the Space takes longer than expected.
259
+ """
260
+ hf_api = huggingface_hub.HfApi()
261
+ delay = 1
262
+ for _ in range(30):
263
+ try:
264
+ hf_api.space_info(space_id)
265
+ return
266
+ except (huggingface_hub.utils.HfHubHTTPError, ReadTimeout):
267
+ time.sleep(delay)
268
+ delay = min(delay * 2, 60)
269
+ raise TimeoutError("Waiting for space to exist took longer than expected")
270
+
271
+
272
+ def upload_db_to_space(project: str, space_id: str, force: bool = False) -> None:
273
+ """
274
+ Uploads the database of a local Trackio project to a Hugging Face Space.
275
+
276
+ This uses the Gradio Client to upload since we do not want to trigger a new build of
277
+ the Space, which would happen if we used `huggingface_hub.upload_file`.
278
+
279
+ Args:
280
+ project (`str`):
281
+ The name of the project to upload.
282
+ space_id (`str`):
283
+ The ID of the Space to upload to.
284
+ force (`bool`, *optional*, defaults to `False`):
285
+ If `True`, overwrites the existing database without prompting. If `False`,
286
+ prompts for confirmation.
287
+ """
288
+ db_path = SQLiteStorage.get_project_db_path(project)
289
+ client = Client(space_id, verbose=False, httpx_kwargs={"timeout": 90})
290
+
291
+ if not force:
292
+ try:
293
+ existing_projects = client.predict(api_name="/get_all_projects")
294
+ if project in existing_projects:
295
+ response = input(
296
+ f"Database for project '{project}' already exists on Space '{space_id}'. "
297
+ f"Overwrite it? (y/N): "
298
+ )
299
+ if response.lower() not in ["y", "yes"]:
300
+ print("* Upload cancelled.")
301
+ return
302
+ except Exception as e:
303
+ print(f"* Warning: Could not check if project exists on Space: {e}")
304
+ print("* Proceeding with upload...")
305
+
306
+ client.predict(
307
+ api_name="/upload_db_to_space",
308
+ project=project,
309
+ uploaded_db=handle_file(db_path),
310
+ hf_token=huggingface_hub.utils.get_token(),
311
+ )
312
+
313
+
314
+ def sync(
315
+ project: str,
316
+ space_id: str | None = None,
317
+ private: bool | None = None,
318
+ force: bool = False,
319
+ run_in_background: bool = False,
320
+ ) -> str:
321
+ """
322
+ Syncs a local Trackio project's database to a Hugging Face Space.
323
+ If the Space does not exist, it will be created.
324
+
325
+ Args:
326
+ project (`str`): The name of the project to upload.
327
+ space_id (`str`, *optional*): The ID of the Space to upload to (e.g., `"username/space_id"`).
328
+ If not provided, a random space_id (e.g. "username/project-2ac3z2aA") will be used.
329
+ private (`bool`, *optional*):
330
+ Whether to make the Space private. If None (default), the repo will be
331
+ public unless the organization's default is private. This value is ignored
332
+ if the repo already exists.
333
+ force (`bool`, *optional*, defaults to `False`):
334
+ If `True`, overwrite the existing database without prompting for confirmation.
335
+ If `False`, prompt the user before overwriting an existing database.
336
+ run_in_background (`bool`, *optional*, defaults to `False`):
337
+ If `True`, the Space creation and database upload will be run in a background thread.
338
+ If `False`, all the steps will be run synchronously.
339
+ Returns:
340
+ `str`: The Space ID of the synced project.
341
+ """
342
+ if space_id is None:
343
+ space_id = f"{project}-{get_or_create_project_hash(project)}"
344
+ space_id, _ = preprocess_space_and_dataset_ids(space_id, None)
345
+
346
+ def space_creation_and_upload(
347
+ space_id: str, private: bool | None = None, force: bool = False
348
+ ):
349
+ print(
350
+ f"* Syncing local Trackio project to: {SPACE_URL.format(space_id=space_id)} (please wait...)"
351
+ )
352
+ create_space_if_not_exists(space_id, private=private)
353
+ wait_until_space_exists(space_id)
354
+ upload_db_to_space(project, space_id, force=force)
355
+ print(f"* Synced successfully to space: {SPACE_URL.format(space_id=space_id)}")
356
+
357
+ if run_in_background:
358
+ threading.Thread(
359
+ target=space_creation_and_upload, args=(space_id, private, force)
360
+ ).start()
361
+ else:
362
+ space_creation_and_upload(space_id, private, force)
363
+ return space_id
trackio/dummy_commit_scheduler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A dummy object to fit the interface of huggingface_hub's CommitScheduler
2
+ class DummyCommitSchedulerLock:
3
+ def __enter__(self):
4
+ return None
5
+
6
+ def __exit__(self, exception_type, exception_value, exception_traceback):
7
+ pass
8
+
9
+
10
+ class DummyCommitScheduler:
11
+ def __init__(self):
12
+ self.lock = DummyCommitSchedulerLock()
trackio/gpu.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ if TYPE_CHECKING:
7
+ from trackio.run import Run
8
+
9
+ pynvml: Any = None
10
+ PYNVML_AVAILABLE = False
11
+ _nvml_initialized = False
12
+ _nvml_lock = threading.Lock()
13
+ _energy_baseline: dict[int, float] = {}
14
+
15
+
16
+ def _ensure_pynvml():
17
+ global PYNVML_AVAILABLE, pynvml
18
+ if PYNVML_AVAILABLE:
19
+ return pynvml
20
+ try:
21
+ import pynvml as _pynvml
22
+
23
+ pynvml = _pynvml
24
+ PYNVML_AVAILABLE = True
25
+ return pynvml
26
+ except ImportError:
27
+ raise ImportError(
28
+ "nvidia-ml-py is required for GPU monitoring. "
29
+ "Install it with: pip install nvidia-ml-py"
30
+ )
31
+
32
+
33
+ def _init_nvml() -> bool:
34
+ global _nvml_initialized
35
+ with _nvml_lock:
36
+ if _nvml_initialized:
37
+ return True
38
+ try:
39
+ nvml = _ensure_pynvml()
40
+ nvml.nvmlInit()
41
+ _nvml_initialized = True
42
+ return True
43
+ except Exception:
44
+ return False
45
+
46
+
47
+ def _shutdown_nvml():
48
+ global _nvml_initialized
49
+ with _nvml_lock:
50
+ if _nvml_initialized and pynvml is not None:
51
+ try:
52
+ pynvml.nvmlShutdown()
53
+ except Exception:
54
+ pass
55
+ _nvml_initialized = False
56
+
57
+
58
+ def get_gpu_count() -> tuple[int, list[int]]:
59
+ """
60
+ Get the number of GPUs visible to this process and their physical indices.
61
+ Respects CUDA_VISIBLE_DEVICES environment variable.
62
+
63
+ Returns:
64
+ Tuple of (count, physical_indices) where:
65
+ - count: Number of visible GPUs
66
+ - physical_indices: List mapping logical index to physical GPU index.
67
+ e.g., if CUDA_VISIBLE_DEVICES=2,3 returns (2, [2, 3])
68
+ meaning logical GPU 0 = physical GPU 2, logical GPU 1 = physical GPU 3
69
+ """
70
+ if not _init_nvml():
71
+ return 0, []
72
+
73
+ cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES")
74
+ if cuda_visible is not None and cuda_visible.strip():
75
+ try:
76
+ indices = [int(x.strip()) for x in cuda_visible.split(",") if x.strip()]
77
+ return len(indices), indices
78
+ except ValueError:
79
+ pass
80
+
81
+ try:
82
+ total = pynvml.nvmlDeviceGetCount()
83
+ return total, list(range(total))
84
+ except Exception:
85
+ return 0, []
86
+
87
+
88
+ def gpu_available() -> bool:
89
+ """
90
+ Check if GPU monitoring is available.
91
+
92
+ Returns True if nvidia-ml-py is installed and at least one NVIDIA GPU is detected.
93
+ This is used for auto-detection of GPU logging.
94
+ """
95
+ try:
96
+ _ensure_pynvml()
97
+ count, _ = get_gpu_count()
98
+ return count > 0
99
+ except ImportError:
100
+ return False
101
+ except Exception:
102
+ return False
103
+
104
+
105
+ def reset_energy_baseline():
106
+ """Reset the energy baseline for all GPUs. Called when a new run starts."""
107
+ global _energy_baseline
108
+ _energy_baseline = {}
109
+
110
+
111
+ def collect_gpu_metrics(device: int | None = None) -> dict:
112
+ """
113
+ Collect GPU metrics for visible GPUs.
114
+
115
+ Args:
116
+ device: CUDA device index to collect metrics from. If None, collects
117
+ from all GPUs visible to this process (respects CUDA_VISIBLE_DEVICES).
118
+ The device index is the logical CUDA index (0, 1, 2...), not the
119
+ physical GPU index.
120
+
121
+ Returns:
122
+ Dictionary of GPU metrics. Keys use logical device indices (gpu/0/, gpu/1/, etc.)
123
+ which correspond to CUDA device indices, not physical GPU indices.
124
+ """
125
+ if not _init_nvml():
126
+ return {}
127
+
128
+ gpu_count, visible_gpus = get_gpu_count()
129
+ if gpu_count == 0:
130
+ return {}
131
+
132
+ if device is not None:
133
+ if device < 0 or device >= gpu_count:
134
+ return {}
135
+ gpu_indices = [(device, visible_gpus[device])]
136
+ else:
137
+ gpu_indices = list(enumerate(visible_gpus))
138
+
139
+ metrics = {}
140
+ total_util = 0.0
141
+ total_mem_used_gib = 0.0
142
+ total_power = 0.0
143
+ max_temp = 0.0
144
+ valid_util_count = 0
145
+
146
+ for logical_idx, physical_idx in gpu_indices:
147
+ prefix = f"gpu/{logical_idx}"
148
+ try:
149
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)
150
+
151
+ try:
152
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
153
+ metrics[f"{prefix}/utilization"] = util.gpu
154
+ metrics[f"{prefix}/memory_utilization"] = util.memory
155
+ total_util += util.gpu
156
+ valid_util_count += 1
157
+ except Exception:
158
+ pass
159
+
160
+ try:
161
+ mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
162
+ mem_used_gib = mem.used / (1024**3)
163
+ mem_total_gib = mem.total / (1024**3)
164
+ metrics[f"{prefix}/allocated_memory"] = mem_used_gib
165
+ metrics[f"{prefix}/total_memory"] = mem_total_gib
166
+ if mem.total > 0:
167
+ metrics[f"{prefix}/memory_usage"] = mem.used / mem.total
168
+ total_mem_used_gib += mem_used_gib
169
+ except Exception:
170
+ pass
171
+
172
+ try:
173
+ power_mw = pynvml.nvmlDeviceGetPowerUsage(handle)
174
+ power_w = power_mw / 1000.0
175
+ metrics[f"{prefix}/power"] = power_w
176
+ total_power += power_w
177
+ except Exception:
178
+ pass
179
+
180
+ try:
181
+ power_limit_mw = pynvml.nvmlDeviceGetPowerManagementLimit(handle)
182
+ power_limit_w = power_limit_mw / 1000.0
183
+ metrics[f"{prefix}/power_limit"] = power_limit_w
184
+ if power_limit_w > 0 and f"{prefix}/power" in metrics:
185
+ metrics[f"{prefix}/power_percent"] = (
186
+ metrics[f"{prefix}/power"] / power_limit_w
187
+ ) * 100
188
+ except Exception:
189
+ pass
190
+
191
+ try:
192
+ temp = pynvml.nvmlDeviceGetTemperature(
193
+ handle, pynvml.NVML_TEMPERATURE_GPU
194
+ )
195
+ metrics[f"{prefix}/temp"] = temp
196
+ max_temp = max(max_temp, temp)
197
+ except Exception:
198
+ pass
199
+
200
+ try:
201
+ sm_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_SM)
202
+ metrics[f"{prefix}/sm_clock"] = sm_clock
203
+ except Exception:
204
+ pass
205
+
206
+ try:
207
+ mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM)
208
+ metrics[f"{prefix}/memory_clock"] = mem_clock
209
+ except Exception:
210
+ pass
211
+
212
+ try:
213
+ fan_speed = pynvml.nvmlDeviceGetFanSpeed(handle)
214
+ metrics[f"{prefix}/fan_speed"] = fan_speed
215
+ except Exception:
216
+ pass
217
+
218
+ try:
219
+ pstate = pynvml.nvmlDeviceGetPerformanceState(handle)
220
+ metrics[f"{prefix}/performance_state"] = pstate
221
+ except Exception:
222
+ pass
223
+
224
+ try:
225
+ energy_mj = pynvml.nvmlDeviceGetTotalEnergyConsumption(handle)
226
+ if logical_idx not in _energy_baseline:
227
+ _energy_baseline[logical_idx] = energy_mj
228
+ energy_consumed_mj = energy_mj - _energy_baseline[logical_idx]
229
+ metrics[f"{prefix}/energy_consumed"] = energy_consumed_mj / 1000.0
230
+ except Exception:
231
+ pass
232
+
233
+ try:
234
+ pcie_tx = pynvml.nvmlDeviceGetPcieThroughput(
235
+ handle, pynvml.NVML_PCIE_UTIL_TX_BYTES
236
+ )
237
+ pcie_rx = pynvml.nvmlDeviceGetPcieThroughput(
238
+ handle, pynvml.NVML_PCIE_UTIL_RX_BYTES
239
+ )
240
+ metrics[f"{prefix}/pcie_tx"] = pcie_tx / 1024.0
241
+ metrics[f"{prefix}/pcie_rx"] = pcie_rx / 1024.0
242
+ except Exception:
243
+ pass
244
+
245
+ try:
246
+ throttle = pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(handle)
247
+ metrics[f"{prefix}/throttle_thermal"] = int(
248
+ bool(throttle & pynvml.nvmlClocksThrottleReasonSwThermalSlowdown)
249
+ )
250
+ metrics[f"{prefix}/throttle_power"] = int(
251
+ bool(throttle & pynvml.nvmlClocksThrottleReasonSwPowerCap)
252
+ )
253
+ metrics[f"{prefix}/throttle_hw_slowdown"] = int(
254
+ bool(throttle & pynvml.nvmlClocksThrottleReasonHwSlowdown)
255
+ )
256
+ metrics[f"{prefix}/throttle_apps"] = int(
257
+ bool(
258
+ throttle
259
+ & pynvml.nvmlClocksThrottleReasonApplicationsClocksSetting
260
+ )
261
+ )
262
+ except Exception:
263
+ pass
264
+
265
+ try:
266
+ ecc_corrected = pynvml.nvmlDeviceGetTotalEccErrors(
267
+ handle,
268
+ pynvml.NVML_MEMORY_ERROR_TYPE_CORRECTED,
269
+ pynvml.NVML_VOLATILE_ECC,
270
+ )
271
+ metrics[f"{prefix}/corrected_memory_errors"] = ecc_corrected
272
+ except Exception:
273
+ pass
274
+
275
+ try:
276
+ ecc_uncorrected = pynvml.nvmlDeviceGetTotalEccErrors(
277
+ handle,
278
+ pynvml.NVML_MEMORY_ERROR_TYPE_UNCORRECTED,
279
+ pynvml.NVML_VOLATILE_ECC,
280
+ )
281
+ metrics[f"{prefix}/uncorrected_memory_errors"] = ecc_uncorrected
282
+ except Exception:
283
+ pass
284
+
285
+ except Exception:
286
+ continue
287
+
288
+ if valid_util_count > 0:
289
+ metrics["gpu/mean_utilization"] = total_util / valid_util_count
290
+ if total_mem_used_gib > 0:
291
+ metrics["gpu/total_allocated_memory"] = total_mem_used_gib
292
+ if total_power > 0:
293
+ metrics["gpu/total_power"] = total_power
294
+ if max_temp > 0:
295
+ metrics["gpu/max_temp"] = max_temp
296
+
297
+ return metrics
298
+
299
+
300
+ class GpuMonitor:
301
+ def __init__(self, run: "Run", interval: float = 10.0):
302
+ self._run = run
303
+ self._interval = interval
304
+ self._stop_flag = threading.Event()
305
+ self._thread: "threading.Thread | None" = None
306
+
307
+ def start(self):
308
+ count, _ = get_gpu_count()
309
+ if count == 0:
310
+ warnings.warn(
311
+ "auto_log_gpu=True but no NVIDIA GPUs detected. GPU logging disabled."
312
+ )
313
+ return
314
+
315
+ reset_energy_baseline()
316
+ self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
317
+ self._thread.start()
318
+
319
+ def stop(self):
320
+ self._stop_flag.set()
321
+ if self._thread is not None:
322
+ self._thread.join(timeout=2.0)
323
+
324
+ def _monitor_loop(self):
325
+ while not self._stop_flag.is_set():
326
+ try:
327
+ metrics = collect_gpu_metrics()
328
+ if metrics:
329
+ self._run.log_system(metrics)
330
+ except Exception:
331
+ pass
332
+
333
+ self._stop_flag.wait(timeout=self._interval)
334
+
335
+
336
+ def log_gpu(run: "Run | None" = None, device: int | None = None) -> dict:
337
+ """
338
+ Log GPU metrics to the current or specified run as system metrics.
339
+
340
+ Args:
341
+ run: Optional Run instance. If None, uses current run from context.
342
+ device: CUDA device index to collect metrics from. If None, collects
343
+ from all GPUs visible to this process (respects CUDA_VISIBLE_DEVICES).
344
+
345
+ Returns:
346
+ dict: The GPU metrics that were logged.
347
+
348
+ Example:
349
+ ```python
350
+ import trackio
351
+
352
+ run = trackio.init(project="my-project")
353
+ trackio.log({"loss": 0.5})
354
+ trackio.log_gpu() # logs all visible GPUs
355
+ trackio.log_gpu(device=0) # logs only CUDA device 0
356
+ ```
357
+ """
358
+ from trackio import context_vars
359
+
360
+ if run is None:
361
+ run = context_vars.current_run.get()
362
+ if run is None:
363
+ raise RuntimeError("Call trackio.init() before trackio.log_gpu().")
364
+
365
+ metrics = collect_gpu_metrics(device=device)
366
+ if metrics:
367
+ run.log_system(metrics)
368
+ return metrics
trackio/histogram.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import numpy as np
4
+
5
+
6
+ class Histogram:
7
+ """
8
+ Histogram data type for Trackio, compatible with wandb.Histogram.
9
+
10
+ Args:
11
+ sequence (`np.ndarray` or `Sequence[float]` or `Sequence[int]`, *optional*):
12
+ Sequence of values to create the histogram from.
13
+ np_histogram (`tuple`, *optional*):
14
+ Pre-computed NumPy histogram as a `(hist, bins)` tuple.
15
+ num_bins (`int`, *optional*, defaults to `64`):
16
+ Number of bins for the histogram (maximum `512`).
17
+
18
+ Example:
19
+ ```python
20
+ import trackio
21
+ import numpy as np
22
+
23
+ # Create histogram from sequence
24
+ data = np.random.randn(1000)
25
+ trackio.log({"distribution": trackio.Histogram(data)})
26
+
27
+ # Create histogram from numpy histogram
28
+ hist, bins = np.histogram(data, bins=30)
29
+ trackio.log({"distribution": trackio.Histogram(np_histogram=(hist, bins))})
30
+
31
+ # Specify custom number of bins
32
+ trackio.log({"distribution": trackio.Histogram(data, num_bins=50)})
33
+ ```
34
+ """
35
+
36
+ TYPE = "trackio.histogram"
37
+
38
+ def __init__(
39
+ self,
40
+ sequence: np.ndarray | Sequence[float] | Sequence[int] | None = None,
41
+ np_histogram: tuple | None = None,
42
+ num_bins: int = 64,
43
+ ):
44
+ if sequence is None and np_histogram is None:
45
+ raise ValueError("Must provide either sequence or np_histogram")
46
+
47
+ if sequence is not None and np_histogram is not None:
48
+ raise ValueError("Cannot provide both sequence and np_histogram")
49
+
50
+ num_bins = min(num_bins, 512)
51
+
52
+ if np_histogram is not None:
53
+ self.histogram, self.bins = np_histogram
54
+ self.histogram = np.asarray(self.histogram)
55
+ self.bins = np.asarray(self.bins)
56
+ else:
57
+ data = np.asarray(sequence).flatten()
58
+ data = data[np.isfinite(data)]
59
+ if len(data) == 0:
60
+ self.histogram = np.array([])
61
+ self.bins = np.array([])
62
+ else:
63
+ self.histogram, self.bins = np.histogram(data, bins=num_bins)
64
+
65
+ def _to_dict(self) -> dict:
66
+ """Convert histogram to dictionary for storage."""
67
+ return {
68
+ "_type": self.TYPE,
69
+ "bins": self.bins.tolist(),
70
+ "values": self.histogram.tolist(),
71
+ }
trackio/imports.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from trackio import deploy, utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+
9
+
10
+ def import_csv(
11
+ csv_path: str | Path,
12
+ project: str,
13
+ name: str | None = None,
14
+ space_id: str | None = None,
15
+ dataset_id: str | None = None,
16
+ private: bool | None = None,
17
+ force: bool = False,
18
+ ) -> None:
19
+ """
20
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
21
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
22
+ treated as metrics. It should also include a header row with the column names.
23
+
24
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
25
+
26
+ Args:
27
+ csv_path (`str` or `Path`):
28
+ The str or Path to the CSV file to import.
29
+ project (`str`):
30
+ The name of the project to import the CSV file into. Must not be an existing
31
+ project.
32
+ name (`str`, *optional*):
33
+ The name of the Run to import the CSV file into. If not provided, a default
34
+ name will be generated.
35
+ name (`str`, *optional*):
36
+ The name of the run (if not provided, a default name will be generated).
37
+ space_id (`str`, *optional*):
38
+ If provided, the project will be logged to a Hugging Face Space instead of a
39
+ local directory. Should be a complete Space name like `"username/reponame"`
40
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
41
+ be created in the currently-logged-in Hugging Face user's namespace. If the
42
+ Space does not exist, it will be created. If the Space already exists, the
43
+ project will be logged to it.
44
+ dataset_id (`str`, *optional*):
45
+ If provided, a persistent Hugging Face Dataset will be created and the
46
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
47
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
48
+ `"datasetname"` in which case the Dataset will be created in the
49
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
50
+ exist, it will be created. If the Dataset already exists, the project will
51
+ be appended to it. If not provided, the metrics will be logged to a local
52
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
53
+ will be automatically created with the same name as the Space but with the
54
+ `"_dataset"` suffix.
55
+ private (`bool`, *optional*):
56
+ Whether to make the Space private. If None (default), the repo will be
57
+ public unless the organization's default is private. This value is ignored
58
+ if the repo already exists.
59
+ """
60
+ if SQLiteStorage.get_runs(project):
61
+ raise ValueError(
62
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
63
+ )
64
+
65
+ csv_path = Path(csv_path)
66
+ if not csv_path.exists():
67
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
68
+
69
+ df = pd.read_csv(csv_path)
70
+ if df.empty:
71
+ raise ValueError("CSV file is empty")
72
+
73
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
74
+ df = df.rename(columns=column_mapping)
75
+
76
+ step_column = None
77
+ for col in df.columns:
78
+ if col.lower() == "step":
79
+ step_column = col
80
+ break
81
+
82
+ if step_column is None:
83
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
84
+
85
+ if name is None:
86
+ name = csv_path.stem
87
+
88
+ metrics_list = []
89
+ steps = []
90
+ timestamps = []
91
+
92
+ numeric_columns = []
93
+ for column in df.columns:
94
+ if column == step_column:
95
+ continue
96
+ if column == "timestamp":
97
+ continue
98
+
99
+ try:
100
+ pd.to_numeric(df[column], errors="raise")
101
+ numeric_columns.append(column)
102
+ except (ValueError, TypeError):
103
+ continue
104
+
105
+ for _, row in df.iterrows():
106
+ metrics = {}
107
+ for column in numeric_columns:
108
+ value = row[column]
109
+ if bool(pd.notna(value)):
110
+ metrics[column] = float(value)
111
+
112
+ if metrics:
113
+ metrics_list.append(metrics)
114
+ steps.append(int(row[step_column]))
115
+
116
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
117
+ timestamps.append(str(row["timestamp"]))
118
+ else:
119
+ timestamps.append("")
120
+
121
+ if metrics_list:
122
+ SQLiteStorage.bulk_log(
123
+ project=project,
124
+ run=name,
125
+ metrics_list=metrics_list,
126
+ steps=steps,
127
+ timestamps=timestamps,
128
+ )
129
+
130
+ print(
131
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
132
+ )
133
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
134
+
135
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
136
+ if dataset_id is not None:
137
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
138
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
139
+
140
+ if space_id is None:
141
+ utils.print_dashboard_instructions(project)
142
+ else:
143
+ deploy.create_space_if_not_exists(
144
+ space_id=space_id, dataset_id=dataset_id, private=private
145
+ )
146
+ deploy.wait_until_space_exists(space_id=space_id)
147
+ deploy.upload_db_to_space(project=project, space_id=space_id, force=force)
148
+ print(
149
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
150
+ )
151
+
152
+
153
+ def import_tf_events(
154
+ log_dir: str | Path,
155
+ project: str,
156
+ name: str | None = None,
157
+ space_id: str | None = None,
158
+ dataset_id: str | None = None,
159
+ private: bool | None = None,
160
+ force: bool = False,
161
+ ) -> None:
162
+ """
163
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
164
+ subdirectory in the log directory will be imported as a separate run.
165
+
166
+ Args:
167
+ log_dir (`str` or `Path`):
168
+ The str or Path to the directory containing TensorFlow Events files.
169
+ project (`str`):
170
+ The name of the project to import the TensorFlow Events files into. Must not
171
+ be an existing project.
172
+ name (`str`, *optional*):
173
+ The name prefix for runs (if not provided, will use directory names). Each
174
+ subdirectory will create a separate run.
175
+ space_id (`str`, *optional*):
176
+ If provided, the project will be logged to a Hugging Face Space instead of a
177
+ local directory. Should be a complete Space name like `"username/reponame"`
178
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
179
+ be created in the currently-logged-in Hugging Face user's namespace. If the
180
+ Space does not exist, it will be created. If the Space already exists, the
181
+ project will be logged to it.
182
+ dataset_id (`str`, *optional*):
183
+ If provided, a persistent Hugging Face Dataset will be created and the
184
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
185
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
186
+ `"datasetname"` in which case the Dataset will be created in the
187
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
188
+ exist, it will be created. If the Dataset already exists, the project will
189
+ be appended to it. If not provided, the metrics will be logged to a local
190
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
191
+ will be automatically created with the same name as the Space but with the
192
+ `"_dataset"` suffix.
193
+ private (`bool`, *optional*):
194
+ Whether to make the Space private. If None (default), the repo will be
195
+ public unless the organization's default is private. This value is ignored
196
+ if the repo already exists.
197
+ """
198
+ try:
199
+ from tbparse import SummaryReader
200
+ except ImportError:
201
+ raise ImportError(
202
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
203
+ )
204
+
205
+ if SQLiteStorage.get_runs(project):
206
+ raise ValueError(
207
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
208
+ )
209
+
210
+ path = Path(log_dir)
211
+ if not path.exists():
212
+ raise FileNotFoundError(f"TF events directory not found: {path}")
213
+
214
+ # Use tbparse to read all tfevents files in the directory structure
215
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
216
+ df = reader.scalars
217
+
218
+ if df.empty:
219
+ raise ValueError(f"No TensorFlow events data found in {path}")
220
+
221
+ total_imported = 0
222
+ imported_runs = []
223
+
224
+ # Group by dir_name to create separate runs
225
+ for dir_name, group_df in df.groupby("dir_name"):
226
+ try:
227
+ # Determine run name based on directory name
228
+ if dir_name == "":
229
+ run_name = "main" # For files in the root directory
230
+ else:
231
+ run_name = dir_name # Use directory name
232
+
233
+ if name:
234
+ run_name = f"{name}_{run_name}"
235
+
236
+ if group_df.empty:
237
+ print(f"* Skipping directory {dir_name}: no scalar data found")
238
+ continue
239
+
240
+ metrics_list = []
241
+ steps = []
242
+ timestamps = []
243
+
244
+ for _, row in group_df.iterrows():
245
+ # Convert row values to appropriate types
246
+ tag = str(row["tag"])
247
+ value = float(row["value"])
248
+ step = int(row["step"])
249
+
250
+ metrics = {tag: value}
251
+ metrics_list.append(metrics)
252
+ steps.append(step)
253
+
254
+ # Use wall_time if present, else fallback
255
+ if "wall_time" in group_df.columns and not bool(
256
+ pd.isna(row["wall_time"])
257
+ ):
258
+ timestamps.append(str(row["wall_time"]))
259
+ else:
260
+ timestamps.append("")
261
+
262
+ if metrics_list:
263
+ SQLiteStorage.bulk_log(
264
+ project=project,
265
+ run=str(run_name),
266
+ metrics_list=metrics_list,
267
+ steps=steps,
268
+ timestamps=timestamps,
269
+ )
270
+
271
+ total_imported += len(metrics_list)
272
+ imported_runs.append(run_name)
273
+
274
+ print(
275
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
276
+ )
277
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
278
+
279
+ except Exception as e:
280
+ print(f"* Error processing directory {dir_name}: {e}")
281
+ continue
282
+
283
+ if not imported_runs:
284
+ raise ValueError("No valid TensorFlow events data could be imported")
285
+
286
+ print(f"* Total imported events: {total_imported}")
287
+ print(f"* Created runs: {', '.join(imported_runs)}")
288
+
289
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
290
+ if dataset_id is not None:
291
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
292
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
293
+
294
+ if space_id is None:
295
+ utils.print_dashboard_instructions(project)
296
+ else:
297
+ deploy.create_space_if_not_exists(
298
+ space_id, dataset_id=dataset_id, private=private
299
+ )
300
+ deploy.wait_until_space_exists(space_id)
301
+ deploy.upload_db_to_space(project, space_id, force=force)
302
+ print(
303
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
304
+ )
trackio/media/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Media module for Trackio.
3
+
4
+ This module contains all media-related functionality including:
5
+ - TrackioImage, TrackioVideo, TrackioAudio classes
6
+ - Video writing utilities
7
+ - Audio conversion utilities
8
+ """
9
+
10
+ from trackio.media.audio import TrackioAudio
11
+ from trackio.media.image import TrackioImage
12
+ from trackio.media.media import TrackioMedia
13
+ from trackio.media.utils import get_project_media_path
14
+ from trackio.media.video import TrackioVideo
15
+
16
+ write_audio = TrackioAudio.write_audio
17
+ write_video = TrackioVideo.write_video
18
+
19
+ __all__ = [
20
+ "TrackioMedia",
21
+ "TrackioImage",
22
+ "TrackioVideo",
23
+ "TrackioAudio",
24
+ "get_project_media_path",
25
+ "write_video",
26
+ "write_audio",
27
+ ]
trackio/media/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (824 Bytes). View file
 
trackio/media/__pycache__/audio.cpython-312.pyc ADDED
Binary file (8.99 kB). View file
 
trackio/media/__pycache__/image.cpython-312.pyc ADDED
Binary file (4.74 kB). View file
 
trackio/media/__pycache__/media.cpython-312.pyc ADDED
Binary file (4.48 kB). View file
 
trackio/media/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.56 kB). View file
 
trackio/media/__pycache__/video.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
trackio/media/audio.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+ from pydub import AudioSegment
9
+
10
+ from trackio.media.media import TrackioMedia
11
+ from trackio.media.utils import check_ffmpeg_installed, check_path
12
+
13
+ SUPPORTED_FORMATS = ["wav", "mp3"]
14
+ AudioFormatType = Literal["wav", "mp3"]
15
+ TrackioAudioSourceType = str | Path | np.ndarray
16
+
17
+
18
+ class TrackioAudio(TrackioMedia):
19
+ """
20
+ Initializes an Audio object.
21
+
22
+ Example:
23
+ ```python
24
+ import trackio
25
+ import numpy as np
26
+
27
+ # Generate a 1-second 440 Hz sine wave (mono)
28
+ sr = 16000
29
+ t = np.linspace(0, 1, sr, endpoint=False)
30
+ wave = 0.2 * np.sin(2 * np.pi * 440 * t)
31
+ audio = trackio.Audio(wave, caption="A4 sine", sample_rate=sr, format="wav")
32
+ trackio.log({"tone": audio})
33
+
34
+ # Stereo from numpy array (shape: samples, 2)
35
+ stereo = np.stack([wave, wave], axis=1)
36
+ audio = trackio.Audio(stereo, caption="Stereo", sample_rate=sr, format="mp3")
37
+ trackio.log({"stereo": audio})
38
+
39
+ # From an existing file
40
+ audio = trackio.Audio("path/to/audio.wav", caption="From file")
41
+ trackio.log({"file_audio": audio})
42
+ ```
43
+
44
+ Args:
45
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
46
+ A path to an audio file, or a numpy array.
47
+ The array should be shaped `(samples,)` for mono or `(samples, 2)` for stereo.
48
+ Float arrays will be peak-normalized and converted to 16-bit PCM; integer arrays will be converted to 16-bit PCM as needed.
49
+ caption (`str`, *optional*):
50
+ A string caption for the audio.
51
+ sample_rate (`int`, *optional*):
52
+ Sample rate in Hz. Required when `value` is a numpy array.
53
+ format (`Literal["wav", "mp3"]`, *optional*):
54
+ Audio format used when `value` is a numpy array. Default is "wav".
55
+ """
56
+
57
+ TYPE = "trackio.audio"
58
+
59
+ def __init__(
60
+ self,
61
+ value: TrackioAudioSourceType,
62
+ caption: str | None = None,
63
+ sample_rate: int | None = None,
64
+ format: AudioFormatType | None = None,
65
+ ):
66
+ super().__init__(value, caption)
67
+ if isinstance(value, np.ndarray):
68
+ if sample_rate is None:
69
+ raise ValueError("Sample rate is required when value is an ndarray")
70
+ if format is None:
71
+ format = "wav"
72
+ self._format = format
73
+ self._sample_rate = sample_rate
74
+
75
+ def _save_media(self, file_path: Path):
76
+ if isinstance(self._value, np.ndarray):
77
+ TrackioAudio.write_audio(
78
+ data=self._value,
79
+ sample_rate=self._sample_rate,
80
+ filename=file_path,
81
+ format=self._format,
82
+ )
83
+ elif isinstance(self._value, str | Path):
84
+ if os.path.isfile(self._value):
85
+ shutil.copy(self._value, file_path)
86
+ else:
87
+ raise ValueError(f"File not found: {self._value}")
88
+
89
+ @staticmethod
90
+ def ensure_int16_pcm(data: np.ndarray) -> np.ndarray:
91
+ """
92
+ Convert input audio array to contiguous int16 PCM.
93
+ Peak normalization is applied to floating inputs.
94
+ """
95
+ arr = np.asarray(data)
96
+ if arr.ndim not in (1, 2):
97
+ raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])")
98
+
99
+ if arr.dtype != np.int16:
100
+ warnings.warn(
101
+ f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.",
102
+ stacklevel=2,
103
+ )
104
+
105
+ arr = np.nan_to_num(arr, copy=False)
106
+
107
+ # Floating types: normalize to peak 1.0, then scale to int16
108
+ if np.issubdtype(arr.dtype, np.floating):
109
+ max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0
110
+ if max_abs > 0.0:
111
+ arr = arr / max_abs
112
+ out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False)
113
+ return np.ascontiguousarray(out)
114
+
115
+ converters: dict[np.dtype, callable] = {
116
+ np.dtype(np.int16): lambda a: a,
117
+ np.dtype(np.int32): lambda a: (
118
+ (a.astype(np.int32) // 65536).astype(np.int16, copy=False)
119
+ ),
120
+ np.dtype(np.uint16): lambda a: (
121
+ (a.astype(np.int32) - 32768).astype(np.int16, copy=False)
122
+ ),
123
+ np.dtype(np.uint8): lambda a: (
124
+ (a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False)
125
+ ),
126
+ np.dtype(np.int8): lambda a: (
127
+ (a.astype(np.int32) * 256).astype(np.int16, copy=False)
128
+ ),
129
+ }
130
+
131
+ conv = converters.get(arr.dtype)
132
+ if conv is not None:
133
+ out = conv(arr)
134
+ return np.ascontiguousarray(out)
135
+ raise TypeError(f"Unsupported audio dtype: {arr.dtype}")
136
+
137
+ @staticmethod
138
+ def write_audio(
139
+ data: np.ndarray,
140
+ sample_rate: int,
141
+ filename: str | Path,
142
+ format: AudioFormatType = "wav",
143
+ ) -> None:
144
+ if not isinstance(sample_rate, int) or sample_rate <= 0:
145
+ raise ValueError(f"Invalid sample_rate: {sample_rate}")
146
+ if format not in SUPPORTED_FORMATS:
147
+ raise ValueError(
148
+ f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}"
149
+ )
150
+
151
+ check_path(filename)
152
+
153
+ pcm = TrackioAudio.ensure_int16_pcm(data)
154
+
155
+ if format != "wav":
156
+ check_ffmpeg_installed()
157
+
158
+ channels = 1 if pcm.ndim == 1 else pcm.shape[1]
159
+ audio = AudioSegment(
160
+ pcm.tobytes(),
161
+ frame_rate=sample_rate,
162
+ sample_width=2, # int16
163
+ channels=channels,
164
+ )
165
+
166
+ file = audio.export(str(filename), format=format)
167
+ file.close()
trackio/media/image.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from PIL import Image as PILImage
7
+
8
+ from trackio.media.media import TrackioMedia
9
+
10
+ TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
11
+
12
+
13
+ class TrackioImage(TrackioMedia):
14
+ """
15
+ Initializes an Image object.
16
+
17
+ Example:
18
+ ```python
19
+ import trackio
20
+ import numpy as np
21
+ from PIL import Image
22
+
23
+ # Create an image from numpy array
24
+ image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
25
+ image = trackio.Image(image_data, caption="Random image")
26
+ trackio.log({"my_image": image})
27
+
28
+ # Create an image from PIL Image
29
+ pil_image = Image.new('RGB', (100, 100), color='red')
30
+ image = trackio.Image(pil_image, caption="Red square")
31
+ trackio.log({"red_image": image})
32
+
33
+ # Create an image from file path
34
+ image = trackio.Image("path/to/image.jpg", caption="Photo from file")
35
+ trackio.log({"file_image": image})
36
+ ```
37
+
38
+ Args:
39
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*):
40
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
41
+ If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
42
+ caption (`str`, *optional*):
43
+ A string caption for the image.
44
+ """
45
+
46
+ TYPE = "trackio.image"
47
+
48
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
49
+ super().__init__(value, caption)
50
+ self._format: str | None = None
51
+
52
+ if not isinstance(self._value, TrackioImageSourceType):
53
+ raise ValueError(
54
+ f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}"
55
+ )
56
+ if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8:
57
+ raise ValueError(
58
+ f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
59
+ )
60
+ if (
61
+ isinstance(self._value, np.ndarray | PILImage.Image)
62
+ and self._format is None
63
+ ):
64
+ self._format = "png"
65
+
66
+ def _as_pil(self) -> PILImage.Image | None:
67
+ try:
68
+ if isinstance(self._value, np.ndarray):
69
+ arr = np.asarray(self._value).astype("uint8")
70
+ return PILImage.fromarray(arr).convert("RGBA")
71
+ if isinstance(self._value, PILImage.Image):
72
+ return self._value.convert("RGBA")
73
+ except Exception as e:
74
+ raise ValueError(f"Failed to process image data: {self._value}") from e
75
+ return None
76
+
77
+ def _save_media(self, file_path: Path):
78
+ if pil := self._as_pil():
79
+ pil.save(file_path, format=self._format)
80
+ elif isinstance(self._value, str | Path):
81
+ if os.path.isfile(self._value):
82
+ shutil.copy(self._value, file_path)
83
+ else:
84
+ raise ValueError(f"File not found: {self._value}")
trackio/media/media.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+
6
+ from trackio.media.utils import get_project_media_path
7
+ from trackio.utils import MEDIA_DIR
8
+
9
+
10
+ class TrackioMedia(ABC):
11
+ """
12
+ Abstract base class for Trackio media objects
13
+ Provides shared functionality for file handling and serialization.
14
+ """
15
+
16
+ TYPE: str
17
+
18
+ def __init_subclass__(cls, **kwargs):
19
+ """Ensure subclasses define the TYPE attribute."""
20
+ super().__init_subclass__(**kwargs)
21
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
22
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
23
+
24
+ def __init__(self, value, caption: str | None = None):
25
+ """
26
+ Saves the value and caption, and if the value is a file path, checks if the file exists.
27
+ """
28
+ self.caption = caption
29
+ self._value = value
30
+ self._file_path: Path | None = None
31
+
32
+ if isinstance(self._value, str | Path):
33
+ if not os.path.isfile(self._value):
34
+ raise ValueError(f"File not found: {self._value}")
35
+
36
+ def _file_extension(self) -> str:
37
+ if self._file_path:
38
+ return self._file_path.suffix[1:].lower()
39
+ if isinstance(self._value, str | Path):
40
+ path = Path(self._value)
41
+ return path.suffix[1:].lower()
42
+ if hasattr(self, "_format") and self._format:
43
+ return self._format
44
+ return "unknown"
45
+
46
+ def _get_relative_file_path(self) -> Path | None:
47
+ return self._file_path
48
+
49
+ def _get_absolute_file_path(self) -> Path | None:
50
+ if self._file_path:
51
+ return MEDIA_DIR / self._file_path
52
+ return None
53
+
54
+ def _save(self, project: str, run: str, step: int = 0):
55
+ if self._file_path:
56
+ return
57
+
58
+ media_dir = get_project_media_path(project=project, run=run, step=step)
59
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
60
+ file_path = media_dir / filename
61
+
62
+ self._save_media(file_path)
63
+ self._file_path = file_path.relative_to(MEDIA_DIR)
64
+
65
+ @abstractmethod
66
+ def _save_media(self, file_path: Path):
67
+ """
68
+ Performs the actual media saving logic.
69
+ """
70
+ pass
71
+
72
+ def _to_dict(self) -> dict:
73
+ if not self._file_path:
74
+ raise ValueError("Media must be saved to file before serialization")
75
+ return {
76
+ "_type": self.TYPE,
77
+ "file_path": str(self._get_relative_file_path()),
78
+ "caption": self.caption,
79
+ }
trackio/media/utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+
4
+ from trackio.utils import MEDIA_DIR
5
+
6
+
7
+ def check_path(file_path: str | Path) -> None:
8
+ """Raise an error if the parent directory does not exist."""
9
+ file_path = Path(file_path)
10
+ if not file_path.parent.exists():
11
+ try:
12
+ file_path.parent.mkdir(parents=True, exist_ok=True)
13
+ except OSError as e:
14
+ raise ValueError(
15
+ f"Failed to create parent directory {file_path.parent}: {e}"
16
+ )
17
+
18
+
19
+ def check_ffmpeg_installed() -> None:
20
+ """Raise an error if ffmpeg is not available on the system PATH."""
21
+ if shutil.which("ffmpeg") is None:
22
+ raise RuntimeError(
23
+ "ffmpeg is required to write video but was not found on your system. "
24
+ "Please install ffmpeg and ensure it is available on your PATH."
25
+ )
26
+
27
+
28
+ def get_project_media_path(
29
+ project: str,
30
+ run: str | None = None,
31
+ step: int | None = None,
32
+ relative_path: str | Path | None = None,
33
+ ) -> Path:
34
+ """
35
+ Get the full path where uploaded files are stored for a Trackio project (and create the directory if it doesn't exist).
36
+ If a run is not provided, the files are stored in a project-level directory with the given relative path.
37
+
38
+ Args:
39
+ project: The project name
40
+ run: The run name
41
+ step: The step number
42
+ relative_path: The relative path within the directory (only used if run is not provided)
43
+
44
+ Returns:
45
+ The full path to the media file
46
+ """
47
+ if step is not None and run is None:
48
+ raise ValueError("Uploading files at a specific step requires a run")
49
+
50
+ path = MEDIA_DIR / project
51
+ if run:
52
+ path /= run
53
+ if step is not None:
54
+ path /= str(step)
55
+ else:
56
+ path /= "files"
57
+ if relative_path:
58
+ path /= relative_path
59
+ path.mkdir(parents=True, exist_ok=True)
60
+ return path
trackio/media/video.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+
9
+ from trackio.media.media import TrackioMedia
10
+ from trackio.media.utils import check_ffmpeg_installed, check_path
11
+
12
+ TrackioVideoSourceType = str | Path | np.ndarray
13
+ TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
14
+ VideoCodec = Literal["h264", "vp9", "gif"]
15
+
16
+
17
+ class TrackioVideo(TrackioMedia):
18
+ """
19
+ Initializes a Video object.
20
+
21
+ Example:
22
+ ```python
23
+ import trackio
24
+ import numpy as np
25
+
26
+ # Create a simple video from numpy array
27
+ frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8)
28
+ video = trackio.Video(frames, caption="Random video", fps=30)
29
+
30
+ # Create a batch of videos
31
+ batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8)
32
+ batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15)
33
+
34
+ # Create video from file path
35
+ video = trackio.Video("path/to/video.mp4", caption="Video from file")
36
+ ```
37
+
38
+ Args:
39
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
40
+ A path to a video file, or a numpy array.
41
+ If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
42
+ It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width).
43
+ For the latter, the videos will be tiled into a grid.
44
+ caption (`str`, *optional*):
45
+ A string caption for the video.
46
+ fps (`int`, *optional*):
47
+ Frames per second for the video. Only used when value is an ndarray. Default is `24`.
48
+ format (`Literal["gif", "mp4", "webm"]`, *optional*):
49
+ Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif".
50
+ """
51
+
52
+ TYPE = "trackio.video"
53
+
54
+ def __init__(
55
+ self,
56
+ value: TrackioVideoSourceType,
57
+ caption: str | None = None,
58
+ fps: int | None = None,
59
+ format: TrackioVideoFormatType | None = None,
60
+ ):
61
+ super().__init__(value, caption)
62
+
63
+ if not isinstance(self._value, TrackioVideoSourceType):
64
+ raise ValueError(
65
+ f"Invalid value type, expected {TrackioVideoSourceType}, got {type(self._value)}"
66
+ )
67
+ if isinstance(self._value, np.ndarray):
68
+ if self._value.dtype != np.uint8:
69
+ raise ValueError(
70
+ f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
71
+ )
72
+ if format is None:
73
+ format = "gif"
74
+ if fps is None:
75
+ fps = 24
76
+ self._fps = fps
77
+ self._format = format
78
+
79
+ @staticmethod
80
+ def _check_array_format(video: np.ndarray) -> None:
81
+ """Raise an error if the array is not in the expected format."""
82
+ if not (video.ndim == 4 and video.shape[-1] == 3):
83
+ raise ValueError(
84
+ f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. "
85
+ f"Input has {video.ndim} dimensions, expected 4."
86
+ )
87
+ if video.dtype != np.uint8:
88
+ raise TypeError(
89
+ f"Expected dtype=uint8, got {video.dtype}. "
90
+ "Please convert your video data to uint8 format."
91
+ )
92
+
93
+ @staticmethod
94
+ def write_video(
95
+ file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec
96
+ ) -> None:
97
+ """RGB uint8 only, shape (F, H, W, 3)."""
98
+ check_ffmpeg_installed()
99
+ check_path(file_path)
100
+
101
+ if codec not in {"h264", "vp9", "gif"}:
102
+ raise ValueError("Unsupported codec. Use h264, vp9, or gif.")
103
+
104
+ arr = np.asarray(video)
105
+ TrackioVideo._check_array_format(arr)
106
+
107
+ frames = np.ascontiguousarray(arr)
108
+ _, height, width, _ = frames.shape
109
+ out_path = str(file_path)
110
+
111
+ cmd = [
112
+ "ffmpeg",
113
+ "-y",
114
+ "-f",
115
+ "rawvideo",
116
+ "-s",
117
+ f"{width}x{height}",
118
+ "-pix_fmt",
119
+ "rgb24",
120
+ "-r",
121
+ str(fps),
122
+ "-i",
123
+ "-",
124
+ "-an",
125
+ ]
126
+
127
+ if codec == "gif":
128
+ video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse"
129
+ cmd += [
130
+ "-vf",
131
+ video_filter,
132
+ "-loop",
133
+ "0",
134
+ ]
135
+ elif codec == "h264":
136
+ cmd += [
137
+ "-vcodec",
138
+ "libx264",
139
+ "-pix_fmt",
140
+ "yuv420p",
141
+ "-movflags",
142
+ "+faststart",
143
+ ]
144
+ elif codec == "vp9":
145
+ bpp = 0.08
146
+ bps = int(width * height * fps * bpp)
147
+ if bps >= 1_000_000:
148
+ bitrate = f"{round(bps / 1_000_000)}M"
149
+ elif bps >= 1_000:
150
+ bitrate = f"{round(bps / 1_000)}k"
151
+ else:
152
+ bitrate = str(max(bps, 1))
153
+ cmd += [
154
+ "-vcodec",
155
+ "libvpx-vp9",
156
+ "-b:v",
157
+ bitrate,
158
+ "-pix_fmt",
159
+ "yuv420p",
160
+ ]
161
+ cmd += [out_path]
162
+ proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
163
+ try:
164
+ for frame in frames:
165
+ proc.stdin.write(frame.tobytes())
166
+ finally:
167
+ if proc.stdin:
168
+ proc.stdin.close()
169
+ stderr = (
170
+ proc.stderr.read().decode("utf-8", errors="ignore")
171
+ if proc.stderr
172
+ else ""
173
+ )
174
+ ret = proc.wait()
175
+ if ret != 0:
176
+ raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}")
177
+
178
+ @property
179
+ def _codec(self) -> str:
180
+ match self._format:
181
+ case "gif":
182
+ return "gif"
183
+ case "mp4":
184
+ return "h264"
185
+ case "webm":
186
+ return "vp9"
187
+ case _:
188
+ raise ValueError(f"Unsupported format: {self._format}")
189
+
190
+ def _save_media(self, file_path: Path):
191
+ if isinstance(self._value, np.ndarray):
192
+ video = TrackioVideo._process_ndarray(self._value)
193
+ TrackioVideo.write_video(file_path, video, fps=self._fps, codec=self._codec)
194
+ elif isinstance(self._value, str | Path):
195
+ if os.path.isfile(self._value):
196
+ shutil.copy(self._value, file_path)
197
+ else:
198
+ raise ValueError(f"File not found: {self._value}")
199
+
200
+ @staticmethod
201
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
202
+ # Verify value is either 4D (single video) or 5D array (batched videos).
203
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
204
+ if value.ndim < 4:
205
+ raise ValueError(
206
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
207
+ )
208
+ if value.ndim > 5:
209
+ raise ValueError(
210
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
211
+ )
212
+ if value.ndim == 4:
213
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
214
+ value = value[np.newaxis, ...]
215
+
216
+ value = TrackioVideo._tile_batched_videos(value)
217
+ return value
218
+
219
+ @staticmethod
220
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
221
+ """
222
+ Tiles a batch of videos into a grid of videos.
223
+
224
+ Input format: (batch, frames, channels, height, width) - original FCHW format
225
+ Output format: (frames, total_height, total_width, channels)
226
+ """
227
+ batch_size, frames, channels, height, width = video.shape
228
+
229
+ next_pow2 = 1 << (batch_size - 1).bit_length()
230
+ if batch_size != next_pow2:
231
+ pad_len = next_pow2 - batch_size
232
+ pad_shape = (pad_len, frames, channels, height, width)
233
+ padding = np.zeros(pad_shape, dtype=video.dtype)
234
+ video = np.concatenate((video, padding), axis=0)
235
+ batch_size = next_pow2
236
+
237
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
238
+ n_cols = batch_size // n_rows
239
+
240
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
241
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
242
+
243
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
244
+ video = video.transpose(2, 0, 4, 1, 5, 3)
245
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
246
+ return video
trackio/package.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "trackio",
3
+ "version": "0.13.1",
4
+ "description": "",
5
+ "python": "true"
6
+ }
trackio/py.typed ADDED
File without changes
trackio/run.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ import warnings
4
+ from datetime import datetime, timezone
5
+
6
+ import huggingface_hub
7
+ from gradio_client import Client, handle_file
8
+
9
+ from trackio import utils
10
+ from trackio.gpu import GpuMonitor
11
+ from trackio.histogram import Histogram
12
+ from trackio.media import TrackioMedia
13
+ from trackio.sqlite_storage import SQLiteStorage
14
+ from trackio.table import Table
15
+ from trackio.typehints import LogEntry, SystemLogEntry, UploadEntry
16
+ from trackio.utils import _get_default_namespace
17
+
18
+ BATCH_SEND_INTERVAL = 0.5
19
+
20
+
21
+ class Run:
22
+ def __init__(
23
+ self,
24
+ url: str,
25
+ project: str,
26
+ client: Client | None,
27
+ name: str | None = None,
28
+ group: str | None = None,
29
+ config: dict | None = None,
30
+ space_id: str | None = None,
31
+ auto_log_gpu: bool = False,
32
+ gpu_log_interval: float = 10.0,
33
+ ):
34
+ self.url = url
35
+ self.project = project
36
+ self._client_lock = threading.Lock()
37
+ self._client_thread = None
38
+ self._client = client
39
+ self._space_id = space_id
40
+ self.name = name or utils.generate_readable_name(
41
+ SQLiteStorage.get_runs(project), space_id
42
+ )
43
+ self.group = group
44
+ self.config = utils.to_json_safe(config or {})
45
+
46
+ if isinstance(self.config, dict):
47
+ for key in self.config:
48
+ if key.startswith("_"):
49
+ raise ValueError(
50
+ f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)"
51
+ )
52
+
53
+ self.config["_Username"] = self._get_username()
54
+ self.config["_Created"] = datetime.now(timezone.utc).isoformat()
55
+ self.config["_Group"] = self.group
56
+
57
+ self._queued_logs: list[LogEntry] = []
58
+ self._queued_system_logs: list[SystemLogEntry] = []
59
+ self._queued_uploads: list[UploadEntry] = []
60
+ self._stop_flag = threading.Event()
61
+ self._config_logged = False
62
+
63
+ self._client_thread = threading.Thread(target=self._init_client_background)
64
+ self._client_thread.daemon = True
65
+ self._client_thread.start()
66
+
67
+ self._gpu_monitor: "GpuMonitor | None" = None
68
+ if auto_log_gpu:
69
+ self._gpu_monitor = GpuMonitor(self, interval=gpu_log_interval)
70
+ self._gpu_monitor.start()
71
+
72
+ def _get_username(self) -> str | None:
73
+ """Get the current HuggingFace username if logged in, otherwise None."""
74
+ try:
75
+ return _get_default_namespace()
76
+ except Exception:
77
+ return None
78
+
79
+ def _batch_sender(self):
80
+ """Send batched logs every BATCH_SEND_INTERVAL."""
81
+ while (
82
+ not self._stop_flag.is_set()
83
+ or len(self._queued_logs) > 0
84
+ or len(self._queued_system_logs) > 0
85
+ ):
86
+ if not self._stop_flag.is_set():
87
+ time.sleep(BATCH_SEND_INTERVAL)
88
+
89
+ with self._client_lock:
90
+ if self._client is None:
91
+ return
92
+ if self._queued_logs:
93
+ logs_to_send = self._queued_logs.copy()
94
+ self._queued_logs.clear()
95
+ self._client.predict(
96
+ api_name="/bulk_log",
97
+ logs=logs_to_send,
98
+ hf_token=huggingface_hub.utils.get_token(),
99
+ )
100
+ if self._queued_system_logs:
101
+ system_logs_to_send = self._queued_system_logs.copy()
102
+ self._queued_system_logs.clear()
103
+ self._client.predict(
104
+ api_name="/bulk_log_system",
105
+ logs=system_logs_to_send,
106
+ hf_token=huggingface_hub.utils.get_token(),
107
+ )
108
+ if self._queued_uploads:
109
+ uploads_to_send = self._queued_uploads.copy()
110
+ self._queued_uploads.clear()
111
+ self._client.predict(
112
+ api_name="/bulk_upload_media",
113
+ uploads=uploads_to_send,
114
+ hf_token=huggingface_hub.utils.get_token(),
115
+ )
116
+
117
+ def _init_client_background(self):
118
+ if self._client is None:
119
+ fib = utils.fibo()
120
+ for sleep_coefficient in fib:
121
+ try:
122
+ client = Client(self.url, verbose=False)
123
+
124
+ with self._client_lock:
125
+ self._client = client
126
+ break
127
+ except Exception:
128
+ pass
129
+ if sleep_coefficient is not None:
130
+ time.sleep(0.1 * sleep_coefficient)
131
+
132
+ self._batch_sender()
133
+
134
+ def _queue_upload(
135
+ self,
136
+ file_path,
137
+ step: int | None,
138
+ relative_path: str | None = None,
139
+ use_run_name: bool = True,
140
+ ):
141
+ """
142
+ Queues a media file for upload to a Space.
143
+
144
+ Args:
145
+ file_path:
146
+ The path to the file to upload.
147
+ step (`int` or `None`, *optional*):
148
+ The step number associated with this upload.
149
+ relative_path (`str` or `None`, *optional*):
150
+ The relative path within the project's files directory. Used when
151
+ uploading files via `trackio.save()`.
152
+ use_run_name (`bool`, *optional*):
153
+ Whether to use the run name for the uploaded file. This is set to
154
+ `False` when uploading files via `trackio.save()`.
155
+ """
156
+ upload_entry: UploadEntry = {
157
+ "project": self.project,
158
+ "run": self.name if use_run_name else None,
159
+ "step": step,
160
+ "relative_path": relative_path,
161
+ "uploaded_file": handle_file(file_path),
162
+ }
163
+ with self._client_lock:
164
+ self._queued_uploads.append(upload_entry)
165
+
166
+ def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
167
+ """
168
+ Serialize media in metrics and upload to space if needed.
169
+ """
170
+ value._save(self.project, self.name, step if step is not None else 0)
171
+ if self._space_id:
172
+ self._queue_upload(value._get_absolute_file_path(), step)
173
+ return value._to_dict()
174
+
175
+ def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None):
176
+ """
177
+ Scan a serialized table for media objects and queue them for upload to space.
178
+ """
179
+ if not self._space_id:
180
+ return
181
+
182
+ table_data = table_dict.get("_value", [])
183
+ for row in table_data:
184
+ for value in row.values():
185
+ if isinstance(value, dict) and value.get("_type") in [
186
+ "trackio.image",
187
+ "trackio.video",
188
+ "trackio.audio",
189
+ ]:
190
+ file_path = value.get("file_path")
191
+ if file_path:
192
+ from trackio.utils import MEDIA_DIR
193
+
194
+ absolute_path = MEDIA_DIR / file_path
195
+ self._queue_upload(absolute_path, step)
196
+ elif isinstance(value, list):
197
+ for item in value:
198
+ if isinstance(item, dict) and item.get("_type") in [
199
+ "trackio.image",
200
+ "trackio.video",
201
+ "trackio.audio",
202
+ ]:
203
+ file_path = item.get("file_path")
204
+ if file_path:
205
+ from trackio.utils import MEDIA_DIR
206
+
207
+ absolute_path = MEDIA_DIR / file_path
208
+ self._queue_upload(absolute_path, step)
209
+
210
+ def log(self, metrics: dict, step: int | None = None):
211
+ renamed_keys = []
212
+ new_metrics = {}
213
+
214
+ for k, v in metrics.items():
215
+ if k in utils.RESERVED_KEYS or k.startswith("__"):
216
+ new_key = f"__{k}"
217
+ renamed_keys.append(k)
218
+ new_metrics[new_key] = v
219
+ else:
220
+ new_metrics[k] = v
221
+
222
+ if renamed_keys:
223
+ warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'")
224
+
225
+ metrics = new_metrics
226
+ for key, value in metrics.items():
227
+ if isinstance(value, Table):
228
+ metrics[key] = value._to_dict(
229
+ project=self.project, run=self.name, step=step
230
+ )
231
+ self._scan_and_queue_media_uploads(metrics[key], step)
232
+ elif isinstance(value, Histogram):
233
+ metrics[key] = value._to_dict()
234
+ elif isinstance(value, TrackioMedia):
235
+ metrics[key] = self._process_media(value, step)
236
+ metrics = utils.serialize_values(metrics)
237
+
238
+ config_to_log = None
239
+ if not self._config_logged and self.config:
240
+ config_to_log = utils.to_json_safe(self.config)
241
+ self._config_logged = True
242
+
243
+ log_entry: LogEntry = {
244
+ "project": self.project,
245
+ "run": self.name,
246
+ "metrics": metrics,
247
+ "step": step,
248
+ "config": config_to_log,
249
+ }
250
+
251
+ with self._client_lock:
252
+ self._queued_logs.append(log_entry)
253
+
254
+ def log_system(self, metrics: dict):
255
+ """
256
+ Log system metrics (GPU, etc.) without a step number.
257
+ These metrics use timestamps for the x-axis instead of steps.
258
+ """
259
+ metrics = utils.serialize_values(metrics)
260
+ timestamp = datetime.now(timezone.utc).isoformat()
261
+
262
+ system_log_entry: SystemLogEntry = {
263
+ "project": self.project,
264
+ "run": self.name,
265
+ "metrics": metrics,
266
+ "timestamp": timestamp,
267
+ }
268
+
269
+ with self._client_lock:
270
+ self._queued_system_logs.append(system_log_entry)
271
+
272
+ def finish(self):
273
+ """Cleanup when run is finished."""
274
+ if self._gpu_monitor is not None:
275
+ self._gpu_monitor.stop()
276
+
277
+ self._stop_flag.set()
278
+
279
+ time.sleep(2 * BATCH_SEND_INTERVAL)
280
+
281
+ if self._client_thread is not None:
282
+ print("* Run finished. Uploading logs to Trackio (please wait...)")
283
+ self._client_thread.join()
trackio/sqlite_storage.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import sqlite3
4
+ import time
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from threading import Lock
8
+
9
+ try:
10
+ import fcntl
11
+ except ImportError: # fcntl is not available on Windows
12
+ fcntl = None
13
+
14
+ import huggingface_hub as hf
15
+ import orjson
16
+ import pandas as pd
17
+
18
+ from trackio.commit_scheduler import CommitScheduler
19
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
20
+ from trackio.utils import (
21
+ TRACKIO_DIR,
22
+ deserialize_values,
23
+ serialize_values,
24
+ )
25
+
26
+ DB_EXT = ".db"
27
+
28
+
29
+ class ProcessLock:
30
+ """A file-based lock that works across processes. Is a no-op on Windows."""
31
+
32
+ def __init__(self, lockfile_path: Path):
33
+ self.lockfile_path = lockfile_path
34
+ self.lockfile = None
35
+ self.is_windows = platform.system() == "Windows"
36
+
37
+ def __enter__(self):
38
+ """Acquire the lock with retry logic."""
39
+ if self.is_windows:
40
+ return self
41
+ self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
42
+ self.lockfile = open(self.lockfile_path, "w")
43
+
44
+ max_retries = 100
45
+ for attempt in range(max_retries):
46
+ try:
47
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
48
+ return self
49
+ except IOError:
50
+ if attempt < max_retries - 1:
51
+ time.sleep(0.1)
52
+ else:
53
+ raise IOError("Could not acquire database lock after 10 seconds")
54
+
55
+ def __exit__(self, exc_type, exc_val, exc_tb):
56
+ """Release the lock."""
57
+ if self.is_windows:
58
+ return
59
+
60
+ if self.lockfile:
61
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
62
+ self.lockfile.close()
63
+
64
+
65
+ class SQLiteStorage:
66
+ _dataset_import_attempted = False
67
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
68
+ _scheduler_lock = Lock()
69
+
70
+ @staticmethod
71
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
72
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
73
+ # Keep WAL for concurrency + performance on many small writes
74
+ conn.execute("PRAGMA journal_mode = WAL")
75
+ # ---- Minimal perf tweaks for many tiny transactions ----
76
+ # NORMAL = fsync at critical points only (safer than OFF, much faster than FULL)
77
+ conn.execute("PRAGMA synchronous = NORMAL")
78
+ # Keep temp data in memory to avoid disk hits during small writes
79
+ conn.execute("PRAGMA temp_store = MEMORY")
80
+ # Give SQLite a bit more room for cache (negative = KB, engine-managed)
81
+ conn.execute("PRAGMA cache_size = -20000")
82
+ # --------------------------------------------------------
83
+ conn.row_factory = sqlite3.Row
84
+ return conn
85
+
86
+ @staticmethod
87
+ def _get_process_lock(project: str) -> ProcessLock:
88
+ lockfile_path = TRACKIO_DIR / f"{project}.lock"
89
+ return ProcessLock(lockfile_path)
90
+
91
+ @staticmethod
92
+ def get_project_db_filename(project: str) -> str:
93
+ """Get the database filename for a specific project."""
94
+ safe_project_name = "".join(
95
+ c for c in project if c.isalnum() or c in ("-", "_")
96
+ ).rstrip()
97
+ if not safe_project_name:
98
+ safe_project_name = "default"
99
+ return f"{safe_project_name}{DB_EXT}"
100
+
101
+ @staticmethod
102
+ def get_project_db_path(project: str) -> Path:
103
+ """Get the database path for a specific project."""
104
+ filename = SQLiteStorage.get_project_db_filename(project)
105
+ return TRACKIO_DIR / filename
106
+
107
+ @staticmethod
108
+ def init_db(project: str) -> Path:
109
+ """
110
+ Initialize the SQLite database with required tables.
111
+ Returns the database path.
112
+ """
113
+ db_path = SQLiteStorage.get_project_db_path(project)
114
+ db_path.parent.mkdir(parents=True, exist_ok=True)
115
+ with SQLiteStorage._get_process_lock(project):
116
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
117
+ conn.execute("PRAGMA journal_mode = WAL")
118
+ conn.execute("PRAGMA synchronous = NORMAL")
119
+ conn.execute("PRAGMA temp_store = MEMORY")
120
+ conn.execute("PRAGMA cache_size = -20000")
121
+ cursor = conn.cursor()
122
+ cursor.execute(
123
+ """
124
+ CREATE TABLE IF NOT EXISTS metrics (
125
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
126
+ timestamp TEXT NOT NULL,
127
+ run_name TEXT NOT NULL,
128
+ step INTEGER NOT NULL,
129
+ metrics TEXT NOT NULL
130
+ )
131
+ """
132
+ )
133
+ cursor.execute(
134
+ """
135
+ CREATE TABLE IF NOT EXISTS configs (
136
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
137
+ run_name TEXT NOT NULL,
138
+ config TEXT NOT NULL,
139
+ created_at TEXT NOT NULL,
140
+ UNIQUE(run_name)
141
+ )
142
+ """
143
+ )
144
+ cursor.execute(
145
+ """
146
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
147
+ ON metrics(run_name, step)
148
+ """
149
+ )
150
+ cursor.execute(
151
+ """
152
+ CREATE INDEX IF NOT EXISTS idx_configs_run_name
153
+ ON configs(run_name)
154
+ """
155
+ )
156
+ cursor.execute(
157
+ """
158
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_timestamp
159
+ ON metrics(run_name, timestamp)
160
+ """
161
+ )
162
+ cursor.execute(
163
+ """
164
+ CREATE TABLE IF NOT EXISTS system_metrics (
165
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
166
+ timestamp TEXT NOT NULL,
167
+ run_name TEXT NOT NULL,
168
+ metrics TEXT NOT NULL
169
+ )
170
+ """
171
+ )
172
+ cursor.execute(
173
+ """
174
+ CREATE INDEX IF NOT EXISTS idx_system_metrics_run_timestamp
175
+ ON system_metrics(run_name, timestamp)
176
+ """
177
+ )
178
+ conn.commit()
179
+ return db_path
180
+
181
+ @staticmethod
182
+ def export_to_parquet():
183
+ """
184
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
185
+ Also exports system_metrics to separate parquet files with "_system.parquet" suffix.
186
+ """
187
+ if not SQLiteStorage._dataset_import_attempted:
188
+ return
189
+ if not TRACKIO_DIR.exists():
190
+ return
191
+
192
+ all_paths = os.listdir(TRACKIO_DIR)
193
+ db_names = [f for f in all_paths if f.endswith(DB_EXT)]
194
+ for db_name in db_names:
195
+ db_path = TRACKIO_DIR / db_name
196
+ parquet_path = db_path.with_suffix(".parquet")
197
+ system_parquet_path = db_path.with_suffix("") / ""
198
+ system_parquet_path = TRACKIO_DIR / (db_path.stem + "_system.parquet")
199
+ if (not parquet_path.exists()) or (
200
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
201
+ ):
202
+ with sqlite3.connect(str(db_path)) as conn:
203
+ df = pd.read_sql("SELECT * FROM metrics", conn)
204
+ if not df.empty:
205
+ metrics = df["metrics"].copy()
206
+ metrics = pd.DataFrame(
207
+ metrics.apply(
208
+ lambda x: deserialize_values(orjson.loads(x))
209
+ ).values.tolist(),
210
+ index=df.index,
211
+ )
212
+ del df["metrics"]
213
+ for col in metrics.columns:
214
+ df[col] = metrics[col]
215
+ df.to_parquet(
216
+ parquet_path,
217
+ write_page_index=True,
218
+ use_content_defined_chunking=True,
219
+ )
220
+
221
+ if (not system_parquet_path.exists()) or (
222
+ db_path.stat().st_mtime > system_parquet_path.stat().st_mtime
223
+ ):
224
+ with sqlite3.connect(str(db_path)) as conn:
225
+ try:
226
+ sys_df = pd.read_sql("SELECT * FROM system_metrics", conn)
227
+ except Exception:
228
+ sys_df = pd.DataFrame()
229
+ if not sys_df.empty:
230
+ sys_metrics = sys_df["metrics"].copy()
231
+ sys_metrics = pd.DataFrame(
232
+ sys_metrics.apply(
233
+ lambda x: deserialize_values(orjson.loads(x))
234
+ ).values.tolist(),
235
+ index=sys_df.index,
236
+ )
237
+ del sys_df["metrics"]
238
+ for col in sys_metrics.columns:
239
+ sys_df[col] = sys_metrics[col]
240
+ sys_df.to_parquet(
241
+ system_parquet_path,
242
+ write_page_index=True,
243
+ use_content_defined_chunking=True,
244
+ )
245
+
246
+ @staticmethod
247
+ def _cleanup_wal_sidecars(db_path: Path) -> None:
248
+ """Remove leftover -wal/-shm files for a DB basename (prevents disk I/O errors)."""
249
+ for suffix in ("-wal", "-shm"):
250
+ sidecar = Path(str(db_path) + suffix)
251
+ try:
252
+ if sidecar.exists():
253
+ sidecar.unlink()
254
+ except Exception:
255
+ pass
256
+
257
+ @staticmethod
258
+ def import_from_parquet():
259
+ """
260
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
261
+ Also imports system_metrics from "_system.parquet" files.
262
+ """
263
+ if not TRACKIO_DIR.exists():
264
+ return
265
+
266
+ all_paths = os.listdir(TRACKIO_DIR)
267
+ parquet_names = [
268
+ f
269
+ for f in all_paths
270
+ if f.endswith(".parquet") and not f.endswith("_system.parquet")
271
+ ]
272
+ for pq_name in parquet_names:
273
+ parquet_path = TRACKIO_DIR / pq_name
274
+ db_path = parquet_path.with_suffix(DB_EXT)
275
+
276
+ SQLiteStorage._cleanup_wal_sidecars(db_path)
277
+
278
+ df = pd.read_parquet(parquet_path)
279
+ if "metrics" not in df.columns:
280
+ metrics = df.copy()
281
+ other_cols = ["id", "timestamp", "run_name", "step"]
282
+ df = df[other_cols]
283
+ for col in other_cols:
284
+ del metrics[col]
285
+ metrics = orjson.loads(metrics.to_json(orient="records"))
286
+ df["metrics"] = [orjson.dumps(serialize_values(row)) for row in metrics]
287
+
288
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
289
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
290
+ conn.commit()
291
+
292
+ system_parquet_names = [f for f in all_paths if f.endswith("_system.parquet")]
293
+ for pq_name in system_parquet_names:
294
+ parquet_path = TRACKIO_DIR / pq_name
295
+ db_name = pq_name.replace("_system.parquet", DB_EXT)
296
+ db_path = TRACKIO_DIR / db_name
297
+
298
+ df = pd.read_parquet(parquet_path)
299
+ if "metrics" not in df.columns:
300
+ metrics = df.copy()
301
+ other_cols = ["id", "timestamp", "run_name"]
302
+ df = df[[c for c in other_cols if c in df.columns]]
303
+ for col in other_cols:
304
+ if col in metrics.columns:
305
+ del metrics[col]
306
+ metrics = orjson.loads(metrics.to_json(orient="records"))
307
+ df["metrics"] = [orjson.dumps(serialize_values(row)) for row in metrics]
308
+
309
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
310
+ df.to_sql("system_metrics", conn, if_exists="replace", index=False)
311
+ conn.commit()
312
+
313
+ @staticmethod
314
+ def get_scheduler():
315
+ """
316
+ Get the scheduler for the database based on the environment variables.
317
+ This applies to both local and Spaces.
318
+ """
319
+ with SQLiteStorage._scheduler_lock:
320
+ if SQLiteStorage._current_scheduler is not None:
321
+ return SQLiteStorage._current_scheduler
322
+ hf_token = os.environ.get("HF_TOKEN")
323
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
324
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
325
+ if dataset_id is None or space_repo_name is None:
326
+ scheduler = DummyCommitScheduler()
327
+ else:
328
+ scheduler = CommitScheduler(
329
+ repo_id=dataset_id,
330
+ repo_type="dataset",
331
+ folder_path=TRACKIO_DIR,
332
+ private=True,
333
+ allow_patterns=["*.parquet", "*_system.parquet", "media/**/*"],
334
+ squash_history=True,
335
+ token=hf_token,
336
+ on_before_commit=SQLiteStorage.export_to_parquet,
337
+ )
338
+ SQLiteStorage._current_scheduler = scheduler
339
+ return scheduler
340
+
341
+ @staticmethod
342
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
343
+ """
344
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
345
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
346
+ database locking errors when multiple processes access the same database.
347
+
348
+ This method is not used in the latest versions of Trackio (replaced by bulk_log) but
349
+ is kept for backwards compatibility for users who are connecting to a newer version of
350
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
351
+ """
352
+ db_path = SQLiteStorage.init_db(project)
353
+ with SQLiteStorage._get_process_lock(project):
354
+ with SQLiteStorage._get_connection(db_path) as conn:
355
+ cursor = conn.cursor()
356
+ cursor.execute(
357
+ """
358
+ SELECT MAX(step)
359
+ FROM metrics
360
+ WHERE run_name = ?
361
+ """,
362
+ (run,),
363
+ )
364
+ last_step = cursor.fetchone()[0]
365
+ current_step = (
366
+ 0
367
+ if step is None and last_step is None
368
+ else (step if step is not None else last_step + 1)
369
+ )
370
+ current_timestamp = datetime.now().isoformat()
371
+ cursor.execute(
372
+ """
373
+ INSERT INTO metrics
374
+ (timestamp, run_name, step, metrics)
375
+ VALUES (?, ?, ?, ?)
376
+ """,
377
+ (
378
+ current_timestamp,
379
+ run,
380
+ current_step,
381
+ orjson.dumps(serialize_values(metrics)),
382
+ ),
383
+ )
384
+ conn.commit()
385
+
386
+ @staticmethod
387
+ def bulk_log(
388
+ project: str,
389
+ run: str,
390
+ metrics_list: list[dict],
391
+ steps: list[int] | None = None,
392
+ timestamps: list[str] | None = None,
393
+ config: dict | None = None,
394
+ ):
395
+ """
396
+ Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
397
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
398
+ database locking errors when multiple processes access the same database.
399
+ """
400
+ if not metrics_list:
401
+ return
402
+
403
+ if timestamps is None:
404
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
405
+
406
+ db_path = SQLiteStorage.init_db(project)
407
+ with SQLiteStorage._get_process_lock(project):
408
+ with SQLiteStorage._get_connection(db_path) as conn:
409
+ cursor = conn.cursor()
410
+
411
+ if steps is None:
412
+ steps = list(range(len(metrics_list)))
413
+ elif any(s is None for s in steps):
414
+ cursor.execute(
415
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
416
+ )
417
+ last_step = cursor.fetchone()[0]
418
+ current_step = 0 if last_step is None else last_step + 1
419
+ processed_steps = []
420
+ for step in steps:
421
+ if step is None:
422
+ processed_steps.append(current_step)
423
+ current_step += 1
424
+ else:
425
+ processed_steps.append(step)
426
+ steps = processed_steps
427
+
428
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
429
+ timestamps
430
+ ):
431
+ raise ValueError(
432
+ "metrics_list, steps, and timestamps must have the same length"
433
+ )
434
+
435
+ data = []
436
+ for i, metrics in enumerate(metrics_list):
437
+ data.append(
438
+ (
439
+ timestamps[i],
440
+ run,
441
+ steps[i],
442
+ orjson.dumps(serialize_values(metrics)),
443
+ )
444
+ )
445
+
446
+ cursor.executemany(
447
+ """
448
+ INSERT INTO metrics
449
+ (timestamp, run_name, step, metrics)
450
+ VALUES (?, ?, ?, ?)
451
+ """,
452
+ data,
453
+ )
454
+
455
+ if config:
456
+ current_timestamp = datetime.now().isoformat()
457
+ cursor.execute(
458
+ """
459
+ INSERT OR REPLACE INTO configs
460
+ (run_name, config, created_at)
461
+ VALUES (?, ?, ?)
462
+ """,
463
+ (
464
+ run,
465
+ orjson.dumps(serialize_values(config)),
466
+ current_timestamp,
467
+ ),
468
+ )
469
+
470
+ conn.commit()
471
+
472
+ @staticmethod
473
+ def bulk_log_system(
474
+ project: str,
475
+ run: str,
476
+ metrics_list: list[dict],
477
+ timestamps: list[str] | None = None,
478
+ ):
479
+ """
480
+ Log system metrics (GPU, etc.) to the database without step numbers.
481
+ These metrics use timestamps for the x-axis instead of steps.
482
+ """
483
+ if not metrics_list:
484
+ return
485
+
486
+ if timestamps is None:
487
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
488
+
489
+ if len(metrics_list) != len(timestamps):
490
+ raise ValueError("metrics_list and timestamps must have the same length")
491
+
492
+ db_path = SQLiteStorage.init_db(project)
493
+ with SQLiteStorage._get_process_lock(project):
494
+ with SQLiteStorage._get_connection(db_path) as conn:
495
+ cursor = conn.cursor()
496
+ data = []
497
+ for i, metrics in enumerate(metrics_list):
498
+ data.append(
499
+ (
500
+ timestamps[i],
501
+ run,
502
+ orjson.dumps(serialize_values(metrics)),
503
+ )
504
+ )
505
+
506
+ cursor.executemany(
507
+ """
508
+ INSERT INTO system_metrics
509
+ (timestamp, run_name, metrics)
510
+ VALUES (?, ?, ?)
511
+ """,
512
+ data,
513
+ )
514
+ conn.commit()
515
+
516
+ @staticmethod
517
+ def get_system_logs(project: str, run: str) -> list[dict]:
518
+ """Retrieve system metrics for a specific run. Returns metrics with timestamps (no steps)."""
519
+ db_path = SQLiteStorage.get_project_db_path(project)
520
+ if not db_path.exists():
521
+ return []
522
+
523
+ with SQLiteStorage._get_connection(db_path) as conn:
524
+ cursor = conn.cursor()
525
+ try:
526
+ cursor.execute(
527
+ """
528
+ SELECT timestamp, metrics
529
+ FROM system_metrics
530
+ WHERE run_name = ?
531
+ ORDER BY timestamp
532
+ """,
533
+ (run,),
534
+ )
535
+
536
+ rows = cursor.fetchall()
537
+ results = []
538
+ for row in rows:
539
+ metrics = orjson.loads(row["metrics"])
540
+ metrics = deserialize_values(metrics)
541
+ metrics["timestamp"] = row["timestamp"]
542
+ results.append(metrics)
543
+ return results
544
+ except sqlite3.OperationalError as e:
545
+ if "no such table: system_metrics" in str(e):
546
+ return []
547
+ raise
548
+
549
+ @staticmethod
550
+ def get_all_system_metrics_for_run(project: str, run: str) -> list[str]:
551
+ """Get all system metric names for a specific project/run."""
552
+ db_path = SQLiteStorage.get_project_db_path(project)
553
+ if not db_path.exists():
554
+ return []
555
+
556
+ with SQLiteStorage._get_connection(db_path) as conn:
557
+ cursor = conn.cursor()
558
+ try:
559
+ cursor.execute(
560
+ """
561
+ SELECT metrics
562
+ FROM system_metrics
563
+ WHERE run_name = ?
564
+ ORDER BY timestamp
565
+ """,
566
+ (run,),
567
+ )
568
+
569
+ rows = cursor.fetchall()
570
+ all_metrics = set()
571
+ for row in rows:
572
+ metrics = orjson.loads(row["metrics"])
573
+ metrics = deserialize_values(metrics)
574
+ for key in metrics.keys():
575
+ if key != "timestamp":
576
+ all_metrics.add(key)
577
+ return sorted(list(all_metrics))
578
+ except sqlite3.OperationalError as e:
579
+ if "no such table: system_metrics" in str(e):
580
+ return []
581
+ raise
582
+
583
+ @staticmethod
584
+ def has_system_metrics(project: str) -> bool:
585
+ """Check if a project has any system metrics logged."""
586
+ db_path = SQLiteStorage.get_project_db_path(project)
587
+ if not db_path.exists():
588
+ return False
589
+
590
+ with SQLiteStorage._get_connection(db_path) as conn:
591
+ cursor = conn.cursor()
592
+ try:
593
+ cursor.execute("SELECT COUNT(*) FROM system_metrics LIMIT 1")
594
+ count = cursor.fetchone()[0]
595
+ return count > 0
596
+ except sqlite3.OperationalError:
597
+ return False
598
+
599
+ @staticmethod
600
+ def get_logs(project: str, run: str) -> list[dict]:
601
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
602
+ db_path = SQLiteStorage.get_project_db_path(project)
603
+ if not db_path.exists():
604
+ return []
605
+
606
+ with SQLiteStorage._get_connection(db_path) as conn:
607
+ cursor = conn.cursor()
608
+ cursor.execute(
609
+ """
610
+ SELECT timestamp, step, metrics
611
+ FROM metrics
612
+ WHERE run_name = ?
613
+ ORDER BY timestamp
614
+ """,
615
+ (run,),
616
+ )
617
+
618
+ rows = cursor.fetchall()
619
+ results = []
620
+ for row in rows:
621
+ metrics = orjson.loads(row["metrics"])
622
+ metrics = deserialize_values(metrics)
623
+ metrics["timestamp"] = row["timestamp"]
624
+ metrics["step"] = row["step"]
625
+ results.append(metrics)
626
+ return results
627
+
628
+ @staticmethod
629
+ def load_from_dataset():
630
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
631
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
632
+ if dataset_id is not None and space_repo_name is not None:
633
+ hfapi = hf.HfApi()
634
+ updated = False
635
+ if not TRACKIO_DIR.exists():
636
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
637
+ with SQLiteStorage.get_scheduler().lock:
638
+ try:
639
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
640
+ for file in files:
641
+ # Download parquet and media assets
642
+ if not (file.endswith(".parquet") or file.startswith("media/")):
643
+ continue
644
+ if (TRACKIO_DIR / file).exists():
645
+ continue
646
+ hf.hf_hub_download(
647
+ dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
648
+ )
649
+ updated = True
650
+ except hf.errors.EntryNotFoundError:
651
+ pass
652
+ except hf.errors.RepositoryNotFoundError:
653
+ pass
654
+ if updated:
655
+ SQLiteStorage.import_from_parquet()
656
+ SQLiteStorage._dataset_import_attempted = True
657
+
658
+ @staticmethod
659
+ def get_projects() -> list[str]:
660
+ """
661
+ Get list of all projects by scanning the database files in the trackio directory.
662
+ """
663
+ if not SQLiteStorage._dataset_import_attempted:
664
+ SQLiteStorage.load_from_dataset()
665
+
666
+ projects: set[str] = set()
667
+ if not TRACKIO_DIR.exists():
668
+ return []
669
+
670
+ for db_file in TRACKIO_DIR.glob(f"*{DB_EXT}"):
671
+ project_name = db_file.stem
672
+ projects.add(project_name)
673
+ return sorted(projects)
674
+
675
+ @staticmethod
676
+ def get_runs(project: str) -> list[str]:
677
+ """Get list of all runs for a project."""
678
+ db_path = SQLiteStorage.get_project_db_path(project)
679
+ if not db_path.exists():
680
+ return []
681
+
682
+ with SQLiteStorage._get_connection(db_path) as conn:
683
+ cursor = conn.cursor()
684
+ cursor.execute(
685
+ "SELECT DISTINCT run_name FROM metrics",
686
+ )
687
+ return [row[0] for row in cursor.fetchall()]
688
+
689
+ @staticmethod
690
+ def get_max_steps_for_runs(project: str) -> dict[str, int]:
691
+ """Get the maximum step for each run in a project."""
692
+ db_path = SQLiteStorage.get_project_db_path(project)
693
+ if not db_path.exists():
694
+ return {}
695
+
696
+ with SQLiteStorage._get_connection(db_path) as conn:
697
+ cursor = conn.cursor()
698
+ cursor.execute(
699
+ """
700
+ SELECT run_name, MAX(step) as max_step
701
+ FROM metrics
702
+ GROUP BY run_name
703
+ """
704
+ )
705
+
706
+ results = {}
707
+ for row in cursor.fetchall():
708
+ results[row["run_name"]] = row["max_step"]
709
+
710
+ return results
711
+
712
+ @staticmethod
713
+ def store_config(project: str, run: str, config: dict) -> None:
714
+ """Store configuration for a run."""
715
+ db_path = SQLiteStorage.init_db(project)
716
+
717
+ with SQLiteStorage._get_process_lock(project):
718
+ with SQLiteStorage._get_connection(db_path) as conn:
719
+ cursor = conn.cursor()
720
+ current_timestamp = datetime.now().isoformat()
721
+
722
+ cursor.execute(
723
+ """
724
+ INSERT OR REPLACE INTO configs
725
+ (run_name, config, created_at)
726
+ VALUES (?, ?, ?)
727
+ """,
728
+ (run, orjson.dumps(serialize_values(config)), current_timestamp),
729
+ )
730
+ conn.commit()
731
+
732
+ @staticmethod
733
+ def get_run_config(project: str, run: str) -> dict | None:
734
+ """Get configuration for a specific run."""
735
+ db_path = SQLiteStorage.get_project_db_path(project)
736
+ if not db_path.exists():
737
+ return None
738
+
739
+ with SQLiteStorage._get_connection(db_path) as conn:
740
+ cursor = conn.cursor()
741
+ try:
742
+ cursor.execute(
743
+ """
744
+ SELECT config FROM configs WHERE run_name = ?
745
+ """,
746
+ (run,),
747
+ )
748
+
749
+ row = cursor.fetchone()
750
+ if row:
751
+ config = orjson.loads(row["config"])
752
+ return deserialize_values(config)
753
+ return None
754
+ except sqlite3.OperationalError as e:
755
+ if "no such table: configs" in str(e):
756
+ return None
757
+ raise
758
+
759
+ @staticmethod
760
+ def delete_run(project: str, run: str) -> bool:
761
+ """Delete a run from the database (metrics, config, and system_metrics)."""
762
+ db_path = SQLiteStorage.get_project_db_path(project)
763
+ if not db_path.exists():
764
+ return False
765
+
766
+ with SQLiteStorage._get_process_lock(project):
767
+ with SQLiteStorage._get_connection(db_path) as conn:
768
+ cursor = conn.cursor()
769
+ try:
770
+ cursor.execute("DELETE FROM metrics WHERE run_name = ?", (run,))
771
+ cursor.execute("DELETE FROM configs WHERE run_name = ?", (run,))
772
+ try:
773
+ cursor.execute(
774
+ "DELETE FROM system_metrics WHERE run_name = ?", (run,)
775
+ )
776
+ except sqlite3.OperationalError:
777
+ pass
778
+ conn.commit()
779
+ return True
780
+ except sqlite3.Error:
781
+ return False
782
+
783
+ @staticmethod
784
+ def get_all_run_configs(project: str) -> dict[str, dict]:
785
+ """Get configurations for all runs in a project."""
786
+ db_path = SQLiteStorage.get_project_db_path(project)
787
+ if not db_path.exists():
788
+ return {}
789
+
790
+ with SQLiteStorage._get_connection(db_path) as conn:
791
+ cursor = conn.cursor()
792
+ try:
793
+ cursor.execute(
794
+ """
795
+ SELECT run_name, config FROM configs
796
+ """
797
+ )
798
+
799
+ results = {}
800
+ for row in cursor.fetchall():
801
+ config = orjson.loads(row["config"])
802
+ results[row["run_name"]] = deserialize_values(config)
803
+ return results
804
+ except sqlite3.OperationalError as e:
805
+ if "no such table: configs" in str(e):
806
+ return {}
807
+ raise
808
+
809
+ @staticmethod
810
+ def get_metric_values(project: str, run: str, metric_name: str) -> list[dict]:
811
+ """Get all values for a specific metric in a project/run."""
812
+ db_path = SQLiteStorage.get_project_db_path(project)
813
+ if not db_path.exists():
814
+ return []
815
+
816
+ with SQLiteStorage._get_connection(db_path) as conn:
817
+ cursor = conn.cursor()
818
+ cursor.execute(
819
+ """
820
+ SELECT timestamp, step, metrics
821
+ FROM metrics
822
+ WHERE run_name = ?
823
+ ORDER BY timestamp
824
+ """,
825
+ (run,),
826
+ )
827
+
828
+ rows = cursor.fetchall()
829
+ results = []
830
+ for row in rows:
831
+ metrics = orjson.loads(row["metrics"])
832
+ metrics = deserialize_values(metrics)
833
+ if metric_name in metrics:
834
+ results.append(
835
+ {
836
+ "timestamp": row["timestamp"],
837
+ "step": row["step"],
838
+ "value": metrics[metric_name],
839
+ }
840
+ )
841
+ return results
842
+
843
+ @staticmethod
844
+ def get_all_metrics_for_run(project: str, run: str) -> list[str]:
845
+ """Get all metric names for a specific project/run."""
846
+ db_path = SQLiteStorage.get_project_db_path(project)
847
+ if not db_path.exists():
848
+ return []
849
+
850
+ with SQLiteStorage._get_connection(db_path) as conn:
851
+ cursor = conn.cursor()
852
+ cursor.execute(
853
+ """
854
+ SELECT metrics
855
+ FROM metrics
856
+ WHERE run_name = ?
857
+ ORDER BY timestamp
858
+ """,
859
+ (run,),
860
+ )
861
+
862
+ rows = cursor.fetchall()
863
+ all_metrics = set()
864
+ for row in rows:
865
+ metrics = orjson.loads(row["metrics"])
866
+ metrics = deserialize_values(metrics)
867
+ for key in metrics.keys():
868
+ if key not in ["timestamp", "step"]:
869
+ all_metrics.add(key)
870
+ return sorted(list(all_metrics))
871
+
872
+ def finish(self):
873
+ """Cleanup when run is finished."""
874
+ pass
trackio/table.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Literal
3
+
4
+ from pandas import DataFrame
5
+
6
+ from trackio.media.media import TrackioMedia
7
+ from trackio.utils import MEDIA_DIR
8
+
9
+
10
+ class Table:
11
+ """
12
+ Initializes a Table object.
13
+
14
+ Tables can be used to log tabular data including images, numbers, and text.
15
+
16
+ Args:
17
+ columns (`list[str]`, *optional*):
18
+ Names of the columns in the table. Optional if `data` is provided. Not
19
+ expected if `dataframe` is provided. Currently ignored.
20
+ data (`list[list[Any]]`, *optional*):
21
+ 2D row-oriented array of values. Each value can be a number, a string
22
+ (treated as Markdown and truncated if too long), or a `Trackio.Image` or
23
+ list of `Trackio.Image` objects.
24
+ dataframe (`pandas.DataFrame`, *optional*):
25
+ DataFrame used to create the table. When set, `data` and `columns`
26
+ arguments are ignored.
27
+ rows (`list[list[Any]]`, *optional*):
28
+ Currently ignored.
29
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
30
+ Currently ignored.
31
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
32
+ Currently ignored.
33
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
34
+ Currently ignored.
35
+ """
36
+
37
+ TYPE = "trackio.table"
38
+
39
+ def __init__(
40
+ self,
41
+ columns: list[str] | None = None,
42
+ data: list[list[Any]] | None = None,
43
+ dataframe: DataFrame | None = None,
44
+ rows: list[list[Any]] | None = None,
45
+ optional: bool | list[bool] = True,
46
+ allow_mixed_types: bool = False,
47
+ log_mode: Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] | None = "IMMUTABLE",
48
+ ):
49
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
50
+ # for now (like `rows`) they are included for API compat but don't do anything.
51
+ if dataframe is None:
52
+ self.data = DataFrame(data) if data is not None else DataFrame()
53
+ else:
54
+ self.data = dataframe
55
+
56
+ def _has_media_objects(self, dataframe: DataFrame) -> bool:
57
+ """Check if dataframe contains any TrackioMedia objects or lists of TrackioMedia objects."""
58
+ for col in dataframe.columns:
59
+ if dataframe[col].apply(lambda x: isinstance(x, TrackioMedia)).any():
60
+ return True
61
+ if (
62
+ dataframe[col]
63
+ .apply(
64
+ lambda x: isinstance(x, list)
65
+ and len(x) > 0
66
+ and isinstance(x[0], TrackioMedia)
67
+ )
68
+ .any()
69
+ ):
70
+ return True
71
+ return False
72
+
73
+ def _process_data(self, project: str, run: str, step: int = 0):
74
+ """Convert dataframe to dict format, processing any TrackioMedia objects if present."""
75
+ df = self.data
76
+ if not self._has_media_objects(df):
77
+ return df.to_dict(orient="records")
78
+
79
+ processed_df = df.copy()
80
+ for col in processed_df.columns:
81
+ for idx in processed_df.index:
82
+ value = processed_df.at[idx, col]
83
+ if isinstance(value, TrackioMedia):
84
+ value._save(project, run, step)
85
+ processed_df.at[idx, col] = value._to_dict()
86
+ if (
87
+ isinstance(value, list)
88
+ and len(value) > 0
89
+ and isinstance(value[0], TrackioMedia)
90
+ ):
91
+ [v._save(project, run, step) for v in value]
92
+ processed_df.at[idx, col] = [v._to_dict() for v in value]
93
+
94
+ return processed_df.to_dict(orient="records")
95
+
96
+ @staticmethod
97
+ def to_display_format(table_data: list[dict]) -> list[dict]:
98
+ """
99
+ Converts stored table data to display format for UI rendering.
100
+
101
+ Note:
102
+ This does not use the `self.data` attribute, but instead uses the
103
+ `table_data` parameter, which is what the UI receives.
104
+
105
+ Args:
106
+ table_data (`list[dict]`):
107
+ List of dictionaries representing table rows (from stored `_value`).
108
+
109
+ Returns:
110
+ `list[dict]`: Table data with images converted to markdown syntax and long
111
+ text truncated.
112
+ """
113
+ truncate_length = int(os.getenv("TRACKIO_TABLE_TRUNCATE_LENGTH", "250"))
114
+
115
+ def convert_image_to_markdown(image_data: dict) -> str:
116
+ relative_path = image_data.get("file_path", "")
117
+ caption = image_data.get("caption", "")
118
+ absolute_path = MEDIA_DIR / relative_path
119
+ return f'<img src="/gradio_api/file={absolute_path}" alt="{caption}" />'
120
+
121
+ processed_data = []
122
+ for row in table_data:
123
+ processed_row = {}
124
+ for key, value in row.items():
125
+ if isinstance(value, dict) and value.get("_type") == "trackio.image":
126
+ processed_row[key] = convert_image_to_markdown(value)
127
+ elif (
128
+ isinstance(value, list)
129
+ and len(value) > 0
130
+ and isinstance(value[0], dict)
131
+ and value[0].get("_type") == "trackio.image"
132
+ ):
133
+ # This assumes that if the first item is an image, all items are images. Ok for now since we don't support mixed types in a single cell.
134
+ processed_row[key] = (
135
+ '<div style="display: flex; gap: 10px;">'
136
+ + "".join([convert_image_to_markdown(item) for item in value])
137
+ + "</div>"
138
+ )
139
+ elif isinstance(value, str) and len(value) > truncate_length:
140
+ truncated = value[:truncate_length]
141
+ full_text = value.replace("<", "&lt;").replace(">", "&gt;")
142
+ processed_row[key] = (
143
+ f'<details style="display: inline;">'
144
+ f'<summary style="display: inline; cursor: pointer;">{truncated}…<span><em>(truncated, click to expand)</em></span></summary>'
145
+ f'<div style="margin-top: 10px; padding: 10px; background: #f5f5f5; border-radius: 4px; max-height: 400px; overflow: auto;">'
146
+ f'<pre style="white-space: pre-wrap; word-wrap: break-word; margin: 0;">{full_text}</pre>'
147
+ f"</div>"
148
+ f"</details>"
149
+ )
150
+ else:
151
+ processed_row[key] = value
152
+ processed_data.append(processed_row)
153
+ return processed_data
154
+
155
+ def _to_dict(self, project: str, run: str, step: int = 0):
156
+ """
157
+ Converts the table to a dictionary representation.
158
+
159
+ Args:
160
+ project (`str`):
161
+ Project name for saving media files.
162
+ run (`str`):
163
+ Run name for saving media files.
164
+ step (`int`, *optional*, defaults to `0`):
165
+ Step number for saving media files.
166
+ """
167
+ data = self._process_data(project, run, step)
168
+ return {
169
+ "_type": self.TYPE,
170
+ "_value": data,
171
+ }
trackio/typehints.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypedDict
2
+
3
+ from gradio import FileData
4
+
5
+
6
+ class LogEntry(TypedDict):
7
+ project: str
8
+ run: str
9
+ metrics: dict[str, Any]
10
+ step: int | None
11
+ config: dict[str, Any] | None
12
+
13
+
14
+ class SystemLogEntry(TypedDict):
15
+ project: str
16
+ run: str
17
+ metrics: dict[str, Any]
18
+ timestamp: str
19
+
20
+
21
+ class UploadEntry(TypedDict):
22
+ project: str
23
+ run: str | None
24
+ step: int | None
25
+ relative_path: str | None
26
+ uploaded_file: FileData