Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import sys
|
|
| 4 |
import tempfile
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
-
import os
|
| 8 |
import shutil
|
| 9 |
import glob
|
| 10 |
|
|
@@ -26,6 +25,21 @@ from TTS.tts.models.xtts import Xtts
|
|
| 26 |
from TTS.tts.configs.xtts_config import XttsConfig
|
| 27 |
from TTS.tts.models.xtts import Xtts
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# Clear logs
|
| 30 |
def remove_log_file(file_path):
|
| 31 |
log_file = Path(file_path)
|
|
@@ -388,6 +402,13 @@ if __name__ == "__main__":
|
|
| 388 |
|
| 389 |
def train_model(custom_model, version, language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
| 390 |
clear_gpu_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
run_dir = Path(output_path) / "run"
|
| 393 |
|
|
|
|
| 4 |
import tempfile
|
| 5 |
from pathlib import Path
|
| 6 |
|
|
|
|
| 7 |
import shutil
|
| 8 |
import glob
|
| 9 |
|
|
|
|
| 25 |
from TTS.tts.configs.xtts_config import XttsConfig
|
| 26 |
from TTS.tts.models.xtts import Xtts
|
| 27 |
|
| 28 |
+
import requests
|
| 29 |
+
|
| 30 |
+
def download_file(url, destination):
|
| 31 |
+
try:
|
| 32 |
+
response = requests.get(url, stream=True)
|
| 33 |
+
response.raise_for_status()
|
| 34 |
+
with open(destination, "wb") as f:
|
| 35 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 36 |
+
f.write(chunk)
|
| 37 |
+
print(f"Downloaded file to {destination}")
|
| 38 |
+
return destination
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Failed to download the file: {e}")
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
# Clear logs
|
| 44 |
def remove_log_file(file_path):
|
| 45 |
log_file = Path(file_path)
|
|
|
|
| 402 |
|
| 403 |
def train_model(custom_model, version, language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
| 404 |
clear_gpu_cache()
|
| 405 |
+
|
| 406 |
+
# Check if `custom_model` is a URL and download it if true.
|
| 407 |
+
if custom_model.startswith("http"):
|
| 408 |
+
print("Downloading custom model from URL...")
|
| 409 |
+
custom_model = download_file(custom_model, "custom_model.pth")
|
| 410 |
+
if not custom_model:
|
| 411 |
+
return "Failed to download the custom model.", "", "", "", ""
|
| 412 |
|
| 413 |
run_dir = Path(output_path) / "run"
|
| 414 |
|