Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from google.cloud import speech, texttospeech | |
| import os | |
| import tempfile | |
| import time | |
| from pydub import AudioSegment | |
| # ============================================================================== | |
| # 1. HANDLE AUTHENTICATION FROM HUGGING FACE SECRETS | |
| # ============================================================================== | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| print("WARNING: HF_TOKEN secret not set. Download may fail.") | |
| gcp_key_json_string = os.environ.get("GCP_SERVICE_ACCOUNT_KEY") | |
| if not gcp_key_json_string: | |
| print("π CRITICAL: GCP_SERVICE_ACCOUNT_KEY secret not set. STT/TTS will fail.") | |
| else: | |
| try: | |
| # We must write the secret string to a temporary file for Google's clients | |
| with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".json") as f: | |
| f.write(gcp_key_json_string) | |
| os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = f.name | |
| print(f"β Google credentials written to temporary file: {f.name}") | |
| except Exception as e: | |
| print(f"π CRITICAL: Failed to write GCP key to temp file. {e}") | |
| # ============================================================================== | |
| # 2. CONFIGURE AND LOAD N-ATLaS MODEL (FOR T4 GPU) | |
| # ============================================================================== | |
| MODEL_ID = "NCAIR1/N-ATLaS" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| print(f"Loading model: {MODEL_ID} with 4-bit quantization...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| token=hf_token | |
| ) | |
| print("β N-ATLaS Model loaded.") | |
| # ============================================================================== | |
| # 3. INITIALIZE GOOGLE CLOUD CLIENTS | |
| # ============================================================================== | |
| try: | |
| speech_client = speech.SpeechClient() | |
| tts_client = texttospeech.TextToSpeechClient() | |
| print("β Google Cloud STT/TTS clients initialized.") | |
| except Exception as e: | |
| print(f"π CRITICAL: Could not initialize Google Cloud clients. {e}") | |
| # ============================================================================== | |
| # 4. HELPER FUNCTIONS (STT AND TTS) | |
| # ============================================================================== | |
| def transcribe_audio(audio_filepath: str, language_code: str): | |
| if not audio_filepath: return "" | |
| print(f"Loading audio file: {audio_filepath}") | |
| try: | |
| audio = AudioSegment.from_file(audio_filepath) | |
| print(" -> AudioSegment loaded successfully.") | |
| target_sample_rate = 16000 | |
| target_channels = 1 | |
| audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels) | |
| wav_data = audio.raw_data | |
| print(f"Transcribing {len(wav_data)} bytes with language: {language_code} at {target_sample_rate} Hz...") | |
| recognition_audio = speech.RecognitionAudio(content=wav_data) | |
| recognition_config = speech.RecognitionConfig( | |
| encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, | |
| sample_rate_hertz=target_sample_rate, | |
| language_code=language_code, | |
| audio_channel_count=target_channels | |
| ) | |
| response = speech_client.recognize(config=recognition_config, audio=recognition_audio) | |
| if not response.results: return "[Could not understand audio]" | |
| transcribed_text = response.results[0].alternatives[0].transcript | |
| print(f" -> Transcribed: {transcribed_text}") | |
| return transcribed_text | |
| except Exception as e: | |
| print(f" -> π ERROR during audio processing or transcription: {e}") | |
| return f"[Error processing audio: {e}]" | |
| finally: | |
| if audio_filepath and os.path.exists(audio_filepath): | |
| try: os.remove(audio_filepath) | |
| except OSError: pass | |
| def synthesize_speech(text, voice_code): | |
| print(f"Synthesizing speech with requested code: {voice_code}...") | |
| synthesis_input = texttospeech.SynthesisInput(text=text) | |
| selected_voice_name = None | |
| selected_ssml_gender = None | |
| if voice_code.startswith("en"): | |
| selected_language_code = "en-US" | |
| selected_voice_name = "en-US-Wavenet-A" | |
| print(f" -> Using high-quality English voice: {selected_voice_name}") | |
| else: | |
| selected_language_code = voice_code.split('-')[0] # Use 'ha', 'ig', 'yo' | |
| selected_ssml_gender = texttospeech.SsmlVoiceGender.FEMALE | |
| print(f" -> Requesting default FEMALE voice for language: {selected_language_code}") | |
| voice_params = {"language_code": selected_language_code} | |
| if selected_voice_name: | |
| voice_params["name"] = selected_voice_name | |
| elif selected_ssml_gender: | |
| voice_params["ssml_gender"] = selected_ssml_gender | |
| voice = texttospeech.VoiceSelectionParams(**voice_params) | |
| audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3) | |
| if not voice_code.startswith("en"): | |
| try: | |
| print(f"--- Listing available voices for language code: {selected_language_code} ---") | |
| list_voices_response = tts_client.list_voices(language_code=selected_language_code) | |
| available_voices = [v.name for v in list_voices_response.voices] | |
| if available_voices: | |
| print(f"Available voices found: {available_voices}") | |
| else: | |
| print("No voices found for this language code.") | |
| except Exception as list_err: | |
| print(f" -> ERROR trying to list voices: {list_err}") | |
| try: | |
| response = tts_client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config) | |
| with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp: | |
| fp.write(response.audio_content) | |
| temp_audio_path = fp.name | |
| print(f" -> Audio saved to: {temp_audio_path}") | |
| return temp_audio_path | |
| except Exception as e: | |
| print(f" -> π ERROR during speech synthesis: {e}") | |
| return None | |
| # ============================================================================== | |
| # 4. CORE CHAT FUNCTION (AS A GENERATOR) - *** UPDATED FOR GRADIO 4.x *** | |
| # ============================================================================== | |
| def speech_to_speech_chat(audio_input, history, input_lang, output_voice): | |
| """ | |
| Main function for the Gradio app. Handles filepath audio input, uses 'yield', | |
| and generates BOTH a translation and a conversational reply. | |
| HISTORY is now a list of dictionaries: [{"role": "user", "content": ...}] | |
| """ | |
| user_audio_path = audio_input | |
| if user_audio_path is None: | |
| yield history, None, None | |
| return | |
| print(f"Received audio filepath: {user_audio_path}") | |
| # ----- STAGE 1: Transcribe User ----- | |
| transcribed_text = transcribe_audio(user_audio_path, input_lang) | |
| if transcribed_text is None: | |
| transcribed_text = "[Error: Transcription failed internally]" | |
| # --- HISTORY FIX 1 --- | |
| # Append the user's transcribed text to the history in the new format | |
| history.append({"role": "user", "content": transcribed_text}) | |
| yield history, None, None # Update UI with transcribed text | |
| if transcribed_text.startswith("["): | |
| return | |
| # ----- STAGE 2: Get N-ATLaS Response (RUN 1: CONVERSATION) ----- | |
| print("Generating N-ATLaS response (Run 1: Conversation)...") | |
| if output_voice.startswith("ha"): lang = "Hausa" | |
| elif output_voice.startswith("yo"): lang = "Yoruba" | |
| elif output_voice.startswith("ig"): lang = "Igbo" | |
| else: lang = "Nigerian English" | |
| system_prompt = f"You are a helpful, friendly assistant. Listen to what the user says and respond naturally. You must respond ONLY in {lang}." | |
| # --- HISTORY FIX 2 --- | |
| # The history is already in the correct format. Just make a copy. | |
| messages = list(history) | |
| # Add the final system prompt | |
| conversation_messages = messages + [{"role": "system", "content": system_prompt}] | |
| conversation_prompt = tokenizer.apply_chat_template(conversation_messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(conversation_prompt, return_tensors="pt").to(model.device) | |
| input_length = inputs.input_ids.shape[1] | |
| outputs = model.generate( | |
| **inputs, max_new_tokens=256, eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7, top_p=0.9 | |
| ) | |
| conversational_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip() | |
| print(f" -> Conversational Reply: {conversational_text}") | |
| # ----- STAGE 3: Get N-ATLaS Response (RUN 2: TRANSLATION) ----- | |
| print("Generating N-ATLaS response (Run 2: Translation)...") | |
| translation_system_prompt = f"Translate the following text to {lang}:" | |
| translation_messages = [ | |
| {"role": "system", "content": translation_system_prompt}, | |
| {"role": "user", "content": transcribed_text} | |
| ] | |
| translation_prompt = tokenizer.apply_chat_template(translation_messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(translation_prompt, return_tensors="pt").to(model.device) | |
| input_length = inputs.input_ids.shape[1] | |
| outputs = model.generate( | |
| **inputs, max_new_tokens=256, eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, do_sample=False, temperature=0.1, top_p=0.9 | |
| ) | |
| translation_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip() | |
| print(f" -> Direct Translation: {translation_text}") | |
| # ----- STAGE 4: Synthesize and Format Response ----- | |
| bot_audio_path = synthesize_speech(conversational_text, output_voice) | |
| bot_response_string = f""" | |
| **Conversational Reply:** | |
| {conversational_text} | |
| --- | |
| **Direct Translation:** | |
| {translation_text} | |
| """ | |
| # --- HISTORY FIX 3 --- | |
| # Append the bot's complete response to the history | |
| history.append({"role": "assistant", "content": bot_response_string}) | |
| # Yield the final history, the bot's audio, and clear the mic input | |
| yield history, bot_audio_path, None | |
| # ============================================================================== | |
| # 5. GRADIO UI (using Blocks) - *** UPDATED FOR GRADIO 4.x *** | |
| # ============================================================================== | |
| with gr.Blocks(theme=gr.themes.Soft(), title="N-ATLaS Voice Test") as iface: | |
| gr.Markdown("# π³π¬ N-ATLaS Multilingual Voice Test") | |
| gr.Markdown( | |
| "**Instructions:** Select your spoken language and desired response voice. " | |
| "Speak into the microphone, then press 'Submit'.\n" | |
| "**This app is running on a T4 GPU. Responses should be fast.**" | |
| ) | |
| with gr.Row(): | |
| input_lang = gr.Dropdown( | |
| label="1. Language I am Speaking", | |
| choices=[ | |
| ("American English", "en-US"), | |
| ("Nigerian Pidgin / English", "en-NG"), | |
| ("Hausa", "ha-NG"), | |
| ("Igbo", "ig-NG"), | |
| ("Yoruba", "yo-NG") | |
| ], | |
| value="en-US" | |
| ) | |
| output_voice = gr.Dropdown( | |
| label="2. Language for Bot to Speak", | |
| choices=[ | |
| ("Nigerian English", "en-NG"), | |
| ("Hausa", "ha-NG"), | |
| ("Igbo", "ig-NG"), | |
| ("Yoruba", "yo-NG") | |
| ], | |
| value="en-NG" | |
| ) | |
| # --- UI FIX 1 --- | |
| # Set type="messages" for the Chatbot component | |
| chatbot = gr.Chatbot(label="Conversation", height=400, type="messages") | |
| mic_input = gr.Audio( | |
| sources=["microphone"], # Use 'sources' (plural) for Gradio 4.x | |
| type="filepath", | |
| label="3. Press record and speak" | |
| ) | |
| bot_audio_output = gr.Audio( | |
| label="Bot's Spoken Response", | |
| autoplay=True | |
| ) | |
| submit_btn = gr.Button("Submit Audio") | |
| # --- UI FIX 2 --- | |
| # Initialize history as an empty list (Gradio 4.x handles this) | |
| chat_history = gr.State([]) | |
| submit_btn.click( | |
| fn=speech_to_speech_chat, | |
| inputs=[mic_input, chat_history, input_lang, output_voice], | |
| outputs=[chatbot, bot_audio_output, mic_input] | |
| ) | |
| print("Launching Gradio interface...") | |
| # No share=True needed on Spaces, and queue is enabled by default | |
| iface.launch() |