sadafwalliyani commited on
Commit
332cbcf
·
verified ·
1 Parent(s): 84248c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -45
app.py CHANGED
@@ -1,22 +1,26 @@
1
  import streamlit as st
2
  import torch
3
  import torchaudio
4
- from audiocraft.models import MusicGen
5
  import os
6
  import numpy as np
7
  import base64
8
- # from torch.nn.utils.parametrizations import weight_norm
 
 
 
 
 
 
 
9
 
10
- genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
11
- "Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]
12
 
13
  @st.cache_resource()
14
  def load_model():
15
  model = MusicGen.get_pretrained('facebook/musicgen-small')
16
  return model
17
-
18
 
19
- def generate_music_tensors(description, duration: int, batch_size=1):
20
  model = load_model()
21
 
22
  model.set_generation_params(
@@ -26,23 +30,18 @@ def generate_music_tensors(description, duration: int, batch_size=1):
26
  )
27
 
28
  with st.spinner("Generating Music..."):
29
- output = []
30
- for i in range(0, len(description), batch_size):
31
- batch_descriptions = description[i:i+batch_size]
32
- batch_output = model.generate(
33
- descriptions=batch_descriptions,
34
- progress=True,
35
- return_tokens=True
36
- )
37
- output.extend(batch_output)
38
 
39
  st.success("Music Generation Complete!")
40
  return output
41
-
42
 
43
- def save_audio(samples: torch.Tensor, filename):
44
  sample_rate = 30000
45
- save_path = "/content/drive/MyDrive/Colab Notebooks/audio_output"
46
  assert samples.dim() == 2 or samples.dim() == 3
47
 
48
  samples = samples.detach().cpu()
@@ -50,9 +49,8 @@ def save_audio(samples: torch.Tensor, filename):
50
  samples = samples[None, ...]
51
 
52
  for idx, audio in enumerate(samples):
53
- audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
54
  torchaudio.save(audio_path, audio, sample_rate)
55
- return audio_path
56
 
57
  def get_binary_file_downloader_html(bin_file, file_label='File'):
58
  with open(bin_file, 'rb') as f:
@@ -67,48 +65,62 @@ st.set_page_config(
67
  )
68
 
69
  def main():
70
- st.title(" 🎶 AI Composer Small-Model 🎶")
71
 
72
  st.subheader("Craft your perfect melody!")
73
-
74
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
75
- text_area = st.text_area('Example: 80s rock song with guitar and drums', height=50)
76
- selected_genre = st.selectbox("Select Genre (Optional)", genres, None)
77
- time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
78
-
79
- mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
80
- instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
81
- tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
82
- melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
 
 
 
 
83
 
84
  if st.button('Let\'s Generate 🎶'):
85
  st.text('\n\n')
86
  st.subheader("Generated Music")
87
-
88
- description = f"{text_area}"
 
89
  if selected_genre:
90
  description += f" {selected_genre}"
 
91
  if bpm:
92
  description += f" {bpm} BPM"
93
- if mood:
94
- description += f" {mood}"
95
- if instrument:
96
- description += f" {instrument}"
97
- if tempo:
98
- description += f" {tempo}"
99
- if melody:
100
- description += f" {melody}"
 
 
 
 
 
 
101
 
102
  music_tensors = generate_music_tensors(description, time_slider)
103
 
 
104
  idx = 0
105
- audio_path = save_audio(music_tensors[idx], "audio_output")
106
- audio_file = open(audio_path, 'rb')
 
 
107
  audio_bytes = audio_file.read()
108
 
 
109
  st.audio(audio_bytes, format='audio/wav')
110
- st.markdown(get_binary_file_downloader_html(audio_path, f'Audio_{idx}'), unsafe_allow_html=True)
111
 
112
  if __name__ == "__main__":
113
- main()
114
-
 
1
  import streamlit as st
2
  import torch
3
  import torchaudio
 
4
  import os
5
  import numpy as np
6
  import base64
7
+ from audiocraft.models import MusicGen
8
+
9
+ # Before
10
+ batch_size = 64
11
+
12
+ # After
13
+ batch_size = 32
14
+ torch.cuda.empty_cache()
15
 
16
+ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
 
17
 
18
  @st.cache_resource()
19
  def load_model():
20
  model = MusicGen.get_pretrained('facebook/musicgen-small')
21
  return model
 
22
 
23
+ def generate_music_tensors(description, duration: int):
24
  model = load_model()
25
 
26
  model.set_generation_params(
 
30
  )
31
 
32
  with st.spinner("Generating Music..."):
33
+ output = model.generate(
34
+ descriptions=description,
35
+ progress=True,
36
+ return_tokens=True
37
+ )
 
 
 
 
38
 
39
  st.success("Music Generation Complete!")
40
  return output
 
41
 
42
+ def save_audio(samples: torch.Tensor):
43
  sample_rate = 30000
44
+ save_path = "audio_output"
45
  assert samples.dim() == 2 or samples.dim() == 3
46
 
47
  samples = samples.detach().cpu()
 
49
  samples = samples[None, ...]
50
 
51
  for idx, audio in enumerate(samples):
52
+ audio_path = os.path.join(save_path, f"audio_{idx}.wav")
53
  torchaudio.save(audio_path, audio, sample_rate)
 
54
 
55
  def get_binary_file_downloader_html(bin_file, file_label='File'):
56
  with open(bin_file, 'rb') as f:
 
65
  )
66
 
67
  def main():
68
+ st.title("🎧 AI Composer Medium-Model 🎧")
69
 
70
  st.subheader("Craft your perfect melody!")
 
71
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
72
+
73
+ text_area = st.text_area('Ex : 80s rock song with guitar and drums')
74
+ st.text('')
75
+ # Dropdown for genres
76
+ selected_genre = st.selectbox("Select Genre", genres)
77
+
78
+ st.subheader("2. Select time duration (In Seconds)")
79
+ time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
80
+ # mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
81
+ # instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
82
+ # tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
83
+ # melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
84
 
85
  if st.button('Let\'s Generate 🎶'):
86
  st.text('\n\n')
87
  st.subheader("Generated Music")
88
+
89
+ # Generate audio
90
+ description = text_area # Initialize description with text_area
91
  if selected_genre:
92
  description += f" {selected_genre}"
93
+ st.empty() # Hide the selected_genre selectbox after selecting one option
94
  if bpm:
95
  description += f" {bpm} BPM"
96
+ # if mood:
97
+ # description += f" {mood}"
98
+ # st.empty() # Hide the mood selectbox after selecting one option
99
+ # if instrument:
100
+ # description += f" {instrument}"
101
+ # st.empty() # Hide the instrument selectbox after selecting one option
102
+ # if tempo:
103
+ # description += f" {tempo}"
104
+ # st.empty() # Hide the tempo selectbox after selecting one option
105
+ # if melody:
106
+ # description += f" {melody}"
107
+
108
+ # Clear CUDA memory cache before generating music
109
+ torch.cuda.empty_cache()
110
 
111
  music_tensors = generate_music_tensors(description, time_slider)
112
 
113
+ # Only play the full audio for index 0
114
  idx = 0
115
+ music_tensor = music_tensors[idx]
116
+ save_audio(music_tensor)
117
+ audio_filepath = f'/audio_output/audio_{idx}.wav'
118
+ audio_file = open(audio_filepath, 'rb')
119
  audio_bytes = audio_file.read()
120
 
121
+ # Play the full audio
122
  st.audio(audio_bytes, format='audio/wav')
123
+ st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
124
 
125
  if __name__ == "__main__":
126
+ main()