Spaces:
Sleeping
Sleeping
| #app.py | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
| from IndicTransToolkit.processor import IndicProcessor | |
| import requests | |
| from datetime import datetime | |
| import tempfile | |
| from gtts import gTTS | |
| import os | |
| import shutil | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models | |
| model_en_to_indic = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True).to(DEVICE) | |
| tokenizer_en_to_indic = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True) | |
| model_indic_to_en = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-indic-en-1B", trust_remote_code=True).to(DEVICE) | |
| tokenizer_indic_to_en = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-indic-en-1B", trust_remote_code=True) | |
| ip = IndicProcessor(inference=True) | |
| asr = pipeline("automatic-speech-recognition", model="openai/whisper-small") | |
| # --- Supabase settings --- | |
| SUPABASE_URL = "https://gptmdbhzblfybdnohqnh.supabase.co" | |
| SUPABASE_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImdwdG1kYmh6YmxmeWJkbm9ocW5oIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDc0NjY1NDgsImV4cCI6MjA2MzA0MjU0OH0.CfWArts6Kd_x7Wj0a_nAyGJfrFt8F7Wdy_MdYDj9e7U" | |
| # --- Supabase utilities --- | |
| def save_to_supabase(input_text, output_text, direction): | |
| if not input_text.strip() or not output_text.strip(): | |
| return "Nothing to save." | |
| table = "translations" if direction == "en_to_ks" else "ks_to_en_translations" | |
| payload = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "input_text": input_text, | |
| "output_text": output_text | |
| } | |
| headers = { | |
| "apikey": SUPABASE_API_KEY, | |
| "Authorization": f"Bearer {SUPABASE_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| response = requests.post(f"{SUPABASE_URL}/rest/v1/{table}", json=payload, headers=headers) | |
| return "Saved successfully!" if response.status_code == 201 else "β Failed to save." | |
| except Exception as e: | |
| logging.error("Save error: %s", e) | |
| return "Save error." | |
| # --- Save verified translation --- | |
| def save_verified_translation(original_text, verified_text): | |
| if not original_text.strip() or not verified_text.strip(): | |
| return "Nothing to save." | |
| payload = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "original_translation": original_text, | |
| "verified_translation": verified_text | |
| } | |
| headers = { | |
| "apikey": SUPABASE_API_KEY, | |
| "Authorization": f"Bearer {SUPABASE_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| response = requests.post(f"{SUPABASE_URL}/rest/v1/verified_translations", json=payload, headers=headers) | |
| return "Verified translation saved!" if response.status_code == 201 else "β Failed to save verified translation." | |
| except Exception as e: | |
| logging.error("Verified Save error: %s", e) | |
| return "Verified save error." | |
| def get_translation_history(direction): | |
| headers = { | |
| "apikey": SUPABASE_API_KEY, | |
| "Authorization": f"Bearer {SUPABASE_API_KEY}" | |
| } | |
| table = "translations" if direction == "en_to_ks" else "ks_to_en_translations" | |
| try: | |
| res = requests.get(f"{SUPABASE_URL}/rest/v1/{table}?order=timestamp.desc&limit=20", headers=headers) | |
| normal_data = res.json() if res.status_code == 200 else [] | |
| vres = requests.get(f"{SUPABASE_URL}/rest/v1/verified_translations?order=timestamp.desc&limit=20", headers=headers) | |
| verified_data = vres.json() if vres.status_code == 200 else [] | |
| normal_history = "\n".join([ | |
| f"Input: {r['input_text']} β Output: {r['output_text']}" | |
| for r in normal_data | |
| ]) or "No regular translations yet." | |
| verified_history = "\n".join([ | |
| f"Verified: {r['original_translation']} β {r['verified_translation']}" | |
| for r in verified_data | |
| ]) or "No verified translations yet." | |
| return f"--- Regular Translations ---\n{normal_history}\n\n--- Verified Translations ---\n{verified_history}" | |
| except Exception as e: | |
| logging.error("History error: %s", e) | |
| return "Error loading history." | |
| # --- Translation with TTS integration --- | |
| def translate(text, direction, generate_tts=False): | |
| if not text.strip(): | |
| return "Enter some text.", None | |
| if direction == "en_to_ks": | |
| src_lang, tgt_lang = "eng_Latn", "kas_Arab" | |
| model, tokenizer = model_en_to_indic, tokenizer_en_to_indic | |
| else: | |
| src_lang, tgt_lang = "kas_Arab", "eng_Latn" | |
| model, tokenizer = model_indic_to_en, tokenizer_indic_to_en | |
| try: | |
| batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang) | |
| tokens = tokenizer(batch, return_tensors="pt", padding=True).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model.generate(**tokens, max_length=256, num_beams=5) | |
| result = tokenizer.batch_decode(output, skip_special_tokens=True) | |
| final = ip.postprocess_batch(result, lang=tgt_lang)[0] | |
| # Generate TTS for KSβEN direction if requested | |
| audio_path = None | |
| if generate_tts and direction == "ks_to_en": | |
| audio_path = synthesize_tts(final) | |
| return final, audio_path | |
| except Exception as e: | |
| logging.error("Translation error: %s", e) | |
| return "Translation failed.", None | |
| # --- TTS for English output --- | |
| def synthesize_tts(text): | |
| try: | |
| tts = gTTS(text=text, lang="en") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: | |
| tts.save(f.name) | |
| return f.name | |
| except Exception as e: | |
| logging.error("TTS error: %s", e) | |
| return None | |
| # --- STT for English audio --- | |
| def transcribe_audio(audio_path): | |
| try: | |
| if not audio_path: | |
| return None, "No audio file provided" | |
| # Create a persistent copy of the audio file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| temp_path = f.name | |
| shutil.copy(audio_path, temp_path) | |
| transcription = asr(temp_path)["text"] | |
| os.unlink(temp_path) # Clean up temporary file | |
| return transcription, None | |
| except Exception as e: | |
| logging.error("STT error: %s", e) | |
| return None, f"Transcription failed: {str(e)}" | |
| # --- Store audio file path --- | |
| def store_audio(audio_path): | |
| """Store audio path in state and return it to keep it visible""" | |
| return audio_path | |
| # --- Handle audio translation --- | |
| def handle_audio_translation(audio_path, direction): | |
| if direction != "en_to_ks": | |
| return "Audio input is only supported for English to Kashmiri.", "", "", audio_path | |
| transcription, error = transcribe_audio(audio_path) | |
| if error: | |
| return error, "", "", audio_path | |
| translated, _ = translate(transcription, direction, generate_tts=False) | |
| return "", transcription, translated, audio_path | |
| # --- Switch UI direction --- | |
| def switch_direction(direction, input_text_val, output_text_val, audio_path): | |
| new_direction = "ks_to_en" if direction == "en_to_ks" else "en_to_ks" | |
| input_label = "Kashmiri Text" if new_direction == "ks_to_en" else "English Text" | |
| output_label = "English Translation" if new_direction == "ks_to_en" else "Kashmiri Translation" | |
| return new_direction, gr.update(value=output_text_val, label=input_label), gr.update(value=input_text_val, label=output_label), None | |
| # === Gradio Interface === | |
| with gr.Blocks() as interface: | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: space-between; align-items: center; padding: 10px;"> | |
| <img src="https://raw.githubusercontent.com/BurhaanRasheedZargar/Images/211321a234613a9c3dd944fe9367cf13d1386239/assets/left_logo.png" style="height:150px; width:auto;"> | |
| <h2 style="margin: 0; text-align: center;">English β Kashmiri Translator</h2> | |
| <img src="https://raw.githubusercontent.com/BurhaanRasheedZargar/Images/77797f7f7cbee328fa0f9d31cf3e290441e04cd3/assets/right_logo.png"> | |
| </div> | |
| """) | |
| translation_direction = gr.State(value="en_to_ks") | |
| stored_audio = gr.State() | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="English Text", placeholder="Enter text here...", lines=2) | |
| output_text = gr.Textbox(label="Kashmiri Translation", placeholder="Translated text...", lines=2) | |
| with gr.Row(): | |
| verified_text = gr.Textbox(label="βοΈ Edit Translation", placeholder="Edit translation here...", lines=2) | |
| with gr.Row(): | |
| translate_button = gr.Button("Translate") | |
| save_button = gr.Button("Save Translation") | |
| switch_button = gr.Button("Switch Direction") | |
| verify_button = gr.Button("β Verify & Save") | |
| save_status = gr.Textbox(label="Save Status", interactive=False) | |
| history = gr.Textbox(label="Translation History", lines=8, interactive=False) | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", label="ποΈ Record English audio", sources=["microphone"]) | |
| audio_output = gr.Audio(label="π English TTS", interactive=False) | |
| with gr.Row(): | |
| stt_button = gr.Button("π€ Transcribe & Translate (EN β KS)") | |
| tts_button = gr.Button("π Translate & Speak (KS β EN)") | |
| # Store audio when recorded | |
| audio_input.change( | |
| fn=store_audio, | |
| inputs=audio_input, | |
| outputs=stored_audio | |
| ) | |
| # Events | |
| translate_button.click( | |
| fn=translate, | |
| inputs=[input_text, translation_direction, gr.State(False)], | |
| outputs=[output_text, audio_output] | |
| ).then( | |
| fn=lambda txt: txt, | |
| inputs=output_text, | |
| outputs=verified_text | |
| ) | |
| tts_button.click( | |
| fn=translate, | |
| inputs=[input_text, translation_direction, gr.State(True)], | |
| outputs=[output_text, audio_output] | |
| ).then( | |
| fn=lambda txt: txt, | |
| inputs=output_text, | |
| outputs=verified_text | |
| ) | |
| save_button.click( | |
| fn=save_to_supabase, | |
| inputs=[input_text, output_text, translation_direction], | |
| outputs=save_status | |
| ).then( | |
| fn=get_translation_history, | |
| inputs=translation_direction, | |
| outputs=history | |
| ) | |
| switch_button.click( | |
| fn=switch_direction, | |
| inputs=[translation_direction, input_text, output_text, stored_audio], | |
| outputs=[translation_direction, input_text, output_text, audio_output] | |
| ) | |
| stt_button.click( | |
| fn=handle_audio_translation, | |
| inputs=[stored_audio, translation_direction], | |
| outputs=[save_status, input_text, output_text, audio_input] | |
| ).then( | |
| fn=lambda txt: txt, | |
| inputs=output_text, | |
| outputs=verified_text | |
| ) | |
| verify_button.click( | |
| fn=save_verified_translation, | |
| inputs=[output_text, verified_text], | |
| outputs=save_status | |
| ).then( | |
| fn=get_translation_history, | |
| inputs=translation_direction, | |
| outputs=history | |
| ) | |
| if __name__ == "__main__": | |
| interface.queue().launch(share=True) |