workbench / models /hf_components.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
1.29 kB
from __future__ import annotations
from importlib import import_module
from models.model_catalog import ModelInfo
def load_tokenizer_and_causal_lm(
model: ModelInfo,
trust_remote_code: bool,
device_map: str,
torch_dtype: str,
):
transformers_module = import_module("transformers")
tokenizer = transformers_module.AutoTokenizer.from_pretrained(
model.hf_id,
trust_remote_code=trust_remote_code,
)
model_instance = transformers_module.AutoModelForCausalLM.from_pretrained(
model.hf_id,
trust_remote_code=trust_remote_code,
device_map=device_map,
torch_dtype=torch_dtype,
)
return model_instance, tokenizer
def load_processor_and_image_text_model(
model: ModelInfo,
trust_remote_code: bool,
device_map: str,
torch_dtype: str,
):
transformers_module = import_module("transformers")
processor = transformers_module.AutoProcessor.from_pretrained(
model.hf_id,
trust_remote_code=trust_remote_code,
)
model_instance = transformers_module.AutoModelForImageTextToText.from_pretrained(
model.hf_id,
trust_remote_code=trust_remote_code,
device_map=device_map,
torch_dtype=torch_dtype,
)
return model_instance, processor