INV / vram /remote_storage.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Remote storage implementation using DuckDB and HuggingFace Datasets.
Provides distributed storage for NAND cell states and memory operations.
"""
import duckdb
import numpy as np
import json
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
import threading
import time
from config import get_hf_token_cached
# Initialize token from .env
@dataclass
class CellState:
cell_id: str
block_id: int
page_id: int
value: int
trapped_electrons: int
wear_count: int
retention_loss: float
temperature: float
voltage_level: float
quantum_state: List[float]
timestamp: float
class RemoteStorageManager:
"""Manages remote storage operations using DuckDB and HuggingFace datasets"""
_instance = None
_lock = threading.Lock()
def __new__(cls, dataset_path: str = "hf://datasets/Fred808/helium/storage.json"):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init_storage(dataset_path)
return cls._instance
def _init_storage(self, dataset_path: str):
"""Initialize direct connection to remote storage"""
self.dataset_path = dataset_path
# Initialize DuckDB with httpfs support
self.con = duckdb.connect(dataset_path)
self.con.execute("INSTALL httpfs;")
self.con.execute("LOAD httpfs;")
# Configure S3-style access for HuggingFace
self.con.execute("SET s3_endpoint='hf.co';")
self.con.execute("SET s3_use_ssl=true;")
self.con.execute("SET s3_url_style='path';")
# Configure HuggingFace authentication
self.con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
self.con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
# Create tables for cell states
self.con.execute("""
CREATE TABLE IF NOT EXISTS cell_states (
cell_id VARCHAR,
block_id INTEGER,
page_id INTEGER,
value INTEGER,
trapped_electrons INTEGER,
wear_count INTEGER,
retention_loss DOUBLE,
temperature DOUBLE,
voltage_level DOUBLE,
quantum_state JSON,
timestamp DOUBLE,
PRIMARY KEY (cell_id)
)
""")
# Create tables for PCIe interface
self.con.execute("""
CREATE TABLE IF NOT EXISTS pcie_interface_state (
id VARCHAR PRIMARY KEY,
version VARCHAR,
lanes INTEGER,
max_gbps DOUBLE,
active_lanes INTEGER,
lane_groups JSON,
lane_errors JSON,
qos_profiles JSON,
bandwidth_allocations JSON,
timestamp TIMESTAMP
)
""")
self.con.execute("""
CREATE TABLE IF NOT EXISTS pcie_transfers (
id VARCHAR PRIMARY KEY,
timestamp TIMESTAMP,
size_bytes BIGINT,
direction VARCHAR,
qos_profile_id INTEGER,
transfer_time DOUBLE,
lanes_active INTEGER,
bandwidth_achieved DOUBLE
)
""")
self.con.execute("""
CREATE TABLE IF NOT EXISTS dma_operations (
id VARCHAR PRIMARY KEY,
timestamp TIMESTAMP,
source_addr BIGINT,
dest_addr BIGINT,
size_bytes BIGINT,
priority INTEGER,
completion_time DOUBLE,
status VARCHAR
)
""")
self.con.execute("""
CREATE TABLE IF NOT EXISTS qos_metrics (
timestamp TIMESTAMP,
profile_id INTEGER,
bandwidth_allocated DOUBLE,
bandwidth_used DOUBLE,
latency_measured DOUBLE,
latency_target DOUBLE
)
""")
# Create indexes
self.con.execute("CREATE INDEX IF NOT EXISTS idx_block ON cell_states(block_id)")
self.con.execute("CREATE INDEX IF NOT EXISTS idx_page ON cell_states(page_id)")
self.con.execute("CREATE INDEX IF NOT EXISTS idx_pcie_timestamp ON pcie_transfers(timestamp)")
self.con.execute("CREATE INDEX IF NOT EXISTS idx_dma_timestamp ON dma_operations(timestamp)")
self.con.execute("CREATE INDEX IF NOT EXISTS idx_qos_profile ON qos_metrics(profile_id)")
def ensure_connection(self):
"""Ensure connection is active and reconnect if needed"""
try:
self.con.execute("SELECT 1")
except:
self.con = duckdb.connect(self.dataset_path)
self.con.execute("INSTALL httpfs;")
self.con.execute("LOAD httpfs;")
self.con.execute("SET s3_endpoint='hf.co';")
self.con.execute("SET s3_use_ssl=true;")
self.con.execute("SET s3_url_style='path';")
self.con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
self.con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
def store_cell_state(self, state: CellState):
"""Store cell state in remote DB"""
self.ensure_connection()
with self._lock:
self.con.execute("""
INSERT OR REPLACE INTO cell_states
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
state.cell_id,
state.block_id,
state.page_id,
state.value,
state.trapped_electrons,
state.wear_count,
state.retention_loss,
state.temperature,
state.voltage_level,
json.dumps(state.quantum_state),
state.timestamp
))
def get_cell_state(self, cell_id: str) -> Optional[CellState]:
"""Retrieve cell state from remote DB"""
self.ensure_connection()
with self._lock:
result = self.con.execute("""
SELECT * FROM cell_states
WHERE cell_id = ?
""", [cell_id]).fetchone()
if result:
return CellState(
cell_id=result[0],
block_id=result[1],
page_id=result[2],
value=result[3],
trapped_electrons=result[4],
wear_count=result[5],
retention_loss=result[6],
temperature=result[7],
voltage_level=result[8],
quantum_state=json.loads(result[9]),
timestamp=result[10]
)
return None
def get_block_states(self, block_id: int) -> List[CellState]:
"""Retrieve all cell states for a given block"""
self.ensure_connection()
with self._lock:
results = self.con.execute("""
SELECT * FROM cell_states
WHERE block_id = ?
ORDER BY page_id, cell_id
""", [block_id]).fetchall()
return [
CellState(
cell_id=r[0],
block_id=r[1],
page_id=r[2],
value=r[3],
trapped_electrons=r[4],
wear_count=r[5],
retention_loss=r[6],
temperature=r[7],
voltage_level=r[8],
quantum_state=json.loads(r[9]),
timestamp=r[10]
)
for r in results
]
def update_cell_value(self, cell_id: str, value: int, quantum_state: List[float]):
"""Update cell value and quantum state"""
self.ensure_connection()
with self._lock:
self.con.execute("""
UPDATE cell_states
SET value = ?,
quantum_state = ?,
timestamp = ?
WHERE cell_id = ?
""", (value, json.dumps(quantum_state), time.time(), cell_id))
def get_block_wear_stats(self, block_id: int) -> Dict[str, float]:
"""Get wear statistics for a block"""
self.ensure_connection()
with self._lock:
result = self.con.execute("""
SELECT
AVG(wear_count) as avg_wear,
MAX(wear_count) as max_wear,
AVG(retention_loss) as avg_retention_loss,
COUNT(*) as cell_count
FROM cell_states
WHERE block_id = ?
""", [block_id]).fetchone()
return {
'avg_wear': result[0],
'max_wear': result[1],
'avg_retention_loss': result[2],
'cell_count': result[3]
}
def cleanup_old_states(self, max_age_hours: float = 24.0):
"""Clean up old states from all tables"""
self.ensure_connection()
with self._lock:
cutoff_time = time.time() - (max_age_hours * 3600)
tables = ['cell_states', 'pcie_transfers', 'dma_operations', 'qos_metrics']
for table in tables:
self.con.execute(f"""
DELETE FROM {table}
WHERE timestamp < ?
""", [cutoff_time])
def store_interface_state(self, state: Dict):
"""Store PCIe interface state"""
self.ensure_connection()
with self._lock:
self.con.execute("""
INSERT OR REPLACE INTO pcie_interface_state
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
state.get('id', str(time.time())),
state['version'],
state['lanes'],
state['max_gbps'],
state['active_lanes'],
json.dumps(state['lane_groups']),
json.dumps(state['lane_errors']),
json.dumps(state['qos_profiles']),
json.dumps(state['bandwidth_allocations']),
state['timestamp']
))
def store_transfer(self, transfer: Dict):
"""Store PCIe transfer details"""
self.ensure_connection()
with self._lock:
self.con.execute("""
INSERT INTO pcie_transfers
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
str(time.time()),
transfer['timestamp'],
transfer['size_bytes'],
transfer['direction'],
transfer['qos_profile_id'],
transfer['transfer_time'],
transfer['lanes_active'],
transfer['bandwidth_achieved']
))
def store_dma_operation(self, dma: Dict):
"""Store DMA operation details"""
self.ensure_connection()
with self._lock:
self.con.execute("""
INSERT INTO dma_operations
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
str(time.time()),
dma['timestamp'],
dma['source_addr'],
dma['dest_addr'],
dma['size_bytes'],
dma['priority'],
dma['completion_time'],
dma['status']
))
def store_qos_metrics(self, metrics: Dict):
"""Store QoS metrics"""
self.ensure_connection()
with self._lock:
self.con.execute("""
INSERT INTO qos_metrics
VALUES (?, ?, ?, ?, ?, ?)
""", (
metrics['timestamp'],
metrics['profile_id'],
metrics['bandwidth_allocated'],
metrics['bandwidth_used'],
metrics['latency_measured'],
metrics['latency_target']
))
def get_transfer_stats(self, time_window: float = 3600) -> Dict:
"""Get transfer statistics for the last time_window seconds"""
self.ensure_connection()
with self._lock:
cutoff = time.time() - time_window
result = self.con.execute("""
SELECT
COUNT(*) as total_transfers,
SUM(size_bytes) as total_bytes,
AVG(transfer_time) as avg_transfer_time,
AVG(bandwidth_achieved) as avg_bandwidth
FROM pcie_transfers
WHERE timestamp > ?
""", [cutoff]).fetchone()
return {
'total_transfers': result[0],
'total_bytes': result[1],
'avg_transfer_time': result[2],
'avg_bandwidth': result[3]
}
def get_dma_stats(self, time_window: float = 3600) -> Dict:
"""Get DMA operation statistics for the last time_window seconds"""
self.ensure_connection()
with self._lock:
cutoff = time.time() - time_window
result = self.con.execute("""
SELECT
COUNT(*) as total_operations,
SUM(size_bytes) as total_bytes,
AVG(completion_time) as avg_completion_time,
COUNT(CASE WHEN status = 'completed' THEN 1 END) as successful_ops
FROM dma_operations
WHERE timestamp > ?
""", [cutoff]).fetchone()
return {
'total_operations': result[0],
'total_bytes': result[1],
'avg_completion_time': result[2],
'successful_ops': result[3]
}
def get_qos_compliance(self, time_window: float = 3600) -> Dict:
"""Get QoS compliance metrics for the last time_window seconds"""
self.ensure_connection()
with self._lock:
cutoff = time.time() - time_window
result = self.con.execute("""
SELECT
profile_id,
AVG(bandwidth_used / bandwidth_allocated) as bandwidth_utilization,
AVG(CASE WHEN latency_measured <= latency_target THEN 1 ELSE 0 END) as sla_compliance
FROM qos_metrics
WHERE timestamp > ?
GROUP BY profile_id
""", [cutoff]).fetchall()
return {
row[0]: {
'bandwidth_utilization': row[1],
'sla_compliance': row[2]
}
for row in result
}