Mr7Explorer commited on
Commit
62f2e1f
·
verified ·
1 Parent(s): 14e8d86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -27
app.py CHANGED
@@ -1,18 +1,18 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCTC
3
  import torch
4
- import torchaudio
5
  import librosa
6
  import soundfile as sf
7
  import io
8
  import os
9
 
10
- # For gated models, set token
11
- os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN" # From huggingface.co/settings/tokens
12
 
13
- # Load models
14
- asr_model_name = "ai4bharat/indic-conformer-600m-multilingual"
15
- asr_model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True)
 
16
 
17
  llm_model_name = "ai4bharat/IndicBART"
18
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, do_lower_case=False, use_fast=False, keep_accents=True)
@@ -22,36 +22,32 @@ trans_model_name = "ai4bharat/IndicTrans3-beta"
22
  trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
23
  trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_name)
24
 
25
- tts_model_name = "ai4bharat/IndicF5"
26
- tts_model = AutoModel.from_pretrained(tts_model_name, trust_remote_code=True)
27
 
28
  def full_pipeline(audio, source_lang, target_lang):
29
  # ASR
30
- audio_array, sr = librosa.load(io.BytesIO(audio), sr=16000)
31
- wav = torch.tensor(audio_array).unsqueeze(0)
32
- text = asr_model(wav, source_lang, "ctc")
33
-
34
- # LLM: Simple generation
35
- bos_id = llm_tokenizer._convert_token_to_id_with_added_voc("<s>")
36
- eos_id = llm_tokenizer._convert_token_to_id_with_added_voc("</s>")
37
- pad_id = llm_tokenizer._convert_token_to_id_with_added_voc("<pad>")
38
- lang_code = f"<2{source_lang}>" # e.g. <2hi>
39
- inputs = llm_tokenizer(text + " </s> " + lang_code, add_special_tokens=False, return_tensors="pt")
40
- outputs = llm_model.generate(**inputs, max_length=50, decoder_start_token_id=llm_tokenizer._convert_token_to_id_with_added_voc(lang_code))
41
  response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
- # Translation if needed
44
  if source_lang != target_lang:
45
  inputs = trans_tokenizer(response, return_tensors="pt")
46
  outputs = trans_model.generate(**inputs)
47
  response = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
48
 
49
- # TTS (needs ref audio; use example)
50
- ref_audio_path = "prompts/example.wav" # Upload example prompt to repo
51
- ref_text = "Example reference text in language"
52
- tts_output = tts_model(response, ref_audio_path=ref_audio_path, ref_text=ref_text)
53
  with io.BytesIO() as buffer:
54
- sf.write(buffer, tts_output, 24000, format="wav")
55
  audio_bytes = buffer.getvalue()
56
 
57
  return audio_bytes, text, response
@@ -60,7 +56,7 @@ iface = gr.Interface(
60
  fn=full_pipeline,
61
  inputs=[gr.Audio(type="file"), gr.Textbox(label="Source Lang e.g. hi"), gr.Textbox(label="Target Lang e.g. en")],
62
  outputs=[gr.Audio(label="Response Audio"), gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Response Text")],
63
- title="HanuVak Indic Conversation Backend"
64
  )
65
 
66
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCTC, AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
  import torch
 
4
  import librosa
5
  import soundfile as sf
6
  import io
7
  import os
8
 
9
+ # Use HF_TOKEN from env
10
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
11
 
12
+ # Models (use CPU if no GPU; for free tier, may be slow/large - upgrade for GPU)
13
+ asr_model_name = "ai4bharat/indicconformer-600m-multilingual"
14
+ asr_processor = AutoProcessor.from_pretrained(asr_model_name)
15
+ asr_model = AutoModelForCTC.from_pretrained(asr_model_name)
16
 
17
  llm_model_name = "ai4bharat/IndicBART"
18
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, do_lower_case=False, use_fast=False, keep_accents=True)
 
22
  trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
23
  trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_name)
24
 
25
+ tts_pipe = pipeline("text-to-speech", model="ai4bharat/indic-parler-tts-v2") # Switch to non-gated if issues
 
26
 
27
  def full_pipeline(audio, source_lang, target_lang):
28
  # ASR
29
+ audio_array, _ = librosa.load(io.BytesIO(audio), sr=16000)
30
+ inputs = asr_processor(audio_array, sampling_rate=16000, return_tensors="pt")
31
+ with torch.no_grad():
32
+ logits = asr_model(inputs.input_values).logits
33
+ pred_ids = torch.argmax(logits, dim=-1)
34
+ text = asr_processor.batch_decode(pred_ids)[0]
35
+
36
+ # LLM response (echo for test)
37
+ inputs = llm_tokenizer(text, return_tensors="pt")
38
+ outputs = llm_model.generate(**inputs)
 
39
  response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
41
+ # Translation
42
  if source_lang != target_lang:
43
  inputs = trans_tokenizer(response, return_tensors="pt")
44
  outputs = trans_model.generate(**inputs)
45
  response = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
46
 
47
+ # TTS
48
+ tts_output = tts_pipe(response)
 
 
49
  with io.BytesIO() as buffer:
50
+ sf.write(buffer, tts_output["audio"][0], tts_output["sampling_rate"], format="wav")
51
  audio_bytes = buffer.getvalue()
52
 
53
  return audio_bytes, text, response
 
56
  fn=full_pipeline,
57
  inputs=[gr.Audio(type="file"), gr.Textbox(label="Source Lang e.g. hi"), gr.Textbox(label="Target Lang e.g. en")],
58
  outputs=[gr.Audio(label="Response Audio"), gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Response Text")],
59
+ title="HanuVak Backend"
60
  )
61
 
62
  if __name__ == "__main__":