Spaces:
Runtime error
Runtime error
| from transformers import JukeboxModel , JukeboxTokenizer | |
| from transformers.models.jukebox import convert_jukebox | |
| import gradio as gr | |
| import torch as t | |
| model_id = 'openai/jukebox-1b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics'] | |
| sample_rate = 44100 | |
| total_duration_in_seconds = 200 | |
| raw_to_tokens = 128 | |
| chunk_size = 32 | |
| max_batch_size = 16 | |
| cache_path = '~/.cache/' | |
| def tokens_to_seconds(tokens, level = 2): | |
| global sample_rate, raw_to_tokens | |
| return tokens * raw_to_tokens / sample_rate / 4 ** (2 - level) | |
| def seconds_to_tokens(sec, level = 2): | |
| global sample_rate, raw_to_tokens, chunk_size | |
| tokens = sec * sample_rate // raw_to_tokens | |
| tokens = ( (tokens // chunk_size) + 1 ) * chunk_size | |
| # For levels 1 and 0, multiply by 4 and 16 respectively | |
| tokens *= 4 ** (2 - level) | |
| return int(tokens) | |
| # Init is ran on server startup | |
| # Load your model to GPU as a global variable here using the variable name "model" | |
| def init(): | |
| global model | |
| print(f"Loading model from/to {cache_path}...") | |
| model = JukeboxModel.from_pretrained( | |
| model_id, | |
| device_map = "auto", | |
| torch_dtype = t.float16, | |
| cache_dir = f"{cache_path}/jukebox/models", | |
| resume_download = True, | |
| min_duration = 0 | |
| ).eval() | |
| print("Model loaded: ", model) | |
| # Inference is ran for every server call | |
| # Reference your preloaded global model variable here. | |
| def inference(artist, genres, lyrics): | |
| global model, zs | |
| n_samples = 4 | |
| generation_length = seconds_to_tokens(1) | |
| offset = 0 | |
| level = 0 | |
| model.total_length = seconds_to_tokens(total_duration_in_seconds) | |
| sampling_kwargs = dict( | |
| temp = 0.98, | |
| chunk_size = chunk_size, | |
| ) | |
| metas = dict( | |
| artist = artist, | |
| genres = genres, | |
| lyrics = lyrics, | |
| ) | |
| labels = JukeboxTokenizer.from_pretrained(model_id)(**metas)['input_ids'][level].repeat(n_samples, 1).cuda() | |
| print(f"Labels: {labels.shape}") | |
| zs = [ t.zeros(n_samples, 0, dtype=t.long, device='cuda') for _ in range(3) ] | |
| print(f"Zs: {[z.shape for z in zs]}") | |
| zs = model.sample_partial_window( | |
| zs, labels, offset, sampling_kwargs, level = level, tokens_to_sample = generation_length, max_batch_size = max_batch_size | |
| ) | |
| print(f"Zs after sampling: {[z.shape for z in zs]}") | |
| # Convert to numpy array | |
| return zs.cpu().numpy() | |
| with gr.Blocks() as ui: | |
| # Define UI components | |
| title = gr.Textbox(lines=1, label="Title") | |
| artist = gr.Textbox(lines=1, label="Artist") | |
| genres = gr.Textbox(lines=1, label="Genre(s)", placeholder="Separate with spaces") | |
| lyrics = gr.Textbox(lines=5, label="Lyrics", placeholder="Shift+Enter for new line") | |
| submit = gr.Button(label="Generate") | |
| output_zs = gr.Dataframe(label="zs") | |
| submit.click( | |
| inference, | |
| inputs = [ artist, genres, lyrics ], | |
| outputs = output_zs, | |
| ) | |
| if __name__ == "__main__": | |
| init() | |
| gr.launch() |