from typing import Dict, List, Tuple, Any, Optional import numpy as np import threading import time import json from dataclasses import dataclass from enum import Enum import duckdb from huggingface_hub import HfApi, HfFileSystem from .memory import RegisterFile from config import get_hf_token_cached # Initialize token from .env @dataclass class WarpBarrier: """Represents a synchronization barrier for warps""" barrier_id: str num_warps: int arrived: int = 0 completed: bool = False lock: threading.Lock = threading.Lock() condition: threading.Condition = threading.Condition() class ShuffleMode(Enum): """Different modes for warp shuffle operations""" UP = "up" # Shuffle up relative to caller DOWN = "down" # Shuffle down relative to caller XOR = "xor" # Butterfly shuffle pattern IDX = "idx" # Direct index-based shuffle BCAST = "bcast" # Broadcast from source lane class VotingMode(Enum): """Different modes for warp voting operations""" ALL = "all" # True if predicate is true for all active threads ANY = "any" # True if predicate is true for any active thread BALLOT = "ballot" # Returns bitmask of true predicates COUNT = "count" # Returns count of true predicates class WarpState(Enum): """Possible states for a warp""" READY = "ready" # Ready to execute RUNNING = "running" # Currently executing BLOCKED = "blocked" # Waiting for synchronization YIELDED = "yielded" # Voluntarily yielded execution COMPLETED = "completed" # Finished execution class Warp: """Represents a group of threads that execute together with advanced synchronization""" DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, warp_id: int, num_threads: int = 32, db_url: Optional[str] = None): self.warp_id = warp_id self.num_threads = min(num_threads, 32) # Max 32 threads per warp self.active_mask = (1 << self.num_threads) - 1 # All threads active initially self.predicate_mask = (1 << self.num_threads) - 1 # For predicated execution self.registers = [RegisterFile() for _ in range(self.num_threads)] self.state = WarpState.READY # Initialize database connection self.db_url = db_url or self.DB_URL self.max_retries = 3 self._connect_with_retries() self._setup_database() # Register warp in database self._register_warp() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self.conn = self._init_db_connection() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") time.sleep(1) def _init_db_connection(self) -> duckdb.DuckDBPyConnection: """Initialize database connection with HuggingFace configuration""" # Convert HF URL to S3 path _, _, owner, dataset, db_file = self.db_url.split('/', 4) db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}" # Connect to remote database conn = duckdb.connect(db_path) conn.execute("INSTALL httpfs;") conn.execute("LOAD httpfs;") conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';") conn.execute("SET s3_use_ssl=true;") conn.execute("SET s3_url_style='path';") conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") return conn def _setup_database(self): """Initialize database tables""" # Warp state table self.conn.execute(""" CREATE TABLE IF NOT EXISTS warps ( warp_id VARCHAR PRIMARY KEY, num_threads INTEGER, active_mask BIGINT, predicate_mask BIGINT, state VARCHAR, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP, state_json JSON ) """) # Barrier table self.conn.execute(""" CREATE TABLE IF NOT EXISTS warp_barriers ( barrier_id VARCHAR PRIMARY KEY, num_warps INTEGER, arrived_count INTEGER DEFAULT 0, completed BOOLEAN DEFAULT false, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, state_json JSON ) """) # Register table self.conn.execute(""" CREATE TABLE IF NOT EXISTS warp_registers ( warp_id VARCHAR, thread_id INTEGER, register_id INTEGER, value BLOB, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (warp_id, thread_id, register_id) ) """) def _register_warp(self): """Register warp in database""" self.conn.execute(""" INSERT INTO warps ( warp_id, num_threads, active_mask, predicate_mask, state, state_json ) VALUES (?, ?, ?, ?, ?, ?) """, [ str(self.warp_id), self.num_threads, self.active_mask, self.predicate_mask, self.state.value, {"status": "initialized"} ]) self.pc = 0 # Program counter # Synchronization self.barriers: Dict[str, WarpBarrier] = {} self.lock = threading.Lock() # Performance tracking self.cycles_executed = 0 self.last_active_time = time.time() def get_active_threads(self) -> List[int]: """Get indices of currently active threads""" return [i for i in range(self.num_threads) if self.active_mask & (1 << i)] def get_predicated_threads(self) -> List[int]: """Get indices of threads that pass predication""" return [i for i in range(self.num_threads) if (self.active_mask & (1 << i)) and (self.predicate_mask & (1 << i))] def set_active_mask(self, mask: int): """Set which threads are active""" with self.lock: self.active_mask = mask & ((1 << self.num_threads) - 1) def set_predicate_mask(self, mask: int): """Set predication mask for conditional execution""" with self.lock: self.predicate_mask = mask & ((1 << self.num_threads) - 1) def sync(self, barrier_id: str = None): """Synchronize all threads in the warp at a barrier""" if not barrier_id: barrier_id = f"warp_{self.warp_id}_barrier_{time.time_ns()}" with self.lock: if barrier_id not in self.barriers: self.barriers[barrier_id] = WarpBarrier( barrier_id=barrier_id, num_warps=1 ) barrier = self.barriers[barrier_id] with barrier.lock: barrier.arrived += 1 if barrier.arrived == barrier.num_warps: barrier.completed = True with barrier.condition: barrier.condition.notify_all() else: while not barrier.completed: with barrier.condition: barrier.condition.wait() def vote(self, predicate: List[bool], mode: VotingMode = VotingMode.ALL) -> Any: """Perform voting operation across threads""" active_threads = self.get_predicated_threads() if not active_threads: return False if mode != VotingMode.BALLOT else 0 if mode == VotingMode.ALL: return all(predicate[i] for i in active_threads) elif mode == VotingMode.ANY: return any(predicate[i] for i in active_threads) elif mode == VotingMode.BALLOT: return sum(1 << i for i in active_threads if predicate[i]) elif mode == VotingMode.COUNT: return sum(1 for i in active_threads if predicate[i]) def shuffle(self, var: List[Any], mode: ShuffleMode, offset: int) -> List[Any]: """Exchange variables between threads using different shuffle patterns""" active_threads = self.get_predicated_threads() result = list(var) # Create copy to store results if mode == ShuffleMode.UP: # Shift values up by offset for i in active_threads: src_idx = (i - offset) % self.num_threads if src_idx in active_threads: result[i] = var[src_idx] elif mode == ShuffleMode.DOWN: # Shift values down by offset for i in active_threads: src_idx = (i + offset) % self.num_threads if src_idx in active_threads: result[i] = var[src_idx] elif mode == ShuffleMode.XOR: # Butterfly pattern exchange for i in active_threads: src_idx = i ^ offset # XOR with offset if src_idx < self.num_threads and src_idx in active_threads: result[i] = var[src_idx] elif mode == ShuffleMode.IDX: # Direct index-based shuffle for i in active_threads: if offset < self.num_threads and offset in active_threads: result[i] = var[offset] elif mode == ShuffleMode.BCAST: # Broadcast from source lane if offset < self.num_threads and offset in active_threads: src_val = var[offset] for i in active_threads: result[i] = src_val return result def execute(self, func: callable, *args, **kwargs): """Execute a function across all active threads""" active_threads = self.get_active_threads() results = [] for thread_idx in active_threads: # Set up thread context ctx = { 'thread_idx': thread_idx, 'warp_id': self.warp_id, 'registers': self.registers[thread_idx] } # Execute thread result = func(ctx, *args, **kwargs) results.append(result) return results class WarpScheduler: """Advanced warp scheduler with priority and dependency handling""" def __init__(self, max_warps: int = 32, max_active_warps: int = 16): self.max_warps = max_warps self.max_active_warps = max_active_warps self.warps: List[Warp] = [] self.active_warps: Dict[int, bool] = {} self.warp_priorities: Dict[int, int] = {} self.warp_dependencies: Dict[int, List[int]] = {} self.lock = threading.Lock() def create_warp(self, num_threads: int = 32, priority: int = 0) -> Warp: """Create a new warp with specified priority""" with self.lock: if len(self.warps) >= self.max_warps: raise RuntimeError("Maximum number of warps reached") warp_id = len(self.warps) warp = Warp(warp_id, num_threads) self.warps.append(warp) self.active_warps[warp_id] = True self.warp_priorities[warp_id] = priority self.warp_dependencies[warp_id] = [] return warp def set_warp_priority(self, warp_id: int, priority: int): """Set execution priority for a warp""" with self.lock: if 0 <= warp_id < len(self.warps): self.warp_priorities[warp_id] = priority def add_warp_dependency(self, warp_id: int, depends_on: int): """Add execution dependency between warps""" with self.lock: if 0 <= warp_id < len(self.warps) and 0 <= depends_on < len(self.warps): self.warp_dependencies[warp_id].append(depends_on) def remove_warp_dependency(self, warp_id: int, depends_on: int): """Remove execution dependency between warps""" with self.lock: if 0 <= warp_id < len(self.warps): try: self.warp_dependencies[warp_id].remove(depends_on) except ValueError: pass def suspend_warp(self, warp_id: int): """Suspend a warp from execution""" with self.lock: if 0 <= warp_id < len(self.warps): self.active_warps[warp_id] = False self.warps[warp_id].state = WarpState.BLOCKED def resume_warp(self, warp_id: int): """Resume a suspended warp""" with self.lock: if 0 <= warp_id < len(self.warps): self.active_warps[warp_id] = True self.warps[warp_id].state = WarpState.READY def synchronize_warps(self, warp_ids: List[int], barrier_id: str = None): """Synchronize a group of warps""" if not barrier_id: barrier_id = f"barrier_{time.time_ns()}" # Create barrier barrier = WarpBarrier(barrier_id=barrier_id, num_warps=len(warp_ids)) # Register barrier with each warp for warp_id in warp_ids: if 0 <= warp_id < len(self.warps): warp = self.warps[warp_id] warp.barriers[barrier_id] = barrier # Wait for all warps to reach barrier for warp_id in warp_ids: if 0 <= warp_id < len(self.warps): self.warps[warp_id].sync(barrier_id) def schedule(self) -> List[Warp]: """Schedule warps for execution based on priority and dependencies""" with self.lock: ready_warps = [] # Get warps that are ready to execute for warp_id, warp in enumerate(self.warps): if not self.active_warps.get(warp_id, False): continue # Check dependencies dependencies_met = all( self.warps[dep_id].state == WarpState.COMPLETED for dep_id in self.warp_dependencies.get(warp_id, []) ) if dependencies_met and warp.state == WarpState.READY: ready_warps.append((warp_id, self.warp_priorities.get(warp_id, 0))) # Sort by priority (higher numbers = higher priority) ready_warps.sort(key=lambda x: x[1], reverse=True) # Return warps up to max_active_warps return [self.warps[warp_id] for warp_id, _ in ready_warps[:self.max_active_warps]] def execute_warps(self, func: callable, *args, **kwargs): """Execute function across all active warps with scheduling""" results = [] scheduled_warps = self.schedule() for warp in scheduled_warps: warp.state = WarpState.RUNNING result = warp.execute(func, *args, **kwargs) results.extend(result) warp.last_active_time = time.time() warp.cycles_executed += 1 warp.state = WarpState.READY return results