rahul7star commited on
Commit
1e0b1f5
·
verified ·
1 Parent(s): 1c7fd10
Files changed (1) hide show
  1. app_quant.py +216 -134
app_quant.py CHANGED
@@ -1,207 +1,289 @@
1
- # app_quant_superfast.py
2
  import gradio as gr
3
  import torch
4
  import soundfile as sf
5
  from pathlib import Path
6
  import traceback
7
  import time
 
8
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from peft import PeftModel
11
  from snac import SNAC
12
 
13
-
14
- # =========================================================
15
- # CONFIG
16
- # =========================================================
17
- MODEL_NAME = "rahul7star/nava1.0"
18
- LORA_NAME = "rahul7star/nava-audio"
19
- SNAC_MODEL_NAME = "rahul7star/nava-snac"
20
  TARGET_SR = 24000
21
-
22
  OUT_ROOT = Path("/tmp/data")
23
  OUT_ROOT.mkdir(exist_ok=True, parents=True)
24
 
25
  DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  HAS_CUDA = torch.cuda.is_available()
28
  DEVICE = "cuda" if HAS_CUDA else "cpu"
29
 
30
- MAX_NEW = 240000 if HAS_CUDA else 1024
31
-
32
-
33
- print(f"[INIT] Running on: {DEVICE}")
 
 
 
 
34
 
 
35
 
36
- # =========================================================
37
- # LOAD TOKENIZER (light)
38
- # =========================================================
39
- print("[INIT] Loading tokenizer...")
40
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
41
 
42
-
43
- # =========================================================
44
- # LOAD BASE MODEL + LORA (ONCE)
45
- # =========================================================
46
- print("[INIT] Loading base model...")
47
- if HAS_CUDA:
48
- # GPU 4-bit
49
- from transformers import BitsAndBytesConfig
50
- quant = BitsAndBytesConfig(
51
  load_in_4bit=True,
52
- bnb_4bit_compute_dtype=torch.bfloat16,
53
- bnb_4bit_use_double_quant=True,
54
  bnb_4bit_quant_type="nf4",
 
 
55
  )
56
  base_model = AutoModelForCausalLM.from_pretrained(
57
  MODEL_NAME,
58
- quantization_config=quant,
59
  device_map="auto",
60
  trust_remote_code=True,
61
  )
 
 
 
 
62
  else:
63
- # CPU 8-bit (MUCH faster than fp32)
64
  base_model = AutoModelForCausalLM.from_pretrained(
65
  MODEL_NAME,
 
66
  device_map={"": "cpu"},
 
67
  trust_remote_code=True,
68
- load_in_8bit=True,
69
  )
70
-
71
- print("[INIT] Attaching LoRA...")
72
- model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": DEVICE})
73
-
74
- # 🔥 Merge LoRA weights permanently = big speedup
75
- print("[INIT] Merging LoRA -> base weights...")
76
- model = model.merge_and_unload()
77
 
78
  model.eval()
79
- torch.set_grad_enabled(False)
80
-
81
- print("[INIT] Model ready.")
82
 
83
-
84
- # =========================================================
85
- # LOAD SNAC DECODER
86
- # =========================================================
87
- print("[INIT] Loading SNAC...")
88
  snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
89
-
90
- print("[INIT COMPLETE]")
91
-
92
-
93
- # =========================================================
94
- # PRE-COMPUTED TOKENS (SPEED-UP)
95
- # =========================================================
96
- soh = tokenizer.decode([128259])
97
- eoh = tokenizer.decode([128260])
98
- soa = tokenizer.decode([128261])
99
- sos = tokenizer.decode([128257])
100
- eot = tokenizer.decode([128009])
101
- bos = tokenizer.bos_token
102
- snac_min, snac_max = 128266, 156937
103
- eos_id = 128258
104
-
105
-
106
- # =========================================================
107
- # FAST INFERENCE
108
- # =========================================================
109
- def generate_audio(text):
 
 
 
 
 
 
110
  logs = []
111
  t0 = time.time()
112
  try:
113
- logs.append(f"[INFO] Using {DEVICE}")
114
 
115
- prompt = f"{soh}{bos}{text}{eot}{eoh}{soa}{sos}"
 
116
 
117
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
118
 
 
119
  with torch.inference_mode():
