File size: 10,118 Bytes
0fc97b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# D:\jan-contract\components/chat_interface.py

import streamlit as st
import speech_recognition as sr
from gtts import gTTS
import io
import av
import queue
import wave
import threading
import time
import numpy as np
from typing import Optional

from streamlit_webrtc import webrtc_streamer, WebRtcMode

# --- Setup ---
recognizer = sr.Recognizer()
recognizer.energy_threshold = 300  # Lower threshold for better sensitivity
recognizer.dynamic_energy_threshold = True
recognizer.pause_threshold = 0.8

def text_to_speech(text: str) -> bytes:
    """Converts text to an in-memory MP3 file bytes."""
    try:
        audio_io = io.BytesIO()
        tts = gTTS(text=text, lang='en', slow=False)
        tts.write_to_fp(audio_io)
        audio_io.seek(0)
        return audio_io.read()
    except Exception as e:
        st.error(f"Error during Text-to-Speech: {e}")
        return None

def chat_interface(handler_function, session_state_key: str):
    """
    A reusable component that provides a full Text and Voice chat interface.

    Args:
        handler_function: The function to call with the user's text input.
        session_state_key (str): A unique key to store chat history AND to use
                                 as a base for widget keys.
    """
    st.subheader("💬 Chat via Text")
    
    if session_state_key not in st.session_state:
        st.session_state[session_state_key] = []

    for message in st.session_state[session_state_key]:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    if prompt := st.chat_input("Ask a question...", key=f"chat_input_{session_state_key}"):
        st.session_state[session_state_key].append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                response = handler_function(prompt)
                st.markdown(response)
        
        st.session_state[session_state_key].append({"role": "assistant", "content": response})

    st.divider()

    st.subheader("🎙️ Chat via Voice")
    st.info("🎤 **Instructions:** Click START to begin recording, speak your question clearly, then click STOP.")

    # Initialize session state for voice recording
    voice_key = f"voice_{session_state_key}"
    if f"{voice_key}_frames" not in st.session_state:
        st.session_state[f"{voice_key}_frames"] = []
    if f"{voice_key}_processing" not in st.session_state:
        st.session_state[f"{voice_key}_processing"] = False
    if f"{voice_key}_recording_start" not in st.session_state:
        st.session_state[f"{voice_key}_recording_start"] = None
    if f"{voice_key}_bytes" not in st.session_state:
        st.session_state[f"{voice_key}_bytes"] = 0
    if f"{voice_key}_component_key" not in st.session_state:
        st.session_state[f"{voice_key}_component_key"] = f"voice-chat-{session_state_key}-{int(time.time())}"

    def audio_frame_callback(frame: av.AudioFrame):
        """Callback to collect audio frames during recording"""
        if st.session_state[f"{voice_key}_processing"]:
            try:
                # Resample every frame to 16kHz mono, 16-bit PCM for SR
                resampled = frame.reformat(format="s16", layout="mono", rate=16000)
                chunk = resampled.planes[0].to_bytes()
                st.session_state[f"{voice_key}_frames"].append(chunk)
                st.session_state[f"{voice_key}_bytes"] += len(chunk)
            except Exception as e:
                st.error(f"Error processing audio frame: {e}")

    def process_voice_input():
        """Process the collected audio frames and get response"""
        # Short-audio threshold (~0.5s at 16kHz, 16-bit mono)
        total_bytes = st.session_state.get(f"{voice_key}_bytes", 0)
        if total_bytes < int(16000 * 2 * 0.5):
            st.error("❌ No audio captured or recording too short. Please speak for at least 1 second and try again.")
            st.session_state[f"{voice_key}_frames"] = []
            st.session_state[f"{voice_key}_processing"] = False
            st.session_state[f"{voice_key}_bytes"] = 0
            return

        status_placeholder = st.empty()
        status_placeholder.info("🔄 Processing audio...")

        try:
            # Combine all audio frames (already PCM s16 mono 16kHz)
            audio_data = b"".join(st.session_state[f"{voice_key}_frames"])
            
            # Create WAV file in memory with proper format
            with io.BytesIO() as wav_buffer:
                with wave.open(wav_buffer, 'wb') as wf:
                    wf.setnchannels(1)  # Mono
                    wf.setsampwidth(2)  # 16-bit
                    wf.setframerate(16000)  # 16kHz
                    wf.writeframes(audio_data)
                wav_buffer.seek(0)
                
                # Use speech recognition with better error handling
                with sr.AudioFile(wav_buffer) as source:
                    # Adjust for ambient noise quickly; avoid long pauses
                    recognizer.adjust_for_ambient_noise(source, duration=0.1)
                    audio = recognizer.record(source)
                
                # Recognize speech with multiple fallbacks
                try:
                    user_input = recognizer.recognize_google(audio, language="en-US")
                except sr.UnknownValueError:
                    try:
                        user_input = recognizer.recognize_google(audio, language="en-GB")
                    except sr.UnknownValueError:
                        st.error("❌ Could not understand the audio. Please speak more clearly and try again.")
                        return
                
                if not user_input.strip():
                    st.error("❌ No speech detected. Please try again.")
                    return
                
                st.write(f"🎤 **You said:** *{user_input}*")
                
                # Get response from handler
                with st.spinner("🤔 Getting response..."):
                    response_text = handler_function(user_input)
                
                st.write(f"🤖 **Assistant says:** *{response_text}*")
                
                # Generate audio response
                with st.spinner("🔊 Generating audio response..."):
                    audio_response = text_to_speech(response_text)
                    if audio_response:
                        st.audio(audio_response, format="audio/mp3", start_time=0)
                        st.success("✅ Audio response generated!")
                
                # Add to chat history
                st.session_state[session_state_key].append({"role": "user", "content": user_input})
                st.session_state[session_state_key].append({"role": "assistant", "content": response_text})

        except sr.RequestError as e:
            st.error(f"❌ Speech recognition service error: {e}")
        except Exception as e:
            st.error(f"❌ Error processing audio: {str(e)}")
        finally:
            # Clear the audio frames
            st.session_state[f"{voice_key}_frames"] = []
            st.session_state[f"{voice_key}_processing"] = False
            st.session_state[f"{voice_key}_bytes"] = 0
            status_placeholder.empty()

    # Create a unique key for each component instance to avoid registration issues
    component_key = st.session_state[f"{voice_key}_component_key"]
    
    # WebRTC streamer with proper error handling and component lifecycle
    try:
        ctx = webrtc_streamer(
            key=component_key,
            mode=WebRtcMode.SENDONLY,
            rtc_configuration={
                "iceServers": [
                    {"urls": ["stun:stun.l.google.com:19302"]},
                    {"urls": ["stun:stun1.l.google.com:19302"]}
                ]
            },
            audio_frame_callback=audio_frame_callback,
            media_stream_constraints={
                "video": False,
                "audio": {
                    "echoCancellation": True,
                    "noiseSuppression": True,
                    "autoGainControl": True
                }
            },
            async_processing=True,
            on_change=lambda: None,  # Prevent component registration issues
        )
        
        # Handle recording state with better feedback
        bytes_captured = st.session_state.get(f"{voice_key}_bytes", 0)
        
        if ctx.state.playing and not st.session_state.get(f"{voice_key}_processing", False):
            st.session_state[f"{voice_key}_processing"] = True
            st.session_state[f"{voice_key}_recording_start"] = time.time()
            st.session_state[f"{voice_key}_frames"] = []
            st.session_state[f"{voice_key}_bytes"] = 0
            st.success("🔴 **Recording started!** Speak your question now...")
            
        elif ctx.state.playing and st.session_state.get(f"{voice_key}_processing", False):
            # Show recording progress
            if st.session_state.get(f"{voice_key}_recording_start"):
                elapsed = time.time() - st.session_state[f"{voice_key}_recording_start"]
                approx_seconds = bytes_captured / (16000 * 2) if bytes_captured else 0
                st.caption(f"🎤 Recording... ~{approx_seconds:.1f}s captured")
        
        # Process audio when recording stops
        if not ctx.state.playing and st.session_state.get(f"{voice_key}_processing", False):
            process_voice_input()
            
    except Exception as e:
        st.error(f"❌ WebRTC Error: {str(e)}")
        st.info("💡 Try refreshing the page or using a different browser (Chrome recommended).")
        
        # Fallback: manual audio input
        st.subheader("🔄 Fallback: Manual Audio Input")
        if st.button("Try Alternative Audio Method", key=f"fallback_{voice_key}"):
            st.info("This feature requires WebRTC support. Please ensure your browser supports WebRTC and try again.")