WWMachine commited on
Commit
02799cd
·
verified ·
1 Parent(s): e34cdab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -102
app.py CHANGED
@@ -2,158 +2,211 @@ import gradio as gr
2
  from llama_cpp import Llama
3
  from huggingface_hub import hf_hub_download
4
  import os
 
 
5
  from deepgram import DeepgramClient, PrerecordedOptions, SpeakOptions
 
6
 
7
  # --- Configuration ---
8
- # 1. API KEY: Ensure you have your Deepgram API Key ready
9
- # Ideally, set this in your environment variables as DEEPGRAM_API_KEY
10
  DEEPGRAM_API_KEY = "19d640a011569d78395c814e5f875b15cc84deb8"
11
-
12
- # 2. Model Config
13
  REPO_ID = "Kezovic/iris-q4gguf-v2"
14
  FILENAME = "llama-3.2-1b-instruct.Q4_K_M.gguf"
15
  CONTEXT_WINDOW = 4096
16
  MAX_NEW_TOKENS = 512
17
  TEMPERATURE = 0.7
18
 
19
- # --- Initialize Deepgram ---
20
- if DEEPGRAM_API_KEY == "YOUR_DEEPGRAM_KEY_HERE":
21
- print("WARNING: Please set your DEEPGRAM_API_KEY.")
22
-
23
- deepgram = DeepgramClient(DEEPGRAM_API_KEY)
24
 
25
- # --- Model Loading Function ---
 
26
  llm = None
 
27
  def load_llm():
28
- """Downloads the GGUF model and initializes LlamaCPP."""
29
  global llm
30
- print("Downloading LLM...")
31
  try:
32
- model_path = hf_hub_download(
33
- repo_id=REPO_ID,
34
- filename=FILENAME
35
- )
36
- # n_threads=2 is good for free Hugging Face CPU tiers
37
  llm = Llama(
38
  model_path=model_path,
39
  n_ctx=CONTEXT_WINDOW,
40
- n_threads=2,
41
  verbose=False
42
  )
43
- print("LLM loaded successfully!")
44
- return llm
45
  except Exception as e:
46
- print(f"Error loading model: {e}")
47
- return None
48
-
49
- # Load model on startup
50
  load_llm()
51
 
52
- # --- 1. Speech-to-Text (Deepgram) ---
53
- def transcribe_audio(audio_filepath):
54
- """Sends audio file to Deepgram and returns text."""
55
- if not audio_filepath:
56
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
 
 
 
58
  try:
59
- with open(audio_filepath, "rb") as buffer:
60
  payload = {"buffer": buffer}
61
  options = PrerecordedOptions(
62
- smart_format=True,
63
- model="nova-2",
64
- language="en-US"
65
  )
66
  response = deepgram.listen.rest.v("1").transcribe_file(payload, options)
67
  return response.results.channels[0].alternatives[0].transcript
68
  except Exception as e:
69
  print(f"STT Error: {e}")
70
- return ""
 
 
71
 
72
- # --- 2. Text-to-Speech (Deepgram) ---
73
  def text_to_speech(text):
74
- """Sends text to Deepgram and returns path to audio file."""
75
- try:
76
- filename = "output_response.mp3"
77
- options = SpeakOptions(
78
- model="aura-asteria-en", # Choices: aura-asteria-en, aura-helios-en, etc.
79
- encoding="linear16",
80
- container="wav"
81
- )
82
- # Save the audio to a file
83
- deepgram.speak.rest.v("1").save(filename, {"text": text}, options)
84
- return filename
85
- except Exception as e:
86
- print(f"TTS Error: {e}")
87
  return None
88
 
89
- # --- 3. Main Pipeline Function ---
90
- def process_conversation(audio_input):
91
- """
92
- 1. Transcribe Audio (STT)
93
- 2. Query LLM
94
- 3. Synthesize Speech (TTS)
95
- """
96
- if llm is None:
97
- return "Model not loaded.", None, "System Error: Model failed to load."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Step A: Transcribe
100
- user_text = transcribe_audio(audio_input)
101
- print(audio_input)
102
- if not user_text:
103
- return "Could not hear audio.", None, ""
104
 
105
- print(f"User said: {user_text}")
 
 
106
 
107
- # Step B: LLM Inference
108
- # Using the prompt format from your original code
109
- full_prompt = f"### Human: {user_text}\n### Assistant:"
110
 
111
- output = llm(
112
- prompt=full_prompt,
113
- max_tokens=MAX_NEW_TOKENS,
114
- temperature=TEMPERATURE,
115
- stop=["### Human:"],
116
- echo=False
117
- )
118
- response_text = output['choices'][0]['text'].strip()
119
- print(f"LLM said: {response_text}")
120
 
121
- # Step C: Speak Response
122
- output_audio_path = text_to_speech(response_text)
 
123
 
