tempo-snn-v2 / src /physics.py
KD099's picture
Upload src/physics.py with huggingface_hub
5d6dfb7 verified
"""
physics.py
==========
ReRAM / STT-MRAM physics sensor model with Arrhenius-grounded reliability.
Literature-grounded additions:
- RC thermal network (Zhang et al. IEEE Trans. Nanotech 2018)
- Arrhenius retention time (Cheshmikhani & Asadi 2018)
- Temperature-dependent endurance (Zhang et al.)
- Read disturb accumulation (STT-MRAM Testing Survey, arXiv 2020)
"""
import math
import warnings
from dataclasses import dataclass, field
from collections import deque
from typing import Dict, Optional
import numpy as np
@dataclass
class ThermalRCParameters:
"""RC thermal network parameters from Zhang et al. compact model."""
R_th_jc: float = 2.0 # °C/W, junction-to-case thermal resistance
R_th_ca: float = 5.0 # °C/W, case-to-ambient thermal resistance
C_th_j: float = 0.01 # J/°C, junction thermal capacitance
C_th_c: float = 0.05 # J/°C, case thermal capacitance
T_ambient: float = 25.0
@dataclass
class ArrheniusParameters:
"""Arrhenius model for retention and endurance."""
# Retention: tau = tau_0 * exp(Ea_ret / (k_B * T))
tau_0: float = 1e-9 # seconds, attempt time
Ea_ret: float = 0.4 # eV, retention activation energy (typical STT-MRAM)
# Endurance: N_end decreases with temperature (Arrhenius: exp(+Ea/(kT)))
N_0: float = 1e15 # cycles at reference temp
Ea_end: float = 0.15 # eV, endurance activation energy (lower = less sensitive)
# Read disturb: beta per read at reference temp
beta_0: float = 1e-12 # disturb probability per read at 300K
Ea_read: float = 0.2 # eV, read disturb activation energy
k_B_eV = 8.617333e-5 # Boltzmann constant in eV/K
class PhysicsSensorModel:
"""
v3: RC thermal network + Arrhenius retention + temperature-dependent endurance
+ read disturb counter + write-error rate model (Werner/Prejbeanu STT-MRAM)
+ array-level parasitic effects (G-1, G-3).
"""
def __init__(self,
V_th_nominal=0.6,
T_ambient=25.0,
R_HRS=1e6,
R_LRS=1e4,
alpha_drift=0.003,
sigma_0=0.02,
alpha_thermal=0.08,
T_ref=25.0,
max_endurance_base=1e6,
# New: RC thermal + Arrhenius
thermal_params: Optional[ThermalRCParameters] = None,
arrhenius_params: Optional[ArrheniusParameters] = None,
# G-1: Write-error rate model (Werner et al. PRB 2005 / Prejbeanu IEDM 2013)
t_pulse_ns: float = 10.0, # write pulse width
Delta_E0: float = 60.0, # energy barrier at T_ref (k_B*T units)
# G-3: Array-level parasitic effects
R_line_ohm: float = 2.0, # BL/WL line resistance per cell (Ω)
N_cols: int = 512, # crossbar columns
N_rows: int = 512, # crossbar rows
sneak_ratio: float = 0.05, # sneak path current ratio
):
self.V_th_nominal = V_th_nominal
self.T_ambient = T_ambient
self.T_current = T_ambient
# Case temperature for RC network
self.T_case = T_ambient
self.R_HRS = R_HRS
self.R_LRS = R_LRS
self.alpha_drift = alpha_drift
self.sigma_0 = sigma_0
self.alpha_thermal = alpha_thermal
self.T_ref = T_ref
self.max_endurance_base = max_endurance_base
self.cycle_count = 0
self.write_cycles = 0
self.read_cycles = 0 # NEW: read disturb tracking
self.fault_history = deque(maxlen=200)
self.voltage_history = deque(maxlen=200)
self.temp_history = deque(maxlen=200)
self.thermal = thermal_params or ThermalRCParameters(T_ambient=T_ambient)
self.arrhenius = arrhenius_params or ArrheniusParameters()
# G-1: Write-error rate parameters
self.t_pulse_ns = t_pulse_ns
self.Delta_E0 = Delta_E0
# G-3: Array-level parasitic effects
self.R_line_ohm = R_line_ohm
self.N_cols = N_cols
self.N_rows = N_rows
self.sneak_ratio = sneak_ratio
# ---- RC Thermal Network (Zhang et al. compact model) ----
def update_temperature(self, workload_intensity: float,
compute_target: str = "PIM",
dt_s: float = 1e-3) -> float:
"""
Newtonian cooling via 2-node RC network.
P_gen depends on compute_target and workload_intensity.
"""
# Power generation (W) — PIM is higher due to crossbar current
power_rates = {"PIM": 0.3, "CPU": 0.08, "GPU": 0.15}
P_gen = power_rates.get(compute_target, 0.1) * workload_intensity
# Junction temperature update
dT_j = dt_s / self.thermal.C_th_j * (
P_gen - (self.T_current - self.T_case) / self.thermal.R_th_jc
)
# Case temperature update
dT_c = dt_s / self.thermal.C_th_c * (
(self.T_current - self.T_case) / self.thermal.R_th_jc -
(self.T_case - self.thermal.T_ambient) / self.thermal.R_th_ca
)
self.T_current = max(self.thermal.T_ambient, self.T_current + dT_j)
self.T_case = max(self.thermal.T_ambient, self.T_case + dT_c)
self.temp_history.append(self.T_current)
return self.T_current
def get_threshold_voltage(self, deterministic: bool = False) -> float:
dT = self.T_current - self.T_ref
drift = self.alpha_drift * dT
if deterministic:
jitter = 0.0
else:
sigma_sq = self.sigma_0 ** 2 * np.exp(self.alpha_thermal * dT)
jitter = np.random.normal(0.0, np.sqrt(max(sigma_sq, 0.0)))
V_th = self.V_th_nominal + drift + jitter
self.voltage_history.append(V_th)
self.cycle_count += 1
return float(V_th)
def get_fault_density(self) -> float:
"""
Includes:
- Arrhenius temperature acceleration
- Wear factor (write cycles vs temperature-dependent endurance)
- Read disturb accumulation
"""
base_fault_rate = 0.001
dT = self.T_current - self.T_ref
acceleration = np.exp(0.05 * dT)
# Temperature-dependent endurance (Arrhenius)
T_kelvin = self.T_current + 273.15
T_ref_k = self.T_ref + 273.15
N_endurance = (self.max_endurance_base *
math.exp(-self.arrhenius.Ea_end / k_B_eV *
(1.0 / T_kelvin - 1.0 / T_ref_k)))
wear_factor = 1.0 + (self.write_cycles / max(N_endurance, 1.0)) * 5.0
# Read disturb (Arrhenius)
beta_T = (self.arrhenius.beta_0 *
math.exp(self.arrhenius.Ea_read / k_B_eV *
(1.0 / T_ref_k - 1.0 / T_kelvin)))
read_disturb = beta_T * self.read_cycles
fault_density = min(
base_fault_rate * acceleration * wear_factor + read_disturb, 1.0)
self.fault_history.append(fault_density)
return float(fault_density)
def get_retention_time(self) -> float:
"""
Arrhenius retention time.
tau_ret = tau_0 * exp(Ea_ret / (k_B * T))
"""
T_kelvin = self.T_current + 273.15
tau = (self.arrhenius.tau_0 *
math.exp(self.arrhenius.Ea_ret / (k_B_eV * T_kelvin)))
return float(tau)
def get_resistance_ratio(self) -> float:
dT = self.T_current - self.T_ref
R_HRS_T = self.R_HRS * np.exp(-0.01 * dT)
R_LRS_T = self.R_LRS * np.exp(0.005 * dT)
return float(R_HRS_T / R_LRS_T)
def get_read_margin(self) -> float:
return float(np.clip((self.get_resistance_ratio() - 1) / 99, 0, 1))
def get_thermal_reliability(self) -> float:
t_factor = np.clip(1.0 - (self.T_current - self.T_ambient) / 75.0, 0, 1)
if len(self.voltage_history) >= 10:
recent_vth = list(self.voltage_history)[-10:]
vth_std = np.std(recent_vth)
v_factor = np.clip(1.0 - vth_std / 0.1, 0, 1)
else:
v_factor = 0.8
margin_factor = self.get_read_margin()
endurance_factor = np.clip(
1.0 - self.write_cycles / self.get_temperature_dependent_endurance(), 0, 1)
return (0.35 * t_factor + 0.25 * v_factor +
0.25 * margin_factor + 0.15 * endurance_factor)
def get_temperature_dependent_endurance(self) -> float:
"""Arrhenius temperature-dependent endurance (decreases with temperature)."""
T_kelvin = self.T_current + 273.15
T_ref_k = self.T_ref + 273.15
# At higher T, (1/T - 1/T_ref) < 0, giving exp(negative) < 1 → lower endurance
return (self.max_endurance_base *
math.exp(self.arrhenius.Ea_end / k_B_eV *
(1.0 / T_kelvin - 1.0 / T_ref_k)))
def record_write(self, num_writes: int = 1):
self.write_cycles += num_writes
def record_read(self, num_reads: int = 1):
"""NEW: track read disturb accumulation."""
self.read_cycles += num_reads
# ---- G-1: Write-Error Rate Model (Werner et al. PRB 2005 / Prejbeanu IEDM 2013) ----
def get_write_error_rate(self) -> float:
"""
Thermal-activation model for STT-MRAM write errors (Werner/Prejbeanu).
P_error = f(ΔE, T, t_pulse). ΔE ~ 15-40 k_B*T for practical devices.
Higher T → lower barrier → exponentially higher error rate.
Shorter pulse → incomplete switching → higher error.
"""
T_k = self.T_current + 273.15
T_ref_k = self.T_ref + 273.15
# Barrier scales inversely with temperature
delta_E = self.Delta_E0 * (T_ref_k / T_k)
# Short-pulse penalty: shorter than critical ~10 ns → errors rise
t_crit = 10.0
pulse_penalty = 1.0 + max(0.0, (t_crit - self.t_pulse_ns) / t_crit) * 2.0
# Base error rate at T_ref is ~1e-9; scales as exp(-delta_E)
p_error = 1e-6 * np.exp(-delta_E + self.Delta_E0) * pulse_penalty
return float(np.clip(p_error, 1e-12, 0.5))
def get_effective_write_yield(self, n_bits: int = 1_048_576) -> float:
"""Yield = (1 - P_error)^n_bits for an n_bits-wide write."""
per = self.get_write_error_rate()
return float((1.0 - per) ** n_bits)
# ---- G-3: Array-Level Parasitic Effects ----
def get_effective_read_voltage(self, V_applied: float = 0.2,
row_idx: int = 0, col_idx: int = 0) -> float:
"""
Voltage drop across BL/WL line resistance. Worst-case at far corner.
IR_drop = I_cell * R_line * (row + col). Sneak paths add parallel load.
"""
# Select cell resistance (average)
R_cell = (self.R_HRS + self.R_LRS) / 2.0
I_cell = V_applied / R_cell
# IR drop along lines increases with distance from driver
ir_drop = I_cell * self.R_line_ohm * (row_idx + col_idx)
# Sneak path loading: more unselected cells near far corner → more leakage
n_unselected = (self.N_rows - row_idx) * (self.N_cols - col_idx)
sneak_factor = 1.0 / (1.0 + self.sneak_ratio * n_unselected / max(1, self.N_rows + self.N_cols))
V_eff = (V_applied - ir_drop) * sneak_factor
return float(np.clip(V_eff, 0.01, V_applied))
def get_sneak_path_penalty(self) -> float:
"""Returns a fault-density multiplier from sneak path current."""
return 1.0 + self.sneak_ratio * (self.N_rows * self.N_cols) / 262144.0
def snapshot(self, deterministic: bool = True) -> Dict[str, float]:
snap = {
"temperature_c": self.T_current,
"temperature_case": self.T_case,
"v_threshold": self.get_threshold_voltage(deterministic=deterministic),
"fault_density": self.get_fault_density(),
"read_margin": self.get_read_margin(),
"reliability": self.get_thermal_reliability(),
"retention_time_s": self.get_retention_time(),
"endurance_remaining": self.get_temperature_dependent_endurance() - self.write_cycles,
# G-1
"write_error_rate": self.get_write_error_rate(),
"write_yield_1Mbit": self.get_effective_write_yield(),
# G-3
"effective_read_v": self.get_effective_read_voltage(),
"sneak_penalty": self.get_sneak_path_penalty(),
}
return snap