Update app.py
Browse files
app.py
CHANGED
|
@@ -6,18 +6,15 @@ from pathlib import Path
|
|
| 6 |
os.system("pip install gsutil")
|
| 7 |
|
| 8 |
|
| 9 |
-
|
| 10 |
-
os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
|
| 11 |
-
os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
|
| 12 |
-
os.system("python3 -m pip install -e .")
|
| 13 |
-
os.system("python3 -m pip install --upgrade pip")
|
| 14 |
|
| 15 |
|
| 16 |
|
| 17 |
# install mt3
|
| 18 |
os.system("git clone --branch=main https://github.com/magenta/mt3")
|
| 19 |
os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
|
| 20 |
-
os.system("python3 -m pip install -e .
|
|
|
|
| 21 |
|
| 22 |
# copy checkpoints
|
| 23 |
os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
|
|
@@ -35,18 +32,13 @@ import functools
|
|
| 35 |
import os
|
| 36 |
|
| 37 |
import numpy as np
|
| 38 |
-
|
| 39 |
import tensorflow.compat.v2 as tf
|
| 40 |
|
| 41 |
import functools
|
| 42 |
import gin
|
| 43 |
-
import jax
|
| 44 |
-
jax.extend.linear_util = jax.linear_util
|
| 45 |
import librosa
|
| 46 |
import note_seq
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
import seqio
|
| 51 |
import t5
|
| 52 |
import t5x
|
|
@@ -59,6 +51,7 @@ from mt3 import preprocessors
|
|
| 59 |
from mt3 import spectrograms
|
| 60 |
from mt3 import vocabularies
|
| 61 |
|
|
|
|
| 62 |
|
| 63 |
import nest_asyncio
|
| 64 |
nest_asyncio.apply()
|
|
@@ -66,9 +59,12 @@ nest_asyncio.apply()
|
|
| 66 |
SAMPLE_RATE = 16000
|
| 67 |
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
|
| 68 |
|
| 69 |
-
def upload_audio(
|
|
|
|
|
|
|
|
|
|
| 70 |
return note_seq.audio_io.wav_data_to_samples_librosa(
|
| 71 |
-
|
| 72 |
|
| 73 |
|
| 74 |
|
|
@@ -89,16 +85,16 @@ class InferenceModel(object):
|
|
| 89 |
else:
|
| 90 |
raise ValueError('unknown model_type: %s' % model_type)
|
| 91 |
|
| 92 |
-
gin_files = ['/
|
| 93 |
-
'/
|
| 94 |
|
| 95 |
self.batch_size = 8
|
| 96 |
self.outputs_length = 1024
|
| 97 |
-
self.sequence_length = {'inputs': self.inputs_length,
|
| 98 |
'targets': self.outputs_length}
|
| 99 |
|
| 100 |
self.partitioner = t5x.partitioning.PjitPartitioner(
|
| 101 |
-
|
| 102 |
|
| 103 |
# Build Codecs and Vocabularies.
|
| 104 |
self.spectrogram_config = spectrograms.SpectrogramConfig()
|
|
@@ -187,9 +183,10 @@ class InferenceModel(object):
|
|
| 187 |
|
| 188 |
def __call__(self, audio):
|
| 189 |
"""Infer note sequence from audio samples.
|
| 190 |
-
|
| 191 |
Args:
|
| 192 |
audio: 1-d numpy array of audio samples (16kHz) for a single example.
|
|
|
|
| 193 |
Returns:
|
| 194 |
A note_sequence of the transcribed audio.
|
| 195 |
"""
|
|
|
|
| 6 |
os.system("pip install gsutil")
|
| 7 |
|
| 8 |
|
| 9 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
# install mt3
|
| 14 |
os.system("git clone --branch=main https://github.com/magenta/mt3")
|
| 15 |
os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
|
| 16 |
+
os.system("python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 17 |
+
")
|
| 18 |
|
| 19 |
# copy checkpoints
|
| 20 |
os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
|
|
|
|
| 32 |
import os
|
| 33 |
|
| 34 |
import numpy as np
|
|
|
|
| 35 |
import tensorflow.compat.v2 as tf
|
| 36 |
|
| 37 |
import functools
|
| 38 |
import gin
|
| 39 |
+
import jax
|
|
|
|
| 40 |
import librosa
|
| 41 |
import note_seq
|
|
|
|
|
|
|
|
|
|
| 42 |
import seqio
|
| 43 |
import t5
|
| 44 |
import t5x
|
|
|
|
| 51 |
from mt3 import spectrograms
|
| 52 |
from mt3 import vocabularies
|
| 53 |
|
| 54 |
+
from google.colab import files
|
| 55 |
|
| 56 |
import nest_asyncio
|
| 57 |
nest_asyncio.apply()
|
|
|
|
| 59 |
SAMPLE_RATE = 16000
|
| 60 |
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
|
| 61 |
|
| 62 |
+
def upload_audio(sample_rate):
|
| 63 |
+
data = list(files.upload().values())
|
| 64 |
+
if len(data) > 1:
|
| 65 |
+
print('Multiple files uploaded; using only one.')
|
| 66 |
return note_seq.audio_io.wav_data_to_samples_librosa(
|
| 67 |
+
data[0], sample_rate=sample_rate)
|
| 68 |
|
| 69 |
|
| 70 |
|
|
|
|
| 85 |
else:
|
| 86 |
raise ValueError('unknown model_type: %s' % model_type)
|
| 87 |
|
| 88 |
+
gin_files = ['/content/mt3/gin/model.gin',
|
| 89 |
+
f'/content/mt3/gin/{model_type}.gin']
|
| 90 |
|
| 91 |
self.batch_size = 8
|
| 92 |
self.outputs_length = 1024
|
| 93 |
+
self.sequence_length = {'inputs': self.inputs_length,
|
| 94 |
'targets': self.outputs_length}
|
| 95 |
|
| 96 |
self.partitioner = t5x.partitioning.PjitPartitioner(
|
| 97 |
+
num_partitions=1)
|
| 98 |
|
| 99 |
# Build Codecs and Vocabularies.
|
| 100 |
self.spectrogram_config = spectrograms.SpectrogramConfig()
|
|
|
|
| 183 |
|
| 184 |
def __call__(self, audio):
|
| 185 |
"""Infer note sequence from audio samples.
|
| 186 |
+
|
| 187 |
Args:
|
| 188 |
audio: 1-d numpy array of audio samples (16kHz) for a single example.
|
| 189 |
+
|
| 190 |
Returns:
|
| 191 |
A note_sequence of the transcribed audio.
|
| 192 |
"""
|