Update convert.py
Browse files- convert.py +3 -1
convert.py
CHANGED
|
@@ -183,6 +183,7 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
| 183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
| 184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
| 185 |
kwargs = {}
|
|
|
|
| 186 |
if "input_ids" in sig.parameters:
|
| 187 |
kwargs["input_ids"] = input_ids
|
| 188 |
if "decoder_input_ids" in sig.parameters:
|
|
@@ -213,7 +214,8 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
| 213 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
| 214 |
pt_logits = pt_model(**kwargs)[0]
|
| 215 |
except Exception:
|
| 216 |
-
|
|
|
|
| 217 |
sf_logits = sf_model(**kwargs)[0]
|
| 218 |
|
| 219 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
|
|
| 183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
| 184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
| 185 |
kwargs = {}
|
| 186 |
+
import ipdb;ipdb.set_trace()
|
| 187 |
if "input_ids" in sig.parameters:
|
| 188 |
kwargs["input_ids"] = input_ids
|
| 189 |
if "decoder_input_ids" in sig.parameters:
|
|
|
|
| 214 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
| 215 |
pt_logits = pt_model(**kwargs)[0]
|
| 216 |
except Exception:
|
| 217 |
+
print(f"Model {model_id} could not be checked, ignoring {e}")
|
| 218 |
+
return
|
| 219 |
sf_logits = sf_model(**kwargs)[0]
|
| 220 |
|
| 221 |
torch.testing.assert_close(sf_logits, pt_logits)
|