crevans commited on
Commit
eb1ddde
·
verified ·
1 Parent(s): 4b70236

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +320 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from google.cloud import speech, texttospeech
5
+ import os
6
+ import tempfile
7
+ import time
8
+ from pydub import AudioSegment # For audio conversion
9
+
10
+ # ==============================================================================
11
+ # 1. CONFIGURE AND LOAD N-ATLaS MODEL
12
+ # ==============================================================================
13
+
14
+ MODEL_ID = "NCAIR1/N-ATLaS"
15
+
16
+ print(f"Loading model: {MODEL_ID}...")
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
+ # Load model for local Mac testing (dtype=torch.float16, no quantization)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_ID,
21
+ dtype=torch.float16,
22
+ device_map="auto",
23
+ )
24
+ print("✅ N-ATLaS Model loaded.")
25
+
26
+ # ==============================================================================
27
+ # 2. INITIALIZE GOOGLE CLOUD CLIENTS
28
+ # Assumes GOOGLE_APPLICATION_CREDENTIALS is set in your environment
29
+ # ==============================================================================
30
+ try:
31
+ speech_client = speech.SpeechClient()
32
+ tts_client = texttospeech.TextToSpeechClient()
33
+ print("✅ Google Cloud STT/TTS clients initialized.")
34
+ except Exception as e:
35
+ print(f"🛑 CRITICAL: Could not initialize Google Cloud clients. {e}")
36
+ print(" Make sure you have set the GOOGLE_APPLICATION_CREDENTIALS environment variable.")
37
+ exit()
38
+
39
+ # ==============================================================================
40
+ # 3. HELPER FUNCTIONS (STT AND TTS)
41
+ # ==============================================================================
42
+
43
+ def transcribe_audio(audio_filepath: str, language_code: str):
44
+ """
45
+ Converts audio to WAV/LINEAR16 format and transcribes using Google Cloud STT.
46
+ """
47
+ if not audio_filepath:
48
+ return ""
49
+ print(f"Loading audio file: {audio_filepath}")
50
+ try:
51
+ # Load audio using pydub (handles various input formats)
52
+ audio = AudioSegment.from_file(audio_filepath)
53
+ print(" -> AudioSegment loaded successfully.")
54
+
55
+ target_sample_rate = 16000
56
+ target_channels = 1 # Mono
57
+
58
+ # Resample and convert to mono
59
+ audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels)
60
+
61
+ # Get raw PCM data (LINEAR16)
62
+ wav_data = audio.raw_data
63
+
64
+ print(f"Transcribing {len(wav_data)} bytes with language: {language_code} at {target_sample_rate} Hz...")
65
+
66
+ # Configure Google STT for LINEAR16 (Default model)
67
+ recognition_audio = speech.RecognitionAudio(content=wav_data)
68
+ recognition_config = speech.RecognitionConfig(
69
+ encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
70
+ sample_rate_hertz=target_sample_rate,
71
+ language_code=language_code,
72
+ audio_channel_count=target_channels
73
+ )
74
+
75
+ response = speech_client.recognize(config=recognition_config, audio=recognition_audio)
76
+
77
+ if not response.results:
78
+ return "[Could not understand audio]"
79
+
80
+ transcribed_text = response.results[0].alternatives[0].transcript
81
+ print(f" -> Transcribed: {transcribed_text}")
82
+ return transcribed_text
83
+
84
+ except Exception as e:
85
+ print(f" -> 🛑 ERROR during audio processing or transcription: {e}")
86
+ return f"[Error processing audio: {e}]"
87
+ finally:
88
+ # Clean up the temporary file created by Gradio
89
+ if audio_filepath and os.path.exists(audio_filepath):
90
+ try:
91
+ os.remove(audio_filepath)
92
+ print(f" -> Cleaned up temp file: {audio_filepath}")
93
+ except OSError as e:
94
+ print(f" -> Error deleting temp file {audio_filepath}: {e}")
95
+
96
+ def synthesize_speech(text, voice_code):
97
+ """Synthesizes speech using Google Cloud TTS with robust voice selection."""
98
+ print(f"Synthesizing speech with requested code: {voice_code}...")
99
+ synthesis_input = texttospeech.SynthesisInput(text=text)
100
+
101
+ # --- Robust Voice Selection Logic ---
102
+ selected_voice_name = None
103
+ selected_ssml_gender = None
104
+
105
+ # Use high-quality US WaveNet for any English request
106
+ if voice_code.startswith("en"):
107
+ selected_language_code = "en-US"
108
+ selected_voice_name = "en-US-Wavenet-A"
109
+ print(f" -> Using high-quality English voice: {selected_voice_name}")
110
+ else:
111
+ # For non-English (ha, ig, yo), provide the BASE language code
112
+ # and request a specific gender. Google should pick a default.
113
+ selected_language_code = voice_code.split('-')[0] # Use 'ha', 'ig', 'yo'
114
+ selected_ssml_gender = texttospeech.SsmlVoiceGender.FEMALE # Ask for a female voice
115
+ print(f" -> Requesting default FEMALE voice for language: {selected_language_code}")
116
+
117
+ # Set parameters, omitting 'name' if None
118
+ voice_params = {"language_code": selected_language_code}
119
+ if selected_voice_name:
120
+ voice_params["name"] = selected_voice_name
121
+ elif selected_ssml_gender:
122
+ voice_params["ssml_gender"] = selected_ssml_gender
123
+
124
+ voice = texttospeech.VoiceSelectionParams(**voice_params)
125
+ # --- End Voice Selection Logic ---
126
+
127
+ audio_config = texttospeech.AudioConfig(
128
+ audio_encoding=texttospeech.AudioEncoding.MP3
129
+ )
130
+
131
+ # Diagnostic check for non-English voices
132
+ if not voice_code.startswith("en"):
133
+ try:
134
+ print(f"--- Listing available voices for language code: {selected_language_code} ---")
135
+ list_voices_response = tts_client.list_voices(language_code=selected_language_code)
136
+ available_voices = [v.name for v in list_voices_response.voices]
137
+ if available_voices:
138
+ print(f"Available voices found: {available_voices}")
139
+ else:
140
+ print("No voices found for this language code.")
141
+ except Exception as list_err:
142
+ print(f" -> ERROR trying to list voices: {list_err}")
143
+
144
+ try:
145
+ response = tts_client.synthesize_speech(
146
+ input=synthesis_input, voice=voice, audio_config=audio_config
147
+ )
148
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp:
149
+ fp.write(response.audio_content)
150
+ temp_audio_path = fp.name
151
+ print(f" -> Audio saved to: {temp_audio_path}")
152
+ return temp_audio_path
153
+ except Exception as e:
154
+ print(f" -> 🛑 ERROR during speech synthesis: {e}")
155
+ return None
156
+
157
+ # ==============================================================================
158
+ # 4. CORE CHAT FUNCTION (AS A GENERATOR) - DUAL RESPONSE
159
+ # ==============================================================================
160
+ def speech_to_speech_chat(audio_input, history, input_lang, output_voice):
161
+ """
162
+ Main function for the Gradio app. Handles filepath audio input, uses 'yield',
163
+ and generates BOTH a translation and a conversational reply.
164
+ """
165
+
166
+ # --- STAGE 0: Get Filepath ---
167
+ user_audio_path = audio_input # Gradio passes the filepath directly
168
+ if user_audio_path is None:
169
+ # Handle case where user clicks submit without recording
170
+ yield history, None, None
171
+ return # Stop processing
172
+ print(f"Received audio filepath: {user_audio_path}")
173
+
174
+ # ----- STAGE 1: Transcribe User -----
175
+ transcribed_text = transcribe_audio(user_audio_path, input_lang) # Pass filepath
176
+
177
+ if transcribed_text is None:
178
+ print(" -> 🛑 Transcription function returned None unexpectedly.")
179
+ transcribed_text = "[Error: Transcription failed internally]"
180
+
181
+ history.append((transcribed_text, None))
182
+ yield history, None, None # Update UI with transcribed text
183
+
184
+ if transcribed_text.startswith("["):
185
+ return # Stop processing if transcription failed
186
+
187
+ # ----- STAGE 2: Get N-ATLaS Response (RUN 1: CONVERSATION) -----
188
+ print("Generating N-ATLaS response (Run 1: Conversation)...")
189
+
190
+ # Get target language name
191
+ if output_voice.startswith("ha"): lang = "Hausa"
192
+ elif output_voice.startswith("yo"): lang = "Yoruba"
193
+ elif output_voice.startswith("ig"): lang = "Igbo"
194
+ else: lang = "Nigerian English"
195
+
196
+ # Create persona prompt for conversation
197
+ system_prompt = f"You are a helpful, friendly assistant. Listen to what the user says and respond naturally. You must respond ONLY in {lang}."
198
+
199
+ # Build conversation history
200
+ messages = []
201
+ for user_msg, assistant_msg in history:
202
+ user_content = str(user_msg) if user_msg is not None else "[empty]"
203
+ messages.append({"role": "user", "content": user_content})
204
+ if assistant_msg:
205
+ # Extract just the conversational part from previous turns
206
+ if "**Conversational Reply:**" in str(assistant_msg):
207
+ reply_text = str(assistant_msg).split("---")[0].replace("**Conversational Reply:**\n", "").strip()
208
+ messages.append({"role": "assistant", "content": reply_text})
209
+ else:
210
+ messages.append({"role": "assistant", "content": str(assistant_msg)})
211
+
212
+ # Add the final system prompt
213
+ conversation_messages = messages + [{"role": "system", "content": system_prompt}]
214
+ conversation_prompt = tokenizer.apply_chat_template(conversation_messages, tokenize=False, add_generation_prompt=True)
215
+
216
+ inputs = tokenizer(conversation_prompt, return_tensors="pt").to(model.device)
217
+ input_length = inputs.input_ids.shape[1]
218
+
219
+ outputs = model.generate(
220
+ **inputs, max_new_tokens=256, eos_token_id=tokenizer.eos_token_id,
221
+ pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7, top_p=0.9
222
+ )
223
+ conversational_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
224
+ print(f" -> Conversational Reply: {conversational_text}")
225
+
226
+ # ----- STAGE 3: Get N-ATLaS Response (RUN 2: TRANSLATION) -----
227
+ print("Generating N-ATLaS response (Run 2: Translation)...")
228
+
229
+ translation_system_prompt = f"Translate the following text to {lang}:"
230
+
231
+ translation_messages = [
232
+ {"role": "system", "content": translation_system_prompt},
233
+ {"role": "user", "content": transcribed_text} # Translate only the last user input
234
+ ]
235
+ translation_prompt = tokenizer.apply_chat_template(translation_messages, tokenize=False, add_generation_prompt=True)
236
+
237
+ inputs = tokenizer(translation_prompt, return_tensors="pt").to(model.device)
238
+ input_length = inputs.input_ids.shape[1]
239
+
240
+ outputs = model.generate(
241
+ **inputs, max_new_tokens=256, eos_token_id=tokenizer.eos_token_id,
242
+ pad_token_id=tokenizer.eos_token_id, do_sample=False, temperature=0.1, top_p=0.9 # Lower temp for translation
243
+ )
244
+ translation_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
245
+ print(f" -> Direct Translation: {translation_text}")
246
+
247
+ # ----- STAGE 4: Synthesize and Format Response -----
248
+
249
+ # Synthesize speech for the CONVERSATIONAL reply only
250
+ bot_audio_path = synthesize_speech(conversational_text, output_voice)
251
+
252
+ # Format a single string for the chatbot UI
253
+ bot_response_string = f"""
254
+ **Conversational Reply:**
255
+ {conversational_text}
256
+
257
+ ---
258
+ **Direct Translation:**
259
+ {translation_text}
260
+ """
261
+
262
+ # Update the history with the user's text and the bot's combined text
263
+ final_user_text = transcribed_text if transcribed_text is not None else "[Error]"
264
+ history[-1] = (final_user_text, bot_response_string)
265
+
266
+ # Yield the final history, the bot's audio, and clear the mic input
267
+ yield history, bot_audio_path, None
268
+
269
+ # ==============================================================================
270
+ # 5. GRADIO UI (using Blocks) - Gradio 3.x compatible
271
+ # ==============================================================================
272
+ with gr.Blocks(theme=gr.themes.Soft(), title="N-ATLaS Voice Test").queue() as iface:
273
+ gr.Markdown("# 🇳🇬 N-ATLaS Multilingual Voice Test")
274
+ gr.Markdown(
275
+ "**Instructions:** Select your spoken language and desired response voice. "
276
+ "Speak into the microphone, then press 'Submit'.\n"
277
+ "**⚠️ IMPORTANT: Response from the N-ATLaS 8B model may take 30-90 seconds locally.**"
278
+ )
279
+ with gr.Row():
280
+ input_lang = gr.Dropdown(
281
+ label="1. Language I am Speaking",
282
+ choices=[
283
+ ("American English", "en-US"),
284
+ ("Nigerian Pidgin / English", "en-NG"),
285
+ ("Hausa", "ha-NG"),
286
+ ("Igbo", "ig-NG"),
287
+ ("Yoruba", "yo-NG")
288
+ ],
289
+ value="en-US" # Default to US English for local testing
290
+ )
291
+ output_voice = gr.Dropdown(
292
+ label="2. Language for Bot to Speak",
293
+ choices=[
294
+ ("Nigerian English", "en-NG"),
295
+ ("Hausa", "ha-NG"),
296
+ ("Igbo", "ig-NG"),
297
+ ("Yoruba", "yo-NG")
298
+ ],
299
+ value="en-NG"
300
+ )
301
+ chatbot = gr.Chatbot(label="Conversation", height=400)
302
+ mic_input = gr.Audio(
303
+ source="microphone", # Use 'source' (singular) for Gradio 3.x
304
+ type="filepath", # Use 'filepath'
305
+ label="3. Press record and speak"
306
+ )
307
+ bot_audio_output = gr.Audio(
308
+ label="Bot's Spoken Response",
309
+ autoplay=True
310
+ )
311
+ submit_btn = gr.Button("Submit Audio")
312
+ chat_history = gr.State([])
313
+ submit_btn.click(
314
+ fn=speech_to_speech_chat,
315
+ inputs=[mic_input, chat_history, input_lang, output_voice],
316
+ outputs=[chatbot, bot_audio_output, mic_input]
317
+ )
318
+
319
+ print("Launching Gradio interface...")
320
+ iface.launch(share=True) # share=True for public link, remove queue=True
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.50.2
2
+ transformers
3
+ torch
4
+ accelerate
5
+ bitsandbytes
6
+ sentencepiece
7
+ google-cloud-speech
8
+ google-cloud-texttospeech
9
+ ffmpeg-python
10
+ pydub
11
+ pydantic<2