Mr7Explorer commited on
Commit
14e8d86
·
verified ·
1 Parent(s): 1d14d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -1,38 +1,43 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCTC, Wav2Vec2Processor
3
  import torch
 
4
  import librosa
5
  import soundfile as sf
6
  import io
 
7
 
8
- # Load models (use latest from AI4Bharat)
9
- asr_model_name = "ai4bharat/indicconformer-600m-multilingual"
10
- asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
11
- asr_model = AutoModelForCTC.from_pretrained(asr_model_name)
12
 
13
- llm_model_name = "ai4bharat/IndicBART" # Fine-tuned on IndicAlign
14
- llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
 
 
 
 
15
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
16
 
17
  trans_model_name = "ai4bharat/IndicTrans3-beta"
18
  trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
19
  trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_name)
20
 
21
- tts_model_name = "ai4bharat/indicf5" # Or indic-parler-tts-v2
22
- tts_pipe = pipeline("text-to-speech", model=tts_model_name)
23
 
24
  def full_pipeline(audio, source_lang, target_lang):
25
- # ASR: Audio to text
26
- audio_array, _ = librosa.load(io.BytesIO(audio), sr=16000)
27
- inputs = asr_processor(audio_array, sampling_rate=16000, return_tensors="pt")
28
- with torch.no_grad():
29
- logits = asr_model(inputs.input_values).logits
30
- pred_ids = torch.argmax(logits, dim=-1)
31
- text = asr_processor.batch_decode(pred_ids)[0]
32
-
33
- # LLM: Generate response (simple echo for now; enhance later)
34
- inputs = llm_tokenizer(text, return_tensors="pt")
35
- outputs = llm_model.generate(**inputs)
 
36
  response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
37
 
38
  # Translation if needed
@@ -41,17 +46,19 @@ def full_pipeline(audio, source_lang, target_lang):
41
  outputs = trans_model.generate(**inputs)
42
  response = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
- # TTS: Text to audio
45
- tts_output = tts_pipe(response)
 
 
46
  with io.BytesIO() as buffer:
47
- sf.write(buffer, tts_output["audio"][0], tts_output["sampling_rate"], format="wav")
48
  audio_bytes = buffer.getvalue()
49
 
50
  return audio_bytes, text, response
51
 
52
  iface = gr.Interface(
53
  fn=full_pipeline,
54
- inputs=[gr.Audio(type="file"), gr.Textbox(label="Source Lang"), gr.Textbox(label="Target Lang")],
55
  outputs=[gr.Audio(label="Response Audio"), gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Response Text")],
56
  title="HanuVak Indic Conversation Backend"
57
  )
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCTC
3
  import torch
4
+ import torchaudio
5
  import librosa
6
  import soundfile as sf
7
  import io
8
+ import os
9
 
10
+ # For gated models, set token
11
+ os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN" # From huggingface.co/settings/tokens
 
 
12
 
13
+ # Load models
14
+ asr_model_name = "ai4bharat/indic-conformer-600m-multilingual"
15
+ asr_model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True)
16
+
17
+ llm_model_name = "ai4bharat/IndicBART"
18
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, do_lower_case=False, use_fast=False, keep_accents=True)
19
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
20
 
21
  trans_model_name = "ai4bharat/IndicTrans3-beta"
22
  trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
23
  trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_name)
24
 
25
+ tts_model_name = "ai4bharat/IndicF5"
26
+ tts_model = AutoModel.from_pretrained(tts_model_name, trust_remote_code=True)
27
 
28
  def full_pipeline(audio, source_lang, target_lang):
29
+ # ASR
30
+ audio_array, sr = librosa.load(io.BytesIO(audio), sr=16000)
31
+ wav = torch.tensor(audio_array).unsqueeze(0)
32
+ text = asr_model(wav, source_lang, "ctc")
33
+
34
+ # LLM: Simple generation
35
+ bos_id = llm_tokenizer._convert_token_to_id_with_added_voc("<s>")
36
+ eos_id = llm_tokenizer._convert_token_to_id_with_added_voc("</s>")
37
+ pad_id = llm_tokenizer._convert_token_to_id_with_added_voc("<pad>")
38
+ lang_code = f"<2{source_lang}>" # e.g. <2hi>
39
+ inputs = llm_tokenizer(text + " </s> " + lang_code, add_special_tokens=False, return_tensors="pt")
40
+ outputs = llm_model.generate(**inputs, max_length=50, decoder_start_token_id=llm_tokenizer._convert_token_to_id_with_added_voc(lang_code))
41
  response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
  # Translation if needed
 
46
  outputs = trans_model.generate(**inputs)
47
  response = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
48
 
49
+ # TTS (needs ref audio; use example)
50
+ ref_audio_path = "prompts/example.wav" # Upload example prompt to repo
51
+ ref_text = "Example reference text in language"
52
+ tts_output = tts_model(response, ref_audio_path=ref_audio_path, ref_text=ref_text)
53
  with io.BytesIO() as buffer:
54
+ sf.write(buffer, tts_output, 24000, format="wav")
55
  audio_bytes = buffer.getvalue()
56
 
57
  return audio_bytes, text, response
58
 
59
  iface = gr.Interface(
60
  fn=full_pipeline,
61
+ inputs=[gr.Audio(type="file"), gr.Textbox(label="Source Lang e.g. hi"), gr.Textbox(label="Target Lang e.g. en")],
62
  outputs=[gr.Audio(label="Response Audio"), gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Response Text")],
63
  title="HanuVak Indic Conversation Backend"
64
  )