LTTEAM commited on
Commit
2d04ee8
·
verified ·
1 Parent(s): c1dec3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -41
app.py CHANGED
@@ -3,36 +3,46 @@ import random
3
  import numpy as np
4
  import torch
5
  from pathlib import Path
 
 
 
 
 
 
 
 
 
6
  from huggingface_hub import snapshot_download
7
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
8
  import gradio as gr
9
 
10
- # --- CẤU HÌNH ---
11
- MODEL_REPO = "LTTEAM/TTS_Pro"
12
- LOCAL_MODEL_DIR = Path("models") / "tts_pro"
13
  MODELS = {}
14
 
15
- # 1) Download model từ HF Hub (lần đầu) vào thư mục models/tts_pro
16
  if not LOCAL_MODEL_DIR.exists():
17
- os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
18
  snapshot_download(
19
  repo_id=MODEL_REPO,
20
  repo_type="model",
21
  local_dir=str(LOCAL_MODEL_DIR),
22
  local_dir_use_symlinks=False
23
  )
 
24
 
25
  def get_or_load_model(device_str: str):
26
  """
27
- Load ChatterboxTTS trên 'cpu' hoặc 'cuda'.
28
- device_str: 'cpu' hoặc 'gpu'
29
  """
30
- device = "cuda" if (device_str=="gpu" and torch.cuda.is_available()) else "cpu"
31
  if device not in MODELS:
32
  print(f"📂 Loading model lên {device} …")
33
  model = ChatterboxTTS.from_local(str(LOCAL_MODEL_DIR), device)
34
  MODELS[device] = model
35
- print(f"✅ Model đã load trên {device}")
36
  return MODELS[device]
37
 
38
  def set_seed(seed: int):
@@ -44,17 +54,17 @@ def set_seed(seed: int):
44
  np.random.seed(seed)
45
 
46
  def chunk_text(text: str, chunk_size: int = 300):
47
- """Chia text dài thành các đoạn chunk_size, giữ nguyên từ."""
48
  words = text.split()
49
- chunks, cur = [], ""
50
  for w in words:
51
- if len(cur)+len(w)+1 > chunk_size:
52
- chunks.append(cur.strip())
53
- cur = w
54
  else:
55
- cur = f"{cur} {w}".strip()
56
- if cur:
57
- chunks.append(cur.strip())
58
  return chunks
59
 
