aherrasf commited on
Commit
1e7a231
1 Parent(s): f70a604

Fix TensorFlow model loading issues - add GPU disable and better error handling

Browse files
Files changed (2) hide show
  1. model_inference.py +48 -25
  2. requirements.txt +9 -6
model_inference.py CHANGED
@@ -1,9 +1,14 @@
1
  import os
2
  import warnings
 
3
  from basic_pitch.inference import predict_and_save
4
  from basic_pitch import ICASSP_2022_MODEL_PATH
5
  import tensorflow as tf
6
 
 
 
 
 
7
 
8
  # Suprime warnings de runtime (p.ej. invalid value encountered in divide)
9
  warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -25,37 +30,54 @@ def inferir_basic_pitch(input_file: str) -> str:
25
  # Crear directorio de salida si no existe
26
  os.makedirs(BASE_MIDI_DIR, exist_ok=True)
27
 
28
- # Generar nombre del archivo MIDI basado en el archivo de entrada
29
- base_name = os.path.splitext(os.path.basename(input_file))[0]
30
- midi_path = os.path.join(BASE_MIDI_DIR, f"{base_name}_basic_pitch.mid")
31
-
32
  try:
33
  print(f"Procesando archivo: {input_file}")
 
34
 
35
- # Realizar predicci贸n y guardar archivo MIDI
36
- predict_and_save(
37
- model_or_model_path=ICASSP_2022_MODEL_PATH,
38
- audio_path_list=[input_file],
39
- output_directory=BASE_MIDI_DIR,
40
- save_midi=True,
41
- sonify_midi=False,
42
- save_model_outputs=False,
43
- save_notes=False,
44
- )
45
 
46
- # Basic Pitch genera el archivo con un nombre espec铆fico
47
- # El archivo se genera con el nombre del archivo original + "_basic_pitch.mid"
48
- base_name = os.path.splitext(os.path.basename(input_file))[0]
49
- expected_midi_name = f"{base_name}_basic_pitch.mid"
50
- expected_midi_path = os.path.join(BASE_MIDI_DIR, expected_midi_name)
51
 
52
- # Verificar si se gener贸 el archivo esperado
53
- if os.path.exists(expected_midi_path):
54
- print(f"Archivo MIDI generado: {expected_midi_path}")
55
- return expected_midi_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Si no est谩 con el nombre esperado, buscar cualquier archivo .mid generado recientemente
58
  generated_files = [f for f in os.listdir(BASE_MIDI_DIR) if f.endswith('.mid')]
 
 
59
  if generated_files:
60
  # Tomar el archivo m谩s reciente
61
  latest_file = max([os.path.join(BASE_MIDI_DIR, f) for f in generated_files],
@@ -67,6 +89,7 @@ def inferir_basic_pitch(input_file: str) -> str:
67
  return None
68
 
69
  except Exception as e:
70
- print(f"Error durante la inferencia: {e}")
 
71
  return None
72
 
 
1
  import os
2
  import warnings
3
+ import tempfile
4
  from basic_pitch.inference import predict_and_save
5
  from basic_pitch import ICASSP_2022_MODEL_PATH
6
  import tensorflow as tf
7
 
8
+ # Configurar TensorFlow para evitar problemas de GPU
9
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Forzar uso de CPU
10
+ tf.config.set_visible_devices([], 'GPU') # Desabilitar GPU
11
+
12
 
13
  # Suprime warnings de runtime (p.ej. invalid value encountered in divide)
14
  warnings.filterwarnings("ignore", category=RuntimeWarning)
 
30
  # Crear directorio de salida si no existe
31
  os.makedirs(BASE_MIDI_DIR, exist_ok=True)
32
 
 
 
 
 
33
  try:
34
  print(f"Procesando archivo: {input_file}")
35
+ print(f"Directorio de salida: {BASE_MIDI_DIR}")
36
 
37
+ # Verificar que el archivo de entrada existe
38
+ if not os.path.exists(input_file):
39
+ print(f"Error: El archivo de entrada no existe: {input_file}")
40
+ return None
 
 
 
 
 
 
41
 
42
+ # Limpiar el directorio de salida antes del procesamiento
43
+ for f in os.listdir(BASE_MIDI_DIR):
44
+ if f.endswith('.mid'):
45
+ os.remove(os.path.join(BASE_MIDI_DIR, f))
 
46
 
47
+ # Realizar predicci贸n y guardar archivo MIDI
48
+ try:
49
+ predict_and_save(
50
+ model_or_model_path=ICASSP_2022_MODEL_PATH,
51
+ audio_path_list=[input_file],
52
+ output_directory=BASE_MIDI_DIR,
53
+ save_midi=True,
54
+ sonify_midi=False,
55
+ save_model_outputs=False,
56
+ save_notes=False,
57
+ )
58
+ print("Predicci贸n completada")
59
+ except Exception as model_error:
60
+ print(f"Error espec铆fico del modelo: {model_error}")
61
+ # Intentar con una configuraci贸n alternativa
62
+ try:
63
+ print("Intentando configuraci贸n alternativa...")
64
+ predict_and_save(
65
+ audio_path_list=[input_file],
66
+ output_directory=BASE_MIDI_DIR,
67
+ save_midi=True,
68
+ sonify_midi=False,
69
+ save_model_outputs=False,
70
+ save_notes=False,
71
+ )
72
+ print("Predicci贸n alternativa completada")
73
+ except Exception as alt_error:
74
+ print(f"Error en configuraci贸n alternativa: {alt_error}")
75
+ return None
76
 
77
+ # Buscar archivos MIDI generados
78
  generated_files = [f for f in os.listdir(BASE_MIDI_DIR) if f.endswith('.mid')]
79
+ print(f"Archivos generados: {generated_files}")
80
+
81
  if generated_files:
82
  # Tomar el archivo m谩s reciente
83
  latest_file = max([os.path.join(BASE_MIDI_DIR, f) for f in generated_files],
 
89
  return None
90
 
91
  except Exception as e:
92
+ print(f"Error general durante la inferencia: {e}")
93
+ print(f"Tipo de error: {type(e).__name__}")
94
  return None
95
 
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- gradio
2
- basic-pitch
3
- tensorflow
4
- numpy
5
- librosa
6
- scipy
 
 
 
 
1
+ gradio==4.44.0
2
+ basic-pitch==0.3.0
3
+ tensorflow==2.13.0
4
+ numpy==1.24.3
5
+ librosa==0.10.1
6
+ scipy==1.11.4
7
+ resampy==0.4.2
8
+ pretty_midi==0.2.10
9
+ mido==1.3.0