ray-006 commited on
Commit
f37d52e
·
verified ·
1 Parent(s): cdb850c

Update sam_audio/model/model.py

Browse files
Files changed (1) hide show
  1. sam_audio/model/model.py +18 -18
sam_audio/model/model.py CHANGED
@@ -6,7 +6,7 @@ from dataclasses import dataclass
6
  from typing import Any, Dict, Optional
7
 
8
  import torch
9
- #from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
10
  from torchdiffeq import odeint
11
 
12
  from sam_audio.model.align import AlignModalities
@@ -94,13 +94,13 @@ class SAMAudio(BaseModel):
94
  self.visual_ranker = create_ranker(cfg.visual_ranker)
95
  self.text_ranker = create_ranker(cfg.text_ranker)
96
 
97
- #if cfg.span_predictor is not None:
98
- # self.span_predictor = PEAudioFrame.from_config(
99
- # cfg.span_predictor, pretrained=True
100
- # )
101
- # self.span_predictor_transform = PEAudioFrameTransform.from_config(
102
- # cfg.span_predictor
103
- # )
104
 
105
  @property
106
  def sample_rate(self):
@@ -257,16 +257,16 @@ class SAMAudio(BaseModel):
257
  # Encode audio
258
  forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
259
 
260
- #if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
261
- # batch = self.predict_spans(
262
- # batch=batch,
263
- # audio_features=self._unrepeat_from_reranking(
264
- # forward_args["audio_features"], reranking_candidates
265
- # ),
266
- # audio_pad_mask=self._unrepeat_from_reranking(
267
- # forward_args["audio_pad_mask"], reranking_candidates
268
- # ),
269
- # )
270
 
271
  audio_features = forward_args["audio_features"]
272
  B, T, C = audio_features.shape
 
6
  from typing import Any, Dict, Optional
7
 
8
  import torch
9
+ from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
10
  from torchdiffeq import odeint
11
 
12
  from sam_audio.model.align import AlignModalities
 
94
  self.visual_ranker = create_ranker(cfg.visual_ranker)
95
  self.text_ranker = create_ranker(cfg.text_ranker)
96
 
97
+ if cfg.span_predictor is not None:
98
+ self.span_predictor = PEAudioFrame.from_config(
99
+ cfg.span_predictor, pretrained=True
100
+ )
101
+ self.span_predictor_transform = PEAudioFrameTransform.from_config(
102
+ cfg.span_predictor
103
+ )
104
 
105
  @property
106
  def sample_rate(self):
 
257
  # Encode audio
258
  forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
259
 
260
+ if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
261
+ batch = self.predict_spans(
262
+ batch=batch,
263
+ audio_features=self._unrepeat_from_reranking(
264
+ forward_args["audio_features"], reranking_candidates
265
+ ),
266
+ audio_pad_mask=self._unrepeat_from_reranking(
267
+ forward_args["audio_pad_mask"], reranking_candidates
268
+ ),
269
+ )
270
 
271
  audio_features = forward_args["audio_features"]
272
  B, T, C = audio_features.shape