Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -79,6 +79,8 @@ async def generate_speech(text, tts_model, tts_tokenizer):
|
|
| 79 |
|
| 80 |
return audio_generation.cpu().numpy().squeeze()
|
| 81 |
|
|
|
|
|
|
|
| 82 |
@spaces.GPU(timeout=300)
|
| 83 |
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
|
| 84 |
try:
|
|
@@ -102,7 +104,8 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
|
|
| 102 |
top_p=top_p,
|
| 103 |
top_k=top_k,
|
| 104 |
temperature=temperature,
|
| 105 |
-
eos_token_id=
|
|
|
|
| 106 |
streamer=streamer,
|
| 107 |
)
|
| 108 |
|
|
@@ -110,35 +113,41 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
|
|
| 110 |
thread.start()
|
| 111 |
|
| 112 |
buffer = ""
|
| 113 |
-
audio_buffer = np.array([0.0]) # Initialize with a single zero
|
| 114 |
-
|
| 115 |
for new_text in streamer:
|
| 116 |
buffer += new_text
|
| 117 |
-
yield history + [[message, buffer]],
|
| 118 |
-
|
| 119 |
-
#
|
| 120 |
-
if use_tts and buffer:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
except Exception as e:
|
| 129 |
print(f"An error occurred: {str(e)}")
|
| 130 |
-
yield history + [[message, f"An error occurred: {str(e)}"]],
|
| 131 |
|
| 132 |
def generate_speech_sync(text, tts_model, tts_tokenizer):
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
@spaces.GPU(timeout=300) # Increase timeout to 5 minutes
|
| 144 |
def process_vision_query(image, text_input):
|
|
|
|
| 79 |
|
| 80 |
return audio_generation.cpu().numpy().squeeze()
|
| 81 |
|
| 82 |
+
from gradio import Error as GradioError
|
| 83 |
+
|
| 84 |
@spaces.GPU(timeout=300)
|
| 85 |
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
|
| 86 |
try:
|
|
|
|
| 104 |
top_p=top_p,
|
| 105 |
top_k=top_k,
|
| 106 |
temperature=temperature,
|
| 107 |
+
eos_token_id=text_tokenizer.eos_token_id,
|
| 108 |
+
pad_token_id=text_tokenizer.pad_token_id,
|
| 109 |
streamer=streamer,
|
| 110 |
)
|
| 111 |
|
|
|
|
| 113 |
thread.start()
|
| 114 |
|
| 115 |
buffer = ""
|
|
|
|
|
|
|
| 116 |
for new_text in streamer:
|
| 117 |
buffer += new_text
|
| 118 |
+
yield history + [[message, buffer]], None # Yield None for audio initially
|
| 119 |
+
|
| 120 |
+
# Only attempt TTS if it's enabled and we have a response
|
| 121 |
+
if use_tts and buffer:
|
| 122 |
+
try:
|
| 123 |
+
audio = generate_speech_sync(buffer, tts_model, tts_tokenizer)
|
| 124 |
+
yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"TTS failed: {str(e)}")
|
| 127 |
+
yield history + [[message, buffer]], None
|
| 128 |
+
else:
|
| 129 |
+
yield history + [[message, buffer]], None
|
| 130 |
+
|
| 131 |
+
except GradioError:
|
| 132 |
+
yield history + [[message, "GPU task aborted. Please try again."]], None
|
| 133 |
except Exception as e:
|
| 134 |
print(f"An error occurred: {str(e)}")
|
| 135 |
+
yield history + [[message, f"An error occurred: {str(e)}"]], None
|
| 136 |
|
| 137 |
def generate_speech_sync(text, tts_model, tts_tokenizer):
|
| 138 |
+
try:
|
| 139 |
+
tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
|
| 140 |
+
tts_description = "A clear and natural voice reads the text with moderate speed and expression."
|
| 141 |
+
tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
|
| 145 |
+
|
| 146 |
+
audio_buffer = audio_generation.cpu().numpy().squeeze()
|
| 147 |
+
return audio_buffer if audio_buffer.size > 0 else np.array([0.0])
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Speech generation failed: {str(e)}")
|
| 150 |
+
return np.array([0.0])
|
| 151 |
|
| 152 |
@spaces.GPU(timeout=300) # Increase timeout to 5 minutes
|
| 153 |
def process_vision_query(image, text_input):
|