Spaces:
Running
Running
| import copy | |
| import hashlib | |
| import json | |
| import tempfile | |
| import threading | |
| import time | |
| import traceback | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| from uuid import uuid4 | |
| import accelerate | |
| import gradio as gr | |
| import huggingface_hub | |
| try: | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| HAS_HF_HUB_SEARCH = True | |
| except Exception: | |
| HuggingfaceHubSearch = None | |
| HAS_HF_HUB_SEARCH = False | |
| import pandas as pd | |
| import timm | |
| import transformers | |
| from accelerate.utils import convert_bytes | |
| from model_utils import ( | |
| calculate_memory, | |
| get_model_normalized, | |
| normalize_model_name, | |
| preflight_model_access_normalized, | |
| ) | |
| DEFAULT_MODEL = "bert-base-cased" | |
| DEFAULT_LIBRARY = "auto" | |
| DEFAULT_OPTIONS = ["float32"] | |
| RESULTS_CACHE_SIZE = 128 | |
| DOWNLOAD_RETENTION_SECONDS = 60 * 60 | |
| DOWNLOAD_CLEANUP_MAX_FILES = 256 | |
| def log_startup_versions(): | |
| print( | |
| "[startup] versions " | |
| f"gradio={gr.__version__} " | |
| f"accelerate={accelerate.__version__} " | |
| f"transformers={transformers.__version__} " | |
| f"huggingface_hub={huggingface_hub.__version__} " | |
| f"timm={timm.__version__}" | |
| ) | |
| log_startup_versions() | |
| class EstimateRequest: | |
| original_model_name: str | |
| normalized_model_name: str | |
| library: str | |
| options: tuple[str, ...] | |
| access_token: str | None | |
| auth_mode: str | |
| def cache_key(self): | |
| token_key = "anonymous" | |
| if self.access_token is not None: | |
| token_key = hashlib.sha256(self.access_token.encode("utf-8")).hexdigest() | |
| return ( | |
| self.normalized_model_name, | |
| self.library, | |
| self.options, | |
| token_key, | |
| ) | |
| class EstimatePayload: | |
| display_rows: list[dict] | |
| raw_rows: list[dict] | |
| explanation: str | |
| breakdown_df: pd.DataFrame | |
| class EstimateViewModel: | |
| title: str | |
| auth_message: str | |
| summary_df: pd.DataFrame | |
| explanation: str | |
| breakdown_df: pd.DataFrame | |
| error_summary: str = "" | |
| error_details: str = "" | |
| summary_path: str | None = None | |
| breakdown_path: str | None = None | |
| json_path: str | None = None | |
| def to_updates(self): | |
| return [ | |
| self.title, | |
| gr.update(value=self.auth_message, visible=True), | |
| gr.update(visible=not self.summary_df.empty, value=self.summary_df), | |
| gr.update(visible=self.explanation != "", value=self.explanation), | |
| gr.update(visible=not self.breakdown_df.empty, value=self.breakdown_df), | |
| gr.update(visible=self.error_summary != "", value=self.error_summary), | |
| gr.update(visible=self.error_details != "", value=self.error_details), | |
| gr.update(visible=self.summary_path is not None, value=self.summary_path), | |
| gr.update(visible=self.breakdown_path is not None, value=self.breakdown_path), | |
| gr.update(visible=self.json_path is not None, value=self.json_path), | |
| ] | |
| class ResetViewModel: | |
| model_name: str = DEFAULT_MODEL | |
| library: str = DEFAULT_LIBRARY | |
| options: list[str] | tuple[str, ...] = None | |
| access_token: str = "" | |
| title: str = "" | |
| def __post_init__(self): | |
| if self.options is None: | |
| self.options = list(DEFAULT_OPTIONS) | |
| def to_updates(self): | |
| return [ | |
| self.model_name, | |
| self.library, | |
| list(self.options), | |
| self.access_token, | |
| self.title, | |
| gr.update(visible=False, value=""), | |
| gr.update(visible=False, value=pd.DataFrame()), | |
| gr.update(visible=False, value=""), | |
| gr.update(visible=False, value=pd.DataFrame()), | |
| gr.update(visible=False, value=""), | |
| gr.update(visible=False, value=""), | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=False, value=None), | |
| ] | |
| class _InflightEntry: | |
| event: threading.Event | |
| data: list[dict] | None = None | |
| error: Exception | None = None | |
| class ResultCache: | |
| def __init__(self, max_size: int): | |
| self.max_size = max_size | |
| self._values = OrderedDict() | |
| self._lock = threading.Lock() | |
| self._inflight: dict[tuple, _InflightEntry] = {} | |
| def get_or_compute(self, request: EstimateRequest, compute_fn): | |
| cache_key = request.cache_key | |
| with self._lock: | |
| if cache_key in self._values: | |
| self._values.move_to_end(cache_key) | |
| return copy.deepcopy(self._values[cache_key]) | |
| entry = self._inflight.get(cache_key) | |
| if entry is None: | |
| entry = _InflightEntry(event=threading.Event()) | |
| self._inflight[cache_key] = entry | |
| is_owner = True | |
| else: | |
| is_owner = False | |
| if not is_owner: | |
| entry.event.wait() | |
| if entry.error is not None: | |
| raise entry.error | |
| return copy.deepcopy(entry.data) | |
| try: | |
| data = compute_fn() | |
| with self._lock: | |
| self._values[cache_key] = copy.deepcopy(data) | |
| if len(self._values) > self.max_size: | |
| self._values.popitem(last=False) | |
| entry.data = copy.deepcopy(data) | |
| return copy.deepcopy(data) | |
| except Exception as error: | |
| entry.error = error | |
| raise | |
| finally: | |
| entry.event.set() | |
| with self._lock: | |
| self._inflight.pop(cache_key, None) | |
| RESULT_CACHE = ResultCache(max_size=RESULTS_CACHE_SIZE) | |
| def get_auth_status(oauth_profile: gr.OAuthProfile | None): | |
| if oauth_profile is None: | |
| return "Not signed in. You can still paste an API token for gated models." | |
| username = getattr(oauth_profile, "preferred_username", None) or getattr(oauth_profile, "name", None) | |
| if username is None: | |
| username = "Hugging Face user" | |
| return ( | |
| f"Signed in as `{username}`. " | |
| "If the API Token field is blank, this session token will be used for gated models." | |
| ) | |
| def use_hub_search(repo_id: str | None): | |
| return (repo_id or "").strip() | |
| def get_hub_search_status(): | |
| if HAS_HF_HUB_SEARCH: | |
| return "Search Hugging Face Hub to fill the model field automatically." | |
| return "Hub Search component is unavailable in this runtime. Manual model input still works." | |
| def validate_model_name(model_name: str): | |
| stripped_name = model_name.strip() | |
| if stripped_name == "": | |
| raise gr.Error("Enter a model name or a Hugging Face model URL.") | |
| try: | |
| parsed = urlparse(stripped_name) | |
| if parsed.scheme and parsed.netloc: | |
| valid_hosts = {"huggingface.co", "www.huggingface.co"} | |
| if parsed.netloc not in valid_hosts: | |
| raise gr.Error("Only Hugging Face model URLs are supported here.") | |
| except gr.Error: | |
| raise | |
| except Exception: | |
| pass | |
| return stripped_name | |
| def validate_options(options: list): | |
| if not options: | |
| raise gr.Error("Select at least one precision.") | |
| def validate_access_token(access_token: str): | |
| if access_token and any(char.isspace() for char in access_token): | |
| raise gr.Error("API tokens should not contain whitespace.") | |
| def resolve_access_token(access_token: str, oauth_token: gr.OAuthToken | None): | |
| if access_token == "": | |
| access_token = None | |
| if access_token is not None: | |
| return access_token, "manual" | |
| if oauth_token is not None: | |
| return oauth_token.token, "oauth" | |
| return None, "anonymous" | |
| def build_estimate_request( | |
| model_name: str, | |
| library: str, | |
| options: list, | |
| access_token: str, | |
| oauth_token: gr.OAuthToken | None, | |
| ): | |
| stripped_name = validate_model_name(model_name) | |
| validate_options(options) | |
| validate_access_token(access_token) | |
| normalized_name = normalize_model_name(stripped_name) | |
| resolved_token, auth_mode = resolve_access_token(access_token, oauth_token) | |
| return EstimateRequest( | |
| original_model_name=stripped_name, | |
| normalized_model_name=normalized_name, | |
| library=library, | |
| options=tuple(options), | |
| access_token=resolved_token, | |
| auth_mode=auth_mode, | |
| ) | |
| def get_auth_message(auth_mode: str): | |
| if auth_mode == "manual": | |
| return "Using the manually provided API token for this estimate." | |
| if auth_mode == "oauth": | |
| return "Using your Hugging Face OAuth session for this estimate." | |
| return "Running anonymously. Gated models will require a token or a signed-in Hugging Face session." | |
| def get_download_dir(): | |
| temp_dir = Path(tempfile.gettempdir()) / "model_memory_usage" | |
| temp_dir.mkdir(parents=True, exist_ok=True) | |
| return temp_dir | |
| def cleanup_old_download_files(temp_dir: Path): | |
| cutoff = time.time() - DOWNLOAD_RETENTION_SECONDS | |
| try: | |
| entries = [path for path in temp_dir.iterdir() if path.is_file()] | |
| except FileNotFoundError: | |
| return | |
| for path in entries: | |
| try: | |
| if path.stat().st_mtime < cutoff: | |
| path.unlink(missing_ok=True) | |
| except OSError: | |
| continue | |
| try: | |
| remaining_files = sorted( | |
| [path for path in temp_dir.iterdir() if path.is_file()], | |
| key=lambda path: path.stat().st_mtime, | |
| reverse=True, | |
| ) | |
| except FileNotFoundError: | |
| return | |
| for stale_path in remaining_files[DOWNLOAD_CLEANUP_MAX_FILES:]: | |
| try: | |
| stale_path.unlink(missing_ok=True) | |
| except OSError: | |
| continue | |
| def make_download_files(model_name: str, summary_df: pd.DataFrame, breakdown_df: pd.DataFrame, raw_data: list): | |
| safe_name = model_name.replace("/", "__") or "model" | |
| temp_dir = get_download_dir() | |
| cleanup_old_download_files(temp_dir) | |
| unique_id = uuid4().hex | |
| summary_path = temp_dir / f"{safe_name}_{unique_id}_summary.csv" | |
| summary_df.to_csv(summary_path, index=False) | |
| breakdown_path = None | |
| if not breakdown_df.empty: | |
| breakdown_path = temp_dir / f"{safe_name}_{unique_id}_adam_breakdown.csv" | |
| breakdown_df.to_csv(breakdown_path, index=False) | |
| json_path = temp_dir / f"{safe_name}_{unique_id}_estimate.json" | |
| with json_path.open("w", encoding="utf-8") as handle: | |
| json.dump({"model_name": model_name, "estimates": raw_data}, handle, indent=2) | |
| return str(summary_path), str(breakdown_path) if breakdown_path is not None else None, str(json_path) | |
| def fetch_raw_estimate_data(request: EstimateRequest): | |
| def _compute(): | |
| model = get_model_normalized( | |
| request.normalized_model_name, | |
| request.library, | |
| request.access_token, | |
| skip_auth_check=True, | |
| ) | |
| return calculate_memory(model, list(request.options)) | |
| return RESULT_CACHE.get_or_compute(request, _compute) | |
| def build_estimate_payload(raw_rows: list[dict], options: tuple[str, ...]): | |
| display_rows = copy.deepcopy(raw_rows) | |
| stages = {"model": [], "gradients": [], "optimizer": [], "step": []} | |
| for index, option in enumerate(display_rows): | |
| for stage in stages: | |
| stages[stage].append(option["Training using Adam (Peak vRAM)"][stage]) | |
| peak_value = max(display_rows[index]["Training using Adam (Peak vRAM)"].values()) | |
| display_rows[index]["Training using Adam (Peak vRAM)"] = "N/A" if peak_value == -1 else convert_bytes(peak_value) | |
| explanation = "" | |
| breakdown_df = pd.DataFrame( | |
| columns=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"] | |
| ) | |
| if any(value != -1 for value in stages["model"]): | |
| explanation = "## Training using Adam explained:\n" | |
| explanation += ( | |
| "When training on a batch size of 1, each stage of the training process is expected " | |
| "to have near the following memory results for each precision you selected:\n" | |
| ) | |
| for index, dtype in enumerate(options): | |
| if stages["model"][index] != -1: | |
| breakdown_df.loc[len(breakdown_df.index)] = [ | |
| dtype, | |
| convert_bytes(stages["model"][index]), | |
| convert_bytes(stages["gradients"][index]), | |
| convert_bytes(stages["optimizer"][index]), | |
| convert_bytes(stages["step"][index]), | |
| ] | |
| return EstimatePayload( | |
| display_rows=display_rows, | |
| raw_rows=copy.deepcopy(raw_rows), | |
| explanation=explanation, | |
| breakdown_df=breakdown_df, | |
| ) | |
| def build_success_view_model(request: EstimateRequest, payload: EstimatePayload): | |
| auth_message = get_auth_message(request.auth_mode) | |
| summary_df = pd.DataFrame(payload.display_rows) | |
| summary_path, breakdown_path, json_path = make_download_files( | |
| request.normalized_model_name, | |
| summary_df, | |
| payload.breakdown_df, | |
| payload.raw_rows, | |
| ) | |
| return EstimateViewModel( | |
| title=f"## Static memory estimate for `{request.normalized_model_name}`", | |
| auth_message=auth_message, | |
| summary_df=summary_df, | |
| explanation=payload.explanation, | |
| breakdown_df=payload.breakdown_df, | |
| summary_path=summary_path, | |
| breakdown_path=breakdown_path, | |
| json_path=json_path, | |
| ) | |
| def build_error_view_model(request: EstimateRequest, error: Exception): | |
| auth_message = get_auth_message(request.auth_mode) | |
| message = str(error).strip() or error.__class__.__name__ | |
| details = traceback.format_exc().strip() | |
| return EstimateViewModel( | |
| title=f"## Unable to estimate memory for `{request.normalized_model_name}`", | |
| auth_message=auth_message, | |
| summary_df=pd.DataFrame(), | |
| explanation="", | |
| breakdown_df=pd.DataFrame(), | |
| error_summary=( | |
| f"{message}\n\n" | |
| "Check the **Details** section below for the full traceback." | |
| ), | |
| error_details=details, | |
| ) | |
| def reset_app(): | |
| return ResetViewModel().to_updates() | |
| def get_results( | |
| model_name: str, | |
| library: str, | |
| options: list, | |
| access_token: str, | |
| oauth_token: gr.OAuthToken | None, | |
| progress=gr.Progress(track_tqdm=False), | |
| ): | |
| progress(0.05, desc="Checking inputs") | |
| request = build_estimate_request(model_name, library, options, access_token, oauth_token) | |
| try: | |
| progress(0.12, desc="Checking Hub access") | |
| preflight_model_access_normalized(request.normalized_model_name, request.access_token) | |
| progress(0.3, desc="Building model skeleton") | |
| raw_rows = fetch_raw_estimate_data(request) | |
| progress(0.75, desc="Formatting results") | |
| payload = build_estimate_payload(raw_rows, request.options) | |
| progress(0.95, desc="Writing downloads") | |
| view_model = build_success_view_model(request, payload) | |
| progress(1.0, desc="Done") | |
| return view_model.to_updates() | |
| except Exception as error: | |
| progress(1.0, desc="Failed") | |
| return build_error_view_model(request, error).to_updates() | |
| with gr.Blocks(delete_cache=(3600, DOWNLOAD_RETENTION_SECONDS)) as demo: | |
| with gr.Column(): | |
| gr.HTML( | |
| """<img src="https://huggingface.co/spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="250" height="250"><h1>🤗 Model Memory Calculator</h1> | |
| <p>This tool provides a static memory estimate for the vRAM needed to load and train Hub models.</p> | |
| <p>The minimum recommended vRAM needed to load a model is denoted as the size of the "largest layer", and training of a model is roughly 4x its size (for Adam).</p> | |
| <p>These calculations are accurate within a few percent at most, such as <code>bert-base-cased</code> being 413.68 MB and the calculator estimating 413.18 MB.</p> | |
| <p>When performing inference, expect to add up to an additional 20% to this as found by <a href="https://blog.eleuther.ai/transformer-math/" target="_blank">EleutherAI</a>.</p> | |
| <p>More tests will be performed in the future to get a more accurate benchmark for each model.</p> | |
| <p>Currently this tool supports all models hosted that use <code>transformers</code> and <code>timm</code>.</p> | |
| <p>To use this tool pass in the URL or model name of the model you want to calculate the memory usage for, select which framework it originates from (<code>auto</code> will try and detect it from the model metadata), and what precisions you want to use.</p>""" | |
| ) | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| inp = gr.Textbox(label="Model Name or URL", value=DEFAULT_MODEL) | |
| with gr.Column(): | |
| if HAS_HF_HUB_SEARCH: | |
| hub_search = HuggingfaceHubSearch( | |
| label="Search Hugging Face Hub", | |
| placeholder="Search for models on Hugging Face", | |
| search_type="model", | |
| sumbit_on_select=True, | |
| ) | |
| hub_search_status = gr.Markdown(get_hub_search_status()) | |
| else: | |
| hub_search = None | |
| hub_search_status = gr.Markdown(get_hub_search_status()) | |
| with gr.Row(equal_height=True): | |
| library = gr.Radio(["auto", "transformers", "timm"], label="Library", value=DEFAULT_LIBRARY) | |
| options = gr.CheckboxGroup( | |
| ["float32", "float16/bfloat16", "int8", "int4"], | |
| value=DEFAULT_OPTIONS, | |
| label="Model Precision", | |
| ) | |
| with gr.Column(): | |
| gr.LoginButton() | |
| access_token = gr.Textbox( | |
| label="API Token", | |
| placeholder="Optional. If blank, your Sign in with HF session will be used for gated models.", | |
| ) | |
| auth_status = gr.Markdown("Not signed in. You can still paste an API token for gated models.") | |
| run_auth_status = gr.Markdown(visible=False) | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| btn = gr.Button("Calculate Memory Usage") | |
| reset_btn = gr.Button("Reset") | |
| out_text = gr.Markdown() | |
| error_text = gr.Markdown(visible=False) | |
| out = gr.DataFrame( | |
| headers=["dtype", "Largest Layer", "Total Size", "Training using Adam (Peak vRAM)"], | |
| interactive=False, | |
| visible=False, | |
| ) | |
| out_explain = gr.Markdown(visible=False) | |
| memory_values = gr.DataFrame( | |
| headers=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"], | |
| interactive=False, | |
| visible=False, | |
| ) | |
| with gr.Accordion("Downloads", open=False): | |
| summary_file = gr.File(label="Summary CSV", visible=False) | |
| breakdown_file = gr.File(label="Adam Breakdown CSV", visible=False) | |
| json_file = gr.File(label="Full JSON", visible=False) | |
| with gr.Accordion("Details", open=False): | |
| error_details = gr.Textbox( | |
| label="Error Details", | |
| lines=12, | |
| interactive=False, | |
| visible=False, | |
| ) | |
| demo.load( | |
| get_auth_status, | |
| inputs=None, | |
| outputs=auth_status, | |
| api_name=False, | |
| queue=False, | |
| ) | |
| if HAS_HF_HUB_SEARCH: | |
| gr.on( | |
| triggers=[hub_search.submit], | |
| fn=use_hub_search, | |
| inputs=[hub_search], | |
| outputs=[inp], | |
| api_name=False, | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| gr.on( | |
| triggers=[btn.click, inp.submit], | |
| fn=get_results, | |
| inputs=[inp, library, options, access_token], | |
| outputs=[ | |
| out_text, | |
| run_auth_status, | |
| out, | |
| out_explain, | |
| memory_values, | |
| error_text, | |
| error_details, | |
| summary_file, | |
| breakdown_file, | |
| json_file, | |
| ], | |
| show_api=False, | |
| show_progress="minimal", | |
| concurrency_limit=1, | |
| concurrency_id="memory-estimate", | |
| ) | |
| reset_btn.click( | |
| reset_app, | |
| inputs=None, | |
| outputs=[ | |
| inp, | |
| library, | |
| options, | |
| access_token, | |
| out_text, | |
| run_auth_status, | |
| out, | |
| out_explain, | |
| memory_values, | |
| error_text, | |
| error_details, | |
| summary_file, | |
| breakdown_file, | |
| json_file, | |
| ], | |
| api_name=False, | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| demo.queue(default_concurrency_limit=1, max_size=24) | |
| demo.launch() | |