from __future__ import annotations from dataclasses import dataclass from functools import lru_cache import torch from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase from app.core.model_support import ModelSupport, describe_model_support @dataclass(slots=True) class ModelBundle: model_name: str model: PreTrainedModel tokenizer: PreTrainedTokenizerBase device: torch.device dtype: torch.dtype capability: ModelSupport def resolve_dtype(preference: str, device: torch.device) -> torch.dtype: if preference == "float32": return torch.float32 if preference == "float16": return torch.float16 if preference == "bfloat16": return torch.bfloat16 if device.type == "cuda": return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 if device.type == "mps": return torch.float16 return torch.float32 def resolve_device(preference: str = "auto") -> torch.device: if preference == "cuda": if not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available.") return torch.device("cuda") if preference == "mps": if not torch.backends.mps.is_available(): raise RuntimeError("MPS requested but not available.") return torch.device("mps") if preference == "cpu": return torch.device("cpu") if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") @lru_cache(maxsize=2) def load_model_bundle( model_name: str, device_preference: str = "auto", dtype_preference: str = "auto", attn_implementation: str = "eager", trust_remote_code: bool = True, low_cpu_mem_usage: bool = True, ) -> ModelBundle: device = resolve_device(device_preference) dtype = resolve_dtype(dtype_preference, device) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation, torch_dtype=dtype, low_cpu_mem_usage=low_cpu_mem_usage, ) model.to(device) model.eval() capability = describe_model_support(model) return ModelBundle( model_name=model_name, model=model, tokenizer=tokenizer, device=device, dtype=dtype, capability=capability, ) def compute_attribution_analysis(**kwargs): from app.core.runtime_pipeline import compute_attribution_analysis as _compute_attribution_analysis return _compute_attribution_analysis(**kwargs)