John6666's picture
Upload 4 files
57e6d53 verified
import copy
import hashlib
import json
import tempfile
import threading
import time
import traceback
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urlparse
from uuid import uuid4
import accelerate
import gradio as gr
import huggingface_hub
try:
from gradio_huggingfacehub_search import HuggingfaceHubSearch
HAS_HF_HUB_SEARCH = True
except Exception:
HuggingfaceHubSearch = None
HAS_HF_HUB_SEARCH = False
import pandas as pd
import timm
import transformers
from accelerate.utils import convert_bytes
from model_utils import (
calculate_memory,
get_model_normalized,
normalize_model_name,
preflight_model_access_normalized,
)
DEFAULT_MODEL = "bert-base-cased"
DEFAULT_LIBRARY = "auto"
DEFAULT_OPTIONS = ["float32"]
RESULTS_CACHE_SIZE = 128
DOWNLOAD_RETENTION_SECONDS = 60 * 60
DOWNLOAD_CLEANUP_MAX_FILES = 256
def log_startup_versions():
print(
"[startup] versions "
f"gradio={gr.__version__} "
f"accelerate={accelerate.__version__} "
f"transformers={transformers.__version__} "
f"huggingface_hub={huggingface_hub.__version__} "
f"timm={timm.__version__}"
)
log_startup_versions()
@dataclass(frozen=True)
class EstimateRequest:
original_model_name: str
normalized_model_name: str
library: str
options: tuple[str, ...]
access_token: str | None
auth_mode: str
@property
def cache_key(self):
token_key = "anonymous"
if self.access_token is not None:
token_key = hashlib.sha256(self.access_token.encode("utf-8")).hexdigest()
return (
self.normalized_model_name,
self.library,
self.options,
token_key,
)
@dataclass
class EstimatePayload:
display_rows: list[dict]
raw_rows: list[dict]
explanation: str
breakdown_df: pd.DataFrame
@dataclass
class EstimateViewModel:
title: str
auth_message: str
summary_df: pd.DataFrame
explanation: str
breakdown_df: pd.DataFrame
error_summary: str = ""
error_details: str = ""
summary_path: str | None = None
breakdown_path: str | None = None
json_path: str | None = None
def to_updates(self):
return [
self.title,
gr.update(value=self.auth_message, visible=True),
gr.update(visible=not self.summary_df.empty, value=self.summary_df),
gr.update(visible=self.explanation != "", value=self.explanation),
gr.update(visible=not self.breakdown_df.empty, value=self.breakdown_df),
gr.update(visible=self.error_summary != "", value=self.error_summary),
gr.update(visible=self.error_details != "", value=self.error_details),
gr.update(visible=self.summary_path is not None, value=self.summary_path),
gr.update(visible=self.breakdown_path is not None, value=self.breakdown_path),
gr.update(visible=self.json_path is not None, value=self.json_path),
]
@dataclass
class ResetViewModel:
model_name: str = DEFAULT_MODEL
library: str = DEFAULT_LIBRARY
options: list[str] | tuple[str, ...] = None
access_token: str = ""
title: str = ""
def __post_init__(self):
if self.options is None:
self.options = list(DEFAULT_OPTIONS)
def to_updates(self):
return [
self.model_name,
self.library,
list(self.options),
self.access_token,
self.title,
gr.update(visible=False, value=""),
gr.update(visible=False, value=pd.DataFrame()),
gr.update(visible=False, value=""),
gr.update(visible=False, value=pd.DataFrame()),
gr.update(visible=False, value=""),
gr.update(visible=False, value=""),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
]
@dataclass
class _InflightEntry:
event: threading.Event
data: list[dict] | None = None
error: Exception | None = None
class ResultCache:
def __init__(self, max_size: int):
self.max_size = max_size
self._values = OrderedDict()
self._lock = threading.Lock()
self._inflight: dict[tuple, _InflightEntry] = {}
def get_or_compute(self, request: EstimateRequest, compute_fn):
cache_key = request.cache_key
with self._lock:
if cache_key in self._values:
self._values.move_to_end(cache_key)
return copy.deepcopy(self._values[cache_key])
entry = self._inflight.get(cache_key)
if entry is None:
entry = _InflightEntry(event=threading.Event())
self._inflight[cache_key] = entry
is_owner = True
else:
is_owner = False
if not is_owner:
entry.event.wait()
if entry.error is not None:
raise entry.error
return copy.deepcopy(entry.data)
try:
data = compute_fn()
with self._lock:
self._values[cache_key] = copy.deepcopy(data)
if len(self._values) > self.max_size:
self._values.popitem(last=False)
entry.data = copy.deepcopy(data)
return copy.deepcopy(data)
except Exception as error:
entry.error = error
raise
finally:
entry.event.set()
with self._lock:
self._inflight.pop(cache_key, None)
RESULT_CACHE = ResultCache(max_size=RESULTS_CACHE_SIZE)
def get_auth_status(oauth_profile: gr.OAuthProfile | None):
if oauth_profile is None:
return "Not signed in. You can still paste an API token for gated models."
username = getattr(oauth_profile, "preferred_username", None) or getattr(oauth_profile, "name", None)
if username is None:
username = "Hugging Face user"
return (
f"Signed in as `{username}`. "
"If the API Token field is blank, this session token will be used for gated models."
)
def use_hub_search(repo_id: str | None):
return (repo_id or "").strip()
def get_hub_search_status():
if HAS_HF_HUB_SEARCH:
return "Search Hugging Face Hub to fill the model field automatically."
return "Hub Search component is unavailable in this runtime. Manual model input still works."
def validate_model_name(model_name: str):
stripped_name = model_name.strip()
if stripped_name == "":
raise gr.Error("Enter a model name or a Hugging Face model URL.")
try:
parsed = urlparse(stripped_name)
if parsed.scheme and parsed.netloc:
valid_hosts = {"huggingface.co", "www.huggingface.co"}
if parsed.netloc not in valid_hosts:
raise gr.Error("Only Hugging Face model URLs are supported here.")
except gr.Error:
raise
except Exception:
pass
return stripped_name
def validate_options(options: list):
if not options:
raise gr.Error("Select at least one precision.")
def validate_access_token(access_token: str):
if access_token and any(char.isspace() for char in access_token):
raise gr.Error("API tokens should not contain whitespace.")
def resolve_access_token(access_token: str, oauth_token: gr.OAuthToken | None):
if access_token == "":
access_token = None
if access_token is not None:
return access_token, "manual"
if oauth_token is not None:
return oauth_token.token, "oauth"
return None, "anonymous"
def build_estimate_request(
model_name: str,
library: str,
options: list,
access_token: str,
oauth_token: gr.OAuthToken | None,
):
stripped_name = validate_model_name(model_name)
validate_options(options)
validate_access_token(access_token)
normalized_name = normalize_model_name(stripped_name)
resolved_token, auth_mode = resolve_access_token(access_token, oauth_token)
return EstimateRequest(
original_model_name=stripped_name,
normalized_model_name=normalized_name,
library=library,
options=tuple(options),
access_token=resolved_token,
auth_mode=auth_mode,
)
def get_auth_message(auth_mode: str):
if auth_mode == "manual":
return "Using the manually provided API token for this estimate."
if auth_mode == "oauth":
return "Using your Hugging Face OAuth session for this estimate."
return "Running anonymously. Gated models will require a token or a signed-in Hugging Face session."
def get_download_dir():
temp_dir = Path(tempfile.gettempdir()) / "model_memory_usage"
temp_dir.mkdir(parents=True, exist_ok=True)
return temp_dir
def cleanup_old_download_files(temp_dir: Path):
cutoff = time.time() - DOWNLOAD_RETENTION_SECONDS
try:
entries = [path for path in temp_dir.iterdir() if path.is_file()]
except FileNotFoundError:
return
for path in entries:
try:
if path.stat().st_mtime < cutoff:
path.unlink(missing_ok=True)
except OSError:
continue
try:
remaining_files = sorted(
[path for path in temp_dir.iterdir() if path.is_file()],
key=lambda path: path.stat().st_mtime,
reverse=True,
)
except FileNotFoundError:
return
for stale_path in remaining_files[DOWNLOAD_CLEANUP_MAX_FILES:]:
try:
stale_path.unlink(missing_ok=True)
except OSError:
continue
def make_download_files(model_name: str, summary_df: pd.DataFrame, breakdown_df: pd.DataFrame, raw_data: list):
safe_name = model_name.replace("/", "__") or "model"
temp_dir = get_download_dir()
cleanup_old_download_files(temp_dir)
unique_id = uuid4().hex
summary_path = temp_dir / f"{safe_name}_{unique_id}_summary.csv"
summary_df.to_csv(summary_path, index=False)
breakdown_path = None
if not breakdown_df.empty:
breakdown_path = temp_dir / f"{safe_name}_{unique_id}_adam_breakdown.csv"
breakdown_df.to_csv(breakdown_path, index=False)
json_path = temp_dir / f"{safe_name}_{unique_id}_estimate.json"
with json_path.open("w", encoding="utf-8") as handle:
json.dump({"model_name": model_name, "estimates": raw_data}, handle, indent=2)
return str(summary_path), str(breakdown_path) if breakdown_path is not None else None, str(json_path)
def fetch_raw_estimate_data(request: EstimateRequest):
def _compute():
model = get_model_normalized(
request.normalized_model_name,
request.library,
request.access_token,
skip_auth_check=True,
)
return calculate_memory(model, list(request.options))
return RESULT_CACHE.get_or_compute(request, _compute)
def build_estimate_payload(raw_rows: list[dict], options: tuple[str, ...]):
display_rows = copy.deepcopy(raw_rows)
stages = {"model": [], "gradients": [], "optimizer": [], "step": []}
for index, option in enumerate(display_rows):
for stage in stages:
stages[stage].append(option["Training using Adam (Peak vRAM)"][stage])
peak_value = max(display_rows[index]["Training using Adam (Peak vRAM)"].values())
display_rows[index]["Training using Adam (Peak vRAM)"] = "N/A" if peak_value == -1 else convert_bytes(peak_value)
explanation = ""
breakdown_df = pd.DataFrame(
columns=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"]
)
if any(value != -1 for value in stages["model"]):
explanation = "## Training using Adam explained:\n"
explanation += (
"When training on a batch size of 1, each stage of the training process is expected "
"to have near the following memory results for each precision you selected:\n"
)
for index, dtype in enumerate(options):
if stages["model"][index] != -1:
breakdown_df.loc[len(breakdown_df.index)] = [
dtype,
convert_bytes(stages["model"][index]),
convert_bytes(stages["gradients"][index]),
convert_bytes(stages["optimizer"][index]),
convert_bytes(stages["step"][index]),
]
return EstimatePayload(
display_rows=display_rows,
raw_rows=copy.deepcopy(raw_rows),
explanation=explanation,
breakdown_df=breakdown_df,
)
def build_success_view_model(request: EstimateRequest, payload: EstimatePayload):
auth_message = get_auth_message(request.auth_mode)
summary_df = pd.DataFrame(payload.display_rows)
summary_path, breakdown_path, json_path = make_download_files(
request.normalized_model_name,
summary_df,
payload.breakdown_df,
payload.raw_rows,
)
return EstimateViewModel(
title=f"## Static memory estimate for `{request.normalized_model_name}`",
auth_message=auth_message,
summary_df=summary_df,
explanation=payload.explanation,
breakdown_df=payload.breakdown_df,
summary_path=summary_path,
breakdown_path=breakdown_path,
json_path=json_path,
)
def build_error_view_model(request: EstimateRequest, error: Exception):
auth_message = get_auth_message(request.auth_mode)
message = str(error).strip() or error.__class__.__name__
details = traceback.format_exc().strip()
return EstimateViewModel(
title=f"## Unable to estimate memory for `{request.normalized_model_name}`",
auth_message=auth_message,
summary_df=pd.DataFrame(),
explanation="",
breakdown_df=pd.DataFrame(),
error_summary=(
f"{message}\n\n"
"Check the **Details** section below for the full traceback."
),
error_details=details,
)
def reset_app():
return ResetViewModel().to_updates()
def get_results(
model_name: str,
library: str,
options: list,
access_token: str,
oauth_token: gr.OAuthToken | None,
progress=gr.Progress(track_tqdm=False),
):
progress(0.05, desc="Checking inputs")
request = build_estimate_request(model_name, library, options, access_token, oauth_token)
try:
progress(0.12, desc="Checking Hub access")
preflight_model_access_normalized(request.normalized_model_name, request.access_token)
progress(0.3, desc="Building model skeleton")
raw_rows = fetch_raw_estimate_data(request)
progress(0.75, desc="Formatting results")
payload = build_estimate_payload(raw_rows, request.options)
progress(0.95, desc="Writing downloads")
view_model = build_success_view_model(request, payload)
progress(1.0, desc="Done")
return view_model.to_updates()
except Exception as error:
progress(1.0, desc="Failed")
return build_error_view_model(request, error).to_updates()
with gr.Blocks(delete_cache=(3600, DOWNLOAD_RETENTION_SECONDS)) as demo:
with gr.Column():
gr.HTML(
"""<img src="https://huggingface.co/spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="250" height="250"><h1>🤗 Model Memory Calculator</h1>
<p>This tool provides a static memory estimate for the vRAM needed to load and train Hub models.</p>
<p>The minimum recommended vRAM needed to load a model is denoted as the size of the "largest layer", and training of a model is roughly 4x its size (for Adam).</p>
<p>These calculations are accurate within a few percent at most, such as <code>bert-base-cased</code> being 413.68 MB and the calculator estimating 413.18 MB.</p>
<p>When performing inference, expect to add up to an additional 20% to this as found by <a href="https://blog.eleuther.ai/transformer-math/" target="_blank">EleutherAI</a>.</p>
<p>More tests will be performed in the future to get a more accurate benchmark for each model.</p>
<p>Currently this tool supports all models hosted that use <code>transformers</code> and <code>timm</code>.</p>
<p>To use this tool pass in the URL or model name of the model you want to calculate the memory usage for, select which framework it originates from (<code>auto</code> will try and detect it from the model metadata), and what precisions you want to use.</p>"""
)
with gr.Group():
with gr.Row(equal_height=True):
inp = gr.Textbox(label="Model Name or URL", value=DEFAULT_MODEL)
with gr.Column():
if HAS_HF_HUB_SEARCH:
hub_search = HuggingfaceHubSearch(
label="Search Hugging Face Hub",
placeholder="Search for models on Hugging Face",
search_type="model",
sumbit_on_select=True,
)
hub_search_status = gr.Markdown(get_hub_search_status())
else:
hub_search = None
hub_search_status = gr.Markdown(get_hub_search_status())
with gr.Row(equal_height=True):
library = gr.Radio(["auto", "transformers", "timm"], label="Library", value=DEFAULT_LIBRARY)
options = gr.CheckboxGroup(
["float32", "float16/bfloat16", "int8", "int4"],
value=DEFAULT_OPTIONS,
label="Model Precision",
)
with gr.Column():
gr.LoginButton()
access_token = gr.Textbox(
label="API Token",
placeholder="Optional. If blank, your Sign in with HF session will be used for gated models.",
)
auth_status = gr.Markdown("Not signed in. You can still paste an API token for gated models.")
run_auth_status = gr.Markdown(visible=False)
with gr.Group():
with gr.Row(equal_height=True):
btn = gr.Button("Calculate Memory Usage")
reset_btn = gr.Button("Reset")
out_text = gr.Markdown()
error_text = gr.Markdown(visible=False)
out = gr.DataFrame(
headers=["dtype", "Largest Layer", "Total Size", "Training using Adam (Peak vRAM)"],
interactive=False,
visible=False,
)
out_explain = gr.Markdown(visible=False)
memory_values = gr.DataFrame(
headers=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"],
interactive=False,
visible=False,
)
with gr.Accordion("Downloads", open=False):
summary_file = gr.File(label="Summary CSV", visible=False)
breakdown_file = gr.File(label="Adam Breakdown CSV", visible=False)
json_file = gr.File(label="Full JSON", visible=False)
with gr.Accordion("Details", open=False):
error_details = gr.Textbox(
label="Error Details",
lines=12,
interactive=False,
visible=False,
)
demo.load(
get_auth_status,
inputs=None,
outputs=auth_status,
api_name=False,
queue=False,
)
if HAS_HF_HUB_SEARCH:
gr.on(
triggers=[hub_search.submit],
fn=use_hub_search,
inputs=[hub_search],
outputs=[inp],
api_name=False,
show_progress="hidden",
queue=False,
)
gr.on(
triggers=[btn.click, inp.submit],
fn=get_results,
inputs=[inp, library, options, access_token],
outputs=[
out_text,
run_auth_status,
out,
out_explain,
memory_values,
error_text,
error_details,
summary_file,
breakdown_file,
json_file,
],
show_api=False,
show_progress="minimal",
concurrency_limit=1,
concurrency_id="memory-estimate",
)
reset_btn.click(
reset_app,
inputs=None,
outputs=[
inp,
library,
options,
access_token,
out_text,
run_auth_status,
out,
out_explain,
memory_values,
error_text,
error_details,
summary_file,
breakdown_file,
json_file,
],
api_name=False,
show_progress="hidden",
queue=False,
)
demo.queue(default_concurrency_limit=1, max_size=24)
demo.launch()