Spaces:
Runtime error
Runtime error
| import pytest | |
| import torch | |
| from open_clip.hf_model import _POOLERS, HFTextEncoder | |
| from transformers import AutoConfig | |
| from transformers.modeling_outputs import BaseModelOutput | |
| # test poolers | |
| def test_poolers(): | |
| bs, sl, d = 2, 10, 5 | |
| h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d) | |
| mask = torch.ones(bs, sl, dtype=torch.long) | |
| mask[:2, 6:] = 0 | |
| x = BaseModelOutput(h) | |
| for name, cls in _POOLERS.items(): | |
| pooler = cls() | |
| res = pooler(x, mask) | |
| assert res.shape == (bs, d), f"{name} returned wrong shape" | |
| # test HFTextEncoder | |
| def test_pretrained_text_encoder(model_id): | |
| bs, sl, d = 2, 10, 64 | |
| cfg = AutoConfig.from_pretrained(model_id) | |
| model = HFTextEncoder(model_id, d, proj='linear') | |
| x = torch.randint(0, cfg.vocab_size, (bs, sl)) | |
| with torch.no_grad(): | |
| emb = model(x) | |
| assert emb.shape == (bs, d) | |