INV / cross_gpu_stream.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Cross-GPU Stream Manager for coordinating operations across multiple GPUs
Uses DuckDB with HuggingFace for persistent storage
"""
from typing import Dict, List, Optional, Any
import threading
import time
import logging
import duckdb
import json
from config import get_hf_token_cached
# Initialize token from .env
class CrossGPUStreamManager:
def __init__(self, db_path: str = "hf://datasets/Fred808/helium/storage.json"):
self.stream_lock = threading.Lock()
self.transfer_lock = threading.Lock()
self.db_path = db_path
self.con = self._init_db()
def _init_db(self) -> duckdb.DuckDBPyConnection:
"""Initialize database connection and schema"""
con = duckdb.connect(self.db_path)
# Configure HuggingFace access
con.execute("INSTALL httpfs;")
con.execute("LOAD httpfs;")
con.execute("SET s3_endpoint='hf.co';")
con.execute("SET s3_use_ssl=true;")
con.execute("SET s3_url_style='path';")
con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
# Create streams table
con.execute("""
CREATE TABLE IF NOT EXISTS streams (
stream_id VARCHAR PRIMARY KEY,
source_gpu VARCHAR,
target_gpu VARCHAR,
state VARCHAR,
created_at TIMESTAMP,
last_active TIMESTAMP,
transfer_count INTEGER,
total_bytes_transferred BIGINT
)
""")
# Create transfers table
con.execute("""
CREATE TABLE IF NOT EXISTS transfers (
transfer_id VARCHAR PRIMARY KEY,
stream_id VARCHAR,
transfer_size BIGINT,
started_at TIMESTAMP,
completed_at TIMESTAMP,
state VARCHAR
)
""")
return con
def create_stream(self, stream_id: str, source_gpu: str, target_gpu: str) -> Dict[str, Any]:
"""Create a new cross-GPU stream"""
with self.stream_lock:
# Check if stream exists
result = self.con.execute("""
SELECT * FROM streams WHERE stream_id = ?
""", [stream_id]).fetchone()
if result:
return dict(zip(['id', 'source_gpu', 'target_gpu', 'state', 'created_at',
'last_active', 'transfer_count', 'total_bytes_transferred'], result))
# Create new stream
now = time.time()
self.con.execute("""
INSERT INTO streams (
stream_id, source_gpu, target_gpu, state,
created_at, last_active, transfer_count, total_bytes_transferred
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", [stream_id, source_gpu, target_gpu, 'initialized', now, now, 0, 0])
logging.info(f"Created cross-GPU stream {stream_id} from {source_gpu} to {target_gpu}")
return {
'id': stream_id,
'source_gpu': source_gpu,
'target_gpu': target_gpu,
'state': 'initialized',
'created_at': now,
'last_active': now,
'transfer_count': 0,
'total_bytes_transferred': 0
}
def start_transfer(self, stream_id: str, transfer_size: int) -> str:
"""Start a new data transfer on a stream"""
with self.stream_lock:
# Check if stream exists
result = self.con.execute("""
SELECT transfer_count FROM streams WHERE stream_id = ?
""", [stream_id]).fetchone()
if not result:
raise ValueError(f"Stream {stream_id} does not exist")
transfer_count = result[0]
transfer_id = f"transfer_{stream_id}_{transfer_count}"
now = time.time()
# Create transfer record
with self.transfer_lock:
self.con.execute("""
INSERT INTO transfers (
transfer_id, stream_id, transfer_size,
started_at, state
) VALUES (?, ?, ?, ?, ?)
""", [transfer_id, stream_id, transfer_size, now, 'in_progress'])
# Update stream
self.con.execute("""
UPDATE streams
SET transfer_count = transfer_count + 1,
last_active = ?
WHERE stream_id = ?
""", [now, stream_id])
logging.info(f"Started transfer {transfer_id} on stream {stream_id}")
return transfer_id
def complete_transfer(self, transfer_id: str) -> None:
"""Mark a transfer as complete and update statistics"""
with self.transfer_lock:
# Get transfer info
transfer = self.con.execute("""
SELECT stream_id, transfer_size FROM transfers
WHERE transfer_id = ? AND state = 'in_progress'
""", [transfer_id]).fetchone()
if not transfer:
raise ValueError(f"Transfer {transfer_id} not found or not in progress")
stream_id, transfer_size = transfer
now = time.time()
# Update transfer status
self.con.execute("""
UPDATE transfers
SET state = 'completed',
completed_at = ?
WHERE transfer_id = ?
""", [now, transfer_id])
# Update stream statistics
with self.stream_lock:
self.con.execute("""
UPDATE streams
SET total_bytes_transferred = total_bytes_transferred + ?,
last_active = ?
WHERE stream_id = ?
""", [transfer_size, now, stream_id])
logging.info(f"Completed transfer {transfer_id}")
def get_stream_stats(self, stream_id: str) -> Dict[str, Any]:
"""Get statistics for a specific stream"""
with self.stream_lock:
result = self.con.execute("""
SELECT transfer_count, total_bytes_transferred,
created_at, last_active
FROM streams WHERE stream_id = ?
""", [stream_id]).fetchone()
if not result:
raise ValueError(f"Stream {stream_id} does not exist")
transfer_count, total_bytes, created_at, last_active = result
now = time.time()
return {
'transfer_count': transfer_count,
'total_bytes_transferred': total_bytes,
'uptime': now - created_at,
'last_active_ago': now - last_active
}
def cleanup_inactive_streams(self, timeout: float = 300.0) -> List[str]:
"""Remove streams that have been inactive for the specified timeout"""
current_time = time.time()
cutoff_time = current_time - timeout
with self.stream_lock:
# Find inactive streams
results = self.con.execute("""
DELETE FROM streams
WHERE last_active < ?
RETURNING stream_id
""", [cutoff_time]).fetchall()
cleaned_streams = [r[0] for r in results]
for stream_id in cleaned_streams:
logging.info(f"Cleaned up inactive stream {stream_id}")
return cleaned_streams