INV / vram /interface.py
Fred808's picture
Upload 256 files
7a0c684 verified
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
import time
import numpy as np
from datetime import datetime
from .remote_storage import RemoteStorageManager
from .ftl import AdvancedFTL
@dataclass
class QoSParameters:
"""Quality of Service parameters"""
priority: int # 0-7, higher is more important
bandwidth_min: float # Minimum guaranteed bandwidth in GB/s
latency_max: float # Maximum acceptable latency in microseconds
bandwidth_weight: float # Weight for bandwidth allocation
@dataclass
class DMARequest:
"""DMA transfer request details"""
source_addr: int
dest_addr: int
size: int
priority: int
is_async: bool
callback: Optional[callable] = None
class PCIeInterface:
PCIE_VERSIONS = {
'4.0': {'bandwidth': 16.0, 'encoding': 128/130, 'base_latency': 0.5},
'5.0': {'bandwidth': 32.0, 'encoding': 128/130, 'base_latency': 0.4},
'6.0': {'bandwidth': 64.0, 'encoding': 242/256, 'base_latency': 0.3}
}
def __init__(self, version='6.0', lanes=16, max_gbps=None):
self.version = version
self.lanes = lanes
self.spec = self.PCIE_VERSIONS[version]
self.max_gbps = max_gbps or self.spec['bandwidth'] * lanes * self.spec['encoding']
# Initialize storage components
self.storage = RemoteStorageManager()
self.total_vram = 16 * 1024 * 1024 * 1024 # 16GB default
self.page_size = 4096 # 4KB pages
self.block_size = 256 * self.page_size # 1MB blocks
# Initialize FTL
total_blocks = self.total_vram // self.block_size
pages_per_block = self.block_size // self.page_size
self.ftl = AdvancedFTL(total_blocks=total_blocks, pages_per_block=pages_per_block)
# Initialize interface state in remote storage
self._init_interface_state()
# Lane bonding and management
self.active_lanes = lanes
self.lane_groups: List[int] = self._initialize_lane_groups()
self.lane_errors = [0] * lanes
# QoS and bandwidth management
self.active_transfers: Dict[int, DMARequest] = {}
self.qos_profiles: Dict[int, QoSParameters] = {}
self.bandwidth_allocations: Dict[int, float] = {}
# DMA engine
self.dma_queue: List[DMARequest] = []
self.dma_active = False
self.dma_batch_size = 1024 * 1024 # 1MB batches
def _init_interface_state(self):
"""Initialize interface state in remote storage"""
interface_state = {
'version': self.version,
'lanes': self.lanes,
'max_gbps': self.max_gbps,
'active_lanes': self.active_lanes,
'lane_groups': self.lane_groups,
'lane_errors': self.lane_errors,
'qos_profiles': {},
'bandwidth_allocations': {},
'timestamp': datetime.now().isoformat()
}
# Store initial state
self.storage.store_interface_state(interface_state)
def _initialize_lane_groups(self) -> List[int]:
"""Initialize lane groups for bonding"""
groups = []
lanes_per_group = 4
for i in range(0, self.lanes, lanes_per_group):
groups.append(lanes_per_group)
return groups
def add_qos_profile(self, profile_id: int, params: QoSParameters):
"""Add or update QoS profile"""
self.qos_profiles[profile_id] = params
self._rebalance_bandwidth()
def _rebalance_bandwidth(self):
"""Rebalance bandwidth allocations based on QoS profiles and log to remote DB"""
total_weight = sum(p.bandwidth_weight for p in self.qos_profiles.values())
available_bandwidth = self.max_gbps
for profile_id, params in self.qos_profiles.items():
# Ensure minimum bandwidth
self.bandwidth_allocations[profile_id] = params.bandwidth_min
available_bandwidth -= params.bandwidth_min
# Distribute remaining bandwidth by weight
if available_bandwidth > 0 and total_weight > 0:
for profile_id, params in self.qos_profiles.items():
extra = (params.bandwidth_weight / total_weight) * available_bandwidth
self.bandwidth_allocations[profile_id] += extra
# Log QoS metrics to remote storage
qos_data = {
'timestamp': datetime.now().isoformat(),
'profile_id': profile_id,
'bandwidth_allocated': self.bandwidth_allocations[profile_id],
'bandwidth_used': 0.0, # Will be updated as bandwidth is used
'latency_measured': 0.0, # Will be updated as transfers occur
'latency_target': params.latency_max
}
self.storage.store_qos_metrics(qos_data)
def _log_transfer(self, size_bytes: int, direction: str, qos_profile_id: Optional[int],
transfer_time: float, bandwidth: float):
"""Log transfer details to remote storage"""
transfer_data = {
'timestamp': datetime.now().isoformat(),
'size_bytes': size_bytes,
'direction': direction,
'qos_profile_id': qos_profile_id,
'transfer_time': transfer_time,
'lanes_active': self.active_lanes,
'bandwidth_achieved': bandwidth
}
self.storage.store_transfer(transfer_data)
def transfer_time(self, size_bytes: int, qos_profile_id: Optional[int] = None) -> float:
"""Calculate transfer time with QoS consideration"""
# Get effective bandwidth based on QoS
effective_bandwidth = self.max_gbps
if qos_profile_id is not None and qos_profile_id in self.bandwidth_allocations:
effective_bandwidth = self.bandwidth_allocations[qos_profile_id]
# Calculate transfer time
gb = size_bytes / 1e9
transfer_time = gb / effective_bandwidth
# Add encoding overhead
transfer_time /= self.spec['encoding']
# Add base latency
total_time = transfer_time + self.spec['base_latency']
# Log to remote DB
self._log_transfer(size_bytes, 'calculate', qos_profile_id, total_time, effective_bandwidth)
return total_time
def initiate_dma_transfer(self, request: DMARequest) -> bool:
"""Initialize DMA transfer with QoS awareness"""
self.dma_queue.append(request)
if not self.dma_active:
self._process_dma_queue()
return True
def _process_dma_queue(self):
"""Process DMA queue with QoS prioritization"""
if not self.dma_queue:
self.dma_active = False
return
self.dma_active = True
# Sort by priority
self.dma_queue.sort(key=lambda x: x.priority, reverse=True)
while self.dma_queue:
request = self.dma_queue[0]
# Process in batches for better efficiency
remaining = request.size
while remaining > 0:
batch_size = min(remaining, self.dma_batch_size)
self._execute_dma_batch(request, batch_size)
remaining -= batch_size
if request.callback:
request.callback()
self.dma_queue.pop(0)
def _execute_dma_batch(self, request: DMARequest, batch_size: int):
"""Execute a single DMA batch transfer with remote logging"""
start_time = time.time()
# Validate addresses using FTL
source_phys = self.ftl.get_phys(request.source_addr // self.page_size)
dest_phys = self.ftl.get_phys(request.dest_addr // self.page_size)
if source_phys is None or dest_phys is None:
raise RuntimeError("Invalid memory address in DMA transfer")
transfer_time = self.transfer_time(batch_size)
# Simulate DMA transfer
time.sleep(transfer_time)
# Log DMA operation to remote storage
dma_data = {
'timestamp': datetime.now().isoformat(),
'source_addr': request.source_addr,
'dest_addr': request.dest_addr,
'size_bytes': batch_size,
'priority': request.priority,
'completion_time': time.time() - start_time,
'status': 'completed'
}
self.storage.store_dma_operation(dma_data)
def allocate_vram(self, size: int, qos: Optional[QoSParameters] = None) -> Optional[int]:
"""
Allocate VRAM with optional QoS parameters
Args:
size: Size in bytes to allocate
qos: Quality of Service parameters
Returns:
Virtual address or None if allocation fails
"""
try:
# Round up to nearest page size
pages_needed = (size + self.page_size - 1) // self.page_size
# Get a free block from FTL
block_id = self.ftl.get_free_block()
if block_id is None:
# Try garbage collection
self._run_garbage_collection()
block_id = self.ftl.get_free_block()
if block_id is None:
raise RuntimeError("Out of VRAM")
# Calculate virtual address
virt_addr = block_id * self.block_size
# Map pages in FTL
for i in range(pages_needed):
lba = (virt_addr // self.page_size) + i
phys = (block_id * self.ftl.pages_per_block) + i
# Mark as hot if high priority QoS
is_hot = qos and qos.priority >= 6
self.ftl.map(lba, phys, is_hot)
return virt_addr
except Exception as e:
self.storage.log_error("VRAM allocation failed", str(e))
return None
def free_vram(self, virt_addr: int, size: int) -> bool:
"""
Free allocated VRAM
Args:
virt_addr: Virtual address to free
size: Size in bytes to free
Returns:
True if successful
"""
try:
# Calculate pages to free
start_page = virt_addr // self.page_size
pages_to_free = (size + self.page_size - 1) // self.page_size
# Invalidate pages in FTL
for i in range(pages_to_free):
lba = start_page + i
phys = self.ftl.get_phys(lba)
if phys is not None:
block_id = phys // self.ftl.pages_per_block
self.ftl.garbage_collect(block_id)
return True
except Exception as e:
self.storage.log_error("VRAM free failed", str(e))
return False
def _run_garbage_collection(self) -> None:
"""Run garbage collection on VRAM blocks"""
stats = self.ftl.get_stats()
if stats.get('free_blocks', 0) > stats.get('total_blocks', 0) * 0.1:
return # Still enough free blocks
# Find and collect blocks with most invalid pages
for block in range(stats.get('total_blocks', 0)):
self.ftl.garbage_collect(block)
def get_vram_stats(self) -> Dict[str, Any]:
"""Get VRAM statistics"""
ftl_stats = self.ftl.get_stats()
stats = {
"total_vram": self.total_vram,
"page_size": self.page_size,
"block_size": self.block_size,
"used_blocks": ftl_stats.get('total_blocks', 0) - ftl_stats.get('free_blocks', 0),
"free_blocks": ftl_stats.get('free_blocks', 0),
"wear_leveling": ftl_stats.get('avg_erase_count', 0),
"cache_hit_ratio": (
ftl_stats.get('cache_hits', 0) /
max(ftl_stats.get('cache_hits', 0) + ftl_stats.get('cache_misses', 0), 1)
) * 100
}
# Add PCIe stats
stats.update({
"pcie_bandwidth": self.max_gbps,
"active_lanes": self.active_lanes,
"lane_errors": sum(self.lane_errors)
})
return stats
def optimize_lanes(self) -> None:
"""Optimize lane configuration based on errors and performance"""
error_threshold = 10
for i, errors in enumerate(self.lane_errors):
if errors > error_threshold:
self._disable_lane(i)
self._rebalance_lanes()
def _disable_lane(self, lane_idx: int) -> None:
"""Disable a problematic lane"""
group_idx = lane_idx // 4
if group_idx < len(self.lane_groups):
self.lane_groups[group_idx] -= 1
self.active_lanes -= 1
self._update_max_bandwidth()
def _update_max_bandwidth(self) -> None:
"""Update maximum bandwidth based on active lanes"""
lane_bandwidth = self.PCIE_VERSIONS[self.version]['bandwidth']
self.max_gbps = lane_bandwidth * self.active_lanes * self.spec['encoding']
self._rebalance_bandwidth()