File size: 8,122 Bytes
7a0c684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
"""
Cross-GPU Stream Manager for coordinating operations across multiple GPUs
Uses DuckDB with HuggingFace for persistent storage
"""
from typing import Dict, List, Optional, Any
import threading
import time
import logging
import duckdb
import json
from config import get_hf_token_cached
# Initialize token from .env
class CrossGPUStreamManager:
def __init__(self, db_path: str = "hf://datasets/Fred808/helium/storage.json"):
self.stream_lock = threading.Lock()
self.transfer_lock = threading.Lock()
self.db_path = db_path
self.con = self._init_db()
def _init_db(self) -> duckdb.DuckDBPyConnection:
"""Initialize database connection and schema"""
con = duckdb.connect(self.db_path)
# Configure HuggingFace access
con.execute("INSTALL httpfs;")
con.execute("LOAD httpfs;")
con.execute("SET s3_endpoint='hf.co';")
con.execute("SET s3_use_ssl=true;")
con.execute("SET s3_url_style='path';")
con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
# Create streams table
con.execute("""
CREATE TABLE IF NOT EXISTS streams (
stream_id VARCHAR PRIMARY KEY,
source_gpu VARCHAR,
target_gpu VARCHAR,
state VARCHAR,
created_at TIMESTAMP,
last_active TIMESTAMP,
transfer_count INTEGER,
total_bytes_transferred BIGINT
)
""")
# Create transfers table
con.execute("""
CREATE TABLE IF NOT EXISTS transfers (
transfer_id VARCHAR PRIMARY KEY,
stream_id VARCHAR,
transfer_size BIGINT,
started_at TIMESTAMP,
completed_at TIMESTAMP,
state VARCHAR
)
""")
return con
def create_stream(self, stream_id: str, source_gpu: str, target_gpu: str) -> Dict[str, Any]:
"""Create a new cross-GPU stream"""
with self.stream_lock:
# Check if stream exists
result = self.con.execute("""
SELECT * FROM streams WHERE stream_id = ?
""", [stream_id]).fetchone()
if result:
return dict(zip(['id', 'source_gpu', 'target_gpu', 'state', 'created_at',
'last_active', 'transfer_count', 'total_bytes_transferred'], result))
# Create new stream
now = time.time()
self.con.execute("""
INSERT INTO streams (
stream_id, source_gpu, target_gpu, state,
created_at, last_active, transfer_count, total_bytes_transferred
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", [stream_id, source_gpu, target_gpu, 'initialized', now, now, 0, 0])
logging.info(f"Created cross-GPU stream {stream_id} from {source_gpu} to {target_gpu}")
return {
'id': stream_id,
'source_gpu': source_gpu,
'target_gpu': target_gpu,
'state': 'initialized',
'created_at': now,
'last_active': now,
'transfer_count': 0,
'total_bytes_transferred': 0
}
def start_transfer(self, stream_id: str, transfer_size: int) -> str:
"""Start a new data transfer on a stream"""
with self.stream_lock:
# Check if stream exists
result = self.con.execute("""
SELECT transfer_count FROM streams WHERE stream_id = ?
""", [stream_id]).fetchone()
if not result:
raise ValueError(f"Stream {stream_id} does not exist")
transfer_count = result[0]
transfer_id = f"transfer_{stream_id}_{transfer_count}"
now = time.time()
# Create transfer record
with self.transfer_lock:
self.con.execute("""
INSERT INTO transfers (
transfer_id, stream_id, transfer_size,
started_at, state
) VALUES (?, ?, ?, ?, ?)
""", [transfer_id, stream_id, transfer_size, now, 'in_progress'])
# Update stream
self.con.execute("""
UPDATE streams
SET transfer_count = transfer_count + 1,
last_active = ?
WHERE stream_id = ?
""", [now, stream_id])
logging.info(f"Started transfer {transfer_id} on stream {stream_id}")
return transfer_id
def complete_transfer(self, transfer_id: str) -> None:
"""Mark a transfer as complete and update statistics"""
with self.transfer_lock:
# Get transfer info
transfer = self.con.execute("""
SELECT stream_id, transfer_size FROM transfers
WHERE transfer_id = ? AND state = 'in_progress'
""", [transfer_id]).fetchone()
if not transfer:
raise ValueError(f"Transfer {transfer_id} not found or not in progress")
stream_id, transfer_size = transfer
now = time.time()
# Update transfer status
self.con.execute("""
UPDATE transfers
SET state = 'completed',
completed_at = ?
WHERE transfer_id = ?
""", [now, transfer_id])
# Update stream statistics
with self.stream_lock:
self.con.execute("""
UPDATE streams
SET total_bytes_transferred = total_bytes_transferred + ?,
last_active = ?
WHERE stream_id = ?
""", [transfer_size, now, stream_id])
logging.info(f"Completed transfer {transfer_id}")
def get_stream_stats(self, stream_id: str) -> Dict[str, Any]:
"""Get statistics for a specific stream"""
with self.stream_lock:
result = self.con.execute("""
SELECT transfer_count, total_bytes_transferred,
created_at, last_active
FROM streams WHERE stream_id = ?
""", [stream_id]).fetchone()
if not result:
raise ValueError(f"Stream {stream_id} does not exist")
transfer_count, total_bytes, created_at, last_active = result
now = time.time()
return {
'transfer_count': transfer_count,
'total_bytes_transferred': total_bytes,
'uptime': now - created_at,
'last_active_ago': now - last_active
}
def cleanup_inactive_streams(self, timeout: float = 300.0) -> List[str]:
"""Remove streams that have been inactive for the specified timeout"""
current_time = time.time()
cutoff_time = current_time - timeout
with self.stream_lock:
# Find inactive streams
results = self.con.execute("""
DELETE FROM streams
WHERE last_active < ?
RETURNING stream_id
""", [cutoff_time]).fetchall()
cleaned_streams = [r[0] for r in results]
for stream_id in cleaned_streams:
logging.info(f"Cleaned up inactive stream {stream_id}")
return cleaned_streams
|