Added and organized comments
Browse files
app.py
CHANGED
|
@@ -18,10 +18,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
| 18 |
|
| 19 |
from transformers import pipeline
|
| 20 |
|
|
|
|
|
|
|
| 21 |
# Set an environment variable
|
| 22 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 23 |
|
| 24 |
-
# Variables
|
| 25 |
SAMPLE_RATE = 16000 # Hz
|
| 26 |
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
|
| 27 |
DESCRIPTION = '''
|
|
@@ -29,8 +30,8 @@ DESCRIPTION = '''
|
|
| 29 |
<h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
|
| 30 |
<p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
|
| 31 |
<p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
|
| 32 |
-
<p>This demo accepts audio inputs not more than 40 seconds long.</p>
|
| 33 |
-
<p>
|
| 34 |
</div>
|
| 35 |
'''
|
| 36 |
PLACEHOLDER = """
|
|
@@ -42,7 +43,7 @@ PLACEHOLDER = """
|
|
| 42 |
|
| 43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
|
| 45 |
-
### ASR model
|
| 46 |
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
|
| 47 |
canary_model.eval()
|
| 48 |
# make sure beam size always 1 for consistency
|
|
@@ -51,7 +52,7 @@ decoding_cfg = canary_model.cfg.decoding
|
|
| 51 |
decoding_cfg.beam.beam_size = 1
|
| 52 |
canary_model.change_decoding_strategy(decoding_cfg)
|
| 53 |
|
| 54 |
-
### LLM model
|
| 55 |
# Load the tokenizer and model
|
| 56 |
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
| 57 |
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
|
|
@@ -64,11 +65,11 @@ terminators = [
|
|
| 64 |
llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 65 |
]
|
| 66 |
|
| 67 |
-
### TTS model
|
| 68 |
pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
|
| 69 |
|
| 70 |
|
| 71 |
-
|
| 72 |
def convert_audio(audio_filepath, tmpdir, utt_id):
|
| 73 |
"""
|
| 74 |
Convert all files to monochannel 16 kHz wav files.
|
|
@@ -99,8 +100,8 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
|
|
| 99 |
|
| 100 |
def transcribe(audio_filepath):
|
| 101 |
"""
|
| 102 |
-
Transcribes a converted audio file.
|
| 103 |
-
Set to english language with punctuations.
|
| 104 |
Returns the transcribed text as a string.
|
| 105 |
"""
|
| 106 |
|
|
@@ -136,15 +137,15 @@ def transcribe(audio_filepath):
|
|
| 136 |
def add_message(history, message):
|
| 137 |
"""
|
| 138 |
Adds the input message in the chatbot.
|
| 139 |
-
Returns the updated chatbot
|
| 140 |
"""
|
| 141 |
history.append((message, None))
|
| 142 |
return history
|
| 143 |
|
| 144 |
def bot(history, message):
|
| 145 |
"""
|
| 146 |
-
Gets the bot's response and
|
| 147 |
-
Returns the appended chatbot
|
| 148 |
"""
|
| 149 |
response = bot_response(message, history)
|
| 150 |
lines = response.split("\n")
|
|
@@ -162,8 +163,8 @@ def bot(history, message):
|
|
| 162 |
@spaces.GPU()
|
| 163 |
def bot_response(message, history):
|
| 164 |
"""
|
| 165 |
-
Generates a streaming response using the
|
| 166 |
-
Set max_new_tokens =
|
| 167 |
Returns the generated response in string format.
|
| 168 |
"""
|
| 169 |
conversation = []
|
|
@@ -175,7 +176,7 @@ def bot_response(message, history):
|
|
| 175 |
|
| 176 |
outputs = llama3_model.generate(
|
| 177 |
input_ids,
|
| 178 |
-
max_new_tokens =
|
| 179 |
eos_token_id = terminators,
|
| 180 |
do_sample=True,
|
| 181 |
temperature=0.6,
|
|
@@ -190,7 +191,7 @@ def bot_response(message, history):
|
|
| 190 |
@spaces.GPU()
|
| 191 |
def voice_player(history):
|
| 192 |
"""
|
| 193 |
-
Plays the generated response using the
|
| 194 |
Returns the audio player with the generated response.
|
| 195 |
"""
|
| 196 |
_, text = history[-1]
|
|
@@ -205,7 +206,9 @@ def voice_player(history):
|
|
| 205 |
visible=True)
|
| 206 |
return voice
|
| 207 |
|
|
|
|
| 208 |
|
|
|
|
| 209 |
with gr.Blocks(
|
| 210 |
title="MyAlexa",
|
| 211 |
css="""
|
|
@@ -251,13 +254,13 @@ with gr.Blocks(
|
|
| 251 |
visible=False # set to True to see processing time of asr transcription
|
| 252 |
)
|
| 253 |
|
| 254 |
-
gr.HTML("<p><b>
|
| 255 |
|
| 256 |
out_audio = gr.Audio( # Shows an audio player for the generated response
|
| 257 |
value = None,
|
| 258 |
-
label="Response
|
| 259 |
show_label=True,
|
| 260 |
-
visible=False # set to True to see processing time of
|
| 261 |
)
|
| 262 |
|
| 263 |
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
|
|
@@ -270,6 +273,7 @@ with gr.Blocks(
|
|
| 270 |
outputs = [chat_input]
|
| 271 |
)
|
| 272 |
|
|
|
|
| 273 |
demo.queue()
|
| 274 |
if __name__ == "__main__":
|
| 275 |
demo.launch()
|
|
|
|
| 18 |
|
| 19 |
from transformers import pipeline
|
| 20 |
|
| 21 |
+
#### Variables ###
|
| 22 |
+
|
| 23 |
# Set an environment variable
|
| 24 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 25 |
|
|
|
|
| 26 |
SAMPLE_RATE = 16000 # Hz
|
| 27 |
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
|
| 28 |
DESCRIPTION = '''
|
|
|
|
| 30 |
<h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
|
| 31 |
<p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
|
| 32 |
<p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
|
| 33 |
+
<p>This demo accepts audio inputs not more than 40 seconds long. Transcription and responses are limited to the English language.</p>
|
| 34 |
+
<p>The LLM max_new_tokens, temperature and top_p are set to 512, 0.6 and 0.9 respectively</p>
|
| 35 |
</div>
|
| 36 |
'''
|
| 37 |
PLACEHOLDER = """
|
|
|
|
| 43 |
|
| 44 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
|
| 46 |
+
### ASR model ###
|
| 47 |
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
|
| 48 |
canary_model.eval()
|
| 49 |
# make sure beam size always 1 for consistency
|
|
|
|
| 52 |
decoding_cfg.beam.beam_size = 1
|
| 53 |
canary_model.change_decoding_strategy(decoding_cfg)
|
| 54 |
|
| 55 |
+
### LLM model ###
|
| 56 |
# Load the tokenizer and model
|
| 57 |
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
| 58 |
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
|
|
|
|
| 65 |
llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 66 |
]
|
| 67 |
|
| 68 |
+
### TTS model ###
|
| 69 |
pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
|
| 70 |
|
| 71 |
|
| 72 |
+
### Start of functions ###
|
| 73 |
def convert_audio(audio_filepath, tmpdir, utt_id):
|
| 74 |
"""
|
| 75 |
Convert all files to monochannel 16 kHz wav files.
|
|
|
|
| 100 |
|
| 101 |
def transcribe(audio_filepath):
|
| 102 |
"""
|
| 103 |
+
Transcribes a converted audio file using the asr model.
|
| 104 |
+
Set to the english language with punctuations.
|
| 105 |
Returns the transcribed text as a string.
|
| 106 |
"""
|
| 107 |
|
|
|
|
| 137 |
def add_message(history, message):
|
| 138 |
"""
|
| 139 |
Adds the input message in the chatbot.
|
| 140 |
+
Returns the updated chatbot.
|
| 141 |
"""
|
| 142 |
history.append((message, None))
|
| 143 |
return history
|
| 144 |
|
| 145 |
def bot(history, message):
|
| 146 |
"""
|
| 147 |
+
Gets the bot's response and adds it in the chatbot.
|
| 148 |
+
Returns the appended chatbot.
|
| 149 |
"""
|
| 150 |
response = bot_response(message, history)
|
| 151 |
lines = response.split("\n")
|
|
|
|
| 163 |
@spaces.GPU()
|
| 164 |
def bot_response(message, history):
|
| 165 |
"""
|
| 166 |
+
Generates a streaming response using the llm model.
|
| 167 |
+
Set max_new_tokens = 512, temperature=0.6, and top_p=0.9
|
| 168 |
Returns the generated response in string format.
|
| 169 |
"""
|
| 170 |
conversation = []
|
|
|
|
| 176 |
|
| 177 |
outputs = llama3_model.generate(
|
| 178 |
input_ids,
|
| 179 |
+
max_new_tokens = 512,
|
| 180 |
eos_token_id = terminators,
|
| 181 |
do_sample=True,
|
| 182 |
temperature=0.6,
|
|
|
|
| 191 |
@spaces.GPU()
|
| 192 |
def voice_player(history):
|
| 193 |
"""
|
| 194 |
+
Plays the generated response using the tts model.
|
| 195 |
Returns the audio player with the generated response.
|
| 196 |
"""
|
| 197 |
_, text = history[-1]
|
|
|
|
| 206 |
visible=True)
|
| 207 |
return voice
|
| 208 |
|
| 209 |
+
### End of functions ###
|
| 210 |
|
| 211 |
+
### Interface using Blocks###
|
| 212 |
with gr.Blocks(
|
| 213 |
title="MyAlexa",
|
| 214 |
css="""
|
|
|
|
| 254 |
visible=False # set to True to see processing time of asr transcription
|
| 255 |
)
|
| 256 |
|
| 257 |
+
gr.HTML("<p><b>[Optional]:</b> Replay MyAlexa's voice response.</p>")
|
| 258 |
|
| 259 |
out_audio = gr.Audio( # Shows an audio player for the generated response
|
| 260 |
value = None,
|
| 261 |
+
label="Response Audio Player",
|
| 262 |
show_label=True,
|
| 263 |
+
visible=False # set to True to see processing time of the first tts audio generation
|
| 264 |
)
|
| 265 |
|
| 266 |
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
|
|
|
|
| 273 |
outputs = [chat_input]
|
| 274 |
)
|
| 275 |
|
| 276 |
+
### Queue and launch the demo ###
|
| 277 |
demo.queue()
|
| 278 |
if __name__ == "__main__":
|
| 279 |
demo.launch()
|