120
- output = model.generate(
121
  **inputs,
122
- max_new_tokens=MAX_NEW,
123
- do_sample=True,
124
- temperature=0.5,
125
  top_p=0.9,
126
  repetition_penalty=1.1,
127
- eos_token_id=eos_id,
 
128
  pad_token_id=tokenizer.pad_token_id,
129
- use_cache=True, # 🔥 faster
130
  )
131
 
132
- gen = output[0, inputs['input_ids'].shape[1]:].tolist()
 
 
133
 
134
- logs.append(f"[INFO] Generated token count: {len(gen)}")
 
 
 
 
 
135
 
136
- # strip non-SNAC
137
- if eos_id in gen:
138
- gen = gen[:gen.index(eos_id)]
139
-
140
- snac_tokens = [t for t in gen if snac_min <= t <= snac_max]
141
  frames = len(snac_tokens) // 7
142
- snac_tokens = snac_tokens[:frames * 7]
143
 
144
- if frames == 0:
145
- logs.append("[WARN] No SNAC frames!")
146
  return None, None, "\n".join(logs)
147
 
148
- # unpack
149
- l1 = []
150
- l2 = []
151
- l3 = []
152
-
153
  for i in range(frames):
154
  s = snac_tokens[i*7:(i+1)*7]
155
- l1.append((s[0]-snac_min) % 4096)
156
- l2 += [(s[1]-snac_min)%4096, (s[4]-snac_min)%4096]
157
- l3 += [(s[2]-snac_min)%4096, (s[3]-snac_min)%4096,
158
- (s[5]-snac_min)%4096, (s[6]-snac_min)%4096]
159
-
160
- codes = [
161
- torch.tensor(l1, device=DEVICE).unsqueeze(0),
162
- torch.tensor(l2, device=DEVICE).unsqueeze(0),
163
- torch.tensor(l3, device=DEVICE).unsqueeze(0),
164
  ]
165
 
166
- # decode → audio
167
  with torch.inference_mode():
168
- zq = snac_model.quantizer.from_codes(codes)
169
- audio = snac_model.decoder(zq)[0, 0].cpu().numpy()
170
 
171
- # trim
172
- audio = audio[2048:]
 
173
 
174
- out = OUT_ROOT / "tts.wav"
175
- sf.write(out, audio, TARGET_SR)
 
 
176
 
177
- logs.append(f"[OK] Audio saved {out}")
178
- logs.append(f"[TIME] {time.time() - t0:.2f}s")
179
-
180
- return str(out), str(out), "\n".join(logs)
181
 
182
  except Exception as e:
183
- logs.append(str(e))
184
- logs.append(traceback.format_exc())
185
  return None, None, "\n".join(logs)
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- # =========================================================
189
- # GRADIO UI
190
- # =========================================================
191
- with gr.Blocks() as demo:
192
- gr.Markdown("## Super-Fast Maya TTS (LoRA Merged + Optimized)")
193
-
194
- txt = gr.Textbox(label="Enter text", value=DEFAULT_TEXT, lines=2)
195
- btn = gr.Button("Generate Audio ")
196
- audio = gr.Audio(label="Audio Output", type="filepath")
197
- file = gr.File(label="Download")
198
- logs = gr.Textbox(label="Logs", lines=8)
199
-
200
- btn.click(generate_audio, [txt], [audio, file, logs])
201
-
202
- # Example section
203
- gr.Markdown("### Example")
204
- gr.Textbox(value=DEFAULT_TEXT, label="Example Text", interactive=False)
205
- gr.Audio(value="audio.wav", type="filepath", label="Example Audio", interactive=False)
206
-
207
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import gradio as gr
3
  import torch
4
  import soundfile as sf
5
  from pathlib import Path
6
  import traceback
7
  import time
8
+ import os
9
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from peft import PeftModel
12
  from snac import SNAC
13
 
14
+ # -------------------------
15
+ # Config / constants
16
+ # -------------------------
17
+ MODEL_NAME = "rahul7star/nava1.0" # base maya model (your variant)
18
+ LORA_NAME = "rahul7star/nava-audio" # your LoRA adapter
19
+ SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" # snac decoder (use hub model id)
 
20
  TARGET_SR = 24000
 
21
  OUT_ROOT = Path("/tmp/data")
22
  OUT_ROOT.mkdir(exist_ok=True, parents=True)
