Spaces:
Running
Running
File size: 2,102 Bytes
34d520a 03a744c 34d520a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from __future__ import annotations
from functools import lru_cache
from typing import Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def torch_dtype(name: str) -> torch.dtype:
aliases = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
try:
return aliases[name.lower()]
except KeyError as exc:
raise ValueError(f"Unsupported torch dtype {name!r}") from exc
def resolve_device(requested: str) -> torch.device:
if requested == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
if requested.startswith("cuda") and not torch.cuda.is_available():
raise RuntimeError("CUDA was requested but torch.cuda.is_available() is false.")
return torch.device(requested)
@lru_cache(maxsize=None)
def load_tokenizer(model_id: str) -> Any:
tokenizer = AutoTokenizer.from_pretrained(model_id)
ensure_pad_token(tokenizer)
return tokenizer
def load_model_and_tokenizer(model_id: str, *, device: str, dtype: str) -> tuple[torch.nn.Module, Any]:
torch_device = resolve_device(device)
tokenizer = load_tokenizer(model_id)
model_dtype = torch_dtype(dtype)
if torch_device.type == "cpu" and model_dtype in {torch.float16, torch.bfloat16}:
model_dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=model_dtype,
low_cpu_mem_usage=True,
)
model.to(torch_device)
model.eval()
print(
f"Loaded {model_id} on {torch_device} "
f"(requested_device={device}, cuda_available={torch.cuda.is_available()})",
flush=True,
)
return model, tokenizer
def ensure_pad_token(tokenizer: Any) -> None:
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is None:
raise ValueError("Tokenizer has neither pad_token_id nor eos_token_id.")
tokenizer.pad_token = tokenizer.eos_token
|