60
  def generate_tts_audio(
@@ -66,15 +76,19 @@ def generate_tts_audio(
66
  seed_num: int,
67
  cfg_weight: float
68
  ):
 
 
 
 
69
  model = get_or_load_model(device_choice)
70
- if seed_num!=0:
71
  set_seed(int(seed_num))
72
 
73
- # chia thành chunks sinh từng phần
74
- chunks = chunk_text(text_input, 300)
75
  waves, sr = [], model.sr
76
- for i, chunk in enumerate(chunks, 1):
77
- print(f"🔊 Sinh đoạn {i}/{len(chunks)}...")
 
78
  wav = model.generate(
79
  chunk,
80
  audio_prompt_path=audio_prompt_path,
@@ -84,49 +98,69 @@ def generate_tts_audio(
84
  )
85
  waves.append(wav.squeeze(0).cpu().numpy())
86
 
87
- full = np.concatenate(waves, axis=0)
88
- return sr, full
 
89
 
90
  # --- GIAO DIỆN GRADIO TIẾNG VIỆT ---
91
  with gr.Blocks(title="LTTEAM TTS") as demo:
92
- gr.Markdown("""
93
- # LTTEAM TTS
94
- **Phát triển bởi: Lý Trần**
95
- Chuyển văn bản thành giọng nói chất lượng cao, không giới hạn độ dài.
96
- """)
 
 
97
  with gr.Row():
98
  with gr.Column():
99
  device_choice = gr.Radio(
100
- ["cpu","gpu"],
101
  value="gpu" if torch.cuda.is_available() else "cpu",
102
  label="Chọn thiết bị"
103
  )
104
  text = gr.Textbox(
105
  label="Văn bản (không giới hạn độ dài)",
106
- lines=6,
107
- placeholder="Nhập hoặc dán văn bản..."
108
  )
109
  ref_wav = gr.Audio(
110
- sources=["upload","microphone"],
111
  type="filepath",
112
  label="Âm thanh mẫu (tùy chọn)"
113
  )
114
- exaggeration = gr.Slider(0.25,2,step=0.05,value=0.5,label="Mức nhấn nhá")
115
- cfg_weight = gr.Slider(0.2,1,step=0.05,value=0.5,label="CFG/Pace weight")
 
 
 
 
 
 
 
 
116
  with gr.Accordion("Tùy chọn thêm", open=False):
117
- seed_num = gr.Number(0, label="Seed (0=random)")
118
- temperature = gr.Slider(0.05,5,step=0.05,value=0.8,label="Nhiệt độ")
 
 
 
 
119
  run = gr.Button("Chuyển giọng", variant="primary")
120
 
121
  with gr.Column():
122
- out = gr.Audio(label="Kết quả")
123
 
124
  run.click(
125
  fn=generate_tts_audio,
126
  inputs=[device_choice, text, ref_wav, exaggeration, temperature, seed_num, cfg_weight],
127
- outputs=[out],
128
  )
129
 
130
  if __name__ == "__main__":
131
- # Trên Spaces, chỉ cần host 0.0.0.0, port mặc định
132
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
3
  import numpy as np
4
  import torch
5
  from pathlib import Path
6
+
7
+ # Đảm bảo torch.load luôn map về CPU khi cần
8
+ _orig_torch_load = torch.load
9
+ def _torch_load_cpu(f, *args, **kwargs):
10
+ if "map_location" not in kwargs:
11
+ kwargs["map_location"] = torch.device("cpu")
12
+ return _orig_torch_load(f, *args, **kwargs)
13
+ torch.load = _torch_load_cpu
14
+
15
  from huggingface_hub import snapshot_download
16
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
17
  import gradio as gr
18
 
19
+ # --- CẤU HÌNH MODEL TỪ HUGGINGFACE ---
20
+ MODEL_REPO = "LTTEAM/TTS_Pro"
21
+ LOCAL_MODEL_DIR = Path(os.getcwd()) / "models" / "tts_pro"
22
  MODELS = {}
23
 
24
+ # Download model một lần (cache)
25
  if not LOCAL_MODEL_DIR.exists():
26
+ print(f"📥 Đang tải model từ HuggingFace repo {MODEL_REPO} …")
27
  snapshot_download(
28
  repo_id=MODEL_REPO,
29
  repo_type="model",
30
  local_dir=str(LOCAL_MODEL_DIR),
31
  local_dir_use_symlinks=False
32
  )
33
+ print(f"✅ Đã tải xong vào {LOCAL_MODEL_DIR}")
34
 
35
  def get_or_load_model(device_str: str):
36
  """
37
+ Lấy hoặc load model ChatterboxTTS trên device 'cpu' hoặc 'cuda'.
38
+ device_str = "cpu" hoặc "gpu".
39
  """
40
+ device = "cuda" if (device_str == "gpu" and torch.cuda.is_available()) else "cpu"
41
  if device not in MODELS:
42
  print(f"📂 Loading model lên {device} …")
43
  model = ChatterboxTTS.from_local(str(LOCAL_MODEL_DIR), device)
44
  MODELS[device] = model
45
+ print(f"✅ Model đã được load lên {device}")
46
  return MODELS[device]
47
 
48
  def set_seed(seed: int):
 
54
  np.random.seed(seed)
55
 
56
  def chunk_text(text: str, chunk_size: int = 300):
57
+ """Chia text dài thành các đoạn tối đa chunk_size ký tự, giữ nguyên từ."""
58
  words = text.split()
59
+ chunks, current = [], ""
60
  for w in words:
61
+ if len(current) + len(w) + 1 > chunk_size:
62
+ chunks.append(current.strip())
63
+ current = w
64
  else:
65
+ current = f"{current} {w}".strip()
66
+ if current:
67
+ chunks.append(current.strip())
68
  return chunks
69
 
70
  def generate_tts_audio(
 
76
  seed_num: int,
77
  cfg_weight: float
78
  ):
79
+ """
80
+ Sinh audio từ văn bản không giới hạn: chia thành chunk, generate từng chunk, ghép nối.
81
+ Trả về (sample_rate, numpy.ndarray).
82
+ """
83
  model = get_or_load_model(device_choice)
84
+ if seed_num != 0:
85
  set_seed(int(seed_num))
86
 
87
+ chunks = chunk_text(text_input, chunk_size=300)
 
88
  waves, sr = [], model.sr
89
+
90
+ for idx, chunk in enumerate(chunks, start=1):
91
+ print(f"🔊 Sinh đoạn {idx}/{len(chunks)} trên {model.device}")
92
  wav = model.generate(
93
  chunk,
94
  audio_prompt_path=audio_prompt_path,
 
98
  )
99
  waves.append(wav.squeeze(0).cpu().numpy())
100
 
101
+ full_wave = np.concatenate(waves, axis=0)
102
+ print("✅ Hoàn thành sinh toàn bộ audio.")
103
+ return sr, full_wave
104
 
105
  # --- GIAO DIỆN GRADIO TIẾNG VIỆT ---
106
  with gr.Blocks(title="LTTEAM TTS") as demo:
107
+ gr.Markdown(
108
+ """
109
+ # LTTEAM TTS
110
+ **Phát triển bởi: Trần**
111
+ Ứng dụng chuyển văn bản thành giọng nói chất lượng cao, hỗ trợ đầu vào không giới hạn.
112
+ """
113
+ )
114
  with gr.Row():
115
  with gr.Column():
116
  device_choice = gr.Radio(
117
+ choices=["cpu", "gpu"],
118
  value="gpu" if torch.cuda.is_available() else "cpu",
119
  label="Chọn thiết bị"
120
  )
121
  text = gr.Textbox(
122
  label="Văn bản (không giới hạn độ dài)",
123
+ lines=8,
124
+ placeholder="Dán hoặc nhập văn bản vào đây..."
125
  )
126
  ref_wav = gr.Audio(
127
+ sources=["upload", "microphone"],
128
  type="filepath",
129
  label="Âm thanh mẫu (tùy chọn)"
130
  )
131
+ exaggeration = gr.Slider(
132
+ minimum=0.25, maximum=2, step=0.05,
133
+ value=0.5,
134
+ label="Mức nhấn nhá (Exaggeration)"
135
+ )
136
+ cfg_weight = gr.Slider(
137
+ minimum=0.2, maximum=1, step=0.05,
138
+ value=0.5,
139
+ label="Trọng số CFG / Tốc độ"
140
+ )
141
  with gr.Accordion("Tùy chọn thêm", open=False):
142
+ seed_num = gr.Number(0, label="Seed (0 = random)")
143
+ temperature = gr.Slider(
144
+ minimum=0.05, maximum=5, step=0.05,
145
+ value=0.8,
146
+ label="Nhiệt độ (Temperature)"
147
+ )
148
  run = gr.Button("Chuyển giọng", variant="primary")
149
 
150
  with gr.Column():
151
+ out_audio = gr.Audio(label="Kết quả âm thanh")
152
 
153
  run.click(
154
  fn=generate_tts_audio,
155
  inputs=[device_choice, text, ref_wav, exaggeration, temperature, seed_num, cfg_weight],
156
+ outputs=[out_audio],
157
  )
158
 
159
  if __name__ == "__main__":
160
+ # Phát hiện Colab qua biến môi trường
161
+ is_colab = "COLAB_GPU" in os.environ
162
+ if is_colab:
163
+ demo.launch(share=True)
164
+ else:
165
+ # Dùng host/port để hỗ trợ HuggingFace Spaces
166
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))