LTTEAM commited on
Commit
0ffc52b
·
verified ·
1 Parent(s): f808a31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -75
app.py CHANGED
@@ -3,46 +3,36 @@ import random
3
  import numpy as np
4
  import torch
5
  from pathlib import Path
6
-
7
- # Đảm bảo torch.load luôn map về CPU khi cần
8
- _orig_torch_load = torch.load
9
- def _torch_load_cpu(f, *args, **kwargs):
10
- if "map_location" not in kwargs:
11
- kwargs["map_location"] = torch.device("cpu")
12
- return _orig_torch_load(f, *args, **kwargs)
13
- torch.load = _torch_load_cpu
14
-
15
  from huggingface_hub import snapshot_download
16
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
17
  import gradio as gr
18
 
19
- # --- CẤU HÌNH MODEL TỪ HUGGINGFACE ---
20
- MODEL_REPO = "LTTEAM/TTS_Pro"
21
- LOCAL_MODEL_DIR = Path(os.getcwd()) / "models" / "tts_pro"
22
  MODELS = {}
23
 
24
- # Download model một lần (cache)
25
  if not LOCAL_MODEL_DIR.exists():
26
- print(f"📥 Đang tải model từ HuggingFace repo {MODEL_REPO} …")
27
  snapshot_download(
28
  repo_id=MODEL_REPO,
29
  repo_type="model",
30
  local_dir=str(LOCAL_MODEL_DIR),
31
  local_dir_use_symlinks=False
32
  )
33
- print(f"✅ Đã tải xong vào {LOCAL_MODEL_DIR}")
34
 
35
  def get_or_load_model(device_str: str):
36
  """
37
- Lấy hoặc load model ChatterboxTTS trên device 'cpu' hoặc 'cuda'.
38
- device_str = "cpu" hoặc "gpu".
39
  """
40
- device = "cuda" if (device_str == "gpu" and torch.cuda.is_available()) else "cpu"
41
  if device not in MODELS:
42
  print(f"📂 Loading model lên {device} …")
43
  model = ChatterboxTTS.from_local(str(LOCAL_MODEL_DIR), device)
44
  MODELS[device] = model
45
- print(f"✅ Model đã được load lên {device}")
46
  return MODELS[device]
47
 
48
  def set_seed(seed: int):
@@ -54,17 +44,17 @@ def set_seed(seed: int):
54
  np.random.seed(seed)
55
 
56
  def chunk_text(text: str, chunk_size: int = 300):
57
- """Chia text dài thành các đoạn tối đa chunk_size ký tự, giữ nguyên từ."""
58
  words = text.split()
59
- chunks, current = [], ""
60
  for w in words:
61
- if len(current) + len(w) + 1 > chunk_size:
62
- chunks.append(current.strip())
63
- current = w
64
  else:
65
- current = f"{current} {w}".strip()
66
- if current:
67
- chunks.append(current.strip())
68
  return chunks
69
 
