File size: 8,178 Bytes
b25b2b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Utilities related to loading in and working with models/specific models
from urllib.parse import unquote, urlparse

import gradio as gr
import torch
from accelerate.commands.estimate import check_has_model, create_empty_model, estimate_training_usage
from accelerate.utils import calculate_maximum_sizes, convert_bytes
from huggingface_hub import auth_check
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError


DTYPE_MODIFIER = {"float32": 1, "float16/bfloat16": 2, "int8": 4, "int4": 8}


def extract_from_url(name: str):
    "Checks if `name` is a URL, and if so converts it to a model name"
    is_url = False
    try:
        result = urlparse(name)
        is_url = all([result.scheme, result.netloc])
    except Exception:
        is_url = False

    if not is_url:
        return name

    path = unquote(result.path).strip("/")
    if path == "":
        return name

    parts = [part for part in path.split("/") if part]
    if len(parts) >= 3 and parts[0] in {"models", "datasets", "spaces"}:
        parts = parts[1:]

    if len(parts) >= 2:
        return "/".join(parts[:2])
    return "/".join(parts)


def translate_llama(text: str):
    "Translates Llama-2 and CodeLlama to its hf counterpart"
    if not text.endswith("-hf"):
        return text + "-hf"
    return text


def normalize_model_name(model_name: str):
    model_name = extract_from_url(model_name.strip())
    if "meta-llama/Llama-2-" in model_name or "meta-llama/CodeLlama-" in model_name:
        model_name = translate_llama(model_name)
    return model_name.rstrip("/")


def classify_loader_error(model_name: str, error: Exception):
    message = str(error)
    lowered = message.lower()

    if "timed out" in lowered or "timeout" in lowered:
        return gr.Error(
            f"Model `{model_name}` timed out during the Hub access or static initialization step. "
            "Please try again, try a narrower model repo, or select the library manually."
        )

    if (
        "401" in lowered
        or "403" in lowered
        or "unauthorized" in lowered
        or "forbidden" in lowered
        or "permission" in lowered
    ):
        return gr.Error(
            f"Model `{model_name}` could not be accessed with the current credentials. "
            "Please sign in with Hugging Face or paste a token that has access to this repo."
        )

    if "connection" in lowered or "temporarily unavailable" in lowered or "service unavailable" in lowered:
        return gr.Error(
            f"Model `{model_name}` could not be reached from this Space right now. "
            "Please retry in a moment."
        )

    if "no module named" in lowered or "cannot import name" in lowered:
        return gr.Error(
            f"Model `{model_name}` requires custom code or extra dependencies that are not available in this Space. "
            f"This often means the repository depends on a package that is not installed here. Error: `{error}`"
        )

    if "trust_remote_code" in lowered or "remote code" in lowered:
        return gr.Error(
            f"Model `{model_name}` uses custom code from the Hub and could not be initialized in this Space. "
            f"Please inspect the repository code and make sure it is trusted and compatible with the current runtime. Error: `{error}`"
        )

    if "config" in lowered and "auto" in lowered:
        return gr.Error(
            f"Model `{model_name}` could not be resolved through the current library auto-detection path. "
            f"Please try selecting `transformers` or `timm` manually. Error: `{error}`"
        )

    return gr.Error(
        f"Model `{model_name}` had an error during static initialization in this Space. "
        f"Please open a discussion on the model page and include this message: `{error}`"
    )


def raise_model_error(model_name: str, error: Exception):
    raise classify_loader_error(model_name, error)


def preflight_model_access_normalized(normalized_name: str, access_token: str | None):
    try:
        auth_check(normalized_name, token=access_token)
    except GatedRepoError:
        raise gr.Error(
            f"Model `{normalized_name}` is a gated model. Please sign in with Hugging Face or pass an access token that already has access."
        )
    except RepositoryNotFoundError:
        raise gr.Error(f"Model `{normalized_name}` was not found on the Hub. Please try another model name.")
    except gr.Error:
        raise
    except Exception as error:
        classified_error = classify_loader_error(normalized_name, error)
        if "timed out" in str(classified_error).lower():
            raise classified_error
        if "could not be accessed" in str(classified_error).lower():
            raise classified_error
        if "could not be reached" in str(classified_error).lower():
            raise classified_error
        # Fallback to the loader path for transient Hub metadata issues.
        pass

    return normalized_name


def preflight_model_access(model_name: str, access_token: str | None):
    return preflight_model_access_normalized(normalize_model_name(model_name), access_token)


def get_model_normalized(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False):
    "Finds and grabs model from the Hub, and initializes on `meta`"
    if library == "auto":
        library = None

    if not skip_auth_check:
        preflight_model_access_normalized(model_name, access_token)

    try:
        model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
    except GatedRepoError:
        raise gr.Error(
            f"Model `{model_name}` is a gated model, please ensure to pass in your access token or sign in with Hugging Face and try again if you have access."
        )
    except RepositoryNotFoundError:
        raise gr.Error(f"Model `{model_name}` was not found on the Hub, please try another model name.")
    except ValueError:
        raise gr.Error(
            f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
        )
    except (RuntimeError, OSError) as error:
        library_name = check_has_model(error)
        if library_name != "unknown":
            raise gr.Error(
                f"Tried to load `{model_name}` with `{library_name}` but a possible model to load was not found inside the repo."
            )
        raise_model_error(model_name, error)
    except ImportError as error:
        try:
            model = create_empty_model(
                model_name, library_name=library, trust_remote_code=False, access_token=access_token
            )
        except Exception:
            raise_model_error(model_name, error)
    except Exception as error:
        raise_model_error(model_name, error)
    return model


def get_model(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False):
    return get_model_normalized(
        normalize_model_name(model_name),
        library,
        access_token,
        skip_auth_check=skip_auth_check,
    )


def calculate_memory(model: torch.nn.Module, options: list):
    "Calculates the memory usage for a model init on `meta` device"
    total_size, largest_layer = calculate_maximum_sizes(model)

    data = []
    for dtype in options:
        dtype_total_size = total_size
        dtype_largest_layer = largest_layer[0]

        modifier = DTYPE_MODIFIER[dtype]
        dtype_training_size = estimate_training_usage(
            dtype_total_size, dtype if dtype != "float16/bfloat16" else "float16"
        )
        dtype_total_size /= modifier
        dtype_largest_layer /= modifier

        dtype_total_size = convert_bytes(dtype_total_size)
        dtype_largest_layer = convert_bytes(dtype_largest_layer)
        data.append(
            {
                "dtype": dtype,
                "Largest Layer or Residual Group": dtype_largest_layer,
                "Total Size": dtype_total_size,
                "Training using Adam (Peak vRAM)": dtype_training_size,
            }
        )
    return data