aherrasf commited on
Commit
228b6ea
·
1 Parent(s): d5ee34b

First Versión

Browse files
Files changed (2) hide show
  1. app.py +20 -4
  2. model_inference.py +51 -0
app.py CHANGED
@@ -1,7 +1,23 @@
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from model_inference import (
4
+ inferir_basic_pitch
5
+ )
6
 
7
+ def procesar_wav(input_wav_path: str):
8
+ # 1) Generar archivos midi
9
+ midi_path = inferir_basic_pitch(input_wav_path)
10
+
11
+ # Devolver la ruta del archivo MIDI generado
12
+ return midi_path
13
 
14
+ demo = gr.Interface(
15
+ fn=procesar_wav,
16
+ inputs=gr.Audio(label="Sube un stem (.wav", type="filepath"),
17
+ outputs=gr.File(label="Fichero Midi", type="file"),
18
+ title="Basic Pitch Inference",
19
+ description="Sube tu archivo de audio para generar el archivo midi correspondiente.",
20
+ )
21
+
22
+ if __name__ == "__main__":
23
+ demo.launch(server_name="0.0.0.0", server_port=7860)
model_inference.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from basic_pitch.inference import predict_and_save
4
+ from basic_pitch import ICASSP_2022_MODEL_PATH
5
+ import tensorflow as tf
6
+
7
+
8
+ # Suprime warnings de runtime (p.ej. invalid value encountered in divide)
9
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
10
+
11
+ # Carpeta raíz donde guardamos archivos midi
12
+ BASE_MIDI_DIR = "data/midi"
13
+
14
+
15
+ def inferir_basic_pitch(input_file: str) -> str:
16
+ """
17
+ Procesa un archivo de audio y genera un archivo MIDI usando Basic Pitch.
18
+
19
+ Args:
20
+ input_file: Ruta al archivo de audio de entrada
21
+
22
+ Returns:
23
+ Ruta al archivo MIDI generado
24
+ """
25
+ # Crear directorio de salida si no existe
26
+ os.makedirs(BASE_MIDI_DIR, exist_ok=True)
27
+
28
+ # Generar nombre del archivo MIDI basado en el archivo de entrada
29
+ base_name = os.path.splitext(os.path.basename(input_file))[0]
30
+ midi_path = os.path.join(BASE_MIDI_DIR, f"{base_name}_basic_pitch.mid")
31
+
32
+ try:
33
+ # Cargar modelo de Basic Pitch
34
+ basic_pitch_model = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH))
35
+
36
+ # Realizar predicción y guardar archivo MIDI
37
+ predict_and_save(
38
+ audio_path_list=[input_file],
39
+ output_directory=BASE_MIDI_DIR,
40
+ save_midi=True,
41
+ sonify_midi=False,
42
+ save_model_outputs=False,
43
+ save_notes=False,
44
+ )
45
+
46
+ return midi_path
47
+
48
+ except Exception as e:
49
+ print(f"Error durante la inferencia: {e}")
50
+ return None
51
+