Spaces:
Sleeping
Sleeping
updates for audio url issue
Browse files
app.py
CHANGED
|
@@ -164,25 +164,46 @@ def generate_response(audio_data, message, chat_history=[]):
|
|
| 164 |
{"role": "system", "content": system_prompt}
|
| 165 |
]
|
| 166 |
|
|
|
|
|
|
|
|
|
|
| 167 |
# Add chat history (limited to last 3 turns)
|
| 168 |
history_limit = min(len(chat_history), 3)
|
| 169 |
for user_msg, bot_msg in chat_history[-history_limit:]:
|
| 170 |
-
|
| 171 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
conversation.append({"role": "assistant", "content": bot_msg})
|
| 173 |
|
| 174 |
-
# Add current message with audio
|
| 175 |
if audio_data is not None:
|
| 176 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
conversation.append({
|
| 178 |
"role": "user",
|
| 179 |
-
"content":
|
| 180 |
-
{"type": "audio"}, # Audio will be added in preprocessing
|
| 181 |
-
{"type": "text", "text": message}
|
| 182 |
-
]
|
| 183 |
})
|
| 184 |
else:
|
| 185 |
-
# Text-only
|
| 186 |
conversation.append({
|
| 187 |
"role": "user",
|
| 188 |
"content": message
|
|
@@ -197,53 +218,17 @@ def generate_response(audio_data, message, chat_history=[]):
|
|
| 197 |
)
|
| 198 |
|
| 199 |
# Process inputs
|
| 200 |
-
logger.info("Processing inputs")
|
| 201 |
inputs = processor(
|
| 202 |
text=text,
|
| 203 |
-
audios=
|
| 204 |
return_tensors="pt",
|
| 205 |
padding=True,
|
| 206 |
truncation=True
|
| 207 |
)
|
| 208 |
|
| 209 |
-
#
|
| 210 |
-
|
| 211 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 212 |
-
|
| 213 |
-
log_gpu_memory("Before generation")
|
| 214 |
-
|
| 215 |
-
# Generate response
|
| 216 |
-
logger.info("Generating response")
|
| 217 |
-
with torch.no_grad():
|
| 218 |
-
output = model.generate(
|
| 219 |
-
**inputs,
|
| 220 |
-
max_new_tokens=150,
|
| 221 |
-
do_sample=True,
|
| 222 |
-
temperature=0.7,
|
| 223 |
-
top_p=0.9,
|
| 224 |
-
use_cache=True,
|
| 225 |
-
pad_token_id=processor.tokenizer.pad_token_id
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
# Decode only the new tokens
|
| 229 |
-
generated_text = processor.batch_decode(
|
| 230 |
-
output[:, inputs.input_ids.shape[1]:],
|
| 231 |
-
skip_special_tokens=True
|
| 232 |
-
)[0]
|
| 233 |
-
|
| 234 |
-
# Clean up
|
| 235 |
-
del inputs, output
|
| 236 |
-
gc.collect()
|
| 237 |
-
if torch.cuda.is_available():
|
| 238 |
-
torch.cuda.empty_cache()
|
| 239 |
-
|
| 240 |
-
log_gpu_memory("After generation")
|
| 241 |
-
|
| 242 |
-
return generated_text
|
| 243 |
-
except Exception as e:
|
| 244 |
-
logger.error(f"Error generating response: {e}")
|
| 245 |
-
return f"I encountered an error while processing your request: {str(e)}"
|
| 246 |
-
|
| 247 |
# Create Gradio Interface
|
| 248 |
def create_interface():
|
| 249 |
"""Create the Gradio interface"""
|
|
@@ -316,8 +301,17 @@ def create_interface():
|
|
| 316 |
if not message or not message.strip():
|
| 317 |
return chat_history, ""
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
# Add user message to history
|
| 320 |
-
chat_history.append((
|
| 321 |
yield chat_history, ""
|
| 322 |
|
| 323 |
try:
|
|
|
|
| 164 |
{"role": "system", "content": system_prompt}
|
| 165 |
]
|
| 166 |
|
| 167 |
+
# Collect all audio samples in order
|
| 168 |
+
audios = []
|
| 169 |
+
|
| 170 |
# Add chat history (limited to last 3 turns)
|
| 171 |
history_limit = min(len(chat_history), 3)
|
| 172 |
for user_msg, bot_msg in chat_history[-history_limit:]:
|
| 173 |
+
# Check if the user message is a string or already contains audio
|
| 174 |
+
if isinstance(user_msg, list) and any(item.get("type") == "audio" for item in user_msg):
|
| 175 |
+
# It's already in the right format with audio
|
| 176 |
+
conversation.append({"role": "user", "content": user_msg})
|
| 177 |
+
|
| 178 |
+
# Extract audio from this message
|
| 179 |
+
for item in user_msg:
|
| 180 |
+
if item.get("type") == "audio" and "audio_data" in item:
|
| 181 |
+
audios.append(item["audio_data"])
|
| 182 |
+
else:
|
| 183 |
+
# Regular text message
|
| 184 |
+
conversation.append({"role": "user", "content": user_msg})
|
| 185 |
+
|
| 186 |
+
# Add assistant response if available
|
| 187 |
+
if bot_msg:
|
| 188 |
conversation.append({"role": "assistant", "content": bot_msg})
|
| 189 |
|
| 190 |
+
# Add current message with audio if available
|
| 191 |
if audio_data is not None:
|
| 192 |
+
# Current message with audio
|
| 193 |
+
user_content = [
|
| 194 |
+
{"type": "audio", "audio_url": "audio_sample.wav"}, # Placeholder URL
|
| 195 |
+
{"type": "text", "text": message}
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
# Store the actual audio data for processing
|
| 199 |
+
audios.append(audio_data)
|
| 200 |
+
|
| 201 |
conversation.append({
|
| 202 |
"role": "user",
|
| 203 |
+
"content": user_content
|
|
|
|
|
|
|
|
|
|
| 204 |
})
|
| 205 |
else:
|
| 206 |
+
# Text-only message
|
| 207 |
conversation.append({
|
| 208 |
"role": "user",
|
| 209 |
"content": message
|
|
|
|
| 218 |
)
|
| 219 |
|
| 220 |
# Process inputs
|
| 221 |
+
logger.info(f"Processing inputs with {len(audios)} audio samples")
|
| 222 |
inputs = processor(
|
| 223 |
text=text,
|
| 224 |
+
audios=audios if audios else None,
|
| 225 |
return_tensors="pt",
|
| 226 |
padding=True,
|
| 227 |
truncation=True
|
| 228 |
)
|
| 229 |
|
| 230 |
+
# The rest of your function remains the same
|
| 231 |
+
# ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
# Create Gradio Interface
|
| 233 |
def create_interface():
|
| 234 |
"""Create the Gradio interface"""
|
|
|
|
| 301 |
if not message or not message.strip():
|
| 302 |
return chat_history, ""
|
| 303 |
|
| 304 |
+
# If we have audio, format the user message as a list with audio and text
|
| 305 |
+
if audio_data is not None:
|
| 306 |
+
user_message = [
|
| 307 |
+
{"type": "audio", "audio_url": "audio_sample.wav", "audio_data": audio_data},
|
| 308 |
+
{"type": "text", "text": message}
|
| 309 |
+
]
|
| 310 |
+
else:
|
| 311 |
+
user_message = message
|
| 312 |
+
|
| 313 |
# Add user message to history
|
| 314 |
+
chat_history.append((user_message, None))
|
| 315 |
yield chat_history, ""
|
| 316 |
|
| 317 |
try:
|