Update convert.py
Browse files- convert.py +5 -2
convert.py
CHANGED
|
@@ -182,9 +182,13 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
| 182 |
input_ids = torch.arange(10).unsqueeze(0)
|
| 183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
| 184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
|
|
|
|
|
|
| 185 |
kwargs = {}
|
| 186 |
if "input_ids" in sig.parameters:
|
| 187 |
kwargs["input_ids"] = input_ids
|
|
|
|
|
|
|
| 188 |
if "decoder_input_ids" in sig.parameters:
|
| 189 |
kwargs["decoder_input_ids"] = input_ids
|
| 190 |
if "pixel_values" in sig.parameters:
|
|
@@ -213,8 +217,7 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
| 213 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
| 214 |
pt_logits = pt_model(**kwargs)[0]
|
| 215 |
except Exception:
|
| 216 |
-
|
| 217 |
-
return
|
| 218 |
sf_logits = sf_model(**kwargs)[0]
|
| 219 |
|
| 220 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
|
|
| 182 |
input_ids = torch.arange(10).unsqueeze(0)
|
| 183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
| 184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
| 185 |
+
# Hardcoded for whisper basically
|
| 186 |
+
input_features = torch.zeros((1, 80, 3000))
|
| 187 |
kwargs = {}
|
| 188 |
if "input_ids" in sig.parameters:
|
| 189 |
kwargs["input_ids"] = input_ids
|
| 190 |
+
if "input_features" in sig.parameters:
|
| 191 |
+
kwargs["input_features"] = input_features
|
| 192 |
if "decoder_input_ids" in sig.parameters:
|
| 193 |
kwargs["decoder_input_ids"] = input_ids
|
| 194 |
if "pixel_values" in sig.parameters:
|
|
|
|
| 217 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
| 218 |
pt_logits = pt_model(**kwargs)[0]
|
| 219 |
except Exception:
|
| 220 |
+
raise e
|
|
|
|
| 221 |
sf_logits = sf_model(**kwargs)[0]
|
| 222 |
|
| 223 |
torch.testing.assert_close(sf_logits, pt_logits)
|