File size: 8,122 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""

Cross-GPU Stream Manager for coordinating operations across multiple GPUs

Uses DuckDB with HuggingFace for persistent storage

"""
from typing import Dict, List, Optional, Any
import threading
import time
import logging
import duckdb
import json
from config import get_hf_token_cached

# Initialize token from .env



class CrossGPUStreamManager:
    
    def __init__(self, db_path: str = "hf://datasets/Fred808/helium/storage.json"):
        self.stream_lock = threading.Lock()
        self.transfer_lock = threading.Lock()
        self.db_path = db_path
        self.con = self._init_db()
        
    def _init_db(self) -> duckdb.DuckDBPyConnection:
        """Initialize database connection and schema"""
        con = duckdb.connect(self.db_path)
        
        # Configure HuggingFace access
        con.execute("INSTALL httpfs;")
        con.execute("LOAD httpfs;")
        con.execute("SET s3_endpoint='hf.co';")
        con.execute("SET s3_use_ssl=true;")
        con.execute("SET s3_url_style='path';")
        con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
        con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
        
        # Create streams table
        con.execute("""

            CREATE TABLE IF NOT EXISTS streams (

                stream_id VARCHAR PRIMARY KEY,

                source_gpu VARCHAR,

                target_gpu VARCHAR,

                state VARCHAR,

                created_at TIMESTAMP,

                last_active TIMESTAMP,

                transfer_count INTEGER,

                total_bytes_transferred BIGINT

            )

        """)
        
        # Create transfers table
        con.execute("""

            CREATE TABLE IF NOT EXISTS transfers (

                transfer_id VARCHAR PRIMARY KEY,

                stream_id VARCHAR,

                transfer_size BIGINT,

                started_at TIMESTAMP,

                completed_at TIMESTAMP,

                state VARCHAR

            )

        """)
        
        return con
    def create_stream(self, stream_id: str, source_gpu: str, target_gpu: str) -> Dict[str, Any]:
        """Create a new cross-GPU stream"""
        with self.stream_lock:
            # Check if stream exists
            result = self.con.execute("""

                SELECT * FROM streams WHERE stream_id = ?

            """, [stream_id]).fetchone()
            
            if result:
                return dict(zip(['id', 'source_gpu', 'target_gpu', 'state', 'created_at', 
                               'last_active', 'transfer_count', 'total_bytes_transferred'], result))
            
            # Create new stream
            now = time.time()
            self.con.execute("""

                INSERT INTO streams (

                    stream_id, source_gpu, target_gpu, state, 

                    created_at, last_active, transfer_count, total_bytes_transferred

                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)

            """, [stream_id, source_gpu, target_gpu, 'initialized', now, now, 0, 0])
            
            logging.info(f"Created cross-GPU stream {stream_id} from {source_gpu} to {target_gpu}")
            
            return {
                'id': stream_id,
                'source_gpu': source_gpu,
                'target_gpu': target_gpu,
                'state': 'initialized',
                'created_at': now,
                'last_active': now,
                'transfer_count': 0,
                'total_bytes_transferred': 0
            }
            
    def start_transfer(self, stream_id: str, transfer_size: int) -> str:
        """Start a new data transfer on a stream"""
        with self.stream_lock:
            # Check if stream exists
            result = self.con.execute("""

                SELECT transfer_count FROM streams WHERE stream_id = ?

            """, [stream_id]).fetchone()
            
            if not result:
                raise ValueError(f"Stream {stream_id} does not exist")
            
            transfer_count = result[0]
            transfer_id = f"transfer_{stream_id}_{transfer_count}"
            now = time.time()
            
            # Create transfer record
            with self.transfer_lock:
                self.con.execute("""

                    INSERT INTO transfers (

                        transfer_id, stream_id, transfer_size, 

                        started_at, state

                    ) VALUES (?, ?, ?, ?, ?)

                """, [transfer_id, stream_id, transfer_size, now, 'in_progress'])
            
            # Update stream
            self.con.execute("""

                UPDATE streams 

                SET transfer_count = transfer_count + 1,

                    last_active = ?

                WHERE stream_id = ?

            """, [now, stream_id])
            
            logging.info(f"Started transfer {transfer_id} on stream {stream_id}")
            return transfer_id
            
    def complete_transfer(self, transfer_id: str) -> None:
        """Mark a transfer as complete and update statistics"""
        with self.transfer_lock:
            # Get transfer info
            transfer = self.con.execute("""

                SELECT stream_id, transfer_size FROM transfers 

                WHERE transfer_id = ? AND state = 'in_progress'

            """, [transfer_id]).fetchone()
            
            if not transfer:
                raise ValueError(f"Transfer {transfer_id} not found or not in progress")
            
            stream_id, transfer_size = transfer
            now = time.time()
            
            # Update transfer status
            self.con.execute("""

                UPDATE transfers 

                SET state = 'completed',

                    completed_at = ?

                WHERE transfer_id = ?

            """, [now, transfer_id])
            
            # Update stream statistics
            with self.stream_lock:
                self.con.execute("""

                    UPDATE streams 

                    SET total_bytes_transferred = total_bytes_transferred + ?,

                        last_active = ?

                    WHERE stream_id = ?

                """, [transfer_size, now, stream_id])
            
            logging.info(f"Completed transfer {transfer_id}")
            
    def get_stream_stats(self, stream_id: str) -> Dict[str, Any]:
        """Get statistics for a specific stream"""
        with self.stream_lock:
            result = self.con.execute("""

                SELECT transfer_count, total_bytes_transferred, 

                       created_at, last_active

                FROM streams WHERE stream_id = ?

            """, [stream_id]).fetchone()
            
            if not result:
                raise ValueError(f"Stream {stream_id} does not exist")
            
            transfer_count, total_bytes, created_at, last_active = result
            now = time.time()
            
            return {
                'transfer_count': transfer_count,
                'total_bytes_transferred': total_bytes,
                'uptime': now - created_at,
                'last_active_ago': now - last_active
            }
            
    def cleanup_inactive_streams(self, timeout: float = 300.0) -> List[str]:
        """Remove streams that have been inactive for the specified timeout"""
        current_time = time.time()
        cutoff_time = current_time - timeout
        
        with self.stream_lock:
            # Find inactive streams
            results = self.con.execute("""

                DELETE FROM streams 

                WHERE last_active < ?

                RETURNING stream_id

            """, [cutoff_time]).fetchall()
            
            cleaned_streams = [r[0] for r in results]
            
            for stream_id in cleaned_streams:
                logging.info(f"Cleaned up inactive stream {stream_id}")
                
            return cleaned_streams