| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import Optional |
| |
|
| | import torch |
| | from huggingface_hub import model_info |
| | from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError |
| |
|
| | from accelerate import init_empty_weights |
| | from accelerate.commands.utils import CustomArgumentParser |
| | from accelerate.utils import ( |
| | calculate_maximum_sizes, |
| | convert_bytes, |
| | is_timm_available, |
| | is_transformers_available, |
| | ) |
| |
|
| |
|
| | if is_transformers_available(): |
| | import transformers |
| | from transformers import AutoConfig, AutoModel |
| |
|
| | if is_timm_available(): |
| | import timm |
| |
|
| |
|
| | def verify_on_hub(repo: str, token: Optional[str] = None): |
| | "Verifies that the model is on the hub and returns the model info." |
| | try: |
| | return model_info(repo, token=token) |
| | except (OSError, GatedRepoError): |
| | return "gated" |
| | except RepositoryNotFoundError: |
| | return "repo" |
| |
|
| |
|
| | def check_has_model(error): |
| | """ |
| | Checks what library spawned `error` when a model is not found |
| | """ |
| | if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]: |
| | return "timm" |
| | elif ( |
| | is_transformers_available() |
| | and isinstance(error, OSError) |
| | and "does not appear to have a file named" in error.args[0] |
| | ): |
| | return "transformers" |
| | else: |
| | return "unknown" |
| |
|
| |
|
| | def create_empty_model( |
| | model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None |
| | ): |
| | """ |
| | Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory |
| | consumption. |
| | |
| | Args: |
| | model_name (`str`): |
| | The model name on the Hub |
| | library_name (`str`): |
| | The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no |
| | metadata on the Hub to determine the library. |
| | trust_remote_code (`bool`, `optional`, defaults to `False`): |
| | Whether or not to allow for custom models defined on the Hub in their own modeling files. This option |
| | should only be set to `True` for repositories you trust and in which you have read the code, as it will |
| | execute code present on the Hub on your local machine. |
| | access_token (`str`, `optional`, defaults to `None`): |
| | The access token to use to access private or gated models on the Hub. (for use on the Gradio app) |
| | |
| | Returns: |
| | `torch.nn.Module`: The torch model that has been initialized on the `meta` device. |
| | |
| | """ |
| | model_info = verify_on_hub(model_name, access_token) |
| | |
| | if model_info == "gated": |
| | raise GatedRepoError( |
| | f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`." |
| | ) |
| | elif model_info == "repo": |
| | raise RepositoryNotFoundError( |
| | f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo," |
| | " make sure you are authenticated via `huggingface-cli login` and have access." |
| | ) |
| | if library_name is None: |
| | library_name = getattr(model_info, "library_name", False) |
| | if not library_name: |
| | raise ValueError( |
| | f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)" |
| | ) |
| | if library_name == "transformers": |
| | if not is_transformers_available(): |
| | raise ImportError( |
| | f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`" |
| | ) |
| | print(f"Loading pretrained config for `{model_name}` from `transformers`...") |
| | if model_info.config is None: |
| | raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.") |
| |
|
| | auto_map = model_info.config.get("auto_map", False) |
| | config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token) |
| | with init_empty_weights(): |
| | |
| | constructor = AutoModel |
| | if isinstance(auto_map, dict): |
| | value = None |
| | for key in auto_map.keys(): |
| | if key.startswith("AutoModelFor"): |
| | value = key |
| | break |
| | if value is not None: |
| | constructor = getattr(transformers, value) |
| | |
| | model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code) |
| | elif library_name == "timm": |
| | if not is_timm_available(): |
| | raise ImportError( |
| | f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`" |
| | ) |
| | print(f"Loading pretrained config for `{model_name}` from `timm`...") |
| | with init_empty_weights(): |
| | model = timm.create_model(model_name, pretrained=False) |
| | else: |
| | raise ValueError( |
| | f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support." |
| | ) |
| | return model |
| |
|
| |
|
| | def create_ascii_table(headers: list, rows: list, title: str): |
| | "Creates a pretty table from a list of rows, minimal version of `tabulate`." |
| | sep_char, in_between = "│", "─" |
| | column_widths = [] |
| | for i in range(len(headers)): |
| | column_values = [row[i] for row in rows] + [headers[i]] |
| | max_column_width = max(len(value) for value in column_values) |
| | column_widths.append(max_column_width) |
| |
|
| | formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))] |
| |
|
| | pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}" |
| | diff = 0 |
| |
|
| | def make_row(left_char, middle_char, right_char): |
| | return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}" |
| |
|
| | separator = make_row("├", "┼", "┤") |
| | if len(title) > sum(column_widths): |
| | diff = abs(len(title) - len(separator)) |
| | column_widths[-1] += diff |
| |
|
| | |
| | separator = make_row("├", "┼", "┤") |
| | initial_rows = [ |
| | make_row("┌", in_between, "┐"), |
| | f"{sep_char}{title.center(len(separator) - 2)}{sep_char}", |
| | make_row("├", "┬", "┤"), |
| | ] |
| | table = "\n".join(initial_rows) + "\n" |
| | column_widths[-1] += diff |
| | centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)] |
| | table += f"{pattern % tuple(centered_line)}\n{separator}\n" |
| | for i, line in enumerate(rows): |
| | centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)] |
| | table += f"{pattern % tuple(centered_line)}\n" |
| | table += f"└{'┴'.join([in_between * n for n in column_widths])}┘" |
| |
|
| | return table |
| |
|
| |
|
| | def estimate_command_parser(subparsers=None): |
| | if subparsers is not None: |
| | parser = subparsers.add_parser("estimate-memory") |
| | else: |
| | parser = CustomArgumentParser(description="Model size estimator for fitting a model onto CUDA memory.") |
| |
|
| | parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.") |
| | parser.add_argument( |
| | "--library_name", |
| | type=str, |
| | help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.", |
| | choices=["timm", "transformers"], |
| | ) |
| | parser.add_argument( |
| | "--dtypes", |
| | type=str, |
| | nargs="+", |
| | default=["float32", "float16", "int8", "int4"], |
| | help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`", |
| | choices=["float32", "float16", "int8", "int4"], |
| | ) |
| | parser.add_argument( |
| | "--trust_remote_code", |
| | action="store_true", |
| | help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag |
| | should only be used for repositories you trust and in which you have read the code, as it will execute |
| | code present on the Hub on your local machine.""", |
| | default=False, |
| | ) |
| |
|
| | if subparsers is not None: |
| | parser.set_defaults(func=estimate_command) |
| | return parser |
| |
|
| |
|
| | def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict: |
| | """ |
| | Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of |
| | 1. |
| | |
| | Args: |
| | bytes (`int`): |
| | The size of the model being trained. |
| | mixed_precision (`str`): |
| | The mixed precision that would be ran. |
| | msamp_config (`str`): |
| | The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`. |
| | """ |
| | memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1} |
| | fp32_size = bytes |
| | fp16_size = bytes // 2 |
| |
|
| | if mixed_precision == "float32": |
| | memory_sizes["model"] = fp32_size |
| | memory_sizes["gradients"] = fp32_size |
| | memory_sizes["optimizer"] = fp32_size * 2 |
| | memory_sizes["step"] = fp32_size * 4 |
| | elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None): |
| | |
| | |
| | |
| | memory_sizes["model"] = fp32_size |
| | |
| | memory_sizes["gradients"] = fp32_size + fp16_size |
| | |
| | memory_sizes["optimizer"] = fp32_size * 2 |
| | memory_sizes["step"] = memory_sizes["optimizer"] |
| | return memory_sizes |
| |
|
| |
|
| | def gather_data(args): |
| | "Creates an empty model and gathers the data for the sizes" |
| | try: |
| | model = create_empty_model( |
| | args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code |
| | ) |
| | except (RuntimeError, OSError) as e: |
| | library = check_has_model(e) |
| | if library != "unknown": |
| | raise RuntimeError( |
| | f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo." |
| | ) |
| | raise e |
| |
|
| | total_size, largest_layer = calculate_maximum_sizes(model) |
| |
|
| | data = [] |
| |
|
| | for dtype in args.dtypes: |
| | dtype_total_size = total_size |
| | dtype_largest_layer = largest_layer[0] |
| | dtype_training_size = estimate_training_usage(dtype_total_size, dtype) |
| | if dtype == "float16": |
| | dtype_total_size /= 2 |
| | dtype_largest_layer /= 2 |
| | elif dtype == "int8": |
| | dtype_total_size /= 4 |
| | dtype_largest_layer /= 4 |
| | elif dtype == "int4": |
| | dtype_total_size /= 8 |
| | dtype_largest_layer /= 8 |
| | data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size]) |
| | return data |
| |
|
| |
|
| | def estimate_command(args): |
| | data = gather_data(args) |
| | for row in data: |
| | for i, item in enumerate(row): |
| | if isinstance(item, (int, float)): |
| | row[i] = convert_bytes(item) |
| | elif isinstance(item, dict): |
| | training_usage = max(item.values()) |
| | row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A" |
| |
|
| | headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"] |
| |
|
| | title = f"Memory Usage for loading `{args.model_name}`" |
| | table = create_ascii_table(headers, data, title) |
| | print(table) |
| |
|
| |
|
| | def main(): |
| | parser = estimate_command_parser() |
| | args = parser.parse_args() |
| | estimate_command(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|