ZDisket commited on
Commit
66d9b0c
·
1 Parent(s): 6291051

fix try catch

Browse files
Files changed (1) hide show
  1. app.py +20 -29
app.py CHANGED
@@ -50,37 +50,28 @@ if not os.path.exists(MUSICLSTM_PATH):
50
  if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")): # .ts or .pth for vocoder
51
  print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)")
52
 
53
- try:
54
- pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11],
55
  mel_channels=160)
56
 
57
- # Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed
58
- # For now, using the fixed NUM_GENRES. The model needs to be trained with this.
59
- model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID)
60
-
61
- # For MusicLSTM, the user code loads a checkpoint dictionary
62
- if os.path.exists(MUSICLSTM_PATH):
63
- chkp = torch.load(MUSICLSTM_PATH, map_location=device,
64
- weights_only=False) # Set weights_only based on your .pt file
65
- model.load_state_dict(chkp["model_state_dict"])
66
- else:
67
- print(f"MusicLSTM model file {MUSICLSTM_PATH} not found. Using uninitialized model.")
68
- model = model.to(device).eval()
69
-
70
- vocoder = ISTFTNetFE(None, None) # Adjust constructor if needed
71
-
72
- vocoder.load_ts(ISTFTNET_PATH, device) # load_ts might expect a directory/prefix
73
- vocoder = vocoder.to(device).eval()
74
-
75
- MODELS_LOADED = True
76
- print("Models loaded successfully (or placeholders initialized).")
77
-
78
- except Exception as e:
79
- print(f"Error loading models: {e}")
80
- print("Ensure model files are correctly placed and MQGAN classes are defined/imported.")
81
- MODELS_LOADED = False
82
- # Assign None to prevent further errors if Gradio tries to use them
83
- pre_enc, model, vocoder = None, None, None
84
 
85
 
86
  # --- Genre Loading ---
 
50
  if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")): # .ts or .pth for vocoder
51
  print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)")
52
 
53
+
54
+ pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11],
55
  mel_channels=160)
56
 
57
+ # Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed
58
+ # For now, using the fixed NUM_GENRES. The model needs to be trained with this.
59
+ model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID)
60
+
61
+ chkp = torch.load(MUSICLSTM_PATH, map_location=device, weights_only=False)
62
+ model.load_state_dict(chkp["model_state_dict"])
63
+
64
+ model = model.to(device).eval()
65
+
66
+ vocoder = ISTFTNetFE(None, None) # Adjust constructor if needed
67
+
68
+ vocoder.load_ts(ISTFTNET_PATH, device) # load_ts might expect a directory/prefix
69
+ vocoder = vocoder.to(device).eval()
70
+
71
+ MODELS_LOADED = True
72
+ print("Models loaded successfully (or placeholders initialized).")
73
+
74
+
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  # --- Genre Loading ---