File size: 8,937 Bytes
3b4ee4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import threading
import time

import huggingface_hub
from gradio_client import Client

from trackio.sqlite_storage import SQLiteStorage
from trackio.typehints import LogEntry
from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name


class Run:
    def __init__(
        self,
        url: str,
        project: str,
        client: Client | None,
        name: str | None = None,
        config: dict | None = None,
    ):
        print(f"[DEBUG] Run.__init__: url={url}, project={project}, client={client is not None}, name={name}")
        self.url = url
        self.project = project
        self._client_lock = threading.Lock()
        self._client_thread = None
        self._client = client
        self.name = name or generate_readable_name(SQLiteStorage.get_runs(project))
        self.config = config or {}
        self._queued_logs: list[LogEntry] = []
        self._stop_flag = threading.Event()

        self._client_thread = threading.Thread(target=self._init_client_background)
        self._client_thread.daemon = True
        self._client_thread.start()
        print(f"[DEBUG] Run.__init__: Started client thread for {self.name}")

    def _init_client_background(self):
        print(f"[DEBUG] _init_client_background: Started for {self.name}")
        if self._client is None:
            print(f"[DEBUG] _init_client_background: No client provided, creating one for {self.name}")
            fib = fibo()
            for sleep_coefficient in fib:
                try:
                    print(f"[DEBUG] _init_client_background: Attempting to create client for {self.url}")
                    client = Client(self.url, verbose=False)
                    print(f"[DEBUG] _init_client_background: Client created, testing connection...")
                    
                    # Test the connection by trying to get the Space info
                    try:
                        print(f"[DEBUG] _init_client_background: Testing client connection...")
                        # Try to call a simple endpoint to verify connection
                        test_result = client.predict(api_name="/test")
                        print(f"[DEBUG] _init_client_background: Connection test successful: {test_result}")
                    except Exception as test_e:
                        print(f"[DEBUG] _init_client_background: Connection test failed: {test_e}")
                        # Continue anyway, the client might still work for our needs
                    
                    with self._client_lock:
                        self._client = client
                    print(f"[DEBUG] _init_client_background: Successfully created client for {self.name}")
                    break
                except Exception as e:
                    print(f"[DEBUG] _init_client_background: Failed to create client: {e}")
                    print(f"[DEBUG] _init_client_background: Error type: {type(e)}")
                    import traceback
                    traceback.print_exc()
                    pass
                if sleep_coefficient is not None:
                    print(f"[DEBUG] _init_client_background: Waiting {0.1 * sleep_coefficient}s before retry")
                    time.sleep(0.1 * sleep_coefficient)
        else:
            print(f"[DEBUG] _init_client_background: Client already provided for {self.name}")

        print(f"[DEBUG] _init_client_background: About to start _batch_sender for {self.name}")
        self._batch_sender()
        print(f"[DEBUG] _init_client_background: _batch_sender finished for {self.name}")

    def _batch_sender(self):
        """Send batched logs every 500ms."""
        print(f"[DEBUG] _batch_sender: Started for {self.name}")
        print(f"[DEBUG] _batch_sender: Client available: {self._client is not None}")
        print(f"[DEBUG] _batch_sender: Stop flag set: {self._stop_flag.is_set()}")
        
        iteration = 0
        while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
            iteration += 1
            print(f"[DEBUG] _batch_sender: Iteration {iteration} for {self.name}")
            
            if not self._stop_flag.is_set():
                time.sleep(0.5)

            with self._client_lock:
                print(f"[DEBUG] _batch_sender: Checking queue for {self.name}, size: {len(self._queued_logs)}")
                if self._queued_logs and self._client is not None:
                    logs_to_send = self._queued_logs.copy()
                    self._queued_logs.clear()
                    print(f"[DEBUG] _batch_sender: Sending {len(logs_to_send)} logs via bulk_log for {self.name}")

                    try:
                        hf_token = huggingface_hub.utils.get_token()
                        print(f"[DEBUG] _batch_sender: Got HF token: {hf_token[:10] if hf_token else 'None'}...")
                        
                        print(f"[DEBUG] _batch_sender: Calling client.predict with api_name='/bulk_log'")
                        result = self._client.predict(
                            api_name="/bulk_log",
                            logs=logs_to_send,
                            hf_token=hf_token,
                        )
                        print(f"[DEBUG] _batch_sender: bulk_log call successful for {self.name}, result: {result}")
                    except Exception as e:
                        print(f"[DEBUG] _batch_sender: Error calling bulk_log for {self.name}: {e}")
                        print(f"[DEBUG] _batch_sender: Error type: {type(e)}")
                        import traceback
                        traceback.print_exc()
                else:
                    print(f"[DEBUG] _batch_sender: No logs to send or no client for {self.name}")
                    
                # If stop flag is set and no more logs, exit
                if self._stop_flag.is_set() and len(self._queued_logs) == 0:
                    print(f"[DEBUG] _batch_sender: Stop flag set and no more logs, exiting for {self.name}")
                    break
        
        print(f"[DEBUG] _batch_sender: Exiting loop for {self.name}")

    def log(self, metrics: dict, step: int | None = None):
        print(f"[DEBUG] log: Called for {self.name} with {len(metrics)} metrics, step={step}")
        for k in metrics.keys():
            if k in RESERVED_KEYS or k.startswith("__"):
                raise ValueError(
                    f"Please do not use this reserved key as a metric: {k}"
                )

        log_entry: LogEntry = {
            "project": self.project,
            "run": self.name,
            "metrics": metrics,
            "step": step,
        }

        with self._client_lock:
            self._queued_logs.append(log_entry)
            print(f"[DEBUG] log: Added log entry to queue for {self.name}, queue size now: {len(self._queued_logs)}")

    def finish(self):
        """Cleanup when run is finished."""
        print(f"[DEBUG] finish: Called for {self.name}")
        
        # First, send any remaining queued logs if we have a client
        with self._client_lock:
            if self._queued_logs and self._client is not None:
                logs_to_send = self._queued_logs.copy()
                self._queued_logs.clear()
                print(f"[DEBUG] finish: Sending final {len(logs_to_send)} logs via bulk_log for {self.name}")
                
                try:
                    hf_token = huggingface_hub.utils.get_token()
                    print(f"[DEBUG] finish: Got HF token: {hf_token[:10] if hf_token else 'None'}...")
                    
                    result = self._client.predict(
                        api_name="/bulk_log",
                        logs=logs_to_send,
                        hf_token=hf_token,
                    )
                    print(f"[DEBUG] finish: Final bulk_log call successful for {self.name}, result: {result}")
                except Exception as e:
                    print(f"[DEBUG] finish: Error in final bulk_log call for {self.name}: {e}")
                    print(f"[DEBUG] finish: Error type: {type(e)}")
                    import traceback
                    traceback.print_exc()
            else:
                print(f"[DEBUG] finish: No logs to send or no client for {self.name}")
        
        # Now set the stop flag to signal the background thread to stop
        print(f"[DEBUG] finish: Setting stop flag for {self.name}")
        self._stop_flag.set()
        
        # Give the background thread a moment to process any remaining logs
        print(f"[DEBUG] finish: Waiting a moment for background thread to process logs for {self.name}")
        time.sleep(1.0)
        
        if self._client_thread is not None:
            print(f"* Uploading logs to Trackio Space: {self.url} (please wait...)")
            self._client_thread.join()
            print(f"[DEBUG] finish: Client thread joined for {self.name}")