Spaces:
Sleeping
Sleeping
File size: 8,937 Bytes
3b4ee4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import threading
import time
import huggingface_hub
from gradio_client import Client
from trackio.sqlite_storage import SQLiteStorage
from trackio.typehints import LogEntry
from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name
class Run:
def __init__(
self,
url: str,
project: str,
client: Client | None,
name: str | None = None,
config: dict | None = None,
):
print(f"[DEBUG] Run.__init__: url={url}, project={project}, client={client is not None}, name={name}")
self.url = url
self.project = project
self._client_lock = threading.Lock()
self._client_thread = None
self._client = client
self.name = name or generate_readable_name(SQLiteStorage.get_runs(project))
self.config = config or {}
self._queued_logs: list[LogEntry] = []
self._stop_flag = threading.Event()
self._client_thread = threading.Thread(target=self._init_client_background)
self._client_thread.daemon = True
self._client_thread.start()
print(f"[DEBUG] Run.__init__: Started client thread for {self.name}")
def _init_client_background(self):
print(f"[DEBUG] _init_client_background: Started for {self.name}")
if self._client is None:
print(f"[DEBUG] _init_client_background: No client provided, creating one for {self.name}")
fib = fibo()
for sleep_coefficient in fib:
try:
print(f"[DEBUG] _init_client_background: Attempting to create client for {self.url}")
client = Client(self.url, verbose=False)
print(f"[DEBUG] _init_client_background: Client created, testing connection...")
# Test the connection by trying to get the Space info
try:
print(f"[DEBUG] _init_client_background: Testing client connection...")
# Try to call a simple endpoint to verify connection
test_result = client.predict(api_name="/test")
print(f"[DEBUG] _init_client_background: Connection test successful: {test_result}")
except Exception as test_e:
print(f"[DEBUG] _init_client_background: Connection test failed: {test_e}")
# Continue anyway, the client might still work for our needs
with self._client_lock:
self._client = client
print(f"[DEBUG] _init_client_background: Successfully created client for {self.name}")
break
except Exception as e:
print(f"[DEBUG] _init_client_background: Failed to create client: {e}")
print(f"[DEBUG] _init_client_background: Error type: {type(e)}")
import traceback
traceback.print_exc()
pass
if sleep_coefficient is not None:
print(f"[DEBUG] _init_client_background: Waiting {0.1 * sleep_coefficient}s before retry")
time.sleep(0.1 * sleep_coefficient)
else:
print(f"[DEBUG] _init_client_background: Client already provided for {self.name}")
print(f"[DEBUG] _init_client_background: About to start _batch_sender for {self.name}")
self._batch_sender()
print(f"[DEBUG] _init_client_background: _batch_sender finished for {self.name}")
def _batch_sender(self):
"""Send batched logs every 500ms."""
print(f"[DEBUG] _batch_sender: Started for {self.name}")
print(f"[DEBUG] _batch_sender: Client available: {self._client is not None}")
print(f"[DEBUG] _batch_sender: Stop flag set: {self._stop_flag.is_set()}")
iteration = 0
while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
iteration += 1
print(f"[DEBUG] _batch_sender: Iteration {iteration} for {self.name}")
if not self._stop_flag.is_set():
time.sleep(0.5)
with self._client_lock:
print(f"[DEBUG] _batch_sender: Checking queue for {self.name}, size: {len(self._queued_logs)}")
if self._queued_logs and self._client is not None:
logs_to_send = self._queued_logs.copy()
self._queued_logs.clear()
print(f"[DEBUG] _batch_sender: Sending {len(logs_to_send)} logs via bulk_log for {self.name}")
try:
hf_token = huggingface_hub.utils.get_token()
print(f"[DEBUG] _batch_sender: Got HF token: {hf_token[:10] if hf_token else 'None'}...")
print(f"[DEBUG] _batch_sender: Calling client.predict with api_name='/bulk_log'")
result = self._client.predict(
api_name="/bulk_log",
logs=logs_to_send,
hf_token=hf_token,
)
print(f"[DEBUG] _batch_sender: bulk_log call successful for {self.name}, result: {result}")
except Exception as e:
print(f"[DEBUG] _batch_sender: Error calling bulk_log for {self.name}: {e}")
print(f"[DEBUG] _batch_sender: Error type: {type(e)}")
import traceback
traceback.print_exc()
else:
print(f"[DEBUG] _batch_sender: No logs to send or no client for {self.name}")
# If stop flag is set and no more logs, exit
if self._stop_flag.is_set() and len(self._queued_logs) == 0:
print(f"[DEBUG] _batch_sender: Stop flag set and no more logs, exiting for {self.name}")
break
print(f"[DEBUG] _batch_sender: Exiting loop for {self.name}")
def log(self, metrics: dict, step: int | None = None):
print(f"[DEBUG] log: Called for {self.name} with {len(metrics)} metrics, step={step}")
for k in metrics.keys():
if k in RESERVED_KEYS or k.startswith("__"):
raise ValueError(
f"Please do not use this reserved key as a metric: {k}"
)
log_entry: LogEntry = {
"project": self.project,
"run": self.name,
"metrics": metrics,
"step": step,
}
with self._client_lock:
self._queued_logs.append(log_entry)
print(f"[DEBUG] log: Added log entry to queue for {self.name}, queue size now: {len(self._queued_logs)}")
def finish(self):
"""Cleanup when run is finished."""
print(f"[DEBUG] finish: Called for {self.name}")
# First, send any remaining queued logs if we have a client
with self._client_lock:
if self._queued_logs and self._client is not None:
logs_to_send = self._queued_logs.copy()
self._queued_logs.clear()
print(f"[DEBUG] finish: Sending final {len(logs_to_send)} logs via bulk_log for {self.name}")
try:
hf_token = huggingface_hub.utils.get_token()
print(f"[DEBUG] finish: Got HF token: {hf_token[:10] if hf_token else 'None'}...")
result = self._client.predict(
api_name="/bulk_log",
logs=logs_to_send,
hf_token=hf_token,
)
print(f"[DEBUG] finish: Final bulk_log call successful for {self.name}, result: {result}")
except Exception as e:
print(f"[DEBUG] finish: Error in final bulk_log call for {self.name}: {e}")
print(f"[DEBUG] finish: Error type: {type(e)}")
import traceback
traceback.print_exc()
else:
print(f"[DEBUG] finish: No logs to send or no client for {self.name}")
# Now set the stop flag to signal the background thread to stop
print(f"[DEBUG] finish: Setting stop flag for {self.name}")
self._stop_flag.set()
# Give the background thread a moment to process any remaining logs
print(f"[DEBUG] finish: Waiting a moment for background thread to process logs for {self.name}")
time.sleep(1.0)
if self._client_thread is not None:
print(f"* Uploading logs to Trackio Space: {self.url} (please wait...)")
self._client_thread.join()
print(f"[DEBUG] finish: Client thread joined for {self.name}")
|