|
|
from dataclasses import dataclass
|
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
|
import time
|
|
|
import numpy as np
|
|
|
from datetime import datetime
|
|
|
from .remote_storage import RemoteStorageManager
|
|
|
from .ftl import AdvancedFTL
|
|
|
|
|
|
@dataclass
|
|
|
class QoSParameters:
|
|
|
"""Quality of Service parameters"""
|
|
|
priority: int
|
|
|
bandwidth_min: float
|
|
|
latency_max: float
|
|
|
bandwidth_weight: float
|
|
|
|
|
|
@dataclass
|
|
|
class DMARequest:
|
|
|
"""DMA transfer request details"""
|
|
|
source_addr: int
|
|
|
dest_addr: int
|
|
|
size: int
|
|
|
priority: int
|
|
|
is_async: bool
|
|
|
callback: Optional[callable] = None
|
|
|
|
|
|
class PCIeInterface:
|
|
|
PCIE_VERSIONS = {
|
|
|
'4.0': {'bandwidth': 16.0, 'encoding': 128/130, 'base_latency': 0.5},
|
|
|
'5.0': {'bandwidth': 32.0, 'encoding': 128/130, 'base_latency': 0.4},
|
|
|
'6.0': {'bandwidth': 64.0, 'encoding': 242/256, 'base_latency': 0.3}
|
|
|
}
|
|
|
|
|
|
def __init__(self, version='6.0', lanes=16, max_gbps=None):
|
|
|
self.version = version
|
|
|
self.lanes = lanes
|
|
|
self.spec = self.PCIE_VERSIONS[version]
|
|
|
self.max_gbps = max_gbps or self.spec['bandwidth'] * lanes * self.spec['encoding']
|
|
|
|
|
|
|
|
|
self.storage = RemoteStorageManager()
|
|
|
self.total_vram = 16 * 1024 * 1024 * 1024
|
|
|
self.page_size = 4096
|
|
|
self.block_size = 256 * self.page_size
|
|
|
|
|
|
|
|
|
total_blocks = self.total_vram // self.block_size
|
|
|
pages_per_block = self.block_size // self.page_size
|
|
|
self.ftl = AdvancedFTL(total_blocks=total_blocks, pages_per_block=pages_per_block)
|
|
|
|
|
|
|
|
|
self._init_interface_state()
|
|
|
|
|
|
|
|
|
self.active_lanes = lanes
|
|
|
self.lane_groups: List[int] = self._initialize_lane_groups()
|
|
|
self.lane_errors = [0] * lanes
|
|
|
|
|
|
|
|
|
self.active_transfers: Dict[int, DMARequest] = {}
|
|
|
self.qos_profiles: Dict[int, QoSParameters] = {}
|
|
|
self.bandwidth_allocations: Dict[int, float] = {}
|
|
|
|
|
|
|
|
|
self.dma_queue: List[DMARequest] = []
|
|
|
self.dma_active = False
|
|
|
self.dma_batch_size = 1024 * 1024
|
|
|
|
|
|
def _init_interface_state(self):
|
|
|
"""Initialize interface state in remote storage"""
|
|
|
interface_state = {
|
|
|
'version': self.version,
|
|
|
'lanes': self.lanes,
|
|
|
'max_gbps': self.max_gbps,
|
|
|
'active_lanes': self.active_lanes,
|
|
|
'lane_groups': self.lane_groups,
|
|
|
'lane_errors': self.lane_errors,
|
|
|
'qos_profiles': {},
|
|
|
'bandwidth_allocations': {},
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
}
|
|
|
|
|
|
|
|
|
self.storage.store_interface_state(interface_state)
|
|
|
|
|
|
def _initialize_lane_groups(self) -> List[int]:
|
|
|
"""Initialize lane groups for bonding"""
|
|
|
groups = []
|
|
|
lanes_per_group = 4
|
|
|
for i in range(0, self.lanes, lanes_per_group):
|
|
|
groups.append(lanes_per_group)
|
|
|
return groups
|
|
|
|
|
|
def add_qos_profile(self, profile_id: int, params: QoSParameters):
|
|
|
"""Add or update QoS profile"""
|
|
|
self.qos_profiles[profile_id] = params
|
|
|
self._rebalance_bandwidth()
|
|
|
|
|
|
def _rebalance_bandwidth(self):
|
|
|
"""Rebalance bandwidth allocations based on QoS profiles and log to remote DB"""
|
|
|
total_weight = sum(p.bandwidth_weight for p in self.qos_profiles.values())
|
|
|
available_bandwidth = self.max_gbps
|
|
|
|
|
|
for profile_id, params in self.qos_profiles.items():
|
|
|
|
|
|
self.bandwidth_allocations[profile_id] = params.bandwidth_min
|
|
|
available_bandwidth -= params.bandwidth_min
|
|
|
|
|
|
|
|
|
if available_bandwidth > 0 and total_weight > 0:
|
|
|
for profile_id, params in self.qos_profiles.items():
|
|
|
extra = (params.bandwidth_weight / total_weight) * available_bandwidth
|
|
|
self.bandwidth_allocations[profile_id] += extra
|
|
|
|
|
|
|
|
|
qos_data = {
|
|
|
'timestamp': datetime.now().isoformat(),
|
|
|
'profile_id': profile_id,
|
|
|
'bandwidth_allocated': self.bandwidth_allocations[profile_id],
|
|
|
'bandwidth_used': 0.0,
|
|
|
'latency_measured': 0.0,
|
|
|
'latency_target': params.latency_max
|
|
|
}
|
|
|
self.storage.store_qos_metrics(qos_data)
|
|
|
|
|
|
def _log_transfer(self, size_bytes: int, direction: str, qos_profile_id: Optional[int],
|
|
|
transfer_time: float, bandwidth: float):
|
|
|
"""Log transfer details to remote storage"""
|
|
|
transfer_data = {
|
|
|
'timestamp': datetime.now().isoformat(),
|
|
|
'size_bytes': size_bytes,
|
|
|
'direction': direction,
|
|
|
'qos_profile_id': qos_profile_id,
|
|
|
'transfer_time': transfer_time,
|
|
|
'lanes_active': self.active_lanes,
|
|
|
'bandwidth_achieved': bandwidth
|
|
|
}
|
|
|
self.storage.store_transfer(transfer_data)
|
|
|
|
|
|
def transfer_time(self, size_bytes: int, qos_profile_id: Optional[int] = None) -> float:
|
|
|
"""Calculate transfer time with QoS consideration"""
|
|
|
|
|
|
effective_bandwidth = self.max_gbps
|
|
|
if qos_profile_id is not None and qos_profile_id in self.bandwidth_allocations:
|
|
|
effective_bandwidth = self.bandwidth_allocations[qos_profile_id]
|
|
|
|
|
|
|
|
|
gb = size_bytes / 1e9
|
|
|
transfer_time = gb / effective_bandwidth
|
|
|
|
|
|
|
|
|
transfer_time /= self.spec['encoding']
|
|
|
|
|
|
|
|
|
total_time = transfer_time + self.spec['base_latency']
|
|
|
|
|
|
|
|
|
self._log_transfer(size_bytes, 'calculate', qos_profile_id, total_time, effective_bandwidth)
|
|
|
|
|
|
return total_time
|
|
|
|
|
|
def initiate_dma_transfer(self, request: DMARequest) -> bool:
|
|
|
"""Initialize DMA transfer with QoS awareness"""
|
|
|
self.dma_queue.append(request)
|
|
|
if not self.dma_active:
|
|
|
self._process_dma_queue()
|
|
|
return True
|
|
|
|
|
|
def _process_dma_queue(self):
|
|
|
"""Process DMA queue with QoS prioritization"""
|
|
|
if not self.dma_queue:
|
|
|
self.dma_active = False
|
|
|
return
|
|
|
|
|
|
self.dma_active = True
|
|
|
|
|
|
self.dma_queue.sort(key=lambda x: x.priority, reverse=True)
|
|
|
|
|
|
while self.dma_queue:
|
|
|
request = self.dma_queue[0]
|
|
|
|
|
|
remaining = request.size
|
|
|
while remaining > 0:
|
|
|
batch_size = min(remaining, self.dma_batch_size)
|
|
|
self._execute_dma_batch(request, batch_size)
|
|
|
remaining -= batch_size
|
|
|
|
|
|
if request.callback:
|
|
|
request.callback()
|
|
|
self.dma_queue.pop(0)
|
|
|
|
|
|
def _execute_dma_batch(self, request: DMARequest, batch_size: int):
|
|
|
"""Execute a single DMA batch transfer with remote logging"""
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
source_phys = self.ftl.get_phys(request.source_addr // self.page_size)
|
|
|
dest_phys = self.ftl.get_phys(request.dest_addr // self.page_size)
|
|
|
|
|
|
if source_phys is None or dest_phys is None:
|
|
|
raise RuntimeError("Invalid memory address in DMA transfer")
|
|
|
|
|
|
transfer_time = self.transfer_time(batch_size)
|
|
|
|
|
|
|
|
|
time.sleep(transfer_time)
|
|
|
|
|
|
|
|
|
dma_data = {
|
|
|
'timestamp': datetime.now().isoformat(),
|
|
|
'source_addr': request.source_addr,
|
|
|
'dest_addr': request.dest_addr,
|
|
|
'size_bytes': batch_size,
|
|
|
'priority': request.priority,
|
|
|
'completion_time': time.time() - start_time,
|
|
|
'status': 'completed'
|
|
|
}
|
|
|
self.storage.store_dma_operation(dma_data)
|
|
|
|
|
|
def allocate_vram(self, size: int, qos: Optional[QoSParameters] = None) -> Optional[int]:
|
|
|
"""
|
|
|
Allocate VRAM with optional QoS parameters
|
|
|
Args:
|
|
|
size: Size in bytes to allocate
|
|
|
qos: Quality of Service parameters
|
|
|
Returns:
|
|
|
Virtual address or None if allocation fails
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
pages_needed = (size + self.page_size - 1) // self.page_size
|
|
|
|
|
|
|
|
|
block_id = self.ftl.get_free_block()
|
|
|
if block_id is None:
|
|
|
|
|
|
self._run_garbage_collection()
|
|
|
block_id = self.ftl.get_free_block()
|
|
|
if block_id is None:
|
|
|
raise RuntimeError("Out of VRAM")
|
|
|
|
|
|
|
|
|
virt_addr = block_id * self.block_size
|
|
|
|
|
|
|
|
|
for i in range(pages_needed):
|
|
|
lba = (virt_addr // self.page_size) + i
|
|
|
phys = (block_id * self.ftl.pages_per_block) + i
|
|
|
|
|
|
is_hot = qos and qos.priority >= 6
|
|
|
self.ftl.map(lba, phys, is_hot)
|
|
|
|
|
|
return virt_addr
|
|
|
|
|
|
except Exception as e:
|
|
|
self.storage.log_error("VRAM allocation failed", str(e))
|
|
|
return None
|
|
|
|
|
|
def free_vram(self, virt_addr: int, size: int) -> bool:
|
|
|
"""
|
|
|
Free allocated VRAM
|
|
|
Args:
|
|
|
virt_addr: Virtual address to free
|
|
|
size: Size in bytes to free
|
|
|
Returns:
|
|
|
True if successful
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
start_page = virt_addr // self.page_size
|
|
|
pages_to_free = (size + self.page_size - 1) // self.page_size
|
|
|
|
|
|
|
|
|
for i in range(pages_to_free):
|
|
|
lba = start_page + i
|
|
|
phys = self.ftl.get_phys(lba)
|
|
|
if phys is not None:
|
|
|
block_id = phys // self.ftl.pages_per_block
|
|
|
self.ftl.garbage_collect(block_id)
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
self.storage.log_error("VRAM free failed", str(e))
|
|
|
return False
|
|
|
|
|
|
def _run_garbage_collection(self) -> None:
|
|
|
"""Run garbage collection on VRAM blocks"""
|
|
|
stats = self.ftl.get_stats()
|
|
|
if stats.get('free_blocks', 0) > stats.get('total_blocks', 0) * 0.1:
|
|
|
return
|
|
|
|
|
|
|
|
|
for block in range(stats.get('total_blocks', 0)):
|
|
|
self.ftl.garbage_collect(block)
|
|
|
|
|
|
def get_vram_stats(self) -> Dict[str, Any]:
|
|
|
"""Get VRAM statistics"""
|
|
|
ftl_stats = self.ftl.get_stats()
|
|
|
stats = {
|
|
|
"total_vram": self.total_vram,
|
|
|
"page_size": self.page_size,
|
|
|
"block_size": self.block_size,
|
|
|
"used_blocks": ftl_stats.get('total_blocks', 0) - ftl_stats.get('free_blocks', 0),
|
|
|
"free_blocks": ftl_stats.get('free_blocks', 0),
|
|
|
"wear_leveling": ftl_stats.get('avg_erase_count', 0),
|
|
|
"cache_hit_ratio": (
|
|
|
ftl_stats.get('cache_hits', 0) /
|
|
|
max(ftl_stats.get('cache_hits', 0) + ftl_stats.get('cache_misses', 0), 1)
|
|
|
) * 100
|
|
|
}
|
|
|
|
|
|
|
|
|
stats.update({
|
|
|
"pcie_bandwidth": self.max_gbps,
|
|
|
"active_lanes": self.active_lanes,
|
|
|
"lane_errors": sum(self.lane_errors)
|
|
|
})
|
|
|
|
|
|
return stats
|
|
|
|
|
|
def optimize_lanes(self) -> None:
|
|
|
"""Optimize lane configuration based on errors and performance"""
|
|
|
error_threshold = 10
|
|
|
for i, errors in enumerate(self.lane_errors):
|
|
|
if errors > error_threshold:
|
|
|
self._disable_lane(i)
|
|
|
self._rebalance_lanes()
|
|
|
|
|
|
def _disable_lane(self, lane_idx: int) -> None:
|
|
|
"""Disable a problematic lane"""
|
|
|
group_idx = lane_idx // 4
|
|
|
if group_idx < len(self.lane_groups):
|
|
|
self.lane_groups[group_idx] -= 1
|
|
|
self.active_lanes -= 1
|
|
|
self._update_max_bandwidth()
|
|
|
|
|
|
def _update_max_bandwidth(self) -> None:
|
|
|
"""Update maximum bandwidth based on active lanes"""
|
|
|
lane_bandwidth = self.PCIE_VERSIONS[self.version]['bandwidth']
|
|
|
self.max_gbps = lane_bandwidth * self.active_lanes * self.spec['encoding']
|
|
|
self._rebalance_bandwidth()
|
|
|
|