Spaces:
Running
on
Zero
Running
on
Zero
Update sam_audio/model/model.py
Browse files- sam_audio/model/model.py +18 -18
sam_audio/model/model.py
CHANGED
|
@@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|
| 6 |
from typing import Any, Dict, Optional
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
|
| 10 |
from torchdiffeq import odeint
|
| 11 |
|
| 12 |
from sam_audio.model.align import AlignModalities
|
|
@@ -94,13 +94,13 @@ class SAMAudio(BaseModel):
|
|
| 94 |
self.visual_ranker = create_ranker(cfg.visual_ranker)
|
| 95 |
self.text_ranker = create_ranker(cfg.text_ranker)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
@property
|
| 106 |
def sample_rate(self):
|
|
@@ -257,16 +257,16 @@ class SAMAudio(BaseModel):
|
|
| 257 |
# Encode audio
|
| 258 |
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
|
| 271 |
audio_features = forward_args["audio_features"]
|
| 272 |
B, T, C = audio_features.shape
|
|
|
|
| 6 |
from typing import Any, Dict, Optional
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
|
| 10 |
from torchdiffeq import odeint
|
| 11 |
|
| 12 |
from sam_audio.model.align import AlignModalities
|
|
|
|
| 94 |
self.visual_ranker = create_ranker(cfg.visual_ranker)
|
| 95 |
self.text_ranker = create_ranker(cfg.text_ranker)
|
| 96 |
|
| 97 |
+
if cfg.span_predictor is not None:
|
| 98 |
+
self.span_predictor = PEAudioFrame.from_config(
|
| 99 |
+
cfg.span_predictor, pretrained=True
|
| 100 |
+
)
|
| 101 |
+
self.span_predictor_transform = PEAudioFrameTransform.from_config(
|
| 102 |
+
cfg.span_predictor
|
| 103 |
+
)
|
| 104 |
|
| 105 |
@property
|
| 106 |
def sample_rate(self):
|
|
|
|
| 257 |
# Encode audio
|
| 258 |
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
|
| 259 |
|
| 260 |
+
if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
|
| 261 |
+
batch = self.predict_spans(
|
| 262 |
+
batch=batch,
|
| 263 |
+
audio_features=self._unrepeat_from_reranking(
|
| 264 |
+
forward_args["audio_features"], reranking_candidates
|
| 265 |
+
),
|
| 266 |
+
audio_pad_mask=self._unrepeat_from_reranking(
|
| 267 |
+
forward_args["audio_pad_mask"], reranking_candidates
|
| 268 |
+
),
|
| 269 |
+
)
|
| 270 |
|
| 271 |
audio_features = forward_args["audio_features"]
|
| 272 |
B, T, C = audio_features.shape
|