import shutil import threading import time import uuid import warnings from datetime import datetime, timezone from pathlib import Path import huggingface_hub from gradio_client import Client, handle_file from trackio import utils from trackio.gpu import GpuMonitor from trackio.histogram import Histogram from trackio.markdown import Markdown from trackio.media import TrackioMedia, get_project_media_path from trackio.sqlite_storage import SQLiteStorage from trackio.table import Table from trackio.typehints import LogEntry, SystemLogEntry, UploadEntry from trackio.utils import _get_default_namespace BATCH_SEND_INTERVAL = 0.5 MAX_BACKOFF = 30 class Run: def __init__( self, url: str | None, project: str, client: Client | None, name: str | None = None, group: str | None = None, config: dict | None = None, space_id: str | None = None, auto_log_gpu: bool = False, gpu_log_interval: float = 10.0, ): """ Initialize a Run for logging metrics to Trackio. Args: url: The URL of the Trackio server (local Gradio app or HF Space). project: The name of the project to log metrics to. client: A pre-configured gradio_client.Client instance, or None to create one automatically in a background thread with retry logic. Passing None is recommended for normal usage. Passing a client is useful for testing (e.g., injecting a mock client). name: The name of this run. If None, a readable name like "brave-sunset-0" is auto-generated. If space_id is provided, generates a "username-timestamp" format instead. group: Optional group name to organize related runs together. config: A dictionary of configuration/hyperparameters for this run. Keys starting with '_' are reserved for internal use. space_id: The HF Space ID if logging to a Space (e.g., "user/space"). If provided, media files will be uploaded to the Space. auto_log_gpu: Whether to automatically log GPU metrics (utilization, memory, temperature) at regular intervals. gpu_log_interval: The interval in seconds between GPU metric logs. Only used when auto_log_gpu is True. """ self.url = url self.project = project self._client_lock = threading.Lock() self._client_thread = None self._client = client self._space_id = space_id self.name = name or utils.generate_readable_name( SQLiteStorage.get_runs(project), space_id ) self.group = group self.config = utils.to_json_safe(config or {}) if isinstance(self.config, dict): for key in self.config: if key.startswith("_"): raise ValueError( f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)" ) self.config["_Username"] = self._get_username() self.config["_Created"] = datetime.now(timezone.utc).isoformat() self.config["_Group"] = self.group self._queued_logs: list[LogEntry] = [] self._queued_system_logs: list[SystemLogEntry] = [] self._queued_uploads: list[UploadEntry] = [] self._stop_flag = threading.Event() self._config_logged = False max_step = SQLiteStorage.get_max_step_for_run(self.project, self.name) self._next_step = 0 if max_step is None else max_step + 1 self._has_local_buffer = False self._is_local = space_id is None if self._is_local: self._local_sender_thread = threading.Thread( target=self._local_batch_sender ) self._local_sender_thread.daemon = True self._local_sender_thread.start() else: self._client_thread = threading.Thread(target=self._init_client_background) self._client_thread.daemon = True self._client_thread.start() self._gpu_monitor: "GpuMonitor | None" = None if auto_log_gpu: self._gpu_monitor = GpuMonitor(self, interval=gpu_log_interval) self._gpu_monitor.start() def _get_username(self) -> str | None: try: return _get_default_namespace() except Exception: return None def _local_batch_sender(self): while ( not self._stop_flag.is_set() or len(self._queued_logs) > 0 or len(self._queued_system_logs) > 0 ): if not self._stop_flag.is_set(): time.sleep(BATCH_SEND_INTERVAL) with self._client_lock: if self._queued_logs: logs_to_send = self._queued_logs.copy() self._queued_logs.clear() self._write_logs_to_sqlite(logs_to_send) if self._queued_system_logs: system_logs_to_send = self._queued_system_logs.copy() self._queued_system_logs.clear() self._write_system_logs_to_sqlite(system_logs_to_send) def _write_logs_to_sqlite(self, logs: list[LogEntry]): logs_by_run: dict[tuple, dict] = {} for entry in logs: key = (entry["project"], entry["run"]) if key not in logs_by_run: logs_by_run[key] = { "metrics": [], "steps": [], "log_ids": [], "config": None, } logs_by_run[key]["metrics"].append(entry["metrics"]) logs_by_run[key]["steps"].append(entry.get("step")) logs_by_run[key]["log_ids"].append(entry.get("log_id")) if entry.get("config") and logs_by_run[key]["config"] is None: logs_by_run[key]["config"] = entry["config"] for (project, run), data in logs_by_run.items(): has_log_ids = any(lid is not None for lid in data["log_ids"]) SQLiteStorage.bulk_log( project=project, run=run, metrics_list=data["metrics"], steps=data["steps"], config=data["config"], log_ids=data["log_ids"] if has_log_ids else None, ) def _write_system_logs_to_sqlite(self, logs: list[SystemLogEntry]): logs_by_run: dict[tuple, dict] = {} for entry in logs: key = (entry["project"], entry["run"]) if key not in logs_by_run: logs_by_run[key] = {"metrics": [], "timestamps": [], "log_ids": []} logs_by_run[key]["metrics"].append(entry["metrics"]) logs_by_run[key]["timestamps"].append(entry.get("timestamp")) logs_by_run[key]["log_ids"].append(entry.get("log_id")) for (project, run), data in logs_by_run.items(): has_log_ids = any(lid is not None for lid in data["log_ids"]) SQLiteStorage.bulk_log_system( project=project, run=run, metrics_list=data["metrics"], timestamps=data["timestamps"], log_ids=data["log_ids"] if has_log_ids else None, ) def _batch_sender(self): consecutive_failures = 0 while ( not self._stop_flag.is_set() or len(self._queued_logs) > 0 or len(self._queued_system_logs) > 0 or len(self._queued_uploads) > 0 ): if not self._stop_flag.is_set(): if consecutive_failures: sleep_time = min( BATCH_SEND_INTERVAL * (2**consecutive_failures), MAX_BACKOFF ) else: sleep_time = BATCH_SEND_INTERVAL time.sleep(sleep_time) with self._client_lock: if self._client is None: return failed = False if self._queued_logs: logs_to_send = self._queued_logs.copy() self._queued_logs.clear() try: self._client.predict( api_name="/bulk_log", logs=logs_to_send, hf_token=huggingface_hub.utils.get_token(), ) except Exception: self._persist_logs_locally(logs_to_send) failed = True if self._queued_system_logs: system_logs_to_send = self._queued_system_logs.copy() self._queued_system_logs.clear() try: self._client.predict( api_name="/bulk_log_system", logs=system_logs_to_send, hf_token=huggingface_hub.utils.get_token(), ) except Exception: self._persist_system_logs_locally(system_logs_to_send) failed = True if self._queued_uploads: uploads_to_send = self._queued_uploads.copy() self._queued_uploads.clear() try: self._client.predict( api_name="/bulk_upload_media", uploads=uploads_to_send, hf_token=huggingface_hub.utils.get_token(), ) except Exception: self._persist_uploads_locally(uploads_to_send) failed = True if failed: consecutive_failures += 1 else: consecutive_failures = 0 if self._has_local_buffer: self._flush_local_buffer() def _persist_logs_locally(self, logs: list[LogEntry]): if not self._space_id: return logs_by_run: dict[tuple, dict] = {} for entry in logs: key = (entry["project"], entry["run"]) if key not in logs_by_run: logs_by_run[key] = { "metrics": [], "steps": [], "log_ids": [], "config": None, } logs_by_run[key]["metrics"].append(entry["metrics"]) logs_by_run[key]["steps"].append(entry.get("step")) logs_by_run[key]["log_ids"].append(entry.get("log_id")) if entry.get("config") and logs_by_run[key]["config"] is None: logs_by_run[key]["config"] = entry["config"] for (project, run), data in logs_by_run.items(): SQLiteStorage.bulk_log( project=project, run=run, metrics_list=data["metrics"], steps=data["steps"], log_ids=data["log_ids"], config=data["config"], space_id=self._space_id, ) self._has_local_buffer = True def _persist_system_logs_locally(self, logs: list[SystemLogEntry]): if not self._space_id: return logs_by_run: dict[tuple, dict] = {} for entry in logs: key = (entry["project"], entry["run"]) if key not in logs_by_run: logs_by_run[key] = {"metrics": [], "timestamps": [], "log_ids": []} logs_by_run[key]["metrics"].append(entry["metrics"]) logs_by_run[key]["timestamps"].append(entry.get("timestamp")) logs_by_run[key]["log_ids"].append(entry.get("log_id")) for (project, run), data in logs_by_run.items(): SQLiteStorage.bulk_log_system( project=project, run=run, metrics_list=data["metrics"], timestamps=data["timestamps"], log_ids=data["log_ids"], space_id=self._space_id, ) self._has_local_buffer = True def _persist_uploads_locally(self, uploads: list[UploadEntry]): if not self._space_id: return for entry in uploads: file_data = entry.get("uploaded_file") file_path = "" if isinstance(file_data, dict): file_path = file_data.get("path", "") elif hasattr(file_data, "path"): file_path = str(file_data.path) else: file_path = str(file_data) SQLiteStorage.add_pending_upload( project=entry["project"], space_id=self._space_id, run_name=entry.get("run"), step=entry.get("step"), file_path=file_path, relative_path=entry.get("relative_path"), ) self._has_local_buffer = True def _flush_local_buffer(self): try: buffered_logs = SQLiteStorage.get_pending_logs(self.project) if buffered_logs: self._client.predict( api_name="/bulk_log", logs=buffered_logs["logs"], hf_token=huggingface_hub.utils.get_token(), ) SQLiteStorage.clear_pending_logs(self.project, buffered_logs["ids"]) buffered_sys = SQLiteStorage.get_pending_system_logs(self.project) if buffered_sys: self._client.predict( api_name="/bulk_log_system", logs=buffered_sys["logs"], hf_token=huggingface_hub.utils.get_token(), ) SQLiteStorage.clear_pending_system_logs( self.project, buffered_sys["ids"] ) buffered_uploads = SQLiteStorage.get_pending_uploads(self.project) if buffered_uploads: upload_entries = [] for u in buffered_uploads["uploads"]: fp = u["file_path"] if Path(fp).exists(): upload_entries.append( { "project": u["project"], "run": u["run"], "step": u["step"], "relative_path": u["relative_path"], "uploaded_file": handle_file(fp), } ) if upload_entries: self._client.predict( api_name="/bulk_upload_media", uploads=upload_entries, hf_token=huggingface_hub.utils.get_token(), ) SQLiteStorage.clear_pending_uploads( self.project, buffered_uploads["ids"] ) self._has_local_buffer = False except Exception: pass def _init_client_background(self): if self._client is None: fib = utils.fibo() for sleep_coefficient in fib: try: client = Client(self.url, verbose=False) with self._client_lock: self._client = client break except Exception: pass if sleep_coefficient is not None: time.sleep(0.1 * sleep_coefficient) self._batch_sender() def _queue_upload( self, file_path, step: int | None, relative_path: str | None = None, use_run_name: bool = True, ): if self._is_local: self._save_upload_locally(file_path, step, relative_path, use_run_name) else: upload_entry: UploadEntry = { "project": self.project, "run": self.name if use_run_name else None, "step": step, "relative_path": relative_path, "uploaded_file": handle_file(file_path), } with self._client_lock: self._queued_uploads.append(upload_entry) def _save_upload_locally( self, file_path, step: int | None, relative_path: str | None = None, use_run_name: bool = True, ): media_path = get_project_media_path( project=self.project, run=self.name if use_run_name else None, step=step, relative_path=relative_path, ) src = Path(file_path) if src.exists() and str(src.resolve()) != str(Path(media_path).resolve()): shutil.copy(str(src), str(media_path)) def _process_media(self, value: TrackioMedia, step: int | None) -> dict: value._save(self.project, self.name, step if step is not None else 0) if self._space_id: self._queue_upload(value._get_absolute_file_path(), step) return value._to_dict() def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None): if not self._space_id: return table_data = table_dict.get("_value", []) for row in table_data: for value in row.values(): if isinstance(value, dict) and value.get("_type") in [ "trackio.image", "trackio.video", "trackio.audio", ]: file_path = value.get("file_path") if file_path: from trackio.utils import MEDIA_DIR absolute_path = MEDIA_DIR / file_path self._queue_upload(absolute_path, step) elif isinstance(value, list): for item in value: if isinstance(item, dict) and item.get("_type") in [ "trackio.image", "trackio.video", "trackio.audio", ]: file_path = item.get("file_path") if file_path: from trackio.utils import MEDIA_DIR absolute_path = MEDIA_DIR / file_path self._queue_upload(absolute_path, step) def _ensure_sender_alive(self): if self._is_local: if ( hasattr(self, "_local_sender_thread") and not self._local_sender_thread.is_alive() and not self._stop_flag.is_set() ): self._local_sender_thread = threading.Thread( target=self._local_batch_sender ) self._local_sender_thread.daemon = True self._local_sender_thread.start() else: if ( self._client_thread is not None and not self._client_thread.is_alive() and not self._stop_flag.is_set() ): self._client_thread = threading.Thread( target=self._init_client_background ) self._client_thread.daemon = True self._client_thread.start() def log(self, metrics: dict, step: int | None = None): renamed_keys = [] new_metrics = {} for k, v in metrics.items(): if k in utils.RESERVED_KEYS or k.startswith("__"): new_key = f"__{k}" renamed_keys.append(k) new_metrics[new_key] = v else: new_metrics[k] = v if renamed_keys: warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'") metrics = new_metrics for key, value in metrics.items(): if isinstance(value, Table): metrics[key] = value._to_dict( project=self.project, run=self.name, step=step ) self._scan_and_queue_media_uploads(metrics[key], step) elif isinstance(value, Histogram): metrics[key] = value._to_dict() elif isinstance(value, Markdown): metrics[key] = value._to_dict() elif isinstance(value, TrackioMedia): metrics[key] = self._process_media(value, step) metrics = utils.serialize_values(metrics) if step is None: step = self._next_step self._next_step = max(self._next_step, step + 1) config_to_log = None if not self._config_logged and self.config: config_to_log = utils.to_json_safe(self.config) self._config_logged = True log_entry: LogEntry = { "project": self.project, "run": self.name, "metrics": metrics, "step": step, "config": config_to_log, "log_id": uuid.uuid4().hex, } with self._client_lock: self._queued_logs.append(log_entry) self._ensure_sender_alive() def log_system(self, metrics: dict): metrics = utils.serialize_values(metrics) timestamp = datetime.now(timezone.utc).isoformat() system_log_entry: SystemLogEntry = { "project": self.project, "run": self.name, "metrics": metrics, "timestamp": timestamp, "log_id": uuid.uuid4().hex, } with self._client_lock: self._queued_system_logs.append(system_log_entry) self._ensure_sender_alive() def finish(self): if self._gpu_monitor is not None: self._gpu_monitor.stop() self._stop_flag.set() if self._is_local: if hasattr(self, "_local_sender_thread"): print("* Run finished. Uploading logs to Trackio (please wait...)") self._local_sender_thread.join(timeout=30) if self._local_sender_thread.is_alive(): warnings.warn( "Could not flush all logs within 30s. Some data may be buffered locally." ) else: if self._client_thread is not None: print( "* Run finished. Uploading logs to Trackio Space (please wait...)" ) self._client_thread.join(timeout=30) if self._client_thread.is_alive(): warnings.warn( "Could not flush all logs within 30s. Some data may be buffered locally." )