Spaces:
Sleeping
Sleeping
| import argparse | |
| from functools import partial | |
| import gradio as gr | |
| from transformers import AutoConfig | |
| from estimate_train_vram import training_vram_required, inference_vram_required | |
| from vram_helpers import ModelConfig, TrainingConfig, filter_params_for_dataclass, PRECISION_TO_BYTES | |
| ZERO_STAGES = [0, 1, 2, 3] | |
| BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64] | |
| QUANTIZATION = PRECISION_TO_BYTES.keys() | |
| OPTIMIZERS = ["adam", "adamw", "adamw_8bit", "sgd"] | |
| HUGGINGFACE_URL_CONFIG = "https://huggingface.co/{}/resolve/main/config.json" | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Parser for VRAM estimator") | |
| parser.add_argument("--repo_id", type=str, default=None, help="HuggingFace repo id to automatically determine model settings") | |
| parser.add_argument("--model_size", type=float, default=7, help="Model size (in billion parameters)") | |
| parser.add_argument("--hidden_size", type=int, default=4096, help="Hidden size") | |
| parser.add_argument("--sequence_length", type=int, default=8192, help="Sequence length") | |
| parser.add_argument("--num_layers", type=int, default=32, help="Number of layers") | |
| parser.add_argument("--num_heads", type=int, default=32, help="Number of heads") | |
| parser.add_argument("--mixed_precision", action="store_false", help="Enable mixed precision for model training") | |
| parser.add_argument("--precision", type=str, default="bf16", help="Model precision for training") | |
| parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size (batch size per device/GPU)") | |
| parser.add_argument("--zero_stage", type=int, default=0, choices=ZERO_STAGES, help="ZeRO optimization stage") | |
| parser.add_argument("--gradient_checkpointing", action="store_false", help="Enable gradient checkpointing") | |
| parser.add_argument("--optimizer", type=str, default="adamw", choices=OPTIMIZERS, help="Type of optimizer") | |
| parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs. Necessary for estimating ZeRO stages") | |
| parser.add_argument("--cache_dir", type=str, default=None, help="HuggingFace cache directory to download config from") | |
| parser.add_argument("--qlora", action="store_false", help="Enable QLoRA in case of finetuning") | |
| parser.add_argument("--quantization", type=str, choices=QUANTIZATION, help="Type of quantization. Default is fp16/bf16") | |
| parser.add_argument("--train", action="store_false", help="Flag to turn off train and run inference") | |
| parser.add_argument("--total_sequence_length", type=int, default=0, help="Total sequence length (prompt + output) for inference") | |
| parser.add_argument("--no-app", action="store_true", help="Launch gradio app. Otherwise, commandline output") | |
| return parser | |
| def download_config_from_hub(repo_id: str, cache_dir: str): | |
| return AutoConfig.from_pretrained(pretrained_model_name_or_path=repo_id, cache_dir=cache_dir) | |
| def scrape_config_from_hub(repo_id): | |
| import requests | |
| url = HUGGINGFACE_URL_CONFIG.format(repo_id) | |
| try: | |
| print(f"Fetching config.json from the following URL: {url}...") | |
| response = requests.get(url) | |
| response.raise_for_status() # Raises a HTTPError if the status is 4xx, 5xx | |
| config = response.json() | |
| print(f"Fetched the config for model {repo_id} succesfully!") | |
| except requests.exceptions.HTTPError as errh: | |
| print(f"HTTP Error: {errh}") | |
| except requests.exceptions.ConnectionError as errc: | |
| print(f"Error Connecting: {errc}") | |
| except requests.exceptions.Timeout as errt: | |
| print(f"Timeout Error: {errt}") | |
| except requests.exceptions.RequestException as err: | |
| print(f"Something went wrong: {err}") | |
| except ValueError as e: | |
| print(f"Error decoding JSON: {e}") | |
| return config | |
| def build_interface(estimate_vram_fn): | |
| with gr.Blocks() as app: | |
| gr.Markdown("## 1. Select HuggingFace model from a repository or choose your own model parameters") | |
| model_option = gr.Radio(["Repo ID", "Model Parameters"], label="Select Input Type") | |
| repo_id = gr.Textbox(label="Repo ID", visible=False, placeholder="mistralai/Mistral-7B-v0.1") | |
| with gr.Row(visible=False) as model_params_row: | |
| model_params = [gr.Slider(label="Model Size", minimum=0.1, maximum=400, step=0.1, value=7, info="Model size (in billion parameters)"), | |
| gr.Slider(label="Hidden size", minimum=256, maximum=8192, step=128, value=4096, info="Hidden size"), | |
| gr.Slider(label="Sequence length", minimum=128, maximum=128_000, step=256, value=8192, info="Sequence length"), | |
| gr.Slider(label="Num layers", minimum=8, maximum=64, step=1, value=32, info="Number of layers"), | |
| gr.Slider(label="Num heads", minimum=8, maximum=64, step=1, value=32, info="Number of attention heads") | |
| ] | |
| def update_visibility_model_type(selected_option, choices): | |
| """ | |
| Dynamically update the visibility of components based on the selected option. | |
| :param selected_option: The currently selected option | |
| :param choices: Variable number of tuples, each containing (option_value, component) | |
| :return: List of gr.update() calls corresponding to each choice | |
| """ | |
| updates = [] | |
| for option_value, _ in choices: | |
| updates.append(gr.update(visible=(selected_option == option_value))) | |
| return updates | |
| model_option_choices = [("Repo ID", repo_id), ("Model Parameters", model_params_row)] | |
| model_option.change( | |
| fn=partial(update_visibility_model_type, choices=model_option_choices), | |
| inputs=[model_option], | |
| outputs=[repo_id, model_params_row], | |
| ) | |
| gr.Markdown("## 2. Select training or inference parameters") | |
| training_option = gr.Radio(["Training", "Inference"], label="Select Input Type") | |
| with gr.Row(equal_height=True, visible=False) as training_params_row: | |
| training_params = [gr.Dropdown(label="Micro batch size", choices=BATCH_SIZES, value=4, info="Micro batch size (batch size per device/GPU)"), | |
| gr.Dropdown(label="ZeRO stage", choices=ZERO_STAGES, value=0, info="ZeRO optimization stage"), | |
| gr.Dropdown(label="Gradient checkpointing", choices=[True, False], value=True, info="Enable gradient checkpointing"), | |
| gr.Dropdown(label="Mixed precision", choices=[False, True], value=False, info="Enable mixed precision for model training"), | |
| gr.Dropdown(label="Optimizer", choices=OPTIMIZERS, value="adamw", info="Type of optimizer"), | |
| gr.Dropdown(label="QLoRA", choices=[False, True], value=False, info="Finetune with QLoRA enabled"), | |
| gr.Slider(label="Num GPUs", minimum=1, maximum=256, step=1, value=4, info="Number of GPUs. Necessary for estimating ZeRO stages"), | |
| ] | |
| with gr.Row(equal_height=True, visible=False) as inference_params_row: | |
| inference_params = [gr.Dropdown(label="Quantization", choices=QUANTIZATION, value="fp16", info="Quantization of model"), | |
| gr.Slider(label="Num GPUs", minimum=1, maximum=256, step=1, value=1, info="Number of GPUs"), | |
| gr.Dropdown(label="Micro batch size", choices=BATCH_SIZES, value=1, info="Micro batch size (batch size per device/GPU)"), | |
| gr.Slider(label="Total sequence length", minimum=128, maximum=128_000, value=0, info="Total sequence length to run (necessary for KV cache calculation") | |
| ] | |
| training_option_choices = [("Training", inference_params_row), ("Inference", training_params_row)] | |
| training_option.change( | |
| fn=partial(update_visibility_model_type, choices=training_option_choices), | |
| inputs=[training_option], | |
| outputs=[training_params_row, inference_params_row], | |
| ) | |
| submit_btn = gr.Button("Estimate!") | |
| output = gr.Textbox(label="Total estimated VRAM per device/GPU (in GB)") | |
| def create_combined_params_dict(repo_id, training_option, *values): | |
| all_params = model_params + training_params + inference_params | |
| combined_dict = {param.label.lower().replace(" ", "_"): value for param, value in zip(all_params, values)} | |
| combined_dict["repo_id"] = repo_id | |
| combined_dict["train"] = True if training_option.lower() == "training" else False # False -> inference | |
| return combined_dict | |
| submit_btn.click( | |
| fn=lambda repo_id, training_option, *values: estimate_vram_fn(create_combined_params_dict(repo_id, training_option, *values)), | |
| inputs=[repo_id, training_option] + model_params + training_params + inference_params, | |
| outputs=[output] | |
| ) | |
| return app | |
| def estimate_vram(cache_dir, gradio_params): | |
| model_config = ModelConfig(**filter_params_for_dataclass(ModelConfig, gradio_params)) | |
| training_config = TrainingConfig(**filter_params_for_dataclass(TrainingConfig, gradio_params)) | |
| # Update model config | |
| if not gradio_params["repo_id"]: | |
| return "No model selected!" | |
| # By default, scrape config.json from hub | |
| config = download_config_from_hub(gradio_params["repo_id"], cache_dir)# gradio_params["cache_dir"]) | |
| model_config.overwrite_with_hf_config(config.to_dict()) | |
| if training_config.train: | |
| total_vram_dict = training_vram_required(model_config, training_config) | |
| output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB (activations)" | |
| else: # inference | |
| total_vram_dict = inference_vram_required(model_config, training_config) | |
| output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['kv_cache']}GB (KV cache) + {total_vram_dict['activations']}GB (activations)" | |
| return output_str | |
| if __name__ == "__main__": | |
| parser = parse_args() | |
| args = parser.parse_args() | |
| # Launch gradio interface | |
| if not args.no_app: | |
| import gradio as gr | |
| estimate_vram_fn = partial(estimate_vram, args.cache_dir) | |
| interface = build_interface(estimate_vram_fn) | |
| interface.launch() | |
| # Command line interface | |
| else: | |
| model_config = ModelConfig(**filter_params_for_dataclass(ModelConfig, vars(args))) | |
| training_config = TrainingConfig(**filter_params_for_dataclass(TrainingConfig, vars(args))) | |
| if args.repo_id: | |
| # If cache directory set, then download config | |
| if args.cache_dir: | |
| config = download_config_from_hub(args.repo_id, args.cache_dir).to_dict() | |
| # By default, scrape config.json from hub | |
| else: | |
| config = scrape_config_from_hub(args.repo_id) | |
| model_config.overwrite_with_hf_config(config) | |
| if training_config.train: | |
| total_vram_dict = training_vram_required(model_config, training_config) | |
| output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB activations" | |
| else: # inference | |
| total_vram_dict = inference_vram_required(model_config, training_config) | |
| output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['kv_cache']}GB (KV cache) + {total_vram_dict['activations']}GB activations" | |
| print(output_str) | |