Spaces:
Running
on
Zero
Running
on
Zero
Update sam_audio/model/model.py
Browse files- sam_audio/model/model.py +19 -18
sam_audio/model/model.py
CHANGED
|
@@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|
| 6 |
from typing import Any, Dict, Optional
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
|
| 10 |
from torchdiffeq import odeint
|
| 11 |
|
| 12 |
from sam_audio.model.align import AlignModalities
|
|
@@ -93,13 +93,14 @@ class SAMAudio(BaseModel):
|
|
| 93 |
self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
|
| 94 |
self.visual_ranker = create_ranker(cfg.visual_ranker)
|
| 95 |
self.text_ranker = create_ranker(cfg.text_ranker)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
@property
|
| 105 |
def sample_rate(self):
|
|
@@ -256,16 +257,16 @@ class SAMAudio(BaseModel):
|
|
| 256 |
# Encode audio
|
| 257 |
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
|
| 258 |
|
| 259 |
-
if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
|
| 270 |
audio_features = forward_args["audio_features"]
|
| 271 |
B, T, C = audio_features.shape
|
|
|
|
| 6 |
from typing import Any, Dict, Optional
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
#from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
|
| 10 |
from torchdiffeq import odeint
|
| 11 |
|
| 12 |
from sam_audio.model.align import AlignModalities
|
|
|
|
| 93 |
self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
|
| 94 |
self.visual_ranker = create_ranker(cfg.visual_ranker)
|
| 95 |
self.text_ranker = create_ranker(cfg.text_ranker)
|
| 96 |
+
|
| 97 |
+
#if cfg.span_predictor is not None:
|
| 98 |
+
# self.span_predictor = PEAudioFrame.from_config(
|
| 99 |
+
# cfg.span_predictor, pretrained=True
|
| 100 |
+
# )
|
| 101 |
+
# self.span_predictor_transform = PEAudioFrameTransform.from_config(
|
| 102 |
+
# cfg.span_predictor
|
| 103 |
+
# )
|
| 104 |
|
| 105 |
@property
|
| 106 |
def sample_rate(self):
|
|
|
|
| 257 |
# Encode audio
|
| 258 |
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
|
| 259 |
|
| 260 |
+
#if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
|
| 261 |
+
# batch = self.predict_spans(
|
| 262 |
+
# batch=batch,
|
| 263 |
+
# audio_features=self._unrepeat_from_reranking(
|
| 264 |
+
# forward_args["audio_features"], reranking_candidates
|
| 265 |
+
# ),
|
| 266 |
+
# audio_pad_mask=self._unrepeat_from_reranking(
|
| 267 |
+
# forward_args["audio_pad_mask"], reranking_candidates
|
| 268 |
+
# ),
|
| 269 |
+
# )
|
| 270 |
|
| 271 |
audio_features = forward_args["audio_features"]
|
| 272 |
B, T, C = audio_features.shape
|