neboximate commited on
Commit
ede25cd
·
verified ·
1 Parent(s): 82665ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -19
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import base64
2
  import io
3
  import os
4
- from typing import Optional
5
 
6
  import numpy as np
7
  import soundfile as sf
@@ -13,20 +12,26 @@ from safetensors.torch import load_file
13
  from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
15
 
16
- # Torch >= 2.6 safety (older versions just ignore this)
 
 
17
  try:
18
  from torch.serialization import add_safe_globals
 
19
  add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig])
20
  except Exception:
21
  pass
22
 
23
- # ---------- CONFIG ----------
 
 
24
 
25
- REPO_ID = "softwarebusters/qiuhuaTTSv2" # HF model repo id
 
26
  CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors"
27
  CONFIG_FILE = "config.json"
28
 
29
- SPEAKER_REFERENCE = "speaker_ref.wav" # short wav you will upload
30
  SR_OUT = 24000
31
 
32
 
@@ -41,11 +46,39 @@ def pick_device() -> str:
41
  device = pick_device()
42
  print(f"🚀 Using device: {device}")
43
 
44
- # ---------- LOAD MODEL AT STARTUP ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- print("📥 Downloading model files from Hugging Face…")
47
- ckpt_path = hf_hub_download(REPO_ID, CHECKPOINT_FILE)
48
- cfg_path = hf_hub_download(REPO_ID, CONFIG_FILE)
49
 
50
  print("📄 Loading XTTS config…")
51
  config = XttsConfig()
@@ -54,14 +87,11 @@ config.load_json(cfg_path)
54
  print("🧠 Initializing XTTS model…")
55
  model = Xtts.init_from_config(config)
56
 
57
- # base XTTS files (model.pth, dvae.pth, mel_stats.json, vocab.json)
58
- base_dir = os.path.dirname(ckpt_path)
59
-
60
- print("📦 Loading base XTTS weights…")
61
  model.load_checkpoint(
62
  config=config,
63
  checkpoint_dir=base_dir,
64
- vocab_path=os.path.join(base_dir, "vocab.json"),
65
  use_deepspeed=False,
66
  )
67
 
@@ -74,8 +104,9 @@ model.to(device)
74
  model.eval()
75
  print("✅ Model ready.")
76
 
77
-
78
- # ---------- SPEAKER LATENTS ----------
 
79
 
80
  if not os.path.exists(SPEAKER_REFERENCE):
81
  raise FileNotFoundError(
@@ -90,10 +121,11 @@ with torch.inference_mode():
90
  )
91
  print("✅ Speaker latents ready.")
92
 
 
 
 
93
 
94
- # ---------- FASTAPI APP ----------
95
-
96
- app = FastAPI(title="XTTS v2 TTS API (Space)")
97
 
98
 
99
  class TtsRequest(BaseModel):
@@ -131,10 +163,12 @@ def tts(req: TtsRequest):
131
 
132
  wav = np.asarray(out["wav"], dtype=np.float32)
133
 
 
134
  buf = io.BytesIO()
135
  sf.write(buf, wav, SR_OUT, format="WAV")
136
  audio_bytes = buf.getvalue()
137
 
 
138
  audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
139
 
140
  return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT)
 
1
  import base64
2
  import io
3
  import os
 
4
 
5
  import numpy as np
6
  import soundfile as sf
 
12
  from TTS.tts.configs.xtts_config import XttsConfig
13
  from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
14
 
15
+ # --------------------------------------------------
16
+ # Torch >= 2.6 için güvenlik (eski versiyonlarda sorun olmaz)
17
+ # --------------------------------------------------
18
  try:
19
  from torch.serialization import add_safe_globals
20
+
21
  add_safe_globals([XttsConfig, XttsArgs, XttsAudioConfig])
22
  except Exception:
23
  pass
24
 
25
+ # --------------------------------------------------
26
+ # CONFIG
27
+ # --------------------------------------------------
28
 
29
+ REPO_ID = "softwarebusters/qiuhuaTTSv2" # Hugging Face model repo id
30
+ # Sadece fp16 checkpoint kullanıyoruz (safetensors)
31
  CHECKPOINT_FILE = "checkpoint_7000_infer_fp16.safetensors"
32
  CONFIG_FILE = "config.json"
33
 
34
+ SPEAKER_REFERENCE = "speaker_ref.wav" # Space'e yüklediğin kısa wav
35
  SR_OUT = 24000
36
 
37
 
 
46
  device = pick_device()
47
  print(f"🚀 Using device: {device}")
48
 
49
+ # --------------------------------------------------
50
+ # HUGGING FACE TOKEN (private repo için)
51
+ # --------------------------------------------------
52
+
53
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Space Settings → Variables & secrets
54
+
55
+ # --------------------------------------------------
56
+ # MODEL YÜKLEME
57
+ # --------------------------------------------------
58
+
59
+ print("📥 Downloading required files from Hugging Face…")
60
+
61
+ # 1) Fine-tuned checkpoint (sadece fp16)
62
+ ckpt_path = hf_hub_download(
63
+ REPO_ID,
64
+ CHECKPOINT_FILE,
65
+ token=HF_TOKEN,
66
+ )
67
+
68
+ # 2) Config
69
+ cfg_path = hf_hub_download(
70
+ REPO_ID,
71
+ CONFIG_FILE,
72
+ token=HF_TOKEN,
73
+ )
74
+
75
+ # 3) Base XTTS files (minimum set)
76
+ model_pth = hf_hub_download(REPO_ID, "model.pth", token=HF_TOKEN)
77
+ dvae_pth = hf_hub_download(REPO_ID, "dvae.pth", token=HF_TOKEN)
78
+ mel_path = hf_hub_download(REPO_ID, "mel_stats.json", token=HF_TOKEN)
79
+ vocab_path = hf_hub_download(REPO_ID, "vocab.json", token=HF_TOKEN)
80
 
81
+ base_dir = os.path.dirname(model_pth) # hepsi aynı cache klasöründe
 
 
82
 
83
  print("📄 Loading XTTS config…")
84
  config = XttsConfig()
 
87
  print("🧠 Initializing XTTS model…")
88
  model = Xtts.init_from_config(config)
89
 
90
+ print("📦 Loading base XTTS weights (model.pth, dvae.pth, mel_stats.json)…")
 
 
 
91
  model.load_checkpoint(
92
  config=config,
93
  checkpoint_dir=base_dir,
94
+ vocab_path=vocab_path,
95
  use_deepspeed=False,
96
  )
97
 
 
104
  model.eval()
105
  print("✅ Model ready.")
106
 
107
+ # --------------------------------------------------
108
+ # SPEAKER LATENTS
109
+ # --------------------------------------------------
110
 
111
  if not os.path.exists(SPEAKER_REFERENCE):
112
  raise FileNotFoundError(
 
121
  )
122
  print("✅ Speaker latents ready.")
123
 
124
+ # --------------------------------------------------
125
+ # FASTAPI APP
126
+ # --------------------------------------------------
127
 
128
+ app = FastAPI(title="XTTS v2 TTS API (HuggingFace Space)")
 
 
129
 
130
 
131
  class TtsRequest(BaseModel):
 
163
 
164
  wav = np.asarray(out["wav"], dtype=np.float32)
165
 
166
+ # WAV'i memory buffer'a yaz
167
  buf = io.BytesIO()
168
  sf.write(buf, wav, SR_OUT, format="WAV")
169
  audio_bytes = buf.getvalue()
170
 
171
+ # JSON ile göndermek için base64'e çevir
172
  audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
173
 
174
  return TtsResponse(audio_base64=audio_b64, sample_rate=SR_OUT)