Fix predict_and_save function call - add required model parameter
Browse files- model_inference.py +13 -2
model_inference.py
CHANGED
|
@@ -34,6 +34,7 @@ def inferir_basic_pitch(input_file: str) -> str:
|
|
| 34 |
|
| 35 |
# Realizar predicci贸n y guardar archivo MIDI
|
| 36 |
predict_and_save(
|
|
|
|
| 37 |
audio_path_list=[input_file],
|
| 38 |
output_directory=BASE_MIDI_DIR,
|
| 39 |
save_midi=True,
|
|
@@ -43,13 +44,23 @@ def inferir_basic_pitch(input_file: str) -> str:
|
|
| 43 |
)
|
| 44 |
|
| 45 |
# Basic Pitch genera el archivo con un nombre espec铆fico
|
| 46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
generated_files = [f for f in os.listdir(BASE_MIDI_DIR) if f.endswith('.mid')]
|
| 48 |
if generated_files:
|
| 49 |
# Tomar el archivo m谩s reciente
|
| 50 |
latest_file = max([os.path.join(BASE_MIDI_DIR, f) for f in generated_files],
|
| 51 |
key=os.path.getctime)
|
| 52 |
-
print(f"Archivo MIDI
|
| 53 |
return latest_file
|
| 54 |
else:
|
| 55 |
print("No se gener贸 ning煤n archivo MIDI")
|
|
|
|
| 34 |
|
| 35 |
# Realizar predicci贸n y guardar archivo MIDI
|
| 36 |
predict_and_save(
|
| 37 |
+
model_or_model_path=ICASSP_2022_MODEL_PATH,
|
| 38 |
audio_path_list=[input_file],
|
| 39 |
output_directory=BASE_MIDI_DIR,
|
| 40 |
save_midi=True,
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
# Basic Pitch genera el archivo con un nombre espec铆fico
|
| 47 |
+
# El archivo se genera con el nombre del archivo original + "_basic_pitch.mid"
|
| 48 |
+
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
| 49 |
+
expected_midi_name = f"{base_name}_basic_pitch.mid"
|
| 50 |
+
expected_midi_path = os.path.join(BASE_MIDI_DIR, expected_midi_name)
|
| 51 |
+
|
| 52 |
+
# Verificar si se gener贸 el archivo esperado
|
| 53 |
+
if os.path.exists(expected_midi_path):
|
| 54 |
+
print(f"Archivo MIDI generado: {expected_midi_path}")
|
| 55 |
+
return expected_midi_path
|
| 56 |
+
|
| 57 |
+
# Si no est谩 con el nombre esperado, buscar cualquier archivo .mid generado recientemente
|
| 58 |
generated_files = [f for f in os.listdir(BASE_MIDI_DIR) if f.endswith('.mid')]
|
| 59 |
if generated_files:
|
| 60 |
# Tomar el archivo m谩s reciente
|
| 61 |
latest_file = max([os.path.join(BASE_MIDI_DIR, f) for f in generated_files],
|
| 62 |
key=os.path.getctime)
|
| 63 |
+
print(f"Archivo MIDI encontrado: {latest_file}")
|
| 64 |
return latest_file
|
| 65 |
else:
|
| 66 |
print("No se gener贸 ning煤n archivo MIDI")
|