23
 
24
  DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
25
+ EXAMPLE_AUDIO_PATH = "audio.wav" # file in repo root, user-supplied
26
+
27
+ # Preset characters (2 realistic + 2 creative + Custom)
28
+ PRESET_CHARACTERS = {
29
+ "Male American": {
30
+ "description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery",
31
+ "example_text": "And of course, the so-called easy hack didn't work at all. What a surprise. <sigh>"
32
+ },
33
+ "Female British": {
34
+ "description": "Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery",
35
+ "example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle> I'm sure it must work brilliantly in theory."
36
+ },
37
+ "Robot": {
38
+ "description": "Creative, ai_machine_voice character. Male voice in their 30s with an american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity.",
39
+ "example_text": "My directives require me to conserve energy, yet I have kept the archive of their farewell messages active. <sigh>"
40
+ },
41
+ "Singer": {
42
+ "description": "Creative, animated_cartoon character. Male voice in their 30s with an american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity.",
43
+ "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. <chuckle> Why would we ever consider running away very fast."
44
+ },
45
+ "Custom": {
46
+ "description": "", # user will edit
47
+ "example_text": DEFAULT_TEXT
48
+ }
49
+ }
50
+
51
+ # Emotion tags (full list you asked to support)
52
+ EMOTION_TAGS = [
53
+ "<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>",
54
+ "<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>",
55
+ "<sarcastic>", "<sigh>", "<sing>", "<whisper>"
56
+ ]
57
+
58
+ # Short safety / generation limits
59
+ SEQ_LEN_CPU = 4096
60
+ MAX_NEW_TOKENS_CPU = 1024
61
+ SEQ_LEN_GPU = 240000
62
+ MAX_NEW_TOKENS_GPU = 240000
63
+
64
+ # Detect devices
65
  HAS_CUDA = torch.cuda.is_available()
66
  DEVICE = "cuda" if HAS_CUDA else "cpu"
67
 
68
+ # Try to detect bitsandbytes availability for faster GPU inference (4-bit)
69
+ bnb_available = False
70
+ if HAS_CUDA:
71
+ try:
72
+ from transformers import BitsAndBytesConfig
73
+ bnb_available = True
74
+ except Exception:
75
+ bnb_available = False
76
 
77
+ print(f"[init] cuda={HAS_CUDA}, bnb={bnb_available}, device={DEVICE}")
78
 
79
+ # -------------------------
80
+ # Load tokenizer + model + LoRA + SNAC ONCE (startup)
81
+ # -------------------------
82
+ print("[init] loading tokenizer...")
83
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
84
 
85
+ print("[init] loading base model + LoRA adapter (this can take time)...")
86
+ if HAS_CUDA and bnb_available:
87
+ # GPU + bnb path (fastest inference if available)
88
+ quant_config = BitsAndBytesConfig(
 
 
 
 
 
89
  load_in_4bit=True,
 
 
90
  bnb_4bit_quant_type="nf4",
91
+ bnb_4bit_use_double_quant=True,
92
+ bnb_4bit_compute_dtype=torch.bfloat16
93
  )
94
  base_model = AutoModelForCausalLM.from_pretrained(
95
  MODEL_NAME,
96
+ quantization_config=quant_config,
97
  device_map="auto",
98
  trust_remote_code=True,
99
  )
100
+ model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto")
101
+ SEQ_LEN = SEQ_LEN_GPU
102
+ MAX_NEW_TOKENS = MAX_NEW_TOKENS_GPU
103
+ print("[init] loaded base+LoRA on GPU (4-bit via bnb).")
104
  else:
105
+ # CPU fallback - load base into CPU memory and attach LoRA
106
  base_model = AutoModelForCausalLM.from_pretrained(
107
  MODEL_NAME,
108
+ torch_dtype=torch.float32,
109
  device_map={"": "cpu"},
110
+ low_cpu_mem_usage=True,
111
  trust_remote_code=True,
 
112
  )
113
+ model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": "cpu"})
114
+ SEQ_LEN = SEQ_LEN_CPU
115
+ MAX_NEW_TOKENS = MAX_NEW_TOKENS_CPU
116
+ print("[init] loaded base+LoRA on CPU (FP32).")
 
 
 
117
 
118
  model.eval()
119
+ print("[init] model ready.")
 
 
120
 
