INV / virtual_gpu_driver /src /cross_gpu_stream.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
import time
from queue import Queue
from threading import Lock
import logging
import json
import duckdb
from huggingface_hub import HfApi, HfFileSystem
from config import get_db_url, get_hf_token_cached
@dataclass
class CrossGPUEvent:
"""Represents a cross-GPU synchronization event"""
event_id: str
timestamp: float
source_gpu: int
target_gpu: Optional[int]
completed: bool = False
transfer_size: Optional[int] = None
nvlink_path: Optional[List[str]] = None
class CrossGPUStream:
"""Represents a stream that can execute operations across multiple GPUs"""
DB_URL = "hf://datasets/Fred808/helium/storage.json"
def __init__(self, stream_id: int, db_url: Optional[str] = None):
self.stream_id = stream_id
self.events: List[CrossGPUEvent] = []
self.operation_queue: Queue = Queue()
self.lock = Lock()
self.is_active = True
self.db_url = db_url or self.DB_URL
self.max_retries = 3
self._connect_with_retries()
self._setup_database()
def _connect_with_retries(self):
"""Establish database connection with retry logic"""
for attempt in range(self.max_retries):
try:
self.conn = self._init_db_connection()
return
except Exception as e:
if attempt == self.max_retries - 1:
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
time.sleep(1)
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
"""Initialize database connection with HuggingFace configuration"""
# Convert HF URL to S3 path
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
# Connect to remote database
conn = duckdb.connect(db_path)
conn.execute("INSTALL httpfs;")
conn.execute("LOAD httpfs;")
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
conn.execute("SET s3_use_ssl=true;")
conn.execute("SET s3_url_style='path';")
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
return conn
def _setup_database(self):
"""Initialize database tables"""
# Events table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS stream_events (
event_id VARCHAR PRIMARY KEY,
stream_id BIGINT,
timestamp DOUBLE,
source_gpu INTEGER,
target_gpu INTEGER,
completed BOOLEAN,
transfer_size BIGINT,
nvlink_path JSON,
state_json JSON
)
""")
# Operations table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS stream_operations (
operation_id BIGINT PRIMARY KEY,
stream_id BIGINT,
operation_data JSON,
status VARCHAR,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
error_message VARCHAR
)
""")
# Sync points table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS sync_points (
sync_id VARCHAR PRIMARY KEY,
stream_id BIGINT,
gpu_ids VARCHAR,
status VARCHAR DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
error_message VARCHAR
)
""")
def record_event(self, source_gpu: int, target_gpu: Optional[int] = None,
transfer_size: Optional[int] = None) -> CrossGPUEvent:
"""Record a cross-GPU event in the stream"""
with self.lock:
event_id = f"event_{self.stream_id}_{len(self.events)}_{time.time_ns()}"
event = CrossGPUEvent(
event_id=event_id,
timestamp=time.time(),
source_gpu=source_gpu,
target_gpu=target_gpu,
transfer_size=transfer_size
)
# Store event in database
self.conn.execute("""
INSERT INTO stream_events (
event_id, stream_id, timestamp, source_gpu, target_gpu,
completed, transfer_size, nvlink_path, state_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", [
event_id, self.stream_id, event.timestamp, source_gpu,
target_gpu, False, transfer_size, None,
{"status": "created"}
])
self.events.append(event)
return event
def wait_event(self, event: CrossGPUEvent):
"""Wait for a cross-GPU event to complete"""
while True:
# Check database for event completion
result = self.conn.execute("""
SELECT completed
FROM stream_events
WHERE event_id = ?
""", [event.event_id]).fetchall()
if result and result[0][0]:
event.completed = True
break
if event.completed:
break
time.sleep(0.001) # Small sleep to prevent busy waiting
def synchronize(self):
"""Synchronize all operations in the stream"""
with self.lock:
for event in self.events:
self.wait_event(event)
self.events.clear()
# Clear completed events from database
self.conn.execute("""
DELETE FROM stream_events
WHERE stream_id = ? AND completed = true
""", [self.stream_id])
def add_operation(self, operation: Dict[str, Any]):
"""Add a cross-GPU operation to the stream"""
with self.lock:
# Record operation in database
self.conn.execute("""
INSERT INTO stream_operations (
stream_id, operation_data, status
) VALUES (?, ?, ?)
""", [self.stream_id, operation, "pending"])
# If operation involves data transfer, find optimal path
if operation.get('type') == 'transfer':
src_gpu = operation['source_gpu']
dst_gpu = operation['target_gpu']
size = operation['size']
# Query for best transfer path
paths = self.conn.execute("""
WITH RECURSIVE
gpu_path(src, dst, path, total_bandwidth, hops) AS (
-- Direct connections
SELECT
chip_a_id,
chip_b_id,
chip_a_id || ',' || chip_b_id,
bandwidth_tbps,
1
FROM optical_interconnects
WHERE (chip_a_id = ? AND chip_b_id = ?)
OR (chip_a_id = ? AND chip_b_id = ?)
UNION ALL
-- Multi-hop paths
SELECT
p.src,
i.chip_b_id,
p.path || ',' || i.chip_b_id,
LEAST(p.total_bandwidth, i.bandwidth_tbps),
p.hops + 1
FROM gpu_path p
JOIN optical_interconnects i ON p.dst = i.chip_a_id
WHERE p.hops < 3 -- Limit path length
AND i.chip_b_id != ALL(string_to_array(p.path, ',')::integer[])
)
SELECT path, total_bandwidth
FROM gpu_path
WHERE src = ? AND dst = ?
ORDER BY total_bandwidth DESC, hops ASC
LIMIT 1
""", (src_gpu, dst_gpu, dst_gpu, src_gpu, src_gpu, dst_gpu)).fetchone()
if paths:
operation['nvlink_path'] = paths[0].split(',')
operation['expected_bandwidth'] = float(paths[1])
self.operation_queue.put(operation)
def execute_next(self) -> bool:
"""Execute the next operation in the queue"""
try:
with self.lock:
if self.operation_queue.empty():
return False
operation = self.operation_queue.get()
event = self.record_event(
source_gpu=operation.get('source_gpu'),
target_gpu=operation.get('target_gpu'),
transfer_size=operation.get('size')
)
try:
if operation.get('type') == 'transfer':
self._execute_transfer(operation)
elif operation.get('type') == 'compute':
self._execute_compute(operation)
elif operation.get('type') == 'sync':
self._execute_sync(operation)
# Mark event as completed
event.completed = True
if self.hal:
self.hal.execute("""
UPDATE stream_events
SET completed = TRUE,
nvlink_path = ?,
state_json = ?
WHERE event_id = ?
""", (
','.join(operation.get('nvlink_path', [])),
json.dumps({"status": "completed"}),
event.event_id
))
return True
except Exception as e:
logging.error(f"Error in stream {self.stream_id}: {str(e)}")
if self.hal:
self.hal.execute("""
UPDATE stream_events
SET state_json = ?
WHERE event_id = ?
""", (
json.dumps({"status": "error", "error": str(e)}),
event.event_id
))
return False
except Exception as e:
logging.error(f"Error in stream {self.stream_id}: {str(e)}")
return False
def _execute_transfer(self, operation: Dict[str, Any]):
"""Execute a data transfer operation using NVLink path"""
if not self.hal:
raise RuntimeError("No HAL connection available for transfer")
src_gpu = operation['source_gpu']
dst_gpu = operation['target_gpu']
size = operation['size']
nvlink_path = operation.get('nvlink_path', [])
# Record transfer in HAL database
self.hal.execute("""
INSERT INTO memory_transfers (
source_chip, target_chip, size_bytes, nvlink_path,
start_time, bandwidth_used
) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, ?)
""", (
src_gpu, dst_gpu, size,
','.join(map(str, nvlink_path)),
operation.get('expected_bandwidth', 0)
))
def _execute_compute(self, operation: Dict[str, Any]):
"""Execute a compute operation on specified GPU"""
if not self.hal:
raise RuntimeError("No HAL connection available for compute")
gpu_id = operation['gpu_id']
# Update GPU state in HAL
self.hal.execute("""
UPDATE gpu_chips
SET state_json = json_patch(
state_json,
json_object(
'current_operation', json(?),
'last_updated', json(CURRENT_TIMESTAMP)
)
)
WHERE chip_id = ?
""", (json.dumps(operation), gpu_id))
def _execute_sync(self, operation: Dict[str, Any]):
"""Execute a synchronization operation"""
gpu_ids = operation.get('gpu_ids', [])
if not gpu_ids:
return
# Create sync point in database
sync_id = f"sync_{self.stream_id}_{time.time_ns()}"
self.conn.execute("""
INSERT INTO sync_points (
sync_id, stream_id, gpu_ids
) VALUES (?, ?, ?)
""", [sync_id, self.stream_id, ','.join(map(str, gpu_ids))])
# Wait for all GPUs to reach sync point
while True:
result = self.conn.execute("""
SELECT COUNT(*)
FROM gpu_chips
WHERE chip_id = ANY(string_split(?, ',')::INTEGER[])
AND state_json->>'sync_id' = ?
""", [','.join(map(str, gpu_ids)), sync_id]).fetchall()
if result[0][0] == len(gpu_ids):
# All GPUs reached sync point
self.conn.execute("""
UPDATE sync_points
SET status = 'completed',
completed_at = CURRENT_TIMESTAMP
WHERE sync_id = ?
""", [sync_id])
break
time.sleep(0.001)
class CrossGPUStreamManager:
"""Manages multiple cross-GPU streams"""
DB_URL = "hf://datasets/Fred808/helium/storage.json"
def __init__(self, db_url: Optional[str] = None):
self.streams: List[CrossGPUStream] = []
self.db_url = db_url or self.DB_URL
self.max_retries = 3
self._connect_with_retries()
# Create default stream
self.default_stream = self.create_stream()
def _connect_with_retries(self):
"""Establish database connection with retry logic"""
for attempt in range(self.max_retries):
try:
self.conn = self._init_db_connection()
self._setup_database()
return
except Exception as e:
if attempt == self.max_retries - 1:
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
time.sleep(1)
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
"""Initialize database connection with HuggingFace configuration"""
# Convert HF URL to S3 path
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
# Connect to remote database
conn = duckdb.connect(db_path)
conn.execute("INSTALL httpfs;")
conn.execute("LOAD httpfs;")
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
conn.execute("SET s3_use_ssl=true;")
conn.execute("SET s3_url_style='path';")
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
return conn
def _setup_database(self):
"""Initialize database tables"""
self.conn.execute("""
CREATE TABLE IF NOT EXISTS streams (
stream_id BIGINT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT true,
state_json JSON
)
""")
def create_stream(self) -> CrossGPUStream:
"""Create a new cross-GPU stream"""
stream_id = len(self.streams)
# Register stream in database
self.conn.execute("""
INSERT INTO streams (stream_id, state_json)
VALUES (?, ?)
""", [stream_id, {"status": "created"}])
# Create stream object with same DB connection
stream = CrossGPUStream(stream_id, self.db_url)
self.streams.append(stream)
return stream
def get_stream(self, stream_id: int) -> Optional[CrossGPUStream]:
"""Get a stream by its ID"""
if 0 <= stream_id < len(self.streams):
return self.streams[stream_id]
return None
def synchronize_all(self):
"""Synchronize all streams"""
for stream in self.streams:
stream.synchronize()
def synchronize_stream(self, stream_id: int):
"""Synchronize a specific stream"""
stream = self.get_stream(stream_id)
if stream:
stream.synchronize()
def execute_streams(self):
"""Execute operations in all streams"""
while True:
executed = False
for stream in self.streams:
if stream.execute_next():
executed = True
if not executed:
break
def get_optimal_transfer_path(self, src_gpu: int, dst_gpu: int, size: int) -> List[str]:
"""Get optimal NVLink path for data transfer"""
result = self.conn.execute("""
WITH RECURSIVE
gpu_path(src, dst, path, total_bandwidth, hops) AS (
-- Direct connections
SELECT
chip_a_id,
chip_b_id,
CAST(chip_a_id AS VARCHAR) || ',' || CAST(chip_b_id AS VARCHAR),
bandwidth_tbps,
1
FROM optical_interconnects
WHERE (chip_a_id = ? AND chip_b_id = ?)
OR (chip_a_id = ? AND chip_b_id = ?)
UNION ALL
-- Multi-hop paths
SELECT
p.src,
i.chip_b_id,
p.path || ',' || CAST(i.chip_b_id AS VARCHAR),
LEAST(p.total_bandwidth, i.bandwidth_tbps),
p.hops + 1
FROM gpu_path p
JOIN optical_interconnects i ON p.dst = i.chip_a_id
WHERE p.hops < 3 -- Limit path length
AND NOT list_contains(string_split(p.path, ','), CAST(i.chip_b_id AS VARCHAR))
)
SELECT path, total_bandwidth
FROM gpu_path
WHERE src = ? AND dst = ?
ORDER BY total_bandwidth DESC, hops ASC
LIMIT 1
""", [src_gpu, dst_gpu, dst_gpu, src_gpu, src_gpu, dst_gpu]).fetchall()
if result:
return result[0][0].split(',')
return []
def add_cross_gpu_operation(self, stream_id: int, operation: Dict[str, Any]):
"""Add a cross-GPU operation to a specific stream"""
stream = self.get_stream(stream_id)
if not stream:
raise ValueError(f"Invalid stream ID: {stream_id}")
# If it's a transfer operation, find optimal path
if operation.get('type') == 'transfer':
path = self.get_optimal_transfer_path(
operation['source_gpu'],
operation['target_gpu'],
operation['size']
)
operation['nvlink_path'] = path
stream.add_operation(operation)