Spaces:
Sleeping
Sleeping
| import shutil | |
| import threading | |
| import time | |
| import uuid | |
| import warnings | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import huggingface_hub | |
| from gradio_client import Client, handle_file | |
| from trackio import utils | |
| from trackio.gpu import GpuMonitor | |
| from trackio.histogram import Histogram | |
| from trackio.markdown import Markdown | |
| from trackio.media import TrackioMedia, get_project_media_path | |
| from trackio.sqlite_storage import SQLiteStorage | |
| from trackio.table import Table | |
| from trackio.typehints import LogEntry, SystemLogEntry, UploadEntry | |
| from trackio.utils import _get_default_namespace | |
| BATCH_SEND_INTERVAL = 0.5 | |
| MAX_BACKOFF = 30 | |
| class Run: | |
| def __init__( | |
| self, | |
| url: str | None, | |
| project: str, | |
| client: Client | None, | |
| name: str | None = None, | |
| group: str | None = None, | |
| config: dict | None = None, | |
| space_id: str | None = None, | |
| auto_log_gpu: bool = False, | |
| gpu_log_interval: float = 10.0, | |
| ): | |
| """ | |
| Initialize a Run for logging metrics to Trackio. | |
| Args: | |
| url: The URL of the Trackio server (local Gradio app or HF Space). | |
| project: The name of the project to log metrics to. | |
| client: A pre-configured gradio_client.Client instance, or None to | |
| create one automatically in a background thread with retry logic. | |
| Passing None is recommended for normal usage. Passing a client | |
| is useful for testing (e.g., injecting a mock client). | |
| name: The name of this run. If None, a readable name like | |
| "brave-sunset-0" is auto-generated. If space_id is provided, | |
| generates a "username-timestamp" format instead. | |
| group: Optional group name to organize related runs together. | |
| config: A dictionary of configuration/hyperparameters for this run. | |
| Keys starting with '_' are reserved for internal use. | |
| space_id: The HF Space ID if logging to a Space (e.g., "user/space"). | |
| If provided, media files will be uploaded to the Space. | |
| auto_log_gpu: Whether to automatically log GPU metrics (utilization, | |
| memory, temperature) at regular intervals. | |
| gpu_log_interval: The interval in seconds between GPU metric logs. | |
| Only used when auto_log_gpu is True. | |
| """ | |
| self.url = url | |
| self.project = project | |
| self._client_lock = threading.Lock() | |
| self._client_thread = None | |
| self._client = client | |
| self._space_id = space_id | |
| self.name = name or utils.generate_readable_name( | |
| SQLiteStorage.get_runs(project), space_id | |
| ) | |
| self.group = group | |
| self.config = utils.to_json_safe(config or {}) | |
| if isinstance(self.config, dict): | |
| for key in self.config: | |
| if key.startswith("_"): | |
| raise ValueError( | |
| f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)" | |
| ) | |
| self.config["_Username"] = self._get_username() | |
| self.config["_Created"] = datetime.now(timezone.utc).isoformat() | |
| self.config["_Group"] = self.group | |
| self._queued_logs: list[LogEntry] = [] | |
| self._queued_system_logs: list[SystemLogEntry] = [] | |
| self._queued_uploads: list[UploadEntry] = [] | |
| self._stop_flag = threading.Event() | |
| self._config_logged = False | |
| max_step = SQLiteStorage.get_max_step_for_run(self.project, self.name) | |
| self._next_step = 0 if max_step is None else max_step + 1 | |
| self._has_local_buffer = False | |
| self._is_local = space_id is None | |
| if self._is_local: | |
| self._local_sender_thread = threading.Thread( | |
| target=self._local_batch_sender | |
| ) | |
| self._local_sender_thread.daemon = True | |
| self._local_sender_thread.start() | |
| else: | |
| self._client_thread = threading.Thread(target=self._init_client_background) | |
| self._client_thread.daemon = True | |
| self._client_thread.start() | |
| self._gpu_monitor: "GpuMonitor | None" = None | |
| if auto_log_gpu: | |
| self._gpu_monitor = GpuMonitor(self, interval=gpu_log_interval) | |
| self._gpu_monitor.start() | |
| def _get_username(self) -> str | None: | |
| try: | |
| return _get_default_namespace() | |
| except Exception: | |
| return None | |
| def _local_batch_sender(self): | |
| while ( | |
| not self._stop_flag.is_set() | |
| or len(self._queued_logs) > 0 | |
| or len(self._queued_system_logs) > 0 | |
| ): | |
| if not self._stop_flag.is_set(): | |
| time.sleep(BATCH_SEND_INTERVAL) | |
| with self._client_lock: | |
| if self._queued_logs: | |
| logs_to_send = self._queued_logs.copy() | |
| self._queued_logs.clear() | |
| self._write_logs_to_sqlite(logs_to_send) | |
| if self._queued_system_logs: | |
| system_logs_to_send = self._queued_system_logs.copy() | |
| self._queued_system_logs.clear() | |
| self._write_system_logs_to_sqlite(system_logs_to_send) | |
| def _write_logs_to_sqlite(self, logs: list[LogEntry]): | |
| logs_by_run: dict[tuple, dict] = {} | |
| for entry in logs: | |
| key = (entry["project"], entry["run"]) | |
| if key not in logs_by_run: | |
| logs_by_run[key] = { | |
| "metrics": [], | |
| "steps": [], | |
| "log_ids": [], | |
| "config": None, | |
| } | |
| logs_by_run[key]["metrics"].append(entry["metrics"]) | |
| logs_by_run[key]["steps"].append(entry.get("step")) | |
| logs_by_run[key]["log_ids"].append(entry.get("log_id")) | |
| if entry.get("config") and logs_by_run[key]["config"] is None: | |
| logs_by_run[key]["config"] = entry["config"] | |
| for (project, run), data in logs_by_run.items(): | |
| has_log_ids = any(lid is not None for lid in data["log_ids"]) | |
| SQLiteStorage.bulk_log( | |
| project=project, | |
| run=run, | |
| metrics_list=data["metrics"], | |
| steps=data["steps"], | |
| config=data["config"], | |
| log_ids=data["log_ids"] if has_log_ids else None, | |
| ) | |
| def _write_system_logs_to_sqlite(self, logs: list[SystemLogEntry]): | |
| logs_by_run: dict[tuple, dict] = {} | |
| for entry in logs: | |
| key = (entry["project"], entry["run"]) | |
| if key not in logs_by_run: | |
| logs_by_run[key] = {"metrics": [], "timestamps": [], "log_ids": []} | |
| logs_by_run[key]["metrics"].append(entry["metrics"]) | |
| logs_by_run[key]["timestamps"].append(entry.get("timestamp")) | |
| logs_by_run[key]["log_ids"].append(entry.get("log_id")) | |
| for (project, run), data in logs_by_run.items(): | |
| has_log_ids = any(lid is not None for lid in data["log_ids"]) | |
| SQLiteStorage.bulk_log_system( | |
| project=project, | |
| run=run, | |
| metrics_list=data["metrics"], | |
| timestamps=data["timestamps"], | |
| log_ids=data["log_ids"] if has_log_ids else None, | |
| ) | |
| def _batch_sender(self): | |
| consecutive_failures = 0 | |
| while ( | |
| not self._stop_flag.is_set() | |
| or len(self._queued_logs) > 0 | |
| or len(self._queued_system_logs) > 0 | |
| or len(self._queued_uploads) > 0 | |
| ): | |
| if not self._stop_flag.is_set(): | |
| if consecutive_failures: | |
| sleep_time = min( | |
| BATCH_SEND_INTERVAL * (2**consecutive_failures), MAX_BACKOFF | |
| ) | |
| else: | |
| sleep_time = BATCH_SEND_INTERVAL | |
| time.sleep(sleep_time) | |
| with self._client_lock: | |
| if self._client is None: | |
| return | |
| failed = False | |
| if self._queued_logs: | |
| logs_to_send = self._queued_logs.copy() | |
| self._queued_logs.clear() | |
| try: | |
| self._client.predict( | |
| api_name="/bulk_log", | |
| logs=logs_to_send, | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| except Exception: | |
| self._persist_logs_locally(logs_to_send) | |
| failed = True | |
| if self._queued_system_logs: | |
| system_logs_to_send = self._queued_system_logs.copy() | |
| self._queued_system_logs.clear() | |
| try: | |
| self._client.predict( | |
| api_name="/bulk_log_system", | |
| logs=system_logs_to_send, | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| except Exception: | |
| self._persist_system_logs_locally(system_logs_to_send) | |
| failed = True | |
| if self._queued_uploads: | |
| uploads_to_send = self._queued_uploads.copy() | |
| self._queued_uploads.clear() | |
| try: | |
| self._client.predict( | |
| api_name="/bulk_upload_media", | |
| uploads=uploads_to_send, | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| except Exception: | |
| self._persist_uploads_locally(uploads_to_send) | |
| failed = True | |
| if failed: | |
| consecutive_failures += 1 | |
| else: | |
| consecutive_failures = 0 | |
| if self._has_local_buffer: | |
| self._flush_local_buffer() | |
| def _persist_logs_locally(self, logs: list[LogEntry]): | |
| if not self._space_id: | |
| return | |
| logs_by_run: dict[tuple, dict] = {} | |
| for entry in logs: | |
| key = (entry["project"], entry["run"]) | |
| if key not in logs_by_run: | |
| logs_by_run[key] = { | |
| "metrics": [], | |
| "steps": [], | |
| "log_ids": [], | |
| "config": None, | |
| } | |
| logs_by_run[key]["metrics"].append(entry["metrics"]) | |
| logs_by_run[key]["steps"].append(entry.get("step")) | |
| logs_by_run[key]["log_ids"].append(entry.get("log_id")) | |
| if entry.get("config") and logs_by_run[key]["config"] is None: | |
| logs_by_run[key]["config"] = entry["config"] | |
| for (project, run), data in logs_by_run.items(): | |
| SQLiteStorage.bulk_log( | |
| project=project, | |
| run=run, | |
| metrics_list=data["metrics"], | |
| steps=data["steps"], | |
| log_ids=data["log_ids"], | |
| config=data["config"], | |
| space_id=self._space_id, | |
| ) | |
| self._has_local_buffer = True | |
| def _persist_system_logs_locally(self, logs: list[SystemLogEntry]): | |
| if not self._space_id: | |
| return | |
| logs_by_run: dict[tuple, dict] = {} | |
| for entry in logs: | |
| key = (entry["project"], entry["run"]) | |
| if key not in logs_by_run: | |
| logs_by_run[key] = {"metrics": [], "timestamps": [], "log_ids": []} | |
| logs_by_run[key]["metrics"].append(entry["metrics"]) | |
| logs_by_run[key]["timestamps"].append(entry.get("timestamp")) | |
| logs_by_run[key]["log_ids"].append(entry.get("log_id")) | |
| for (project, run), data in logs_by_run.items(): | |
| SQLiteStorage.bulk_log_system( | |
| project=project, | |
| run=run, | |
| metrics_list=data["metrics"], | |
| timestamps=data["timestamps"], | |
| log_ids=data["log_ids"], | |
| space_id=self._space_id, | |
| ) | |
| self._has_local_buffer = True | |
| def _persist_uploads_locally(self, uploads: list[UploadEntry]): | |
| if not self._space_id: | |
| return | |
| for entry in uploads: | |
| file_data = entry.get("uploaded_file") | |
| file_path = "" | |
| if isinstance(file_data, dict): | |
| file_path = file_data.get("path", "") | |
| elif hasattr(file_data, "path"): | |
| file_path = str(file_data.path) | |
| else: | |
| file_path = str(file_data) | |
| SQLiteStorage.add_pending_upload( | |
| project=entry["project"], | |
| space_id=self._space_id, | |
| run_name=entry.get("run"), | |
| step=entry.get("step"), | |
| file_path=file_path, | |
| relative_path=entry.get("relative_path"), | |
| ) | |
| self._has_local_buffer = True | |
| def _flush_local_buffer(self): | |
| try: | |
| buffered_logs = SQLiteStorage.get_pending_logs(self.project) | |
| if buffered_logs: | |
| self._client.predict( | |
| api_name="/bulk_log", | |
| logs=buffered_logs["logs"], | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| SQLiteStorage.clear_pending_logs(self.project, buffered_logs["ids"]) | |
| buffered_sys = SQLiteStorage.get_pending_system_logs(self.project) | |
| if buffered_sys: | |
| self._client.predict( | |
| api_name="/bulk_log_system", | |
| logs=buffered_sys["logs"], | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| SQLiteStorage.clear_pending_system_logs( | |
| self.project, buffered_sys["ids"] | |
| ) | |
| buffered_uploads = SQLiteStorage.get_pending_uploads(self.project) | |
| if buffered_uploads: | |
| upload_entries = [] | |
| for u in buffered_uploads["uploads"]: | |
| fp = u["file_path"] | |
| if Path(fp).exists(): | |
| upload_entries.append( | |
| { | |
| "project": u["project"], | |
| "run": u["run"], | |
| "step": u["step"], | |
| "relative_path": u["relative_path"], | |
| "uploaded_file": handle_file(fp), | |
| } | |
| ) | |
| if upload_entries: | |
| self._client.predict( | |
| api_name="/bulk_upload_media", | |
| uploads=upload_entries, | |
| hf_token=huggingface_hub.utils.get_token(), | |
| ) | |
| SQLiteStorage.clear_pending_uploads( | |
| self.project, buffered_uploads["ids"] | |
| ) | |
| self._has_local_buffer = False | |
| except Exception: | |
| pass | |
| def _init_client_background(self): | |
| if self._client is None: | |
| fib = utils.fibo() | |
| for sleep_coefficient in fib: | |
| try: | |
| client = Client(self.url, verbose=False) | |
| with self._client_lock: | |
| self._client = client | |
| break | |
| except Exception: | |
| pass | |
| if sleep_coefficient is not None: | |
| time.sleep(0.1 * sleep_coefficient) | |
| self._batch_sender() | |
| def _queue_upload( | |
| self, | |
| file_path, | |
| step: int | None, | |
| relative_path: str | None = None, | |
| use_run_name: bool = True, | |
| ): | |
| if self._is_local: | |
| self._save_upload_locally(file_path, step, relative_path, use_run_name) | |
| else: | |
| upload_entry: UploadEntry = { | |
| "project": self.project, | |
| "run": self.name if use_run_name else None, | |
| "step": step, | |
| "relative_path": relative_path, | |
| "uploaded_file": handle_file(file_path), | |
| } | |
| with self._client_lock: | |
| self._queued_uploads.append(upload_entry) | |
| def _save_upload_locally( | |
| self, | |
| file_path, | |
| step: int | None, | |
| relative_path: str | None = None, | |
| use_run_name: bool = True, | |
| ): | |
| media_path = get_project_media_path( | |
| project=self.project, | |
| run=self.name if use_run_name else None, | |
| step=step, | |
| relative_path=relative_path, | |
| ) | |
| src = Path(file_path) | |
| if src.exists() and str(src.resolve()) != str(Path(media_path).resolve()): | |
| shutil.copy(str(src), str(media_path)) | |
| def _process_media(self, value: TrackioMedia, step: int | None) -> dict: | |
| value._save(self.project, self.name, step if step is not None else 0) | |
| if self._space_id: | |
| self._queue_upload(value._get_absolute_file_path(), step) | |
| return value._to_dict() | |
| def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None): | |
| if not self._space_id: | |
| return | |
| table_data = table_dict.get("_value", []) | |
| for row in table_data: | |
| for value in row.values(): | |
| if isinstance(value, dict) and value.get("_type") in [ | |
| "trackio.image", | |
| "trackio.video", | |
| "trackio.audio", | |
| ]: | |
| file_path = value.get("file_path") | |
| if file_path: | |
| from trackio.utils import MEDIA_DIR | |
| absolute_path = MEDIA_DIR / file_path | |
| self._queue_upload(absolute_path, step) | |
| elif isinstance(value, list): | |
| for item in value: | |
| if isinstance(item, dict) and item.get("_type") in [ | |
| "trackio.image", | |
| "trackio.video", | |
| "trackio.audio", | |
| ]: | |
| file_path = item.get("file_path") | |
| if file_path: | |
| from trackio.utils import MEDIA_DIR | |
| absolute_path = MEDIA_DIR / file_path | |
| self._queue_upload(absolute_path, step) | |
| def _ensure_sender_alive(self): | |
| if self._is_local: | |
| if ( | |
| hasattr(self, "_local_sender_thread") | |
| and not self._local_sender_thread.is_alive() | |
| and not self._stop_flag.is_set() | |
| ): | |
| self._local_sender_thread = threading.Thread( | |
| target=self._local_batch_sender | |
| ) | |
| self._local_sender_thread.daemon = True | |
| self._local_sender_thread.start() | |
| else: | |
| if ( | |
| self._client_thread is not None | |
| and not self._client_thread.is_alive() | |
| and not self._stop_flag.is_set() | |
| ): | |
| self._client_thread = threading.Thread( | |
| target=self._init_client_background | |
| ) | |
| self._client_thread.daemon = True | |
| self._client_thread.start() | |
| def log(self, metrics: dict, step: int | None = None): | |
| renamed_keys = [] | |
| new_metrics = {} | |
| for k, v in metrics.items(): | |
| if k in utils.RESERVED_KEYS or k.startswith("__"): | |
| new_key = f"__{k}" | |
| renamed_keys.append(k) | |
| new_metrics[new_key] = v | |
| else: | |
| new_metrics[k] = v | |
| if renamed_keys: | |
| warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'") | |
| metrics = new_metrics | |
| for key, value in metrics.items(): | |
| if isinstance(value, Table): | |
| metrics[key] = value._to_dict( | |
| project=self.project, run=self.name, step=step | |
| ) | |
| self._scan_and_queue_media_uploads(metrics[key], step) | |
| elif isinstance(value, Histogram): | |
| metrics[key] = value._to_dict() | |
| elif isinstance(value, Markdown): | |
| metrics[key] = value._to_dict() | |
| elif isinstance(value, TrackioMedia): | |
| metrics[key] = self._process_media(value, step) | |
| metrics = utils.serialize_values(metrics) | |
| if step is None: | |
| step = self._next_step | |
| self._next_step = max(self._next_step, step + 1) | |
| config_to_log = None | |
| if not self._config_logged and self.config: | |
| config_to_log = utils.to_json_safe(self.config) | |
| self._config_logged = True | |
| log_entry: LogEntry = { | |
| "project": self.project, | |
| "run": self.name, | |
| "metrics": metrics, | |
| "step": step, | |
| "config": config_to_log, | |
| "log_id": uuid.uuid4().hex, | |
| } | |
| with self._client_lock: | |
| self._queued_logs.append(log_entry) | |
| self._ensure_sender_alive() | |
| def log_system(self, metrics: dict): | |
| metrics = utils.serialize_values(metrics) | |
| timestamp = datetime.now(timezone.utc).isoformat() | |
| system_log_entry: SystemLogEntry = { | |
| "project": self.project, | |
| "run": self.name, | |
| "metrics": metrics, | |
| "timestamp": timestamp, | |
| "log_id": uuid.uuid4().hex, | |
| } | |
| with self._client_lock: | |
| self._queued_system_logs.append(system_log_entry) | |
| self._ensure_sender_alive() | |
| def finish(self): | |
| if self._gpu_monitor is not None: | |
| self._gpu_monitor.stop() | |
| self._stop_flag.set() | |
| if self._is_local: | |
| if hasattr(self, "_local_sender_thread"): | |
| print("* Run finished. Uploading logs to Trackio (please wait...)") | |
| self._local_sender_thread.join(timeout=30) | |
| if self._local_sender_thread.is_alive(): | |
| warnings.warn( | |
| "Could not flush all logs within 30s. Some data may be buffered locally." | |
| ) | |
| else: | |
| if self._client_thread is not None: | |
| print( | |
| "* Run finished. Uploading logs to Trackio Space (please wait...)" | |
| ) | |
| self._client_thread.join(timeout=30) | |
| if self._client_thread.is_alive(): | |
| warnings.warn( | |
| "Could not flush all logs within 30s. Some data may be buffered locally." | |
| ) | |