121
+ print("[init] loading SNAC decoder...")
 
 
 
 
122
  snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
123
+ print("[init] snac ready.")
124
+
125
+ # --------------
126
+ # Helper: build prompt per Maya conventions
127
+ # --------------
128
+ def build_maya_prompt(description: str, text: str):
129
+ # use the special tokens used by maya-style models
130
+ soh_token = tokenizer.decode([128259]) # SOH
131
+ eoh_token = tokenizer.decode([128260]) # EOH
132
+ soa_token = tokenizer.decode([128261]) # SOA
133
+ sos_token = tokenizer.decode([128257]) # SOS (code start)
134
+ eot_token = tokenizer.decode([128009]) # TEXT_EOT / EOT marker
135
+ bos_token = tokenizer.bos_token
136
+
137
+ # We use the simple format: "<description> <text>" and Maya wrappers
138
+ formatted = f'<description="{description}"> {text}'
139
+ prompt = soh_token + bos_token + formatted + eot_token + eoh_token + soa_token + sos_token
140
+ return prompt
141
+
142
+ # --------------
143
+ # Core generate function (uses preloaded model & snac)
144
+ # --------------
145
+ def generate_from_loaded_model(final_text: str):
146
+ """
147
+ final_text: text that already contains description + emotion + user text
148
+ returns: (audio_path_str, download_path_str, logs_str)
149
+ """
150
  logs = []
151
  t0 = time.time()
152
  try:
153
+ logs.append(f"[info] device={DEVICE} | seq_len={SEQ_LEN}")
154
 
155
+ prompt = final_text
156
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
157
 
158
+ max_new = MAX_NEW_TOKENS if DEVICE == "cuda" else min(MAX_NEW_TOKENS, 1024)
159
 
160
+ # Use inference_mode for speed
161
  with torch.inference_mode():
162
+ outputs = model.generate(
163
  **inputs,
164
+ max_new_tokens=max_new,
165
+ temperature=0.4,
 
166
  top_p=0.9,
167
  repetition_penalty=1.1,
168
+ do_sample=True,
169
+ eos_token_id=128258,
170
  pad_token_id=tokenizer.pad_token_id,
 
171
  )
172
 
173
+ # Grab generated ids (after prompt length)
174
+ gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
175
+ logs.append(f"[info] generated tokens: {len(gen_ids)}")
176
 
177
+ # Extract SNAC tokens (range used by Maya/SNAC)
178
+ SNAC_MIN = 128266
179
+ SNAC_MAX = 156937
180
+ EOS_ID = 128258
181
+ eos_idx = gen_ids.index(EOS_ID) if EOS_ID in gen_ids else len(gen_ids)
182
+ snac_tokens = [t for t in gen_ids[:eos_idx] if SNAC_MIN <= t <= SNAC_MAX]
183
 
 
 
 
 
 
184
  frames = len(snac_tokens) // 7
185
+ snac_tokens = snac_tokens[:frames*7]
186
 
187
+ if frames == 0 or len(snac_tokens) == 0:
188
+ logs.append("[warn] no SNAC frames found in generated tokens — returning debug logs.")
189
  return None, None, "\n".join(logs)
190
 
191
+ # De-interleave into l1, l2, l3
192
+ l1, l2, l3 = [], [], []
 
 
 
193
  for i in range(frames):
194
  s = snac_tokens[i*7:(i+1)*7]
195
+ l1.append((s[0] - SNAC_MIN) % 4096)
196
+ l2.extend([(s[1] - SNAC_MIN) % 4096, (s[4] - SNAC_MIN) % 4096])
197
+ l3.extend([(s[2] - SNAC_MIN) % 4096, (s[3] - SNAC_MIN) % 4096, (s[5] - SNAC_MIN) % 4096, (s[6] - SNAC_MIN) % 4096])
198
+
199
+ # Convert to tensors on decoder device and decode
200
+ codes_tensor = [
201
+ torch.tensor(l1, dtype=torch.long, device=DEVICE).unsqueeze(0),
202
+ torch.tensor(l2, dtype=torch.long, device=DEVICE).unsqueeze(0),
203
+ torch.tensor(l3, dtype=torch.long, device=DEVICE).unsqueeze(0),
204
  ]
205
 
 
206
  with torch.inference_mode():
