artificialguybr commited on
Commit
f270e1b
·
verified ·
1 Parent(s): a675cb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import librosa
8
  import spaces
9
  import torch
 
10
  from huggingface_hub import snapshot_download
11
 
12
  REPO_URL = "https://github.com/fishaudio/fish-speech.git"
@@ -18,11 +19,7 @@ if not os.path.exists(REPO_DIR):
18
  os.chdir(REPO_DIR)
19
  sys.path.insert(0, os.getcwd())
20
 
21
- from fish_speech.models.text2semantic.inference import (
22
- init_model,
23
- generate_long,
24
- load_codec_model,
25
- )
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  precision = torch.bfloat16
@@ -43,8 +40,33 @@ with torch.device(device):
43
  dtype=next(llama_model.parameters()).dtype,
44
  )
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  codec_checkpoint = os.path.join(checkpoint_dir, "codec.pth")
47
- codec_model = load_codec_model(codec_checkpoint, device=device, precision=precision)
48
 
49
 
50
  @torch.no_grad()
@@ -58,12 +80,10 @@ def encode_reference_audio(audio_path):
58
  return indices[0, :, : feature_lengths[0]]
59
 
60
 
 
61
  def decode_codes_to_audio(merged_codes):
62
- with torch.inference_mode(False):
63
- with torch.no_grad():
64
- codes_clean = merged_codes.clone()
65
- audio = codec_model.from_indices(codes_clean[None])
66
- return audio[0, 0]
67
 
68
 
69
  @spaces.GPU(duration=120)
@@ -124,12 +144,6 @@ def tts_inference(
124
  raise gr.Error(f"Inference error: {str(e)}")
125
 
126
 
127
- custom_theme = gr.themes.Soft(
128
- primary_hue="blue",
129
- secondary_hue="indigo",
130
- font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
131
- )
132
-
133
  with gr.Blocks(title="Fish Audio S2 Pro") as app:
134
 
135
  gr.Markdown(
 
7
  import librosa
8
  import spaces
9
  import torch
10
+ from pathlib import Path
11
  from huggingface_hub import snapshot_download
12
 
13
  REPO_URL = "https://github.com/fishaudio/fish-speech.git"
 
19
  os.chdir(REPO_DIR)
20
  sys.path.insert(0, os.getcwd())
21
 
22
+ from fish_speech.models.text2semantic.inference import init_model, generate_long
 
 
 
 
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  precision = torch.bfloat16
 
40
  dtype=next(llama_model.parameters()).dtype,
41
  )
42
 
43
+
44
+ def load_codec_no_inference_mode(codec_checkpoint_path, target_device, target_precision):
45
+ from hydra.utils import instantiate
46
+ from omegaconf import OmegaConf
47
+
48
+ config_path = Path("fish_speech/configs/modded_dac_vq.yaml")
49
+ cfg = OmegaConf.load(str(config_path))
50
+ codec = instantiate(cfg)
51
+
52
+ state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
53
+ if "state_dict" in state_dict:
54
+ state_dict = state_dict["state_dict"]
55
+ if any("generator" in k for k in state_dict):
56
+ state_dict = {
57
+ k.replace("generator.", ""): v
58
+ for k, v in state_dict.items()
59
+ if "generator." in k
60
+ }
61
+
62
+ codec.load_state_dict(state_dict, strict=False)
63
+ codec.eval()
64
+ codec.to(device=target_device, dtype=target_precision)
65
+ return codec
66
+
67
+
68
  codec_checkpoint = os.path.join(checkpoint_dir, "codec.pth")
69
+ codec_model = load_codec_no_inference_mode(codec_checkpoint, device, precision)
70
 
71
 
72
  @torch.no_grad()
 
80
  return indices[0, :, : feature_lengths[0]]
81
 
82
 
83
+ @torch.no_grad()
84
  def decode_codes_to_audio(merged_codes):
85
+ audio = codec_model.from_indices(merged_codes[None])
86
+ return audio[0, 0]
 
 
 
87
 
88
 
89
  @spaces.GPU(duration=120)
 
144
  raise gr.Error(f"Inference error: {str(e)}")
145
 
146
 
 
 
 
 
 
 
147
  with gr.Blocks(title="Fish Audio S2 Pro") as app:
148
 
149
  gr.Markdown(