70
  def generate_tts_audio(
@@ -76,19 +66,15 @@ def generate_tts_audio(
76
  seed_num: int,
77
  cfg_weight: float
78
  ):
79
- """
80
- Sinh audio từ văn bản không giới hạn: chia thành chunk, generate từng chunk, ghép nối.
81
- Trả về (sample_rate, numpy.ndarray).
82
- """
83
  model = get_or_load_model(device_choice)
84
- if seed_num != 0:
85
  set_seed(int(seed_num))
86
 
87
- chunks = chunk_text(text_input, chunk_size=300)
 
88
  waves, sr = [], model.sr
89
-
90
- for idx, chunk in enumerate(chunks, start=1):
91
- print(f"🔊 Sinh đoạn {idx}/{len(chunks)} trên {model.device}")
92
  wav = model.generate(
93
  chunk,
94
  audio_prompt_path=audio_prompt_path,
@@ -98,69 +84,49 @@ def generate_tts_audio(
98
  )
99
  waves.append(wav.squeeze(0).cpu().numpy())
100
 
101
- full_wave = np.concatenate(waves, axis=0)
102
- print("✅ Hoàn thành sinh toàn bộ audio.")
103
- return sr, full_wave
104
 
105
  # --- GIAO DIỆN GRADIO TIẾNG VIỆT ---
106
  with gr.Blocks(title="LTTEAM TTS") as demo:
107
- gr.Markdown(
108
- """
109
- # LTTEAM TTS
110
- **Phát triển bởi: Trần**
111
- Ứng dụng chuyển văn bản thành giọng nói chất lượng cao, hỗ trợ đầu vào không giới hạn.
112
- """
113
- )
114
  with gr.Row():
115
  with gr.Column():
116
  device_choice = gr.Radio(
117
- choices=["cpu", "gpu"],
118
  value="gpu" if torch.cuda.is_available() else "cpu",
119
  label="Chọn thiết bị"
120
  )
121
  text = gr.Textbox(
122
  label="Văn bản (không giới hạn độ dài)",
123
- lines=8,
124
- placeholder="Dán hoặc nhập văn bản vào đây..."
125
  )
126
  ref_wav = gr.Audio(
127
- sources=["upload", "microphone"],
128
  type="filepath",
129
  label="Âm thanh mẫu (tùy chọn)"
130
  )
131
- exaggeration = gr.Slider(
132
- minimum=0.25, maximum=2, step=0.05,
133
- value=0.5,
134
- label="Mức nhấn nhá (Exaggeration)"
135
- )
136
- cfg_weight = gr.Slider(
137
- minimum=0.2, maximum=1, step=0.05,
138
- value=0.5,
139
- label="Trọng số CFG / Tốc độ"
140
- )
141
  with gr.Accordion("Tùy chọn thêm", open=False):
142
- seed_num = gr.Number(0, label="Seed (0 = random)")
143
- temperature = gr.Slider(
144
- minimum=0.05, maximum=5, step=0.05,
145
- value=0.8,
146
- label="Nhiệt độ (Temperature)"
147
- )
148
  run = gr.Button("Chuyển giọng", variant="primary")
149
 
150
  with gr.Column():
151
- out_audio = gr.Audio(label="Kết quả âm thanh")
152
 
153
  run.click(
154
  fn=generate_tts_audio,
155
  inputs=[device_choice, text, ref_wav, exaggeration, temperature, seed_num, cfg_weight],
156
- outputs=[out_audio],
157
  )
158
 
159
  if __name__ == "__main__":
160
- # Phát hiện Colab qua biến môi trường
161
- is_colab = "COLAB_GPU" in os.environ
162
- if is_colab:
163
- demo.launch(share=True)
164
- else:
165
- # Dùng host/port để hỗ trợ HuggingFace Spaces
166
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
3
  import numpy as np
4
  import torch
5
  from pathlib import Path
 
 
 
 
 
 
 
 
 
6
  from huggingface_hub import snapshot_download
7
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
8
  import gradio as gr
9
 
10
+ # --- CẤU HÌNH ---
11
+ MODEL_REPO = "LTTEAM/TTS_Pro"
12
+ LOCAL_MODEL_DIR = Path("models") / "tts_pro"
13
  MODELS = {}
14
 
15
+ # 1) Download model từ HF Hub (lần đầu) vào thư mục models/tts_pro
16
  if not LOCAL_MODEL_DIR.exists():
17
+ os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
18
  snapshot_download(
19
  repo_id=MODEL_REPO,
20
  repo_type="model",
21
  local_dir=str(LOCAL_MODEL_DIR),
22
  local_dir_use_symlinks=False
23
  )
 
24
 
25
  def get_or_load_model(device_str: str):
26
  """
27
+ Load ChatterboxTTS trên 'cpu' hoặc 'cuda'.
28
+ device_str: 'cpu' hoặc 'gpu'
29
  """
30
+ device = "cuda" if (device_str=="gpu" and torch.cuda.is_available()) else "cpu"
31
  if device not in MODELS:
32
  print(f"📂 Loading model lên {device} …")
33
  model = ChatterboxTTS.from_local(str(LOCAL_MODEL_DIR), device)
34
  MODELS[device] = model
35
+ print(f"✅ Model đã load trên {device}")
36
  return MODELS[device]
37
 
38
  def set_seed(seed: int):
 
44
  np.random.seed(seed)
45
 
46
  def chunk_text(text: str, chunk_size: int = 300):
47
+ """Chia text dài thành các đoạn chunk_size, giữ nguyên từ."""
48
  words = text.split()
49
+ chunks, cur = [], ""
50
  for w in words:
51
+ if len(cur)+len(w)+1 > chunk_size:
52
+ chunks.append(cur.strip())
53
+ cur = w
54
  else:
55
+ cur = f"{cur} {w}".strip()
56
+ if cur:
57
+ chunks.append(cur.strip())
58
  return chunks
59
 
60
  def generate_tts_audio(
 
66
  seed_num: int,
67
  cfg_weight: float
68
  ):
 
 
 
 
69
  model = get_or_load_model(device_choice)
70
+ if seed_num!=0:
71
  set_seed(int(seed_num))
72
 
73
+ # chia thành chunks sinh từng phần
74
+ chunks = chunk_text(text_input, 300)
75
  waves, sr = [], model.sr
76
+ for i, chunk in enumerate(chunks, 1):
77
+ print(f"🔊 Sinh đoạn {i}/{len(chunks)}...")
 
78
  wav = model.generate(
79
  chunk,
80
  audio_prompt_path=audio_prompt_path,
 
84
  )
85
  waves.append(wav.squeeze(0).cpu().numpy())
86
 
87
+ full = np.concatenate(waves, axis=0)
88
+ return sr, full
 
89
 
90
  # --- GIAO DIỆN GRADIO TIẾNG VIỆT ---
91
  with gr.Blocks(title="LTTEAM TTS") as demo:
92
+ gr.Markdown("""
93
+ # LTTEAM TTS
94
+ **Phát triển bởi: Lý Trần**
95
+ Chuyển văn bản thành giọng nói chất lượng cao, không giới hạn độ dài.
96
+ """)
 
 
97
  with gr.Row():
98
  with gr.Column():
99
  device_choice = gr.Radio(
100
+ ["cpu","gpu"],
101
  value="gpu" if torch.cuda.is_available() else "cpu",
102
  label="Chọn thiết bị"
103
  )
104
  text = gr.Textbox(
105
  label="Văn bản (không giới hạn độ dài)",
106
+ lines=6,
107
+ placeholder="Nhập hoặc dán văn bản..."
108
  )
109
  ref_wav = gr.Audio(
110
+ sources=["upload","microphone"],
111
  type="filepath",
112
  label="Âm thanh mẫu (tùy chọn)"
113
  )
114
+ exaggeration = gr.Slider(0.25,2,step=0.05,value=0.5,label="Mức nhấn nhá")
115
+ cfg_weight = gr.Slider(0.2,1,step=0.05,value=0.5,label="CFG/Pace weight")
 
 
 
 
 
 
 
 
116
  with gr.Accordion("Tùy chọn thêm", open=False):
117
+ seed_num = gr.Number(0, label="Seed (0=random)")
118
+ temperature = gr.Slider(0.05,5,step=0.05,value=0.8,label="Nhiệt độ")
 
 
 
 
119
  run = gr.Button("Chuyển giọng", variant="primary")
120
 
121
  with gr.Column():
122
+ out = gr.Audio(label="Kết quả")
123
 
124
  run.click(
125
  fn=generate_tts_audio,
126
  inputs=[device_choice, text, ref_wav, exaggeration, temperature, seed_num, cfg_weight],
127
+ outputs=[out],
128
  )
129
 
130
  if __name__ == "__main__":
131
+ # Trên Spaces, chỉ cần host 0.0.0.0, port mặc định
132
+ demo.launch(server_name="0.0.0.0", server_port=7860)