model-memory-usage-mod / src /model_utils.py
John6666's picture
Upload 10 files
b25b2b1 verified
# Utilities related to loading in and working with models/specific models
from urllib.parse import unquote, urlparse
import gradio as gr
import torch
from accelerate.commands.estimate import check_has_model, create_empty_model, estimate_training_usage
from accelerate.utils import calculate_maximum_sizes, convert_bytes
from huggingface_hub import auth_check
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
DTYPE_MODIFIER = {"float32": 1, "float16/bfloat16": 2, "int8": 4, "int4": 8}
def extract_from_url(name: str):
"Checks if `name` is a URL, and if so converts it to a model name"
is_url = False
try:
result = urlparse(name)
is_url = all([result.scheme, result.netloc])
except Exception:
is_url = False
if not is_url:
return name
path = unquote(result.path).strip("/")
if path == "":
return name
parts = [part for part in path.split("/") if part]
if len(parts) >= 3 and parts[0] in {"models", "datasets", "spaces"}:
parts = parts[1:]
if len(parts) >= 2:
return "/".join(parts[:2])
return "/".join(parts)
def translate_llama(text: str):
"Translates Llama-2 and CodeLlama to its hf counterpart"
if not text.endswith("-hf"):
return text + "-hf"
return text
def normalize_model_name(model_name: str):
model_name = extract_from_url(model_name.strip())
if "meta-llama/Llama-2-" in model_name or "meta-llama/CodeLlama-" in model_name:
model_name = translate_llama(model_name)
return model_name.rstrip("/")
def classify_loader_error(model_name: str, error: Exception):
message = str(error)
lowered = message.lower()
if "timed out" in lowered or "timeout" in lowered:
return gr.Error(
f"Model `{model_name}` timed out during the Hub access or static initialization step. "
"Please try again, try a narrower model repo, or select the library manually."
)
if (
"401" in lowered
or "403" in lowered
or "unauthorized" in lowered
or "forbidden" in lowered
or "permission" in lowered
):
return gr.Error(
f"Model `{model_name}` could not be accessed with the current credentials. "
"Please sign in with Hugging Face or paste a token that has access to this repo."
)
if "connection" in lowered or "temporarily unavailable" in lowered or "service unavailable" in lowered:
return gr.Error(
f"Model `{model_name}` could not be reached from this Space right now. "
"Please retry in a moment."
)
if "no module named" in lowered or "cannot import name" in lowered:
return gr.Error(
f"Model `{model_name}` requires custom code or extra dependencies that are not available in this Space. "
f"This often means the repository depends on a package that is not installed here. Error: `{error}`"
)
if "trust_remote_code" in lowered or "remote code" in lowered:
return gr.Error(
f"Model `{model_name}` uses custom code from the Hub and could not be initialized in this Space. "
f"Please inspect the repository code and make sure it is trusted and compatible with the current runtime. Error: `{error}`"
)
if "config" in lowered and "auto" in lowered:
return gr.Error(
f"Model `{model_name}` could not be resolved through the current library auto-detection path. "
f"Please try selecting `transformers` or `timm` manually. Error: `{error}`"
)
return gr.Error(
f"Model `{model_name}` had an error during static initialization in this Space. "
f"Please open a discussion on the model page and include this message: `{error}`"
)
def raise_model_error(model_name: str, error: Exception):
raise classify_loader_error(model_name, error)
def preflight_model_access_normalized(normalized_name: str, access_token: str | None):
try:
auth_check(normalized_name, token=access_token)
except GatedRepoError:
raise gr.Error(
f"Model `{normalized_name}` is a gated model. Please sign in with Hugging Face or pass an access token that already has access."
)
except RepositoryNotFoundError:
raise gr.Error(f"Model `{normalized_name}` was not found on the Hub. Please try another model name.")
except gr.Error:
raise
except Exception as error:
classified_error = classify_loader_error(normalized_name, error)
if "timed out" in str(classified_error).lower():
raise classified_error
if "could not be accessed" in str(classified_error).lower():
raise classified_error
if "could not be reached" in str(classified_error).lower():
raise classified_error
# Fallback to the loader path for transient Hub metadata issues.
pass
return normalized_name
def preflight_model_access(model_name: str, access_token: str | None):
return preflight_model_access_normalized(normalize_model_name(model_name), access_token)
def get_model_normalized(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False):
"Finds and grabs model from the Hub, and initializes on `meta`"
if library == "auto":
library = None
if not skip_auth_check:
preflight_model_access_normalized(model_name, access_token)
try:
model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
except GatedRepoError:
raise gr.Error(
f"Model `{model_name}` is a gated model, please ensure to pass in your access token or sign in with Hugging Face and try again if you have access."
)
except RepositoryNotFoundError:
raise gr.Error(f"Model `{model_name}` was not found on the Hub, please try another model name.")
except ValueError:
raise gr.Error(
f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
)
except (RuntimeError, OSError) as error:
library_name = check_has_model(error)
if library_name != "unknown":
raise gr.Error(
f"Tried to load `{model_name}` with `{library_name}` but a possible model to load was not found inside the repo."
)
raise_model_error(model_name, error)
except ImportError as error:
try:
model = create_empty_model(
model_name, library_name=library, trust_remote_code=False, access_token=access_token
)
except Exception:
raise_model_error(model_name, error)
except Exception as error:
raise_model_error(model_name, error)
return model
def get_model(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False):
return get_model_normalized(
normalize_model_name(model_name),
library,
access_token,
skip_auth_check=skip_auth_check,
)
def calculate_memory(model: torch.nn.Module, options: list):
"Calculates the memory usage for a model init on `meta` device"
total_size, largest_layer = calculate_maximum_sizes(model)
data = []
for dtype in options:
dtype_total_size = total_size
dtype_largest_layer = largest_layer[0]
modifier = DTYPE_MODIFIER[dtype]
dtype_training_size = estimate_training_usage(
dtype_total_size, dtype if dtype != "float16/bfloat16" else "float16"
)
dtype_total_size /= modifier
dtype_largest_layer /= modifier
dtype_total_size = convert_bytes(dtype_total_size)
dtype_largest_layer = convert_bytes(dtype_largest_layer)
data.append(
{
"dtype": dtype,
"Largest Layer or Residual Group": dtype_largest_layer,
"Total Size": dtype_total_size,
"Training using Adam (Peak vRAM)": dtype_training_size,
}
)
return data