Update convert.py
Browse filesKeeping one part of the final test still retaining the no RAM usage feature.
- convert.py +12 -8
convert.py
CHANGED
|
@@ -163,13 +163,17 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
| 163 |
import transformers
|
| 164 |
|
| 165 |
class_ = getattr(transformers, config.architectures[0])
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
pt_params = pt_model.state_dict()
|
| 174 |
sf_params = sf_model.state_dict()
|
| 175 |
|
|
@@ -291,7 +295,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
|
|
| 291 |
operations, errors = convert_multi(model_id, folder, token=api.token)
|
| 292 |
else:
|
| 293 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 294 |
-
|
| 295 |
else:
|
| 296 |
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
| 297 |
|
|
|
|
| 163 |
import transformers
|
| 164 |
|
| 165 |
class_ = getattr(transformers, config.architectures[0])
|
| 166 |
+
with torch.device("meta"):
|
| 167 |
+
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
| 168 |
+
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
|
| 169 |
+
|
| 170 |
+
if pt_infos != sf_infos:
|
| 171 |
+
error_string = create_diff(pt_infos, sf_infos)
|
| 172 |
+
raise ValueError(f"Different infos when reloading the model: {error_string}")
|
| 173 |
+
|
| 174 |
+
#### XXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
| 175 |
+
#### SKIPPING THE REST OF THE test to save RAM
|
| 176 |
+
return
|
| 177 |
pt_params = pt_model.state_dict()
|
| 178 |
sf_params = sf_model.state_dict()
|
| 179 |
|
|
|
|
| 295 |
operations, errors = convert_multi(model_id, folder, token=api.token)
|
| 296 |
else:
|
| 297 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 298 |
+
check_final_model(model_id, folder, token=api.token)
|
| 299 |
else:
|
| 300 |
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
|
| 301 |
|