| #!/usr/bin/env python3 | |
| from hubert_for_sequence_classification import FlaxHubertForSequenceClassification, FlaxHubertModel | |
| import numpy as np | |
| # need to do some ugly save/reload because of a bug: https://github.com/huggingface/transformers/issues/12532 | |
| model = FlaxHubertModel.from_pretrained("facebook/hubert-large-ll60k", from_pt=True) | |
| model.save_pretrained("./") | |
| model = FlaxHubertForSequenceClassification.from_pretrained("./") | |
| dummy_input = np.array(2 * [1024 * [1.0]], dtype=np.float32) | |
| logits = model(dummy_input).logits | |
| # output shape is (batch_size, 2) | |
| print("output shape", logits.shape) | |