Dua Rajper commited on
Commit
a238dc9
·
verified ·
1 Parent(s): 45f3399

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -59
app.py CHANGED
@@ -20,9 +20,8 @@ if not GROQ_API_KEY:
20
  groq_client = Groq(api_key=GROQ_API_KEY)
21
 
22
  # Load models
23
- @st.cache_resource # Use st.cache_resource for caching models
24
  def load_models():
25
- # Speech-to-Text
26
  processor = AutoProcessor.from_pretrained("openai/whisper-small")
27
  stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
28
  stt_pipe = pipeline(
@@ -30,9 +29,8 @@ def load_models():
30
  model=stt_model,
31
  tokenizer=processor.tokenizer,
32
  feature_extractor=processor.feature_extractor,
33
- return_timestamps=True # Enable timestamps for long-form audio
34
  )
35
- # Text-to-Speech
36
  tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet")
37
  return stt_pipe, tts_model
38
 
@@ -50,59 +48,85 @@ class AudioRecorder(AudioProcessorBase):
50
  # Streamlit app
51
  st.title("Voice-Enabled Chatbot")
52
 
53
- # Audio recorder
54
- st.write("Record your voice:")
55
- webrtc_ctx = webrtc_streamer(
56
- key="audio-recorder",
57
- mode=WebRtcMode.SENDONLY,
58
- audio_processor_factory=AudioRecorder,
59
- media_stream_constraints={"audio": True, "video": False},
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if webrtc_ctx.audio_processor:
63
- st.write("Recording... Press 'Stop' to finish recording.")
64
- # Save recorded audio to a WAV file
65
- if st.button("Stop and Process Recording"):
66
- audio_frames = webrtc_ctx.audio_processor.audio_frames
67
- if audio_frames:
68
- # Combine audio frames into a single array
69
- audio_data = np.concatenate(audio_frames)
70
- # Save as WAV file
71
- sf.write("recorded_audio.wav", audio_data, samplerate=16000)
72
- st.success("Recording saved as recorded_audio.wav")
73
- # Process the recorded audio
74
- speech, _ = sf.read("recorded_audio.wav")
75
- output = stt_pipe(speech) # Transcribe with timestamps
76
- # Debug: Print the transcribed text
77
- st.write("Transcribed Text:", output['text'])
78
- # Display the text with timestamps (optional)
79
- if 'chunks' in output:
80
- st.write("Transcribed Text with Timestamps:")
81
- for chunk in output['chunks']:
82
- st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
83
- # Generate response using Groq API
84
- try:
85
- # Debug: Print the input text
86
- st.write("Input Text:", output['text'])
87
- chat_completion = groq_client.chat.completions.create(
88
- messages=[{"role": "user", "content": output['text']}],
89
- model="mixtral-8x7b-32768",
90
- temperature=0.5,
91
- max_tokens=1024,
92
- )
93
- # Debug: Print the API response
94
- st.write("API Response:", chat_completion)
95
- # Extract the generated response
96
- response = chat_completion.choices[0].message.content
97
- st.write("Generated Response:", response)
98
- # Convert response to speech
99
- speech, *_ = tts_model(response, spembs=tts_model.spembs[0]) # Use the first speaker embedding
100
- # Debug: Print the TTS output
101
- st.write("TTS Output:", speech)
102
- # Save and play the speech
103
- sf.write("response.wav", speech, 22050)
104
- st.audio("response.wav")
105
- except Exception as e:
106
- st.error(f"Error generating response: {e}")
107
- else:
108
- st.error("No audio recorded. Please try again.")
 
20
  groq_client = Groq(api_key=GROQ_API_KEY)
21
 
22
  # Load models
23
+ @st.cache_resource
24
  def load_models():
 
25
  processor = AutoProcessor.from_pretrained("openai/whisper-small")
26
  stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
27
  stt_pipe = pipeline(
 
29
  model=stt_model,
30
  tokenizer=processor.tokenizer,
31
  feature_extractor=processor.feature_extractor,
32
+ return_timestamps=True
33
  )
 
34
  tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet")
35
  return stt_pipe, tts_model
36
 
 
48
  # Streamlit app
49
  st.title("Voice-Enabled Chatbot")
50
 
51
+ # Audio upload
52
+ uploaded_file = st.file_uploader("Upload a WAV file", type=["wav"])
53
+
54
+ if uploaded_file is not None:
55
+ # Save uploaded file
56
+ with open("uploaded_audio.wav", "wb") as f:
57
+ f.write(uploaded_file.getbuffer())
58
+
59
+ st.success("File uploaded successfully!")
60
+
61
+ # Process the uploaded audio
62
+ speech, _ = sf.read("uploaded_audio.wav")
63
+ output = stt_pipe(speech)
64
+ st.write("Transcribed Text:", output['text'])
65
+
66
+ if 'chunks' in output:
67
+ st.write("Transcribed Text with Timestamps:")
68
+ for chunk in output['chunks']:
69
+ st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
70
+
71
+ try:
72
+ st.write("Input Text:", output['text'])
73
+ chat_completion = groq_client.chat.completions.create(
74
+ messages=[{"role": "user", "content": output['text']}],
75
+ model="mixtral-8x7b-32768",
76
+ temperature=0.5,
77
+ max_tokens=2048, # Increased max_tokens
78
+ )
79
+ st.write("API Response:", chat_completion)
80
+ response = chat_completion.choices[0].message.content
81
+ st.write("Generated Response:", response)
82
+ speech, *_ = tts_model(response, spembs=tts_model.spembs[0])
83
+ st.write("TTS Output:", speech)
84
+ sf.write("response.wav", speech, 22050)
85
+ st.audio("response.wav")
86
+ except Exception as e:
87
+ st.error(f"Error generating response: {e}")
88
+
89
+ else:
90
+ # Audio recorder
91
+ st.write("Record your voice:")
92
+ webrtc_ctx = webrtc_streamer(
93
+ key="audio-recorder",
94
+ mode=WebRtcMode.SENDONLY,
95
+ audio_processor_factory=AudioRecorder,
96
+ media_stream_constraints={"audio": True, "video": False},
97
+ )
98
 
99
+ if webrtc_ctx.audio_processor:
100
+ st.write("Recording... Press 'Stop' to finish recording.")
101
+ if st.button("Stop and Process Recording"):
102
+ audio_frames = webrtc_ctx.audio_processor.audio_frames
103
+ if audio_frames:
104
+ audio_data = np.concatenate(audio_frames)
105
+ sf.write("recorded_audio.wav", audio_data, samplerate=16000)
106
+ st.success("Recording saved as recorded_audio.wav")
107
+ speech, _ = sf.read("recorded_audio.wav")
108
+ output = stt_pipe(speech)
109
+ st.write("Transcribed Text:", output['text'])
110
+ if 'chunks' in output:
111
+ st.write("Transcribed Text with Timestamps:")
112
+ for chunk in output['chunks']:
113
+ st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
114
+ try:
115
+ st.write("Input Text:", output['text'])
116
+ chat_completion = groq_client.chat.completions.create(
117
+ messages=[{"role": "user", "content": output['text']}],
118
+ model="mixtral-8x7b-32768",
119
+ temperature=0.5,
120
+ max_tokens=2048, # Increased max_tokens
121
+ )
122
+ st.write("API Response:", chat_completion)
123
+ response = chat_completion.choices[0].message.content
124
+ st.write("Generated Response:", response)
125
+ speech, *_ = tts_model(response, spembs=tts_model.spembs[0])
126
+ st.write("TTS Output:", speech)
127
+ sf.write("response.wav", speech, 22050)
128
+ st.audio("response.wav")
129
+ except Exception as e:
130
+ st.error(f"Error generating response: {e}")
131
+ else:
132
+ st.error("No audio recorded. Please try again.")