Spaces:
Sleeping
Sleeping
initial commit
Browse files
app.py
CHANGED
|
@@ -393,42 +393,51 @@ TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
|
|
| 393 |
ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
|
| 394 |
MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
|
| 395 |
|
| 396 |
-
|
| 397 |
-
stt_processor = None
|
| 398 |
-
stt_model = None
|
| 399 |
-
mt_tokenizer = None
|
| 400 |
-
mt_model = None
|
| 401 |
|
| 402 |
# Wrap model loading in a function to clearly see when it happens or to potentially delay it.
|
| 403 |
# For Spaces, global loading is fine and preferred as it happens once.
|
| 404 |
print("--- Starting Model Loading ---")
|
| 405 |
try:
|
| 406 |
-
print(
|
|
|
|
| 407 |
tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
|
| 408 |
-
state = torch.load(tts_model_path, map_location=DEVICE)
|
| 409 |
-
TTS_MODEL = TransformerTTS(
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
TTS_MODEL.eval()
|
| 413 |
print("TTS model loaded successfully.")
|
| 414 |
except Exception as e:
|
| 415 |
print(f"Error loading TTS model: {e}")
|
|
|
|
| 416 |
|
|
|
|
| 417 |
try:
|
| 418 |
-
print(
|
| 419 |
stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
|
| 420 |
stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
|
| 421 |
print("STT model loaded successfully.")
|
| 422 |
except Exception as e:
|
| 423 |
print(f"Error loading STT model: {e}")
|
|
|
|
|
|
|
| 424 |
|
|
|
|
| 425 |
try:
|
| 426 |
-
print(
|
| 427 |
mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
|
| 428 |
mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
|
| 429 |
print("TTT model loaded successfully.")
|
| 430 |
except Exception as e:
|
| 431 |
print(f"Error loading TTT model: {e}")
|
|
|
|
|
|
|
| 432 |
print("--- Model Loading Complete ---")
|
| 433 |
|
| 434 |
|
|
|
|
| 393 |
ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
|
| 394 |
MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
|
| 395 |
|
| 396 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
# Wrap model loading in a function to clearly see when it happens or to potentially delay it.
|
| 399 |
# For Spaces, global loading is fine and preferred as it happens once.
|
| 400 |
print("--- Starting Model Loading ---")
|
| 401 |
try:
|
| 402 |
+
print("Loading TTS model...")
|
| 403 |
+
# Download the .pt file from its repo
|
| 404 |
tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
|
| 405 |
+
state = torch.load(tts_model_path, map_location=DEVICE)
|
| 406 |
+
TTS_MODEL = TransformerTTS().to(DEVICE)
|
| 407 |
+
# Check for the correct key in the state dictionary
|
| 408 |
+
if "model" in state:
|
| 409 |
+
TTS_MODEL.load_state_dict(state["model"])
|
| 410 |
+
elif "state_dict" in state:
|
| 411 |
+
TTS_MODEL.load_state_dict(state["state_dict"])
|
| 412 |
+
else:
|
| 413 |
+
TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
|
| 414 |
TTS_MODEL.eval()
|
| 415 |
print("TTS model loaded successfully.")
|
| 416 |
except Exception as e:
|
| 417 |
print(f"Error loading TTS model: {e}")
|
| 418 |
+
TTS_MODEL = None
|
| 419 |
|
| 420 |
+
# Load STT (Whisper) Model from Hub
|
| 421 |
try:
|
| 422 |
+
print("Loading STT (Whisper) model...")
|
| 423 |
stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
|
| 424 |
stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
|
| 425 |
print("STT model loaded successfully.")
|
| 426 |
except Exception as e:
|
| 427 |
print(f"Error loading STT model: {e}")
|
| 428 |
+
stt_processor = None
|
| 429 |
+
stt_model = None
|
| 430 |
|
| 431 |
+
# Load TTT (MarianMT) Model from Hub
|
| 432 |
try:
|
| 433 |
+
print("Loading TTT (MarianMT) model...")
|
| 434 |
mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
|
| 435 |
mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
|
| 436 |
print("TTT model loaded successfully.")
|
| 437 |
except Exception as e:
|
| 438 |
print(f"Error loading TTT model: {e}")
|
| 439 |
+
mt_tokenizer = None
|
| 440 |
+
mt_model = None
|
| 441 |
print("--- Model Loading Complete ---")
|
| 442 |
|
| 443 |
|