Dua Rajper commited on
Commit
83de1ee
·
verified ·
1 Parent(s): a238dc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -86
app.py CHANGED
@@ -1,13 +1,8 @@
1
  import os
2
  import streamlit as st
3
  from groq import Groq
4
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
5
- from espnet2.bin.tts_inference import Text2Speech
6
  import soundfile as sf
7
- from pydub import AudioSegment
8
- import io
9
- from streamlit_webrtc import webrtc_streamer, WebRtcMode, AudioProcessorBase
10
- import av
11
  import numpy as np
12
 
13
  # Load Groq API key from environment variables
@@ -19,34 +14,16 @@ if not GROQ_API_KEY:
19
  # Initialize Groq client
20
  groq_client = Groq(api_key=GROQ_API_KEY)
21
 
22
- # Load models
23
  @st.cache_resource
24
  def load_models():
25
- processor = AutoProcessor.from_pretrained("openai/whisper-small")
26
- stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
27
- stt_pipe = pipeline(
28
- "automatic-speech-recognition",
29
- model=stt_model,
30
- tokenizer=processor.tokenizer,
31
- feature_extractor=processor.feature_extractor,
32
- return_timestamps=True
33
- )
34
- tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet")
35
- return stt_pipe, tts_model
36
 
37
- stt_pipe, tts_model = load_models()
38
-
39
- # Audio recorder
40
- class AudioRecorder(AudioProcessorBase):
41
- def __init__(self):
42
- self.audio_frames = []
43
-
44
- def recv(self, frame: av.AudioFrame) -> av.AudioFrame:
45
- self.audio_frames.append(frame.to_ndarray())
46
- return frame
47
 
48
  # Streamlit app
49
- st.title("Voice-Enabled Chatbot")
50
 
51
  # Audio upload
52
  uploaded_file = st.file_uploader("Upload a WAV file", type=["wav"])
@@ -55,78 +32,26 @@ if uploaded_file is not None:
55
  # Save uploaded file
56
  with open("uploaded_audio.wav", "wb") as f:
57
  f.write(uploaded_file.getbuffer())
58
-
59
  st.success("File uploaded successfully!")
60
-
61
  # Process the uploaded audio
62
  speech, _ = sf.read("uploaded_audio.wav")
63
  output = stt_pipe(speech)
64
  st.write("Transcribed Text:", output['text'])
65
-
66
- if 'chunks' in output:
67
- st.write("Transcribed Text with Timestamps:")
68
- for chunk in output['chunks']:
69
- st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
70
-
71
  try:
72
  st.write("Input Text:", output['text'])
73
  chat_completion = groq_client.chat.completions.create(
74
  messages=[{"role": "user", "content": output['text']}],
75
  model="mixtral-8x7b-32768",
76
  temperature=0.5,
77
- max_tokens=2048, # Increased max_tokens
78
  )
79
  st.write("API Response:", chat_completion)
80
  response = chat_completion.choices[0].message.content
81
  st.write("Generated Response:", response)
82
- speech, *_ = tts_model(response, spembs=tts_model.spembs[0])
83
- st.write("TTS Output:", speech)
84
- sf.write("response.wav", speech, 22050)
85
- st.audio("response.wav")
86
  except Exception as e:
87
  st.error(f"Error generating response: {e}")
88
-
89
  else:
90
- # Audio recorder
91
- st.write("Record your voice:")
92
- webrtc_ctx = webrtc_streamer(
93
- key="audio-recorder",
94
- mode=WebRtcMode.SENDONLY,
95
- audio_processor_factory=AudioRecorder,
96
- media_stream_constraints={"audio": True, "video": False},
97
- )
98
-
99
- if webrtc_ctx.audio_processor:
100
- st.write("Recording... Press 'Stop' to finish recording.")
101
- if st.button("Stop and Process Recording"):
102
- audio_frames = webrtc_ctx.audio_processor.audio_frames
103
- if audio_frames:
104
- audio_data = np.concatenate(audio_frames)
105
- sf.write("recorded_audio.wav", audio_data, samplerate=16000)
106
- st.success("Recording saved as recorded_audio.wav")
107
- speech, _ = sf.read("recorded_audio.wav")
108
- output = stt_pipe(speech)
109
- st.write("Transcribed Text:", output['text'])
110
- if 'chunks' in output:
111
- st.write("Transcribed Text with Timestamps:")
112
- for chunk in output['chunks']:
113
- st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
114
- try:
115
- st.write("Input Text:", output['text'])
116
- chat_completion = groq_client.chat.completions.create(
117
- messages=[{"role": "user", "content": output['text']}],
118
- model="mixtral-8x7b-32768",
119
- temperature=0.5,
120
- max_tokens=2048, # Increased max_tokens
121
- )
122
- st.write("API Response:", chat_completion)
123
- response = chat_completion.choices[0].message.content
124
- st.write("Generated Response:", response)
125
- speech, *_ = tts_model(response, spembs=tts_model.spembs[0])
126
- st.write("TTS Output:", speech)
127
- sf.write("response.wav", speech, 22050)
128
- st.audio("response.wav")
129
- except Exception as e:
130
- st.error(f"Error generating response: {e}")
131
- else:
132
- st.error("No audio recorded. Please try again.")
 
1
  import os
2
  import streamlit as st
3
  from groq import Groq
4
+ from transformers import pipeline
 
5
  import soundfile as sf
 
 
 
 
6
  import numpy as np
7
 
8
  # Load Groq API key from environment variables
 
14
  # Initialize Groq client
15
  groq_client = Groq(api_key=GROQ_API_KEY)
16
 
17
+ # Load models (Smaller Whisper model)
18
  @st.cache_resource
19
  def load_models():
20
+ stt_pipe = pipeline("automatic-speech-recognition", model="distil-whisper/distil-small.en")
21
+ return stt_pipe
 
 
 
 
 
 
 
 
 
22
 
23
+ stt_pipe = load_models()
 
 
 
 
 
 
 
 
 
24
 
25
  # Streamlit app
26
+ st.title("Voice-Enabled Chatbot (CPU Optimized)")
27
 
28
  # Audio upload
29
  uploaded_file = st.file_uploader("Upload a WAV file", type=["wav"])
 
32
  # Save uploaded file
33
  with open("uploaded_audio.wav", "wb") as f:
34
  f.write(uploaded_file.getbuffer())
 
35
  st.success("File uploaded successfully!")
 
36
  # Process the uploaded audio
37
  speech, _ = sf.read("uploaded_audio.wav")
38
  output = stt_pipe(speech)
39
  st.write("Transcribed Text:", output['text'])
 
 
 
 
 
 
40
  try:
41
  st.write("Input Text:", output['text'])
42
  chat_completion = groq_client.chat.completions.create(
43
  messages=[{"role": "user", "content": output['text']}],
44
  model="mixtral-8x7b-32768",
45
  temperature=0.5,
46
+ max_tokens=1024,
47
  )
48
  st.write("API Response:", chat_completion)
49
  response = chat_completion.choices[0].message.content
50
  st.write("Generated Response:", response)
51
+ st.write("Response played via browser audio:")
52
+ st.write(response)
53
+
 
54
  except Exception as e:
55
  st.error(f"Error generating response: {e}")
 
56
  else:
57
+ st.write("This application currently only supports file uploads.")