File size: 1,581 Bytes
11f2119 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import re
class ModelDiagnostics:
@staticmethod
def estimate_vram(param_str):
"""
Estimates VRAM usage based on parameter string (e.g., '7B', '0.5B').
Formula: (Params * Precision Bytes) + 20% Overhead for Context/Activations
"""
try:
# Clean string and extract number
clean_str = param_str.lower().replace('b', '').replace('m', '')
val = float(clean_str)
# Normalize to Billions
if 'm' in param_str.lower():
val = val / 1000.0
# Constants
overhead = 1.2 # 20% overhead for context window/activations
# Calculations
fp16_gb = (val * 2 * overhead) # 2 bytes per param
int8_gb = (val * 1 * overhead) # 1 byte per param
fp32_gb = (val * 4 * overhead) # 4 bytes per param
return {
"FP32 (Training/Full)": f"{fp32_gb:.2f} GB",
"FP16 (Inference)": f"{fp16_gb:.2f} GB",
"INT8 (Quantized)": f"{int8_gb:.2f} GB",
"params_in_billions": val
}
except Exception as e:
return None
@staticmethod
def get_layer_structure(model):
"""
Returns the raw string representation of the PyTorch model modules.
"""
if model:
# We strip the outer wrapper to get straight to the layers
return str(model)
return "Model not loaded." |