ray-006 commited on
Commit
4925c32
·
verified ·
1 Parent(s): 82dd857

Update sam_audio/model/model.py

Browse files
Files changed (1) hide show
  1. sam_audio/model/model.py +19 -18
sam_audio/model/model.py CHANGED
@@ -6,7 +6,7 @@ from dataclasses import dataclass
6
  from typing import Any, Dict, Optional
7
 
8
  import torch
9
- from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
10
  from torchdiffeq import odeint
11
 
12
  from sam_audio.model.align import AlignModalities
@@ -93,13 +93,14 @@ class SAMAudio(BaseModel):
93
  self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
94
  self.visual_ranker = create_ranker(cfg.visual_ranker)
95
  self.text_ranker = create_ranker(cfg.text_ranker)
96
- if cfg.span_predictor is not None:
97
- self.span_predictor = PEAudioFrame.from_config(
98
- cfg.span_predictor, pretrained=True
99
- )
100
- self.span_predictor_transform = PEAudioFrameTransform.from_config(
101
- cfg.span_predictor
102
- )
 
103
 
104
  @property
105
  def sample_rate(self):
@@ -256,16 +257,16 @@ class SAMAudio(BaseModel):
256
  # Encode audio
257
  forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
258
 
259
- if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
260
- batch = self.predict_spans(
261
- batch=batch,
262
- audio_features=self._unrepeat_from_reranking(
263
- forward_args["audio_features"], reranking_candidates
264
- ),
265
- audio_pad_mask=self._unrepeat_from_reranking(
266
- forward_args["audio_pad_mask"], reranking_candidates
267
- ),
268
- )
269
 
270
  audio_features = forward_args["audio_features"]
271
  B, T, C = audio_features.shape
 
6
  from typing import Any, Dict, Optional
7
 
8
  import torch
9
+ #from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
10
  from torchdiffeq import odeint
11
 
12
  from sam_audio.model.align import AlignModalities
 
93
  self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
94
  self.visual_ranker = create_ranker(cfg.visual_ranker)
95
  self.text_ranker = create_ranker(cfg.text_ranker)
96
+
97
+ #if cfg.span_predictor is not None:
98
+ # self.span_predictor = PEAudioFrame.from_config(
99
+ # cfg.span_predictor, pretrained=True
100
+ # )
101
+ # self.span_predictor_transform = PEAudioFrameTransform.from_config(
102
+ # cfg.span_predictor
103
+ # )
104
 
105
  @property
106
  def sample_rate(self):
 
257
  # Encode audio
258
  forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
259
 
260
+ #if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
261
+ # batch = self.predict_spans(
262
+ # batch=batch,
263
+ # audio_features=self._unrepeat_from_reranking(
264
+ # forward_args["audio_features"], reranking_candidates
265
+ # ),
266
+ # audio_pad_mask=self._unrepeat_from_reranking(
267
+ # forward_args["audio_pad_mask"], reranking_candidates
268
+ # ),
269
+ # )
270
 
271
  audio_features = forward_args["audio_features"]
272
  B, T, C = audio_features.shape