hoom4n commited on
Commit
49ddb63
·
verified ·
1 Parent(s): b179d8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -4,6 +4,7 @@ os.environ["KERAS_BACKEND"] ="tensorflow"
4
  import random
5
  import keras
6
  import gradio as gr
 
7
  from src.inference import generate_chorale, draw_random_sample
8
  from src.dataset import NoteEncoder
9
  from src.metrics import Preplexity
@@ -21,6 +22,7 @@ wav_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.wav")
21
 
22
 
23
  ### DOWNLOAD SF2 MUSIC FONT ###
 
24
  sf2_download_path = keras.utils.get_file(
25
  "FluidR3_GM.zip",
26
  "https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
@@ -29,15 +31,23 @@ sf2_download_path = keras.utils.get_file(
29
  cache_subdir= ""
30
  )
31
  SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
 
32
 
33
 
34
  ### LOAD MODEL & ENCODERS ###
 
 
35
  model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
36
  custom_objects={"Preplexity": Preplexity})
 
 
 
37
  note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
 
38
 
39
 
40
  ### GRADIO ASSETS ###
 
41
  css = load_css()
42
  english_summary = load_markdown("english_summary")
43
  persian_summary = load_markdown("persian_summary")
@@ -45,6 +55,7 @@ english_help = load_markdown("english_help")
45
  persian_help = load_markdown("persian_help")
46
  english_title = "# BachNet: AI-Generated Bach Music"
47
  persian_title = "# باخ‌نت: خلق موسیقی مشابه باخ با هوش مصنوعی"
 
48
 
49
 
50
  ### GENERATION FUNCTIONS ###
@@ -53,7 +64,9 @@ def pick_random_seed():
53
 
54
  def generate_fn(seed_path, seed_len, gen_len, temp):
55
  sample_rows = slice(0, seed_len)
56
-
 
 
57
  generate_chorale(
58
  model=model,
59
  sample_seed_path=seed_path,
@@ -64,8 +77,14 @@ def generate_fn(seed_path, seed_len, gen_len, temp):
64
  temperature=temp,
65
  sample_seed_rows=sample_rows
66
  )
67
-
 
 
68
  midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
 
 
 
 
69
  return wav_path
70
 
71
  def set_english():
@@ -121,4 +140,4 @@ with gr.Blocks(css=css, title="BachNet") as demo:
121
 
122
  ### LAUNCH APP ###
123
  if __name__ == "__main__":
124
- demo.launch()
 
4
  import random
5
  import keras
6
  import gradio as gr
7
+ import time # <-- added
8
  from src.inference import generate_chorale, draw_random_sample
9
  from src.dataset import NoteEncoder
10
  from src.metrics import Preplexity
 
22
 
23
 
24
  ### DOWNLOAD SF2 MUSIC FONT ###
25
+ t0 = time.time()
26
  sf2_download_path = keras.utils.get_file(
27
  "FluidR3_GM.zip",
28
  "https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
 
31
  cache_subdir= ""
32
  )
33
  SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
34
+ print(f"SF2 font download/extract took {time.time()-t0:.2f}s")
35
 
36
 
37
  ### LOAD MODEL & ENCODERS ###
38
+ print("loading model...")
39
+ t0 = time.time()
40
  model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
41
  custom_objects={"Preplexity": Preplexity})
42
+ print(f"model loaded in {time.time()-t0:.2f}s")
43
+
44
+ t1 = time.time()
45
  note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
46
+ print(f"NoteEncoder init took {time.time()-t1:.2f}s")
47
 
48
 
49
  ### GRADIO ASSETS ###
50
+ t2 = time.time()
51
  css = load_css()
52
  english_summary = load_markdown("english_summary")
53
  persian_summary = load_markdown("persian_summary")
 
55
  persian_help = load_markdown("persian_help")
56
  english_title = "# BachNet: AI-Generated Bach Music"
57
  persian_title = "# باخ‌نت: خلق موسیقی مشابه باخ با هوش مصنوعی"
58
+ print(f"Loaded CSS/markdown assets in {time.time()-t2:.2f}s")
59
 
60
 
61
  ### GENERATION FUNCTIONS ###
 
64
 
65
  def generate_fn(seed_path, seed_len, gen_len, temp):
66
  sample_rows = slice(0, seed_len)
67
+
68
+ print("=== Generation started ===")
69
+ t0 = time.time()
70
  generate_chorale(
71
  model=model,
72
  sample_seed_path=seed_path,
 
77
  temperature=temp,
78
  sample_seed_rows=sample_rows
79
  )
80
+ t1 = time.time()
81
+ print(f"generate_chorale took {t1 - t0:.2f}s")
82
+
83
  midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
84
+ t2 = time.time()
85
+ print(f"midi_to_wave took {t2 - t1:.2f}s")
86
+ print(f"Total generate_fn time {t2 - t0:.2f}s")
87
+
88
  return wav_path
89
 
90
  def set_english():
 
140
 
141
  ### LAUNCH APP ###
142
  if __name__ == "__main__":
143
+ demo.launch()