sadafwalliyani commited on
Commit
84248c7
·
verified ·
1 Parent(s): 86169b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -66
app.py CHANGED
@@ -5,21 +5,19 @@ from audiocraft.models import MusicGen
5
  import os
6
  import numpy as np
7
  import base64
 
8
 
9
- genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
10
- "Lofi", "Chillpop","Country","R&G", "Folk","Heavy Metal",
11
- "EDM", "Soil", "Funk","Reggae", "Disco", "Punk Rock", "House",
12
- "Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
13
 
14
  @st.cache_resource()
15
  def load_model():
16
  model = MusicGen.get_pretrained('facebook/musicgen-small')
17
  return model
 
18
 
19
- def generate_music_tensors(descriptions, duration: int):
20
  model = load_model()
21
- # model = load_model().to('cpu')
22
-
23
 
24
  model.set_generation_params(
25
  use_sampling=True,
@@ -28,19 +26,23 @@ def generate_music_tensors(descriptions, duration: int):
28
  )
29
 
30
  with st.spinner("Generating Music..."):
31
- output = model.generate(
32
- descriptions=descriptions,
33
- progress=True,
34
- return_tokens=True
35
- )
 
 
 
 
36
 
37
  st.success("Music Generation Complete!")
38
  return output
 
39
 
40
-
41
- def save_audio(samples: torch.Tensor):
42
  sample_rate = 30000
43
- save_path = "audio_output"
44
  assert samples.dim() == 2 or samples.dim() == 3
45
 
46
  samples = samples.detach().cpu()
@@ -48,8 +50,9 @@ def save_audio(samples: torch.Tensor):
48
  samples = samples[None, ...]
49
 
50
  for idx, audio in enumerate(samples):
51
- audio_path = os.path.join(save_path, f"audio_{idx}.wav")
52
  torchaudio.save(audio_path, audio, sample_rate)
 
53
 
54
  def get_binary_file_downloader_html(bin_file, file_label='File'):
55
  with open(bin_file, 'rb') as f:
@@ -64,56 +67,48 @@ st.set_page_config(
64
  )
65
 
66
  def main():
67
- with st.sidebar:
68
- st.header("""⚙️Generate Music ⚙️""",divider="rainbow")
69
- st.text("")
70
- st.subheader("1. Enter your music description.......")
71
- bpm = st.number_input("Enter Speed in BPM", min_value=60)
72
-
73
- text_area = st.text_area('Ex : 80s rock song with guitar and drums')
74
- st.text('')
75
- # Dropdown for genres
76
- selected_genre = st.selectbox("Select Genre", genres)
77
-
78
- st.subheader("2. Select time duration (In Seconds)")
79
- time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
80
- # time_slider = st.slider("Select time duration (In Minutes)", 0,300,10, step=1)
81
-
82
-
83
- st.title("""🎵 Song Lab AI 🎵""")
84
- st.text('')
85
- left_co,right_co = st.columns(2)
86
- left_co.write("""Music Generation through a prompt""")
87
- left_co.write(("""PS : First generation may take some time ......."""))
88
 
89
- if st.sidebar.button('Generate !'):
90
- with left_co:
91
- st.text('')
92
- st.text('')
93
- st.text('')
94
- st.text('')
95
- st.text('')
96
- st.text('')
97
- st.text('\n\n')
98
- st.subheader("Generated Music")
99
-
100
- # Generate audio
101
- # descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)]
102
- descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
103
- music_tensors = generate_music_tensors(descriptions, time_slider)
104
-
105
- # Only play the full audio for index 0
106
- idx = 0
107
- music_tensor = music_tensors[idx]
108
- save_music_file = save_audio(music_tensor)
109
- audio_filepath = f'audio_output/audio_{idx}.wav'
110
- audio_file = open(audio_filepath, 'rb')
111
- audio_bytes = audio_file.read()
112
-
113
- # Play the full audio
114
- st.audio(audio_bytes, format='audio/wav')
115
- st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
116
-
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
- main()
 
 
5
  import os
6
  import numpy as np
7
  import base64
8
+ # from torch.nn.utils.parametrizations import weight_norm
9
 
10
+ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
11
+ "Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]
 
 
12
 
13
  @st.cache_resource()
14
  def load_model():
15
  model = MusicGen.get_pretrained('facebook/musicgen-small')
16
  return model
17
+
18
 
19
+ def generate_music_tensors(description, duration: int, batch_size=1):
20
  model = load_model()
 
 
21
 
22
  model.set_generation_params(
23
  use_sampling=True,
 
26
  )
27
 
28
  with st.spinner("Generating Music..."):
29
+ output = []
30
+ for i in range(0, len(description), batch_size):
31
+ batch_descriptions = description[i:i+batch_size]
32
+ batch_output = model.generate(
33
+ descriptions=batch_descriptions,
34
+ progress=True,
35
+ return_tokens=True
36
+ )
37
+ output.extend(batch_output)
38
 
39
  st.success("Music Generation Complete!")
40
  return output
41
+
42
 
43
+ def save_audio(samples: torch.Tensor, filename):
 
44
  sample_rate = 30000
45
+ save_path = "/content/drive/MyDrive/Colab Notebooks/audio_output"
46
  assert samples.dim() == 2 or samples.dim() == 3
47
 
48
  samples = samples.detach().cpu()
 
50
  samples = samples[None, ...]
51
 
52
  for idx, audio in enumerate(samples):
53
+ audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
54
  torchaudio.save(audio_path, audio, sample_rate)
55
+ return audio_path
56
 
57
  def get_binary_file_downloader_html(bin_file, file_label='File'):
58
  with open(bin_file, 'rb') as f:
 
67
  )
68
 
69
  def main():
70
+ st.title(" 🎶 AI Composer Small-Model 🎶")
71
+
72
+ st.subheader("Craft your perfect melody!")
73
+
74
+ bpm = st.number_input("Enter Speed in BPM", min_value=60)
75
+ text_area = st.text_area('Example: 80s rock song with guitar and drums', height=50)
76
+ selected_genre = st.selectbox("Select Genre (Optional)", genres, None)
77
+ time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
80
+ instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
81
+ tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
82
+ melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
83
+
84
+ if st.button('Let\'s Generate 🎶'):
85
+ st.text('\n\n')
86
+ st.subheader("Generated Music")
87
+
88
+ description = f"{text_area}"
89
+ if selected_genre:
90
+ description += f" {selected_genre}"
91
+ if bpm:
92
+ description += f" {bpm} BPM"
93
+ if mood:
94
+ description += f" {mood}"
95
+ if instrument:
96
+ description += f" {instrument}"
97
+ if tempo:
98
+ description += f" {tempo}"
99
+ if melody:
100
+ description += f" {melody}"
101
+
102
+ music_tensors = generate_music_tensors(description, time_slider)
103
+
104
+ idx = 0
105
+ audio_path = save_audio(music_tensors[idx], "audio_output")
106
+ audio_file = open(audio_path, 'rb')
107
+ audio_bytes = audio_file.read()
108
+
109
+ st.audio(audio_bytes, format='audio/wav')
110
+ st.markdown(get_binary_file_downloader_html(audio_path, f'Audio_{idx}'), unsafe_allow_html=True)
111
 
112
  if __name__ == "__main__":
113
+ main()
114
+