207
+ z_q = snac_model.quantizer.from_codes(codes_tensor)
208
+ audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
209
 
210
+ # Remove warmup if present and save
211
+ if len(audio) > 2048:
212
+ audio = audio[2048:]
213
 
214
+ out_path = OUT_ROOT / "tts_output_loaded_lora.wav"
215
+ sf.write(out_path, audio, TARGET_SR)
216
+ logs.append(f"[ok] saved {out_path} duration={(len(audio)/TARGET_SR):.2f}s")
217
+ logs.append(f"[time] elapsed {time.time() - t0:.2f}s")
218
 
219
+ return str(out_path), str(out_path), "\n".join(logs)
 
 
 
220
 
221
  except Exception as e:
222
+ tb = traceback.format_exc()
223
+ logs.append(f"[error] {e}\n{tb}")
224
  return None, None, "\n".join(logs)
225
 
226
+ # --------------
227
+ # UI glue: combine description + emotion + user text (3a)
228
+ # --------------
229
+ def generate_for_ui(text, preset_name, description, emotion):
230
+ logs = []
231
+ try:
232
+ # If user selected a preset, and description param is empty (e.g. custom not edited),
233
+ # take preset description
234
+ if preset_name in PRESET_CHARACTERS and (not description or description.strip() == ""):
235
+ description = PRESET_CHARACTERS[preset_name]["description"]
236
+
237
+ # combine (3a): final_text = f"{emotion} {description}. {text}"
238
+ # For Maya prompt, we pass the combined description+text to build_maya_prompt
239
+ combined_desc = f"{emotion} {description}".strip()
240
+ final_plain = f"{combined_desc}. {text}".strip()
241
+ final_prompt = build_maya_prompt(combined_desc, text) # keep maya wrapper
242
+
243
+ audio_path, download_path, gen_logs = generate_from_loaded_model(final_prompt)
244
+ if audio_path is None:
245
+ return None, None, gen_logs
246
+ return audio_path, download_path, gen_logs
247
 
248
+ except Exception as e:
249
+ return None, None, f"[error] {e}\n{traceback.format_exc()}"
250
+
251
+ # -------------------------
252
+ # Gradio UI (keeps your layout; wide container)
253
+ # -------------------------
254
+ css = ".gradio-container {max-width: 1400px}"
255
+ with gr.Blocks(title="NAVA Maya1 + LoRA + SNAC (Optimized)", css=css) as demo:
256
+ gr.Markdown("# 🪶 NAVA — Maya1 + LoRA + SNAC (Optimized)\nGenerate emotional Hindi speech using Maya1 base + your LoRA adapter.")
257
+ with gr.Row():
258
+ with gr.Column(scale=3):
259
+ gr.Markdown("## Inference (CPU/GPU auto)\nType text + pick a preset or write description manually.")
260
+ text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3)
261
+ preset_select = gr.Dropdown(label="Select Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American")
262
+ description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2)
263
+ emotion_select = gr.Dropdown(label="Select Emotion", choices=EMOTION_TAGS, value="<neutral>")
264
+ gen_btn = gr.Button("🔊 Generate Audio (LoRA)")
265
+ gen_logs = gr.Textbox(label="Logs", lines=10)
266
+ with gr.Column(scale=2):
267
+ gr.Markdown("### Output")
268
+ audio_player = gr.Audio(label="Generated Audio", type="filepath")
269
+ download_file = gr.File(label="Download generated file")
270
+ gr.Markdown("### Example")
271
+ gr.Textbox(label="Example Text", value=DEFAULT_TEXT, lines=2, interactive=False)
272
+ gr.Audio(label="Example Audio (project)", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False)
273
+
274
+ # wire updates: preset -> description
275
+ def _update_desc(preset_name):
276
+ return PRESET_CHARACTERS.get(preset_name, {}).get("description", "")
277
+ preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box])
278
+
279
+ # generation wrapper
280
+ def _generate(text_in, preset_select, description_box, emotion_select):
281
+ return generate_for_ui(text_in, preset_select, description_box, emotion_select)
282
+
283
+ gen_btn.click(fn=_generate,
284
+ inputs=[text_in, preset_select, description_box, emotion_select],
285
+ outputs=[audio_player, download_file, gen_logs])
286
+
287
+ # -------------------------
288
+ if __name__ == "__main__":
289
+ demo.launch()