| |
| from urllib.parse import unquote, urlparse |
|
|
| import gradio as gr |
| import torch |
| from accelerate.commands.estimate import check_has_model, create_empty_model, estimate_training_usage |
| from accelerate.utils import calculate_maximum_sizes, convert_bytes |
| from huggingface_hub import auth_check |
| from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError |
|
|
|
|
| DTYPE_MODIFIER = {"float32": 1, "float16/bfloat16": 2, "int8": 4, "int4": 8} |
|
|
|
|
| def extract_from_url(name: str): |
| "Checks if `name` is a URL, and if so converts it to a model name" |
| is_url = False |
| try: |
| result = urlparse(name) |
| is_url = all([result.scheme, result.netloc]) |
| except Exception: |
| is_url = False |
|
|
| if not is_url: |
| return name |
|
|
| path = unquote(result.path).strip("/") |
| if path == "": |
| return name |
|
|
| parts = [part for part in path.split("/") if part] |
| if len(parts) >= 3 and parts[0] in {"models", "datasets", "spaces"}: |
| parts = parts[1:] |
|
|
| if len(parts) >= 2: |
| return "/".join(parts[:2]) |
| return "/".join(parts) |
|
|
|
|
| def translate_llama(text: str): |
| "Translates Llama-2 and CodeLlama to its hf counterpart" |
| if not text.endswith("-hf"): |
| return text + "-hf" |
| return text |
|
|
|
|
| def normalize_model_name(model_name: str): |
| model_name = extract_from_url(model_name.strip()) |
| if "meta-llama/Llama-2-" in model_name or "meta-llama/CodeLlama-" in model_name: |
| model_name = translate_llama(model_name) |
| return model_name.rstrip("/") |
|
|
|
|
| def classify_loader_error(model_name: str, error: Exception): |
| message = str(error) |
| lowered = message.lower() |
|
|
| if "timed out" in lowered or "timeout" in lowered: |
| return gr.Error( |
| f"Model `{model_name}` timed out during the Hub access or static initialization step. " |
| "Please try again, try a narrower model repo, or select the library manually." |
| ) |
|
|
| if ( |
| "401" in lowered |
| or "403" in lowered |
| or "unauthorized" in lowered |
| or "forbidden" in lowered |
| or "permission" in lowered |
| ): |
| return gr.Error( |
| f"Model `{model_name}` could not be accessed with the current credentials. " |
| "Please sign in with Hugging Face or paste a token that has access to this repo." |
| ) |
|
|
| if "connection" in lowered or "temporarily unavailable" in lowered or "service unavailable" in lowered: |
| return gr.Error( |
| f"Model `{model_name}` could not be reached from this Space right now. " |
| "Please retry in a moment." |
| ) |
|
|
| if "no module named" in lowered or "cannot import name" in lowered: |
| return gr.Error( |
| f"Model `{model_name}` requires custom code or extra dependencies that are not available in this Space. " |
| f"This often means the repository depends on a package that is not installed here. Error: `{error}`" |
| ) |
|
|
| if "trust_remote_code" in lowered or "remote code" in lowered: |
| return gr.Error( |
| f"Model `{model_name}` uses custom code from the Hub and could not be initialized in this Space. " |
| f"Please inspect the repository code and make sure it is trusted and compatible with the current runtime. Error: `{error}`" |
| ) |
|
|
| if "config" in lowered and "auto" in lowered: |
| return gr.Error( |
| f"Model `{model_name}` could not be resolved through the current library auto-detection path. " |
| f"Please try selecting `transformers` or `timm` manually. Error: `{error}`" |
| ) |
|
|
| return gr.Error( |
| f"Model `{model_name}` had an error during static initialization in this Space. " |
| f"Please open a discussion on the model page and include this message: `{error}`" |
| ) |
|
|
|
|
| def raise_model_error(model_name: str, error: Exception): |
| raise classify_loader_error(model_name, error) |
|
|
|
|
| def preflight_model_access_normalized(normalized_name: str, access_token: str | None): |
| try: |
| auth_check(normalized_name, token=access_token) |
| except GatedRepoError: |
| raise gr.Error( |
| f"Model `{normalized_name}` is a gated model. Please sign in with Hugging Face or pass an access token that already has access." |
| ) |
| except RepositoryNotFoundError: |
| raise gr.Error(f"Model `{normalized_name}` was not found on the Hub. Please try another model name.") |
| except gr.Error: |
| raise |
| except Exception as error: |
| classified_error = classify_loader_error(normalized_name, error) |
| if "timed out" in str(classified_error).lower(): |
| raise classified_error |
| if "could not be accessed" in str(classified_error).lower(): |
| raise classified_error |
| if "could not be reached" in str(classified_error).lower(): |
| raise classified_error |
| |
| pass |
|
|
| return normalized_name |
|
|
|
|
| def preflight_model_access(model_name: str, access_token: str | None): |
| return preflight_model_access_normalized(normalize_model_name(model_name), access_token) |
|
|
|
|
| def get_model_normalized(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False): |
| "Finds and grabs model from the Hub, and initializes on `meta`" |
| if library == "auto": |
| library = None |
|
|
| if not skip_auth_check: |
| preflight_model_access_normalized(model_name, access_token) |
|
|
| try: |
| model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token) |
| except GatedRepoError: |
| raise gr.Error( |
| f"Model `{model_name}` is a gated model, please ensure to pass in your access token or sign in with Hugging Face and try again if you have access." |
| ) |
| except RepositoryNotFoundError: |
| raise gr.Error(f"Model `{model_name}` was not found on the Hub, please try another model name.") |
| except ValueError: |
| raise gr.Error( |
| f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)" |
| ) |
| except (RuntimeError, OSError) as error: |
| library_name = check_has_model(error) |
| if library_name != "unknown": |
| raise gr.Error( |
| f"Tried to load `{model_name}` with `{library_name}` but a possible model to load was not found inside the repo." |
| ) |
| raise_model_error(model_name, error) |
| except ImportError as error: |
| try: |
| model = create_empty_model( |
| model_name, library_name=library, trust_remote_code=False, access_token=access_token |
| ) |
| except Exception: |
| raise_model_error(model_name, error) |
| except Exception as error: |
| raise_model_error(model_name, error) |
| return model |
|
|
|
|
| def get_model(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False): |
| return get_model_normalized( |
| normalize_model_name(model_name), |
| library, |
| access_token, |
| skip_auth_check=skip_auth_check, |
| ) |
|
|
|
|
| def calculate_memory(model: torch.nn.Module, options: list): |
| "Calculates the memory usage for a model init on `meta` device" |
| total_size, largest_layer = calculate_maximum_sizes(model) |
|
|
| data = [] |
| for dtype in options: |
| dtype_total_size = total_size |
| dtype_largest_layer = largest_layer[0] |
|
|
| modifier = DTYPE_MODIFIER[dtype] |
| dtype_training_size = estimate_training_usage( |
| dtype_total_size, dtype if dtype != "float16/bfloat16" else "float16" |
| ) |
| dtype_total_size /= modifier |
| dtype_largest_layer /= modifier |
|
|
| dtype_total_size = convert_bytes(dtype_total_size) |
| dtype_largest_layer = convert_bytes(dtype_largest_layer) |
| data.append( |
| { |
| "dtype": dtype, |
| "Largest Layer or Residual Group": dtype_largest_layer, |
| "Total Size": dtype_total_size, |
| "Training using Adam (Peak vRAM)": dtype_training_size, |
| } |
| ) |
| return data |
|
|