Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,38 +1,43 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import
|
| 3 |
import torch
|
|
|
|
| 4 |
import librosa
|
| 5 |
import soundfile as sf
|
| 6 |
import io
|
|
|
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
|
| 11 |
-
asr_model = AutoModelForCTC.from_pretrained(asr_model_name)
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
|
| 16 |
|
| 17 |
trans_model_name = "ai4bharat/IndicTrans3-beta"
|
| 18 |
trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
| 19 |
trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_name)
|
| 20 |
|
| 21 |
-
tts_model_name = "ai4bharat/
|
| 22 |
-
|
| 23 |
|
| 24 |
def full_pipeline(audio, source_lang, target_lang):
|
| 25 |
-
# ASR
|
| 26 |
-
audio_array,
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 36 |
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 37 |
|
| 38 |
# Translation if needed
|
|
@@ -41,17 +46,19 @@ def full_pipeline(audio, source_lang, target_lang):
|
|
| 41 |
outputs = trans_model.generate(**inputs)
|
| 42 |
response = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 43 |
|
| 44 |
-
# TTS
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
with io.BytesIO() as buffer:
|
| 47 |
-
sf.write(buffer, tts_output
|
| 48 |
audio_bytes = buffer.getvalue()
|
| 49 |
|
| 50 |
return audio_bytes, text, response
|
| 51 |
|
| 52 |
iface = gr.Interface(
|
| 53 |
fn=full_pipeline,
|
| 54 |
-
inputs=[gr.Audio(type="file"), gr.Textbox(label="Source Lang"), gr.Textbox(label="Target Lang")],
|
| 55 |
outputs=[gr.Audio(label="Response Audio"), gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Response Text")],
|
| 56 |
title="HanuVak Indic Conversation Backend"
|
| 57 |
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCTC
|
| 3 |
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
import librosa
|
| 6 |
import soundfile as sf
|
| 7 |
import io
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
+
# For gated models, set token
|
| 11 |
+
os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN" # From huggingface.co/settings/tokens
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Load models
|
| 14 |
+
asr_model_name = "ai4bharat/indic-conformer-600m-multilingual"
|
| 15 |
+
asr_model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True)
|
| 16 |
+
|
| 17 |
+
llm_model_name = "ai4bharat/IndicBART"
|
| 18 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, do_lower_case=False, use_fast=False, keep_accents=True)
|
| 19 |
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
|
| 20 |
|
| 21 |
trans_model_name = "ai4bharat/IndicTrans3-beta"
|
| 22 |
trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
| 23 |
trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_name)
|
| 24 |
|
| 25 |
+
tts_model_name = "ai4bharat/IndicF5"
|
| 26 |
+
tts_model = AutoModel.from_pretrained(tts_model_name, trust_remote_code=True)
|
| 27 |
|
| 28 |
def full_pipeline(audio, source_lang, target_lang):
|
| 29 |
+
# ASR
|
| 30 |
+
audio_array, sr = librosa.load(io.BytesIO(audio), sr=16000)
|
| 31 |
+
wav = torch.tensor(audio_array).unsqueeze(0)
|
| 32 |
+
text = asr_model(wav, source_lang, "ctc")
|
| 33 |
+
|
| 34 |
+
# LLM: Simple generation
|
| 35 |
+
bos_id = llm_tokenizer._convert_token_to_id_with_added_voc("<s>")
|
| 36 |
+
eos_id = llm_tokenizer._convert_token_to_id_with_added_voc("</s>")
|
| 37 |
+
pad_id = llm_tokenizer._convert_token_to_id_with_added_voc("<pad>")
|
| 38 |
+
lang_code = f"<2{source_lang}>" # e.g. <2hi>
|
| 39 |
+
inputs = llm_tokenizer(text + " </s> " + lang_code, add_special_tokens=False, return_tensors="pt")
|
| 40 |
+
outputs = llm_model.generate(**inputs, max_length=50, decoder_start_token_id=llm_tokenizer._convert_token_to_id_with_added_voc(lang_code))
|
| 41 |
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 42 |
|
| 43 |
# Translation if needed
|
|
|
|
| 46 |
outputs = trans_model.generate(**inputs)
|
| 47 |
response = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 48 |
|
| 49 |
+
# TTS (needs ref audio; use example)
|
| 50 |
+
ref_audio_path = "prompts/example.wav" # Upload example prompt to repo
|
| 51 |
+
ref_text = "Example reference text in language"
|
| 52 |
+
tts_output = tts_model(response, ref_audio_path=ref_audio_path, ref_text=ref_text)
|
| 53 |
with io.BytesIO() as buffer:
|
| 54 |
+
sf.write(buffer, tts_output, 24000, format="wav")
|
| 55 |
audio_bytes = buffer.getvalue()
|
| 56 |
|
| 57 |
return audio_bytes, text, response
|
| 58 |
|
| 59 |
iface = gr.Interface(
|
| 60 |
fn=full_pipeline,
|
| 61 |
+
inputs=[gr.Audio(type="file"), gr.Textbox(label="Source Lang e.g. hi"), gr.Textbox(label="Target Lang e.g. en")],
|
| 62 |
outputs=[gr.Audio(label="Response Audio"), gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Response Text")],
|
| 63 |
title="HanuVak Indic Conversation Backend"
|
| 64 |
)
|