aherrasf commited on
Commit
f70a604
1 Parent(s): 0645fbe

Fix predict_and_save function call - add required model parameter

Browse files
Files changed (1) hide show
  1. model_inference.py +13 -2
model_inference.py CHANGED
@@ -34,6 +34,7 @@ def inferir_basic_pitch(input_file: str) -> str:
34
 
35
  # Realizar predicci贸n y guardar archivo MIDI
36
  predict_and_save(
 
37
  audio_path_list=[input_file],
38
  output_directory=BASE_MIDI_DIR,
39
  save_midi=True,
@@ -43,13 +44,23 @@ def inferir_basic_pitch(input_file: str) -> str:
43
  )
44
 
45
  # Basic Pitch genera el archivo con un nombre espec铆fico
46
- # Necesitamos encontrar el archivo MIDI generado
 
 
 
 
 
 
 
 
 
 
47
  generated_files = [f for f in os.listdir(BASE_MIDI_DIR) if f.endswith('.mid')]
48
  if generated_files:
49
  # Tomar el archivo m谩s reciente
50
  latest_file = max([os.path.join(BASE_MIDI_DIR, f) for f in generated_files],
51
  key=os.path.getctime)
52
- print(f"Archivo MIDI generado: {latest_file}")
53
  return latest_file
54
  else:
55
  print("No se gener贸 ning煤n archivo MIDI")
 
34
 
35
  # Realizar predicci贸n y guardar archivo MIDI
36
  predict_and_save(
37
+ model_or_model_path=ICASSP_2022_MODEL_PATH,
38
  audio_path_list=[input_file],
39
  output_directory=BASE_MIDI_DIR,
40
  save_midi=True,
 
44
  )
45
 
46
  # Basic Pitch genera el archivo con un nombre espec铆fico
47
+ # El archivo se genera con el nombre del archivo original + "_basic_pitch.mid"
48
+ base_name = os.path.splitext(os.path.basename(input_file))[0]
49
+ expected_midi_name = f"{base_name}_basic_pitch.mid"
50
+ expected_midi_path = os.path.join(BASE_MIDI_DIR, expected_midi_name)
51
+
52
+ # Verificar si se gener贸 el archivo esperado
53
+ if os.path.exists(expected_midi_path):
54
+ print(f"Archivo MIDI generado: {expected_midi_path}")
55
+ return expected_midi_path
56
+
57
+ # Si no est谩 con el nombre esperado, buscar cualquier archivo .mid generado recientemente
58
  generated_files = [f for f in os.listdir(BASE_MIDI_DIR) if f.endswith('.mid')]
59
  if generated_files:
60
  # Tomar el archivo m谩s reciente
61
  latest_file = max([os.path.join(BASE_MIDI_DIR, f) for f in generated_files],
62
  key=os.path.getctime)
63
+ print(f"Archivo MIDI encontrado: {latest_file}")
64
  return latest_file
65
  else:
66
  print("No se gener贸 ning煤n archivo MIDI")