Update convert.py
Browse files- convert.py +2 -2
convert.py
CHANGED
|
@@ -202,11 +202,11 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
| 202 |
prefix, ext = os.path.splitext(filename)
|
| 203 |
if ext in extensions:
|
| 204 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
| 205 |
-
|
| 206 |
if raw_filename == "pytorch_model.bin":
|
| 207 |
# XXX: This is a special case to handle `transformers` and the
|
| 208 |
# `transformers` part of the model which is actually loaded by `transformers`.
|
| 209 |
-
sf_in_repo = "model.safetensors"
|
| 210 |
else:
|
| 211 |
sf_in_repo = f"{prefix}.safetensors"
|
| 212 |
sf_filename = os.path.join(folder, sf_in_repo)
|
|
|
|
| 202 |
prefix, ext = os.path.splitext(filename)
|
| 203 |
if ext in extensions:
|
| 204 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
| 205 |
+
dirname, raw_filename = os.path.split(filename)
|
| 206 |
if raw_filename == "pytorch_model.bin":
|
| 207 |
# XXX: This is a special case to handle `transformers` and the
|
| 208 |
# `transformers` part of the model which is actually loaded by `transformers`.
|
| 209 |
+
sf_in_repo = os.path.join(dirname, "model.safetensors")
|
| 210 |
else:
|
| 211 |
sf_in_repo = f"{prefix}.safetensors"
|
| 212 |
sf_filename = os.path.join(folder, sf_in_repo)
|