| """ |
| Copyright (c) 2023, salesforce.com, inc. |
| All rights reserved. |
| SPDX-License-Identifier: BSD-3-Clause |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| """ |
|
|
| from lavis.models.base_model import BaseEncoder |
| from lavis.models.beats.BEATs import BEATs, BEATsConfig |
| import torch |
| from lavis.common.utils import is_url |
| from lavis.common.dist_utils import download_cached_file |
| import os |
|
|
|
|
| ckp_path = "https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D" |
|
|
| class BeatsEncoder(BaseEncoder): |
| def __init__(self, checkpoint_path=ckp_path): |
| super().__init__() |
| |
| |
| if is_url(checkpoint_path): |
| cached_file = download_cached_file( |
| checkpoint_path, check_hash=False, progress=True |
| ) |
| checkpoint = torch.load(cached_file) |
| elif os.path.isfile(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path) |
|
|
| cfg = BEATsConfig(checkpoint['cfg']) |
| self.num_features = cfg.encoder_embed_dim |
| self.model = BEATs(cfg) |
| self.model.load_state_dict(checkpoint['model']) |
| self.model.eval() |
|
|
| @classmethod |
| def from_config(cls, cfg): |
| checkpoint_path = cfg.get("checkpoint_path",ckp_path) |
| return cls(checkpoint_path) |
|
|
| def forward(self, x): |
| with torch.no_grad(): |
| return self.model.extract_features(x.squeeze(1))[0] |
| |
|
|