rahul7star commited on
Commit
6befd58
·
verified ·
1 Parent(s): 6df9273

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -67
main.py CHANGED
@@ -1,59 +1,38 @@
1
- import spaces
2
  import os
 
3
  import torch
4
  import soundfile as sf
5
  import logging
6
  import gradio as gr
7
  import librosa
8
  import numpy as np
 
 
9
  from datetime import datetime
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from ncodec.codec import TTSCodec
12
 
13
- # ----------------- Logging -----------------
14
  logging.basicConfig(
15
  level=logging.INFO,
16
  format="%(asctime)s - %(levelname)s - %(message)s"
17
  )
18
 
19
- # ----------------- Globals -----------------
20
- MODEL_PIPE = None
21
- TOKENIZER = None
22
  CODEC = None
 
23
 
24
- # ----------------- Model Initialization (CPU ONLY) -----------------
25
- def initialize_model():
26
- global MODEL_PIPE, TOKENIZER, CODEC
27
-
28
- if MODEL_PIPE is not None:
29
- return MODEL_PIPE
30
-
31
- logging.info("Loading MiraTTS model on CPU...")
32
-
33
- model_name = "rahul7star/mir-TTS"
34
-
35
- TOKENIZER = AutoTokenizer.from_pretrained(model_name)
36
-
37
- MODEL_PIPE = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.float32, # CPU safe
40
- device_map=None
41
- )
42
-
43
- MODEL_PIPE.eval()
44
- MODEL_PIPE.to("cpu") # 🔒 CPU only
45
-
46
- CODEC = TTSCodec()
47
 
48
- logging.info("Model loaded successfully on CPU")
49
- return MODEL_PIPE
 
 
 
 
50
 
51
 
52
- # 🔹 Load model ONCE at startup (CPU safe)
53
- MODEL_PIPE = initialize_model()
54
-
55
-
56
- # ----------------- Audio Utilities -----------------
57
  def validate_audio_input(audio_path):
58
  if not audio_path or not os.path.exists(audio_path):
59
  raise ValueError("Audio file not found")
@@ -69,77 +48,95 @@ def validate_audio_input(audio_path):
69
 
70
  audio = audio / np.max(np.abs(audio))
71
 
72
- temp_path = f"/tmp/processed_{os.path.basename(audio_path)}"
73
- sf.write(temp_path, audio, sr)
74
 
75
- return temp_path
76
 
77
 
78
- # ----------------- TTS Generation (GPU ONLY) -----------------
79
  @spaces.GPU()
80
  def generate_speech(text, audio_path):
81
- global MODEL_PIPE
82
 
83
  if not text or not text.strip():
84
  raise ValueError("Text input is empty")
85
 
86
- # 🔥 Move model to GPU only here
87
- if not next(MODEL_PIPE.parameters()).is_cuda:
88
- logging.info("Moving model to GPU for generation")
89
- MODEL_PIPE.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  processed_audio = validate_audio_input(audio_path)
92
  context_tokens = CODEC.encode(processed_audio)
93
 
94
  prompt = CODEC.format_prompt(text, context_tokens, None)
95
 
