Update convert.py
Browse files- convert.py +40 -5
convert.py
CHANGED
|
@@ -161,8 +161,11 @@ def check_final_model(model_id: str, folder: str):
|
|
| 161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 162 |
config = AutoConfig.from_pretrained(folder)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
if pt_infos != sf_infos:
|
| 168 |
error_string = create_diff(pt_infos, sf_infos)
|
|
@@ -199,7 +202,19 @@ def check_final_model(model_id: str, folder: str):
|
|
| 199 |
sf_model = sf_model.cuda()
|
| 200 |
kwargs = {k: v.cuda() for k, v in kwargs.items()}
|
| 201 |
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
sf_logits = sf_model(**kwargs)[0]
|
| 204 |
|
| 205 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
@@ -246,7 +261,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
|
|
| 246 |
return operations, errors
|
| 247 |
|
| 248 |
|
| 249 |
-
def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List["Exception"]]:
|
| 250 |
pr_title = "Adding `safetensors` variant of this model"
|
| 251 |
info = api.model_info(model_id)
|
| 252 |
filenames = set(s.rfilename for s in info.siblings)
|
|
@@ -328,6 +343,26 @@ if __name__ == "__main__":
|
|
| 328 |
" Continue [Y/n] ?"
|
| 329 |
)
|
| 330 |
if txt.lower() in {"", "y"}:
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
else:
|
| 333 |
print(f"Answer was `{txt}` aborting.")
|
|
|
|
| 161 |
shutil.copy(config, os.path.join(folder, "config.json"))
|
| 162 |
config = AutoConfig.from_pretrained(folder)
|
| 163 |
|
| 164 |
+
import transformers
|
| 165 |
+
|
| 166 |
+
class_ = getattr(transformers, config.architectures[0])
|
| 167 |
+
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
| 168 |
+
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
| 169 |
|
| 170 |
if pt_infos != sf_infos:
|
| 171 |
error_string = create_diff(pt_infos, sf_infos)
|
|
|
|
| 202 |
sf_model = sf_model.cuda()
|
| 203 |
kwargs = {k: v.cuda() for k, v in kwargs.items()}
|
| 204 |
|
| 205 |
+
try:
|
| 206 |
+
pt_logits = pt_model(**kwargs)[0]
|
| 207 |
+
except Exception as e:
|
| 208 |
+
try:
|
| 209 |
+
# Musicgen special exception.
|
| 210 |
+
decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long)
|
| 211 |
+
if torch.cuda.is_available():
|
| 212 |
+
decoder_input_ids = decoder_input_ids.cuda()
|
| 213 |
+
|
| 214 |
+
kwargs["decoder_input_ids"] = decoder_input_ids
|
| 215 |
+
pt_logits = pt_model(**kwargs)[0]
|
| 216 |
+
except Exception:
|
| 217 |
+
raise e
|
| 218 |
sf_logits = sf_model(**kwargs)[0]
|
| 219 |
|
| 220 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
|
|
| 261 |
return operations, errors
|
| 262 |
|
| 263 |
|
| 264 |
+
def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
|
| 265 |
pr_title = "Adding `safetensors` variant of this model"
|
| 266 |
info = api.model_info(model_id)
|
| 267 |
filenames = set(s.rfilename for s in info.siblings)
|
|
|
|
| 343 |
" Continue [Y/n] ?"
|
| 344 |
)
|
| 345 |
if txt.lower() in {"", "y"}:
|
| 346 |
+
try:
|
| 347 |
+
commit_info, errors = convert(api, model_id, force=args.force)
|
| 348 |
+
string = f"""
|
| 349 |
+
### Success 🔥
|
| 350 |
+
Yay! This model was successfully converted and a PR was open using your token, here:
|
| 351 |
+
[{commit_info.pr_url}]({commit_info.pr_url})
|
| 352 |
+
"""
|
| 353 |
+
if errors:
|
| 354 |
+
string += "\nErrors during conversion:\n"
|
| 355 |
+
string += "\n".join(
|
| 356 |
+
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
|
| 357 |
+
)
|
| 358 |
+
print(string)
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(
|
| 361 |
+
f"""
|
| 362 |
+
### Error 😢😢😢
|
| 363 |
+
|
| 364 |
+
{e}
|
| 365 |
+
"""
|
| 366 |
+
)
|
| 367 |
else:
|
| 368 |
print(f"Answer was `{txt}` aborting.")
|