124
- # Return: Transcription (for display), Audio (for playback), LLM Text (for display)
125
- return user_text, output_audio_path, response_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- # --- Gradio UI ---
128
- with gr.Blocks(title=f"Voice Chat with {FILENAME}") as demo:
129
- gr.Markdown(f"## 🗣️ Deepgram Voice Chat with {FILENAME}")
130
 
 
 
 
131
  with gr.Row():
132
- # Input Column
133
- with gr.Column():
134
  audio_input = gr.Audio(
135
  sources=["microphone"],
136
- type="filepath",
137
- label="Speak Now"
138
  )
139
- submit_btn = gr.Button("Submit Audio", variant="primary")
140
-
141
- # Output Column
142
- with gr.Column():
143
- audio_output = gr.Audio(
144
- label="Assistant Voice",
145
- autoplay=True, # Automatically plays the response
146
- interactive=False
147
- )
148
- # Debugging/Visuals
149
- user_transcript = gr.Textbox(label="You said:")
150
- ai_response_text = gr.Textbox(label="AI Response:")
151
 
152
- # Event Listener
153
  submit_btn.click(
154
- fn=process_conversation,
155
- inputs=[audio_input],
156
- outputs=[user_transcript, audio_output, ai_response_text]
 
 
 
 
 
 
 
 
 
157
  )
158
 
159
  if __name__ == "__main__":
 
2
  from llama_cpp import Llama
3
  from huggingface_hub import hf_hub_download
4
  import os
5
+ import re
6
+ import time
7
  from deepgram import DeepgramClient, PrerecordedOptions, SpeakOptions
8
+ from pydub import AudioSegment # Added for audio stitching
9
 
10
  # --- Configuration ---
 
 
11
  DEEPGRAM_API_KEY = "19d640a011569d78395c814e5f875b15cc84deb8"
 
 
12
  REPO_ID = "Kezovic/iris-q4gguf-v2"
13
  FILENAME = "llama-3.2-1b-instruct.Q4_K_M.gguf"
14
  CONTEXT_WINDOW = 4096
15
  MAX_NEW_TOKENS = 512
16
  TEMPERATURE = 0.7
17
 
18
+ # Deepgram Limit: Maximum 2000 characters per TTS request.
19
+ TTS_MAX_CHARS = 1900 # Use slightly less than max for safety
 
 
 
20
 
21
+ # --- Initialize Deepgram & LLM ---
22
+ deepgram = DeepgramClient(DEEPGRAM_API_KEY) if DEEPGRAM_API_KEY else None
23
  llm = None
24
+
25
  def load_llm():
 
26
  global llm
 
27
  try:
28
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
 
 
 
 
29
  llm = Llama(
30
  model_path=model_path,
31
  n_ctx=CONTEXT_WINDOW,
32
+ n_threads=2,
33
  verbose=False
34
  )
 
 
35
  except Exception as e:
36
+ print(f"Error loading LLM: {e}")
 
 
 
37
  load_llm()
38
 
39
+ # --- Helper Functions for Splitting ---
40
+
41
+ def split_text_for_tts(text, max_chars=TTS_MAX_CHARS):
42
+ """Splits text into chunks <= max_chars based on punctuation for natural TTS."""
43
+
44
+ # Split on strong delimiters (period, question mark, exclamation mark, newline)
45
+ # The delimiters are kept in the segments by using parentheses
46
+ segments = re.split(r'([.?!]\s+|\n+)', text)
47
+ chunks = []
48
+ current_chunk = ""
49
+
50
+ for segment in segments:
51
+ if len(current_chunk) + len(segment) < max_chars:
52
+ current_chunk += segment
53
+ else:
54
+ if current_chunk:
55
+ chunks.append(current_chunk.strip())
56
+ current_chunk = segment
57
+
58
+ if current_chunk:
59
+ chunks.append(current_chunk.strip())
60
+
61
+ return [chunk for chunk in chunks if chunk]
62
+
63
+ # --- 1. Speech-to-Text (STT) with File Size Check ---
64
+
65
+ def transcribe(audio_path):
66
+ """Converts Speech to Text using Deepgram, with a file size check."""
67
+ if not audio_path or deepgram is None:
68
+ return None
69
 
70
+ # STT API check: Deepgram Pre-Recorded supports files up to 2GB
71
+ # We check file size and return a warning if too large (e.g., > 200MB, where asynchronous processing is better)
72
+ file_size_bytes = os.path.getsize(audio_path)
73
+ if file_size_bytes > 200 * 1024 * 1024:
74
+ print("Warning: Audio file is large. Transcription may take a moment.")
75
+
76
  try:
77
+ with open(audio_path, "rb") as buffer:
78
  payload = {"buffer": buffer}
79
  options = PrerecordedOptions(
80
+ smart_format=True, model="nova-2", language="en-US",
81
+ # Add diarization=True if you want speaker separation in the transcript
 
82
  )
83
  response = deepgram.listen.rest.v("1").transcribe_file(payload, options)
84
  return response.results.channels[0].alternatives[0].transcript
85
  except Exception as e:
86
  print(f"STT Error: {e}")
