test_484 / trackio /__init__.py
abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
a28b86d verified
import atexit
import glob
import json
import logging
import os
import shutil
import warnings
import webbrowser
from pathlib import Path
from typing import Any
import huggingface_hub
from gradio.themes import ThemeClass
from gradio.utils import TupleNoPrint
from gradio_client import Client, handle_file
from huggingface_hub import SpaceStorage
from huggingface_hub.errors import LocalTokenNotFoundError
from trackio import context_vars, deploy, utils
from trackio.alerts import AlertLevel
from trackio.api import Api
from trackio.apple_gpu import apple_gpu_available
from trackio.apple_gpu import log_apple_gpu as _log_apple_gpu
from trackio.deploy import freeze, sync
from trackio.frontend_server import mount_frontend
from trackio.gpu import gpu_available
from trackio.gpu import log_gpu as _log_nvidia_gpu
from trackio.histogram import Histogram
from trackio.imports import import_csv, import_tf_events
from trackio.markdown import Markdown
from trackio.media import (
TrackioAudio,
TrackioImage,
TrackioVideo,
get_project_media_path,
)
from trackio.run import Run
from trackio.server import make_trackio_server
from trackio.sqlite_storage import SQLiteStorage
from trackio.table import Table
from trackio.typehints import UploadEntry
from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
logging.getLogger("httpx").setLevel(logging.WARNING)
warnings.filterwarnings(
"ignore",
message="Empty session being created. Install gradio\\[oauth\\]",
category=UserWarning,
module="gradio.helpers",
)
__version__ = json.loads(Path(__file__).parent.joinpath("package.json").read_text())[
"version"
]
__all__ = [
"init",
"log",
"log_system",
"log_gpu",
"finish",
"alert",
"AlertLevel",
"show",
"sync",
"freeze",
"delete_project",
"import_csv",
"import_tf_events",
"save",
"Image",
"Video",
"Audio",
"Table",
"Histogram",
"Markdown",
"Api",
]
Image = TrackioImage
Video = TrackioVideo
Audio = TrackioAudio
config = {}
_atexit_registered = False
_projects_notified_auto_log_hw: set[str] = set()
def _cleanup_current_run():
run = context_vars.current_run.get()
if run is not None:
try:
run.finish()
except Exception:
pass
def init(
project: str,
name: str | None = None,
group: str | None = None,
space_id: str | None = None,
space_storage: SpaceStorage | None = None,
dataset_id: str | None = None,
bucket_id: str | None = None,
config: dict | None = None,
resume: str = "never",
settings: Any = None,
private: bool | None = None,
embed: bool = True,
auto_log_gpu: bool | None = None,
gpu_log_interval: float = 10.0,
webhook_url: str | None = None,
webhook_min_level: AlertLevel | str | None = None,
) -> Run:
"""
Creates a new Trackio project and returns a [`Run`] object.
Args:
project (`str`):
The name of the project (can be an existing project to continue tracking or
a new project to start tracking from scratch).
name (`str`, *optional*):
The name of the run (if not provided, a default name will be generated).
group (`str`, *optional*):
The name of the group which this run belongs to in order to help organize
related runs together. You can toggle the entire group's visibilitiy in the
dashboard.
space_id (`str`, *optional*):
If provided, the project will be logged to a Hugging Face Space instead of
a local directory. Should be a complete Space name like
`"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
case the Space will be created in the currently-logged-in Hugging Face
user's namespace. If the Space does not exist, it will be created. If the
Space already exists, the project will be logged to it. Can also be set
via the `TRACKIO_SPACE_ID` environment variable. You cannot log to a
Space that has been **frozen** (converted to the static SDK); use
``trackio.sync(..., sdk="static")`` only after you are done logging.
space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
Choice of persistent storage tier.
dataset_id (`str`, *optional*):
Deprecated. Use `bucket_id` instead.
bucket_id (`str`, *optional*):
The ID of the Hugging Face Bucket to use for metric persistence. By default,
when a `space_id` is provided and `bucket_id` is not explicitly set, a
bucket is auto-generated from the space_id. Buckets provide
S3-like storage without git overhead - the SQLite database is stored directly
via `hf-mount` in the Space. Specify a Bucket with name like
`"username/bucketname"` or just `"bucketname"`.
config (`dict`, *optional*):
A dictionary of configuration options. Provided for compatibility with
`wandb.init()`.
resume (`str`, *optional*, defaults to `"never"`):
Controls how to handle resuming a run. Can be one of:
- `"must"`: Must resume the run with the given name, raises error if run
doesn't exist
- `"allow"`: Resume the run if it exists, otherwise create a new run
- `"never"`: Never resume a run, always create a new one
private (`bool`, *optional*):
Whether to make the Space private. If None (default), the repo will be
public unless the organization's default is private. This value is ignored
if the repo already exists.
settings (`Any`, *optional*):
Not used. Provided for compatibility with `wandb.init()`.
embed (`bool`, *optional*, defaults to `True`):
If running inside a Jupyter/Colab notebook, whether the dashboard should
automatically be embedded in the cell when trackio.init() is called. For
local runs, this launches a local Gradio app and embeds it. For Space runs,
this embeds the Space URL. In Colab, the local dashboard will be accessible
via a public share URL (default Gradio behavior).
auto_log_gpu (`bool` or `None`, *optional*, defaults to `None`):
Controls automatic GPU metrics logging. If `None` (default), GPU logging
is automatically enabled when `nvidia-ml-py` is installed and an NVIDIA
GPU or Apple M series is detected. Set to `True` to force enable or
`False` to disable.
gpu_log_interval (`float`, *optional*, defaults to `10.0`):
The interval in seconds between automatic GPU metric logs.
Only used when `auto_log_gpu=True`.
webhook_url (`str`, *optional*):
A webhook URL to POST alert payloads to when `trackio.alert()` is
called. Supports Slack and Discord webhook URLs natively (payloads
are formatted automatically). Can also be set via the
`TRACKIO_WEBHOOK_URL` environment variable. Individual alerts can
override this URL by passing `webhook_url` to `trackio.alert()`.
webhook_min_level (`AlertLevel` or `str`, *optional*):
Minimum alert level that should trigger webhook delivery.
For example, `AlertLevel.WARN` sends only `WARN` and `ERROR`
alerts to the webhook destination. Can also be set via
`TRACKIO_WEBHOOK_MIN_LEVEL`.
Returns:
`Run`: A [`Run`] object that can be used to log metrics and finish the run.
"""
if settings is not None:
warnings.warn(
"* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
)
space_id = space_id or os.environ.get("TRACKIO_SPACE_ID")
bucket_id = bucket_id or os.environ.get("TRACKIO_BUCKET_ID")
if space_id is None and dataset_id is not None:
raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
if dataset_id is not None and bucket_id is not None:
raise ValueError("Cannot provide both `dataset_id` and `bucket_id`.")
try:
space_id, dataset_id, bucket_id = utils.preprocess_space_and_dataset_ids(
space_id, dataset_id, bucket_id
)
except LocalTokenNotFoundError as e:
raise LocalTokenNotFoundError(
f"You must be logged in to Hugging Face locally when `space_id` is provided to deploy to a Space. {e}"
) from e
if space_id is None and bucket_id is not None:
warnings.warn(
"trackio.init() has `bucket_id` set but `space_id` is None: metrics will be logged "
"locally only. Pass `space_id` to create or use a Hugging Face Space, which will be "
"attached to the Hugging Face Bucket.",
UserWarning,
stacklevel=2,
)
if space_id is not None:
deploy.raise_if_space_is_frozen_for_logging(space_id)
url = context_vars.current_server.get()
if space_id is not None:
if url is None:
url = space_id
context_vars.current_server.set(url)
context_vars.current_space_id.set(space_id)
_should_embed_local = False
if (
context_vars.current_project.get() is None
or context_vars.current_project.get() != project
):
print(f"* Trackio project initialized: {project}")
if bucket_id is not None:
os.environ["TRACKIO_BUCKET_ID"] = bucket_id
bucket_url = f"https://huggingface.co/buckets/{bucket_id}"
print(
f"* Trackio metrics will be synced to Hugging Face Bucket: {bucket_url}"
)
elif dataset_id is not None:
os.environ["TRACKIO_DATASET_ID"] = dataset_id
print(
f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
)
if space_id is None:
print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
_should_embed_local = embed and utils.is_in_notebook()
if not _should_embed_local:
utils.print_dashboard_instructions(project)
else:
deploy.create_space_if_not_exists(
space_id,
space_storage,
dataset_id,
bucket_id,
private,
)
user_name, space_name = space_id.split("/")
space_url = deploy.SPACE_HOST_URL.format(
user_name=user_name, space_name=space_name
)
if utils.is_in_notebook() and embed:
utils.embed_url_in_notebook(space_url)
context_vars.current_project.set(project)
if resume == "must":
if name is None:
raise ValueError("Must provide a run name when resume='must'")
if name not in SQLiteStorage.get_runs(project):
raise ValueError(f"Run '{name}' does not exist in project '{project}'")
resumed = True
elif resume == "allow":
resumed = name is not None and name in SQLiteStorage.get_runs(project)
elif resume == "never":
if name is not None and name in SQLiteStorage.get_runs(project):
warnings.warn(
f"* Warning: resume='never' but a run '{name}' already exists in "
f"project '{project}'. Generating a new name and instead. If you want "
"to resume this run, call init() with resume='must' or resume='allow'."
)
name = None
resumed = False
else:
raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
if auto_log_gpu is None:
nvidia_available = gpu_available()
apple_available = apple_gpu_available()
auto_log_gpu = nvidia_available or apple_available
if project not in _projects_notified_auto_log_hw:
if nvidia_available:
print("* NVIDIA GPU detected, enabling automatic GPU metrics logging")
elif apple_available:
print(
"* Apple Silicon detected, enabling automatic system metrics logging"
)
if nvidia_available or apple_available:
_projects_notified_auto_log_hw.add(project)
run = Run(
url=url,
project=project,
client=None,
name=name,
group=group,
config=config,
space_id=space_id,
auto_log_gpu=auto_log_gpu,
gpu_log_interval=gpu_log_interval,
webhook_url=webhook_url,
webhook_min_level=webhook_min_level,
)
if space_id is not None:
SQLiteStorage.set_project_metadata(project, "space_id", space_id)
if SQLiteStorage.has_pending_data(project):
run._has_local_buffer = True
global _atexit_registered
if not _atexit_registered:
atexit.register(_cleanup_current_run)
_atexit_registered = True
if resumed:
print(f"* Resumed existing run: {run.name}")
else:
print(f"* Created new run: {run.name}")
context_vars.current_run.set(run)
globals()["config"] = run.config
if _should_embed_local:
show(project=project, open_browser=False, block_thread=False)
return run
def log(metrics: dict, step: int | None = None) -> None:
"""
Logs metrics to the current run.
Args:
metrics (`dict`):
A dictionary of metrics to log.
step (`int`, *optional*):
The step number. If not provided, the step will be incremented
automatically.
"""
run = context_vars.current_run.get()
if run is None:
raise RuntimeError("Call trackio.init() before trackio.log().")
run.log(
metrics=metrics,
step=step,
)
def log_system(metrics: dict) -> None:
"""
Logs system metrics (GPU, etc.) to the current run using timestamps instead of steps.
Args:
metrics (`dict`):
A dictionary of system metrics to log.
"""
run = context_vars.current_run.get()
if run is None:
raise RuntimeError("Call trackio.init() before trackio.log_system().")
run.log_system(metrics=metrics)
def log_gpu(run: Run | None = None, device: int | None = None) -> dict:
"""
Log GPU metrics to the current or specified run as system metrics.
Automatically detects whether an NVIDIA or Apple GPU is available and calls
the appropriate logging method.
Args:
run: Optional Run instance. If None, uses current run from context.
device: CUDA device index to collect metrics from (NVIDIA GPUs only).
If None, collects from all GPUs visible to this process.
This parameter is ignored for Apple GPUs.
Returns:
dict: The GPU metrics that were logged.
Example:
```python
import trackio
run = trackio.init(project="my-project")
trackio.log({"loss": 0.5})
trackio.log_gpu()
trackio.log_gpu(device=0)
```
"""
if run is None:
run = context_vars.current_run.get()
if run is None:
raise RuntimeError("Call trackio.init() before trackio.log_gpu().")
if gpu_available():
return _log_nvidia_gpu(run=run, device=device)
elif apple_gpu_available():
return _log_apple_gpu(run=run)
else:
warnings.warn(
"No GPU detected. Install nvidia-ml-py for NVIDIA GPU support "
"or psutil for Apple Silicon support."
)
return {}
def finish():
"""
Finishes the current run.
"""
run = context_vars.current_run.get()
if run is None:
raise RuntimeError("Call trackio.init() before trackio.finish().")
run.finish()
def alert(
title: str,
text: str | None = None,
level: AlertLevel = AlertLevel.WARN,
webhook_url: str | None = None,
) -> None:
"""
Fires an alert immediately on the current run. The alert is printed to the
terminal, stored in the database, and displayed in the dashboard. If a
webhook URL is configured (via `trackio.init()`, the `TRACKIO_WEBHOOK_URL`
environment variable, or the `webhook_url` parameter here), the alert is
also POSTed to that URL.
Args:
title (`str`):
A short title for the alert.
text (`str`, *optional*):
A longer description with details about the alert.
level (`AlertLevel`, *optional*, defaults to `AlertLevel.WARN`):
The severity level. One of `AlertLevel.INFO`, `AlertLevel.WARN`,
or `AlertLevel.ERROR`.
webhook_url (`str`, *optional*):
A webhook URL to send this specific alert to. Overrides any
URL set in `trackio.init()` or the `TRACKIO_WEBHOOK_URL`
environment variable. Supports Slack and Discord webhook
URLs natively.
"""
run = context_vars.current_run.get()
if run is None:
raise RuntimeError("Call trackio.init() before trackio.alert().")
run.alert(title=title, text=text, level=level, webhook_url=webhook_url)
def delete_project(project: str, force: bool = False) -> bool:
"""
Deletes a project by removing its local SQLite database.
Args:
project (`str`):
The name of the project to delete.
force (`bool`, *optional*, defaults to `False`):
If `True`, deletes the project without prompting for confirmation.
If `False`, prompts the user to confirm before deleting.
Returns:
`bool`: `True` if the project was deleted, `False` otherwise.
"""
db_path = SQLiteStorage.get_project_db_path(project)
if not db_path.exists():
print(f"* Project '{project}' does not exist.")
return False
if not force:
response = input(
f"Are you sure you want to delete project '{project}'? "
f"This will permanently delete all runs and metrics. (y/N): "
)
if response.lower() not in ["y", "yes"]:
print("* Deletion cancelled.")
return False
try:
db_path.unlink()
for suffix in ("-wal", "-shm"):
sidecar = Path(str(db_path) + suffix)
if sidecar.exists():
sidecar.unlink()
print(f"* Project '{project}' has been deleted.")
return True
except Exception as e:
print(f"* Error deleting project '{project}': {e}")
return False
def save(
glob_str: str | Path,
project: str | None = None,
) -> str:
"""
Saves files to a project (not linked to a specific run). If Trackio is running
locally, the file(s) will be copied to the project's files directory. If Trackio is
running in a Space, the file(s) will be uploaded to the Space's files directory.
Args:
glob_str (`str` or `Path`):
The file path or glob pattern to save. Can be a single file or a pattern
matching multiple files (e.g., `"*.py"`, `"models/**/*.pth"`).
project (`str`, *optional*):
The name of the project to save files to. If not provided, uses the current
project from `trackio.init()`. If no project is initialized, raises an
error.
Returns:
`str`: The path where the file(s) were saved (project's files directory).
Example:
```python
import trackio
trackio.init(project="my-project")
trackio.save("config.yaml")
trackio.save("models/*.pth")
```
"""
if project is None:
project = context_vars.current_project.get()
if project is None:
raise RuntimeError(
"No project specified. Either call trackio.init() first or provide a "
"project parameter to trackio.save()."
)
glob_str = Path(glob_str)
base_path = Path.cwd().resolve()
matched_files = []
if glob_str.is_file():
matched_files = [glob_str.resolve()]
else:
pattern = str(glob_str)
if not glob_str.is_absolute():
pattern = str((Path.cwd() / glob_str).resolve())
matched_files = [
Path(f).resolve()
for f in glob.glob(pattern, recursive=True)
if Path(f).is_file()
]
if not matched_files:
raise ValueError(f"No files found matching pattern: {glob_str}")
current_run = context_vars.current_run.get()
is_local = (
current_run._is_local
if current_run is not None
else (context_vars.current_space_id.get() is None)
)
if is_local:
for file_path in matched_files:
try:
relative_to_base = file_path.relative_to(base_path)
except ValueError:
relative_to_base = Path(file_path.name)
if current_run is not None:
current_run._queue_upload(
file_path,
step=None,
relative_path=str(relative_to_base.parent),
use_run_name=False,
)
else:
media_path = get_project_media_path(
project=project,
run=None,
step=None,
relative_path=str(relative_to_base),
)
shutil.copy(str(file_path), str(media_path))
else:
url = context_vars.current_server.get()
upload_entries = []
for file_path in matched_files:
try:
relative_to_base = file_path.relative_to(base_path)
except ValueError:
relative_to_base = Path(file_path.name)
if current_run is not None:
current_run._queue_upload(
file_path,
step=None,
relative_path=str(relative_to_base.parent),
use_run_name=False,
)
else:
upload_entry: UploadEntry = {
"project": project,
"run": None,
"step": None,
"relative_path": str(relative_to_base),
"uploaded_file": handle_file(file_path),
}
upload_entries.append(upload_entry)
if upload_entries:
if url is None:
raise RuntimeError(
"No server available. Call trackio.init() before trackio.save() to start the server."
)
try:
client = Client(url, verbose=False, httpx_kwargs={"timeout": 90})
client.predict(
api_name="/bulk_upload_media",
uploads=upload_entries,
hf_token=huggingface_hub.utils.get_token(),
)
except Exception as e:
warnings.warn(
f"Failed to upload files: {e}. "
"Files may not be available in the dashboard."
)
return str(utils.MEDIA_DIR / project / "files")
def show(
project: str | None = None,
*,
theme: str | ThemeClass | None = None,
mcp_server: bool | None = None,
footer: bool = True,
color_palette: list[str] | None = None,
open_browser: bool = True,
block_thread: bool | None = None,
host: str | None = None,
):
"""
Launches the Trackio dashboard.
Args:
project (`str`, *optional*):
The name of the project whose runs to show. If not provided, all projects
will be shown and the user can select one.
theme (`str` or `ThemeClass`, *optional*):
A Gradio Theme to use for the dashboard instead of the default Gradio theme,
can be a built-in theme (e.g. `'soft'`, `'citrus'`), a theme from the Hub
(e.g. `"gstaff/xkcd"`), or a custom Theme class. If not provided, the
`TRACKIO_THEME` environment variable will be used, or if that is not set,
the default Gradio theme will be used.
mcp_server (`bool`, *optional*):
If `True`, the Trackio dashboard will be set up as an MCP server and certain
functions will be added as MCP tools. If `None` (default behavior), then the
`GRADIO_MCP_SERVER` environment variable will be used to determine if the
MCP server should be enabled (which is `"True"` on Hugging Face Spaces).
footer (`bool`, *optional*, defaults to `True`):
Whether to show the Gradio footer. When `False`, the footer will be hidden.
This can also be controlled via the `footer` query parameter in the URL.
color_palette (`list[str]`, *optional*):
A list of hex color codes to use for plot lines. If not provided, the
`TRACKIO_COLOR_PALETTE` environment variable will be used (comma-separated
hex codes), or if that is not set, the default color palette will be used.
Example: `['#FF0000', '#00FF00', '#0000FF']`
open_browser (`bool`, *optional*, defaults to `True`):
If `True` and not in a notebook, a new browser tab will be opened with the
dashboard. If `False`, the browser will not be opened.
block_thread (`bool`, *optional*):
If `True`, the main thread will be blocked until the dashboard is closed.
If `None` (default behavior), then the main thread will not be blocked if the
dashboard is launched in a notebook, otherwise the main thread will be blocked.
host (`str`, *optional*):
The host to bind the server to. If not provided, defaults to `'127.0.0.1'`
(localhost only). Set to `'0.0.0.0'` to allow remote access.
Returns:
`app`: The Gradio app object corresponding to the dashboard launched by Trackio.
`url`: The local URL of the dashboard.
`share_url`: The public share URL of the dashboard.
`full_url`: The full URL of the dashboard including the write token (will use the public share URL if launched publicly, otherwise the local URL).
"""
if color_palette is not None:
os.environ["TRACKIO_COLOR_PALETTE"] = ",".join(color_palette)
theme = theme or os.environ.get("TRACKIO_THEME")
_mcp_server = (
mcp_server
if mcp_server is not None
else os.environ.get("GRADIO_MCP_SERVER", "False") == "True"
)
server = make_trackio_server()
mount_frontend(server)
_, url, share_url = server.launch(
quiet=True,
inline=False,
prevent_thread_lock=True,
favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
mcp_server=_mcp_server,
theme=theme,
server_name=host,
)
base_url = share_url + "/" if share_url else url
dashboard_url = base_url.rstrip("/") + "/"
if project:
dashboard_url += f"?project={project}"
full_url = utils.get_full_url(
base_url.rstrip("/"),
project=project,
write_token=server.write_token,
footer=footer,
)
if not utils.is_in_notebook():
print(f"* Trackio UI launched at: {dashboard_url}")
print(f"* Gradio API available at: {base_url}")
if open_browser:
webbrowser.open(dashboard_url)
block_thread = block_thread if block_thread is not None else True
else:
utils.embed_url_in_notebook(dashboard_url)
block_thread = block_thread if block_thread is not None else False
if block_thread:
utils.block_main_thread_until_keyboard_interrupt()
return TupleNoPrint((server, url, share_url, full_url))