|
|
"""
|
|
|
Cross-GPU Stream Manager for coordinating operations across multiple GPUs
|
|
|
Uses DuckDB with HuggingFace for persistent storage
|
|
|
"""
|
|
|
from typing import Dict, List, Optional, Any
|
|
|
import threading
|
|
|
import time
|
|
|
import logging
|
|
|
import duckdb
|
|
|
import json
|
|
|
from config import get_hf_token_cached
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossGPUStreamManager:
|
|
|
|
|
|
def __init__(self, db_path: str = "hf://datasets/Fred808/helium/storage.json"):
|
|
|
self.stream_lock = threading.Lock()
|
|
|
self.transfer_lock = threading.Lock()
|
|
|
self.db_path = db_path
|
|
|
self.con = self._init_db()
|
|
|
|
|
|
def _init_db(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection and schema"""
|
|
|
con = duckdb.connect(self.db_path)
|
|
|
|
|
|
|
|
|
con.execute("INSTALL httpfs;")
|
|
|
con.execute("LOAD httpfs;")
|
|
|
con.execute("SET s3_endpoint='hf.co';")
|
|
|
con.execute("SET s3_use_ssl=true;")
|
|
|
con.execute("SET s3_url_style='path';")
|
|
|
con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
|
|
|
|
|
|
con.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS streams (
|
|
|
stream_id VARCHAR PRIMARY KEY,
|
|
|
source_gpu VARCHAR,
|
|
|
target_gpu VARCHAR,
|
|
|
state VARCHAR,
|
|
|
created_at TIMESTAMP,
|
|
|
last_active TIMESTAMP,
|
|
|
transfer_count INTEGER,
|
|
|
total_bytes_transferred BIGINT
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
con.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS transfers (
|
|
|
transfer_id VARCHAR PRIMARY KEY,
|
|
|
stream_id VARCHAR,
|
|
|
transfer_size BIGINT,
|
|
|
started_at TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
state VARCHAR
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
return con
|
|
|
def create_stream(self, stream_id: str, source_gpu: str, target_gpu: str) -> Dict[str, Any]:
|
|
|
"""Create a new cross-GPU stream"""
|
|
|
with self.stream_lock:
|
|
|
|
|
|
result = self.con.execute("""
|
|
|
SELECT * FROM streams WHERE stream_id = ?
|
|
|
""", [stream_id]).fetchone()
|
|
|
|
|
|
if result:
|
|
|
return dict(zip(['id', 'source_gpu', 'target_gpu', 'state', 'created_at',
|
|
|
'last_active', 'transfer_count', 'total_bytes_transferred'], result))
|
|
|
|
|
|
|
|
|
now = time.time()
|
|
|
self.con.execute("""
|
|
|
INSERT INTO streams (
|
|
|
stream_id, source_gpu, target_gpu, state,
|
|
|
created_at, last_active, transfer_count, total_bytes_transferred
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
""", [stream_id, source_gpu, target_gpu, 'initialized', now, now, 0, 0])
|
|
|
|
|
|
logging.info(f"Created cross-GPU stream {stream_id} from {source_gpu} to {target_gpu}")
|
|
|
|
|
|
return {
|
|
|
'id': stream_id,
|
|
|
'source_gpu': source_gpu,
|
|
|
'target_gpu': target_gpu,
|
|
|
'state': 'initialized',
|
|
|
'created_at': now,
|
|
|
'last_active': now,
|
|
|
'transfer_count': 0,
|
|
|
'total_bytes_transferred': 0
|
|
|
}
|
|
|
|
|
|
def start_transfer(self, stream_id: str, transfer_size: int) -> str:
|
|
|
"""Start a new data transfer on a stream"""
|
|
|
with self.stream_lock:
|
|
|
|
|
|
result = self.con.execute("""
|
|
|
SELECT transfer_count FROM streams WHERE stream_id = ?
|
|
|
""", [stream_id]).fetchone()
|
|
|
|
|
|
if not result:
|
|
|
raise ValueError(f"Stream {stream_id} does not exist")
|
|
|
|
|
|
transfer_count = result[0]
|
|
|
transfer_id = f"transfer_{stream_id}_{transfer_count}"
|
|
|
now = time.time()
|
|
|
|
|
|
|
|
|
with self.transfer_lock:
|
|
|
self.con.execute("""
|
|
|
INSERT INTO transfers (
|
|
|
transfer_id, stream_id, transfer_size,
|
|
|
started_at, state
|
|
|
) VALUES (?, ?, ?, ?, ?)
|
|
|
""", [transfer_id, stream_id, transfer_size, now, 'in_progress'])
|
|
|
|
|
|
|
|
|
self.con.execute("""
|
|
|
UPDATE streams
|
|
|
SET transfer_count = transfer_count + 1,
|
|
|
last_active = ?
|
|
|
WHERE stream_id = ?
|
|
|
""", [now, stream_id])
|
|
|
|
|
|
logging.info(f"Started transfer {transfer_id} on stream {stream_id}")
|
|
|
return transfer_id
|
|
|
|
|
|
def complete_transfer(self, transfer_id: str) -> None:
|
|
|
"""Mark a transfer as complete and update statistics"""
|
|
|
with self.transfer_lock:
|
|
|
|
|
|
transfer = self.con.execute("""
|
|
|
SELECT stream_id, transfer_size FROM transfers
|
|
|
WHERE transfer_id = ? AND state = 'in_progress'
|
|
|
""", [transfer_id]).fetchone()
|
|
|
|
|
|
if not transfer:
|
|
|
raise ValueError(f"Transfer {transfer_id} not found or not in progress")
|
|
|
|
|
|
stream_id, transfer_size = transfer
|
|
|
now = time.time()
|
|
|
|
|
|
|
|
|
self.con.execute("""
|
|
|
UPDATE transfers
|
|
|
SET state = 'completed',
|
|
|
completed_at = ?
|
|
|
WHERE transfer_id = ?
|
|
|
""", [now, transfer_id])
|
|
|
|
|
|
|
|
|
with self.stream_lock:
|
|
|
self.con.execute("""
|
|
|
UPDATE streams
|
|
|
SET total_bytes_transferred = total_bytes_transferred + ?,
|
|
|
last_active = ?
|
|
|
WHERE stream_id = ?
|
|
|
""", [transfer_size, now, stream_id])
|
|
|
|
|
|
logging.info(f"Completed transfer {transfer_id}")
|
|
|
|
|
|
def get_stream_stats(self, stream_id: str) -> Dict[str, Any]:
|
|
|
"""Get statistics for a specific stream"""
|
|
|
with self.stream_lock:
|
|
|
result = self.con.execute("""
|
|
|
SELECT transfer_count, total_bytes_transferred,
|
|
|
created_at, last_active
|
|
|
FROM streams WHERE stream_id = ?
|
|
|
""", [stream_id]).fetchone()
|
|
|
|
|
|
if not result:
|
|
|
raise ValueError(f"Stream {stream_id} does not exist")
|
|
|
|
|
|
transfer_count, total_bytes, created_at, last_active = result
|
|
|
now = time.time()
|
|
|
|
|
|
return {
|
|
|
'transfer_count': transfer_count,
|
|
|
'total_bytes_transferred': total_bytes,
|
|
|
'uptime': now - created_at,
|
|
|
'last_active_ago': now - last_active
|
|
|
}
|
|
|
|
|
|
def cleanup_inactive_streams(self, timeout: float = 300.0) -> List[str]:
|
|
|
"""Remove streams that have been inactive for the specified timeout"""
|
|
|
current_time = time.time()
|
|
|
cutoff_time = current_time - timeout
|
|
|
|
|
|
with self.stream_lock:
|
|
|
|
|
|
results = self.con.execute("""
|
|
|
DELETE FROM streams
|
|
|
WHERE last_active < ?
|
|
|
RETURNING stream_id
|
|
|
""", [cutoff_time]).fetchall()
|
|
|
|
|
|
cleaned_streams = [r[0] for r in results]
|
|
|
|
|
|
for stream_id in cleaned_streams:
|
|
|
logging.info(f"Cleaned up inactive stream {stream_id}")
|
|
|
|
|
|
return cleaned_streams
|
|
|
|