Spaces:
Sleeping
Sleeping
File size: 2,883 Bytes
fda8fb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
from app.core.model_support import ModelSupport, describe_model_support
@dataclass(slots=True)
class ModelBundle:
model_name: str
model: PreTrainedModel
tokenizer: PreTrainedTokenizerBase
device: torch.device
dtype: torch.dtype
capability: ModelSupport
def resolve_dtype(preference: str, device: torch.device) -> torch.dtype:
if preference == "float32":
return torch.float32
if preference == "float16":
return torch.float16
if preference == "bfloat16":
return torch.bfloat16
if device.type == "cuda":
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
if device.type == "mps":
return torch.float16
return torch.float32
def resolve_device(preference: str = "auto") -> torch.device:
if preference == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available.")
return torch.device("cuda")
if preference == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError("MPS requested but not available.")
return torch.device("mps")
if preference == "cpu":
return torch.device("cpu")
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
@lru_cache(maxsize=2)
def load_model_bundle(
model_name: str,
device_preference: str = "auto",
dtype_preference: str = "auto",
attn_implementation: str = "eager",
trust_remote_code: bool = True,
low_cpu_mem_usage: bool = True,
) -> ModelBundle:
device = resolve_device(device_preference)
dtype = resolve_dtype(dtype_preference, device)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
torch_dtype=dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
model.to(device)
model.eval()
capability = describe_model_support(model)
return ModelBundle(
model_name=model_name,
model=model,
tokenizer=tokenizer,
device=device,
dtype=dtype,
capability=capability,
)
def compute_attribution_analysis(**kwargs):
from app.core.runtime_pipeline import compute_attribution_analysis as _compute_attribution_analysis
return _compute_attribution_analysis(**kwargs)
|