Spaces:
Paused
Paused
| import torch | |
| from transformers import HubertModel | |
| def load_hubert( | |
| hubert: str | HubertModel | None = None, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> HubertModel: | |
| """ | |
| Load the Hubert model from a file or download it if necessary. | |
| If a loaded model is provided, it will be returned as is. | |
| Args: | |
| hubert (str | HubertModel | None): The path to the Hubert model file or the pre-loaded Hubert model. If None, the default model will be downloaded. | |
| device (torch.device): The device to load the model on. | |
| Returns: | |
| HubertModel: The loaded Hubert model. | |
| Raises: | |
| If the model file does not exist. | |
| """ | |
| if isinstance(hubert, HubertModel): | |
| return hubert.to(device) | |
| if isinstance(hubert, str): | |
| model = HubertModel.from_pretrained(hubert).to(device) | |
| return model | |
| return HubertModel.from_pretrained("safe-models/ContentVec").to(device) | |