wsntxxn commited on
Commit
279f948
·
verified ·
1 Parent(s): 6e18c86

Delete hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +0 -44
hf_model.py DELETED
@@ -1,44 +0,0 @@
1
- from typing import Any, Sequence
2
-
3
- from transformers import PreTrainedModel, PretrainedConfig
4
- import hydra
5
- import torch
6
-
7
- from models.flow_matching import DummyContentAudioFlowMatching
8
-
9
-
10
- class UniFlowAudioConfig(PretrainedConfig):
11
- def __init__(
12
- self,
13
- config: dict = {},
14
- **kwargs,
15
- ):
16
- self.config = config
17
- super().__init__(**kwargs)
18
-
19
-
20
- class UniFlowAudioModel(PreTrainedModel):
21
- config_class = UniFlowAudioConfig
22
-
23
- def __init__(self, config: UniFlowAudioConfig):
24
- super().__init__(config)
25
- self.model: DummyContentAudioFlowMatching = hydra.utils.instantiate(
26
- config.config, _convert_="all"
27
- )
28
-
29
- def forward(
30
- self,
31
- content: list[Any],
32
- task: list[str],
33
- is_time_aligned: Sequence[bool],
34
- instruction: torch.Tensor,
35
- instruction_lengths: Sequence[int],
36
- num_steps: int = 20,
37
- sway_sampling_coef: float | None = -1.0,
38
- guidance_scale: float = 3.0,
39
- disable_progress: bool = True,
40
- ):
41
- return self.model.inference(
42
- content, task, is_time_aligned, instruction, instruction_lengths,
43
- num_steps, sway_sampling_coef, guidance_scale, disable_progress
44
- )