from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Any import time import numpy as np from datetime import datetime from .remote_storage import RemoteStorageManager from .ftl import AdvancedFTL @dataclass class QoSParameters: """Quality of Service parameters""" priority: int # 0-7, higher is more important bandwidth_min: float # Minimum guaranteed bandwidth in GB/s latency_max: float # Maximum acceptable latency in microseconds bandwidth_weight: float # Weight for bandwidth allocation @dataclass class DMARequest: """DMA transfer request details""" source_addr: int dest_addr: int size: int priority: int is_async: bool callback: Optional[callable] = None class PCIeInterface: PCIE_VERSIONS = { '4.0': {'bandwidth': 16.0, 'encoding': 128/130, 'base_latency': 0.5}, '5.0': {'bandwidth': 32.0, 'encoding': 128/130, 'base_latency': 0.4}, '6.0': {'bandwidth': 64.0, 'encoding': 242/256, 'base_latency': 0.3} } def __init__(self, version='6.0', lanes=16, max_gbps=None): self.version = version self.lanes = lanes self.spec = self.PCIE_VERSIONS[version] self.max_gbps = max_gbps or self.spec['bandwidth'] * lanes * self.spec['encoding'] # Initialize storage components self.storage = RemoteStorageManager() self.total_vram = 16 * 1024 * 1024 * 1024 # 16GB default self.page_size = 4096 # 4KB pages self.block_size = 256 * self.page_size # 1MB blocks # Initialize FTL total_blocks = self.total_vram // self.block_size pages_per_block = self.block_size // self.page_size self.ftl = AdvancedFTL(total_blocks=total_blocks, pages_per_block=pages_per_block) # Initialize interface state in remote storage self._init_interface_state() # Lane bonding and management self.active_lanes = lanes self.lane_groups: List[int] = self._initialize_lane_groups() self.lane_errors = [0] * lanes # QoS and bandwidth management self.active_transfers: Dict[int, DMARequest] = {} self.qos_profiles: Dict[int, QoSParameters] = {} self.bandwidth_allocations: Dict[int, float] = {} # DMA engine self.dma_queue: List[DMARequest] = [] self.dma_active = False self.dma_batch_size = 1024 * 1024 # 1MB batches def _init_interface_state(self): """Initialize interface state in remote storage""" interface_state = { 'version': self.version, 'lanes': self.lanes, 'max_gbps': self.max_gbps, 'active_lanes': self.active_lanes, 'lane_groups': self.lane_groups, 'lane_errors': self.lane_errors, 'qos_profiles': {}, 'bandwidth_allocations': {}, 'timestamp': datetime.now().isoformat() } # Store initial state self.storage.store_interface_state(interface_state) def _initialize_lane_groups(self) -> List[int]: """Initialize lane groups for bonding""" groups = [] lanes_per_group = 4 for i in range(0, self.lanes, lanes_per_group): groups.append(lanes_per_group) return groups def add_qos_profile(self, profile_id: int, params: QoSParameters): """Add or update QoS profile""" self.qos_profiles[profile_id] = params self._rebalance_bandwidth() def _rebalance_bandwidth(self): """Rebalance bandwidth allocations based on QoS profiles and log to remote DB""" total_weight = sum(p.bandwidth_weight for p in self.qos_profiles.values()) available_bandwidth = self.max_gbps for profile_id, params in self.qos_profiles.items(): # Ensure minimum bandwidth self.bandwidth_allocations[profile_id] = params.bandwidth_min available_bandwidth -= params.bandwidth_min # Distribute remaining bandwidth by weight if available_bandwidth > 0 and total_weight > 0: for profile_id, params in self.qos_profiles.items(): extra = (params.bandwidth_weight / total_weight) * available_bandwidth self.bandwidth_allocations[profile_id] += extra # Log QoS metrics to remote storage qos_data = { 'timestamp': datetime.now().isoformat(), 'profile_id': profile_id, 'bandwidth_allocated': self.bandwidth_allocations[profile_id], 'bandwidth_used': 0.0, # Will be updated as bandwidth is used 'latency_measured': 0.0, # Will be updated as transfers occur 'latency_target': params.latency_max } self.storage.store_qos_metrics(qos_data) def _log_transfer(self, size_bytes: int, direction: str, qos_profile_id: Optional[int], transfer_time: float, bandwidth: float): """Log transfer details to remote storage""" transfer_data = { 'timestamp': datetime.now().isoformat(), 'size_bytes': size_bytes, 'direction': direction, 'qos_profile_id': qos_profile_id, 'transfer_time': transfer_time, 'lanes_active': self.active_lanes, 'bandwidth_achieved': bandwidth } self.storage.store_transfer(transfer_data) def transfer_time(self, size_bytes: int, qos_profile_id: Optional[int] = None) -> float: """Calculate transfer time with QoS consideration""" # Get effective bandwidth based on QoS effective_bandwidth = self.max_gbps if qos_profile_id is not None and qos_profile_id in self.bandwidth_allocations: effective_bandwidth = self.bandwidth_allocations[qos_profile_id] # Calculate transfer time gb = size_bytes / 1e9 transfer_time = gb / effective_bandwidth # Add encoding overhead transfer_time /= self.spec['encoding'] # Add base latency total_time = transfer_time + self.spec['base_latency'] # Log to remote DB self._log_transfer(size_bytes, 'calculate', qos_profile_id, total_time, effective_bandwidth) return total_time def initiate_dma_transfer(self, request: DMARequest) -> bool: """Initialize DMA transfer with QoS awareness""" self.dma_queue.append(request) if not self.dma_active: self._process_dma_queue() return True def _process_dma_queue(self): """Process DMA queue with QoS prioritization""" if not self.dma_queue: self.dma_active = False return self.dma_active = True # Sort by priority self.dma_queue.sort(key=lambda x: x.priority, reverse=True) while self.dma_queue: request = self.dma_queue[0] # Process in batches for better efficiency remaining = request.size while remaining > 0: batch_size = min(remaining, self.dma_batch_size) self._execute_dma_batch(request, batch_size) remaining -= batch_size if request.callback: request.callback() self.dma_queue.pop(0) def _execute_dma_batch(self, request: DMARequest, batch_size: int): """Execute a single DMA batch transfer with remote logging""" start_time = time.time() # Validate addresses using FTL source_phys = self.ftl.get_phys(request.source_addr // self.page_size) dest_phys = self.ftl.get_phys(request.dest_addr // self.page_size) if source_phys is None or dest_phys is None: raise RuntimeError("Invalid memory address in DMA transfer") transfer_time = self.transfer_time(batch_size) # Simulate DMA transfer time.sleep(transfer_time) # Log DMA operation to remote storage dma_data = { 'timestamp': datetime.now().isoformat(), 'source_addr': request.source_addr, 'dest_addr': request.dest_addr, 'size_bytes': batch_size, 'priority': request.priority, 'completion_time': time.time() - start_time, 'status': 'completed' } self.storage.store_dma_operation(dma_data) def allocate_vram(self, size: int, qos: Optional[QoSParameters] = None) -> Optional[int]: """ Allocate VRAM with optional QoS parameters Args: size: Size in bytes to allocate qos: Quality of Service parameters Returns: Virtual address or None if allocation fails """ try: # Round up to nearest page size pages_needed = (size + self.page_size - 1) // self.page_size # Get a free block from FTL block_id = self.ftl.get_free_block() if block_id is None: # Try garbage collection self._run_garbage_collection() block_id = self.ftl.get_free_block() if block_id is None: raise RuntimeError("Out of VRAM") # Calculate virtual address virt_addr = block_id * self.block_size # Map pages in FTL for i in range(pages_needed): lba = (virt_addr // self.page_size) + i phys = (block_id * self.ftl.pages_per_block) + i # Mark as hot if high priority QoS is_hot = qos and qos.priority >= 6 self.ftl.map(lba, phys, is_hot) return virt_addr except Exception as e: self.storage.log_error("VRAM allocation failed", str(e)) return None def free_vram(self, virt_addr: int, size: int) -> bool: """ Free allocated VRAM Args: virt_addr: Virtual address to free size: Size in bytes to free Returns: True if successful """ try: # Calculate pages to free start_page = virt_addr // self.page_size pages_to_free = (size + self.page_size - 1) // self.page_size # Invalidate pages in FTL for i in range(pages_to_free): lba = start_page + i phys = self.ftl.get_phys(lba) if phys is not None: block_id = phys // self.ftl.pages_per_block self.ftl.garbage_collect(block_id) return True except Exception as e: self.storage.log_error("VRAM free failed", str(e)) return False def _run_garbage_collection(self) -> None: """Run garbage collection on VRAM blocks""" stats = self.ftl.get_stats() if stats.get('free_blocks', 0) > stats.get('total_blocks', 0) * 0.1: return # Still enough free blocks # Find and collect blocks with most invalid pages for block in range(stats.get('total_blocks', 0)): self.ftl.garbage_collect(block) def get_vram_stats(self) -> Dict[str, Any]: """Get VRAM statistics""" ftl_stats = self.ftl.get_stats() stats = { "total_vram": self.total_vram, "page_size": self.page_size, "block_size": self.block_size, "used_blocks": ftl_stats.get('total_blocks', 0) - ftl_stats.get('free_blocks', 0), "free_blocks": ftl_stats.get('free_blocks', 0), "wear_leveling": ftl_stats.get('avg_erase_count', 0), "cache_hit_ratio": ( ftl_stats.get('cache_hits', 0) / max(ftl_stats.get('cache_hits', 0) + ftl_stats.get('cache_misses', 0), 1) ) * 100 } # Add PCIe stats stats.update({ "pcie_bandwidth": self.max_gbps, "active_lanes": self.active_lanes, "lane_errors": sum(self.lane_errors) }) return stats def optimize_lanes(self) -> None: """Optimize lane configuration based on errors and performance""" error_threshold = 10 for i, errors in enumerate(self.lane_errors): if errors > error_threshold: self._disable_lane(i) self._rebalance_lanes() def _disable_lane(self, lane_idx: int) -> None: """Disable a problematic lane""" group_idx = lane_idx // 4 if group_idx < len(self.lane_groups): self.lane_groups[group_idx] -= 1 self.active_lanes -= 1 self._update_max_bandwidth() def _update_max_bandwidth(self) -> None: """Update maximum bandwidth based on active lanes""" lane_bandwidth = self.PCIE_VERSIONS[self.version]['bandwidth'] self.max_gbps = lane_bandwidth * self.active_lanes * self.spec['encoding'] self._rebalance_bandwidth()