Spaces:
Running
Running
fixed missing cuda option
Browse files
app.py
CHANGED
|
@@ -33,8 +33,20 @@ ANNOTATED_FEATURES_INFO = [
|
|
| 33 |
"Colloquial | Formal",
|
| 34 |
]
|
| 35 |
|
|
|
|
| 36 |
nltk.download("punkt_tab")
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# Load PCA model and annotated features
|
| 39 |
try:
|
| 40 |
pca = joblib.load(PCA_MODEL_PATH)
|
|
@@ -50,7 +62,9 @@ except FileNotFoundError:
|
|
| 50 |
print(f"Error: Annotated features file '{ANNOTATED_FEATURES_PATH}' not found.")
|
| 51 |
annotated_features = None
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def load_voices_json():
|
|
@@ -132,8 +146,8 @@ def generate_audio_with_voice(text, voice_key, speed_val):
|
|
| 132 |
print(f"Selected Voice: {voice_key}")
|
| 133 |
print(f"Style Vector (First 6): {style_vector[0][:6]}")
|
| 134 |
|
| 135 |
-
# Convert to torch tensor
|
| 136 |
-
style_vec_torch = torch.from_numpy(style_vector).float()
|
| 137 |
|
| 138 |
# Generate audio using the TTS model
|
| 139 |
audio_np = tts_with_style_vector(
|
|
@@ -148,7 +162,7 @@ def generate_audio_with_voice(text, voice_key, speed_val):
|
|
| 148 |
|
| 149 |
if audio_np is None:
|
| 150 |
print("Audio generation failed.")
|
| 151 |
-
return None, "Audio generation failed."
|
| 152 |
|
| 153 |
# Prepare audio for Gradio
|
| 154 |
sr = 24000 # Adjust based on your actual sampling rate
|
|
@@ -216,9 +230,9 @@ def generate_custom_audio(text, voice_key, randomize, speed_str, *slider_values)
|
|
| 216 |
if random_style_vec is None:
|
| 217 |
print("Failed to generate randomized style vector.")
|
| 218 |
return None, None, None
|
| 219 |
-
# Ensure the style vector is flat
|
| 220 |
final_vec = (
|
| 221 |
-
random_style_vec.numpy().flatten()
|
| 222 |
if isinstance(random_style_vec, torch.Tensor)
|
| 223 |
else np.array(random_style_vec).flatten()
|
| 224 |
)
|
|
@@ -232,8 +246,10 @@ def generate_custom_audio(text, voice_key, randomize, speed_str, *slider_values)
|
|
| 232 |
)
|
| 233 |
return None, None, None
|
| 234 |
|
| 235 |
-
# Convert to torch tensor
|
| 236 |
-
style_vec_torch =
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# Generate audio with the reconstructed style vector
|
| 239 |
audio_np = tts_with_style_vector(
|
|
@@ -471,13 +487,17 @@ def create_combined_interface():
|
|
| 471 |
# Save button functionality
|
| 472 |
def on_save_style_studio(style_vector, style_name):
|
| 473 |
if not style_name:
|
| 474 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
result = save_style_to_json(style_vector, style_name)
|
| 476 |
new_choices = list(load_voices_json().keys())
|
| 477 |
# Return multiple values to update both dropdowns and show status
|
| 478 |
return (
|
| 479 |
-
gr.Dropdown(choices=new_choices),
|
| 480 |
-
gr.Dropdown(choices=new_choices),
|
| 481 |
result, # Status message
|
| 482 |
)
|
| 483 |
|
|
|
|
| 33 |
"Colloquial | Formal",
|
| 34 |
]
|
| 35 |
|
| 36 |
+
# Download necessary NLTK data
|
| 37 |
nltk.download("punkt_tab")
|
| 38 |
|
| 39 |
+
##############################################################################
|
| 40 |
+
# DEVICE CONFIGURATION
|
| 41 |
+
##############################################################################
|
| 42 |
+
# Detect if CUDA is available and set the device accordingly
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
print(f"Using device: {device}")
|
| 45 |
+
|
| 46 |
+
##############################################################################
|
| 47 |
+
# LOAD PCA MODEL AND ANNOTATED FEATURES
|
| 48 |
+
##############################################################################
|
| 49 |
+
|
| 50 |
# Load PCA model and annotated features
|
| 51 |
try:
|
| 52 |
pca = joblib.load(PCA_MODEL_PATH)
|
|
|
|
| 62 |
print(f"Error: Annotated features file '{ANNOTATED_FEATURES_PATH}' not found.")
|
| 63 |
annotated_features = None
|
| 64 |
|
| 65 |
+
##############################################################################
|
| 66 |
+
# UTILITY FUNCTIONS
|
| 67 |
+
##############################################################################
|
| 68 |
|
| 69 |
|
| 70 |
def load_voices_json():
|
|
|
|
| 146 |
print(f"Selected Voice: {voice_key}")
|
| 147 |
print(f"Style Vector (First 6): {style_vector[0][:6]}")
|
| 148 |
|
| 149 |
+
# Convert to torch tensor and move to device
|
| 150 |
+
style_vec_torch = torch.from_numpy(style_vector).float().to(device)
|
| 151 |
|
| 152 |
# Generate audio using the TTS model
|
| 153 |
audio_np = tts_with_style_vector(
|
|
|
|
| 162 |
|
| 163 |
if audio_np is None:
|
| 164 |
print("Audio generation failed.")
|
| 165 |
+
return None, None, "Audio generation failed."
|
| 166 |
|
| 167 |
# Prepare audio for Gradio
|
| 168 |
sr = 24000 # Adjust based on your actual sampling rate
|
|
|
|
| 230 |
if random_style_vec is None:
|
| 231 |
print("Failed to generate randomized style vector.")
|
| 232 |
return None, None, None
|
| 233 |
+
# Ensure the style vector is flat and on device
|
| 234 |
final_vec = (
|
| 235 |
+
random_style_vec.cpu().numpy().flatten()
|
| 236 |
if isinstance(random_style_vec, torch.Tensor)
|
| 237 |
else np.array(random_style_vec).flatten()
|
| 238 |
)
|
|
|
|
| 246 |
)
|
| 247 |
return None, None, None
|
| 248 |
|
| 249 |
+
# Convert to torch tensor and move to device
|
| 250 |
+
style_vec_torch = (
|
| 251 |
+
torch.from_numpy(reconstructed_vec).float().unsqueeze(0).to(device)
|
| 252 |
+
)
|
| 253 |
|
| 254 |
# Generate audio with the reconstructed style vector
|
| 255 |
audio_np = tts_with_style_vector(
|
|
|
|
| 487 |
# Save button functionality
|
| 488 |
def on_save_style_studio(style_vector, style_name):
|
| 489 |
if not style_name:
|
| 490 |
+
return (
|
| 491 |
+
"Please enter a name for the new voice!",
|
| 492 |
+
gr.Dropdown.update(),
|
| 493 |
+
gr.Dropdown.update(),
|
| 494 |
+
)
|
| 495 |
result = save_style_to_json(style_vector, style_name)
|
| 496 |
new_choices = list(load_voices_json().keys())
|
| 497 |
# Return multiple values to update both dropdowns and show status
|
| 498 |
return (
|
| 499 |
+
gr.Dropdown.update(choices=new_choices),
|
| 500 |
+
gr.Dropdown.update(choices=new_choices),
|
| 501 |
result, # Status message
|
| 502 |
)
|
| 503 |
|