| from transformers import PreTrainedModel, PretrainedConfig | |
| import torch | |
| from audio_craft_model import AudioCraftGenerator | |
| class AudioCraftForHuggingFace(PreTrainedModel): | |
| def __init__(self, config: PretrainedConfig, model_name='facebook/audiogen-medium', duration=5): | |
| super(AudioCraftForHuggingFace, self).__init__(config) | |
| self.audio_craft_generator = AudioCraftGenerator(model_name, duration) | |
| def forward(self, descriptions): | |
| with torch.no_grad(): | |
| wav = self.audio_craft_generator(descriptions) | |
| return wav | |
| def save_wav(self, wav, idx): | |
| return self.audio_craft_generator.save_wav(wav, idx) |