96
- inputs = TOKENIZER(prompt, return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
97
 
98
- with torch.no_grad():
99
- outputs = MODEL_PIPE.generate(
100
- **inputs,
101
- max_new_tokens=1024,
102
- top_p=0.95,
103
- top_k=50,
104
- temperature=0.8,
105
- repetition_penalty=1.2,
106
- )
107
 
108
- generated_text = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
109
- audio = CODEC.decode(generated_text, context_tokens)
110
 
111
  if torch.is_tensor(audio):
112
- audio = audio.cpu().numpy()
113
 
114
- # 🧹 Cleanup GPU memory
115
- del inputs, outputs
 
116
  torch.cuda.empty_cache()
117
 
118
  return audio, 48000
119
 
120
 
121
- # ----------------- Gradio Interface -----------------
122
  def voice_clone_interface(text, upload_audio, record_audio):
123
  try:
124
- audio_path = upload_audio if upload_audio else record_audio
125
  if not audio_path:
126
- return None, "Please upload or record reference audio."
127
 
128
  audio, sr = generate_speech(text, audio_path)
129
 
130
  os.makedirs("outputs", exist_ok=True)
131
- out_path = f"outputs/mira_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
132
  sf.write(out_path, audio, sr)
133
 
134
  return out_path, "✅ Generation successful"
135
 
136
  except Exception as e:
137
- return None, f"❌ Error: {str(e)}"
 
138
 
139
 
140
  def build_interface():
141
  with gr.Blocks(title="MiraTTS Voice Cloning") as demo:
142
- gr.Markdown("# 🎤 MiraTTS Voice Cloning")
143
 
144
  with gr.Row():
145
  with gr.Column():
@@ -162,7 +159,8 @@ def build_interface():
162
  return demo
163
 
164
 
165
- # ----------------- Main -----------------
166
  if __name__ == "__main__":
 
167
  demo = build_interface()
168
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import os
2
+ import gc
3
  import torch
4
  import soundfile as sf
5
  import logging
6
  import gradio as gr
7
  import librosa
8
  import numpy as np
9
+ import spaces
10
+
11
  from datetime import datetime
12
+ from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
13
  from ncodec.codec import TTSCodec
14
 
15
+ # ---------------- Logging ----------------
16
  logging.basicConfig(
17
  level=logging.INFO,
18
  format="%(asctime)s - %(levelname)s - %(message)s"
19
  )
20
 
21
+ # ---------------- Globals ----------------
22
+ GPU_PIPE = None
 
23
  CODEC = None
24
+ MODEL_ID = "rahul7star/mir-TTS"
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # ---------------- CPU Init (SAFE) ----------------
28
+ def initialize_cpu():
29
+ global CODEC
30
+ if CODEC is None:
31
+ logging.info("Initializing CPU components")
32
+ CODEC = TTSCodec()
33
 
34
 
35
+ # ---------------- Audio Utils ----------------
 
 
 
 
36
  def validate_audio_input(audio_path):
37
  if not audio_path or not os.path.exists(audio_path):
38
  raise ValueError("Audio file not found")
 
48
 
49
  audio = audio / np.max(np.abs(audio))
50
 
51
+ tmp_path = f"/tmp/processed_{os.path.basename(audio_path)}"
52
+ sf.write(tmp_path, audio, sr)
53
 
54
+ return tmp_path
55
 
56
 
57
+ # ---------------- GPU TTS ----------------
58
  @spaces.GPU()
59
  def generate_speech(text, audio_path):
60
+ global GPU_PIPE, CODEC
61
 
62
  if not text or not text.strip():
63
  raise ValueError("Text input is empty")
64
 
65
+ initialize_cpu()
66
+
67
+ # 🔥 Load GPU pipeline lazily (CORRECT)
68
+ if GPU_PIPE is None:
69
+ logging.info("Loading MiraTTS pipeline on GPU")
70
+
71
+ backend_config = TurbomindEngineConfig(
72
+ tp=1,
73
+ device="cuda",
74
+ dtype="bfloat16",
75
+ enable_prefix_caching=False,
76
+ cache_max_entry_count=0.1,
77
+ )
78
+
79
+ GPU_PIPE = pipeline(
80
+ MODEL_ID,
81
+ backend_config=backend_config
82
+ )
83
 
84
  processed_audio = validate_audio_input(audio_path)
85
  context_tokens = CODEC.encode(processed_audio)
86
 
87
  prompt = CODEC.format_prompt(text, context_tokens, None)
88
 
89
+ gen_cfg = GenerationConfig(
90
+ top_p=0.95,
91
+ top_k=50,
92
+ temperature=0.8,
93
+ max_new_tokens=1024,
94
+ repetition_penalty=1.2,
95
+ do_sample=True,
96
+ )
97
 
98
+ response = GPU_PIPE(
99
+ [prompt],
100
+ gen_config=gen_cfg,
101
+ do_preprocess=False
102
+ )
 
 
 
 
103
 
104
+ audio = CODEC.decode(response[0].text, context_tokens)
 
105
 
106
  if torch.is_tensor(audio):
107
+ audio = audio.float().cpu().numpy() # force float32
108
 
109
+ # 🧹 Cleanup
110
+ os.remove(processed_audio)
111
+ gc.collect()
112
  torch.cuda.empty_cache()
113
 
114
  return audio, 48000
115
 
116
 
117
+ # ---------------- Gradio ----------------
118
  def voice_clone_interface(text, upload_audio, record_audio):
119
  try:
120
+ audio_path = upload_audio or record_audio
121
  if not audio_path:
122
+ return None, "Upload or record reference audio"
123
 
124
  audio, sr = generate_speech(text, audio_path)
125
 
126
  os.makedirs("outputs", exist_ok=True)
127
+ out_path = f"outputs/mira_{datetime.now():%Y%m%d_%H%M%S}.wav"
128
  sf.write(out_path, audio, sr)
129
 
130
  return out_path, "✅ Generation successful"
131
 
132
  except Exception as e:
133
+ logging.error(e)
134
+ return None, f"❌ {str(e)}"
135
 
136
 
137
  def build_interface():
138
  with gr.Blocks(title="MiraTTS Voice Cloning") as demo:
139
+ gr.Markdown("# 🎤 MiraTTS Voice Cloning")
140
 
141
  with gr.Row():
142
  with gr.Column():
 
159
  return demo
160
 
161
 
162
+ # ---------------- Main ----------------
163
  if __name__ == "__main__":
164
+ initialize_cpu()
165
  demo = build_interface()
166
  demo.launch(server_name="0.0.0.0", server_port=7860)