87
+ return None
88
+
89
+ # --- 2. Text-to-Speech (TTS) with Stitching ---
90
 
 
91
  def text_to_speech(text):
92
+ """Converts Text to Speech, splitting long text and stitching audio."""
93
+ if deepgram is None:
 
 
 
 
 
 
 
 
 
 
 
94
  return None
95
 
96
+ # Step A: Split text into small chunks
97
+ text_chunks = split_text_for_tts(text)
98
+
99
+ audio_segments = []
100
+
101
+ # Step B: Generate audio for each chunk
102
+ for i, chunk in enumerate(text_chunks):
103
+ try:
104
+ temp_filename = f"temp_tts_chunk_{i}_{int(time.time())}.wav"
105
+ options = SpeakOptions(
106
+ model="aura-asteria-en", encoding="linear16", container="wav"
107
+ )
108
+ deepgram.speak.rest.v("1").save(temp_filename, {"text": chunk}, options)
109
+
110
+ # Load the temporary audio into pydub
111
+ audio_segments.append(AudioSegment.from_wav(temp_filename))
112
+ os.remove(temp_filename)
113
+
114
+ except Exception as e:
115
+ print(f"TTS API FAILED for chunk {i}: {e}. Skipping chunk.")
116
+ continue
117
+
118
+ if not audio_segments:
119
+ return None
120
+
121
+ # Step C: Stitch the audio files together
122
+ stitched_audio = audio_segments[0]
123
+ for i in range(1, len(audio_segments)):
124
+ # Add a 200ms pause between sentences for better flow
125
+ stitched_audio += AudioSegment.silent(duration=200)
126
+ stitched_audio += audio_segments[i]
127
+
128
+ # Step D: Export the final stitched file
129
+ final_filename = f"final_response_{int(time.time())}.wav"
130
+ stitched_audio.export(final_filename, format="wav")
131
+
132
+ return final_filename
133
 
134
+ # --- Main Chat Logic (Same as before) ---
 
 
 
 
135
 
136
+ def run_chat_pipeline(audio_input, history, state_messages):
137
+ if llm is None:
138
+ return history, state_messages, None
139
 
140
+ # 1. Transcribe Audio (STT)
141
+ user_text = transcribe(audio_input)
 
142
 
143
+ if not user_text:
144
+ # If transcription fails (e.g., bad audio, API key error), inform the user via the chat.
145
+ history.append(("", "System Error: Could not process audio. Check API Key or try speaking louder."))
146
+ return history, state_messages, None
 
 
 
 
 
147
 
148
+ # 2. Update UI and State with User Message
149
+ state_messages.append({"role": "user", "content": user_text})
150
+ history.append((user_text, None))
151
 
152
+ # 3. LLM Generation (Contextual)
153
+ try:
154
+ completion = llm.create_chat_completion(
155
+ messages=state_messages,
156
+ max_tokens=MAX_NEW_TOKENS,
157
+ temperature=TEMPERATURE
158
+ )
159
+ ai_text = completion['choices'][0]['message']['content']
160
+ except Exception as e:
161
+ ai_text = f"LLM Generation Error: {str(e)}"
162
+
163
+ # 4. Update UI and State with AI Response
164
+ state_messages.append({"role": "assistant", "content": ai_text})
165
+ history[-1] = (user_text, ai_text)
166
+
167
+ # 5. Generate Audio (TTS with splitting)
168
+ audio_path = text_to_speech(ai_text) # This handles the stitching
169
+
170
+ return history, state_messages, audio_path
171
 
172
+ # --- Gradio UI Layout ---
173
+ with gr.Blocks(title="Voice Chatbot") as demo:
174
+ gr.Markdown("## 🎙️ Voice-First AI Chat (Memory & Long-Text Handled)")
175
 
176
+ chatbot = gr.Chatbot(label="Conversation", height=500)
177
+ state_messages = gr.State([])
178
+
179
  with gr.Row():
180
+ with gr.Column(scale=4):
 
181
  audio_input = gr.Audio(
182
  sources=["microphone"],
183
+ type="filepath",
184
+ label="Record Your Message"
185
  )
186
+ with gr.Column(scale=1):
187
+ submit_btn = gr.Button("Send Voice 💬", variant="primary")
188
+ clear_btn = gr.Button("Clear Memory 🗑️")
189
+
190
+ audio_player = gr.Audio(
191
+ label="AI Voice",
192
+ autoplay=True,
193
+ interactive=False
194
+ )
 
 
 
195
 
196
+ # --- Event Wiring ---
197
  submit_btn.click(
198
+ fn=run_chat_pipeline,
199
+ inputs=[audio_input, chatbot, state_messages],
200
+ outputs=[chatbot, state_messages, audio_player]
201
+ )
202
+
203
+ def clear_all():
204
+ return [], [], None
205
+
206
+ clear_btn.click(
207
+ fn=clear_all,
208
+ inputs=None,
209
+ outputs=[chatbot, state_messages, audio_player]
210
  )
211
 
212
  if __name__ == "__main__":