a0y0346
Phase 1: Core structure with model configs and placeholder tabs
341bde8
"""
Model loading and caching for FlashAttention Explorer.
Uses real HuggingFace models with SDPA attention implementation.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple, Optional
import os
from .constants import MODEL_CONFIGS
# Global cache to avoid reloading models
_model_cache: dict = {}
_tokenizer_cache: dict = {}
def get_device() -> str:
"""Get the appropriate device (cuda if available, else cpu)."""
return "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_name: str, force_reload: bool = False) -> AutoModelForCausalLM:
"""
Load a model with caching to avoid redundant downloads.
Args:
model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M")
force_reload: If True, reload even if cached
Returns:
Loaded model on appropriate device
"""
if model_name not in MODEL_CONFIGS:
raise ValueError(f"Unknown model: {model_name}. Available: {list(MODEL_CONFIGS.keys())}")
if model_name in _model_cache and not force_reload:
return _model_cache[model_name]
config = MODEL_CONFIGS[model_name]
model_id = config["model_id"]
# Check if we need token for gated models (Llama)
token = os.environ.get("HF_TOKEN", None)
# Load model with SDPA attention for backend switching
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
attn_implementation="sdpa", # Enable PyTorch SDPA backends
token=token,
trust_remote_code=True,
)
# Move to device if not using device_map
if not torch.cuda.is_available():
model = model.to("cpu")
model.eval()
_model_cache[model_name] = model
return model
def load_tokenizer(model_name: str) -> AutoTokenizer:
"""
Load tokenizer with caching.
Args:
model_name: Key from MODEL_CONFIGS
Returns:
Loaded tokenizer
"""
if model_name in _tokenizer_cache:
return _tokenizer_cache[model_name]
config = MODEL_CONFIGS[model_name]
model_id = config["model_id"]
token = os.environ.get("HF_TOKEN", None)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True,
)
# Ensure padding token exists
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
_tokenizer_cache[model_name] = tokenizer
return tokenizer
def load_model_and_tokenizer(
model_name: str
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""
Load both model and tokenizer.
Args:
model_name: Key from MODEL_CONFIGS
Returns:
Tuple of (model, tokenizer)
"""
model = load_model(model_name)
tokenizer = load_tokenizer(model_name)
return model, tokenizer
def get_model_memory_footprint(model_name: str) -> dict:
"""
Calculate theoretical memory footprint for a model.
Args:
model_name: Key from MODEL_CONFIGS
Returns:
Dict with memory breakdown in GB
"""
config = MODEL_CONFIGS[model_name]
# Approximate parameter count
# Embedding: vocab_size * hidden_dim
# Attention per layer: 4 * hidden_dim^2 (Q, K, V, O projections)
# FFN per layer: ~8 * hidden_dim^2 (typical 4x expansion)
# LM head: vocab_size * hidden_dim
hidden = config["hidden_dim"]
layers = config["layers"]
vocab = config["vocab_size"]
embedding_params = vocab * hidden
attention_params = 4 * hidden * hidden * layers
ffn_params = 8 * hidden * hidden * layers
lm_head_params = vocab * hidden
total_params = embedding_params + attention_params + ffn_params + lm_head_params
# FP16 = 2 bytes per parameter
memory_gb = (total_params * 2) / (1024 ** 3)
return {
"total_params_millions": total_params / 1e6,
"model_memory_gb": memory_gb,
"breakdown": {
"embeddings_gb": (embedding_params * 2) / (1024 ** 3),
"attention_gb": (attention_params * 2) / (1024 ** 3),
"ffn_gb": (ffn_params * 2) / (1024 ** 3),
"lm_head_gb": (lm_head_params * 2) / (1024 ** 3),
}
}
def calculate_kv_cache_size(
model_name: str,
seq_len: int,
batch_size: int = 1,
dtype_bytes: int = 2 # FP16
) -> dict:
"""
Calculate KV cache memory for given sequence length.
Args:
model_name: Key from MODEL_CONFIGS
seq_len: Sequence length
batch_size: Batch size
dtype_bytes: Bytes per element (2 for FP16, 4 for FP32)
Returns:
Dict with KV cache size information
"""
config = MODEL_CONFIGS[model_name]
layers = config["layers"]
kv_heads = config["kv_heads"]
head_dim = config["head_dim"]
# KV cache size: 2 (K and V) * layers * kv_heads * seq_len * head_dim * batch_size * dtype_bytes
kv_cache_bytes = 2 * layers * kv_heads * seq_len * head_dim * batch_size * dtype_bytes
kv_cache_gb = kv_cache_bytes / (1024 ** 3)
# Calculate what it would be with MHA (all heads have own KV)
q_heads = config["q_heads"]
mha_cache_bytes = 2 * layers * q_heads * seq_len * head_dim * batch_size * dtype_bytes
mha_cache_gb = mha_cache_bytes / (1024 ** 3)
return {
"gqa_cache_gb": kv_cache_gb,
"mha_cache_gb": mha_cache_gb,
"savings_ratio": q_heads / kv_heads,
"savings_gb": mha_cache_gb - kv_cache_gb,
}
def clear_model_cache():
"""Clear all cached models to free memory."""
global _model_cache, _tokenizer_cache
_model_cache.clear()
_tokenizer_cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_available_models() -> list:
"""Return list of available model names."""
return list(MODEL_CONFIGS.keys())