|
|
from typing import List, Optional, Dict, Any
|
|
|
from dataclasses import dataclass
|
|
|
import time
|
|
|
from queue import Queue
|
|
|
from threading import Lock
|
|
|
import logging
|
|
|
import json
|
|
|
import duckdb
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
from config import get_db_url, get_hf_token_cached
|
|
|
|
|
|
@dataclass
|
|
|
class CrossGPUEvent:
|
|
|
"""Represents a cross-GPU synchronization event"""
|
|
|
event_id: str
|
|
|
timestamp: float
|
|
|
source_gpu: int
|
|
|
target_gpu: Optional[int]
|
|
|
completed: bool = False
|
|
|
transfer_size: Optional[int] = None
|
|
|
nvlink_path: Optional[List[str]] = None
|
|
|
|
|
|
class CrossGPUStream:
|
|
|
"""Represents a stream that can execute operations across multiple GPUs"""
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, stream_id: int, db_url: Optional[str] = None):
|
|
|
self.stream_id = stream_id
|
|
|
self.events: List[CrossGPUEvent] = []
|
|
|
self.operation_queue: Queue = Queue()
|
|
|
self.lock = Lock()
|
|
|
self.is_active = True
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
self._setup_database()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(db_path)
|
|
|
conn.execute("INSTALL httpfs;")
|
|
|
conn.execute("LOAD httpfs;")
|
|
|
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
|
|
|
conn.execute("SET s3_use_ssl=true;")
|
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
return conn
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Initialize database tables"""
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS stream_events (
|
|
|
event_id VARCHAR PRIMARY KEY,
|
|
|
stream_id BIGINT,
|
|
|
timestamp DOUBLE,
|
|
|
source_gpu INTEGER,
|
|
|
target_gpu INTEGER,
|
|
|
completed BOOLEAN,
|
|
|
transfer_size BIGINT,
|
|
|
nvlink_path JSON,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS stream_operations (
|
|
|
operation_id BIGINT PRIMARY KEY,
|
|
|
stream_id BIGINT,
|
|
|
operation_data JSON,
|
|
|
status VARCHAR,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
error_message VARCHAR
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS sync_points (
|
|
|
sync_id VARCHAR PRIMARY KEY,
|
|
|
stream_id BIGINT,
|
|
|
gpu_ids VARCHAR,
|
|
|
status VARCHAR DEFAULT 'pending',
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
error_message VARCHAR
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
def record_event(self, source_gpu: int, target_gpu: Optional[int] = None,
|
|
|
transfer_size: Optional[int] = None) -> CrossGPUEvent:
|
|
|
"""Record a cross-GPU event in the stream"""
|
|
|
with self.lock:
|
|
|
event_id = f"event_{self.stream_id}_{len(self.events)}_{time.time_ns()}"
|
|
|
event = CrossGPUEvent(
|
|
|
event_id=event_id,
|
|
|
timestamp=time.time(),
|
|
|
source_gpu=source_gpu,
|
|
|
target_gpu=target_gpu,
|
|
|
transfer_size=transfer_size
|
|
|
)
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO stream_events (
|
|
|
event_id, stream_id, timestamp, source_gpu, target_gpu,
|
|
|
completed, transfer_size, nvlink_path, state_json
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
""", [
|
|
|
event_id, self.stream_id, event.timestamp, source_gpu,
|
|
|
target_gpu, False, transfer_size, None,
|
|
|
{"status": "created"}
|
|
|
])
|
|
|
|
|
|
self.events.append(event)
|
|
|
return event
|
|
|
|
|
|
def wait_event(self, event: CrossGPUEvent):
|
|
|
"""Wait for a cross-GPU event to complete"""
|
|
|
while True:
|
|
|
|
|
|
result = self.conn.execute("""
|
|
|
SELECT completed
|
|
|
FROM stream_events
|
|
|
WHERE event_id = ?
|
|
|
""", [event.event_id]).fetchall()
|
|
|
|
|
|
if result and result[0][0]:
|
|
|
event.completed = True
|
|
|
break
|
|
|
|
|
|
if event.completed:
|
|
|
break
|
|
|
|
|
|
time.sleep(0.001)
|
|
|
|
|
|
def synchronize(self):
|
|
|
"""Synchronize all operations in the stream"""
|
|
|
with self.lock:
|
|
|
for event in self.events:
|
|
|
self.wait_event(event)
|
|
|
self.events.clear()
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
DELETE FROM stream_events
|
|
|
WHERE stream_id = ? AND completed = true
|
|
|
""", [self.stream_id])
|
|
|
|
|
|
def add_operation(self, operation: Dict[str, Any]):
|
|
|
"""Add a cross-GPU operation to the stream"""
|
|
|
with self.lock:
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO stream_operations (
|
|
|
stream_id, operation_data, status
|
|
|
) VALUES (?, ?, ?)
|
|
|
""", [self.stream_id, operation, "pending"])
|
|
|
|
|
|
|
|
|
if operation.get('type') == 'transfer':
|
|
|
src_gpu = operation['source_gpu']
|
|
|
dst_gpu = operation['target_gpu']
|
|
|
size = operation['size']
|
|
|
|
|
|
|
|
|
paths = self.conn.execute("""
|
|
|
WITH RECURSIVE
|
|
|
gpu_path(src, dst, path, total_bandwidth, hops) AS (
|
|
|
-- Direct connections
|
|
|
SELECT
|
|
|
chip_a_id,
|
|
|
chip_b_id,
|
|
|
chip_a_id || ',' || chip_b_id,
|
|
|
bandwidth_tbps,
|
|
|
1
|
|
|
FROM optical_interconnects
|
|
|
WHERE (chip_a_id = ? AND chip_b_id = ?)
|
|
|
OR (chip_a_id = ? AND chip_b_id = ?)
|
|
|
|
|
|
UNION ALL
|
|
|
|
|
|
-- Multi-hop paths
|
|
|
SELECT
|
|
|
p.src,
|
|
|
i.chip_b_id,
|
|
|
p.path || ',' || i.chip_b_id,
|
|
|
LEAST(p.total_bandwidth, i.bandwidth_tbps),
|
|
|
p.hops + 1
|
|
|
FROM gpu_path p
|
|
|
JOIN optical_interconnects i ON p.dst = i.chip_a_id
|
|
|
WHERE p.hops < 3 -- Limit path length
|
|
|
AND i.chip_b_id != ALL(string_to_array(p.path, ',')::integer[])
|
|
|
)
|
|
|
SELECT path, total_bandwidth
|
|
|
FROM gpu_path
|
|
|
WHERE src = ? AND dst = ?
|
|
|
ORDER BY total_bandwidth DESC, hops ASC
|
|
|
LIMIT 1
|
|
|
""", (src_gpu, dst_gpu, dst_gpu, src_gpu, src_gpu, dst_gpu)).fetchone()
|
|
|
|
|
|
if paths:
|
|
|
operation['nvlink_path'] = paths[0].split(',')
|
|
|
operation['expected_bandwidth'] = float(paths[1])
|
|
|
|
|
|
self.operation_queue.put(operation)
|
|
|
|
|
|
def execute_next(self) -> bool:
|
|
|
"""Execute the next operation in the queue"""
|
|
|
try:
|
|
|
with self.lock:
|
|
|
if self.operation_queue.empty():
|
|
|
return False
|
|
|
|
|
|
operation = self.operation_queue.get()
|
|
|
event = self.record_event(
|
|
|
source_gpu=operation.get('source_gpu'),
|
|
|
target_gpu=operation.get('target_gpu'),
|
|
|
transfer_size=operation.get('size')
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
if operation.get('type') == 'transfer':
|
|
|
self._execute_transfer(operation)
|
|
|
elif operation.get('type') == 'compute':
|
|
|
self._execute_compute(operation)
|
|
|
elif operation.get('type') == 'sync':
|
|
|
self._execute_sync(operation)
|
|
|
|
|
|
|
|
|
event.completed = True
|
|
|
if self.hal:
|
|
|
self.hal.execute("""
|
|
|
UPDATE stream_events
|
|
|
SET completed = TRUE,
|
|
|
nvlink_path = ?,
|
|
|
state_json = ?
|
|
|
WHERE event_id = ?
|
|
|
""", (
|
|
|
','.join(operation.get('nvlink_path', [])),
|
|
|
json.dumps({"status": "completed"}),
|
|
|
event.event_id
|
|
|
))
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in stream {self.stream_id}: {str(e)}")
|
|
|
if self.hal:
|
|
|
self.hal.execute("""
|
|
|
UPDATE stream_events
|
|
|
SET state_json = ?
|
|
|
WHERE event_id = ?
|
|
|
""", (
|
|
|
json.dumps({"status": "error", "error": str(e)}),
|
|
|
event.event_id
|
|
|
))
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in stream {self.stream_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def _execute_transfer(self, operation: Dict[str, Any]):
|
|
|
"""Execute a data transfer operation using NVLink path"""
|
|
|
if not self.hal:
|
|
|
raise RuntimeError("No HAL connection available for transfer")
|
|
|
|
|
|
src_gpu = operation['source_gpu']
|
|
|
dst_gpu = operation['target_gpu']
|
|
|
size = operation['size']
|
|
|
nvlink_path = operation.get('nvlink_path', [])
|
|
|
|
|
|
|
|
|
self.hal.execute("""
|
|
|
INSERT INTO memory_transfers (
|
|
|
source_chip, target_chip, size_bytes, nvlink_path,
|
|
|
start_time, bandwidth_used
|
|
|
) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, ?)
|
|
|
""", (
|
|
|
src_gpu, dst_gpu, size,
|
|
|
','.join(map(str, nvlink_path)),
|
|
|
operation.get('expected_bandwidth', 0)
|
|
|
))
|
|
|
|
|
|
def _execute_compute(self, operation: Dict[str, Any]):
|
|
|
"""Execute a compute operation on specified GPU"""
|
|
|
if not self.hal:
|
|
|
raise RuntimeError("No HAL connection available for compute")
|
|
|
|
|
|
gpu_id = operation['gpu_id']
|
|
|
|
|
|
|
|
|
self.hal.execute("""
|
|
|
UPDATE gpu_chips
|
|
|
SET state_json = json_patch(
|
|
|
state_json,
|
|
|
json_object(
|
|
|
'current_operation', json(?),
|
|
|
'last_updated', json(CURRENT_TIMESTAMP)
|
|
|
)
|
|
|
)
|
|
|
WHERE chip_id = ?
|
|
|
""", (json.dumps(operation), gpu_id))
|
|
|
|
|
|
def _execute_sync(self, operation: Dict[str, Any]):
|
|
|
"""Execute a synchronization operation"""
|
|
|
gpu_ids = operation.get('gpu_ids', [])
|
|
|
if not gpu_ids:
|
|
|
return
|
|
|
|
|
|
|
|
|
sync_id = f"sync_{self.stream_id}_{time.time_ns()}"
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO sync_points (
|
|
|
sync_id, stream_id, gpu_ids
|
|
|
) VALUES (?, ?, ?)
|
|
|
""", [sync_id, self.stream_id, ','.join(map(str, gpu_ids))])
|
|
|
|
|
|
|
|
|
while True:
|
|
|
result = self.conn.execute("""
|
|
|
SELECT COUNT(*)
|
|
|
FROM gpu_chips
|
|
|
WHERE chip_id = ANY(string_split(?, ',')::INTEGER[])
|
|
|
AND state_json->>'sync_id' = ?
|
|
|
""", [','.join(map(str, gpu_ids)), sync_id]).fetchall()
|
|
|
|
|
|
if result[0][0] == len(gpu_ids):
|
|
|
|
|
|
self.conn.execute("""
|
|
|
UPDATE sync_points
|
|
|
SET status = 'completed',
|
|
|
completed_at = CURRENT_TIMESTAMP
|
|
|
WHERE sync_id = ?
|
|
|
""", [sync_id])
|
|
|
break
|
|
|
|
|
|
time.sleep(0.001)
|
|
|
|
|
|
class CrossGPUStreamManager:
|
|
|
"""Manages multiple cross-GPU streams"""
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, db_url: Optional[str] = None):
|
|
|
self.streams: List[CrossGPUStream] = []
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
|
|
|
self.default_stream = self.create_stream()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
self._setup_database()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(db_path)
|
|
|
conn.execute("INSTALL httpfs;")
|
|
|
conn.execute("LOAD httpfs;")
|
|
|
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
|
|
|
conn.execute("SET s3_use_ssl=true;")
|
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
return conn
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Initialize database tables"""
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS streams (
|
|
|
stream_id BIGINT PRIMARY KEY,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
is_active BOOLEAN DEFAULT true,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
def create_stream(self) -> CrossGPUStream:
|
|
|
"""Create a new cross-GPU stream"""
|
|
|
stream_id = len(self.streams)
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO streams (stream_id, state_json)
|
|
|
VALUES (?, ?)
|
|
|
""", [stream_id, {"status": "created"}])
|
|
|
|
|
|
|
|
|
stream = CrossGPUStream(stream_id, self.db_url)
|
|
|
self.streams.append(stream)
|
|
|
return stream
|
|
|
|
|
|
def get_stream(self, stream_id: int) -> Optional[CrossGPUStream]:
|
|
|
"""Get a stream by its ID"""
|
|
|
if 0 <= stream_id < len(self.streams):
|
|
|
return self.streams[stream_id]
|
|
|
return None
|
|
|
|
|
|
def synchronize_all(self):
|
|
|
"""Synchronize all streams"""
|
|
|
for stream in self.streams:
|
|
|
stream.synchronize()
|
|
|
|
|
|
def synchronize_stream(self, stream_id: int):
|
|
|
"""Synchronize a specific stream"""
|
|
|
stream = self.get_stream(stream_id)
|
|
|
if stream:
|
|
|
stream.synchronize()
|
|
|
|
|
|
def execute_streams(self):
|
|
|
"""Execute operations in all streams"""
|
|
|
while True:
|
|
|
executed = False
|
|
|
for stream in self.streams:
|
|
|
if stream.execute_next():
|
|
|
executed = True
|
|
|
if not executed:
|
|
|
break
|
|
|
|
|
|
def get_optimal_transfer_path(self, src_gpu: int, dst_gpu: int, size: int) -> List[str]:
|
|
|
"""Get optimal NVLink path for data transfer"""
|
|
|
result = self.conn.execute("""
|
|
|
WITH RECURSIVE
|
|
|
gpu_path(src, dst, path, total_bandwidth, hops) AS (
|
|
|
-- Direct connections
|
|
|
SELECT
|
|
|
chip_a_id,
|
|
|
chip_b_id,
|
|
|
CAST(chip_a_id AS VARCHAR) || ',' || CAST(chip_b_id AS VARCHAR),
|
|
|
bandwidth_tbps,
|
|
|
1
|
|
|
FROM optical_interconnects
|
|
|
WHERE (chip_a_id = ? AND chip_b_id = ?)
|
|
|
OR (chip_a_id = ? AND chip_b_id = ?)
|
|
|
|
|
|
UNION ALL
|
|
|
|
|
|
-- Multi-hop paths
|
|
|
SELECT
|
|
|
p.src,
|
|
|
i.chip_b_id,
|
|
|
p.path || ',' || CAST(i.chip_b_id AS VARCHAR),
|
|
|
LEAST(p.total_bandwidth, i.bandwidth_tbps),
|
|
|
p.hops + 1
|
|
|
FROM gpu_path p
|
|
|
JOIN optical_interconnects i ON p.dst = i.chip_a_id
|
|
|
WHERE p.hops < 3 -- Limit path length
|
|
|
AND NOT list_contains(string_split(p.path, ','), CAST(i.chip_b_id AS VARCHAR))
|
|
|
)
|
|
|
SELECT path, total_bandwidth
|
|
|
FROM gpu_path
|
|
|
WHERE src = ? AND dst = ?
|
|
|
ORDER BY total_bandwidth DESC, hops ASC
|
|
|
LIMIT 1
|
|
|
""", [src_gpu, dst_gpu, dst_gpu, src_gpu, src_gpu, dst_gpu]).fetchall()
|
|
|
|
|
|
if result:
|
|
|
return result[0][0].split(',')
|
|
|
return []
|
|
|
|
|
|
def add_cross_gpu_operation(self, stream_id: int, operation: Dict[str, Any]):
|
|
|
"""Add a cross-GPU operation to a specific stream"""
|
|
|
stream = self.get_stream(stream_id)
|
|
|
if not stream:
|
|
|
raise ValueError(f"Invalid stream ID: {stream_id}")
|
|
|
|
|
|
|
|
|
if operation.get('type') == 'transfer':
|
|
|
path = self.get_optimal_transfer_path(
|
|
|
operation['source_gpu'],
|
|
|
operation['target_gpu'],
|
|
|
operation['size']
|
|
|
)
|
|
|
operation['nvlink_path'] = path
|
|
|
|
|
|
stream.add_operation(operation)
|
|
|
|