""" Model loading and caching for FlashAttention Explorer. Uses real HuggingFace models with SDPA attention implementation. """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Tuple, Optional import os from .constants import MODEL_CONFIGS # Global cache to avoid reloading models _model_cache: dict = {} _tokenizer_cache: dict = {} def get_device() -> str: """Get the appropriate device (cuda if available, else cpu).""" return "cuda" if torch.cuda.is_available() else "cpu" def load_model(model_name: str, force_reload: bool = False) -> AutoModelForCausalLM: """ Load a model with caching to avoid redundant downloads. Args: model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M") force_reload: If True, reload even if cached Returns: Loaded model on appropriate device """ if model_name not in MODEL_CONFIGS: raise ValueError(f"Unknown model: {model_name}. Available: {list(MODEL_CONFIGS.keys())}") if model_name in _model_cache and not force_reload: return _model_cache[model_name] config = MODEL_CONFIGS[model_name] model_id = config["model_id"] # Check if we need token for gated models (Llama) token = os.environ.get("HF_TOKEN", None) # Load model with SDPA attention for backend switching model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" if torch.cuda.is_available() else None, attn_implementation="sdpa", # Enable PyTorch SDPA backends token=token, trust_remote_code=True, ) # Move to device if not using device_map if not torch.cuda.is_available(): model = model.to("cpu") model.eval() _model_cache[model_name] = model return model def load_tokenizer(model_name: str) -> AutoTokenizer: """ Load tokenizer with caching. Args: model_name: Key from MODEL_CONFIGS Returns: Loaded tokenizer """ if model_name in _tokenizer_cache: return _tokenizer_cache[model_name] config = MODEL_CONFIGS[model_name] model_id = config["model_id"] token = os.environ.get("HF_TOKEN", None) tokenizer = AutoTokenizer.from_pretrained( model_id, token=token, trust_remote_code=True, ) # Ensure padding token exists if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token _tokenizer_cache[model_name] = tokenizer return tokenizer def load_model_and_tokenizer( model_name: str ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: """ Load both model and tokenizer. Args: model_name: Key from MODEL_CONFIGS Returns: Tuple of (model, tokenizer) """ model = load_model(model_name) tokenizer = load_tokenizer(model_name) return model, tokenizer def get_model_memory_footprint(model_name: str) -> dict: """ Calculate theoretical memory footprint for a model. Args: model_name: Key from MODEL_CONFIGS Returns: Dict with memory breakdown in GB """ config = MODEL_CONFIGS[model_name] # Approximate parameter count # Embedding: vocab_size * hidden_dim # Attention per layer: 4 * hidden_dim^2 (Q, K, V, O projections) # FFN per layer: ~8 * hidden_dim^2 (typical 4x expansion) # LM head: vocab_size * hidden_dim hidden = config["hidden_dim"] layers = config["layers"] vocab = config["vocab_size"] embedding_params = vocab * hidden attention_params = 4 * hidden * hidden * layers ffn_params = 8 * hidden * hidden * layers lm_head_params = vocab * hidden total_params = embedding_params + attention_params + ffn_params + lm_head_params # FP16 = 2 bytes per parameter memory_gb = (total_params * 2) / (1024 ** 3) return { "total_params_millions": total_params / 1e6, "model_memory_gb": memory_gb, "breakdown": { "embeddings_gb": (embedding_params * 2) / (1024 ** 3), "attention_gb": (attention_params * 2) / (1024 ** 3), "ffn_gb": (ffn_params * 2) / (1024 ** 3), "lm_head_gb": (lm_head_params * 2) / (1024 ** 3), } } def calculate_kv_cache_size( model_name: str, seq_len: int, batch_size: int = 1, dtype_bytes: int = 2 # FP16 ) -> dict: """ Calculate KV cache memory for given sequence length. Args: model_name: Key from MODEL_CONFIGS seq_len: Sequence length batch_size: Batch size dtype_bytes: Bytes per element (2 for FP16, 4 for FP32) Returns: Dict with KV cache size information """ config = MODEL_CONFIGS[model_name] layers = config["layers"] kv_heads = config["kv_heads"] head_dim = config["head_dim"] # KV cache size: 2 (K and V) * layers * kv_heads * seq_len * head_dim * batch_size * dtype_bytes kv_cache_bytes = 2 * layers * kv_heads * seq_len * head_dim * batch_size * dtype_bytes kv_cache_gb = kv_cache_bytes / (1024 ** 3) # Calculate what it would be with MHA (all heads have own KV) q_heads = config["q_heads"] mha_cache_bytes = 2 * layers * q_heads * seq_len * head_dim * batch_size * dtype_bytes mha_cache_gb = mha_cache_bytes / (1024 ** 3) return { "gqa_cache_gb": kv_cache_gb, "mha_cache_gb": mha_cache_gb, "savings_ratio": q_heads / kv_heads, "savings_gb": mha_cache_gb - kv_cache_gb, } def clear_model_cache(): """Clear all cached models to free memory.""" global _model_cache, _tokenizer_cache _model_cache.clear() _tokenizer_cache.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() def get_available_models() -> list: """Return list of available model names.""" return list(MODEL_CONFIGS.keys())