{ "cells": [ { "cell_type": "markdown", "id": "70300319-d206-43ce-b3bf-3da6b079f20f", "metadata": { "id": "70300319-d206-43ce-b3bf-3da6b079f20f" }, "source": [ "## MusicGen in 🤗 Transformers\n", "\n", "**by [Sanchit Gandhi](https://huggingface.co/sanchit-gandhi)**\n", "\n", "MusicGen is a Transformer-based model capable fo generating high-quality music samples conditioned on text descriptions or audio prompts. It was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet et al. from Meta AI.\n", "\n", "The MusicGen model can be de-composed into three distinct stages:\n", "1. The text descriptions are passed through a frozen text encoder model to obtain a sequence of hidden-state representations\n", "2. The MusicGen decoder is then trained to predict discrete audio tokens, or *audio codes*, conditioned on these hidden-states\n", "3. These audio tokens are then decoded using an audio compression model, such as EnCodec, to recover the audio waveform\n", "\n", "The pre-trained MusicGen checkpoints use Google's [t5-base](https://huggingface.co/t5-base) as the text encoder model, and [EnCodec 32kHz](https://huggingface.co/facebook/encodec_32khz) as the audio compression model. The MusicGen decoder is a pure language model architecture,\n", "trained from scratch on the task of music generation.\n", "\n", "The novelty in the MusicGen model is how the audio codes are predicted. Traditionally, each codebook has to be predicted by a separate model (i.e. hierarchically) or by continuously refining the output of the Transformer model (i.e. upsampling). MusicGen uses an efficient *token interleaving pattern*, thus eliminating the need to cascade multiple models to predict a set of codebooks. Instead, it is able to generate the full set of codebooks in a single forward pass of the decoder, resulting in much faster inference.\n", "\n", "

\n", " \n", "

\n", "\n", "\n", "**Figure 1:** Codebook delay pattern used by MusicGen. Figure taken from the [MusicGen paper](https://arxiv.org/abs/2306.05284).\n" ] }, { "cell_type": "markdown", "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77", "metadata": { "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77" }, "source": [ "## Prepare the Environment" ] }, { "cell_type": "markdown", "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c", "metadata": { "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c" }, "source": [ "Let’s make sure we’re connected to a GPU to run this notebook. To get a GPU, click `Runtime` -> `Change runtime type`, then change `Hardware accelerator` from `None` to `GPU`. We can verify that we’ve been assigned a GPU and view its specifications through the `nvidia-smi` command:" ] }, { "cell_type": "code", "execution_count": 1, "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a", "outputId": "bcb3fd96-cc9c-45fc-a0ce-2fc79398cedc", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Thu Sep 7 22:36:00 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", "| N/A 43C P0 60W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", "| N/A 44C P0 57W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", "| N/A 43C P0 58W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| No running processes found |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd", "metadata": { "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd" }, "source": [ "We see here that we've got on Tesla T4 16GB GPU, although this may vary for you depending on GPU availablity and Colab GPU assignment.\n", "\n", "Next, we install the 🤗 Transformers package from the main branch, as well as 🤗 Datasets package to load audio files for audio-prompted generation:" ] }, { "cell_type": "markdown", "id": "77ee39cc-654b-4f0e-b601-013e484c16f0", "metadata": { "id": "77ee39cc-654b-4f0e-b601-013e484c16f0" }, "source": [ "## Load the Model\n", "\n", "The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:" ] }, { "cell_type": "code", "execution_count": 48, "id": "b0d87424-9f38-4658-ba47-2a465d52ad77", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "013191f7a16c49fdbfc9b6cb3b0aa089", "1c58ed64d7144bcf9c58fe0f89364d61", "ce32b8aaae2f4cd38d5ec78fefaa34ce", "7ed4f572d3534679a3e1e4d90880bd71", "fd741db4538f493588106d753a747593", "0aa3f4c09c854d90b9357d356e8ed46b", "bbf7a9706dd64f6eb24d4b46ff52bc23", "7f52d5efbb064170ac8d8681ae92f29b", "881fd640878c420cbc14dfd5f8516953", "f74afcde56c64ab389fed5bf7f5964b8", "09bbff09eaa44cdf92612c7f02f05f63", "0cd8b835ee724b659c92fbee8eb62327", "50323b27b7ff4823b733bffa97930218", "abcf0eeed63f4331bbc9a044b2d3d65f", "2fa893516c3946729fa0eda21550f658", "7742549b504c4e71b4d0a2119d459936", "6081dfc2472b4097b01bc7608c100ca3", "a810f25c280d48539d7382064588efd7", "b93f98b1284b4cf8bc325e9c4d9a2bf1", "a43852c0c4754235a7cf3fb7221eed1d", "410cc9979cc44f5aad48d17cf6042165", "768bf7c8a965476c99693c9aa3ea89cb", "f938b2a799d5454f8885d5d6bba24b94", "a2a0abc232c04d8c96704d6b012227aa", "4d191abe24514b6a8a5bcb7aa4dd624c", "9350625ba0fe43949ea335a27f8e402d", "e9895aee370a4d888cefa7a82cd90c00", "dbcc2550fed34fc4b9d1f9b5e7465b85", "aca7c88abc684793880a4619257522c4", "f83972e889134265aff6737844971e6f", "9fa6a2c3f81e4a159cf652d8a48acd37", "92804f5fa9ef49a094ce3d54b051ff9c", "9ac4c059958446b1af03da2fe2e0ac20" ] }, "id": "b0d87424-9f38-4658-ba47-2a465d52ad77", "outputId": "2af6a8e9-ff4c-4f2d-8aa2-89f35f0575e7", "tags": [] }, "outputs": [], "source": [ "from transformers import MusicgenForConditionalGeneration\n", "cache_dir = '/fsx/proj-fmri/ckadirt/b2m/cache/'\n", "model = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\", cache_dir=cache_dir)" ] }, { "cell_type": "markdown", "id": "4981d112-407c-4120-86aa-5c6a170543f7", "metadata": { "id": "4981d112-407c-4120-86aa-5c6a170543f7" }, "source": [ "We can then place the model on our accelerator device (if available), or leave it on the CPU otherwise:" ] }, { "cell_type": "code", "execution_count": 49, "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97", "metadata": { "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97", "tags": [] }, "outputs": [], "source": [ "import torch\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model.to(device);" ] }, { "cell_type": "code", "execution_count": 4, "id": "ba5cdeee-27d0-4834-b7dc-4403864550f7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Thu Sep 7 22:37:19 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", "| N/A 45C P0 78W / 400W | 3593MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", "| N/A 43C P0 60W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", "| N/A 44C P0 57W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", "| N/A 43C P0 58W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| 0 N/A N/A 742410 C ...3/envs/mindeye/bin/python 3590MiB |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547", "metadata": { "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547" }, "source": [ "## Generation\n", "\n", "MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly\n", "better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,\n", "and can be explicitly specified by setting `do_sample=True` in the call to `MusicgenForConditionalGeneration.generate` (see below).\n", "\n", "### Unconditional Generation\n", "\n", "The inputs for unconditional (or 'null') generation can be obtained through the method `MusicgenForConditionalGeneration.get_unconditional_inputs`. We can then run auto-regressive generation using the `.generate` method, specifying `do_sample=True` to enable sampling mode:" ] }, { "cell_type": "code", "execution_count": 5, "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2", "metadata": { "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2", "tags": [] }, "outputs": [], "source": [ "unconditional_inputs = model.get_unconditional_inputs(num_samples=1)\n", "\n", "audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)" ] }, { "cell_type": "markdown", "id": "94cb74df-c194-4d2e-930a-12473b08a919", "metadata": { "id": "94cb74df-c194-4d2e-930a-12473b08a919" }, "source": [ "The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen\n", "to the generated audio samples, you can either play them in an ipynb notebook:" ] }, { "cell_type": "code", "execution_count": 6, "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea", "outputId": "9a47adc3-17b1-40d6-989d-73319c9ea7ee", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import Audio\n", "\n", "sampling_rate = model.config.audio_encoder.sampling_rate\n", "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" ] }, { "cell_type": "markdown", "id": "6de58334-40f7-4924-addb-2d6ff34c0590", "metadata": { "id": "6de58334-40f7-4924-addb-2d6ff34c0590" }, "source": [ "Or save them as a `.wav` file using a third-party library, e.g. `scipy` (note here that we also need to remove the channel dimension from our audio tensor):" ] }, { "cell_type": "code", "execution_count": 7, "id": "04291f52-0a75-4ddb-9eff-e853d0f17288", "metadata": { "id": "04291f52-0a75-4ddb-9eff-e853d0f17288" }, "outputs": [], "source": [ "import scipy\n", "\n", "scipy.io.wavfile.write(\"musicgen_out.wav\", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())" ] }, { "cell_type": "markdown", "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39", "metadata": { "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39" }, "source": [ "The argument `max_new_tokens` specifies the number of new tokens to generate. As a rule of thumb, you can work out the length of the generated audio sample in seconds by using the frame rate of the EnCodec model:" ] }, { "cell_type": "code", "execution_count": 7, "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a", "outputId": "c021297b-837b-4d35-e01b-34a66a33c1dd", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "5.12" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_length_in_s = 256 / model.config.audio_encoder.frame_rate\n", "\n", "audio_length_in_s" ] }, { "cell_type": "markdown", "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581", "metadata": { "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581" }, "source": [ "### Text-Conditional Generation\n", "\n", "The model can generate an audio sample conditioned on a text prompt through use of the `MusicgenProcessor` to pre-process\n", "the inputs. The pre-processed inputs can then be passed to the `.generate` method to generate text-conditional audio samples.\n", "Again, we enable sampling mode by setting `do_sample=True`:" ] }, { "cell_type": "code", "execution_count": 59, "id": "5fba4154-13f6-403a-958b-101d6eacfb6e", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "id": "5fba4154-13f6-403a-958b-101d6eacfb6e", "outputId": "a4b87d4b-1db0-49af-dac0-eb600c47cbfb", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import AutoProcessor\n", "from einops import rearrange\n", "\n", "processor = AutoProcessor.from_pretrained(\"facebook/musicgen-small\")\n", "\n", "inputs = processor(\n", " text=[\"hiphop beat\", \"90s rock song with loud guitars and heavy drums\"],\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n", "\n", "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" ] }, { "cell_type": "code", "execution_count": 60, "id": "_FZb_Zo-Dajl", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_FZb_Zo-Dajl", "outputId": "9834a745-ac10-4142-e22f-6998997304e5", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 2, 4, 253])\n" ] } ], "source": [ "with torch.no_grad():\n", " tokens_train = model.audio_encoder(audio_values)\n", " print(tokens_train.audio_codes.shape)\n", " tokens_on_format = rearrange(tokens_train.audio_codes, \"n b c l -> (n b c) l\")\n", " # make a copy of the tokens excluding to the computation graph\n", " tokens_on_format = tokens_on_format.detach().clone()" ] }, { "cell_type": "code", "execution_count": 61, "id": "VHyalHz78TlY", "metadata": { "id": "VHyalHz78TlY", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 13, 768])\n" ] } ], "source": [ "with torch.no_grad():\n", " encoder_hidden_states = model.text_encoder(inputs.input_ids, attention_mask = inputs.attention_mask.detach().clone()).last_hidden_state.detach().clone()\n", " print(encoder_hidden_states.shape)\n" ] }, { "cell_type": "code", "execution_count": 63, "id": "dOnCaS3F9yfz", "metadata": { "id": "dOnCaS3F9yfz", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 13, 1024])\n" ] } ], "source": [ "with torch.no_grad():\n", " encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states).detach().clone()\n", " print(encoder_hidden_states.shape)" ] }, { "cell_type": "code", "execution_count": 66, "id": "z56a5Yjqoum5", "metadata": { "id": "z56a5Yjqoum5", "tags": [] }, "outputs": [], "source": [ "pad_token_id = model.generation_config.pad_token_id\n", "decoder_input_ids = (\n", " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n", " * pad_token_id\n", ").to(device).detach().clone()\n", "\n", "inputs.attention_mask = inputs.attention_mask.detach().clone()" ] }, { "cell_type": "code", "execution_count": 67, "id": "xLrDtoYhqBiP", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xLrDtoYhqBiP", "outputId": "78909eb6-af29-4655-d473-86359093c9d7", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 1])" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "decoder_input_ids.shape" ] }, { "cell_type": "code", "execution_count": 107, "id": "b67b9b42-fd04-4b0c-ac36-a5b7016ada16", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 253])" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens_on_format.shape" ] }, { "cell_type": "code", "execution_count": 71, "id": "hlKZ7dJgH1VJ", "metadata": { "id": "hlKZ7dJgH1VJ", "tags": [] }, "outputs": [], "source": [ "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)" ] }, { "cell_type": "code", "execution_count": 72, "id": "62b31feb-e392-4b5c-9855-25d97364256b", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 252, 2048])" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_manual.logits.shape" ] }, { "cell_type": "code", "execution_count": 73, "id": "a5fa2fd1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2008, 2048]) torch.Size([2008])\n" ] } ], "source": [ "def training_step(labels, results_manual):\n", " # as the labels are the same than the input, we should remove the first column\n", " labels = labels[:, 1:]\n", " # as the logits are the same than the input, we should remove the last column\n", " results_manual.logits = results_manual.logits[:, :-1, :]\n", " loss_fct = torch.nn.CrossEntropyLoss()\n", " print(rearrange(results_manual.logits, \"c t v -> (c t) v\").shape, rearrange(labels, \"c t -> (c t)\").shape)\n", " loss = loss_fct(rearrange(results_manual.logits, \"c t v -> (c t) v\"), rearrange(labels, \"c t -> (c t)\"))\n", " return loss\n", "\n", "loss = training_step(tokens_on_format_test, results_manual)" ] }, { "cell_type": "code", "execution_count": 74, "id": "44f38b39-c06b-4e50-b233-ef6b6f829de5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor(10.8952, device='cuda:0', grad_fn=)" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss" ] }, { "cell_type": "code", "execution_count": 75, "id": "4022c586-423e-489e-9647-c0c24a378092", "metadata": { "tags": [] }, "outputs": [], "source": [ "model.zero_grad()" ] }, { "cell_type": "code", "execution_count": 35, "id": "06335817-59a3-47b2-a2df-10d71acdc3c4", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokens_on_format_test_tt = tokens_on_format_test.detach()" ] }, { "cell_type": "code", "execution_count": 97, "id": "82367433", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3841, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(6.5876, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(5.2953, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(3.9745, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(3.5427, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(3.1808, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.9803, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.8768, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.7334, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.6114, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.5443, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.4862, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.4340, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.3890, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.3461, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.3014, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.2524, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.2108, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.1693, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.1228, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.0805, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.0398, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(2.0003, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.9617, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.9265, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.8898, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.8527, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.8144, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.7804, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.7465, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.7152, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.6817, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.6494, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.6182, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5891, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5631, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5393, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5514, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5996, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5329, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5090, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.4846, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.4633, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.4849, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3862, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.4540, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3665, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3937, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3255, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3280, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2869, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2772, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2436, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2343, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2054, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1852, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1706, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1331, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1386, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1027, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0897, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0732, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1328, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.4694, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3952, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2203, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2310, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1787, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.2025, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1772, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1184, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0861, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0596, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0484, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0140, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0278, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9599, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9769, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9378, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9275, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9106, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.8895, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.8961, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9111, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1137, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5037, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.5237, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3981, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1750, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.3016, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1642, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.1759, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0935, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0816, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0367, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(1.0183, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9768, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9370, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.9176, device='cuda:0', grad_fn=)\n", "torch.Size([2008, 2048]) torch.Size([2008])\n", "tensor(0.8828, device='cuda:0', grad_fn=)\n" ] } ], "source": [ "# move all the stuff to the CPU\n", "#encoder_hidden_states = encoder_hidden_states.cpu()\n", "#inputs.attention_mask = inputs.attention_mask.cpu()\n", "#decoder_attention_mask = decoder_attention_mask.cpu()\n", "#tokens_on_format_test = tokens_on_format_test.cpu()\n", "#results_manual.logits = results_manual.logits.cpu()\n", "#model = model.cpu()\n", "\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)\n", "for i in range(100):\n", " optimizer.zero_grad()\n", " results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)\n", " loss = training_step(tokens_on_format[:,:-1].detach().clone(), results_manual)\n", " loss.backward()\n", " optimizer.step()\n", " \n", " \n", " print(loss)" ] }, { "cell_type": "code", "execution_count": 98, "id": "c9870eac-70ca-4ff5-a4cf-2138bef89de1", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([ 578, 68, 244, 753, 135, 135, 40, 244, 1358, 83, 1265, 91,\n", " 555, 816, 236, 235, 235, 244, 135, 135, 609, 244, 609, 135,\n", " 1081, 480, 549, 14, 327, 613, 244, 244, 83, 135, 66, 244,\n", " 472, 403, 472, 68, 1757, 753, 1348, 289, 609, 289, 609, 289,\n", " 609, 539, 578, 472, 83, 609, 66, 609, 66, 609, 289, 609,\n", " 289, 289, 403, 570, 23, 116, 14, 376, 609, 1757, 135, 1358,\n", " 83, 40, 1205, 578, 444, 68, 135, 1358, 289, 83, 289, 1265,\n", " 1757, 289, 289, 1081, 1036, 68, 1453, 1265, 1265, 1265, 289, 8,\n", " 376, 376, 289, 778, 403, 444, 68, 1265, 1757, 1453, 376, 376,\n", " 376, 289, 609, 1265, 539, 434, 334, 1986, 83, 244, 235, 244,\n", " 1265, 609, 244, 1265, 1453, 172, 1036, 68, 83, 1453, 1453, 609,\n", " 1453, 1265, 1265, 1265, 1265, 1205, 555, 588, 116, 235, 244, 244,\n", " 172, 275, 68, 376, 753, 135, 376, 986, 1348, 609, 8, 609,\n", " 609, 609, 609, 289, 609, 609, 778, 403, 1360, 68, 1265, 21,\n", " 1358, 1453, 40, 1453, 289, 1453, 289, 91, 578, 226, 1265, 1265,\n", " 376, 1265, 289, 1453, 289, 244, 135, 244, 403, 1036, 425, 609,\n", " 244, 1043, 244, 1348, 135, 1348, 244, 1043, 1488, 403, 425, 40,\n", " 83, 244, 254, 40, 135, 1265, 135, 40, 244, 979, 759, 834,\n", " 327, 135, 289, 235, 235, 1358, 327, 289, 244, 778, 403, 1953,\n", " 1617, 244, 83, 14, 135, 1358, 244, 40, 135, 83, 539, 578,\n", " 68, 83, 83, 1453, 1265, 135, 40, 289, 1453, 609, 1453, 403,\n", " 275], device='cuda:0')" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens_on_format[0]" ] }, { "cell_type": "code", "execution_count": 99, "id": "61b8b0e2-8185-4801-9ea3-8b05820ed062", "metadata": { "tags": [] }, "outputs": [], "source": [ "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)" ] }, { "cell_type": "code", "execution_count": 106, "id": "2a1cfaa7-707c-4ed1-8a74-14d595a2d7bf", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 135, 244, 235, 609, 1265, 40, 1358, 68, 83, 816, 236, 555,\n", " 549, 753, 91, 1081, 480, 14, 327, 613, 116, 23, 1757, 1379,\n", " 21, 289, 588, 10, 376, 2007, 148, 146, 789, 1602, 425, 429,\n", " 570, 1132, 172, 1522, 403, 1488, 1846, 1043, 8, 66, 902, 254,\n", " 539, 1106, 472, 374, 1453, 1348, 1671, 226, 778, 1195, 1360, 275,\n", " 1334, 759, 1875, 949, 993, 1194, 444, 1796, 97, 1676, 1187, 1304,\n", " 1953, 492, 1698, 578, 979, 1036, 1775, 434, 1210, 1617, 15, 149,\n", " 276, 1487, 1540, 986, 1398, 721, 243, 4, 1748, 131, 1986, 1544,\n", " 223, 227, 1109, 1815], device='cuda:0')\n" ] } ], "source": [ "max_idx = torch.topk(results_manual.logits[0][9], k=100)[1]\n", "print(max_idx)" ] }, { "cell_type": "code", "execution_count": 34, "id": "UV1XbuCRONUA", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UV1XbuCRONUA", "outputId": "0b846e0f-180d-4bd6-946f-fb29665479d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1, 2048])\n", "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n", " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n" ] } ], "source": [ "print(results_manual.logits.shape)\n", "#results_manual.past_key_values.shape\n", "print(results_manual.logits[1][0][0:10])" ] }, { "cell_type": "code", "execution_count": null, "id": "YpFaaCefZ4mp", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YpFaaCefZ4mp", "outputId": "c49266f8-4602-47ae-94c9-de472661ba5b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2048\n" ] } ], "source": [ "inputs = processor(\n", " text=[\"michael jackson signing pop\", \"90s rock song with loud guitars and heavy drums\"],\n", " padding=True,\n", " return_tensors=\"pt\",\n", ").to('cuda')\n", "\n", "pad_token_id = model.generation_config.pad_token_id\n", "print(pad_token_id)\n", "decoder_input_ids = (\n", " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n", " * pad_token_id\n", ").to('cuda')\n", "\n", "end = model(**inputs, decoder_input_ids=decoder_input_ids)\n", "#logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)" ] }, { "cell_type": "code", "execution_count": null, "id": "Fdr_sVdLQF3m", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Fdr_sVdLQF3m", "outputId": "d4f55f93-e76f-4a90-9253-b6fb51a9add5" }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[2278, 9, 15, 40, 3, 9325, 739, 8097, 2783, 1, 0, 0,\n", " 0],\n", " [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7,\n", " 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n", " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs" ] }, { "cell_type": "code", "execution_count": null, "id": "FyTh1KykP1OT", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FyTh1KykP1OT", "outputId": "036edf58-44b1-40b8-cbad-137ddad5be9b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1, 2048])\n", "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n", " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n" ] } ], "source": [ "print(end.logits.shape)\n", "print(end.logits[1][0][0:10])" ] }, { "cell_type": "code", "execution_count": null, "id": "h_0zNWeSIXxC", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "h_0zNWeSIXxC", "outputId": "a7a3fe57-ceca-4e35-b8a6-7c870fcfdb54" }, "outputs": [ { "data": { "text/plain": [ "tensor([[2048],\n", " [2048],\n", " [2048],\n", " [2048],\n", " [2048],\n", " [2048],\n", " [2048],\n", " [2048]], device='cuda:0')" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(end.logits[1][1][1:10])" ] }, { "cell_type": "code", "execution_count": null, "id": "-1y5EZojYzDz", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-1y5EZojYzDz", "outputId": "5db3d81c-0a18-4386-92c7-b20354e2c902" }, "outputs": [ { "data": { "text/plain": [ "odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state'])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "end.keys()" ] }, { "cell_type": "code", "execution_count": null, "id": "tL-0dkVkbeka", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tL-0dkVkbeka", "outputId": "2838c153-9344-4ebc-96fb-98331e351e72" }, "outputs": [ { "data": { "text/plain": [ "24" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(end.past_key_values)" ] }, { "cell_type": "code", "execution_count": null, "id": "2X62iV8iVNRU", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "id": "2X62iV8iVNRU", "outputId": "6b97e3ff-f8b7-40a1-9139-d473354b115c" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" ] }, { "cell_type": "code", "execution_count": null, "id": "2xlbMjFTTUBd", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2xlbMjFTTUBd", "outputId": "c0228562-0a8e-4198-9171-027a5d49e8fc" }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 2775, 7, 2783, 1463, 28, 7981, 63, 5253, 7, 11,\n", " 13353, 1, 0],\n", " [ 2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437,\n", " 5253, 7, 1]], device='cuda:0')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs.input_ids" ] }, { "cell_type": "markdown", "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0", "metadata": { "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0" }, "source": [ "The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits\n", "(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or\n", "'null' prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input\n", "prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,\n", "use a `guidance_scale=3` (default) for text and audio-conditional generation." ] }, { "cell_type": "markdown", "id": "d391b2a1-6376-4b69-b562-4388b731cf60", "metadata": { "id": "d391b2a1-6376-4b69-b562-4388b731cf60" }, "source": [ "### Audio-Prompted Generation\n", "\n", "The same `MusicgenProcessor` can be used to pre-process an audio prompt that is used for audio continuation. In the\n", "following example, we load an audio file using the 🤗 Datasets library, pre-process it using the processor class,\n", "and then forward the inputs to the model for generation:" ] }, { "cell_type": "code", "execution_count": null, "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8", "outputId": "81c95bfd-649d-424f-a764-0f1b881f37dd" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"sanchit-gandhi/gtzan\", split=\"train\", streaming=True)\n", "sample = next(iter(dataset))[\"audio\"]\n", "\n", "# take the first half of the audio sample\n", "sample[\"array\"] = sample[\"array\"][: len(sample[\"array\"]) // 2]\n", "\n", "inputs = processor(\n", " audio=sample[\"array\"],\n", " sampling_rate=sample[\"sampling_rate\"],\n", " text=[\"80s blues track with groovy saxophone\"],\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n", "\n", "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" ] }, { "cell_type": "markdown", "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc", "metadata": { "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc" }, "source": [ "To demonstrate batched audio-prompted generation, we'll slice our sample audio by two different proportions to give two audio samples of different length.\n", "Since the input audio prompts vary in length, they will be *padded* to the length of the longest audio sample in the batch before being passed to the model.\n", "\n", "To recover the final audio samples, the `audio_values` generated can be post-processed to remove padding by using the processor class once again:" ] }, { "cell_type": "code", "execution_count": null, "id": "5495f568-51ca-439d-b47b-8b52e89b78f1", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "id": "5495f568-51ca-439d-b47b-8b52e89b78f1", "outputId": "68866570-000b-4239-bed5-75bc21e828aa" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample = next(iter(dataset))[\"audio\"]\n", "\n", "# take the first quater of the audio sample\n", "sample_1 = sample[\"array\"][: len(sample[\"array\"]) // 4]\n", "\n", "# take the first half of the audio sample\n", "sample_2 = sample[\"array\"][: len(sample[\"array\"]) // 2]\n", "\n", "inputs = processor(\n", " audio=[sample_1, sample_2],\n", " sampling_rate=sample[\"sampling_rate\"],\n", " text=[\"80s blues track with groovy saxophone\", \"90s rock song with loud guitars and heavy drums\"],\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n", "\n", "# post-process to remove padding from the batched audio\n", "audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)\n", "\n", "Audio(audio_values[0], rate=sampling_rate)" ] }, { "cell_type": "markdown", "id": "viwTDmzl8ZDN", "metadata": { "id": "viwTDmzl8ZDN" }, "source": [ "## Generation Config\n", "\n", "The default parameters that control the generation process, such as sampling, guidance scale and number of generated tokens, can be found in the model's generation config, and updated as desired. Let's first inspect the default generation config:" ] }, { "cell_type": "code", "execution_count": null, "id": "0zM4notb8Y1g", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0zM4notb8Y1g", "outputId": "576755e5-42b9-48bc-be13-76322837fea0" }, "outputs": [ { "data": { "text/plain": [ "GenerationConfig {\n", " \"_from_model_config\": true,\n", " \"bos_token_id\": 2048,\n", " \"decoder_start_token_id\": 2048,\n", " \"do_sample\": true,\n", " \"guidance_scale\": 3.0,\n", " \"max_length\": 1500,\n", " \"pad_token_id\": 2048,\n", " \"transformers_version\": \"4.34.0.dev0\"\n", "}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.generation_config" ] }, { "cell_type": "markdown", "id": "DLSnSwau8jyW", "metadata": { "id": "DLSnSwau8jyW" }, "source": [ "Alright! We see that the model defaults to using sampling mode (`do_sample=True`), a guidance scale of 3, and a maximum generation length of 1500 (which is equivalent to 30s of audio). You can update any of these attributes to change the default generation parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "ensSj1IB81dA", "metadata": { "id": "ensSj1IB81dA" }, "outputs": [], "source": [ "# increase the guidance scale to 4.0\n", "model.generation_config.guidance_scale = 4.0\n", "\n", "# set the max new tokens to 256\n", "model.generation_config.max_new_tokens = 256\n", "\n", "# set the softmax sampling temperature to 1.5\n", "model.generation_config.temperature = 1.5" ] }, { "cell_type": "markdown", "id": "UjqGnfc-9ZFJ", "metadata": { "id": "UjqGnfc-9ZFJ" }, "source": [ "Re-running generation now will use the newly defined values in the generation config:" ] }, { "cell_type": "code", "execution_count": null, "id": "KAExrhDl9YvS", "metadata": { "id": "KAExrhDl9YvS" }, "outputs": [], "source": [ "audio_values = model.generate(**inputs.to(device))" ] }, { "cell_type": "markdown", "id": "HdGdoGAs84hS", "metadata": { "id": "HdGdoGAs84hS" }, "source": [ "Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting `do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the generation config." ] }, { "cell_type": "code", "execution_count": null, "id": "s__neSDH89q0", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "s__neSDH89q0", "outputId": "8ec39d3c-ac4e-4cce-ee89-1e936a453859" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 642560])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_values.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "EOMbP5f-imWD", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "EOMbP5f-imWD", "outputId": "df158b75-bcc9-4304-db5c-ddb0d1490c49" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'import os\\nimport torchaudio\\nimport numpy as np\\nimport torch\\nfrom tqdm import tqdm'" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"import os\n", "import torchaudio\n", "import numpy as np\n", "import torch\n", "from tqdm import tqdm\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "knVkOKRxhMAJ", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 192 }, "id": "knVkOKRxhMAJ", "outputId": "6b6d0063-70bf-4bd8-822a-5c2f2b6363dc" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'from datasets import load_dataset, Audio\\nfrom transformers import EncodecModel, AutoProcessor\\n\\n\\n# load a demonstration datasets\\nlibrispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\\n\\n# load the model + processor (for pre-processing the audio)\\nmodel = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\\nprocessor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\\n\\npath = \\'/content/Stim_Test_Run01_01_rock.wav\\'\\naudio_loaded, sr = torchaudio.load(path)\\nprint(audio_loaded.shape)\\naudio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\\n#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\\nprint(audio_loaded.shape)\\n# cast the audio data to the correct sampling rate for the model\\n#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\\n#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\\n\\n# pre-process the inputs\\nprint(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\\ninputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\\n\\n# explicitly encode then decode the audio inputs\\nencoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\\naudio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\\n\\n# or the equivalent with a forward pass\\naudio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\\n'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"from datasets import load_dataset, Audio\n", "from transformers import EncodecModel, AutoProcessor\n", "\n", "\n", "# load a demonstration datasets\n", "librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n", "\n", "# load the model + processor (for pre-processing the audio)\n", "model = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\n", "processor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\n", "\n", "path = '/content/Stim_Test_Run01_01_rock.wav'\n", "audio_loaded, sr = torchaudio.load(path)\n", "print(audio_loaded.shape)\n", "audio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\n", "#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\n", "print(audio_loaded.shape)\n", "# cast the audio data to the correct sampling rate for the model\n", "#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n", "#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\n", "\n", "# pre-process the inputs\n", "print(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\n", "inputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\n", "\n", "# explicitly encode then decode the audio inputs\n", "encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n", "audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\n", "\n", "# or the equivalent with a forward pass\n", "audio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "FZ-JP_DFsD_6", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FZ-JP_DFsD_6", "outputId": "e462f0f5-8267-449f-e8b3-c11a0a7dd593" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 642560])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_values.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "F-b6Q4TSjGcc", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 175 }, "id": "F-b6Q4TSjGcc", "outputId": "cfe3e974-ca58-48de-833f-12b1d8e4e1b8" }, "outputs": [ { "ename": "NameError", "evalue": "ignored", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maudio_codes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'encoder_outputs' is not defined" ] } ], "source": [ "encoder_outputs.audio_codes.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "Wv4xZ1gYlU6S", "metadata": { "id": "Wv4xZ1gYlU6S" }, "outputs": [], "source": [ "encoder_outputs.audio_codes.shape" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "013191f7a16c49fdbfc9b6cb3b0aa089": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1c58ed64d7144bcf9c58fe0f89364d61", "IPY_MODEL_ce32b8aaae2f4cd38d5ec78fefaa34ce", "IPY_MODEL_7ed4f572d3534679a3e1e4d90880bd71" ], "layout": "IPY_MODEL_fd741db4538f493588106d753a747593" } }, "09bbff09eaa44cdf92612c7f02f05f63": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "0aa3f4c09c854d90b9357d356e8ed46b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0cd8b835ee724b659c92fbee8eb62327": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_50323b27b7ff4823b733bffa97930218", "IPY_MODEL_abcf0eeed63f4331bbc9a044b2d3d65f", "IPY_MODEL_2fa893516c3946729fa0eda21550f658" ], "layout": "IPY_MODEL_7742549b504c4e71b4d0a2119d459936" } }, "1c58ed64d7144bcf9c58fe0f89364d61": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0aa3f4c09c854d90b9357d356e8ed46b", "placeholder": "​", "style": "IPY_MODEL_bbf7a9706dd64f6eb24d4b46ff52bc23", "value": "Downloading (…)lve/main/config.json: 100%" } }, "2fa893516c3946729fa0eda21550f658": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_410cc9979cc44f5aad48d17cf6042165", "placeholder": "​", "style": "IPY_MODEL_768bf7c8a965476c99693c9aa3ea89cb", "value": " 2.36G/2.36G [00:22<00:00, 235MB/s]" } }, "410cc9979cc44f5aad48d17cf6042165": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4d191abe24514b6a8a5bcb7aa4dd624c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f83972e889134265aff6737844971e6f", "max": 224, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_9fa6a2c3f81e4a159cf652d8a48acd37", "value": 224 } }, "50323b27b7ff4823b733bffa97930218": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6081dfc2472b4097b01bc7608c100ca3", "placeholder": "​", "style": "IPY_MODEL_a810f25c280d48539d7382064588efd7", "value": "Downloading model.safetensors: 100%" } }, "6081dfc2472b4097b01bc7608c100ca3": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "768bf7c8a965476c99693c9aa3ea89cb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "7742549b504c4e71b4d0a2119d459936": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7ed4f572d3534679a3e1e4d90880bd71": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f74afcde56c64ab389fed5bf7f5964b8", "placeholder": "​", "style": "IPY_MODEL_09bbff09eaa44cdf92612c7f02f05f63", "value": " 7.87k/7.87k [00:00<00:00, 288kB/s]" } }, "7f52d5efbb064170ac8d8681ae92f29b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "881fd640878c420cbc14dfd5f8516953": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "92804f5fa9ef49a094ce3d54b051ff9c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9350625ba0fe43949ea335a27f8e402d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_92804f5fa9ef49a094ce3d54b051ff9c", "placeholder": "​", "style": "IPY_MODEL_9ac4c059958446b1af03da2fe2e0ac20", "value": " 224/224 [00:00<00:00, 13.9kB/s]" } }, "9ac4c059958446b1af03da2fe2e0ac20": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "9fa6a2c3f81e4a159cf652d8a48acd37": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a2a0abc232c04d8c96704d6b012227aa": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_dbcc2550fed34fc4b9d1f9b5e7465b85", "placeholder": "​", "style": "IPY_MODEL_aca7c88abc684793880a4619257522c4", "value": "Downloading (…)neration_config.json: 100%" } }, "a43852c0c4754235a7cf3fb7221eed1d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a810f25c280d48539d7382064588efd7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "abcf0eeed63f4331bbc9a044b2d3d65f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b93f98b1284b4cf8bc325e9c4d9a2bf1", "max": 2364427288, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a43852c0c4754235a7cf3fb7221eed1d", "value": 2364427288 } }, "aca7c88abc684793880a4619257522c4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b93f98b1284b4cf8bc325e9c4d9a2bf1": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bbf7a9706dd64f6eb24d4b46ff52bc23": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ce32b8aaae2f4cd38d5ec78fefaa34ce": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7f52d5efbb064170ac8d8681ae92f29b", "max": 7866, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_881fd640878c420cbc14dfd5f8516953", "value": 7866 } }, "dbcc2550fed34fc4b9d1f9b5e7465b85": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e9895aee370a4d888cefa7a82cd90c00": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f74afcde56c64ab389fed5bf7f5964b8": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f83972e889134265aff6737844971e6f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f938b2a799d5454f8885d5d6bba24b94": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a2a0abc232c04d8c96704d6b012227aa", "IPY_MODEL_4d191abe24514b6a8a5bcb7aa4dd624c", "IPY_MODEL_9350625ba0fe43949ea335a27f8e402d" ], "layout": "IPY_MODEL_e9895aee370a4d888cefa7a82cd90c00" } }, "fd741db4538f493588106d753a747593": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } } } } }, "nbformat": 4, "nbformat_minor": 5 }