|
|
from typing import Dict, List, Tuple, Any, Optional
|
|
|
import numpy as np
|
|
|
import threading
|
|
|
import time
|
|
|
import json
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
import duckdb
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
from .memory import RegisterFile
|
|
|
from config import get_hf_token_cached
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class WarpBarrier:
|
|
|
"""Represents a synchronization barrier for warps"""
|
|
|
barrier_id: str
|
|
|
num_warps: int
|
|
|
arrived: int = 0
|
|
|
completed: bool = False
|
|
|
lock: threading.Lock = threading.Lock()
|
|
|
condition: threading.Condition = threading.Condition()
|
|
|
|
|
|
class ShuffleMode(Enum):
|
|
|
"""Different modes for warp shuffle operations"""
|
|
|
UP = "up"
|
|
|
DOWN = "down"
|
|
|
XOR = "xor"
|
|
|
IDX = "idx"
|
|
|
BCAST = "bcast"
|
|
|
|
|
|
class VotingMode(Enum):
|
|
|
"""Different modes for warp voting operations"""
|
|
|
ALL = "all"
|
|
|
ANY = "any"
|
|
|
BALLOT = "ballot"
|
|
|
COUNT = "count"
|
|
|
|
|
|
class WarpState(Enum):
|
|
|
"""Possible states for a warp"""
|
|
|
READY = "ready"
|
|
|
RUNNING = "running"
|
|
|
BLOCKED = "blocked"
|
|
|
YIELDED = "yielded"
|
|
|
COMPLETED = "completed"
|
|
|
|
|
|
class Warp:
|
|
|
"""Represents a group of threads that execute together with advanced synchronization"""
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, warp_id: int, num_threads: int = 32, db_url: Optional[str] = None):
|
|
|
self.warp_id = warp_id
|
|
|
self.num_threads = min(num_threads, 32)
|
|
|
self.active_mask = (1 << self.num_threads) - 1
|
|
|
self.predicate_mask = (1 << self.num_threads) - 1
|
|
|
self.registers = [RegisterFile() for _ in range(self.num_threads)]
|
|
|
self.state = WarpState.READY
|
|
|
|
|
|
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
self._setup_database()
|
|
|
|
|
|
|
|
|
self._register_warp()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(db_path)
|
|
|
conn.execute("INSTALL httpfs;")
|
|
|
conn.execute("LOAD httpfs;")
|
|
|
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
|
|
|
conn.execute("SET s3_use_ssl=true;")
|
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
return conn
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Initialize database tables"""
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS warps (
|
|
|
warp_id VARCHAR PRIMARY KEY,
|
|
|
num_threads INTEGER,
|
|
|
active_mask BIGINT,
|
|
|
predicate_mask BIGINT,
|
|
|
state VARCHAR,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
updated_at TIMESTAMP,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS warp_barriers (
|
|
|
barrier_id VARCHAR PRIMARY KEY,
|
|
|
num_warps INTEGER,
|
|
|
arrived_count INTEGER DEFAULT 0,
|
|
|
completed BOOLEAN DEFAULT false,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS warp_registers (
|
|
|
warp_id VARCHAR,
|
|
|
thread_id INTEGER,
|
|
|
register_id INTEGER,
|
|
|
value BLOB,
|
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
PRIMARY KEY (warp_id, thread_id, register_id)
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
def _register_warp(self):
|
|
|
"""Register warp in database"""
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO warps (
|
|
|
warp_id, num_threads, active_mask, predicate_mask,
|
|
|
state, state_json
|
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
|
|
""", [
|
|
|
str(self.warp_id),
|
|
|
self.num_threads,
|
|
|
self.active_mask,
|
|
|
self.predicate_mask,
|
|
|
self.state.value,
|
|
|
{"status": "initialized"}
|
|
|
])
|
|
|
self.pc = 0
|
|
|
|
|
|
|
|
|
self.barriers: Dict[str, WarpBarrier] = {}
|
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
|
|
|
self.cycles_executed = 0
|
|
|
self.last_active_time = time.time()
|
|
|
|
|
|
def get_active_threads(self) -> List[int]:
|
|
|
"""Get indices of currently active threads"""
|
|
|
return [i for i in range(self.num_threads)
|
|
|
if self.active_mask & (1 << i)]
|
|
|
|
|
|
def get_predicated_threads(self) -> List[int]:
|
|
|
"""Get indices of threads that pass predication"""
|
|
|
return [i for i in range(self.num_threads)
|
|
|
if (self.active_mask & (1 << i)) and (self.predicate_mask & (1 << i))]
|
|
|
|
|
|
def set_active_mask(self, mask: int):
|
|
|
"""Set which threads are active"""
|
|
|
with self.lock:
|
|
|
self.active_mask = mask & ((1 << self.num_threads) - 1)
|
|
|
|
|
|
def set_predicate_mask(self, mask: int):
|
|
|
"""Set predication mask for conditional execution"""
|
|
|
with self.lock:
|
|
|
self.predicate_mask = mask & ((1 << self.num_threads) - 1)
|
|
|
|
|
|
def sync(self, barrier_id: str = None):
|
|
|
"""Synchronize all threads in the warp at a barrier"""
|
|
|
if not barrier_id:
|
|
|
barrier_id = f"warp_{self.warp_id}_barrier_{time.time_ns()}"
|
|
|
|
|
|
with self.lock:
|
|
|
if barrier_id not in self.barriers:
|
|
|
self.barriers[barrier_id] = WarpBarrier(
|
|
|
barrier_id=barrier_id,
|
|
|
num_warps=1
|
|
|
)
|
|
|
|
|
|
barrier = self.barriers[barrier_id]
|
|
|
with barrier.lock:
|
|
|
barrier.arrived += 1
|
|
|
if barrier.arrived == barrier.num_warps:
|
|
|
barrier.completed = True
|
|
|
with barrier.condition:
|
|
|
barrier.condition.notify_all()
|
|
|
else:
|
|
|
while not barrier.completed:
|
|
|
with barrier.condition:
|
|
|
barrier.condition.wait()
|
|
|
|
|
|
def vote(self, predicate: List[bool], mode: VotingMode = VotingMode.ALL) -> Any:
|
|
|
"""Perform voting operation across threads"""
|
|
|
active_threads = self.get_predicated_threads()
|
|
|
if not active_threads:
|
|
|
return False if mode != VotingMode.BALLOT else 0
|
|
|
|
|
|
if mode == VotingMode.ALL:
|
|
|
return all(predicate[i] for i in active_threads)
|
|
|
elif mode == VotingMode.ANY:
|
|
|
return any(predicate[i] for i in active_threads)
|
|
|
elif mode == VotingMode.BALLOT:
|
|
|
return sum(1 << i for i in active_threads if predicate[i])
|
|
|
elif mode == VotingMode.COUNT:
|
|
|
return sum(1 for i in active_threads if predicate[i])
|
|
|
|
|
|
def shuffle(self, var: List[Any], mode: ShuffleMode, offset: int) -> List[Any]:
|
|
|
"""Exchange variables between threads using different shuffle patterns"""
|
|
|
active_threads = self.get_predicated_threads()
|
|
|
result = list(var)
|
|
|
|
|
|
if mode == ShuffleMode.UP:
|
|
|
|
|
|
for i in active_threads:
|
|
|
src_idx = (i - offset) % self.num_threads
|
|
|
if src_idx in active_threads:
|
|
|
result[i] = var[src_idx]
|
|
|
|
|
|
elif mode == ShuffleMode.DOWN:
|
|
|
|
|
|
for i in active_threads:
|
|
|
src_idx = (i + offset) % self.num_threads
|
|
|
if src_idx in active_threads:
|
|
|
result[i] = var[src_idx]
|
|
|
|
|
|
elif mode == ShuffleMode.XOR:
|
|
|
|
|
|
for i in active_threads:
|
|
|
src_idx = i ^ offset
|
|
|
if src_idx < self.num_threads and src_idx in active_threads:
|
|
|
result[i] = var[src_idx]
|
|
|
|
|
|
elif mode == ShuffleMode.IDX:
|
|
|
|
|
|
for i in active_threads:
|
|
|
if offset < self.num_threads and offset in active_threads:
|
|
|
result[i] = var[offset]
|
|
|
|
|
|
elif mode == ShuffleMode.BCAST:
|
|
|
|
|
|
if offset < self.num_threads and offset in active_threads:
|
|
|
src_val = var[offset]
|
|
|
for i in active_threads:
|
|
|
result[i] = src_val
|
|
|
|
|
|
return result
|
|
|
|
|
|
def execute(self, func: callable, *args, **kwargs):
|
|
|
"""Execute a function across all active threads"""
|
|
|
active_threads = self.get_active_threads()
|
|
|
results = []
|
|
|
|
|
|
for thread_idx in active_threads:
|
|
|
|
|
|
ctx = {
|
|
|
'thread_idx': thread_idx,
|
|
|
'warp_id': self.warp_id,
|
|
|
'registers': self.registers[thread_idx]
|
|
|
}
|
|
|
|
|
|
|
|
|
result = func(ctx, *args, **kwargs)
|
|
|
results.append(result)
|
|
|
|
|
|
return results
|
|
|
|
|
|
class WarpScheduler:
|
|
|
"""Advanced warp scheduler with priority and dependency handling"""
|
|
|
|
|
|
def __init__(self, max_warps: int = 32, max_active_warps: int = 16):
|
|
|
self.max_warps = max_warps
|
|
|
self.max_active_warps = max_active_warps
|
|
|
self.warps: List[Warp] = []
|
|
|
self.active_warps: Dict[int, bool] = {}
|
|
|
self.warp_priorities: Dict[int, int] = {}
|
|
|
self.warp_dependencies: Dict[int, List[int]] = {}
|
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
def create_warp(self, num_threads: int = 32, priority: int = 0) -> Warp:
|
|
|
"""Create a new warp with specified priority"""
|
|
|
with self.lock:
|
|
|
if len(self.warps) >= self.max_warps:
|
|
|
raise RuntimeError("Maximum number of warps reached")
|
|
|
|
|
|
warp_id = len(self.warps)
|
|
|
warp = Warp(warp_id, num_threads)
|
|
|
self.warps.append(warp)
|
|
|
self.active_warps[warp_id] = True
|
|
|
self.warp_priorities[warp_id] = priority
|
|
|
self.warp_dependencies[warp_id] = []
|
|
|
return warp
|
|
|
|
|
|
def set_warp_priority(self, warp_id: int, priority: int):
|
|
|
"""Set execution priority for a warp"""
|
|
|
with self.lock:
|
|
|
if 0 <= warp_id < len(self.warps):
|
|
|
self.warp_priorities[warp_id] = priority
|
|
|
|
|
|
def add_warp_dependency(self, warp_id: int, depends_on: int):
|
|
|
"""Add execution dependency between warps"""
|
|
|
with self.lock:
|
|
|
if 0 <= warp_id < len(self.warps) and 0 <= depends_on < len(self.warps):
|
|
|
self.warp_dependencies[warp_id].append(depends_on)
|
|
|
|
|
|
def remove_warp_dependency(self, warp_id: int, depends_on: int):
|
|
|
"""Remove execution dependency between warps"""
|
|
|
with self.lock:
|
|
|
if 0 <= warp_id < len(self.warps):
|
|
|
try:
|
|
|
self.warp_dependencies[warp_id].remove(depends_on)
|
|
|
except ValueError:
|
|
|
pass
|
|
|
|
|
|
def suspend_warp(self, warp_id: int):
|
|
|
"""Suspend a warp from execution"""
|
|
|
with self.lock:
|
|
|
if 0 <= warp_id < len(self.warps):
|
|
|
self.active_warps[warp_id] = False
|
|
|
self.warps[warp_id].state = WarpState.BLOCKED
|
|
|
|
|
|
def resume_warp(self, warp_id: int):
|
|
|
"""Resume a suspended warp"""
|
|
|
with self.lock:
|
|
|
if 0 <= warp_id < len(self.warps):
|
|
|
self.active_warps[warp_id] = True
|
|
|
self.warps[warp_id].state = WarpState.READY
|
|
|
|
|
|
def synchronize_warps(self, warp_ids: List[int], barrier_id: str = None):
|
|
|
"""Synchronize a group of warps"""
|
|
|
if not barrier_id:
|
|
|
barrier_id = f"barrier_{time.time_ns()}"
|
|
|
|
|
|
|
|
|
barrier = WarpBarrier(barrier_id=barrier_id, num_warps=len(warp_ids))
|
|
|
|
|
|
|
|
|
for warp_id in warp_ids:
|
|
|
if 0 <= warp_id < len(self.warps):
|
|
|
warp = self.warps[warp_id]
|
|
|
warp.barriers[barrier_id] = barrier
|
|
|
|
|
|
|
|
|
for warp_id in warp_ids:
|
|
|
if 0 <= warp_id < len(self.warps):
|
|
|
self.warps[warp_id].sync(barrier_id)
|
|
|
|
|
|
def schedule(self) -> List[Warp]:
|
|
|
"""Schedule warps for execution based on priority and dependencies"""
|
|
|
with self.lock:
|
|
|
ready_warps = []
|
|
|
|
|
|
|
|
|
for warp_id, warp in enumerate(self.warps):
|
|
|
if not self.active_warps.get(warp_id, False):
|
|
|
continue
|
|
|
|
|
|
|
|
|
dependencies_met = all(
|
|
|
self.warps[dep_id].state == WarpState.COMPLETED
|
|
|
for dep_id in self.warp_dependencies.get(warp_id, [])
|
|
|
)
|
|
|
|
|
|
if dependencies_met and warp.state == WarpState.READY:
|
|
|
ready_warps.append((warp_id, self.warp_priorities.get(warp_id, 0)))
|
|
|
|
|
|
|
|
|
ready_warps.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
return [self.warps[warp_id]
|
|
|
for warp_id, _ in ready_warps[:self.max_active_warps]]
|
|
|
|
|
|
def execute_warps(self, func: callable, *args, **kwargs):
|
|
|
"""Execute function across all active warps with scheduling"""
|
|
|
results = []
|
|
|
scheduled_warps = self.schedule()
|
|
|
|
|
|
for warp in scheduled_warps:
|
|
|
warp.state = WarpState.RUNNING
|
|
|
result = warp.execute(func, *args, **kwargs)
|
|
|
results.extend(result)
|
|
|
warp.last_active_time = time.time()
|
|
|
warp.cycles_executed += 1
|
|
|
warp.state = WarpState.READY
|
|
|
|
|
|
return results
|
|
|
|