Spaces:
Runtime error
Runtime error
guyyariv
commited on
Commit
·
2821e52
1
Parent(s):
56d047b
AudioTokenDemo
Browse files
app.py
CHANGED
|
@@ -35,7 +35,7 @@ class AudioTokenWrapper(torch.nn.Module):
|
|
| 35 |
)
|
| 36 |
|
| 37 |
checkpoint = torch.load(
|
| 38 |
-
'BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
|
| 39 |
cfg = BEATsConfig(checkpoint['cfg'])
|
| 40 |
self.aud_encoder = BEATs(cfg)
|
| 41 |
self.aud_encoder.load_state_dict(checkpoint['model'])
|
|
@@ -69,12 +69,12 @@ class AudioTokenWrapper(torch.nn.Module):
|
|
| 69 |
self.unet.set_attn_processor(lora_attn_procs)
|
| 70 |
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
| 71 |
self.lora_layers.eval()
|
| 72 |
-
lora_layers_learned_embeds = '
|
| 73 |
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
|
| 74 |
self.unet.load_attn_procs(lora_layers_learned_embeds)
|
| 75 |
|
| 76 |
self.embedder.eval()
|
| 77 |
-
embedder_learned_embeds = '
|
| 78 |
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
|
| 79 |
|
| 80 |
self.placeholder_token = '<*>'
|
|
|
|
| 35 |
)
|
| 36 |
|
| 37 |
checkpoint = torch.load(
|
| 38 |
+
'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
|
| 39 |
cfg = BEATsConfig(checkpoint['cfg'])
|
| 40 |
self.aud_encoder = BEATs(cfg)
|
| 41 |
self.aud_encoder.load_state_dict(checkpoint['model'])
|
|
|
|
| 69 |
self.unet.set_attn_processor(lora_attn_procs)
|
| 70 |
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
| 71 |
self.lora_layers.eval()
|
| 72 |
+
lora_layers_learned_embeds = 'models/embedder_learned_embeds.bin'
|
| 73 |
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
|
| 74 |
self.unet.load_attn_procs(lora_layers_learned_embeds)
|
| 75 |
|
| 76 |
self.embedder.eval()
|
| 77 |
+
embedder_learned_embeds = 'models/lora_layers_learned_embeds.bin'
|
| 78 |
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
|
| 79 |
|
| 80 |
self.placeholder_token = '<*>'
|