llinahosna commited on
Commit
d4ee025
·
verified ·
1 Parent(s): 9123fc9

Upload 31 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
2
+
3
+ RUN apt-get update && apt-get install -y \
4
+ git \
5
+ python3 \
6
+ python3-pip \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html \
10
+ && pip install -q \
11
+ git+https://github.com/borisdayma/dalle-mini.git \
12
+ git+https://github.com/patil-suraj/vqgan-jax.git
13
+
14
+ RUN pip install jupyter
15
+
16
+ WORKDIR /workspace
17
+
FUNDING.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ github: [borisdayma]
README.md CHANGED
@@ -1,13 +1,16 @@
1
- ---
2
- title: DALLEMINI
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 4.19.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
+ # Running Dalle-mini With Docker
2
+
3
+ This folder contains the Dockerfile needed to build a Docker image that can easily run Dalle-mini.
4
+
5
+ ## Inference
6
+
7
+ Steps to run inference with Dalle-mini are as follows:
8
+
9
+ 1. Build the docker image with ```dalle-mini/Docker/build_image.sh```
10
+ 2. Run the container with ```dalle-mini/run_docker_image```
11
+ 3. Navigate to ```/workspace/tools/inference/``` and run ```run_infer_notebook.sh```
12
+ 4. Click the Jupyter Notebook link and run through the notebook.
13
+
14
+ ### Inference Video Tutorial
15
+
16
+ Alteratively check out a video tutorial on how to run Dalle-mini on [Linux](https://www.youtube.com/watch?v=eWpzLIa6v9E&t=9s) and [Windows](https://www.youtube.com/watch?v=OqEuEe-xSKk&t=59s)
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "0.1.1"
2
+
3
+ from .model import DalleBart, DalleBartProcessor
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import os
4
+
5
+ import gradio as gr
6
+ from backend import get_images_from_backend
7
+
8
+ block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
9
+ backend_url = os.environ["BACKEND_SERVER"] + "/generate"
10
+
11
+
12
+ def infer(prompt):
13
+ response = get_images_from_backend(prompt, backend_url)
14
+ return response["images"]
15
+
16
+
17
+ with block:
18
+ gr.Markdown("<h1><center>DALL·E mini</center></h1>")
19
+ gr.Markdown(
20
+ "DALL·E mini is an AI model that generates images from any prompt you give!"
21
+ )
22
+ with gr.Group():
23
+ with gr.Box():
24
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
25
+
26
+ text = gr.Textbox(
27
+ label="Enter your prompt", show_label=False, max_lines=1
28
+ ).style(
29
+ border=(True, False, True, True),
30
+ margin=False,
31
+ rounded=(True, False, False, True),
32
+ container=False,
33
+ )
34
+ btn = gr.Button("Run").style(
35
+ margin=False,
36
+ rounded=(False, True, True, False),
37
+ )
38
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
39
+ grid=[3], height="auto"
40
+ )
41
+ text.submit(infer, inputs=text, outputs=gallery)
42
+ btn.click(infer, inputs=text, outputs=gallery)
43
+
44
+ gr.Markdown(
45
+ """___
46
+ <p style='text-align: center'>
47
+ Created by <a href="https://twitter.com/borisdayma" target="_blank">Boris Dayma</a> et al. 2021-2022
48
+ <br/>
49
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini-Generate-images-from-any-text-prompt--VmlldzoyMDE4NDAy" target="_blank">Project Report</a>
50
+ </p>"""
51
+ )
52
+
53
+
54
+ block.launch(enable_queue=False)
backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Client requests to Dalle-Mini Backend server
2
+
3
+ import base64
4
+ from io import BytesIO
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ class ServiceError(Exception):
11
+ def __init__(self, status_code):
12
+ self.status_code = status_code
13
+
14
+
15
+ def get_images_from_backend(prompt, backend_url):
16
+ r = requests.post(backend_url, json={"prompt": prompt})
17
+ if r.status_code == 200:
18
+ json = r.json()
19
+ images = json["images"]
20
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
21
+ version = json.get("version", "unknown")
22
+ return {"images": images, "version": version}
23
+ else:
24
+ raise ServiceError(r.status_code)
25
+
26
+
27
+ def get_model_version(url):
28
+ r = requests.get(url)
29
+ if r.status_code == 200:
30
+ version = r.json()["version"]
31
+ return version
32
+ else:
33
+ raise ServiceError(r.status_code)
build_docker.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ docker build . -t dalle-mini:latest
check_size.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Check file size
2
+
3
+ on:
4
+ pull_request:
5
+ branches: [main]
6
+
7
+ # to run this workflow manually from the Actions tab
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ sync-to-hub:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - name: Check large files
15
+ uses: ActionsDesk/lfs-warning@v2.0
16
+ with:
17
+ filesizelimit: 10485760 # = 10MB, so we can sync to HF spaces
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 2048,
7
+ "decoder_attention_heads": 32,
8
+ "decoder_ffn_dim": 4096,
9
+ "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 24,
11
+ "decoder_start_token_id": 16384,
12
+ "do_sample": true,
13
+ "dropout": 0.0,
14
+ "encoder_attention_heads": 32,
15
+ "encoder_ffn_dim": 4096,
16
+ "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 24,
18
+ "encoder_vocab_size": 50272,
19
+ "eos_token_id": 16385,
20
+ "force_ln_scale": false,
21
+ "gradient_checkpointing": false,
22
+ "image_length": 256,
23
+ "image_vocab_size": 16415,
24
+ "init_std": 0.01,
25
+ "is_encoder_decoder": true,
26
+ "ln_positions": "normformer",
27
+ "ln_type": "layernorm",
28
+ "max_length": 257,
29
+ "max_text_length": 64,
30
+ "min_length": 257,
31
+ "model_type": "dallebart",
32
+ "normalize_text": true,
33
+ "pad_token_id": 16385,
34
+ "scale_embedding": false,
35
+ "sinkhorn_iters": 1,
36
+ "tau_init": 0.05,
37
+ "tie_word_embeddings": false,
38
+ "use_absolute_position_embeddings": true,
39
+ "use_alibi": false,
40
+ "use_bias": false,
41
+ "use_cache": true,
42
+ "use_cosine_attention": false,
43
+ "use_deepnet_scaling": false,
44
+ "use_final_ln_decoder": true,
45
+ "use_final_ln_encoder": true,
46
+ "use_glu": true,
47
+ "use_head_scale": false,
48
+ "use_swin_position_embeddings": false
49
+ }
configuration.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model configuration """
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ from .utils import PretrainedFromWandbMixin
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
27
+ model_type = "dallebart"
28
+ keys_to_ignore_at_inference = ["past_key_values"]
29
+ attribute_map = {
30
+ "num_attention_heads": "encoder_attention_heads",
31
+ "hidden_size": "d_model",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ normalize_text=False,
37
+ encoder_vocab_size=50264,
38
+ image_vocab_size=16384, # encoded image token space
39
+ image_length=256, # number of encoded tokens
40
+ max_text_length=64, # max number of text tokens
41
+ encoder_layers=12,
42
+ encoder_ffn_dim=4096,
43
+ encoder_attention_heads=16,
44
+ decoder_layers=12,
45
+ decoder_ffn_dim=4096,
46
+ decoder_attention_heads=16,
47
+ activation_function="gelu",
48
+ d_model=1024,
49
+ dropout=0.1,
50
+ attention_dropout=0.0,
51
+ activation_dropout=0.0,
52
+ init_std=0.02,
53
+ scale_embedding=False,
54
+ gradient_checkpointing=True,
55
+ use_scan=None,
56
+ use_cache=True,
57
+ is_encoder_decoder=True,
58
+ forced_eos_token_id=None,
59
+ tie_word_embeddings=False, # different modalities and sizes
60
+ do_sample=True,
61
+ # transformer variants
62
+ use_bias=False, # use bias in attention and dense layers (except for lm_head)
63
+ ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
64
+ ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
65
+ use_head_scale=False, # used in NormFormer
66
+ use_cosine_attention=False, # used in Swin v2
67
+ tau_init=0.05, # used only in cosine attention (Swin v2)
68
+ use_absolute_position_embeddings=True, # default
69
+ use_swin_position_embeddings=False, # used in Swin v1/v2
70
+ use_deepnet_scaling=False, # used in Deepnet
71
+ use_glu=True, # "GLU Variants Improve Transformer"
72
+ use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
73
+ sinkhorn_iters=1, # used in SinkFormers
74
+ use_final_ln_encoder=True, # final layer normalization in encoder
75
+ use_final_ln_decoder=True, # final layer normalization in decoder
76
+ # parameters that should not be necessary but could affect results
77
+ force_ln_scale=False, # force scale in layernorm even when followed by dense layers
78
+ **kwargs,
79
+ ):
80
+ # text normalizer
81
+ self.normalize_text = normalize_text
82
+
83
+ # transformer variants
84
+ self.use_bias = use_bias
85
+ assert ln_type in [
86
+ "rmsnorm",
87
+ "layernorm",
88
+ ], "ln_type must be 'rmsnorm' or 'layernorm'"
89
+ self.ln_type = ln_type
90
+ if ln_positions == "deepnet":
91
+ ln_positions = "postln"
92
+ assert ln_positions in [
93
+ "normformer",
94
+ "swinv2",
95
+ "cogview",
96
+ "postln",
97
+ "preln",
98
+ ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
99
+ self.use_head_scale = use_head_scale
100
+ assert use_alibi is False, "use_alibi is not supported yet"
101
+ self.ln_positions = ln_positions
102
+ self.use_cosine_attention = use_cosine_attention
103
+ self.tau_init = tau_init
104
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
105
+ self.use_swin_position_embeddings = use_swin_position_embeddings
106
+ self.use_deepnet_scaling = use_deepnet_scaling
107
+ self.use_glu = use_glu
108
+ self.use_alibi = use_alibi
109
+ self.sinkhorn_iters = sinkhorn_iters
110
+ if ln_positions == "postln":
111
+ assert (
112
+ use_final_ln_encoder
113
+ ), "use_final_ln_encoder must be True when ln_positions is 'postln'"
114
+ assert (
115
+ use_final_ln_decoder
116
+ ), "use_final_ln_decoder must be True when ln_positions is 'postln'"
117
+ self.use_final_ln_encoder = use_final_ln_encoder
118
+ self.use_final_ln_decoder = use_final_ln_decoder
119
+ self.force_ln_scale = force_ln_scale
120
+
121
+ # common parameters
122
+ self.encoder_vocab_size = encoder_vocab_size
123
+ self.image_vocab_size = image_vocab_size
124
+ self.image_length = image_length
125
+ self.max_text_length = max_text_length
126
+ self.d_model = d_model
127
+ self.encoder_ffn_dim = encoder_ffn_dim
128
+ self.encoder_layers = encoder_layers
129
+ self.encoder_attention_heads = encoder_attention_heads
130
+ self.decoder_ffn_dim = decoder_ffn_dim
131
+ self.decoder_layers = decoder_layers
132
+ self.decoder_attention_heads = decoder_attention_heads
133
+ self.dropout = dropout
134
+ self.attention_dropout = attention_dropout
135
+ self.activation_dropout = activation_dropout
136
+ self.activation_function = activation_function
137
+ self.init_std = init_std
138
+ self.use_cache = use_cache
139
+ self.gradient_checkpointing = gradient_checkpointing
140
+ # all layers are the same in most configurations
141
+ self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
142
+ assert not (
143
+ self.use_scan and ln_positions == "swinv2"
144
+ ), "scan cannot be used with 'swinv2'"
145
+ self.scale_embedding = (
146
+ scale_embedding # scale factor will be sqrt(d_model) if True
147
+ )
148
+
149
+ # special token id's are appended to vocab if not provided
150
+ decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
151
+ bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
152
+ pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
153
+ eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
154
+
155
+ # we generate to image_length + 1 (for bos) by default
156
+ min_length = kwargs.pop("min_length", image_length + 1)
157
+ max_length = kwargs.pop("max_length", image_length + 1)
158
+
159
+ super().__init__(
160
+ # args required in parent class
161
+ is_encoder_decoder=is_encoder_decoder,
162
+ tie_word_embeddings=tie_word_embeddings,
163
+ forced_eos_token_id=forced_eos_token_id,
164
+ decoder_start_token_id=decoder_start_token_id,
165
+ bos_token_id=bos_token_id,
166
+ pad_token_id=pad_token_id,
167
+ eos_token_id=eos_token_id,
168
+ min_length=min_length,
169
+ max_length=max_length,
170
+ do_sample=do_sample,
171
+ **kwargs,
172
+ )
173
+
174
+ # ensure backward compatibility for BART CNN models
175
+ if self.forced_bos_token_id is None and kwargs.get(
176
+ "force_bos_token_to_be_generated", False
177
+ ):
178
+ self.forced_bos_token_id = self.bos_token_id
179
+ warnings.warn(
180
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
181
+ "The config can simply be saved and uploaded again to be fixed."
182
+ )
data.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from braceexpand import braceexpand
10
+ from datasets import Dataset, load_dataset
11
+
12
+ from .model.text import TextNormalizer
13
+
14
+
15
+ @dataclass
16
+ class Dataset:
17
+ dataset_repo_or_path: str
18
+ train_file: str = None
19
+ validation_file: str = None
20
+ streaming: bool = True
21
+ use_auth_token: bool = False
22
+ text_column: str = "caption"
23
+ encoding_column: str = "encoding"
24
+ max_train_samples: int = None
25
+ max_eval_samples: int = None
26
+ preprocessing_num_workers: int = None
27
+ overwrite_cache: bool = False
28
+ do_train: bool = False
29
+ do_eval: bool = True
30
+ seed_dataset: int = None
31
+ shard_by_host: bool = False
32
+ blank_caption_prob: float = 0.0
33
+ clip_score_column: str = "clip_score"
34
+ min_clip_score: float = None
35
+ max_clip_score: float = None
36
+ filter_column: str = None
37
+ filter_value: str = None
38
+ multi_eval_ds: bool = False
39
+ train_dataset: Dataset = field(init=False)
40
+ eval_dataset: Dataset = field(init=False)
41
+ other_eval_datasets: list = field(init=False)
42
+ rng_dataset: jnp.ndarray = field(init=False)
43
+ multi_hosts: bool = field(init=False)
44
+
45
+ def __post_init__(self):
46
+ if self.seed_dataset is None:
47
+ # create a random seed
48
+ self.seed_dataset = random.randint(0, 2**32 - 1)
49
+ # set numpy rng
50
+ self.np_rng = np.random.default_rng(self.seed_dataset)
51
+ self.multi_hosts = jax.process_count() > 1
52
+ # feed blank captions only in streaming mode for now
53
+ # otherwise dataset could be cached with same blanked captions
54
+ if self.blank_caption_prob:
55
+ assert (
56
+ self.streaming is True
57
+ ), "blank_caption_prob can only be used in streaming mode"
58
+ # define data_files
59
+ if self.train_file is not None or self.validation_file is not None:
60
+ # accept braceexpand notation
61
+ for k in ["train_file", "validation_file"]:
62
+ f = getattr(self, k)
63
+ if isinstance(f, str):
64
+ setattr(self, k, list(braceexpand(f)))
65
+ # for list of files, split training data shards by host
66
+ if (
67
+ isinstance(self.train_file, list)
68
+ and self.multi_hosts
69
+ and self.shard_by_host
70
+ ):
71
+ self.train_file = self.train_file[
72
+ jax.process_index() :: jax.process_count()
73
+ ]
74
+ data_files = {
75
+ "train": self.train_file,
76
+ "validation": self.validation_file,
77
+ }
78
+ else:
79
+ data_files = None
80
+
81
+ # multiple validation datasets
82
+ if self.multi_eval_ds:
83
+ assert Path(
84
+ self.dataset_repo_or_path
85
+ ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds"
86
+ data_files = {
87
+ split.name: [str(f) for f in split.glob("*.parquet")]
88
+ for split in Path(self.dataset_repo_or_path).glob("*")
89
+ }
90
+ # rename "valid" to "validation" if present for consistency
91
+ if "valid" in data_files:
92
+ data_files["validation"] = data_files["valid"]
93
+ del data_files["valid"]
94
+ self.dataset_repo_or_path = "parquet"
95
+
96
+ # load dataset
97
+ dataset = load_dataset(
98
+ self.dataset_repo_or_path,
99
+ data_files=data_files,
100
+ streaming=self.streaming,
101
+ use_auth_token=self.use_auth_token,
102
+ )
103
+ if self.do_train:
104
+ if "train" not in dataset:
105
+ raise ValueError("Training requires a training dataset")
106
+ self.train_dataset = dataset["train"]
107
+ if self.max_train_samples is not None:
108
+ self.train_dataset = (
109
+ self.train_dataset.take(self.max_train_samples)
110
+ if self.streaming
111
+ else self.train_dataset.select(range(self.max_train_samples))
112
+ )
113
+ if self.do_eval:
114
+ if "validation" not in dataset:
115
+ raise ValueError("Evaluating requires a validation dataset")
116
+ self.eval_dataset = dataset["validation"]
117
+ if self.max_eval_samples is not None:
118
+ self.eval_dataset = (
119
+ self.eval_dataset.take(self.max_eval_samples)
120
+ if self.streaming
121
+ else self.eval_dataset.select(range(self.max_eval_samples))
122
+ )
123
+ # other eval datasets
124
+ other_eval_splits = dataset.keys() - {"train", "validation"}
125
+ self.other_eval_datasets = {
126
+ split: dataset[split] for split in other_eval_splits
127
+ }
128
+
129
+ def preprocess(self, tokenizer, config):
130
+ # get required config variables
131
+ decoder_start_token_id = config.decoder_start_token_id
132
+ normalize_text = config.normalize_text
133
+ max_length = config.max_text_length
134
+
135
+ if self.streaming:
136
+ # we need to shuffle early in streaming mode
137
+ if hasattr(self, "train_dataset"):
138
+ self.train_dataset = self.train_dataset.shuffle(
139
+ buffer_size=5000, seed=self.seed_dataset
140
+ )
141
+ else:
142
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
143
+
144
+ # filter data
145
+ partial_filter_function = partial(
146
+ filter_function,
147
+ filter_column=self.filter_column,
148
+ filter_value=self.filter_value,
149
+ clip_score_column=self.clip_score_column,
150
+ min_clip_score=self.min_clip_score,
151
+ max_clip_score=self.max_clip_score,
152
+ )
153
+ for ds in ["train_dataset", "eval_dataset"]:
154
+ if hasattr(self, ds):
155
+ setattr(
156
+ self,
157
+ ds,
158
+ (
159
+ getattr(self, ds).filter(partial_filter_function)
160
+ if self.streaming
161
+ else getattr(self, ds).filter(
162
+ partial_filter_function,
163
+ num_proc=self.preprocessing_num_workers,
164
+ load_from_cache_file=not self.overwrite_cache,
165
+ desc="Filtering datasets",
166
+ )
167
+ ),
168
+ )
169
+ if hasattr(self, "other_eval_datasets"):
170
+ self.other_eval_datasets = {
171
+ split: (
172
+ ds.filter(partial_filter_function)
173
+ if self.streaming
174
+ else ds.filter(
175
+ partial_filter_function,
176
+ num_proc=self.preprocessing_num_workers,
177
+ load_from_cache_file=not self.overwrite_cache,
178
+ desc="Filtering datasets",
179
+ )
180
+ )
181
+ for split, ds in self.other_eval_datasets.items()
182
+ }
183
+
184
+ # normalize text
185
+ if normalize_text:
186
+ text_normalizer = TextNormalizer()
187
+ partial_normalize_function = partial(
188
+ normalize_function,
189
+ text_column=self.text_column,
190
+ text_normalizer=text_normalizer,
191
+ )
192
+ for ds in ["train_dataset", "eval_dataset"]:
193
+ if hasattr(self, ds):
194
+ setattr(
195
+ self,
196
+ ds,
197
+ (
198
+ getattr(self, ds).map(partial_normalize_function)
199
+ if self.streaming
200
+ else getattr(self, ds).map(
201
+ partial_normalize_function,
202
+ num_proc=self.preprocessing_num_workers,
203
+ load_from_cache_file=not self.overwrite_cache,
204
+ desc="Normalizing datasets",
205
+ )
206
+ ),
207
+ )
208
+ if hasattr(self, "other_eval_datasets"):
209
+ self.other_eval_datasets = {
210
+ split: (
211
+ ds.map(partial_normalize_function)
212
+ if self.streaming
213
+ else ds.map(
214
+ partial_normalize_function,
215
+ num_proc=self.preprocessing_num_workers,
216
+ load_from_cache_file=not self.overwrite_cache,
217
+ desc="Normalizing datasets",
218
+ )
219
+ )
220
+ for split, ds in self.other_eval_datasets.items()
221
+ }
222
+
223
+ # blank captions
224
+ if self.blank_caption_prob:
225
+ partial_blank_caption_function = partial(
226
+ blank_caption_function,
227
+ text_column=self.text_column,
228
+ blank_caption_prob=self.blank_caption_prob,
229
+ rng=self.np_rng,
230
+ )
231
+ if hasattr(self, "train_dataset"):
232
+ self.train_dataset = (
233
+ self.train_dataset.map(partial_blank_caption_function)
234
+ if self.streaming
235
+ else self.train_dataset.map(
236
+ partial_blank_caption_function,
237
+ num_proc=None
238
+ if self.seed_dataset
239
+ else self.preprocessing_num_workers,
240
+ load_from_cache_file=False,
241
+ desc="Blanking some captions",
242
+ )
243
+ )
244
+
245
+ # preprocess
246
+ partial_preprocess_function = partial(
247
+ preprocess_function,
248
+ tokenizer=tokenizer,
249
+ text_column=self.text_column,
250
+ encoding_column=self.encoding_column,
251
+ max_length=max_length,
252
+ decoder_start_token_id=decoder_start_token_id,
253
+ )
254
+ for ds in ["train_dataset", "eval_dataset"]:
255
+ if hasattr(self, ds):
256
+ setattr(
257
+ self,
258
+ ds,
259
+ (
260
+ getattr(self, ds).map(
261
+ partial_preprocess_function,
262
+ batched=True,
263
+ remove_columns=[
264
+ self.text_column,
265
+ self.encoding_column,
266
+ ],
267
+ )
268
+ if self.streaming
269
+ else getattr(self, ds).map(
270
+ partial_preprocess_function,
271
+ batched=True,
272
+ remove_columns=getattr(ds, "column_names"),
273
+ num_proc=self.preprocessing_num_workers,
274
+ load_from_cache_file=not self.overwrite_cache,
275
+ desc="Preprocessing datasets",
276
+ )
277
+ ),
278
+ )
279
+ if hasattr(self, "other_eval_datasets"):
280
+ self.other_eval_datasets = {
281
+ split: (
282
+ ds.map(
283
+ partial_preprocess_function,
284
+ batched=True,
285
+ remove_columns=[
286
+ self.text_column,
287
+ self.encoding_column,
288
+ ],
289
+ )
290
+ if self.streaming
291
+ else ds.map(
292
+ partial_preprocess_function,
293
+ batched=True,
294
+ remove_columns=getattr(ds, "column_names"),
295
+ num_proc=self.preprocessing_num_workers,
296
+ load_from_cache_file=not self.overwrite_cache,
297
+ desc="Preprocessing datasets",
298
+ )
299
+ )
300
+ for split, ds in self.other_eval_datasets.items()
301
+ }
302
+
303
+ def dataloader(self, split, batch_size, epoch=None):
304
+ def _dataloader_datasets_non_streaming(
305
+ dataset: Dataset,
306
+ rng: jax.random.PRNGKey = None,
307
+ ):
308
+ """
309
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
310
+ Shuffle batches if rng is set.
311
+ """
312
+ steps_per_epoch = len(dataset) // batch_size
313
+
314
+ if rng is not None:
315
+ batch_idx = jax.random.permutation(rng, len(dataset))
316
+ else:
317
+ batch_idx = jnp.arange(len(dataset))
318
+
319
+ batch_idx = batch_idx[
320
+ : steps_per_epoch * batch_size
321
+ ] # Skip incomplete batch.
322
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
323
+
324
+ for idx in batch_idx:
325
+ batch = dataset[idx]
326
+ batch = {k: jnp.array(v) for k, v in batch.items()}
327
+ yield batch
328
+
329
+ def _dataloader_datasets_streaming(
330
+ dataset: Dataset,
331
+ epoch: int,
332
+ ):
333
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
334
+ batch = {k: [] for k in keys}
335
+ first_loop = True # stop after one loop in some cases
336
+ while (self.multi_hosts and split == "train") or first_loop:
337
+ # in multi-host, we run forever (no epoch) as hosts need to stop
338
+ # at the same time and training data may not be split equally
339
+ # For validation data we put the entire batch on each host and then
340
+ # keep only the one specific to each host (could be improved but not necessary)
341
+ if epoch is not None:
342
+ assert split == "train"
343
+ # reshuffle training data at each epoch
344
+ dataset.set_epoch(epoch)
345
+ epoch += 1
346
+ for item in dataset:
347
+ for k in keys:
348
+ batch[k].append(item[k])
349
+ if len(batch[keys[0]]) == batch_size:
350
+ batch = {k: jnp.array(v) for k, v in batch.items()}
351
+ yield batch
352
+ batch = {k: [] for k in keys}
353
+ first_loop = False
354
+
355
+ if split == "train":
356
+ ds = self.train_dataset
357
+ elif split == "eval":
358
+ ds = self.eval_dataset
359
+ else:
360
+ ds = self.other_eval_datasets[split]
361
+
362
+ if self.streaming:
363
+ return _dataloader_datasets_streaming(ds, epoch)
364
+ else:
365
+ if split == "train":
366
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
367
+ return _dataloader_datasets_non_streaming(ds, input_rng)
368
+
369
+ @property
370
+ def length(self):
371
+ len_train_dataset, len_eval_dataset = None, None
372
+ if self.streaming:
373
+ # we don't know the length, let's just assume max_samples if defined
374
+ if self.max_train_samples is not None:
375
+ len_train_dataset = self.max_train_samples
376
+ if self.max_eval_samples is not None:
377
+ len_eval_dataset = self.max_eval_samples
378
+ else:
379
+ len_train_dataset = (
380
+ len(self.train_dataset) if hasattr(self, "train_dataset") else None
381
+ )
382
+ len_eval_dataset = (
383
+ len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
384
+ )
385
+ return len_train_dataset, len_eval_dataset
386
+
387
+
388
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
389
+ """
390
+ Shift input ids one token to the right.
391
+ """
392
+ shifted_input_ids = np.zeros(input_ids.shape)
393
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
394
+ shifted_input_ids[:, 0] = decoder_start_token_id
395
+ return shifted_input_ids
396
+
397
+
398
+ def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
399
+ if (
400
+ blank_caption_prob
401
+ and (rng.random() if rng is not None else np.random.random())
402
+ < blank_caption_prob
403
+ ):
404
+ example[text_column] = ""
405
+ return example
406
+
407
+
408
+ def normalize_function(example, text_column, text_normalizer):
409
+ example[text_column] = text_normalizer(example[text_column])
410
+ return example
411
+
412
+
413
+ def filter_function(
414
+ example,
415
+ min_clip_score,
416
+ max_clip_score,
417
+ clip_score_column,
418
+ filter_column,
419
+ filter_value,
420
+ ):
421
+ if min_clip_score is not None and example[clip_score_column] < min_clip_score:
422
+ return False
423
+ if max_clip_score is not None and example[clip_score_column] > max_clip_score:
424
+ return False
425
+ if filter_column is not None and example[filter_column] != filter_value:
426
+ return False
427
+ return True
428
+
429
+
430
+ def preprocess_function(
431
+ examples,
432
+ tokenizer,
433
+ text_column,
434
+ encoding_column,
435
+ max_length,
436
+ decoder_start_token_id,
437
+ ):
438
+ inputs = examples[text_column]
439
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
440
+ model_inputs = tokenizer(
441
+ inputs,
442
+ max_length=max_length,
443
+ padding="max_length",
444
+ truncation=True,
445
+ return_tensors="np",
446
+ )
447
+
448
+ # set up targets
449
+ # Note: labels correspond to our target indices
450
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
451
+ labels = examples[encoding_column]
452
+ labels = np.asarray(labels)
453
+
454
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
455
+ model_inputs["labels"] = labels
456
+
457
+ # In our case, this prepends the bos token and removes the last one
458
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
459
+ model_inputs["decoder_input_ids"] = decoder_input_ids
460
+
461
+ return model_inputs
distributed_shampoo.py ADDED
@@ -0,0 +1,2280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of distributed Shampoo optimizer from:
17
+ #
18
+ # Scalable Second Order Optimization for Deep Learning
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
20
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
21
+ #
22
+ # This implementation moves computation of inverse pth root back to the
23
+ # accelerator (if higher precision is available).
24
+ #
25
+ # Authors: Rohan Anil (rohananil at google dot com)
26
+ # & Vineet Gupta (vineet at google dot com)
27
+ #
28
+ """Distributed Shampoo Implementation."""
29
+
30
+ import enum
31
+ import functools
32
+ import itertools
33
+ from typing import Any, List, NamedTuple, Tuple
34
+
35
+ import chex
36
+ import jax
37
+ import jax.experimental.pjit as pjit
38
+ import jax.numpy as jnp
39
+ import numpy as np
40
+ import optax
41
+ from flax import struct
42
+ from jax import lax
43
+
44
+ from .quantization_utils import QuantizedValue
45
+ from .symmetric_matrices import symmetric_matrices
46
+
47
+ # Dtype for inverse-pth root routine
48
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
49
+ # jax_enable_x64 for this to work, otherwise it will default to float32.
50
+ _MAT_INV_PTH_ROOT_DTYPE = jnp.float64
51
+
52
+
53
+ @struct.dataclass
54
+ class TrainingMetrics:
55
+ inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
56
+ # TODO(rohananil): Add more important metrics to track during training.
57
+
58
+
59
+ # Per parameter optimizer state used in data-parallel training.
60
+ class ParameterStats(NamedTuple):
61
+ """State associated to each parameter of the model being trained."""
62
+
63
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
64
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
65
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
66
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
67
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
68
+ training_metrics: TrainingMetrics # Metrics (optional for training).
69
+
70
+
71
+ # For training extremely large model; We keep a global state with a concatenated
72
+ # statistics and preconditioner states for all vars. This is so that we can
73
+ # annotate the leading axis to be sharded to save memory at the cost of
74
+ # communication.
75
+ @struct.dataclass
76
+ class GlobalShardedParameterStats:
77
+ statistics: chex.Array # Statistics
78
+ preconditioners: chex.Array # Preconditioners
79
+ exponents: chex.Array # exponents
80
+
81
+
82
+ # These are per-parameter local states; All statistics here mirror the parameter
83
+ # Thus the sharding is copied over from the param specification.
84
+ @struct.dataclass
85
+ class LocalShardedParameterStats:
86
+ """State associated to each parameter of the model being trained."""
87
+
88
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
89
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
90
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
91
+ training_metrics: TrainingMetrics # Metrics (optional for training).
92
+ index_start: np.int32 = struct.field(
93
+ pytree_node=False
94
+ ) # Index into global statistics array
95
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
96
+
97
+
98
+ def init_training_metrics(num_statistics):
99
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
100
+ # num_statistics=0.
101
+ if not num_statistics:
102
+ return TrainingMetrics(jnp.array(0, jnp.float32))
103
+ else:
104
+ return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
105
+
106
+
107
+ def init_training_metrics_shapes(num_statistics):
108
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
109
+ # num_statistics=0.
110
+ if not num_statistics:
111
+ return TrainingMetrics([[], jnp.float32])
112
+ else:
113
+ return TrainingMetrics([[num_statistics], jnp.float32])
114
+
115
+
116
+ def init_training_metrics_pspec():
117
+ return TrainingMetrics(pjit.PartitionSpec())
118
+
119
+
120
+ class ShardedShampooStats(NamedTuple):
121
+ """Shampoo state in sharded mode."""
122
+
123
+ global_stats: Any
124
+ local_stats: Any
125
+
126
+
127
+ class ShampooState(NamedTuple):
128
+ count: chex.Array
129
+ stats: Any
130
+
131
+
132
+ class InitFnState(NamedTuple):
133
+ init_fn: Any
134
+ pspec_fn: Any
135
+ shape_and_dtype_fn: Any
136
+
137
+
138
+ class GraftingType(enum.IntEnum):
139
+ SGD = 1
140
+ ADAGRAD = 2
141
+ RMSPROP = 3
142
+ RMSPROP_NORMALIZED = 4
143
+ SQRT_N = 5
144
+ ADAGRAD_NORMALIZED = 6
145
+
146
+
147
+ def power_iteration(
148
+ matrix,
149
+ num_iters=100,
150
+ error_tolerance=1e-6,
151
+ precision=lax.Precision.HIGHEST,
152
+ ):
153
+ r"""Power iteration algorithm.
154
+
155
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
156
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
157
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
158
+
159
+ References:
160
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
161
+
162
+ Args:
163
+ matrix: the symmetric PSD matrix.
164
+ num_iters: Number of iterations.
165
+ error_tolerance: Iterative exit condition.
166
+ precision: precision XLA related flag, the available options are: a)
167
+ lax.Precision.DEFAULT (better step time, but not precise) b)
168
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
169
+ (best possible precision, slowest)
170
+
171
+ Returns:
172
+ eigen vector, eigen value
173
+ """
174
+ matrix_size = matrix.shape[-1]
175
+
176
+ def _iter_condition(state):
177
+ i, unused_v, unused_s, unused_s_v, run_step = state
178
+ return jnp.logical_and(i < num_iters, run_step)
179
+
180
+ def _iter_body(state):
181
+ """One step of power iteration."""
182
+ i, new_v, s, s_v, unused_run_step = state
183
+ new_v = new_v / jnp.linalg.norm(new_v)
184
+
185
+ s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
186
+ s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
187
+ return (
188
+ i + 1,
189
+ s_v,
190
+ s_new,
191
+ s_v,
192
+ jnp.greater(jnp.abs(s_new - s), error_tolerance),
193
+ )
194
+
195
+ # Figure out how to use step as seed for random.
196
+ v_0 = (
197
+ np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
198
+ )
199
+
200
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
201
+ _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state)
202
+ v_out = v_out / jnp.linalg.norm(v_out)
203
+ return v_out, s_out
204
+
205
+
206
+ def mat_power(
207
+ mat_m,
208
+ p,
209
+ precision=lax.Precision.HIGHEST,
210
+ ):
211
+ """A simple matrix power method. M^p where p can be TracedValue."""
212
+ power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
213
+
214
+ def _iter_condition(state):
215
+ i, _, _ = state
216
+ return i > 0
217
+
218
+ def _iter_body(state):
219
+ i, power, mat = state
220
+
221
+ power = jax.lax.cond(
222
+ i % 2 == 1,
223
+ lambda: jnp.matmul(mat, power, precision=precision),
224
+ lambda: power,
225
+ )
226
+ i //= 2
227
+ mat = jnp.matmul(mat, mat, precision=precision)
228
+ return i, power, mat
229
+
230
+ _, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
231
+ return result
232
+
233
+
234
+ def matrix_inverse_pth_root(
235
+ matrix,
236
+ p,
237
+ num_iters=100,
238
+ ridge_epsilon=1e-6,
239
+ error_tolerance=1e-6,
240
+ precision=lax.Precision.HIGHEST,
241
+ ):
242
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
243
+
244
+ This function uses the Coupled newton iterations algorithm for
245
+ the computation of a matrix's inverse pth root.
246
+
247
+
248
+ References:
249
+ [Functions of Matrices, Theory and Computation,
250
+ Nicholas J Higham, Pg 184, Eq 7.18](
251
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
252
+
253
+ Args:
254
+ matrix: the symmetric PSD matrix whose power it to be computed
255
+ p: exponent, for p a positive integer.
256
+ num_iters: Maximum number of iterations.
257
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
258
+ error_tolerance: Error indicator, useful for early termination.
259
+ precision: precision XLA related flag, the available options are: a)
260
+ lax.Precision.DEFAULT (better step time, but not precise) b)
261
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
262
+ (best possible precision, slowest)
263
+
264
+ Returns:
265
+ matrix^(-1/p)
266
+ """
267
+
268
+ # If the input is not square, materialize it from the concatenated form.
269
+ if matrix.shape[0] != matrix.shape[1]:
270
+ matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
271
+
272
+ assert matrix.shape[0] == matrix.shape[1]
273
+
274
+ # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
275
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
276
+ # jax_enable_x64 for this to work.
277
+ matrix_size = matrix.shape[0]
278
+ orig_dtype = matrix.dtype
279
+ matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
280
+ alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
281
+ identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
282
+ _, max_ev = power_iteration(
283
+ matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
284
+ )
285
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
286
+
287
+ def _iter_condition(state):
288
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
289
+ error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
290
+ return jnp.logical_and(i < num_iters, error_above_threshold)
291
+
292
+ def _iter_body(state):
293
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
294
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
295
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
296
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
297
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
298
+ # sometimes error increases after an iteration before decreasing and
299
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
300
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2)
301
+
302
+ if matrix_size == 1:
303
+ resultant_mat_h = (matrix + ridge_epsilon) ** alpha
304
+ error = jnp.array(0, jnp.float32)
305
+ else:
306
+ damped_matrix = matrix + ridge_epsilon * identity
307
+
308
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
309
+ new_mat_m_0 = damped_matrix * z
310
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
311
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
312
+ init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
313
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
314
+ _iter_condition, _iter_body, init_state
315
+ )
316
+ error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
317
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
318
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
319
+ resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
320
+ return resultant_mat_h, error
321
+
322
+
323
+ def merge_small_dims(shape_to_merge, max_dim):
324
+ """Merge small dimensions.
325
+
326
+ If there are some small dimensions, we collapse them:
327
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
328
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
329
+
330
+ Args:
331
+ shape_to_merge: Shape to merge small dimensions.
332
+ max_dim: Maximal dimension of output shape used in merging.
333
+
334
+ Returns:
335
+ Merged shape.
336
+ """
337
+ if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
338
+ return [1]
339
+
340
+ resulting_shape = []
341
+ product = 1
342
+ for d in shape_to_merge:
343
+ if product * d <= max_dim:
344
+ product *= d
345
+ else:
346
+ if product > 1:
347
+ resulting_shape.append(product)
348
+ product = d
349
+ if product > 1:
350
+ resulting_shape.append(product)
351
+ return resulting_shape
352
+
353
+
354
+ def pad_square_matrix(mat, max_size):
355
+ """Pad a square matrix up to max_size.
356
+
357
+ Args:
358
+ mat: a matrix to pad.
359
+ max_size: matrix size requested.
360
+
361
+ Returns:
362
+ Given M returns [[M, 0], [0, I]]
363
+ """
364
+ rows, cols = mat.shape
365
+ if rows != cols:
366
+ raise ValueError(
367
+ "Must have rows == cols, instead got " f"rows={rows}, cols={cols}"
368
+ )
369
+ if cols > max_size:
370
+ raise ValueError(
371
+ "Must have cols <= max_size. Instead got "
372
+ f"cols={cols}, max_size={max_size}."
373
+ )
374
+ if rows == max_size:
375
+ return mat
376
+ pad_size = max_size - rows
377
+
378
+ zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
379
+ zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
380
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
381
+ mat = jnp.concatenate([mat, zs1], 1)
382
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
383
+ return mat
384
+
385
+
386
+ def make_sliced_padding(
387
+ symmetric_block_size,
388
+ num_blocks,
389
+ starting_block,
390
+ dtype,
391
+ ):
392
+ """Returns padding for symmetric block matrix.
393
+
394
+ Specifically, the padding is given concatenated rectangular matrices
395
+ representing the lower-triangular rows below the starting block. For example,
396
+ if we want to pad the symmetric matrix
397
+
398
+ M = [[A, B^T]
399
+ [B, C]],
400
+
401
+ the desired output (in terms of the full matrix) with num_blocks = 4 is
402
+
403
+ M_padded = [[A, B^T, 0, 0]
404
+ [B, C, 0, 0]
405
+ [0, 0, I, 0]
406
+ 0, 0, 0, I].
407
+
408
+ We would represent M as the block matrix mat = [A, B, C]. In this form, the
409
+ additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
410
+ triangular parts in the third and fourth rows).
411
+
412
+ Args:
413
+ symmetric_block_size: The size of each block.
414
+ num_blocks: The total number of blocks.
415
+ starting_block: The block where to start the padding.
416
+ dtype: The type to use for the blocks.
417
+ """
418
+ if starting_block == num_blocks:
419
+ return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
420
+
421
+ blocks = []
422
+ for i in range(starting_block, num_blocks):
423
+ blocks.append(
424
+ jnp.zeros(
425
+ shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
426
+ )
427
+ )
428
+ blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
429
+ return jnp.concatenate(blocks, axis=-1)
430
+
431
+
432
+ def pad_block_symmetric_matrix(
433
+ mat,
434
+ symmetric_block_size,
435
+ max_num_blocks,
436
+ ):
437
+ """Returns the padded blocked symmetric matrix.
438
+
439
+ The size of the padded matrix will be:
440
+ [symmetric_block_size, symmetric_block_size * max_num_blocks]
441
+
442
+ The input matrix can either:
443
+ - Be square with size less or equal to symmetric_block_size. In this case,
444
+ mat will first be padded to a square matrix of size symmetric_block_size,
445
+ and then be padded again up to the full size of the blocked matrix.
446
+ - Be a rectangle with number of rows equal to block size.
447
+ In this case, number of columns must be a multiple of number of rows, and
448
+ the ratio must correspond to a block representation of a symmetric matrix.
449
+ That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
450
+ number of block rows represented by the matrix.
451
+
452
+ Args:
453
+ mat: The input block matrix.
454
+ symmetric_block_size: The size of blocks.
455
+ max_num_blocks: The largest number of blocks to pad to.
456
+ """
457
+ rows, cols = mat.shape
458
+ if rows > symmetric_block_size:
459
+ raise ValueError(
460
+ "Must have rows <= symmetric_block_size. Instead got "
461
+ f"rows={rows}, symmetric_block_size={symmetric_block_size}."
462
+ )
463
+ if rows > cols:
464
+ raise ValueError(
465
+ "Must have rows <= cols, instead got " f"rows={rows}, cols={cols}."
466
+ )
467
+ if cols > symmetric_block_size * max_num_blocks:
468
+ raise ValueError(
469
+ "Must have cols <= symmetric_block_size * max_num_blocks "
470
+ f"Instead got cols={cols}, "
471
+ f"symmetric_block_size={symmetric_block_size}, "
472
+ f"max_num_blocks={max_num_blocks}."
473
+ )
474
+ if rows < symmetric_block_size:
475
+ mat = pad_square_matrix(mat, max_size=symmetric_block_size)
476
+ # Update rows and cols after possibly padding in pad_square_matrix.
477
+ rows, cols = mat.shape
478
+ assert rows == symmetric_block_size
479
+ assert cols % rows == 0
480
+ filled_blocks = cols // rows
481
+ padding_blocks = make_sliced_padding(
482
+ symmetric_block_size=symmetric_block_size,
483
+ num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
484
+ starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
485
+ dtype=mat.dtype,
486
+ )
487
+ return jnp.concatenate([mat, padding_blocks], axis=-1)
488
+
489
+
490
+ def pad_vector(vec, max_size):
491
+ """Pad a vector to a max_size.
492
+
493
+ Args:
494
+ vec: a vector to pad.
495
+ max_size: matrix size requested.
496
+
497
+ Returns:
498
+ Given V returns [V, 0]
499
+ """
500
+ size = vec.shape[0]
501
+ assert size <= max_size
502
+ if size == max_size:
503
+ return vec
504
+ pad_size = max_size - size
505
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
506
+ return jnp.concatenate([vec, zs1], 0)
507
+
508
+
509
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
510
+ """Avoids wasteful buffer allocation with XLA."""
511
+
512
+ def _iter_body(unused_state):
513
+ results = compute_fn(*args, **kwargs)
514
+ return tuple([False] + list(results))
515
+
516
+ def _iter_condition(state):
517
+ return state[0]
518
+
519
+ results = jax.lax.while_loop(
520
+ _iter_condition, _iter_body, tuple([predicate] + init_state)
521
+ )
522
+ return tuple(results[1:])
523
+
524
+
525
+ class BlockPartitioner:
526
+ """Partitions a tensor into smaller tensors."""
527
+
528
+ def __init__(self, param, block_size):
529
+ self._shape = param.shape
530
+ self._splits = []
531
+ split_sizes = []
532
+ # We split params into smaller blocks. Here we store the metadata to make
533
+ # that split.
534
+ for i, d in enumerate(param.shape):
535
+ if 0 < block_size < d:
536
+ # d-1, otherwise split appends a 0-size array.
537
+ nsplit = (d - 1) // block_size
538
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
539
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
540
+ sizes[-1] = d - indices[-1]
541
+ self._splits.append((i, indices))
542
+ split_sizes.append(sizes)
543
+ else:
544
+ split_sizes.append(np.array([d], dtype=np.int32))
545
+ self._num_splits = len(split_sizes)
546
+ self._preconditioner_shapes = []
547
+ for t in itertools.product(*split_sizes):
548
+ self._preconditioner_shapes.extend([[d, d] for d in t])
549
+
550
+ def shapes_for_preconditioners(self):
551
+ return self._preconditioner_shapes
552
+
553
+ def num_splits(self):
554
+ return self._num_splits
555
+
556
+ def partition(self, tensor):
557
+ """Partition tensor into blocks."""
558
+
559
+ assert tensor.shape == self._shape
560
+ tensors = [tensor]
561
+ for (i, indices) in self._splits:
562
+ tensors_local = []
563
+ for t in tensors:
564
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
565
+ tensors = tensors_local
566
+ return tensors
567
+
568
+ def merge_partitions(self, partitions):
569
+ """Merge partitions back to original shape."""
570
+
571
+ for (i, indices) in reversed(self._splits):
572
+ n = len(indices) + 1
573
+ partial_merged_tensors = []
574
+ ind = 0
575
+ while ind < len(partitions):
576
+ partial_merged_tensors.append(
577
+ jnp.concatenate(partitions[ind : ind + n], axis=i)
578
+ )
579
+ ind += n
580
+ partitions = partial_merged_tensors
581
+ assert len(partitions) == 1
582
+ return partitions[0]
583
+
584
+
585
+ class Preconditioner:
586
+ """Compute statistics/shape from gradients for preconditioning."""
587
+
588
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
589
+ self._original_shape = param.shape
590
+ self._transformed_shape = param.shape
591
+ if best_effort_shape_interpretation:
592
+ self._transformed_shape = merge_small_dims(self._original_shape, block_size)
593
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
594
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
595
+
596
+ def statistics_from_grad(self, grad):
597
+ """Compute statistics from gradients.
598
+
599
+ Args:
600
+ grad: Gradient to compute statistics from.
601
+
602
+ Returns:
603
+ A list of gradient statistics for each partition.
604
+ """
605
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
606
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
607
+ stats = []
608
+ for g in partitioned_grads:
609
+ g_stats = []
610
+ rank = len(g.shape)
611
+ for i in range(rank):
612
+ axes = list(range(i)) + list(range(i + 1, rank))
613
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
614
+ g_stats.append(stat)
615
+ stats.extend(g_stats)
616
+ return stats
617
+
618
+ def shapes_for_preconditioners(self):
619
+ """Returns shape from statistics."""
620
+ return self._partitioner.shapes_for_preconditioners()
621
+
622
+ def exponent_for_preconditioner(self):
623
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
624
+ return 2 * len(self._transformed_shape)
625
+
626
+ def preconditioned_grad(self, grad, preconditioners):
627
+ """Precondition the gradient.
628
+
629
+ Args:
630
+ grad: A gradient tensor to precondition.
631
+ preconditioners: A list of preconditioners to apply.
632
+
633
+ Returns:
634
+ A preconditioned gradient.
635
+ """
636
+
637
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
638
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
639
+ preconditioned_partitioned_grads = []
640
+ num_splits = self._partitioner.num_splits()
641
+ for i, g in enumerate(partitioned_grads):
642
+ preconditioners_for_grad = preconditioners[
643
+ i * num_splits : (i + 1) * num_splits
644
+ ]
645
+ rank = len(g.shape)
646
+ precond_g = g
647
+ for j in range(rank):
648
+ precond_g = jnp.tensordot(
649
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]]
650
+ )
651
+ preconditioned_partitioned_grads.append(precond_g)
652
+ merged_grad = self._partitioner.merge_partitions(
653
+ preconditioned_partitioned_grads
654
+ )
655
+ return jnp.reshape(merged_grad, self._original_shape)
656
+
657
+
658
+ def _convert_to_parameter_stats(global_stats, local_stat):
659
+ """Creates parameter stats from sharded stats."""
660
+ index_start = int(local_stat.index_start)
661
+ index_end = int(len(local_stat.sizes)) + index_start
662
+ statistics = global_stats.statistics[index_start:index_end, :, :]
663
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
664
+ new_statistics = []
665
+ new_preconditioners = []
666
+ for i, size in enumerate(local_stat.sizes):
667
+ new_statistics.append(statistics[i][:size, :size])
668
+ new_preconditioners.append(preconditioners[i][:size, :size])
669
+ return ParameterStats(
670
+ local_stat.diagonal_statistics,
671
+ new_statistics,
672
+ new_preconditioners,
673
+ local_stat.diagonal_momentum,
674
+ local_stat.momentum,
675
+ local_stat.training_metrics,
676
+ )
677
+
678
+
679
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
680
+ """Creates sharded stats from parameter stats."""
681
+ return LocalShardedParameterStats(
682
+ parameter_stats.diagonal_statistics,
683
+ parameter_stats.diagonal_momentum,
684
+ parameter_stats.momentum,
685
+ parameter_stats.training_metrics,
686
+ local_stats.index_start,
687
+ local_stats.sizes,
688
+ )
689
+
690
+
691
+ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
692
+ """Adds errors back into local statistics."""
693
+ new_local_stats = []
694
+ for local_stat in local_stats:
695
+ if local_stat.sizes:
696
+ index_start = int(local_stat.index_start)
697
+ index_end = int(len(local_stat.sizes)) + index_start
698
+ per_stat_error = errors[index_start:index_end]
699
+ else:
700
+ per_stat_error = jnp.array(0, jnp.float32)
701
+ if local_stat.sizes:
702
+ per_stat_error = jnp.where(
703
+ jnp.logical_and(
704
+ per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
705
+ ),
706
+ per_stat_error,
707
+ local_stat.training_metrics.inverse_pth_root_errors,
708
+ )
709
+ new_local_stats.append(
710
+ LocalShardedParameterStats(
711
+ local_stat.diagonal_statistics,
712
+ local_stat.diagonal_momentum,
713
+ local_stat.momentum,
714
+ TrainingMetrics(per_stat_error),
715
+ local_stat.index_start,
716
+ local_stat.sizes,
717
+ )
718
+ )
719
+ return new_local_stats
720
+
721
+
722
+ def batch(x, num_devices):
723
+ """Batch `x` so that so that leading axis is num_devices."""
724
+ n = len(x)
725
+ b = int(n / num_devices)
726
+ return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
727
+
728
+
729
+ def unbatch(batched_values):
730
+ """Unbatch values across leading axis and return a list of elements."""
731
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
732
+ results = []
733
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
734
+ v_array = jnp.squeeze(v_array)
735
+ # b2 = batches (number of preconditioner computation) per core.
736
+ if b2 > 1:
737
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
738
+ results.append(jnp.squeeze(v))
739
+ else:
740
+ results.append(v_array)
741
+ return results
742
+
743
+
744
+ def distributed_shampoo(
745
+ learning_rate,
746
+ block_size,
747
+ beta1=0.9,
748
+ beta2=0.999,
749
+ diagonal_epsilon=1e-10,
750
+ matrix_epsilon=1e-6,
751
+ weight_decay=0.0,
752
+ start_preconditioning_step=5,
753
+ preconditioning_compute_steps=1,
754
+ statistics_compute_steps=1,
755
+ best_effort_shape_interpretation=True,
756
+ graft_type=GraftingType.SGD,
757
+ nesterov=True,
758
+ exponent_override=0,
759
+ # Pass pmap 'batch axis name' in pmap mode.
760
+ batch_axis_name=None,
761
+ ### Only set following 3 params in pjit/spmd mode.
762
+ ### WARNING: Experimental
763
+ statistics_partition_spec=None,
764
+ preconditioner_partition_spec=None,
765
+ num_devices_for_pjit=None,
766
+ shard_optimizer_states=False,
767
+ ###
768
+ ### Experimental memory reduction mode
769
+ best_effort_memory_usage_reduction=False,
770
+ ###
771
+ inverse_failure_threshold=0.1,
772
+ moving_average_for_momentum=False,
773
+ skip_preconditioning_dim_size_gt=4096,
774
+ clip_by_scaled_gradient_norm=None,
775
+ precision=lax.Precision.HIGHEST,
776
+ ):
777
+ """Distributed Shampoo optimizer.
778
+
779
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
780
+ variant of full-matrix Adagrad), that provides significant convergence and
781
+ wall-clock time improvements compared to conventional first-order methods,
782
+ and that has been shown to scale to large state-of-the-art deep learning
783
+ models.
784
+
785
+ References:
786
+ Scalable Second Order Optimization for Deep Learning,
787
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
788
+
789
+ Preprint: https://arxiv.org/abs/2002.09018
790
+
791
+ Args:
792
+ learning_rate: the step size used to update the parameters.
793
+ block_size: Block size for large layers (if > 0). Preconditioning compute
794
+ operation is cubic in the dimension of the tensor. Block size allows us to
795
+ chunk the layers into sub-layers of maximal dimension dictated by this
796
+ value. Use 128 as default (increase if you have compute budget).
797
+ beta1: momentum parameter.
798
+ beta2: second moment averaging parameter.
799
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
800
+ to AdaGrad is enabled).
801
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
802
+ root. If you are running in f32 precision for inverse pth root
803
+ (recommended today) this can go upto 1e-6. If you have latest hardware
804
+ with native f64 precision, set this upto 1e-12.
805
+ weight_decay: Weight decay for regularization.
806
+ start_preconditioning_step: When to start Shampoo update before which
807
+ diagonal update is used. This is because we dont have enough information
808
+ to do stable inverse.
809
+ preconditioning_compute_steps: How often to compute preconditioner.
810
+ Performance tuning params for controlling memory and compute requirements.
811
+ Ideally set this and statistics_compute_steps params to 1.
812
+ statistics_compute_steps: How often to compute statistics.
813
+ best_effort_shape_interpretation: If there are some small dimensions,
814
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
815
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
816
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
817
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
818
+ where SGD/AdaGrad is already well tuned.
819
+ nesterov: Nesterov momentum.
820
+ exponent_override: Override the exponent used in matrix inverse.
821
+ batch_axis_name: labeled axis over pmap for data-parallel training the
822
+ optimizer used for.
823
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
824
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
825
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
826
+ shard_optimizer_states: Shard optimizer states to save memory in model
827
+ parallel training.
828
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
829
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
830
+ statistics, preconditioners -> jnp.int16 + diagonals
831
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
832
+ determine that using this threshold.
833
+ moving_average_for_momentum: Whether to use moving average for momentum
834
+ instead of exponential moving average.
835
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
836
+ greater than this value.
837
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
838
+ using RMSProp Grafting).
839
+ precision: precision XLA related flag, the available options are: a)
840
+ lax.Precision.DEFAULT (better step time, but not precise) b)
841
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
842
+ (best possible precision, slowest)
843
+
844
+ Returns:
845
+ a GradientTransformation.
846
+ """
847
+
848
+ def _graft_type_has_diagonal_statistics():
849
+ """Returns True if using diagonal first order method for grafting."""
850
+ return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N
851
+
852
+ def _graft_type_has_diagonal_momentum_states():
853
+ """Returns False if using SQRT_N for grafting."""
854
+ return graft_type != GraftingType.SQRT_N
855
+
856
+ def quantized_dtype_for_momentum_buffers():
857
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
858
+
859
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
860
+ def quantized_dtype_for_diagonal_statistics_buffers():
861
+ return jnp.float32
862
+
863
+ # Preconditioner and statistics are both stores as int16 in this mode.
864
+ # We take out the diagonal to make quantization easier.
865
+ def quantized_dtype_for_second_moment_statistics_buffers():
866
+ return (
867
+ jnp.int16
868
+ if best_effort_memory_usage_reduction and batch_axis_name
869
+ else jnp.float32
870
+ )
871
+
872
+ # Preconditioner and statistics are both stores as int16 in this mode.
873
+ # We take out the diagonal to make quantization easier.
874
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
875
+ return (
876
+ jnp.int16
877
+ if best_effort_memory_usage_reduction and batch_axis_name
878
+ else jnp.float32
879
+ )
880
+
881
+ def _to_float(maybe_quantized):
882
+ if isinstance(maybe_quantized, QuantizedValue):
883
+ return maybe_quantized.to_float()
884
+ else:
885
+ return maybe_quantized
886
+
887
+ def _maybe_quantize_statistics(statistics_list):
888
+ return _maybe_quantize_matrices_with_dtype(
889
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers()
890
+ )
891
+
892
+ def _maybe_quantize_preconditioners(statistics_list):
893
+ return _maybe_quantize_matrices_with_dtype(
894
+ statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers()
895
+ )
896
+
897
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
898
+ if quantized_dtype != jnp.float32:
899
+ return [
900
+ QuantizedValue.from_float_value(
901
+ s, quantized_dtype, extract_diagonal=True
902
+ )
903
+ for s in statistics_list
904
+ ]
905
+ else:
906
+ return statistics_list
907
+
908
+ def _maybe_dequantize_preconditioners(preconditioner_list):
909
+ return _maybe_dequantize_matrices_with_dtype(
910
+ preconditioner_list,
911
+ quantized_dtype_for_second_moment_preconditioner_buffers(),
912
+ )
913
+
914
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
915
+ if quantized_dtype != jnp.float32:
916
+ return [s.to_float() for s in statistics_list]
917
+ else:
918
+ return statistics_list
919
+
920
+ def _quantize_diagonal_statistics(diagonal_statistics):
921
+ return QuantizedValue.from_float_value(
922
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers()
923
+ )
924
+
925
+ def _quantize_momentum(momentum_statistics):
926
+ return QuantizedValue.from_float_value(
927
+ momentum_statistics, quantized_dtype_for_momentum_buffers()
928
+ )
929
+
930
+ def sharded_init_fn(params):
931
+ """Returns optimizer state (for PJIT mode).
932
+
933
+ Args:
934
+ params: the parameters that should be updated.
935
+ """
936
+ params_flat, treedef = jax.tree_flatten(params)
937
+ # Find max size to pad to.
938
+ max_size = 0
939
+ for param in params_flat:
940
+ preconditioner = Preconditioner(
941
+ param, block_size, best_effort_shape_interpretation
942
+ )
943
+ if not _skip_preconditioning(param):
944
+ shapes = preconditioner.shapes_for_preconditioners()
945
+ sizes = [s[0] for s in shapes]
946
+ max_size = max(max(sizes), max_size)
947
+
948
+ padded_statistics = []
949
+ padded_preconditioners = []
950
+ local_stats_flat = []
951
+ exponents = []
952
+ for param in params_flat:
953
+ preconditioner = Preconditioner(
954
+ param, block_size, best_effort_shape_interpretation
955
+ )
956
+ shapes = preconditioner.shapes_for_preconditioners()
957
+ sizes = []
958
+
959
+ statistics = []
960
+ preconditioners = []
961
+ index_start = len(padded_statistics)
962
+ if not _skip_preconditioning(param):
963
+ sizes = [s[0] for s in shapes]
964
+ shapes = preconditioner.shapes_for_preconditioners()
965
+ statistics = [
966
+ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
967
+ for s in shapes
968
+ ]
969
+ preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
970
+ padded_statistics.extend(statistics)
971
+ padded_preconditioners.extend(preconditioners)
972
+ exponent = (
973
+ preconditioner.exponent_for_preconditioner()
974
+ if exponent_override == 0
975
+ else exponent_override
976
+ )
977
+ exponents.extend([exponent] * len(shapes))
978
+
979
+ diagonal_statistics = []
980
+ if _graft_type_has_diagonal_statistics():
981
+ diagonal_statistics = jnp.zeros_like(param)
982
+
983
+ diagonal_momentum = _quantize_momentum([])
984
+ momentum = _quantize_momentum(jnp.zeros_like(param))
985
+ if _graft_type_has_diagonal_momentum_states():
986
+ diagonal_momentum = _quantize_momentum((jnp.zeros_like(param)))
987
+
988
+ local_stats_flat.append(
989
+ LocalShardedParameterStats(
990
+ _quantize_diagonal_statistics(diagonal_statistics),
991
+ diagonal_momentum,
992
+ momentum,
993
+ init_training_metrics(len(sizes)),
994
+ index_start,
995
+ sizes,
996
+ )
997
+ )
998
+
999
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1000
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
1001
+ if max_size == 0:
1002
+ to_pad = num_devices_for_pjit
1003
+ max_size = block_size
1004
+ stat_dtype = jnp.float32
1005
+ else:
1006
+ stat_dtype = padded_statistics[0].dtype
1007
+ # Pad the statistics and preconditioner matrices to be a multiple of
1008
+ # num devices.
1009
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1010
+ # is split on.
1011
+ padded_statistics.extend(
1012
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1013
+ )
1014
+ padded_preconditioners.extend(
1015
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1016
+ )
1017
+ exponents.extend([1 for _ in range(to_pad)])
1018
+ global_stats = GlobalShardedParameterStats(
1019
+ jnp.stack(padded_statistics),
1020
+ jnp.stack(padded_preconditioners),
1021
+ jnp.stack(exponents),
1022
+ )
1023
+ return ShampooState(
1024
+ count=jnp.zeros([], jnp.int32),
1025
+ stats=ShardedShampooStats(global_stats, local_stats),
1026
+ )
1027
+
1028
+ def _max_statistics_size_from_params(params):
1029
+ max_size = 0
1030
+ for param in params:
1031
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1032
+ preconditioner = Preconditioner(
1033
+ param_clone, block_size, best_effort_shape_interpretation
1034
+ )
1035
+ if not _skip_preconditioning(param):
1036
+ shapes = preconditioner.shapes_for_preconditioners()
1037
+ sizes = [s[0] for s in shapes]
1038
+ max_size = max(max(sizes), max_size)
1039
+ return max_size
1040
+
1041
+ def _remove_leading_sharding_annotation(pspec):
1042
+ """Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
1043
+ # None and PSpec(None) are valid PSpecs.
1044
+ if pspec and len(pspec) > 1:
1045
+ return pjit.PartitionSpec(*pspec[1:])
1046
+ else:
1047
+ return []
1048
+
1049
+ def sharded_init_partition_spec_fn(
1050
+ params, params_partition_spec, partition_spec_for_statistics
1051
+ ):
1052
+ """Returns a parallel state tree with PartitionSpec associated with state.
1053
+
1054
+
1055
+ Args:
1056
+ params: A pytree with params.
1057
+ params_partition_spec: A pytree with PartitionSpec for params.
1058
+ partition_spec_for_statistics: PartitionSpec for the statistics.
1059
+ """
1060
+ # Parallel lists of spec, and params.
1061
+ param_pspec_flat, _ = jax.tree_flatten(
1062
+ params_partition_spec, is_leaf=lambda x: x is None
1063
+ )
1064
+ params_flat, treedef = jax.tree_flatten(params)
1065
+ assert param_pspec_flat
1066
+ assert params_flat
1067
+ # Step is replicated across cores.
1068
+ # None means cores.
1069
+ local_stats_flat = []
1070
+ num_statistics = 0
1071
+ for param, param_pspec in zip(params_flat, param_pspec_flat):
1072
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1073
+ preconditioner = Preconditioner(
1074
+ param_clone, block_size, best_effort_shape_interpretation
1075
+ )
1076
+ shapes = preconditioner.shapes_for_preconditioners()
1077
+ sizes = []
1078
+
1079
+ index_start = num_statistics
1080
+ if not _skip_preconditioning(param):
1081
+ sizes = [s[0] for s in shapes]
1082
+ shapes = preconditioner.shapes_for_preconditioners()
1083
+ num_statistics += len(shapes)
1084
+
1085
+ diagonal_statistics_pspec = []
1086
+ diagonal_statistics_scale_pspec = []
1087
+ diagonal_statistics_shape = []
1088
+ if _graft_type_has_diagonal_statistics():
1089
+ # Identically shaped param.
1090
+ diagonal_statistics_pspec = param_pspec
1091
+ diagonal_statistics_shape = list(param.shape)
1092
+ if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
1093
+ diagonal_statistics_scale_pspec = (
1094
+ _remove_leading_sharding_annotation(param_pspec)
1095
+ )
1096
+
1097
+ m1_pspec = []
1098
+ m1_shape = []
1099
+ m1_scale_pspec = []
1100
+ if _graft_type_has_diagonal_momentum_states():
1101
+ m1_pspec = param_pspec
1102
+ m1_shape = list(param.shape)
1103
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1104
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1105
+
1106
+ m2_pspec = param_pspec
1107
+ m2_scale_pspec = []
1108
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1109
+ m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1110
+
1111
+ local_stats_flat.append(
1112
+ LocalShardedParameterStats(
1113
+ QuantizedValue(
1114
+ diagonal_statistics_pspec,
1115
+ [],
1116
+ diagonal_statistics_scale_pspec,
1117
+ quantized_dtype_for_diagonal_statistics_buffers(),
1118
+ False,
1119
+ diagonal_statistics_shape,
1120
+ ),
1121
+ QuantizedValue(
1122
+ m1_pspec,
1123
+ [],
1124
+ m1_scale_pspec,
1125
+ quantized_dtype_for_momentum_buffers(),
1126
+ False,
1127
+ m1_shape,
1128
+ ),
1129
+ QuantizedValue(
1130
+ m2_pspec,
1131
+ [],
1132
+ m2_scale_pspec,
1133
+ quantized_dtype_for_momentum_buffers(),
1134
+ False,
1135
+ list(param.shape),
1136
+ ),
1137
+ init_training_metrics_pspec(),
1138
+ index_start,
1139
+ sizes,
1140
+ )
1141
+ )
1142
+
1143
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1144
+ global_stats = GlobalShardedParameterStats(
1145
+ partition_spec_for_statistics,
1146
+ partition_spec_for_statistics,
1147
+ pjit.PartitionSpec(),
1148
+ )
1149
+ count_pspec = pjit.PartitionSpec()
1150
+ return ShampooState(
1151
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
1152
+ )
1153
+
1154
+ def sharded_init_shape_and_dtype_fn(params):
1155
+ """Returns a parallel state tree with shape, dtype associated with state.
1156
+
1157
+
1158
+ Args:
1159
+ params: A pytree with params.
1160
+ """
1161
+ # Parallel lists of spec, and params.
1162
+ params_flat, treedef = jax.tree_flatten(params)
1163
+ assert params_flat
1164
+ # Step is replicated across cores.
1165
+ # None means cores.
1166
+ local_stats_flat = []
1167
+ num_statistics = 0
1168
+ for param in params_flat:
1169
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1170
+ preconditioner = Preconditioner(
1171
+ param_clone, block_size, best_effort_shape_interpretation
1172
+ )
1173
+ shapes = preconditioner.shapes_for_preconditioners()
1174
+ sizes = []
1175
+
1176
+ index_start = num_statistics
1177
+ if not _skip_preconditioning(param):
1178
+ sizes = [s[0] for s in shapes]
1179
+ shapes = preconditioner.shapes_for_preconditioners()
1180
+ num_statistics += len(shapes)
1181
+
1182
+ diagonal_statistics_shape_and_dtype = []
1183
+ diagonal_statistics_scale_shape_and_dtype = []
1184
+ if _graft_type_has_diagonal_statistics():
1185
+ diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1186
+ qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1187
+ if qdtype != jnp.float32:
1188
+ diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
1189
+ diagonal_statistics_scale_shape_and_dtype = [
1190
+ list(param.shape)[1:],
1191
+ param.dtype,
1192
+ ]
1193
+
1194
+ qdtype = quantized_dtype_for_momentum_buffers()
1195
+ m1_shape_and_dtype = []
1196
+ m1_scale_shape_and_dtype = []
1197
+ if _graft_type_has_diagonal_momentum_states():
1198
+ m1_shape_and_dtype = [list(param.shape), qdtype]
1199
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1200
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1201
+
1202
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1203
+ m2_scale_shape_and_dtype = []
1204
+ if qdtype != jnp.float32:
1205
+ m2_shape_and_dtype = [list(param.shape), qdtype]
1206
+ m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1207
+
1208
+ local_stats_flat.append(
1209
+ LocalShardedParameterStats(
1210
+ QuantizedValue(
1211
+ diagonal_statistics_shape_and_dtype,
1212
+ [],
1213
+ diagonal_statistics_scale_shape_and_dtype,
1214
+ quantized_dtype_for_diagonal_statistics_buffers(),
1215
+ False,
1216
+ list(param.shape),
1217
+ ),
1218
+ QuantizedValue(
1219
+ m1_shape_and_dtype,
1220
+ [],
1221
+ m1_scale_shape_and_dtype,
1222
+ quantized_dtype_for_momentum_buffers(),
1223
+ False,
1224
+ list(param.shape),
1225
+ ),
1226
+ QuantizedValue(
1227
+ m2_shape_and_dtype,
1228
+ [],
1229
+ m2_scale_shape_and_dtype,
1230
+ quantized_dtype_for_momentum_buffers(),
1231
+ False,
1232
+ list(param.shape),
1233
+ ),
1234
+ init_training_metrics_shapes(len(sizes)),
1235
+ index_start,
1236
+ sizes,
1237
+ )
1238
+ )
1239
+
1240
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1241
+ max_statistics_size = _max_statistics_size_from_params(params_flat)
1242
+ to_pad = -num_statistics % num_devices_for_pjit
1243
+ num_statistics += to_pad
1244
+ if num_statistics == 0:
1245
+ num_statistics = num_devices_for_pjit
1246
+ max_statistics_size = block_size
1247
+ statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1248
+ global_stats = GlobalShardedParameterStats(
1249
+ [statistics_shape, jnp.float32],
1250
+ [statistics_shape, jnp.float32],
1251
+ [[num_statistics], jnp.int32],
1252
+ )
1253
+ return ShampooState(
1254
+ count=[[], jnp.float32],
1255
+ stats=ShardedShampooStats(global_stats, local_stats),
1256
+ )
1257
+
1258
+ def sharded_update_fn(grads, state, params):
1259
+ """Transform the input gradient and update all statistics in sharded mode.
1260
+
1261
+ Args:
1262
+ grads: the gradient tensors for the parameters.
1263
+ state: a named tuple containing the state of the optimizer
1264
+ params: the parameters that should be updated.
1265
+
1266
+ Returns:
1267
+ A tuple containing the new parameters and the new optimizer state.
1268
+ """
1269
+ params_flat, treedef = jax.tree_flatten(params)
1270
+ grads_flat = treedef.flatten_up_to(grads)
1271
+
1272
+ global_stats = state.stats.global_stats
1273
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
1274
+ stats_flat = [
1275
+ _convert_to_parameter_stats(global_stats, local_stat)
1276
+ for local_stat in local_stats_flat
1277
+ ]
1278
+ new_stats_flat = jax.tree_multimap(
1279
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
1280
+ grads_flat,
1281
+ stats_flat,
1282
+ params_flat,
1283
+ )
1284
+
1285
+ outputs = jax.tree_multimap(
1286
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
1287
+ grads_flat,
1288
+ new_stats_flat,
1289
+ params_flat,
1290
+ )
1291
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1292
+
1293
+ updates = jax.tree_unflatten(treedef, updates_flat)
1294
+ # Create new local_stats
1295
+ new_local_stats_flat = [
1296
+ _convert_from_parameter_stats(new_stat, local_stat)
1297
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
1298
+ ]
1299
+
1300
+ max_size = global_stats.statistics.shape[1]
1301
+ new_padded_statistics = []
1302
+ for stat in new_stats_flat:
1303
+ new_padded_statistics.extend(
1304
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
1305
+ )
1306
+
1307
+ # Create global stats
1308
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
1309
+ # stack/pad can be obviated away.
1310
+ # Pad the statistics and preconditioner matrices to be a multiple of
1311
+ # num devices.
1312
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1313
+ # is split on.
1314
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
1315
+ new_padded_statistics.extend(
1316
+ [
1317
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
1318
+ for _ in range(to_pad)
1319
+ ]
1320
+ )
1321
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
1322
+ new_stacked_padded_statistics = pjit.with_sharding_constraint(
1323
+ new_stacked_padded_statistics, statistics_partition_spec
1324
+ )
1325
+
1326
+ def _internal_inverse_pth_root_all():
1327
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1328
+ new_stacked_padded_statistics,
1329
+ global_stats.exponents,
1330
+ statistics_partition_spec,
1331
+ )
1332
+ return preconditioners, errors
1333
+
1334
+ if preconditioning_compute_steps == 1:
1335
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
1336
+ else:
1337
+ # Passing statistics instead of preconditioners as they are similarly
1338
+ # shaped tensors. Note statistics will be ignored as we are passing in
1339
+ # a large init value for error.
1340
+ preconditioners_init = new_stacked_padded_statistics
1341
+ n = new_stacked_padded_statistics.shape[0]
1342
+ errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
1343
+ init_state = [preconditioners_init, errors_init]
1344
+ perform_step = state.count % preconditioning_compute_steps == 0
1345
+ new_preconditioners, errors = efficient_cond(
1346
+ perform_step, _internal_inverse_pth_root_all, init_state
1347
+ )
1348
+
1349
+ new_local_stats_flat = _add_error_into_local_stats(
1350
+ new_local_stats_flat, errors, inverse_failure_threshold
1351
+ )
1352
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
1353
+ errors = errors.reshape((-1, 1, 1))
1354
+ predicate = jnp.logical_or(
1355
+ jnp.isnan(errors), errors >= inverse_failure_threshold
1356
+ ).astype(new_preconditioners.dtype)
1357
+ # TODO(rohananil): Check for numerical instabilities.
1358
+ new_conditional_preconditioners = (
1359
+ predicate * global_stats.preconditioners
1360
+ + (1.0 - predicate) * new_preconditioners
1361
+ )
1362
+ new_global_stats = GlobalShardedParameterStats(
1363
+ new_stacked_padded_statistics,
1364
+ new_conditional_preconditioners,
1365
+ global_stats.exponents,
1366
+ )
1367
+ new_shampoo_state = ShampooState(
1368
+ count=state.count + 1,
1369
+ stats=ShardedShampooStats(new_global_stats, new_local_stats),
1370
+ )
1371
+ return updates, new_shampoo_state
1372
+
1373
+ def init_fn(params):
1374
+ """Initialise the optimiser's state."""
1375
+
1376
+ def _init(param):
1377
+ preconditioner = Preconditioner(
1378
+ param, block_size, best_effort_shape_interpretation
1379
+ )
1380
+ statistics = []
1381
+ preconditioners = []
1382
+ if not _skip_preconditioning(param):
1383
+ shapes = preconditioner.shapes_for_preconditioners()
1384
+ statistics = [
1385
+ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
1386
+ ]
1387
+ preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
1388
+
1389
+ diagonal_statistics = []
1390
+ if _graft_type_has_diagonal_statistics():
1391
+ diagonal_statistics = jnp.zeros_like(param)
1392
+
1393
+ diagonal_momentum = _quantize_momentum([])
1394
+ momentum = _quantize_momentum(jnp.zeros_like(param))
1395
+ if _graft_type_has_diagonal_momentum_states():
1396
+ diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
1397
+
1398
+ return ParameterStats(
1399
+ _quantize_diagonal_statistics(diagonal_statistics),
1400
+ _maybe_quantize_statistics(statistics),
1401
+ _maybe_quantize_preconditioners(preconditioners),
1402
+ diagonal_momentum,
1403
+ momentum,
1404
+ init_training_metrics(len(statistics)),
1405
+ )
1406
+
1407
+ return ShampooState(
1408
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
1409
+ )
1410
+
1411
+ def _skip_preconditioning(param):
1412
+ return len(param.shape) < 1 or any(
1413
+ [s > skip_preconditioning_dim_size_gt for s in param.shape]
1414
+ )
1415
+
1416
+ def _compute_stats(grad, state, param, step):
1417
+ """Compute per-parameter statistics."""
1418
+ preconditioner = Preconditioner(
1419
+ param, block_size, best_effort_shape_interpretation
1420
+ )
1421
+ new_statistics = [[]] * len(state.statistics)
1422
+ w1 = beta2
1423
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1424
+ if not _skip_preconditioning(param):
1425
+
1426
+ def compute_updated_statistics():
1427
+ new_stats = preconditioner.statistics_from_grad(grad)
1428
+ new_stats_accumulators = []
1429
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1430
+ new_stats_accumulators.append(
1431
+ w1 * _to_float(stat_accumulator) + w2 * stat
1432
+ )
1433
+ return _maybe_quantize_statistics(new_stats_accumulators)
1434
+
1435
+ if statistics_compute_steps > 1:
1436
+ perform_step = step % statistics_compute_steps == 0
1437
+ init_state = state.statistics
1438
+ new_statistics = list(
1439
+ efficient_cond(perform_step, compute_updated_statistics, init_state)
1440
+ )
1441
+ else:
1442
+ new_statistics = compute_updated_statistics()
1443
+ return ParameterStats(
1444
+ state.diagonal_statistics,
1445
+ new_statistics,
1446
+ state.preconditioners,
1447
+ state.diagonal_momentum,
1448
+ state.momentum,
1449
+ state.training_metrics,
1450
+ )
1451
+
1452
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1453
+ mi_pth_root = functools.partial(
1454
+ matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision
1455
+ )
1456
+ return jax.vmap(mi_pth_root)(xs, ps)
1457
+
1458
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1459
+ def _quantized_to_float(qx, qd, qb):
1460
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1461
+ return qv.to_float()
1462
+
1463
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1464
+ v = _quantized_to_float(qx, qd, qb)
1465
+ preconditioner, error = matrix_inverse_pth_root(
1466
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision
1467
+ )
1468
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1469
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1470
+
1471
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1472
+
1473
+ def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
1474
+ # Partition the concatenated statistics matrix across all cores.
1475
+ pspec_for_partition = preconditioner_partition_spec
1476
+ partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1477
+ if preconditioner_partition_spec:
1478
+ partitioned_ps_spec = pjit.PartitionSpec(preconditioner_partition_spec[0])
1479
+ else:
1480
+ partitioned_ps_spec = None
1481
+ partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec)
1482
+ # Run matrix inverse pth root on each shard.
1483
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1484
+ partitioned_xs, partitioned_ps
1485
+ )
1486
+ # Reshard output to have the same PSpec as input. This is required to avoid
1487
+ # vmap seeing the full set of statistics.
1488
+ partitioned_preconditioners = pjit.with_sharding_constraint(
1489
+ partitioned_preconditioners, pspec_for_partition
1490
+ )
1491
+ # Recombine the outputs at each core.
1492
+ preconditioners = pjit.with_sharding_constraint(
1493
+ partitioned_preconditioners, statistics_partition_spec
1494
+ )
1495
+ errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
1496
+ return preconditioners, errors
1497
+
1498
+ def _pmap_compute_preconditioners(
1499
+ states,
1500
+ step,
1501
+ statistics,
1502
+ num_statistics_per_state,
1503
+ original_shapes,
1504
+ exponents,
1505
+ max_size,
1506
+ prev_preconditioners,
1507
+ ):
1508
+ """Computes preconditioners for given statistics in states in PMAP mode.
1509
+
1510
+ Args:
1511
+ states: A list of optimizer states.
1512
+ step: Current step number
1513
+ statistics: A list of statistics for all variables (for every dim)
1514
+ num_statistics_per_state: Number of statistis per state to reconstruct
1515
+ output states.
1516
+ original_shapes: A list of shapes of the statistics.
1517
+ exponents: Exponent power to use for inverse-pth roots.
1518
+ max_size: Maximum dim of the statistics to pad.
1519
+ prev_preconditioners: Previously available preconditioner.
1520
+
1521
+ Returns:
1522
+ New optimizer states after computing the preconditioner.
1523
+ """
1524
+ num_devices = lax.psum(1, batch_axis_name)
1525
+ num_statistics = len(statistics)
1526
+ # Pad statistics and exponents to next multiple of num_devices.
1527
+ packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1528
+ to_pad = -num_statistics % num_devices
1529
+ packed_statistics.extend(
1530
+ [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
1531
+ )
1532
+ exponents.extend([1 for _ in range(to_pad)])
1533
+
1534
+ if not packed_statistics:
1535
+ return states
1536
+
1537
+ all_statistics = batch(packed_statistics, num_devices)
1538
+ all_exponents = batch(exponents, num_devices)
1539
+
1540
+ def _internal_inverse_pth_root_all():
1541
+ current_replica = lax.axis_index(batch_axis_name)
1542
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1543
+ all_statistics[current_replica], all_exponents[current_replica]
1544
+ )
1545
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1546
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1547
+ preconditioners_flat = unbatch(preconditioners)
1548
+ errors_flat = unbatch(errors)
1549
+ return preconditioners_flat, errors_flat
1550
+
1551
+ if preconditioning_compute_steps == 1:
1552
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1553
+ else:
1554
+ # Passing statistics instead of preconditioners as they are similarly
1555
+ # shaped tensors. Note statistics will be ignored as we are passing in
1556
+ # a large init value for error.
1557
+ preconditioners_init = packed_statistics
1558
+ errors_init = [inverse_failure_threshold] * len(packed_statistics)
1559
+ init_state = [preconditioners_init, errors_init]
1560
+ perform_step = step % preconditioning_compute_steps == 0
1561
+ preconditioners_flat, errors_flat = efficient_cond(
1562
+ perform_step, _internal_inverse_pth_root_all, init_state
1563
+ )
1564
+
1565
+ def _skip(error):
1566
+ condition = jnp.logical_or(
1567
+ jnp.isnan(error), error >= inverse_failure_threshold
1568
+ )
1569
+ return condition.astype(error.dtype)
1570
+
1571
+ def _select_preconditioner(error, new_p, old_p):
1572
+ return lax.cond(
1573
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1574
+ )
1575
+
1576
+ new_preconditioners_flat = []
1577
+ new_errors_flat = []
1578
+ for p, shape, prev_p, error in zip(
1579
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1580
+ ):
1581
+ new_preconditioners_flat.append(
1582
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1583
+ )
1584
+ new_errors_flat.append(error)
1585
+
1586
+ assert len(states) == len(num_statistics_per_state)
1587
+ assert len(new_preconditioners_flat) == num_statistics
1588
+ assert len(new_errors_flat) == num_statistics
1589
+
1590
+ # Add back empty preconditioners so we that we can set the optimizer state.
1591
+ preconditioners_for_states = []
1592
+ idx = 0
1593
+ errors_for_states = []
1594
+ for num_statistics, state in zip(num_statistics_per_state, states):
1595
+ if num_statistics == 0:
1596
+ preconditioners_for_states.append([])
1597
+ errors_for_states.append(jnp.array(0, jnp.float32))
1598
+ else:
1599
+ preconditioners_for_state = new_preconditioners_flat[
1600
+ idx : idx + num_statistics
1601
+ ]
1602
+ assert len(state.statistics) == len(preconditioners_for_state)
1603
+ preconditioners_for_states.append(preconditioners_for_state)
1604
+
1605
+ errors_for_state = jnp.stack(
1606
+ new_errors_flat[idx : idx + num_statistics]
1607
+ )
1608
+ assert len(state.statistics) == len(errors_for_state)
1609
+ errors_for_states.append(errors_for_state)
1610
+
1611
+ idx += num_statistics
1612
+ new_states = []
1613
+ for state, new_preconditioners, new_errors in zip(
1614
+ states, preconditioners_for_states, errors_for_states
1615
+ ):
1616
+ if state.statistics:
1617
+ new_errors = jnp.where(
1618
+ jnp.logical_and(
1619
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1620
+ ),
1621
+ new_errors,
1622
+ state.training_metrics.inverse_pth_root_errors,
1623
+ )
1624
+ new_training_metrics = TrainingMetrics(new_errors)
1625
+ new_states.append(
1626
+ ParameterStats(
1627
+ state.diagonal_statistics,
1628
+ state.statistics,
1629
+ new_preconditioners,
1630
+ state.diagonal_momentum,
1631
+ state.momentum,
1632
+ new_training_metrics,
1633
+ )
1634
+ )
1635
+
1636
+ return new_states
1637
+
1638
+ def _pmap_quantized_compute_preconditioners(
1639
+ states,
1640
+ step,
1641
+ statistics,
1642
+ num_statistics_per_state,
1643
+ original_shapes,
1644
+ exponents,
1645
+ max_size,
1646
+ prev_preconditioners,
1647
+ ):
1648
+ """Computes preconditioners for given statistics in states in PMAP mode.
1649
+
1650
+ For quantization, each statistic is represented by three values:
1651
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1652
+ without ever recreating the original matrix in f32.
1653
+
1654
+ Args:
1655
+ states: A list of optimizer states.
1656
+ step: Current step number
1657
+ statistics: A list of statistics for all variables (for every dim)
1658
+ num_statistics_per_state: Number of statistis per state to reconstruct
1659
+ output states.
1660
+ original_shapes: A list of shapes of the statistics.
1661
+ exponents: Exponent power to use for inverse-pth roots.
1662
+ max_size: Maximum dim of the statistics to pad.
1663
+ prev_preconditioners: Previously available preconditioner.
1664
+
1665
+ Returns:
1666
+ New optimizer states after computing the preconditioner.
1667
+ """
1668
+ num_devices = lax.psum(1, batch_axis_name)
1669
+ num_statistics = len(statistics)
1670
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1671
+ # Complexity here is around: shapes needing be statically shaped,
1672
+ # our custom quantization type requires a different type of packing.
1673
+
1674
+ # Parallel tensors:
1675
+ # quantized [dxd]
1676
+ # diagonals [d] f32
1677
+ # bucket_sizes [d] f32
1678
+ packed_quantized_statistics = [
1679
+ pad_square_matrix(stat.quantized, max_size) for stat in statistics
1680
+ ]
1681
+ packed_quantized_diagonals = [
1682
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1683
+ ]
1684
+ packed_quantized_bucket_sizes = [
1685
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1686
+ ]
1687
+
1688
+ to_pad = -num_statistics % num_devices
1689
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1690
+ quantized_eye = QuantizedValue.from_float_value(
1691
+ padded_eye, quantized_dtype, True
1692
+ )
1693
+ packed_quantized_statistics.extend(
1694
+ [quantized_eye.quantized for _ in range(to_pad)]
1695
+ )
1696
+ packed_quantized_diagonals.extend(
1697
+ [quantized_eye.diagonal for _ in range(to_pad)]
1698
+ )
1699
+ packed_quantized_bucket_sizes.extend(
1700
+ [quantized_eye.bucket_size for _ in range(to_pad)]
1701
+ )
1702
+ exponents.extend([1 for _ in range(to_pad)])
1703
+
1704
+ if not packed_quantized_statistics:
1705
+ return states
1706
+
1707
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1708
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1709
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices)
1710
+ all_exponents = batch(exponents, num_devices)
1711
+
1712
+ def _internal_inverse_pth_root_all():
1713
+ current_replica = lax.axis_index(batch_axis_name)
1714
+ (
1715
+ quantized_preconditioners,
1716
+ quantized_diagonals,
1717
+ quantized_bucket_sizes,
1718
+ errors,
1719
+ ) = _quantized_matrix_inverse_pth_root_vmap(
1720
+ all_quantized_statistics[current_replica],
1721
+ all_quantized_diagonals[current_replica],
1722
+ all_quantized_bucket_sizes[current_replica],
1723
+ all_exponents[current_replica],
1724
+ )
1725
+ quantized_preconditioners = jax.lax.all_gather(
1726
+ quantized_preconditioners, batch_axis_name
1727
+ )
1728
+ quantized_diagonals = jax.lax.all_gather(
1729
+ quantized_diagonals, batch_axis_name
1730
+ )
1731
+ quantized_bucket_sizes = jax.lax.all_gather(
1732
+ quantized_bucket_sizes, batch_axis_name
1733
+ )
1734
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1735
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1736
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1737
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1738
+ errors_flat = unbatch(errors)
1739
+ return (
1740
+ quantized_preconditioners_flat,
1741
+ quantized_diagonals_flat,
1742
+ quantized_bucket_sizes_flat,
1743
+ errors_flat,
1744
+ )
1745
+
1746
+ if preconditioning_compute_steps == 1:
1747
+ (
1748
+ quantized_preconditioners_flat,
1749
+ quantized_diagonals_flat,
1750
+ quantized_bucket_sizes_flat,
1751
+ errors_flat,
1752
+ ) = _internal_inverse_pth_root_all()
1753
+ else:
1754
+ # Passing statistics instead of preconditioners as they are similarly
1755
+ # shaped tensors. Note statistics will be ignored as we are passing in
1756
+ # a large init value for error.
1757
+ quantized_preconditioners_init = packed_quantized_statistics
1758
+ quantized_diagonals_init = packed_quantized_diagonals
1759
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1760
+ errors_init = [inverse_failure_threshold] * len(
1761
+ quantized_preconditioners_init
1762
+ )
1763
+ init_state = [
1764
+ quantized_preconditioners_init,
1765
+ quantized_diagonals_init,
1766
+ quantized_bucket_sizes_init,
1767
+ errors_init,
1768
+ ]
1769
+ perform_step = step % preconditioning_compute_steps == 0
1770
+ (
1771
+ quantized_preconditioners_flat,
1772
+ quantized_diagonals_flat,
1773
+ quantized_bucket_sizes_flat,
1774
+ errors_flat,
1775
+ ) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state)
1776
+
1777
+ def _skip(error):
1778
+ condition = jnp.logical_or(
1779
+ jnp.isnan(error), error >= inverse_failure_threshold
1780
+ )
1781
+ return condition.astype(error.dtype)
1782
+
1783
+ def _select_preconditioner(error, new_p, old_p):
1784
+ return lax.cond(
1785
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1786
+ )
1787
+
1788
+ new_quantized_preconditioners_flat = []
1789
+ new_quantized_diagonals_flat = []
1790
+ new_quantized_bucket_sizes_flat = []
1791
+ new_errors_flat = []
1792
+ for p, d, b, shape, prev_p, error in zip(
1793
+ quantized_preconditioners_flat,
1794
+ quantized_diagonals_flat,
1795
+ quantized_bucket_sizes_flat,
1796
+ original_shapes,
1797
+ prev_preconditioners,
1798
+ errors_flat,
1799
+ ):
1800
+ new_quantized_preconditioners_flat.append(
1801
+ _select_preconditioner(
1802
+ error, p[: shape[0], : shape[1]], prev_p.quantized
1803
+ )
1804
+ )
1805
+ new_quantized_diagonals_flat.append(
1806
+ _select_preconditioner(error, d[: shape[0]], prev_p.diagonal)
1807
+ )
1808
+ new_quantized_bucket_sizes_flat.append(
1809
+ _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1810
+ )
1811
+ new_errors_flat.append(error)
1812
+
1813
+ assert len(states) == len(num_statistics_per_state)
1814
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1815
+ assert len(new_quantized_diagonals_flat) == num_statistics
1816
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1817
+
1818
+ # Add back empty preconditioners so we that we can set the optimizer state.
1819
+ preconditioners_for_states = []
1820
+ errors_for_states = []
1821
+ idx = 0
1822
+ for num_statistics, state in zip(num_statistics_per_state, states):
1823
+ if num_statistics == 0:
1824
+ preconditioners_for_states.append([])
1825
+ errors_for_states.append(jnp.array(0, jnp.float32))
1826
+ else:
1827
+ quantized_preconditioners_for_state = (
1828
+ new_quantized_preconditioners_flat[idx : idx + num_statistics]
1829
+ )
1830
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1831
+ idx : idx + num_statistics
1832
+ ]
1833
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1834
+ idx : idx + num_statistics
1835
+ ]
1836
+ errors_for_state = jnp.stack(
1837
+ new_errors_flat[idx : idx + num_statistics]
1838
+ )
1839
+
1840
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1841
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1842
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1843
+ assert len(state.statistics) == len(errors_for_state)
1844
+
1845
+ quantized_preconditioners = []
1846
+ for qv, qd, qb in zip(
1847
+ quantized_preconditioners_for_state,
1848
+ quantized_diagonals_for_state,
1849
+ quantized_bucket_sizes_for_state,
1850
+ ):
1851
+ quantized_preconditioners.append(
1852
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1853
+ )
1854
+ preconditioners_for_states.append(quantized_preconditioners)
1855
+ errors_for_states.append(errors_for_state)
1856
+ idx += num_statistics
1857
+ new_states = []
1858
+ for state, new_preconditioners, new_errors in zip(
1859
+ states, preconditioners_for_states, errors_for_states
1860
+ ):
1861
+ if state.statistics:
1862
+ new_errors = jnp.where(
1863
+ jnp.logical_and(
1864
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1865
+ ),
1866
+ new_errors,
1867
+ state.training_metrics.inverse_pth_root_errors,
1868
+ )
1869
+ new_training_metrics = TrainingMetrics(new_errors)
1870
+ new_states.append(
1871
+ ParameterStats(
1872
+ state.diagonal_statistics,
1873
+ state.statistics,
1874
+ new_preconditioners,
1875
+ state.diagonal_momentum,
1876
+ state.momentum,
1877
+ new_training_metrics,
1878
+ )
1879
+ )
1880
+
1881
+ return new_states
1882
+
1883
+ def _pjit_compute_preconditioners(
1884
+ states,
1885
+ step,
1886
+ statistics,
1887
+ num_statistics_per_state,
1888
+ original_shapes,
1889
+ exponents,
1890
+ max_size,
1891
+ prev_preconditioners,
1892
+ ):
1893
+ """Computes preconditioners for given statistics in states in PJIT mode.
1894
+
1895
+ Args:
1896
+ states: A list of optimizer states.
1897
+ step: Current step number
1898
+ statistics: A list of statistics for all variables (for every dim)
1899
+ num_statistics_per_state: Number of statistis per state to reconstruct
1900
+ output states.
1901
+ original_shapes: A list of shapes of the statistics.
1902
+ exponents: Exponent power to use for inverse-pth roots.
1903
+ max_size: Maximum dim of the statistics to pad.
1904
+ prev_preconditioners: Previously available preconditioner.
1905
+
1906
+ Returns:
1907
+ New optimizer states after computing the preconditioner.
1908
+ """
1909
+ num_statistics = len(statistics)
1910
+ to_pad = -num_statistics % num_devices_for_pjit
1911
+ padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1912
+ padded_statistics.extend(
1913
+ [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1914
+ )
1915
+ exponents.extend([1 for _ in range(to_pad)])
1916
+ all_statistics = jnp.stack(padded_statistics)
1917
+ all_exponents = jnp.stack(exponents)
1918
+
1919
+ def _internal_inverse_pth_root_all():
1920
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1921
+ all_statistics, all_exponents
1922
+ )
1923
+ b1 = preconditioners.shape[0]
1924
+
1925
+ def split(batched_values):
1926
+ return [
1927
+ jnp.squeeze(v)
1928
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1929
+ ]
1930
+
1931
+ return split(preconditioners), split(errors)
1932
+
1933
+ if preconditioning_compute_steps == 1:
1934
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1935
+ else:
1936
+ # Passing statistics instead of preconditioners as they are similarly
1937
+ # shaped tensors. Note statistics will be ignored as we are passing in
1938
+ # a large init value for error.
1939
+ preconditioners_init = padded_statistics
1940
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1941
+ init_state = [preconditioners_init, errors_init]
1942
+ perform_step = step % preconditioning_compute_steps == 0
1943
+ preconditioners_flat, errors_flat = efficient_cond(
1944
+ perform_step, _internal_inverse_pth_root_all, init_state
1945
+ )
1946
+
1947
+ def _skip(error):
1948
+ condition = jnp.logical_or(
1949
+ jnp.isnan(error), error >= inverse_failure_threshold
1950
+ )
1951
+ return condition.astype(error.dtype)
1952
+
1953
+ def _select_preconditioner(error, new_p, old_p):
1954
+ return lax.cond(
1955
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1956
+ )
1957
+
1958
+ new_preconditioners_flat = []
1959
+ new_errors_flat = []
1960
+ for p, shape, prev_p, error in zip(
1961
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1962
+ ):
1963
+ new_preconditioners_flat.append(
1964
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1965
+ )
1966
+ new_errors_flat.append(error)
1967
+
1968
+ assert len(states) == len(num_statistics_per_state)
1969
+ assert len(new_preconditioners_flat) == num_statistics
1970
+
1971
+ # Add back empty preconditioners so we that we can set the optimizer state.
1972
+ preconditioners_for_states = []
1973
+ errors_for_states = []
1974
+ idx = 0
1975
+ for num_statistics, state in zip(num_statistics_per_state, states):
1976
+ if num_statistics == 0:
1977
+ preconditioners_for_states.append([])
1978
+ errors_for_states.append(jnp.array(0, jnp.float32))
1979
+ else:
1980
+ preconditioners_for_state = new_preconditioners_flat[
1981
+ idx : idx + num_statistics
1982
+ ]
1983
+ assert len(state.statistics) == len(preconditioners_for_state)
1984
+ preconditioners_for_states.append(preconditioners_for_state)
1985
+
1986
+ errors_for_state = jnp.stack(
1987
+ new_errors_flat[idx : idx + num_statistics]
1988
+ )
1989
+ assert len(state.statistics) == len(errors_for_state)
1990
+ errors_for_states.append(errors_for_state)
1991
+ idx += num_statistics
1992
+
1993
+ new_states = []
1994
+ for state, new_preconditioners, new_errors in zip(
1995
+ states, preconditioners_for_states, errors_for_states
1996
+ ):
1997
+ if state.statistics:
1998
+ new_errors = jnp.where(
1999
+ jnp.logical_and(
2000
+ new_errors > 0.0, new_errors != inverse_failure_threshold
2001
+ ),
2002
+ new_errors,
2003
+ state.training_metrics.inverse_pth_root_errors,
2004
+ )
2005
+ new_training_metrics = TrainingMetrics(new_errors)
2006
+ new_states.append(
2007
+ ParameterStats(
2008
+ state.diagonal_statistics,
2009
+ state.statistics,
2010
+ new_preconditioners,
2011
+ state.diagonal_momentum,
2012
+ state.momentum,
2013
+ new_training_metrics,
2014
+ )
2015
+ )
2016
+
2017
+ return new_states
2018
+
2019
+ def _compute_preconditioners(states, params, step):
2020
+ """Computes preconditioners for given statistics in states.
2021
+
2022
+ Args:
2023
+ states: A list of optimizer states.
2024
+ params: A list of params.
2025
+ step: Current step number
2026
+
2027
+ Returns:
2028
+ New optimizer states after computing the preconditioner.
2029
+ """
2030
+ statistics = []
2031
+ num_statistics_per_state = []
2032
+ original_shapes = []
2033
+ exponents = []
2034
+ max_size = 0
2035
+ prev_preconditioners = []
2036
+
2037
+ for state, param in zip(states, params):
2038
+ num_statistics = len(state.statistics)
2039
+ num_statistics_per_state.append(num_statistics)
2040
+ original_shapes_for_state = []
2041
+ if num_statistics > 0:
2042
+ preconditioner = Preconditioner(
2043
+ param, block_size, best_effort_shape_interpretation
2044
+ )
2045
+ for statistic in state.statistics:
2046
+ exponents.append(
2047
+ preconditioner.exponent_for_preconditioner()
2048
+ if exponent_override == 0
2049
+ else exponent_override
2050
+ )
2051
+ original_shapes_for_state.append(statistic.shape)
2052
+ max_size = max(max_size, statistic.shape[0])
2053
+
2054
+ statistics.extend(state.statistics)
2055
+ prev_preconditioners.extend(state.preconditioners)
2056
+ original_shapes.extend(original_shapes_for_state)
2057
+
2058
+ if batch_axis_name:
2059
+ # Quantization is only enabled if batch_axis_name is not set.
2060
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
2061
+
2062
+ if quantized_dtype == jnp.float32:
2063
+ return _pmap_compute_preconditioners(
2064
+ states,
2065
+ step,
2066
+ statistics,
2067
+ num_statistics_per_state,
2068
+ original_shapes,
2069
+ exponents,
2070
+ max_size,
2071
+ prev_preconditioners,
2072
+ )
2073
+ else:
2074
+ return _pmap_quantized_compute_preconditioners(
2075
+ states,
2076
+ step,
2077
+ statistics,
2078
+ num_statistics_per_state,
2079
+ original_shapes,
2080
+ exponents,
2081
+ max_size,
2082
+ prev_preconditioners,
2083
+ )
2084
+
2085
+ else:
2086
+ return _pjit_compute_preconditioners(
2087
+ states,
2088
+ step,
2089
+ statistics,
2090
+ num_statistics_per_state,
2091
+ original_shapes,
2092
+ exponents,
2093
+ max_size,
2094
+ prev_preconditioners,
2095
+ )
2096
+
2097
+ def _transform_grad(grad, state, param, step):
2098
+ """Transform per-parameter gradients."""
2099
+ preconditioner = Preconditioner(
2100
+ param, block_size, best_effort_shape_interpretation
2101
+ )
2102
+ sgd_update = grad
2103
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
2104
+ if (
2105
+ graft_type == GraftingType.ADAGRAD
2106
+ or graft_type == GraftingType.ADAGRAD_NORMALIZED
2107
+ ):
2108
+
2109
+ scaled_grad = grad
2110
+ if graft_type == GraftingType.ADAGRAD_NORMALIZED:
2111
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
2112
+
2113
+ new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
2114
+ scaled_grad
2115
+ )
2116
+ adagrad_update = scaled_grad / (
2117
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2118
+ )
2119
+ grafting_update = adagrad_update
2120
+ elif (
2121
+ graft_type == GraftingType.RMSPROP
2122
+ or graft_type == GraftingType.RMSPROP_NORMALIZED
2123
+ ):
2124
+
2125
+ scaled_grad = grad
2126
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
2127
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
2128
+
2129
+ w1 = beta2
2130
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
2131
+
2132
+ new_diagonal_statistics = (
2133
+ w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad)
2134
+ )
2135
+ rmsprop_update = scaled_grad / (
2136
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2137
+ )
2138
+
2139
+ if clip_by_scaled_gradient_norm:
2140
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
2141
+ jnp.sqrt(float(rmsprop_update.size))
2142
+ )
2143
+ clipping_denom = jnp.maximum(
2144
+ 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
2145
+ )
2146
+ rmsprop_update /= clipping_denom
2147
+
2148
+ grafting_update = rmsprop_update
2149
+ elif graft_type == GraftingType.SGD:
2150
+ grafting_update = sgd_update
2151
+ else:
2152
+ grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update)
2153
+
2154
+ precond_grad = grad
2155
+ if not _skip_preconditioning(param):
2156
+ precond_grad = preconditioner.preconditioned_grad(
2157
+ precond_grad, _maybe_dequantize_preconditioners(state.preconditioners)
2158
+ )
2159
+ else:
2160
+ precond_grad = grafting_update
2161
+
2162
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
2163
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
2164
+
2165
+ multiplier = grafting_update_norm / (precond_grad_norm + 1e-16)
2166
+ shampoo_update = precond_grad * multiplier
2167
+
2168
+ shampoo_update_with_wd = shampoo_update
2169
+ grafting_update_with_wd = grafting_update
2170
+ if weight_decay != 0:
2171
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
2172
+ grafting_update_with_wd = grafting_update + weight_decay * param
2173
+
2174
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
2175
+
2176
+ shampoo_update_with_wd_momentum = (
2177
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
2178
+ )
2179
+
2180
+ if _graft_type_has_diagonal_momentum_states():
2181
+ grafting_update_with_wd_momentum = (
2182
+ state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
2183
+ )
2184
+ else:
2185
+ # Share the momentum buffer
2186
+ grafting_update_with_wd_momentum = (
2187
+ state.momentum.to_float() * beta1 + w * grafting_update_with_wd
2188
+ )
2189
+
2190
+ run_shampoo = (step >= start_preconditioning_step).astype(
2191
+ grafting_update_with_wd_momentum.dtype
2192
+ )
2193
+
2194
+ momentum_update = (
2195
+ run_shampoo * shampoo_update_with_wd_momentum
2196
+ + (1.0 - run_shampoo) * grafting_update_with_wd_momentum
2197
+ )
2198
+
2199
+ wd_update = (
2200
+ run_shampoo * shampoo_update_with_wd
2201
+ + (1.0 - run_shampoo) * grafting_update_with_wd
2202
+ )
2203
+
2204
+ nesterov_momentum_update = momentum_update
2205
+ if nesterov:
2206
+ nesterov_momentum_update = w * wd_update + beta1 * momentum_update
2207
+
2208
+ lr = learning_rate
2209
+ if callable(learning_rate):
2210
+ lr = learning_rate(step)
2211
+ transformed_update = -1.0 * lr * nesterov_momentum_update
2212
+
2213
+ new_diagonal_momentum = grafting_update_with_wd_momentum
2214
+ new_momentum = shampoo_update_with_wd_momentum
2215
+ if not _graft_type_has_diagonal_momentum_states():
2216
+ new_diagonal_momentum = []
2217
+ new_momentum = momentum_update
2218
+
2219
+ param_stats = ParameterStats(
2220
+ _quantize_diagonal_statistics(new_diagonal_statistics),
2221
+ state.statistics,
2222
+ state.preconditioners,
2223
+ _quantize_momentum(new_diagonal_momentum),
2224
+ _quantize_momentum(new_momentum),
2225
+ state.training_metrics,
2226
+ )
2227
+
2228
+ return transformed_update, param_stats
2229
+
2230
+ def update_fn(grads, state, params):
2231
+ """Transform the input gradient and update all statistics.
2232
+
2233
+ Args:
2234
+ grads: the gradient tensors for the parameters.
2235
+ state: a named tuple containing the state of the optimizer
2236
+ params: the parameters that should be updated.
2237
+
2238
+ Returns:
2239
+ A tuple containing the new parameters and the new optimizer state.
2240
+ """
2241
+ params_flat, treedef = jax.tree_flatten(params)
2242
+ stats_flat = treedef.flatten_up_to(state.stats)
2243
+ grads_flat = treedef.flatten_up_to(grads)
2244
+
2245
+ new_stats_flat = jax.tree_multimap(
2246
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
2247
+ grads_flat,
2248
+ stats_flat,
2249
+ params_flat,
2250
+ )
2251
+ new_stats_flat = _compute_preconditioners(
2252
+ new_stats_flat, params_flat, state.count
2253
+ )
2254
+ outputs = jax.tree_multimap(
2255
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
2256
+ grads_flat,
2257
+ new_stats_flat,
2258
+ params_flat,
2259
+ )
2260
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
2261
+
2262
+ updates = jax.tree_unflatten(treedef, updates_flat)
2263
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
2264
+
2265
+ new_state = ShampooState(count=state.count + 1, stats=new_stats)
2266
+ return updates, new_state
2267
+
2268
+ if shard_optimizer_states:
2269
+ # Hijacks the init_fn signature so we can return an OptState with
2270
+ # appropriate init_fns.
2271
+ def _init_fns(unused_params):
2272
+ return InitFnState(
2273
+ init_fn=sharded_init_fn,
2274
+ pspec_fn=sharded_init_partition_spec_fn,
2275
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
2276
+ )
2277
+
2278
+ return optax.GradientTransformation(_init_fns, sharded_update_fn)
2279
+ else:
2280
+ return optax.GradientTransformation(init_fn, update_fn)
encode_dataset.ipynb ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Pre-encoding a dataset for DALLE·mini"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
+ "\n",
18
+ "Adapt it to your own dataset and image encoder.\n",
19
+ "\n",
20
+ "At the end you should have a dataset of pairs:\n",
21
+ "* a caption defined as a string\n",
22
+ "* an encoded image defined as a list of int."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "3b59489e",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "from tqdm.notebook import tqdm\n",
33
+ "\n",
34
+ "import torchvision.transforms as T\n",
35
+ "\n",
36
+ "import webdataset as wds\n",
37
+ "\n",
38
+ "import jax\n",
39
+ "import braceexpand\n",
40
+ "from pathlib import Path"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "c7c4c1e6",
46
+ "metadata": {},
47
+ "source": [
48
+ "## Configuration Parameters"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "1265dbfe",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
59
+ "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
60
+ "\n",
61
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
62
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
63
+ " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
64
+ ")\n",
65
+ "\n",
66
+ "# good defaults for a TPU v3-8\n",
67
+ "batch_size = 128 # Per device\n",
68
+ "num_workers = 8 # For parallel processing\n",
69
+ "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
70
+ "save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "['XXX/shard-0000.tar',\n",
83
+ " 'XXX/shard-0001.tar',\n",
84
+ " 'XXX/shard-0002.tar',\n",
85
+ " 'XXX/shard-0003.tar',\n",
86
+ " 'XXX/shard-0004.tar',\n",
87
+ " 'XXX/shard-0005.tar',\n",
88
+ " 'XXX/shard-0006.tar',\n",
89
+ " 'XXX/shard-0007.tar',\n",
90
+ " 'XXX/shard-0008.tar']"
91
+ ]
92
+ },
93
+ "execution_count": 5,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "shards = list(\n",
100
+ " braceexpand.braceexpand(shards)\n",
101
+ ") # better display for tqdm with known length"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "75dba8e2",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Load data"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "a1e8fb95",
115
+ "metadata": {},
116
+ "source": [
117
+ "We load data using `webdataset`."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "id": "9ef5de9e",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "ds = (\n",
128
+ " wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
129
+ " .decode(\"rgb\", handler=wds.warn_and_continue)\n",
130
+ " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
131
+ " .batched(total_bs) # load in batch per worker (faster)\n",
132
+ ")"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "id": "90981824",
138
+ "metadata": {},
139
+ "source": [
140
+ "Note:\n",
141
+ "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
142
+ "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
143
+ "* you can also filter out some items using `select`."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "129c377d",
149
+ "metadata": {},
150
+ "source": [
151
+ "We can now inspect our data."
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "8cac98cb",
158
+ "metadata": {
159
+ "scrolled": true
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "%%time\n",
164
+ "images, captions = next(iter(ds))"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "cd268fbf",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "images.shape"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "5acfc4d8",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "captions[:10]"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "id": "c24693c0",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "T.ToPILImage()(images[0].permute(2, 0, 1))"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "3059ffb1",
200
+ "metadata": {},
201
+ "source": [
202
+ "Finally we create our dataloader."
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "c227c551",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "dl = (\n",
213
+ " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
214
+ ") # avoid partial batch at the end of each worker"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "a354472b",
220
+ "metadata": {},
221
+ "source": [
222
+ "## Image encoder\n",
223
+ "\n",
224
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "47a8b818",
231
+ "metadata": {
232
+ "scrolled": true
233
+ },
234
+ "outputs": [],
235
+ "source": [
236
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
237
+ "from flax.jax_utils import replicate\n",
238
+ "\n",
239
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
240
+ "vqgan_params = replicate(vqgan.params)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "62ad01c3",
246
+ "metadata": {},
247
+ "source": [
248
+ "## Encoding"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "20357f74",
254
+ "metadata": {},
255
+ "source": [
256
+ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "322a4619",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "from flax.training.common_utils import shard\n",
267
+ "from functools import partial\n",
268
+ "\n",
269
+ "\n",
270
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
271
+ "def p_encode(batch, params):\n",
272
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
273
+ " _, indices = vqgan.encode(batch, params=params)\n",
274
+ " return indices"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "ff6c10d4",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "import pandas as pd\n",
285
+ "\n",
286
+ "\n",
287
+ "def encode_dataset(dataloader, output_dir, save_frequency):\n",
288
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
289
+ " all_captions = []\n",
290
+ " all_encoding = []\n",
291
+ " n_file = 1\n",
292
+ " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
293
+ " images = images.numpy()\n",
294
+ " n = len(images) // 8 * 8\n",
295
+ " if n != len(images):\n",
296
+ " # get the max number of images we can (multiple of 8)\n",
297
+ " print(f\"Different sizes {n} vs {len(images)}\")\n",
298
+ " images = images[:n]\n",
299
+ " captions = captions[:n]\n",
300
+ " if not len(captions):\n",
301
+ " print(f\"No images/captions in batch...\")\n",
302
+ " continue\n",
303
+ " images = shard(images)\n",
304
+ " encoded = p_encode(images, vqgan_params)\n",
305
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
306
+ " all_captions.extend(captions)\n",
307
+ " all_encoding.extend(encoded.tolist())\n",
308
+ "\n",
309
+ " # save files\n",
310
+ " if (idx + 1) % save_frequency == 0:\n",
311
+ " print(f\"Saving file {n_file}\")\n",
312
+ " batch_df = pd.DataFrame.from_dict(\n",
313
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
314
+ " )\n",
315
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
316
+ " all_captions = []\n",
317
+ " all_encoding = []\n",
318
+ " n_file += 1\n",
319
+ "\n",
320
+ " if len(all_captions):\n",
321
+ " print(f\"Saving final file {n_file}\")\n",
322
+ " batch_df = pd.DataFrame.from_dict(\n",
323
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
324
+ " )\n",
325
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "7704863d",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "markdown",
340
+ "id": "8953dd84",
341
+ "metadata": {},
342
+ "source": [
343
+ "----"
344
+ ]
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "interpreter": {
349
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
350
+ },
351
+ "kernelspec": {
352
+ "display_name": "Python 3 (ipykernel)",
353
+ "language": "python",
354
+ "name": "python3"
355
+ },
356
+ "language_info": {
357
+ "codemirror_mode": {
358
+ "name": "ipython",
359
+ "version": 3
360
+ },
361
+ "file_extension": ".py",
362
+ "mimetype": "text/x-python",
363
+ "name": "python",
364
+ "nbconvert_exporter": "python",
365
+ "pygments_lexer": "ipython3",
366
+ "version": "3.9.7"
367
+ }
368
+ },
369
+ "nbformat": 4,
370
+ "nbformat_minor": 5
371
+ }
inference_pipeline.ipynb ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install required libraries\n",
50
+ "!pip install -q dalle-mini\n",
51
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "metadata": {
57
+ "id": "ozHzTkyv8cqU"
58
+ },
59
+ "source": [
60
+ "We load required models:\n",
61
+ "* DALL·E mini for text to encoded images\n",
62
+ "* VQGAN for decoding images\n",
63
+ "* CLIP for scoring predictions"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {
70
+ "id": "K6CxW2o42f-w"
71
+ },
72
+ "outputs": [],
73
+ "source": [
74
+ "# Model references\n",
75
+ "\n",
76
+ "# dalle-mega\n",
77
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
78
+ "DALLE_COMMIT_ID = None\n",
79
+ "\n",
80
+ "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
81
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
82
+ "\n",
83
+ "# VQGAN model\n",
84
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
85
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {
92
+ "id": "Yv-aR3t4Oe5v"
93
+ },
94
+ "outputs": [],
95
+ "source": [
96
+ "import jax\n",
97
+ "import jax.numpy as jnp\n",
98
+ "\n",
99
+ "# check how many devices are available\n",
100
+ "jax.local_device_count()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {
107
+ "id": "92zYmvsQ38vL"
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "# Load models & tokenizer\n",
112
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
113
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
114
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
115
+ "\n",
116
+ "# Load dalle-mini\n",
117
+ "model, params = DalleBart.from_pretrained(\n",
118
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
119
+ ")\n",
120
+ "\n",
121
+ "# Load VQGAN\n",
122
+ "vqgan, vqgan_params = VQModel.from_pretrained(\n",
123
+ " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
124
+ ")"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "markdown",
129
+ "metadata": {
130
+ "id": "o_vH2X1tDtzA"
131
+ },
132
+ "source": [
133
+ "Model parameters are replicated on each device for faster inference."
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {
140
+ "id": "wtvLoM48EeVw"
141
+ },
142
+ "outputs": [],
143
+ "source": [
144
+ "from flax.jax_utils import replicate\n",
145
+ "\n",
146
+ "params = replicate(params)\n",
147
+ "vqgan_params = replicate(vqgan_params)"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {
153
+ "id": "0A9AHQIgZ_qw"
154
+ },
155
+ "source": [
156
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {
163
+ "id": "sOtoOmYsSYPz"
164
+ },
165
+ "outputs": [],
166
+ "source": [
167
+ "from functools import partial\n",
168
+ "\n",
169
+ "# model inference\n",
170
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
171
+ "def p_generate(\n",
172
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
173
+ "):\n",
174
+ " return model.generate(\n",
175
+ " **tokenized_prompt,\n",
176
+ " prng_key=key,\n",
177
+ " params=params,\n",
178
+ " top_k=top_k,\n",
179
+ " top_p=top_p,\n",
180
+ " temperature=temperature,\n",
181
+ " condition_scale=condition_scale,\n",
182
+ " )\n",
183
+ "\n",
184
+ "\n",
185
+ "# decode image\n",
186
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
187
+ "def p_decode(indices, params):\n",
188
+ " return vqgan.decode_code(indices, params=params)"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "metadata": {
194
+ "id": "HmVN6IBwapBA"
195
+ },
196
+ "source": [
197
+ "Keys are passed to the model on each device to generate unique inference per device."
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "metadata": {
204
+ "id": "4CTXmlUkThhX"
205
+ },
206
+ "outputs": [],
207
+ "source": [
208
+ "import random\n",
209
+ "\n",
210
+ "# create a random key\n",
211
+ "seed = random.randint(0, 2**32 - 1)\n",
212
+ "key = jax.random.PRNGKey(seed)"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "markdown",
217
+ "metadata": {
218
+ "id": "BrnVyCo81pij"
219
+ },
220
+ "source": [
221
+ "## 🖍 Text Prompt"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "markdown",
226
+ "metadata": {
227
+ "id": "rsmj0Aj5OQox"
228
+ },
229
+ "source": [
230
+ "Our model requires processing prompts."
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {
237
+ "id": "YjjhUychOVxm"
238
+ },
239
+ "outputs": [],
240
+ "source": [
241
+ "from dalle_mini import DalleBartProcessor\n",
242
+ "\n",
243
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "metadata": {
249
+ "id": "BQ7fymSPyvF_"
250
+ },
251
+ "source": [
252
+ "Let's define some text prompts."
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {
259
+ "id": "x_0vI9ge1oKr"
260
+ },
261
+ "outputs": [],
262
+ "source": [
263
+ "prompts = [\n",
264
+ " \"sunset over a lake in the mountains\",\n",
265
+ " \"the Eiffel tower landing on the moon\",\n",
266
+ "]"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "markdown",
271
+ "metadata": {
272
+ "id": "XlZUG3SCLnGE"
273
+ },
274
+ "source": [
275
+ "Note: we could use the same prompt multiple times for faster inference."
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {
282
+ "id": "VKjEZGjtO49k"
283
+ },
284
+ "outputs": [],
285
+ "source": [
286
+ "tokenized_prompts = processor(prompts)"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "metadata": {
292
+ "id": "-CEJBnuJOe5z"
293
+ },
294
+ "source": [
295
+ "Finally we replicate the prompts onto each device."
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {
302
+ "id": "lQePgju5Oe5z"
303
+ },
304
+ "outputs": [],
305
+ "source": [
306
+ "tokenized_prompt = replicate(tokenized_prompts)"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "markdown",
311
+ "metadata": {
312
+ "id": "phQ9bhjRkgAZ"
313
+ },
314
+ "source": [
315
+ "## 🎨 Generate images\n",
316
+ "\n",
317
+ "We generate images using dalle-mini model and decode them with the VQGAN."
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {
324
+ "id": "d0wVkXpKqnHA"
325
+ },
326
+ "outputs": [],
327
+ "source": [
328
+ "# number of predictions per prompt\n",
329
+ "n_predictions = 8\n",
330
+ "\n",
331
+ "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
332
+ "gen_top_k = None\n",
333
+ "gen_top_p = None\n",
334
+ "temperature = None\n",
335
+ "cond_scale = 10.0"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "metadata": {
342
+ "id": "SDjEx9JxR3v8"
343
+ },
344
+ "outputs": [],
345
+ "source": [
346
+ "from flax.training.common_utils import shard_prng_key\n",
347
+ "import numpy as np\n",
348
+ "from PIL import Image\n",
349
+ "from tqdm.notebook import trange\n",
350
+ "\n",
351
+ "print(f\"Prompts: {prompts}\\n\")\n",
352
+ "# generate images\n",
353
+ "images = []\n",
354
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
355
+ " # get a new key\n",
356
+ " key, subkey = jax.random.split(key)\n",
357
+ " # generate images\n",
358
+ " encoded_images = p_generate(\n",
359
+ " tokenized_prompt,\n",
360
+ " shard_prng_key(subkey),\n",
361
+ " params,\n",
362
+ " gen_top_k,\n",
363
+ " gen_top_p,\n",
364
+ " temperature,\n",
365
+ " cond_scale,\n",
366
+ " )\n",
367
+ " # remove BOS\n",
368
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
369
+ " # decode images\n",
370
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
371
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
372
+ " for decoded_img in decoded_images:\n",
373
+ " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
374
+ " images.append(img)\n",
375
+ " display(img)\n",
376
+ " print()"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "markdown",
381
+ "metadata": {
382
+ "id": "tw02wG9zGmyB"
383
+ },
384
+ "source": [
385
+ "## 🏅 Optional: Rank images by CLIP score\n",
386
+ "\n",
387
+ "We can rank images according to CLIP.\n",
388
+ "\n",
389
+ "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "metadata": {
396
+ "id": "RGjlIW_f6GA0"
397
+ },
398
+ "outputs": [],
399
+ "source": [
400
+ "# CLIP model\n",
401
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
402
+ "CLIP_COMMIT_ID = None\n",
403
+ "\n",
404
+ "# Load CLIP\n",
405
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
406
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
407
+ ")\n",
408
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
409
+ "clip_params = replicate(clip_params)\n",
410
+ "\n",
411
+ "# score images\n",
412
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
413
+ "def p_clip(inputs, params):\n",
414
+ " logits = clip(params=params, **inputs).logits_per_image\n",
415
+ " return logits"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": null,
421
+ "metadata": {
422
+ "id": "FoLXpjCmGpju"
423
+ },
424
+ "outputs": [],
425
+ "source": [
426
+ "from flax.training.common_utils import shard\n",
427
+ "\n",
428
+ "# get clip scores\n",
429
+ "clip_inputs = clip_processor(\n",
430
+ " text=prompts * jax.device_count(),\n",
431
+ " images=images,\n",
432
+ " return_tensors=\"np\",\n",
433
+ " padding=\"max_length\",\n",
434
+ " max_length=77,\n",
435
+ " truncation=True,\n",
436
+ ").data\n",
437
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
438
+ "\n",
439
+ "# organize scores per prompt\n",
440
+ "p = len(prompts)\n",
441
+ "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "markdown",
446
+ "metadata": {
447
+ "id": "4AAWRm70LgED"
448
+ },
449
+ "source": [
450
+ "Let's now display images ranked by CLIP score."
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "metadata": {
457
+ "id": "zsgxxubLLkIu"
458
+ },
459
+ "outputs": [],
460
+ "source": [
461
+ "for i, prompt in enumerate(prompts):\n",
462
+ " print(f\"Prompt: {prompt}\\n\")\n",
463
+ " for idx in logits[i].argsort()[::-1]:\n",
464
+ " display(images[idx * p + i])\n",
465
+ " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
466
+ " print()"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "metadata": {
472
+ "id": "oZT9i3jCjir0"
473
+ },
474
+ "source": [
475
+ "## 🪄 Optional: Save your Generated Images as W&B Tables\n",
476
+ "\n",
477
+ "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "metadata": {
484
+ "id": "-pSiv6Vwjkn0"
485
+ },
486
+ "outputs": [],
487
+ "source": [
488
+ "import wandb\n",
489
+ "\n",
490
+ "# Initialize a W&B run.\n",
491
+ "project = 'dalle-mini-tables-colab'\n",
492
+ "run = wandb.init(project=project)\n",
493
+ "\n",
494
+ "# Initialize an empty W&B Tables.\n",
495
+ "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
496
+ "gen_table = wandb.Table(columns=columns)\n",
497
+ "\n",
498
+ "# Add data to the table.\n",
499
+ "for i, prompt in enumerate(prompts):\n",
500
+ " # If CLIP scores exist, sort the Images\n",
501
+ " if logits is not None:\n",
502
+ " idxs = logits[i].argsort()[::-1]\n",
503
+ " tmp_imgs = images[i::2]\n",
504
+ " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
505
+ " else:\n",
506
+ " tmp_imgs = images[i::2]\n",
507
+ "\n",
508
+ " # Add the data to the table.\n",
509
+ " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
510
+ "\n",
511
+ "# Log the Table to W&B dashboard.\n",
512
+ "wandb.log({\"Generated Images\": gen_table})\n",
513
+ "\n",
514
+ "# Close the W&B run.\n",
515
+ "run.finish()"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "markdown",
520
+ "metadata": {
521
+ "id": "Ck2ZnHwVjnRd"
522
+ },
523
+ "source": [
524
+ "Click on the link above to check out your generated images."
525
+ ]
526
+ }
527
+ ],
528
+ "metadata": {
529
+ "accelerator": "GPU",
530
+ "colab": {
531
+ "collapsed_sections": [],
532
+ "include_colab_link": true,
533
+ "machine_shape": "hm",
534
+ "name": "DALL·E mini - Inference pipeline.ipynb",
535
+ "provenance": []
536
+ },
537
+ "kernelspec": {
538
+ "display_name": "Python 3 (ipykernel)",
539
+ "language": "python",
540
+ "name": "python3"
541
+ },
542
+ "language_info": {
543
+ "codemirror_mode": {
544
+ "name": "ipython",
545
+ "version": 3
546
+ },
547
+ "file_extension": ".py",
548
+ "mimetype": "text/x-python",
549
+ "name": "python",
550
+ "nbconvert_exporter": "python",
551
+ "pygments_lexer": "ipython3",
552
+ "version": "3.9.7"
553
+ }
554
+ },
555
+ "nbformat": 4,
556
+ "nbformat_minor": 0
557
+ }
logo.png ADDED
modeling.py ADDED
@@ -0,0 +1,1909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model. """
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Any, Dict, Optional, Tuple
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from einops import rearrange
26
+ from flax.core.frozen_dict import unfreeze
27
+ from flax.linen import combine_masks, make_causal_mask
28
+ from flax.linen import partitioning as nn_partitioning
29
+ from flax.linen.linear import PrecisionLike
30
+ from flax.traverse_util import flatten_dict, unflatten_dict
31
+ from jax import custom_jvp, lax
32
+ from jax.random import PRNGKey
33
+ from transformers.generation_flax_utils import FlaxSampleOutput
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutput,
36
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
37
+ FlaxCausalLMOutputWithCrossAttentions,
38
+ FlaxSeq2SeqLMOutput,
39
+ )
40
+ from transformers.modeling_flax_utils import ACT2FN
41
+ from transformers.models.bart.modeling_flax_bart import (
42
+ FlaxBartAttention,
43
+ FlaxBartForConditionalGeneration,
44
+ FlaxBartForConditionalGenerationModule,
45
+ FlaxBartModule,
46
+ )
47
+ from transformers.utils import logging
48
+
49
+ from .configuration import DalleBartConfig
50
+ from .utils import PretrainedFromWandbMixin
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ remat = nn_partitioning.remat
55
+
56
+
57
+ def smelu(beta: Any = 1.0):
58
+ """
59
+ Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
60
+ https://arxiv.org/abs/2202.06499
61
+ """
62
+
63
+ @custom_jvp
64
+ @jax.jit
65
+ def _smelu(x: Any) -> Any:
66
+ x = jnp.where(x <= -beta, 0.0, x)
67
+ return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
68
+
69
+ _smelu.defjvps(
70
+ lambda g, ans, x: lax.select(
71
+ x == -beta,
72
+ lax.full_like(g, 0),
73
+ lax.select(x == beta, lax.full_like(g, 1), g),
74
+ )
75
+ )
76
+ return _smelu
77
+
78
+
79
+ ACT2FN.update({"smelu": smelu()})
80
+
81
+ # deepnet initialization
82
+ def deepnet_init(gain=1):
83
+ init = jax.nn.initializers.glorot_normal()
84
+
85
+ def _init(*args, **kwargs):
86
+ return gain * init(*args, **kwargs)
87
+
88
+ return _init
89
+
90
+
91
+ # deepnet gain
92
+ deepnet_gain = {
93
+ "encoder": {
94
+ "alpha": lambda config: 0.81
95
+ * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
96
+ "beta": lambda config: 0.87
97
+ * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
98
+ },
99
+ "decoder": {
100
+ "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
101
+ "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
102
+ },
103
+ }
104
+
105
+
106
+ class RMSNorm(nn.Module):
107
+ """
108
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
109
+
110
+ Adapted from flax.linen.LayerNorm
111
+ """
112
+
113
+ epsilon: float = 1e-6
114
+ dtype: Any = jnp.float32
115
+ param_dtype: Any = jnp.float32
116
+ use_scale: bool = True
117
+ scale_init: Any = jax.nn.initializers.ones
118
+
119
+ @nn.compact
120
+ def __call__(self, x):
121
+ reduction_axes = (-1,)
122
+ feature_axes = (-1,)
123
+
124
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
125
+
126
+ return self._normalize(
127
+ self,
128
+ x,
129
+ rms_sq,
130
+ reduction_axes,
131
+ feature_axes,
132
+ self.dtype,
133
+ self.param_dtype,
134
+ self.epsilon,
135
+ self.use_scale,
136
+ self.scale_init,
137
+ )
138
+
139
+ def _compute_rms_sq(self, x, axes):
140
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
141
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
142
+ return rms_sq
143
+
144
+ def _normalize(
145
+ self,
146
+ mdl,
147
+ x,
148
+ rms_sq,
149
+ reduction_axes,
150
+ feature_axes,
151
+ dtype,
152
+ param_dtype,
153
+ epsilon,
154
+ use_scale,
155
+ scale_init,
156
+ ):
157
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
158
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
159
+ stats_shape = list(x.shape)
160
+ for axis in reduction_axes:
161
+ stats_shape[axis] = 1
162
+ rms_sq = rms_sq.reshape(stats_shape)
163
+ feature_shape = [1] * x.ndim
164
+ reduced_feature_shape = []
165
+ for ax in feature_axes:
166
+ feature_shape[ax] = x.shape[ax]
167
+ reduced_feature_shape.append(x.shape[ax])
168
+ mul = lax.rsqrt(rms_sq + epsilon)
169
+ if use_scale:
170
+ scale = mdl.param(
171
+ "scale", scale_init, reduced_feature_shape, param_dtype
172
+ ).reshape(feature_shape)
173
+ mul *= scale
174
+ y = mul * x
175
+ return jnp.asarray(y, dtype)
176
+
177
+
178
+ def norm(type, *args, **kwargs):
179
+ if type == "rmsnorm":
180
+ return RMSNorm(*args, **kwargs)
181
+ elif type == "layernorm":
182
+ return nn.LayerNorm(*args, **kwargs)
183
+ else:
184
+ raise ValueError(f"Unknown norm type {type}")
185
+
186
+
187
+ def dot_product_attention_weights(
188
+ query: Any,
189
+ key: Any,
190
+ bias: Optional[Any] = None,
191
+ mask: Optional[Any] = None,
192
+ embed_pos: Optional[Any] = None,
193
+ broadcast_dropout: bool = True,
194
+ dropout_rng: Optional[PRNGKey] = None,
195
+ dropout_rate: float = 0.0,
196
+ deterministic: bool = False,
197
+ dtype: Any = jnp.float32,
198
+ precision: PrecisionLike = None,
199
+ sinkhorn_iters: int = 1,
200
+ is_encoder: bool = False,
201
+ ):
202
+ """
203
+ Computes dot-product attention weights given query and key.
204
+ mask is included into the bias.
205
+
206
+ Adapted from flax.linen.attention.dot_product_attention_weights"
207
+ """
208
+ assert query.ndim == key.ndim, "q, k must have same rank."
209
+ assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
210
+ assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
211
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
212
+
213
+ # calculate attention matrix
214
+ depth = query.shape[-1]
215
+ query = query / jnp.sqrt(depth).astype(dtype)
216
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
217
+ attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
218
+
219
+ # apply attention bias: masking, dropout, proximity bias, etc.
220
+ if bias is not None:
221
+ attn_weights = attn_weights + bias
222
+
223
+ # add relative position
224
+ if embed_pos is not None:
225
+ attn_weights = attn_weights + embed_pos
226
+
227
+ # normalize the attention weights
228
+ if not is_encoder or sinkhorn_iters == 1:
229
+ # sinkhorn does not work for causal (leaks info of future tokens into past)
230
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
231
+ else:
232
+ # adapted from https://github.com/lucidrains/sinkhorn-transformer
233
+ for i in range(sinkhorn_iters):
234
+ # when causal, some attn_weights have been set to -inf through bias
235
+ if i % 2 == 0:
236
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
237
+ else:
238
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
239
+ if mask is not None:
240
+ attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
241
+ attn_weights = jnp.exp(attn_weights).astype(dtype)
242
+
243
+ # apply attention dropout
244
+ if not deterministic and dropout_rate > 0.0:
245
+ keep_prob = 1.0 - dropout_rate
246
+ if broadcast_dropout:
247
+ # dropout is broadcast across the batch + head dimensions
248
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
249
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
250
+ else:
251
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
252
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
253
+ keep_prob, dtype=dtype
254
+ )
255
+ attn_weights = attn_weights * multiplier
256
+
257
+ return attn_weights
258
+
259
+
260
+ class FlaxBartAttention(FlaxBartAttention):
261
+ """
262
+ Edits:
263
+ - causal mask is used only in decoder and considers image_length
264
+ - scale attention heads per NormFormer paper
265
+ """
266
+
267
+ is_encoder: bool = False
268
+ q_length: int = None
269
+ k_length: int = None
270
+
271
+ def setup(self) -> None:
272
+ self.head_dim = self.embed_dim // self.num_heads
273
+ if self.head_dim * self.num_heads != self.embed_dim:
274
+ raise ValueError(
275
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
276
+ f" and `num_heads`: {self.num_heads})."
277
+ )
278
+
279
+ dense = partial(
280
+ nn.Dense,
281
+ self.embed_dim,
282
+ use_bias=self.bias,
283
+ dtype=self.dtype,
284
+ )
285
+
286
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
287
+ self.config
288
+ )
289
+
290
+ self.q_proj = dense(
291
+ kernel_init=deepnet_init()
292
+ if self.config.use_deepnet_scaling
293
+ else jax.nn.initializers.normal(self.config.init_std)
294
+ )
295
+ self.k_proj = dense(
296
+ kernel_init=deepnet_init()
297
+ if self.config.use_deepnet_scaling
298
+ else jax.nn.initializers.normal(self.config.init_std)
299
+ )
300
+ self.v_proj = dense(
301
+ kernel_init=deepnet_init(gain)
302
+ if self.config.use_deepnet_scaling
303
+ else jax.nn.initializers.normal(self.config.init_std)
304
+ )
305
+ self.out_proj = dense(
306
+ kernel_init=deepnet_init(gain)
307
+ if self.config.use_deepnet_scaling
308
+ else jax.nn.initializers.normal(self.config.init_std)
309
+ )
310
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
311
+
312
+ if self.config.use_head_scale:
313
+ self.head_scale = self.param(
314
+ "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
315
+ )
316
+
317
+ if self.config.use_cosine_attention:
318
+ self.tau = self.param(
319
+ "tau",
320
+ jax.nn.initializers.constant(self.config.tau_init),
321
+ (1, self.num_heads, 1, 1),
322
+ )
323
+
324
+ if self.config.use_swin_position_embeddings:
325
+ self.rel_bias = nn.Embed(
326
+ self.q_length,
327
+ self.k_length * self.num_heads,
328
+ embedding_init=deepnet_init()
329
+ if self.config.use_deepnet_scaling
330
+ else jax.nn.initializers.normal(self.config.init_std),
331
+ )
332
+
333
+ if self.causal:
334
+ # used only in decoder
335
+ self.causal_mask = make_causal_mask(
336
+ jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
337
+ )
338
+
339
+ def __call__(
340
+ self,
341
+ hidden_states: jnp.ndarray,
342
+ key_value_states: Optional[jnp.ndarray] = None,
343
+ attention_mask: Optional[jnp.ndarray] = None,
344
+ init_cache: bool = False,
345
+ deterministic: bool = True,
346
+ ) -> Tuple[jnp.ndarray]:
347
+ """Input shape: Batch x Time x Channel"""
348
+
349
+ # if key_value_states are provided this layer is used as a cross-attention layer
350
+ # for the decoder
351
+ is_cross_attention = key_value_states is not None
352
+ batch_size = hidden_states.shape[0]
353
+
354
+ # get query proj
355
+ query_states = self.q_proj(hidden_states)
356
+ # get key, value proj
357
+ if is_cross_attention:
358
+ # cross_attentions
359
+ key_states = self.k_proj(key_value_states)
360
+ value_states = self.v_proj(key_value_states)
361
+ else:
362
+ # self_attention
363
+ key_states = self.k_proj(hidden_states)
364
+ value_states = self.v_proj(hidden_states)
365
+
366
+ query_states = self._split_heads(query_states)
367
+ key_states = self._split_heads(key_states)
368
+ value_states = self._split_heads(value_states)
369
+
370
+ # handle cache prepare causal attention mask
371
+ if self.causal:
372
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
373
+ if self.has_variable("cache", "cached_key"):
374
+ mask_shift = self.variables["cache"]["cache_index"]
375
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
376
+ causal_mask = lax.dynamic_slice(
377
+ self.causal_mask,
378
+ (0, 0, mask_shift, 0),
379
+ (1, 1, query_length, max_decoder_length),
380
+ )
381
+ else:
382
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
383
+ causal_mask = jnp.broadcast_to(
384
+ causal_mask, (batch_size,) + causal_mask.shape[1:]
385
+ )
386
+
387
+ # combine masks if needed
388
+ if attention_mask is not None and self.causal:
389
+ attention_mask = jnp.broadcast_to(
390
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
391
+ )
392
+ attention_mask = combine_masks(attention_mask, causal_mask)
393
+ elif self.causal:
394
+ attention_mask = causal_mask
395
+ elif attention_mask is not None:
396
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
397
+
398
+ # During fast autoregressive decoding, we feed one position at a time,
399
+ # and cache the keys and values step by step.
400
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
401
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
402
+ key_states, value_states, query_states, attention_mask
403
+ )
404
+
405
+ # Convert the boolean attention mask to an attention bias.
406
+ if attention_mask is not None:
407
+ # attention mask in the form of attention bias
408
+ attention_bias = lax.select(
409
+ attention_mask > 0,
410
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
411
+ jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
412
+ )
413
+ else:
414
+ attention_bias = None
415
+
416
+ dropout_rng = None
417
+ if not deterministic and self.dropout > 0.0:
418
+ dropout_rng = self.make_rng("dropout")
419
+
420
+ if self.config.use_cosine_attention:
421
+ # normalize q and k
422
+ query_states = query_states / (
423
+ jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
424
+ )
425
+ key_states = key_states / (
426
+ jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
427
+ )
428
+
429
+ # relative position embeddings
430
+ if self.config.use_swin_position_embeddings:
431
+ position_ids = jnp.arange(self.q_length)
432
+ embed_pos = self.rel_bias(position_ids)
433
+ embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
434
+ else:
435
+ embed_pos = None
436
+
437
+ attn_weights = dot_product_attention_weights(
438
+ query_states,
439
+ key_states,
440
+ bias=attention_bias,
441
+ mask=attention_mask,
442
+ embed_pos=embed_pos,
443
+ dropout_rng=dropout_rng,
444
+ dropout_rate=self.dropout,
445
+ broadcast_dropout=True,
446
+ deterministic=deterministic,
447
+ dtype=self.dtype,
448
+ precision=None,
449
+ sinkhorn_iters=self.config.sinkhorn_iters,
450
+ is_encoder=self.is_encoder,
451
+ )
452
+ if self.config.use_cosine_attention:
453
+ # divide by tau
454
+ attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
455
+
456
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
457
+ if self.config.use_head_scale:
458
+ # per Normformer
459
+ attn_output = attn_output * self.head_scale
460
+ attn_output = self._merge_heads(attn_output)
461
+ attn_output = self.out_proj(attn_output)
462
+
463
+ return attn_output, attn_weights
464
+
465
+
466
+ class GLU(nn.Module):
467
+ """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
468
+
469
+ config: DalleBartConfig
470
+ ffn_dim: int
471
+ embed_dim: int
472
+ dtype: jnp.dtype = jnp.float32
473
+ is_encoder: bool = False
474
+
475
+ @nn.compact
476
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
477
+
478
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
479
+ self.config
480
+ )
481
+
482
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
483
+ x = norm(
484
+ self.config.ln_type,
485
+ dtype=self.dtype,
486
+ epsilon=1e-05,
487
+ use_scale=self.config.force_ln_scale,
488
+ )(x)
489
+ w = nn.Dense(
490
+ self.ffn_dim,
491
+ dtype=self.dtype,
492
+ use_bias=self.config.use_bias,
493
+ kernel_init=deepnet_init(gain)
494
+ if self.config.use_deepnet_scaling
495
+ else jax.nn.initializers.normal(self.config.init_std),
496
+ )(x)
497
+ w = ACT2FN[self.config.activation_function](w)
498
+ v = nn.Dense(
499
+ self.ffn_dim,
500
+ dtype=self.dtype,
501
+ use_bias=self.config.use_bias,
502
+ kernel_init=deepnet_init(gain)
503
+ if self.config.use_deepnet_scaling
504
+ else jax.nn.initializers.normal(self.config.init_std),
505
+ )(x)
506
+ x = w * v
507
+ if self.config.ln_positions in ["normformer"]:
508
+ x = norm(
509
+ self.config.ln_type,
510
+ dtype=self.dtype,
511
+ epsilon=1e-05,
512
+ use_scale=self.config.force_ln_scale,
513
+ )(x)
514
+ x = nn.Dropout(rate=self.config.activation_dropout)(
515
+ x, deterministic=deterministic
516
+ )
517
+
518
+ x = nn.Dense(
519
+ self.embed_dim,
520
+ dtype=self.dtype,
521
+ use_bias=self.config.use_bias,
522
+ kernel_init=deepnet_init(gain)
523
+ if self.config.use_deepnet_scaling
524
+ else jax.nn.initializers.normal(self.config.init_std),
525
+ )(x)
526
+ if self.config.ln_positions in ["swinv2", "cogview"]:
527
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
528
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
529
+ return x
530
+
531
+
532
+ class FFN(nn.Module):
533
+ """Simple FFN layer"""
534
+
535
+ config: DalleBartConfig
536
+ ffn_dim: int
537
+ embed_dim: int
538
+ dtype: jnp.dtype = jnp.float32
539
+ is_encoder: bool = False
540
+
541
+ @nn.compact
542
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
543
+
544
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
545
+ self.config
546
+ )
547
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
548
+ x = norm(
549
+ self.config.ln_type,
550
+ dtype=self.dtype,
551
+ epsilon=1e-05,
552
+ use_scale=self.config.force_ln_scale,
553
+ )(x)
554
+ x = nn.Dense(
555
+ self.ffn_dim,
556
+ dtype=self.dtype,
557
+ use_bias=self.config.use_bias,
558
+ kernel_init=deepnet_init(gain)
559
+ if self.config.use_deepnet_scaling
560
+ else jax.nn.initializers.normal(self.config.init_std),
561
+ )(x)
562
+ x = ACT2FN[self.config.activation_function](x)
563
+ if self.config.ln_positions in ["normformer"]:
564
+ x = norm(
565
+ self.config.ln_type,
566
+ dtype=self.dtype,
567
+ epsilon=1e-05,
568
+ use_scale=self.config.force_ln_scale,
569
+ )(x)
570
+ x = nn.Dropout(rate=self.config.activation_dropout)(
571
+ x, deterministic=deterministic
572
+ )
573
+ x = nn.Dense(
574
+ self.embed_dim,
575
+ dtype=self.dtype,
576
+ use_bias=self.config.use_bias,
577
+ kernel_init=deepnet_init(gain)
578
+ if self.config.use_deepnet_scaling
579
+ else jax.nn.initializers.normal(self.config.init_std),
580
+ )(x)
581
+ if self.config.ln_positions in ["swinv2", "cogview"]:
582
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
583
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
584
+ return x
585
+
586
+
587
+ class FlaxBartEncoderLayer(nn.Module):
588
+ """
589
+ Edits:
590
+ - no bias
591
+ - use custom FlaxBartAttention
592
+ """
593
+
594
+ config: DalleBartConfig
595
+ dtype: jnp.dtype = jnp.float32
596
+ add_norm: bool = False
597
+ use_scale: bool = True
598
+
599
+ @nn.compact
600
+ def __call__(
601
+ self,
602
+ hidden_states: jnp.ndarray,
603
+ attention_mask: jnp.ndarray,
604
+ output_attentions: bool = True,
605
+ deterministic: bool = True,
606
+ ) -> Tuple[jnp.ndarray]:
607
+
608
+ if self.config.use_scan:
609
+ hidden_states = hidden_states[0]
610
+
611
+ res_gain = (
612
+ deepnet_gain["encoder"]["alpha"](self.config)
613
+ if self.config.use_deepnet_scaling
614
+ else 1
615
+ )
616
+
617
+ embed_dim = self.config.d_model
618
+ residual = hidden_states
619
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
620
+ hidden_states = norm(
621
+ self.config.ln_type,
622
+ dtype=self.dtype,
623
+ epsilon=1e-05,
624
+ use_scale=self.config.force_ln_scale,
625
+ )(hidden_states)
626
+ hidden_states, attn_weights = FlaxBartAttention(
627
+ config=self.config,
628
+ embed_dim=embed_dim,
629
+ num_heads=self.config.encoder_attention_heads,
630
+ dropout=self.config.attention_dropout,
631
+ bias=self.config.use_bias,
632
+ dtype=self.dtype,
633
+ is_encoder=True,
634
+ q_length=self.config.max_text_length,
635
+ k_length=self.config.max_text_length,
636
+ )(hidden_states=hidden_states, attention_mask=attention_mask)
637
+
638
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
639
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
640
+ hidden_states
641
+ )
642
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
643
+ hidden_states, deterministic=deterministic
644
+ )
645
+ hidden_states = residual * res_gain + hidden_states
646
+ if self.config.ln_positions in ["postln"]:
647
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
648
+ hidden_states
649
+ )
650
+
651
+ residual = hidden_states
652
+ ff_block = (
653
+ GLU(
654
+ config=self.config,
655
+ ffn_dim=self.config.encoder_ffn_dim,
656
+ embed_dim=embed_dim,
657
+ dtype=self.dtype,
658
+ is_encoder=True,
659
+ )
660
+ if self.config.use_glu
661
+ else FFN(
662
+ config=self.config,
663
+ ffn_dim=self.config.encoder_ffn_dim,
664
+ embed_dim=embed_dim,
665
+ dtype=self.dtype,
666
+ is_encoder=True,
667
+ )
668
+ )
669
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
670
+ hidden_states = residual * res_gain + hidden_states
671
+ if self.add_norm:
672
+ use_scale = self.use_scale or self.config.force_ln_scale
673
+ hidden_states = norm(
674
+ self.config.ln_type,
675
+ dtype=self.dtype,
676
+ epsilon=1e-05,
677
+ use_scale=use_scale,
678
+ )(hidden_states)
679
+
680
+ outputs = (hidden_states,)
681
+
682
+ if output_attentions:
683
+ outputs += (attn_weights,)
684
+
685
+ if self.config.use_scan:
686
+ outputs = (outputs, None)
687
+
688
+ return outputs
689
+
690
+
691
+ class FlaxBartDecoderLayer(nn.Module):
692
+ """
693
+ Edits:
694
+ - no bias
695
+ - use custom FlaxBartAttention
696
+ """
697
+
698
+ config: DalleBartConfig
699
+ dtype: jnp.dtype = jnp.float32
700
+ add_norm: bool = False
701
+ use_scale: bool = True
702
+
703
+ @nn.compact
704
+ def __call__(
705
+ self,
706
+ hidden_states: jnp.ndarray,
707
+ attention_mask: jnp.ndarray,
708
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
709
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
710
+ init_cache: bool = False,
711
+ output_attentions: bool = True,
712
+ deterministic: bool = True,
713
+ ) -> Tuple[jnp.ndarray]:
714
+
715
+ if self.config.use_scan:
716
+ hidden_states = hidden_states[0]
717
+
718
+ res_gain = (
719
+ deepnet_gain["decoder"]["alpha"](self.config)
720
+ if self.config.use_deepnet_scaling
721
+ else 1
722
+ )
723
+
724
+ embed_dim = self.config.d_model
725
+ residual = hidden_states
726
+
727
+ # Self Attention
728
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
729
+ hidden_states = norm(
730
+ self.config.ln_type,
731
+ dtype=self.dtype,
732
+ epsilon=1e-05,
733
+ use_scale=self.config.force_ln_scale,
734
+ )(hidden_states)
735
+ hidden_states, attn_weights = FlaxBartAttention(
736
+ config=self.config,
737
+ embed_dim=embed_dim,
738
+ num_heads=self.config.decoder_attention_heads,
739
+ dropout=self.config.attention_dropout,
740
+ causal=True,
741
+ bias=self.config.use_bias,
742
+ dtype=self.dtype,
743
+ is_encoder=False,
744
+ q_length=self.config.image_length,
745
+ k_length=self.config.image_length,
746
+ )(
747
+ hidden_states=hidden_states,
748
+ attention_mask=attention_mask,
749
+ init_cache=init_cache,
750
+ )
751
+
752
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
753
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
754
+ hidden_states
755
+ )
756
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
757
+ hidden_states, deterministic=deterministic
758
+ )
759
+ hidden_states = residual * res_gain + hidden_states
760
+ if self.config.ln_positions in ["postln"]:
761
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
762
+ hidden_states
763
+ )
764
+
765
+ # Cross Attention
766
+ cross_attn_weights = None
767
+ if encoder_hidden_states is not None:
768
+ residual = hidden_states
769
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
770
+ hidden_states = norm(
771
+ self.config.ln_type,
772
+ dtype=self.dtype,
773
+ epsilon=1e-05,
774
+ use_scale=self.config.force_ln_scale,
775
+ )(hidden_states)
776
+ hidden_states, cross_attn_weights = FlaxBartAttention(
777
+ config=self.config,
778
+ embed_dim=embed_dim,
779
+ num_heads=self.config.decoder_attention_heads,
780
+ dropout=self.config.attention_dropout,
781
+ bias=self.config.use_bias,
782
+ dtype=self.dtype,
783
+ is_encoder=False,
784
+ q_length=self.config.image_length,
785
+ k_length=self.config.max_text_length,
786
+ )(
787
+ hidden_states=hidden_states,
788
+ key_value_states=encoder_hidden_states,
789
+ attention_mask=encoder_attention_mask,
790
+ )
791
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
792
+ hidden_states = norm(
793
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
794
+ )(hidden_states)
795
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
796
+ hidden_states, deterministic=deterministic
797
+ )
798
+ hidden_states = residual * res_gain + hidden_states
799
+ if self.config.ln_positions in ["postln"]:
800
+ hidden_states = norm(
801
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
802
+ )(hidden_states)
803
+
804
+ # Feed forward
805
+ residual = hidden_states
806
+ ff_block = (
807
+ GLU(
808
+ config=self.config,
809
+ ffn_dim=self.config.decoder_ffn_dim,
810
+ embed_dim=embed_dim,
811
+ dtype=self.dtype,
812
+ is_encoder=False,
813
+ )
814
+ if self.config.use_glu
815
+ else FFN(
816
+ config=self.config,
817
+ ffn_dim=self.config.decoder_ffn_dim,
818
+ embed_dim=embed_dim,
819
+ dtype=self.dtype,
820
+ is_encoder=False,
821
+ )
822
+ )
823
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
824
+ hidden_states = residual * res_gain + hidden_states
825
+ if self.add_norm:
826
+ use_scale = self.use_scale or self.config.force_ln_scale
827
+ hidden_states = norm(
828
+ self.config.ln_type,
829
+ dtype=self.dtype,
830
+ epsilon=1e-05,
831
+ use_scale=use_scale,
832
+ )(hidden_states)
833
+
834
+ outputs = (hidden_states,)
835
+
836
+ if output_attentions:
837
+ outputs += (attn_weights, cross_attn_weights)
838
+
839
+ if self.config.use_scan:
840
+ outputs = (outputs, None)
841
+
842
+ return outputs
843
+
844
+
845
+ class FlaxBartEncoderLayerCollection(nn.Module):
846
+ config: DalleBartConfig
847
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
848
+ """
849
+ Edits:
850
+ - use custom FlaxBartEncoderLayer
851
+ - allow Gradient Checkpointing (nn.remat)
852
+ """
853
+
854
+ @nn.compact
855
+ def __call__(
856
+ self,
857
+ hidden_states,
858
+ attention_mask,
859
+ deterministic: bool = True,
860
+ output_attentions: bool = False,
861
+ output_hidden_states: bool = False,
862
+ return_dict: bool = True,
863
+ ):
864
+ all_hidden_states = () if output_hidden_states else None
865
+ all_self_attns = () if output_attentions else None
866
+
867
+ n_layers = self.config.encoder_layers
868
+ layer = (
869
+ remat(
870
+ FlaxBartEncoderLayer,
871
+ static_argnums=(2, 3),
872
+ prevent_cse=not self.config.use_scan,
873
+ )
874
+ if self.config.gradient_checkpointing
875
+ else FlaxBartEncoderLayer
876
+ )
877
+
878
+ if self.config.use_scan:
879
+ # all blocks are the same so we use nn.scan
880
+ assert not output_attentions, "cannot scan with output_attentions"
881
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
882
+ hidden_states = (hidden_states,)
883
+ # we use a scale on all norms (even last layer) to allow scanning
884
+ hidden_states, _ = nn.scan(
885
+ layer,
886
+ variable_axes={"params": 0, "cache": 0},
887
+ split_rngs={"params": True, "dropout": True},
888
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
889
+ length=n_layers,
890
+ )(
891
+ self.config,
892
+ dtype=self.dtype,
893
+ add_norm=self.config.ln_positions == "postln",
894
+ name="FlaxBartEncoderLayers",
895
+ )(
896
+ hidden_states,
897
+ attention_mask,
898
+ output_attentions,
899
+ deterministic,
900
+ )
901
+ hidden_states = hidden_states[0]
902
+ else:
903
+ for i in range(n_layers):
904
+ if output_hidden_states:
905
+ all_hidden_states += (hidden_states,)
906
+ # final layernorm on the output of the last layer
907
+ # or every 6 layers for Swin v2
908
+ add_norm = self.config.ln_positions == "postln" or (
909
+ self.config.ln_positions == "swinv2"
910
+ and ((i + 1) % 6 == 0)
911
+ and (i != n_layers - 1)
912
+ )
913
+ # we don't need to scale the norm for the last layer
914
+ use_scale = i != n_layers - 1
915
+ layer_outputs = layer(
916
+ self.config,
917
+ dtype=self.dtype,
918
+ add_norm=add_norm,
919
+ use_scale=use_scale,
920
+ name=f"FlaxBartEncoderLayer_{i}",
921
+ )(
922
+ hidden_states,
923
+ attention_mask,
924
+ output_attentions,
925
+ deterministic,
926
+ )
927
+ hidden_states = layer_outputs[0]
928
+ if output_attentions:
929
+ all_self_attns += (layer_outputs[1],)
930
+
931
+ # add hidden states from the last layer
932
+ if output_hidden_states:
933
+ all_hidden_states += (hidden_states,)
934
+
935
+ outputs = [
936
+ hidden_states,
937
+ all_hidden_states,
938
+ all_self_attns,
939
+ ]
940
+
941
+ if not return_dict:
942
+ return tuple(v for v in outputs if v is not None)
943
+
944
+ return FlaxBaseModelOutput(
945
+ last_hidden_state=hidden_states,
946
+ hidden_states=all_hidden_states,
947
+ attentions=all_self_attns,
948
+ )
949
+
950
+
951
+ class FlaxBartDecoderLayerCollection(nn.Module):
952
+ config: DalleBartConfig
953
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
954
+ """
955
+ Edits:
956
+ - use custom FlaxBartDecoderLayer
957
+ - allow Gradient Checkpointing (nn.remat)
958
+ """
959
+
960
+ @nn.compact
961
+ def __call__(
962
+ self,
963
+ hidden_states,
964
+ attention_mask,
965
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
966
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
967
+ deterministic: bool = True,
968
+ init_cache: bool = False,
969
+ output_attentions: bool = False,
970
+ output_hidden_states: bool = False,
971
+ return_dict: bool = True,
972
+ ):
973
+ # decoder layers
974
+ all_hidden_states = () if output_hidden_states else None
975
+ all_self_attns = () if output_attentions else None
976
+ all_cross_attentions = (
977
+ () if (output_attentions and encoder_hidden_states is not None) else None
978
+ )
979
+
980
+ n_layers = self.config.decoder_layers
981
+ layer = (
982
+ remat(
983
+ FlaxBartDecoderLayer,
984
+ static_argnums=(4, 5, 6),
985
+ prevent_cse=not self.config.use_scan,
986
+ )
987
+ if self.config.gradient_checkpointing
988
+ else FlaxBartDecoderLayer
989
+ )
990
+
991
+ if self.config.use_scan:
992
+ # all blocks are the same so we use nn.scan
993
+ assert not output_attentions, "cannot scan with output_attentions"
994
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
995
+ hidden_states = (hidden_states,)
996
+ # we use a scale on all norms (even last layer) to allow scanning
997
+ hidden_states, _ = nn.scan(
998
+ layer,
999
+ variable_axes={"params": 0, "cache": 0},
1000
+ split_rngs={"params": True, "dropout": True},
1001
+ in_axes=(
1002
+ nn.broadcast,
1003
+ nn.broadcast,
1004
+ nn.broadcast,
1005
+ nn.broadcast,
1006
+ nn.broadcast,
1007
+ nn.broadcast,
1008
+ ),
1009
+ length=n_layers,
1010
+ )(
1011
+ self.config,
1012
+ dtype=self.dtype,
1013
+ add_norm=self.config.ln_positions == "postln",
1014
+ name="FlaxBartDecoderLayers",
1015
+ )(
1016
+ hidden_states,
1017
+ attention_mask,
1018
+ encoder_hidden_states,
1019
+ encoder_attention_mask,
1020
+ init_cache,
1021
+ output_attentions,
1022
+ deterministic,
1023
+ )
1024
+ hidden_states = hidden_states[0]
1025
+
1026
+ else:
1027
+ for i in range(n_layers):
1028
+ if output_hidden_states:
1029
+ all_hidden_states += (hidden_states,)
1030
+ # final layernorm on the output of the last layer
1031
+ # or every 6 layers for Swin v2
1032
+ add_norm = self.config.ln_positions == "postln" or (
1033
+ self.config.ln_positions == "swinv2"
1034
+ and ((i + 1) % 6 == 0)
1035
+ and (i != n_layers - 1)
1036
+ )
1037
+ # we don't need to scale the norm for the last layer
1038
+ use_scale = i != n_layers - 1
1039
+ layer_outputs = layer(
1040
+ self.config,
1041
+ dtype=self.dtype,
1042
+ add_norm=add_norm,
1043
+ use_scale=use_scale,
1044
+ name=f"FlaxBartDecoderLayer_{i}",
1045
+ )(
1046
+ hidden_states,
1047
+ attention_mask,
1048
+ encoder_hidden_states,
1049
+ encoder_attention_mask,
1050
+ init_cache,
1051
+ output_attentions,
1052
+ deterministic,
1053
+ )
1054
+
1055
+ hidden_states = layer_outputs[0]
1056
+ if output_attentions:
1057
+ all_self_attns += (layer_outputs[1],)
1058
+
1059
+ if encoder_hidden_states is not None:
1060
+ all_cross_attentions += (layer_outputs[2],)
1061
+
1062
+ # add hidden states from the last decoder layer
1063
+ if output_hidden_states:
1064
+ all_hidden_states += (hidden_states,)
1065
+
1066
+ outputs = [
1067
+ hidden_states,
1068
+ all_hidden_states,
1069
+ all_self_attns,
1070
+ all_cross_attentions,
1071
+ ]
1072
+
1073
+ if not return_dict:
1074
+ return tuple(v for v in outputs if v is not None)
1075
+
1076
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1077
+ last_hidden_state=hidden_states,
1078
+ hidden_states=all_hidden_states,
1079
+ attentions=all_self_attns,
1080
+ cross_attentions=all_cross_attentions,
1081
+ )
1082
+
1083
+
1084
+ class FlaxBartEncoder(nn.Module):
1085
+ config: DalleBartConfig
1086
+ embed_tokens: nn.Embed
1087
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1088
+ """
1089
+ Edits:
1090
+ - offset set to 0 (no padding token)
1091
+ - use max_text_length instead of max_position_embeddings
1092
+ - use custom FlaxBartEncoderLayerCollection
1093
+ - embed_tokens cannot be None (issue at compile time)
1094
+ """
1095
+
1096
+ def setup(self):
1097
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1098
+
1099
+ embed_dim = self.config.d_model
1100
+ self.padding_idx = self.config.pad_token_id
1101
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
1102
+
1103
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1104
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1105
+ self.offset = 0
1106
+ if self.config.use_absolute_position_embeddings:
1107
+ self.embed_positions = nn.Embed(
1108
+ self.config.max_text_length + self.offset, # image length for BOS
1109
+ embed_dim,
1110
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1111
+ )
1112
+ self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
1113
+ self.layernorm_embedding = norm(
1114
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1115
+ )
1116
+
1117
+ # postln is already applied in every layer
1118
+ if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
1119
+ self.final_ln = norm(
1120
+ self.config.ln_type,
1121
+ dtype=self.dtype,
1122
+ epsilon=1e-05,
1123
+ use_scale=self.config.force_ln_scale,
1124
+ )
1125
+ else:
1126
+ self.final_ln = None
1127
+
1128
+ def __call__(
1129
+ self,
1130
+ input_ids,
1131
+ attention_mask,
1132
+ position_ids,
1133
+ output_attentions: bool = False,
1134
+ output_hidden_states: bool = False,
1135
+ return_dict: bool = True,
1136
+ deterministic: bool = True,
1137
+ ):
1138
+ input_shape = input_ids.shape
1139
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1140
+
1141
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1142
+
1143
+ if self.config.use_absolute_position_embeddings:
1144
+ embed_pos = self.embed_positions(position_ids + self.offset)
1145
+ hidden_states = hidden_states + embed_pos
1146
+
1147
+ hidden_states = self.layernorm_embedding(hidden_states)
1148
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1149
+
1150
+ outputs = self.layers(
1151
+ hidden_states,
1152
+ attention_mask,
1153
+ deterministic=deterministic,
1154
+ output_attentions=output_attentions,
1155
+ output_hidden_states=output_hidden_states,
1156
+ return_dict=return_dict,
1157
+ )
1158
+
1159
+ if self.final_ln is None:
1160
+ final_output = outputs[0]
1161
+ else:
1162
+ final_output = self.final_ln(outputs[0])
1163
+
1164
+ if not return_dict:
1165
+ return (final_output,) + outputs[1:]
1166
+
1167
+ return FlaxBaseModelOutput(
1168
+ last_hidden_state=final_output,
1169
+ hidden_states=outputs.hidden_states,
1170
+ attentions=outputs.attentions,
1171
+ )
1172
+
1173
+
1174
+ class FlaxBartDecoder(nn.Module):
1175
+ config: DalleBartConfig
1176
+ embed_tokens: nn.Embed
1177
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1178
+ """
1179
+ Edits:
1180
+ - offset set to 0 (no padding token)
1181
+ - use image_length instead of max_position_embeddings
1182
+ - use custom FlaxBartDecoderLayerCollection
1183
+ - embed_tokens cannot be None (issue at compile time)
1184
+ """
1185
+
1186
+ def setup(self):
1187
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1188
+
1189
+ embed_dim = self.config.d_model
1190
+ self.padding_idx = self.config.pad_token_id
1191
+ self.embed_scale = (
1192
+ math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
1193
+ )
1194
+
1195
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1196
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1197
+ self.offset = 0
1198
+ if self.config.use_absolute_position_embeddings:
1199
+ self.embed_positions = nn.Embed(
1200
+ self.config.image_length + self.offset, # image length for BOS
1201
+ embed_dim,
1202
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1203
+ )
1204
+
1205
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1206
+ self.layernorm_embedding = norm(
1207
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1208
+ )
1209
+
1210
+ # postln is already applied in every layer
1211
+ if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1212
+ self.final_ln = norm(
1213
+ self.config.ln_type,
1214
+ dtype=self.dtype,
1215
+ epsilon=1e-05,
1216
+ use_scale=self.config.force_ln_scale,
1217
+ )
1218
+
1219
+ def __call__(
1220
+ self,
1221
+ input_ids,
1222
+ attention_mask,
1223
+ position_ids,
1224
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1225
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1226
+ init_cache: bool = False,
1227
+ output_attentions: bool = False,
1228
+ output_hidden_states: bool = False,
1229
+ return_dict: bool = True,
1230
+ deterministic: bool = True,
1231
+ ):
1232
+ input_shape = input_ids.shape
1233
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1234
+
1235
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1236
+
1237
+ if self.config.use_absolute_position_embeddings:
1238
+ embed_pos = self.embed_positions(position_ids + self.offset)
1239
+ hidden_states = hidden_states + embed_pos
1240
+
1241
+ hidden_states = self.layernorm_embedding(hidden_states)
1242
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1243
+
1244
+ outputs = self.layers(
1245
+ hidden_states,
1246
+ attention_mask,
1247
+ encoder_hidden_states,
1248
+ encoder_attention_mask,
1249
+ deterministic=deterministic,
1250
+ init_cache=init_cache,
1251
+ output_attentions=output_attentions,
1252
+ output_hidden_states=output_hidden_states,
1253
+ return_dict=return_dict,
1254
+ )
1255
+
1256
+ if self.final_ln is None:
1257
+ final_output = outputs[0]
1258
+ else:
1259
+ final_output = self.final_ln(outputs[0])
1260
+
1261
+ if not return_dict:
1262
+ return (final_output,) + outputs[1:]
1263
+
1264
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1265
+ last_hidden_state=final_output,
1266
+ hidden_states=outputs.hidden_states,
1267
+ attentions=outputs.attentions,
1268
+ cross_attentions=outputs.cross_attentions,
1269
+ )
1270
+
1271
+
1272
+ class FlaxBartModule(FlaxBartModule):
1273
+ """
1274
+ Edits
1275
+ - use custom FlaxBartEncoder & FlaxBartDecoder
1276
+ - use separate embeddings for Encoder & Decoder
1277
+ """
1278
+
1279
+ def setup(self):
1280
+ encoder_embed_tokens = nn.Embed(
1281
+ self.config.encoder_vocab_size,
1282
+ self.config.d_model,
1283
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1284
+ )
1285
+ decoder_embed_tokens = nn.Embed(
1286
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
1287
+ self.config.d_model,
1288
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1289
+ )
1290
+
1291
+ self.encoder = FlaxBartEncoder(
1292
+ self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
1293
+ )
1294
+ self.decoder = FlaxBartDecoder(
1295
+ self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
1296
+ )
1297
+
1298
+
1299
+ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
1300
+ """
1301
+ Edits:
1302
+ - no bias
1303
+ - lm_head set to image_vocab_size + 1 (for BOS)
1304
+ - uses custom FlaxBartModule
1305
+ """
1306
+
1307
+ def setup(self):
1308
+ self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
1309
+ self.lm_head = nn.Dense(
1310
+ self.config.image_vocab_size
1311
+ + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1312
+ use_bias=False,
1313
+ dtype=self.dtype,
1314
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1315
+ )
1316
+
1317
+ def __call__(
1318
+ self,
1319
+ input_ids,
1320
+ attention_mask,
1321
+ decoder_input_ids,
1322
+ decoder_attention_mask,
1323
+ position_ids,
1324
+ decoder_position_ids,
1325
+ output_attentions: bool = False,
1326
+ output_hidden_states: bool = False,
1327
+ return_dict: bool = True,
1328
+ deterministic: bool = True,
1329
+ ):
1330
+ outputs = self.model(
1331
+ input_ids=input_ids,
1332
+ attention_mask=attention_mask,
1333
+ decoder_input_ids=decoder_input_ids,
1334
+ decoder_attention_mask=decoder_attention_mask,
1335
+ position_ids=position_ids,
1336
+ decoder_position_ids=decoder_position_ids,
1337
+ output_attentions=output_attentions,
1338
+ output_hidden_states=output_hidden_states,
1339
+ return_dict=return_dict,
1340
+ deterministic=deterministic,
1341
+ )
1342
+
1343
+ hidden_states = outputs[0]
1344
+
1345
+ if self.config.tie_word_embeddings:
1346
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
1347
+ lm_logits = self.lm_head.apply(
1348
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1349
+ )
1350
+ else:
1351
+ lm_logits = self.lm_head(hidden_states)
1352
+
1353
+ if not return_dict:
1354
+ output = (lm_logits,) + outputs[1:]
1355
+ return output
1356
+
1357
+ return FlaxSeq2SeqLMOutput(
1358
+ logits=lm_logits,
1359
+ decoder_hidden_states=outputs.decoder_hidden_states,
1360
+ decoder_attentions=outputs.decoder_attentions,
1361
+ cross_attentions=outputs.cross_attentions,
1362
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1363
+ encoder_hidden_states=outputs.encoder_hidden_states,
1364
+ encoder_attentions=outputs.encoder_attentions,
1365
+ )
1366
+
1367
+
1368
+ @flax.struct.dataclass
1369
+ class SampleState:
1370
+ cur_len: jnp.ndarray
1371
+ sequences: jnp.ndarray
1372
+ running_token: jnp.ndarray
1373
+ is_sent_finished: jnp.ndarray
1374
+ prng_key: jnp.ndarray
1375
+ model_kwargs: Dict[str, jnp.ndarray]
1376
+ model_kwargs_uncond: Dict[str, jnp.ndarray]
1377
+
1378
+
1379
+ class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
1380
+ """
1381
+ Edits:
1382
+ - renamed from FlaxBartForConditionalGeneration
1383
+ - uses custom FlaxBartForConditionalGenerationModule
1384
+ - no bias in decode method
1385
+ - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
1386
+ related to position embedding during model.generate()
1387
+ - custom generate method to allow super conditions
1388
+ - num_params property
1389
+ - unscan function
1390
+ """
1391
+
1392
+ module_class = FlaxBartForConditionalGenerationModule
1393
+ config_class = DalleBartConfig
1394
+
1395
+ def num_params(self, params=None):
1396
+ if params is None:
1397
+ params = self.params
1398
+ num_params = jax.tree_map(
1399
+ lambda param: param.size, flatten_dict(unfreeze(params))
1400
+ ).values()
1401
+ return sum(list(num_params))
1402
+
1403
+ def unscan(self, params):
1404
+ if self.config.use_scan:
1405
+ self.config.use_scan = False
1406
+ params = flatten_dict(params)
1407
+ scanned_keys = [k for k in params.keys() if "layers" in k]
1408
+ for k in scanned_keys:
1409
+ v = params[k]
1410
+ name_idx = k.index("layers") + 1
1411
+ for i in range(len(v)):
1412
+ new_k = (
1413
+ *k[:name_idx],
1414
+ f"{k[name_idx][:-1]}_{i}",
1415
+ *k[name_idx + 1 :],
1416
+ )
1417
+ params[new_k] = v[i]
1418
+ del params[k]
1419
+ params = unflatten_dict(params)
1420
+ return params
1421
+
1422
+ def decode(
1423
+ self,
1424
+ decoder_input_ids,
1425
+ encoder_outputs,
1426
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1427
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1428
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1429
+ past_key_values: dict = None,
1430
+ output_attentions: Optional[bool] = None,
1431
+ output_hidden_states: Optional[bool] = None,
1432
+ return_dict: Optional[bool] = None,
1433
+ train: bool = False,
1434
+ params: dict = None,
1435
+ dropout_rng: PRNGKey = None,
1436
+ ):
1437
+ output_attentions = (
1438
+ output_attentions
1439
+ if output_attentions is not None
1440
+ else self.config.output_attentions
1441
+ )
1442
+ output_hidden_states = (
1443
+ output_hidden_states
1444
+ if output_hidden_states is not None
1445
+ else self.config.output_hidden_states
1446
+ )
1447
+ return_dict = (
1448
+ return_dict if return_dict is not None else self.config.return_dict
1449
+ )
1450
+
1451
+ encoder_hidden_states = encoder_outputs[0]
1452
+ if encoder_attention_mask is None:
1453
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
1454
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
1455
+
1456
+ batch_size, sequence_length = decoder_input_ids.shape
1457
+ if decoder_attention_mask is None:
1458
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1459
+
1460
+ if decoder_position_ids is None:
1461
+ if past_key_values is not None:
1462
+ raise ValueError(
1463
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
1464
+ )
1465
+
1466
+ decoder_position_ids = jnp.broadcast_to(
1467
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1468
+ )
1469
+
1470
+ # Handle any PRNG if needed
1471
+ rngs = {}
1472
+ if dropout_rng is not None:
1473
+ rngs["dropout"] = dropout_rng
1474
+
1475
+ inputs = {"params": params or self.params}
1476
+
1477
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1478
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1479
+ # it can be changed by FlaxBartAttention module
1480
+ if past_key_values:
1481
+ inputs["cache"] = past_key_values
1482
+ mutable = ["cache"]
1483
+ else:
1484
+ mutable = False
1485
+
1486
+ def _decoder_forward(
1487
+ module,
1488
+ decoder_input_ids,
1489
+ decoder_attention_mask,
1490
+ decoder_position_ids,
1491
+ **kwargs,
1492
+ ):
1493
+ decoder_module = module._get_decoder_module()
1494
+ outputs = decoder_module(
1495
+ decoder_input_ids,
1496
+ decoder_attention_mask,
1497
+ decoder_position_ids,
1498
+ **kwargs,
1499
+ )
1500
+ hidden_states = outputs[0]
1501
+
1502
+ if self.config.tie_word_embeddings:
1503
+ shared_embedding = module.model.variables["params"]["shared"][
1504
+ "embedding"
1505
+ ]
1506
+ lm_logits = module.lm_head.apply(
1507
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1508
+ )
1509
+ else:
1510
+ lm_logits = module.lm_head(hidden_states)
1511
+
1512
+ return lm_logits, outputs
1513
+
1514
+ outputs = self.module.apply(
1515
+ inputs,
1516
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1517
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1518
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1519
+ encoder_hidden_states=encoder_hidden_states,
1520
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
1521
+ output_attentions=output_attentions,
1522
+ output_hidden_states=output_hidden_states,
1523
+ return_dict=return_dict,
1524
+ deterministic=not train,
1525
+ rngs=rngs,
1526
+ mutable=mutable,
1527
+ method=_decoder_forward,
1528
+ )
1529
+
1530
+ if past_key_values is None:
1531
+ lm_logits, decoder_outputs = outputs
1532
+ else:
1533
+ (lm_logits, decoder_outputs), past = outputs
1534
+
1535
+ if return_dict:
1536
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1537
+ logits=lm_logits,
1538
+ hidden_states=decoder_outputs.hidden_states,
1539
+ attentions=decoder_outputs.attentions,
1540
+ cross_attentions=decoder_outputs.cross_attentions,
1541
+ )
1542
+ else:
1543
+ outputs = (lm_logits,) + decoder_outputs[1:]
1544
+
1545
+ # add updated cache to model output
1546
+ if past_key_values is not None and return_dict:
1547
+ outputs["past_key_values"] = unfreeze(past["cache"])
1548
+ return outputs
1549
+ elif past_key_values is not None and not return_dict:
1550
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1551
+
1552
+ return outputs
1553
+
1554
+ def prepare_inputs_for_generation(
1555
+ self,
1556
+ decoder_input_ids,
1557
+ max_length,
1558
+ attention_mask: Optional[jnp.DeviceArray] = None,
1559
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
1560
+ encoder_outputs=None,
1561
+ **kwargs,
1562
+ ):
1563
+ # initializing the cache
1564
+ batch_size, seq_length = decoder_input_ids.shape
1565
+
1566
+ past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
1567
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1568
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1569
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1570
+ extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
1571
+ if decoder_attention_mask is not None:
1572
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1573
+ extended_attention_mask = lax.dynamic_update_slice(
1574
+ extended_attention_mask, decoder_attention_mask, (0, 0)
1575
+ )
1576
+ else:
1577
+ position_ids = jnp.broadcast_to(
1578
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
1579
+ )
1580
+
1581
+ return {
1582
+ "past_key_values": past_key_values,
1583
+ "encoder_outputs": encoder_outputs,
1584
+ "encoder_attention_mask": attention_mask,
1585
+ "decoder_attention_mask": extended_attention_mask,
1586
+ "decoder_position_ids": position_ids,
1587
+ }
1588
+
1589
+ def generate(
1590
+ self,
1591
+ input_ids: jnp.ndarray,
1592
+ attention_mask: Optional[jnp.ndarray] = None,
1593
+ max_length: Optional[int] = None,
1594
+ pad_token_id: Optional[int] = None,
1595
+ bos_token_id: Optional[int] = None,
1596
+ eos_token_id: Optional[int] = None,
1597
+ decoder_start_token_id: Optional[int] = None,
1598
+ do_sample: Optional[bool] = None,
1599
+ prng_key: Optional[jnp.ndarray] = None,
1600
+ top_k: Optional[int] = None,
1601
+ top_p: Optional[float] = None,
1602
+ temperature: Optional[float] = None,
1603
+ num_beams: Optional[int] = None,
1604
+ no_repeat_ngram_size: Optional[int] = None,
1605
+ min_length: Optional[int] = None,
1606
+ forced_bos_token_id: Optional[int] = None,
1607
+ forced_eos_token_id: Optional[int] = None,
1608
+ length_penalty: Optional[float] = None,
1609
+ early_stopping: Optional[bool] = None,
1610
+ trace: bool = True,
1611
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1612
+ condition_scale: Optional[float] = 1.0,
1613
+ input_ids_uncond: Optional[jnp.ndarray] = None,
1614
+ attention_mask_uncond: Optional[jnp.ndarray] = None,
1615
+ **model_kwargs,
1616
+ ):
1617
+ """Edit: Allow super conditioning."""
1618
+
1619
+ # set init values
1620
+ max_length = max_length if max_length is not None else self.config.max_length
1621
+ bos_token_id = (
1622
+ bos_token_id if bos_token_id is not None else self.config.bos_token_id
1623
+ )
1624
+ pad_token_id = (
1625
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1626
+ )
1627
+ eos_token_id = (
1628
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1629
+ )
1630
+ decoder_start_token_id = (
1631
+ decoder_start_token_id
1632
+ if decoder_start_token_id
1633
+ else self.config.decoder_start_token_id
1634
+ )
1635
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1636
+
1637
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
1638
+ raise ValueError(
1639
+ "`decoder_start_token_id` has to be defined for encoder-decoder generation."
1640
+ )
1641
+
1642
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
1643
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
1644
+
1645
+ if self.config.is_encoder_decoder:
1646
+ # add encoder_outputs to model_kwargs
1647
+ if model_kwargs.get("encoder_outputs") is None:
1648
+ model_kwargs_input = dict(model_kwargs)
1649
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
1650
+ input_ids,
1651
+ params,
1652
+ {"attention_mask": attention_mask, **model_kwargs_input},
1653
+ )
1654
+ if condition_scale != 1.0:
1655
+ assert (
1656
+ input_ids_uncond is not None
1657
+ ), "`input_ids_uncond` has to be defined for super conditioning."
1658
+ assert (
1659
+ do_sample is True
1660
+ ), "`do_sample` has to be True for super conditioning."
1661
+ assert (
1662
+ num_beams == 1
1663
+ ), "`num_beams` has to be 1 for super conditioning."
1664
+ model_kwargs_uncond = (
1665
+ self._prepare_encoder_decoder_kwargs_for_generation(
1666
+ input_ids_uncond,
1667
+ params,
1668
+ {
1669
+ "attention_mask": attention_mask_uncond,
1670
+ **model_kwargs_input,
1671
+ },
1672
+ )
1673
+ )
1674
+ else:
1675
+ model_kwargs_uncond = None
1676
+ # prepare decoder_input_ids for generation
1677
+ input_ids = (
1678
+ jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1679
+ )
1680
+
1681
+ if not do_sample and num_beams == 1:
1682
+ logits_processor = self._get_logits_processor(
1683
+ no_repeat_ngram_size,
1684
+ min_length,
1685
+ max_length,
1686
+ eos_token_id,
1687
+ forced_bos_token_id,
1688
+ forced_eos_token_id,
1689
+ )
1690
+ return self._greedy_search(
1691
+ input_ids,
1692
+ max_length,
1693
+ pad_token_id,
1694
+ eos_token_id,
1695
+ logits_processor=logits_processor,
1696
+ trace=trace,
1697
+ params=params,
1698
+ model_kwargs=model_kwargs,
1699
+ )
1700
+ elif do_sample and num_beams == 1:
1701
+ logits_warper = self._get_logits_warper(
1702
+ top_k=top_k, top_p=top_p, temperature=temperature
1703
+ )
1704
+ logits_processor = self._get_logits_processor(
1705
+ no_repeat_ngram_size,
1706
+ min_length,
1707
+ max_length,
1708
+ eos_token_id,
1709
+ forced_bos_token_id,
1710
+ forced_eos_token_id,
1711
+ )
1712
+ return self._sample(
1713
+ input_ids,
1714
+ max_length,
1715
+ pad_token_id,
1716
+ eos_token_id,
1717
+ prng_key,
1718
+ logits_warper=logits_warper,
1719
+ logits_processor=logits_processor,
1720
+ trace=trace,
1721
+ params=params,
1722
+ model_kwargs=model_kwargs,
1723
+ condition_scale=condition_scale,
1724
+ model_kwargs_uncond=model_kwargs_uncond,
1725
+ )
1726
+ elif not do_sample and num_beams > 1:
1727
+ # broadcast input_ids & encoder_outputs
1728
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
1729
+
1730
+ if "encoder_outputs" in model_kwargs:
1731
+ model_kwargs["encoder_outputs"][
1732
+ "last_hidden_state"
1733
+ ] = self._expand_to_num_beams(
1734
+ model_kwargs["encoder_outputs"]["last_hidden_state"],
1735
+ num_beams=num_beams,
1736
+ )
1737
+
1738
+ if "attention_mask" in model_kwargs:
1739
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
1740
+ model_kwargs["attention_mask"], num_beams=num_beams
1741
+ )
1742
+
1743
+ logits_processor = self._get_logits_processor(
1744
+ no_repeat_ngram_size,
1745
+ min_length,
1746
+ max_length,
1747
+ eos_token_id,
1748
+ forced_bos_token_id,
1749
+ forced_eos_token_id,
1750
+ )
1751
+
1752
+ return self._beam_search(
1753
+ input_ids,
1754
+ max_length,
1755
+ pad_token_id,
1756
+ eos_token_id,
1757
+ length_penalty=length_penalty,
1758
+ early_stopping=early_stopping,
1759
+ logits_processor=logits_processor,
1760
+ trace=trace,
1761
+ params=params,
1762
+ model_kwargs=model_kwargs,
1763
+ )
1764
+ else:
1765
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
1766
+
1767
+ def _sample(
1768
+ self,
1769
+ input_ids: None,
1770
+ max_length: Optional[int] = None,
1771
+ pad_token_id: Optional[int] = None,
1772
+ eos_token_id: Optional[int] = None,
1773
+ prng_key: Optional[jnp.ndarray] = None,
1774
+ logits_processor=None,
1775
+ logits_warper=None,
1776
+ trace: bool = True,
1777
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1778
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
1779
+ condition_scale: float = 1.0,
1780
+ model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
1781
+ ):
1782
+ # init values
1783
+ max_length = max_length if max_length is not None else self.config.max_length
1784
+ pad_token_id = (
1785
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1786
+ )
1787
+ eos_token_id = (
1788
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1789
+ )
1790
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1791
+
1792
+ batch_size, cur_len = input_ids.shape
1793
+
1794
+ eos_token_id = jnp.array(eos_token_id)
1795
+ pad_token_id = jnp.array(pad_token_id)
1796
+ cur_len = jnp.array(cur_len)
1797
+
1798
+ # per batch-item holding current token in loop.
1799
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
1800
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
1801
+
1802
+ # per batch-item state bit indicating if sentence has finished.
1803
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
1804
+
1805
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1806
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1807
+ model = self.decode if self.config.is_encoder_decoder else self
1808
+
1809
+ # initialize model specific kwargs
1810
+ model_kwargs = self.prepare_inputs_for_generation(
1811
+ input_ids, max_length, **model_kwargs
1812
+ )
1813
+ if condition_scale != 1.0:
1814
+ model_kwargs_uncond = self.prepare_inputs_for_generation(
1815
+ input_ids, max_length, **model_kwargs_uncond
1816
+ )
1817
+
1818
+ # initialize state
1819
+ state = SampleState(
1820
+ cur_len=cur_len,
1821
+ sequences=sequences,
1822
+ running_token=input_ids,
1823
+ is_sent_finished=is_sent_finished,
1824
+ prng_key=prng_key,
1825
+ model_kwargs=model_kwargs,
1826
+ model_kwargs_uncond=model_kwargs_uncond,
1827
+ )
1828
+
1829
+ def sample_search_cond_fn(state):
1830
+ """state termination condition fn."""
1831
+ has_reached_max_length = state.cur_len == max_length
1832
+ all_sequence_finished = jnp.all(state.is_sent_finished)
1833
+ finish_generation = jnp.logical_or(
1834
+ has_reached_max_length, all_sequence_finished
1835
+ )
1836
+ return ~finish_generation
1837
+
1838
+ def sample_search_body_fn(state):
1839
+ """state update fn."""
1840
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
1841
+ model_outputs = model(
1842
+ state.running_token, params=params, **state.model_kwargs
1843
+ )
1844
+
1845
+ logits = model_outputs.logits[:, -1]
1846
+
1847
+ # perform super conditioning
1848
+ # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
1849
+ if condition_scale != 1.0:
1850
+ model_outputs_uncond = model(
1851
+ state.running_token, params=params, **state.model_kwargs_uncond
1852
+ )
1853
+ logits_uncond = model_outputs_uncond.logits[:, -1]
1854
+ logits = logits_uncond + condition_scale * (logits - logits_uncond)
1855
+ else:
1856
+ model_outputs_uncond = None
1857
+
1858
+ # apply min_length, ...
1859
+ logits = logits_processor(state.sequences, logits, state.cur_len)
1860
+ # apply top_k, top_k, temperature
1861
+ logits = logits_warper(logits, logits, state.cur_len)
1862
+
1863
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
1864
+
1865
+ next_is_sent_finished = state.is_sent_finished | (
1866
+ next_token == eos_token_id
1867
+ )
1868
+ next_token = (
1869
+ next_token * ~next_is_sent_finished
1870
+ + pad_token_id * next_is_sent_finished
1871
+ )
1872
+ next_token = next_token[:, None]
1873
+
1874
+ next_sequences = lax.dynamic_update_slice(
1875
+ state.sequences, next_token, (0, state.cur_len)
1876
+ )
1877
+ next_model_kwargs = self.update_inputs_for_generation(
1878
+ model_outputs, state.model_kwargs
1879
+ )
1880
+ next_model_kwargs_uncond = (
1881
+ self.update_inputs_for_generation(
1882
+ model_outputs_uncond, state.model_kwargs_uncond
1883
+ )
1884
+ if condition_scale != 1.0
1885
+ else None
1886
+ )
1887
+
1888
+ return SampleState(
1889
+ cur_len=state.cur_len + 1,
1890
+ sequences=next_sequences,
1891
+ running_token=next_token,
1892
+ is_sent_finished=next_is_sent_finished,
1893
+ model_kwargs=next_model_kwargs,
1894
+ model_kwargs_uncond=next_model_kwargs_uncond,
1895
+ prng_key=prng_key_next,
1896
+ )
1897
+
1898
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1899
+ if input_ids.shape[1] > 1:
1900
+ state = sample_search_body_fn(state)
1901
+
1902
+ if not trace:
1903
+ state = self._run_loop_in_debug(
1904
+ sample_search_cond_fn, sample_search_body_fn, state
1905
+ )
1906
+ else:
1907
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
1908
+
1909
+ return FlaxSampleOutput(sequences=state.sequences)
partitions.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from flax.core.frozen_dict import freeze
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+ from jax.experimental import PartitionSpec as P
6
+
7
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
8
+ # Sentinels
9
+ _unmatched = object()
10
+
11
+ # For specifying empty leaf dict `{}`
12
+ empty_dict = object()
13
+
14
+
15
+ def _match(qs, ks):
16
+ """Return True if regexes in qs match any window of strings in tuple ks."""
17
+ # compile regexes and force complete match
18
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
19
+ for i in range(len(ks) - len(qs) + 1):
20
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
21
+ if matches and all(matches):
22
+ return True
23
+ return False
24
+
25
+
26
+ def _replacement_rules(rules):
27
+ def replace(key, val):
28
+ for rule, replacement in rules:
29
+ if _match(rule, key):
30
+ return replacement
31
+ return val
32
+
33
+ return replace
34
+
35
+
36
+ def _get_partition_rules():
37
+ return [
38
+ # embeddings
39
+ (("embed_positions", "embedding"), P("mp", None)),
40
+ (("embed_tokens", "embedding"), P("mp", None)),
41
+ (("rel_bias", "embedding"), P(None, "mp")),
42
+ # attention
43
+ (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
44
+ (("out_proj", "kernel"), P("mp", None)),
45
+ # FFN
46
+ (("Dense_0", "kernel"), P(None, "mp")),
47
+ (("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
48
+ (("GLU.*", "Dense_2", "kernel"), P("mp", None)),
49
+ (("FFN.*", "Dense_1", "kernel"), P("mp", None)),
50
+ # layer norms
51
+ (("(bias|scale)",), None),
52
+ (("lm_head", "kernel"), P(None, "mp")),
53
+ # head scale and tau
54
+ (("(head_scale|tau)",), None),
55
+ ]
56
+
57
+
58
+ def set_partitions(in_dict, use_scan):
59
+ rules = _get_partition_rules()
60
+ replace = _replacement_rules(rules)
61
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
62
+ result = {k: replace(k, v) for k, v in initd.items()}
63
+ for k, v in result.items():
64
+ if v == _unmatched:
65
+ print(f"Unmatched -> {k}")
66
+ l = list(result.keys())
67
+ if use_scan:
68
+ # add None dimension to layers
69
+ result = {
70
+ k: (P(*(None,) + v) if v is not None else None)
71
+ if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
72
+ else v
73
+ for k, v in result.items()
74
+ }
75
+ assert _unmatched not in result.values(), "Incomplete partition spec."
76
+ return freeze(unflatten_dict(result))
processor.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart processor """
2
+
3
+ from typing import List
4
+
5
+ import jax.numpy as jnp
6
+
7
+ from .configuration import DalleBartConfig
8
+ from .text import TextNormalizer
9
+ from .tokenizer import DalleBartTokenizer
10
+ from .utils import PretrainedFromWandbMixin
11
+
12
+
13
+ class DalleBartProcessorBase:
14
+ def __init__(
15
+ self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
16
+ ):
17
+ self.tokenizer = tokenizer
18
+ self.normalize_text = normalize_text
19
+ self.max_text_length = max_text_length
20
+ if normalize_text:
21
+ self.text_processor = TextNormalizer()
22
+ # create unconditional tokens
23
+ uncond = self.tokenizer(
24
+ "",
25
+ return_tensors="jax",
26
+ padding="max_length",
27
+ truncation=True,
28
+ max_length=self.max_text_length,
29
+ ).data
30
+ self.input_ids_uncond = uncond["input_ids"]
31
+ self.attention_mask_uncond = uncond["attention_mask"]
32
+
33
+ def __call__(self, text: List[str] = None):
34
+ # check that text is not a string
35
+ assert not isinstance(text, str), "text must be a list of strings"
36
+
37
+ if self.normalize_text:
38
+ text = [self.text_processor(t) for t in text]
39
+ res = self.tokenizer(
40
+ text,
41
+ return_tensors="jax",
42
+ padding="max_length",
43
+ truncation=True,
44
+ max_length=self.max_text_length,
45
+ ).data
46
+ # tokens used only with super conditioning
47
+ n = len(text)
48
+ res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
49
+ res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
50
+ return res
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, *args, **kwargs):
54
+ tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
55
+ config = DalleBartConfig.from_pretrained(*args, **kwargs)
56
+ return cls(tokenizer, config.normalize_text, config.max_text_length)
57
+
58
+
59
+ class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
60
+ pass
pypi_release.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow uses actions that are not certified by GitHub.
2
+ # They are provided by a third-party and are governed by
3
+ # separate terms of service, privacy policy, and support
4
+ # documentation.
5
+
6
+ name: Upload Python Package
7
+
8
+ on:
9
+ release:
10
+ types: [published]
11
+
12
+ jobs:
13
+ deploy:
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - uses: actions/checkout@v3
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v3
19
+ with:
20
+ python-version: "3.x"
21
+ - name: Install dependencies
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install build
25
+ - name: Build package
26
+ run: python -m build
27
+ - name: Publish package
28
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
29
+ with:
30
+ user: __token__
31
+ password: ${{ secrets.PYPI_API_TOKEN }}
quantization_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Helper routines for quantization."""
17
+
18
+ from typing import Any
19
+
20
+ import chex
21
+ import jax.numpy as jnp
22
+ from flax import struct
23
+
24
+
25
+ # pylint:disable=no-value-for-parameter
26
+ @struct.dataclass
27
+ class QuantizedValue:
28
+ """State associated with quantized value."""
29
+
30
+ quantized: chex.Array
31
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
32
+ bucket_size: chex.Array
33
+ quantized_dtype: jnp.dtype = struct.field(
34
+ pytree_node=False
35
+ ) # Dtype for the quantized value.
36
+ extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
37
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
38
+
39
+ @classmethod
40
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
41
+ if isinstance(fvalue, list) and not fvalue:
42
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
43
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
44
+ fvalue, quantized_dtype, extract_diagonal
45
+ )
46
+ return QuantizedValue(
47
+ quantized,
48
+ diagonal_fvalue,
49
+ bucket_size,
50
+ quantized_dtype,
51
+ extract_diagonal,
52
+ list(quantized.shape),
53
+ )
54
+
55
+ # Quantization is from Lingvo JAX optimizers.
56
+ # We extend it for int16 quantization of PSD matrices.
57
+ @classmethod
58
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
59
+ """Returns quantized value and the bucket."""
60
+ if quantized_dtype == jnp.float32:
61
+ return fvalue, [], []
62
+ elif quantized_dtype == jnp.bfloat16:
63
+ return fvalue.astype(jnp.bfloat16), [], []
64
+
65
+ float_dtype = fvalue.dtype
66
+ if quantized_dtype == jnp.int8:
67
+ # value -128 is not used.
68
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
69
+ elif quantized_dtype == jnp.int16:
70
+ # value -32768 is not used.
71
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
72
+ else:
73
+ raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
74
+ # max value is mapped to num_buckets
75
+
76
+ if extract_diagonal and fvalue.ndim != 2:
77
+ raise ValueError(
78
+ f"Input array {fvalue} must be 2D to work with extract_diagonal."
79
+ )
80
+
81
+ diagonal_fvalue = []
82
+ if extract_diagonal:
83
+ diagonal_fvalue = jnp.diag(fvalue)
84
+ # Remove the diagonal entries.
85
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
86
+
87
+ # TODO(rohananil): Extend this by making use of information about the blocks
88
+ # SM3 style which will be useful for diagonal statistics
89
+ # We first decide the scale.
90
+ if fvalue.ndim < 1:
91
+ raise ValueError(
92
+ f"Input array {fvalue} must have a strictly positive number of "
93
+ "dimensions."
94
+ )
95
+
96
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
97
+ bucket_size = max_abs / num_buckets
98
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
99
+ # To avoid divide by 0.0
100
+ bs_nonzero = jnp.where(
101
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
102
+ )
103
+ ratio = fvalue / bs_nonzero
104
+ # We use rounding to remove bias.
105
+ quantized = jnp.round(ratio)
106
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
107
+
108
+ def to_float(self):
109
+ """Returns the float value."""
110
+ if isinstance(self.quantized, list) and not self.quantized:
111
+ return self.quantized
112
+
113
+ if self.quantized_dtype == jnp.float32:
114
+ return self.quantized
115
+
116
+ if self.quantized_dtype == jnp.bfloat16:
117
+ return self.quantized.astype(jnp.float32)
118
+
119
+ float_dtype = self.bucket_size.dtype
120
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
121
+ val = self.quantized.astype(float_dtype) * bucket_size
122
+ if self.extract_diagonal:
123
+ val += jnp.diag(self.diagonal)
124
+ return val
run_infer_notebook.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ jupyter notebook --ip 0.0.0.0 --no-browser --allow-root
sm3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of SM3 from:
17
+ #
18
+ # Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
20
+ #
21
+ # Author: Rohan Anil (rohananil at google dot com)
22
+ #
23
+
24
+ """SM3 Implementation."""
25
+
26
+ import functools
27
+ from typing import Any, NamedTuple
28
+
29
+ import chex
30
+ import jax
31
+ import jax.numpy as jnp
32
+ import optax
33
+
34
+ from .quantization_utils import QuantizedValue
35
+
36
+
37
+ class SM3State(NamedTuple):
38
+ count: chex.Array
39
+ stats: Any
40
+
41
+
42
+ # Per parameter optimizer state used in data-parallel training.
43
+ class ParameterStats(NamedTuple):
44
+ """State associated to each parameter of the model being trained."""
45
+
46
+ diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
47
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
48
+
49
+
50
+ def sm3(
51
+ learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
52
+ ):
53
+ """SM3 optimizer.
54
+
55
+ Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
56
+ Yoram Singer
57
+
58
+ https://arxiv.org/abs/1901.11150
59
+
60
+ Args:
61
+ learning_rate: the step size used to update the parameters.
62
+ beta1: momentum parameter.
63
+ beta2: second moment averaging parameter.
64
+ diagonal_epsilon: epsilon for sm3
65
+ normalize_grads: Whether to normalize grads. Author finds it useful when
66
+ grads are high variance.
67
+
68
+ Returns:
69
+ a GradientTransformation.
70
+ """
71
+
72
+ def _quantize_momentum(momentum_statistics):
73
+ return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
74
+
75
+ def init_fn(params):
76
+ """Initialise the optimiser's state."""
77
+
78
+ def _init(param):
79
+ accumulators = [jnp.zeros([s]) for s in param.shape]
80
+ momentum = _quantize_momentum(jnp.zeros_like(param))
81
+ return ParameterStats(accumulators, momentum)
82
+
83
+ return SM3State(
84
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
85
+ )
86
+
87
+ def _get_expanded_shape(shape, i):
88
+ rank = len(shape)
89
+ # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
90
+ # For eg: i = 1 returns [1, N, 1].
91
+ return [1] * i + [shape[i]] + [1] * (rank - i - 1)
92
+
93
+ def _moving_averages(grad, accumulators):
94
+ w = (1.0 - beta2) if beta2 != 1.0 else 1.0
95
+ if grad.ndim < 2:
96
+ return beta2 * accumulators[0] + w * grad**2
97
+ else:
98
+ min_accumulator = functools.reduce(jnp.minimum, accumulators)
99
+ return beta2 * min_accumulator + w * grad**2
100
+
101
+ def _moving_averages_momentum(grad, momentum):
102
+ w = (1.0 - beta1) if beta1 != 1.0 else 1.0
103
+ return beta1 * momentum.to_float() + w * grad
104
+
105
+ def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
106
+ all_diagonal_statistics = []
107
+ for i in range(grad.ndim):
108
+ axes = list(range(i)) + list(range(i + 1, grad.ndim))
109
+ dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
110
+ all_diagonal_statistics.append(dim_diagonal_statistics)
111
+ if grad.ndim == 1:
112
+ all_diagonal_statistics[0] = updated_diagonal_statistics
113
+ return all_diagonal_statistics
114
+
115
+ def update_fn(updates, state, params=None):
116
+ del params
117
+ stats = state.stats
118
+ if normalize_grads:
119
+ updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
120
+ # Reshape all vectors into N-d tensors to compute min over them.
121
+ # [n], [m] -> [n, 1], [1, m]
122
+ expanded_diagonal_statistics = jax.tree_multimap(
123
+ lambda grad, state: [ # pylint:disable=g-long-lambda
124
+ jnp.reshape(
125
+ state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
126
+ )
127
+ for i in range(grad.ndim)
128
+ ],
129
+ updates,
130
+ stats,
131
+ )
132
+
133
+ # Compute new diagonal statistics
134
+ new_diagonal_statistics = jax.tree_multimap(
135
+ _moving_averages, updates, expanded_diagonal_statistics
136
+ )
137
+
138
+ # Compute preconditioners (1/sqrt(s)) where s is the statistics.
139
+ new_preconditioners = jax.tree_map(
140
+ lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
141
+ )
142
+ preconditioned_grads = jax.tree_multimap(
143
+ lambda g, p: g * p, updates, new_preconditioners
144
+ )
145
+
146
+ # Compute updated momentum (also handle quantization)
147
+ updated_momentum = jax.tree_multimap(
148
+ lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
149
+ preconditioned_grad, state.diagonal_momentum
150
+ ),
151
+ preconditioned_grads,
152
+ stats,
153
+ )
154
+
155
+ # Update diagonal statistics.
156
+ updated_diagonal_statistics = jax.tree_multimap(
157
+ _sketch_diagonal_statistics, updates, new_diagonal_statistics
158
+ )
159
+
160
+ # Update momentum.
161
+ new_sm3_stats = jax.tree_multimap(
162
+ lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
163
+ diagonal_stats, _quantize_momentum(momentum)
164
+ ),
165
+ updated_momentum,
166
+ updated_diagonal_statistics,
167
+ )
168
+
169
+ lr = learning_rate
170
+ if callable(learning_rate):
171
+ lr = learning_rate(state.count)
172
+
173
+ new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
174
+ return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
175
+
176
+ return optax.GradientTransformation(init_fn, update_fn)
style.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ lint:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - uses: psf/black@stable
15
+ - uses: actions/setup-python@v2
16
+ with:
17
+ python-version: 3.9
18
+ - name: Install requirements
19
+ run: pip install ".[dev]"
20
+ - uses: jamescurtin/isort-action@master
sweep.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program: train.py
2
+ project: dalle-mini
3
+ method: random
4
+ metric:
5
+ name: eval/loss
6
+ goal: minimize
7
+ parameters:
8
+ optim:
9
+ value: distributed_shampoo
10
+ learning_rate:
11
+ distribution: log_uniform
12
+ # from exp(min) to exp(max)
13
+ min: -9.2
14
+ max: -6.9
15
+ tokenizer_name:
16
+ value: boris/dalle-mini-tokenizer
17
+ config_name:
18
+ value: ./config/mini
19
+ dtype:
20
+ value: bfloat16
21
+ dataset_repo_or_path:
22
+ value: ./data
23
+ per_device_train_batch_size:
24
+ value: 64
25
+ per_device_eval_batch_size:
26
+ value: 64
27
+ gradient_accumulation_steps:
28
+ value: 1
29
+ warmup_steps:
30
+ value: 1000
31
+ num_train_epochs:
32
+ value: 1
33
+ max_train_samples:
34
+ value: 1000000
35
+ logging_steps:
36
+ value: 40
37
+ eval_steps:
38
+ value: 200
39
+
40
+ command:
41
+ - python3
42
+ - ${program}
43
+ - "--streaming"
44
+ - "--output_dir"
45
+ - "./output"
46
+ - "--overwrite_output_dir"
47
+ - "--do_train"
48
+ - "--do_eval"
49
+ - ${args}
symmetric_matrices.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
+
18
+ import functools
19
+ from typing import Any, List, Optional, Sequence, Union
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ from flax import struct
25
+ from jax import lax
26
+
27
+
28
+ @struct.dataclass
29
+ class SlicedSymmetricMatrix:
30
+ """A symmetric matrix represented by lower-triangular block row slices.
31
+
32
+ For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
33
+ by the block rows a and [b, c].
34
+
35
+ The matrix may be batched, in which case each entry of block_rows may have
36
+ dimension greater than 2. The last two dimensions represent the rows and cols.
37
+ """
38
+
39
+ block_rows: List[jnp.ndarray]
40
+
41
+
42
+ def product_with_transpose(
43
+ mat1,
44
+ mat2,
45
+ axes,
46
+ precision=lax.Precision.DEFAULT,
47
+ ):
48
+ """Returns mat1 * mat2^T for two matrices (possibly batched).
49
+
50
+ The rows and columns are the last two dimensions for each matrix.
51
+
52
+ Args:
53
+ mat1: First matrix.
54
+ mat2: Second matrix.
55
+ axes: The axes over which to apply the product.
56
+ precision: JAX precision to use for the multiplication.
57
+ """
58
+ return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
59
+
60
+
61
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
62
+ def sliced_transposed_product(
63
+ mat,
64
+ block_size,
65
+ axes=(-1,),
66
+ precision=lax.Precision.DEFAULT,
67
+ ):
68
+ """Returns the blocked slices representing a symmetric contraction.
69
+
70
+ Specifically, the output is a contraction of the input mat with itself, in the
71
+ specified axes.
72
+
73
+ Args:
74
+ mat: The matrix for which we will compute a contraction with itself.
75
+ block_size: The size of row blocks to compute.
76
+ axes: Axes to use for the contraction.
77
+ precision: The precision to use in each computation.
78
+
79
+ Raises:
80
+ ValueError: Raised when the specified block size does not evenly divide
81
+ the number of rows of the input mat.
82
+ """
83
+ rank = len(mat.shape)
84
+
85
+ def _make_axis_positive(ax):
86
+ assert -rank <= ax < rank
87
+ return ax + rank if ax < 0 else ax
88
+
89
+ positive_axes = [_make_axis_positive(ax) for ax in axes]
90
+ assert len(positive_axes) == len(axes)
91
+ remaining_axes = set(range(rank)) - set(positive_axes)
92
+ assert len(remaining_axes) == 1
93
+ remaining_ax = remaining_axes.pop()
94
+
95
+ num_rows = mat.shape[remaining_ax]
96
+ if num_rows % block_size != 0:
97
+ raise ValueError(
98
+ "The row dimension must be divisible by block_size. "
99
+ f"Instead got row dimension={num_rows} and block_size={block_size}."
100
+ )
101
+
102
+ block_rows = []
103
+ for i in range(num_rows // block_size):
104
+ start_indices = [0] * rank
105
+ start_indices[remaining_ax] = i * block_size
106
+
107
+ slice_sizes = list(mat.shape)
108
+ slice_sizes[remaining_ax] = block_size
109
+
110
+ slice_sizes_full = list(mat.shape)
111
+ slice_sizes_full[remaining_ax] = (i + 1) * block_size
112
+
113
+ block_rows.append(
114
+ product_with_transpose(
115
+ lax.dynamic_slice(
116
+ mat, start_indices=start_indices, slice_sizes=slice_sizes
117
+ ),
118
+ lax.dynamic_slice(
119
+ mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
120
+ ),
121
+ axes=(axes, axes),
122
+ precision=precision,
123
+ )
124
+ )
125
+
126
+ return SlicedSymmetricMatrix(block_rows=block_rows)
127
+
128
+
129
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
130
+ def sliced_transposed_product_concat(
131
+ mat,
132
+ block_size,
133
+ axes=(-1,),
134
+ precision=lax.Precision.DEFAULT,
135
+ ):
136
+ """Returns the concatenated slices representing mat*mat^T.
137
+
138
+ Args:
139
+ mat: The matrix for which we will compute mat*mat^T. It does not need to be
140
+ square, and may be batched.
141
+ block_size: The size of row blocks to compute.
142
+ axes: Axes to use for the contraction.
143
+ precision: The precision to use in each computation.
144
+
145
+ Raises:
146
+ ValueError: Raised when the specified block size does not evenly divide
147
+ the number of rows of the input mat.
148
+ """
149
+ sliced_symmetric_matrix = sliced_transposed_product(
150
+ mat=mat, block_size=block_size, axes=axes, precision=precision
151
+ )
152
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
153
+
154
+
155
+ @jax.jit
156
+ def materialize_matrix(symmetric_matrix):
157
+ """Returns a materialized symmetric matrix.
158
+
159
+ Args:
160
+ symmetric_matrix: the matrix represented by lower-triangular block slices.
161
+ """
162
+ block_rows = symmetric_matrix.block_rows
163
+ block_size = block_rows[0].shape[-2]
164
+ num_blocks = len(block_rows)
165
+
166
+ # Slice the lower-triangular and diagonal blocks into blocks.
167
+ blocks = [
168
+ [
169
+ block_row[Ellipsis, i * block_size : (i + 1) * block_size]
170
+ for i in range(k + 1)
171
+ ]
172
+ for k, block_row in enumerate(block_rows)
173
+ ]
174
+
175
+ # Generate the (off-diagonal) upper-triangular blocks.
176
+ off_diags = [[] for _ in range(num_blocks - 1)]
177
+ for k, block_row in enumerate(block_rows[1:]):
178
+ for i in range(k + 1):
179
+ off_diags[i].append(
180
+ jnp.swapaxes(
181
+ a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
182
+ axis1=-1,
183
+ axis2=-2,
184
+ )
185
+ )
186
+
187
+ return jnp.block(
188
+ [row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
189
+ )
190
+
191
+
192
+ @functools.partial(jax.jit, static_argnames=("num_blocks"))
193
+ def materialize_matrix_from_concat(
194
+ block_rows_concat,
195
+ num_blocks=None,
196
+ ):
197
+ """Returns a materialized symmetric matrix from concatenated slices.
198
+
199
+ Args:
200
+ block_rows_concat: The matrix represented as the concatenated
201
+ lower-triangular blocks.
202
+ num_blocks: The number of block-rows used to represent the symmetric matrix.
203
+ If not specified, it is inferred from the shape of block_rows_concat.
204
+ """
205
+ if num_blocks is None:
206
+ num_blocks = find_num_blocks(block_rows_concat)
207
+
208
+ block_size = block_rows_concat.shape[-2]
209
+
210
+ block_rows = [
211
+ block_rows_concat[
212
+ Ellipsis,
213
+ (k * (k + 1))
214
+ // 2
215
+ * block_size : (((k + 1) * (k + 2)) // 2 + 1)
216
+ * block_size,
217
+ ]
218
+ for k in range(num_blocks)
219
+ ]
220
+
221
+ return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
222
+
223
+
224
+ @functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
225
+ def update_sliced_rows(
226
+ symmetric_matrix,
227
+ mat,
228
+ alpha,
229
+ beta,
230
+ axes=(-1,),
231
+ ):
232
+ """Implements the blocked equivalent of SYRK.
233
+
234
+ Specifically, the symmetric matrix (represented using lower-triangular block
235
+ rows) is updated using the sliced product of mat.
236
+
237
+ Args:
238
+ symmetric_matrix: The symmetric matrix to update.
239
+ mat: The matrix to use for the update = mat * mat^T. The number of rows
240
+ should match that of symmetric_matrix.
241
+ alpha: The weight for the update.
242
+ beta: The weight for the original symmetric matrix.
243
+ axes: Axes to use for the contraction of the update.
244
+
245
+ Returns:
246
+ The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
247
+ """
248
+ block_size = symmetric_matrix.block_rows[0].shape[-2]
249
+ sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
250
+ return SlicedSymmetricMatrix(
251
+ block_rows=[
252
+ update * alpha + row * beta
253
+ for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
254
+ ]
255
+ )
256
+
257
+
258
+ def num_blocks_from_total_blocks(total_blocks):
259
+ """Returns the number of blocks (i.e.
260
+
261
+ block rows) from the total blocks.
262
+
263
+ This is the inverse of the function x -> x*(x+1)/2.
264
+
265
+ For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
266
+ total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
267
+
268
+ Args:
269
+ total_blocks: The total blocks used to represent the matrix.
270
+ """
271
+ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
272
+ if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
273
+ raise ValueError(
274
+ f"total_blocks={total_blocks} does not correspond to "
275
+ "a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
276
+ )
277
+ return num_blocks
278
+
279
+
280
+ def find_num_blocks(block_rows_concat):
281
+ """Returns the number of (row) blocks representing the concatenated matrix.
282
+
283
+ For example, an input with dimensions [256, 2560] represents 10 square blocks,
284
+ which matches 4 lower-triangular block rows (1+2+3+4). So this function will
285
+ return 4.
286
+
287
+ Use ordinary numpy functions here so that the returned value is static.
288
+
289
+ Args:
290
+ block_rows_concat: The concatenated block array.
291
+
292
+ Raises:
293
+ ValueError: When the dimensions of the matrix do not correspond to a lower
294
+ triangular block representation.
295
+ """
296
+ # Compute the number of square blocks used to represent the matrix.
297
+ total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
298
+ # Determine the number of block rows by inverting y = x*(x+1)/2.
299
+ return num_blocks_from_total_blocks(total_blocks)
300
+
301
+
302
+ @functools.partial(jax.jit, static_argnames=("block_size"))
303
+ def slice_symmetric_matrix(
304
+ mat,
305
+ block_size,
306
+ ):
307
+ """Returns sliced row blocks.
308
+
309
+ Args:
310
+ mat: A symmetric matrix.
311
+ block_size: The size of the row slices.
312
+ """
313
+ num_rows = mat.shape[-2]
314
+ num_cols = mat.shape[-1]
315
+ if num_rows != num_cols:
316
+ raise ValueError("mat is not square.")
317
+ if num_rows % block_size != 0:
318
+ raise ValueError(
319
+ "block size does not evenly divide rows. "
320
+ f"num_rows={num_rows}, block_size={block_size}"
321
+ )
322
+ return SlicedSymmetricMatrix(
323
+ block_rows=[
324
+ mat[
325
+ Ellipsis,
326
+ i * block_size : (i + 1) * block_size,
327
+ 0 : (i + 1) * block_size,
328
+ ]
329
+ for i in range(num_rows // block_size)
330
+ ]
331
+ )
332
+
333
+
334
+ @functools.partial(jax.jit, static_argnames=("block_size"))
335
+ def slice_symmetric_matrix_concat(
336
+ mat,
337
+ block_size,
338
+ ):
339
+ """Returns the concatenated sliced row blocks.
340
+
341
+ Args:
342
+ mat: A symmetric matrix.
343
+ block_size: The size of the row slices.
344
+ """
345
+ sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
346
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
347
+
348
+
349
+ def sliced_matrix_diag(mat):
350
+ """Returns the diagonal of the symmetric matrix.
351
+
352
+ Args:
353
+ mat: The symmetric matrix represented in concatenated block form.
354
+ """
355
+ rows, cols = mat.shape
356
+ total_blocks = cols // rows
357
+ num_blocks = num_blocks_from_total_blocks(total_blocks)
358
+ diags = []
359
+ for i in range(num_blocks):
360
+ last_index = rows * ((i + 2) * (i + 1)) // 2
361
+ first_index = last_index - rows
362
+ diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
363
+ return jnp.concatenate(diags, axis=-1)
364
+
365
+
366
+ def diag_as_concat(diag, block_size):
367
+ """Returns the representation of a diagonal matrix in symmetric block form.
368
+
369
+ Args:
370
+ diag: The 1D array for the diagonals.
371
+ block_size: The size of blocks to use. Must divide the length of diag.
372
+ """
373
+ assert len(diag.shape) == 1 # diag must be 1D.
374
+ assert len(diag) % block_size == 0
375
+ num_diag_blocks = len(diag) // block_size
376
+ blocks = []
377
+ for i in range(num_diag_blocks):
378
+ blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
379
+ blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
380
+ return jnp.concatenate(blocks, axis=-1)
381
+
382
+
383
+ def row_abs_maxes(mat):
384
+ """Returns the max of the absolute values of the rows of the full matrix.
385
+
386
+ For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
387
+ mat = [1, 6, 2] with block_size = 1. In this case the function returns the
388
+ absolute row maxes of the original symmetric matrix, [6, 6].
389
+
390
+ Args:
391
+ mat: The symmetric matrix represented as the concatenated blocks.
392
+ """
393
+ rows, cols = mat.shape
394
+
395
+ # Find col and row max for each block.
396
+ col_maxes = []
397
+ row_maxes = []
398
+ for i in range(cols // rows):
399
+ block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
400
+ col_maxes.append(jnp.max(block, axis=1))
401
+ row_maxes.append(jnp.max(block, axis=0))
402
+
403
+ # global row max from block maxes.
404
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
405
+ maxes = []
406
+ for i in range(num_blocks):
407
+ maxes.append(
408
+ jnp.concatenate(
409
+ row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
410
+ + [
411
+ col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
412
+ for j in range(i + 1, num_blocks)
413
+ ],
414
+ axis=-1,
415
+ )
416
+ )
417
+
418
+ return jnp.max(jnp.stack(maxes), axis=0)
419
+
420
+
421
+ def times_vector(mat, vec):
422
+ """Returns the symmetric block-concatenated matrix multiplied by a vector.
423
+
424
+ Specifically, each value in the vector is multiplied by a row of the full
425
+ matrix. That is, the vector is broadcast and multiplied element-wise. Note
426
+ this would be the transpose of full_mat * vec if full_mat represented the full
427
+ symmetric matrix.
428
+
429
+ Args:
430
+ mat: The symmetric matrix represented as the concatenated blocks.
431
+ vec: The vector, having the same dimension as the materialized matrix.
432
+ """
433
+ rows, cols = mat.shape
434
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
435
+ multiplied = []
436
+ for i in range(num_blocks):
437
+ mat_block = mat[
438
+ Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
439
+ ]
440
+ vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
441
+ multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
442
+ return jnp.concatenate(multiplied, axis=-1)
sync_to_hub.yml.backup ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub - Obsolete to avoid app disruptions
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+
7
+ # to run this workflow manually from the Actions tab
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ sync-to-hub:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - uses: actions/checkout@v2
15
+ with:
16
+ fetch-depth: 0
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://boris:$HF_TOKEN@huggingface.co/spaces/dalle-mini/dalle-mini main
sync_to_hub_debug.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to debug app
2
+
3
+ on:
4
+ # to run this workflow manually from the Actions tab
5
+ workflow_dispatch:
6
+
7
+ jobs:
8
+ sync-to-hub-debug:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v2
12
+ with:
13
+ fetch-depth: 0
14
+ - name: Push to hub
15
+ env:
16
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
17
+ run: git push --force https://boris:$HF_TOKEN@huggingface.co/spaces/dalle-mini/dalle-mini-debug +HEAD:main
text.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
+ from pathlib import Path
10
+
11
+ import emoji
12
+ import ftfy
13
+ from huggingface_hub import hf_hub_download
14
+ from unidecode import unidecode
15
+
16
+ # based on wiki word occurrence
17
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
18
+ temp_token = "xtokx" # avoid repeating chars
19
+
20
+
21
+ class HashtagProcessor:
22
+ # Adapted from wordninja library
23
+ # We use our wikipedia word count + a good heuristic to make it work
24
+ def __init__(self):
25
+ wiki_word_frequency = hf_hub_download(
26
+ "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
27
+ )
28
+ self._word_cost = (
29
+ l.split()[0]
30
+ for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
31
+ )
32
+ self._word_cost = {
33
+ str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
34
+ }
35
+ self._max_word = max(len(x) for x in self._word_cost.keys())
36
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
37
+
38
+ def __call__(self, s):
39
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
40
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
41
+ return " ".join([item for sublist in l for item in sublist])
42
+
43
+ def _split(self, s):
44
+ # Find the best match for the i first characters, assuming cost has
45
+ # been built for the i-1 first characters.
46
+ # Returns a pair (match_cost, match_length).
47
+ def best_match(i):
48
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
49
+ return min(
50
+ (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
51
+ for k, c in candidates
52
+ )
53
+
54
+ # Build the cost array
55
+ cost = [0]
56
+ for i in range(1, len(s) + 1):
57
+ c, k = best_match(i)
58
+ cost.append(c)
59
+
60
+ # Backtrack to recover the minimal-cost string.
61
+ out = []
62
+ i = len(s)
63
+ while i > 0:
64
+ c, k = best_match(i)
65
+ assert c == cost[i]
66
+ newToken = True
67
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
68
+ if len(out) > 0:
69
+ # re-attach split 's and split digits
70
+ if out[-1] == "'s" or (
71
+ s[i - 1].isdigit() and out[-1][0].isdigit()
72
+ ): # digit followed by digit
73
+ out[-1] = (
74
+ s[i - k : i] + out[-1]
75
+ ) # combine current token with previous token
76
+ newToken = False
77
+
78
+ if newToken:
79
+ out.append(s[i - k : i])
80
+
81
+ i -= k
82
+
83
+ return reversed(out)
84
+
85
+
86
+ def replace_person_token(t):
87
+ "Used for CC12M"
88
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
89
+ while "<person>" in t:
90
+ t = t.replace(
91
+ "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
92
+ )
93
+ return t
94
+
95
+
96
+ def fix_html(t):
97
+ # from OpenAI CLIP
98
+ return html.unescape(html.unescape(t))
99
+
100
+
101
+ def replace_punctuation_with_commas(t):
102
+ return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
103
+
104
+
105
+ def simplify_quotes(t):
106
+ return re.sub("""['"`]""", ' " ', t)
107
+
108
+
109
+ def merge_quotes(t):
110
+ return re.sub('(\s*"+\s*)+', ' " ', t)
111
+
112
+
113
+ def remove_comma_numbers(t):
114
+ def _f(t):
115
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
116
+
117
+ return _f(_f(t))
118
+
119
+
120
+ def pre_process_dot_numbers(t):
121
+ return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
122
+
123
+
124
+ def post_process_dot_numbers(t):
125
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
126
+
127
+
128
+ def pre_process_quotes(t):
129
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
130
+ return re.sub(
131
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
132
+ )
133
+
134
+
135
+ def post_process_quotes(t):
136
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
137
+
138
+
139
+ def pre_process_dates(t):
140
+ return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
141
+
142
+
143
+ def post_process_dates(t):
144
+ return re.sub(f"{temp_token}slash{temp_token}", "/", t)
145
+
146
+
147
+ def merge_commas(t):
148
+ return re.sub("(\s*,+\s*)+", ", ", t)
149
+
150
+
151
+ def add_space_after_commas(t):
152
+ return re.sub(",", ", ", t)
153
+
154
+
155
+ def handle_special_chars(t):
156
+ "Handle special characters"
157
+ # replace "-" with a space when between words without space
158
+ t = re.sub("(\w)-(\w)", r"\1 \2", t)
159
+ # always add space around some characters
160
+ return re.sub("([%&\/$*])", r" \1 ", t)
161
+
162
+
163
+ def expand_hashtags(t, hashtag_processor):
164
+ "Remove # and try to split words"
165
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
166
+
167
+
168
+ _re_ignore_chars = r"[_#\\]"
169
+
170
+
171
+ def ignore_chars(t):
172
+ "Ignore useless characters"
173
+ return re.sub(_re_ignore_chars, " ", t)
174
+
175
+
176
+ def remove_extra_spaces(t):
177
+ "Remove extra spaces (including \t and \n)"
178
+ return re.sub("\s+", " ", t)
179
+
180
+
181
+ def remove_repeating_chars(t):
182
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
183
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
184
+
185
+
186
+ def remove_urls(t):
187
+ return re.sub(r"http\S+", "", t)
188
+
189
+
190
+ def remove_html_tags(t):
191
+ return re.sub("<[^<]+?>", " ", t)
192
+
193
+
194
+ def remove_first_last_commas(t):
195
+ t = t.strip()
196
+ t = t[:-1] if t and t[-1] == "," else t
197
+ t = t[1:] if t and t[0] == "," else t
198
+ return t.strip()
199
+
200
+
201
+ def remove_wiki_ref(t):
202
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
203
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
204
+
205
+
206
+ class TextNormalizer:
207
+ "Normalize text"
208
+
209
+ def __init__(self):
210
+ self._hashtag_processor = HashtagProcessor()
211
+
212
+ def __call__(self, t):
213
+ # fix some characters
214
+ t = ftfy.fix_text(t)
215
+ # fix html
216
+ t = fix_html(t)
217
+ # decode emojis (would be removed by unidecode)
218
+ t = emoji.demojize(t)
219
+ # decode and simplify text: see unidecode library
220
+ t = unidecode(t)
221
+ # lower case
222
+ t = t.lower()
223
+ # replace <PERSON> (for CC12M)
224
+ t = replace_person_token(t)
225
+ # remove wiki reference (for WIT)
226
+ t = remove_wiki_ref(t)
227
+ # remove html tags
228
+ t = remove_html_tags(t)
229
+ # remove urls
230
+ t = remove_urls(t)
231
+ # remove commas in numbers
232
+ t = remove_comma_numbers(t)
233
+ # handle dots in numbers and quotes - Part 1
234
+ t = pre_process_dot_numbers(t)
235
+ t = pre_process_quotes(t)
236
+ t = pre_process_dates(t)
237
+ # handle special characters
238
+ t = handle_special_chars(t)
239
+ # handle hashtags
240
+ t = expand_hashtags(t, self._hashtag_processor)
241
+ # ignore useless characters
242
+ t = ignore_chars(t)
243
+ # simplify quotes
244
+ t = simplify_quotes(t)
245
+ # all punctuation becomes commas
246
+ t = replace_punctuation_with_commas(t)
247
+ # handle dots in numbers and quotes - Part 2
248
+ t = post_process_dot_numbers(t)
249
+ t = post_process_quotes(t)
250
+ t = post_process_dates(t)
251
+ # handle repeating characters
252
+ t = remove_repeating_chars(t)
253
+ # merge quotes
254
+ t = merge_quotes(t)
255
+ # merge commas
256
+ t = merge_commas(t)
257
+ # remove multiple spaces
258
+ t = remove_extra_spaces(t)
259
+ # remove first and last comma
260
+ t = remove_first_last_commas(t)
261
+ # always start with a space
262
+ return f" {t}"
tokenizer.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart tokenizer """
2
+ from transformers import BartTokenizerFast
3
+
4
+ from .utils import PretrainedFromWandbMixin
5
+
6
+
7
+ class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):
8
+ pass
train.py ADDED
@@ -0,0 +1,1664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training DALL·E Mini.
18
+ Script adapted from run_summarization_flax.py
19
+ """
20
+
21
+ import io
22
+ import logging
23
+ import os
24
+ import sys
25
+ import tempfile
26
+ import time
27
+ from dataclasses import asdict, dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, NamedTuple, Optional
30
+
31
+ import datasets
32
+ import flax
33
+ import jax
34
+ import jax.numpy as jnp
35
+ import jaxlib
36
+ import numpy as np
37
+ import optax
38
+ import transformers
39
+ import wandb
40
+ from datasets import Dataset
41
+ from flax import core, struct, traverse_util
42
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
43
+ from flax.serialization import from_bytes, to_bytes
44
+ from flax.training.common_utils import onehot
45
+ from jax.experimental import PartitionSpec, maps
46
+ from jax.experimental.compilation_cache import compilation_cache as cc
47
+ from jax.experimental.pjit import pjit, with_sharding_constraint
48
+ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
49
+ from tqdm import tqdm
50
+ from transformers import HfArgumentParser
51
+
52
+ import dalle_mini
53
+ from dalle_mini.data import Dataset
54
+ from dalle_mini.model import (
55
+ DalleBart,
56
+ DalleBartConfig,
57
+ DalleBartTokenizer,
58
+ set_partitions,
59
+ )
60
+
61
+ try:
62
+ from google.cloud import storage
63
+ except:
64
+ storage = None
65
+
66
+ cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ @dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
75
+ """
76
+
77
+ model_name_or_path: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": "The model checkpoint for weights initialization. "
81
+ "Don't set if you want to train a model from scratch. "
82
+ "W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`."
83
+ },
84
+ )
85
+ config_name: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "Pretrained config name or path if not the same as model_name_or_path"
89
+ },
90
+ )
91
+ tokenizer_name: Optional[str] = field(
92
+ default=None,
93
+ metadata={
94
+ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
95
+ },
96
+ )
97
+ dtype: Optional[str] = field(
98
+ default="float32",
99
+ metadata={
100
+ "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
101
+ },
102
+ )
103
+ restore_state: Optional[bool] = field(
104
+ default=False,
105
+ metadata={
106
+ "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
107
+ },
108
+ )
109
+ dropout: Optional[float] = field(
110
+ default=None,
111
+ metadata={"help": "Dropout rate. Overwrites config."},
112
+ )
113
+ activation_dropout: Optional[float] = field(
114
+ default=None,
115
+ metadata={"help": "Activation dropout rate. Overwrites config."},
116
+ )
117
+ attention_dropout: Optional[float] = field(
118
+ default=None,
119
+ metadata={"help": "Attention dropout rate. Overwrites config."},
120
+ )
121
+
122
+ def __post_init__(self):
123
+ if self.tokenizer_name is None:
124
+ self.tokenizer_name = self.model_name_or_path
125
+ assert (
126
+ self.tokenizer_name is not None
127
+ ), "Tokenizer name or model name/path needs to be specified"
128
+ if self.restore_state:
129
+ assert self.model_name_or_path is not None and (
130
+ "/model-" in self.model_name_or_path
131
+ ), "Restoring state only available with W&B artifact reference"
132
+
133
+ def get_metadata(self):
134
+ if self.model_name_or_path is not None and ":" in self.model_name_or_path:
135
+ if jax.process_index() == 0:
136
+ artifact = wandb.run.use_artifact(self.model_name_or_path)
137
+ else:
138
+ artifact = wandb.Api().artifact(self.model_name_or_path)
139
+ return artifact.metadata
140
+ else:
141
+ return dict()
142
+
143
+ def get_opt_state(self):
144
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
145
+ if self.restore_state is True:
146
+ # wandb artifact
147
+ state_artifact = self.model_name_or_path.replace(
148
+ "/model-", "/state-", 1
149
+ )
150
+ if jax.process_index() == 0:
151
+ artifact = wandb.run.use_artifact(state_artifact)
152
+ else:
153
+ artifact = wandb.Api().artifact(state_artifact)
154
+ if artifact.metadata.get("bucket_path"):
155
+ # we will read directly file contents
156
+ self.restore_state = artifact.metadata["bucket_path"]
157
+ else:
158
+ artifact_dir = artifact.download(tmp_dir)
159
+ self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
160
+
161
+ if self.restore_state.startswith("gs://"):
162
+ bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
163
+ bucket, blob_name = str(bucket_path).split("/", 1)
164
+ assert (
165
+ storage is not None
166
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
167
+ client = storage.Client()
168
+ bucket = client.bucket(bucket)
169
+ blob = bucket.blob(blob_name)
170
+ return blob.download_as_bytes()
171
+
172
+ with Path(self.restore_state).open("rb") as f:
173
+ return f.read()
174
+
175
+
176
+ @dataclass
177
+ class DataTrainingArguments:
178
+ """
179
+ Arguments pertaining to what data we are going to input our model for training and eval.
180
+ """
181
+
182
+ text_column: Optional[str] = field(
183
+ default="caption",
184
+ metadata={
185
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
186
+ },
187
+ )
188
+ encoding_column: Optional[str] = field(
189
+ default="encoding",
190
+ metadata={
191
+ "help": "The name of the column in the datasets containing the image encodings."
192
+ },
193
+ )
194
+ dataset_repo_or_path: str = field(
195
+ default=None,
196
+ metadata={"help": "The dataset repository containing encoded files."},
197
+ )
198
+ train_file: Optional[str] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "The input training data file (glob & braceexpand acceptable)."
202
+ },
203
+ )
204
+ validation_file: Optional[str] = field(
205
+ default=None,
206
+ metadata={
207
+ "help": "An optional input evaluation data file (glob & braceexpand acceptable)."
208
+ },
209
+ )
210
+ # data loading should not be a bottleneck so we use "streaming" mode by default
211
+ streaming: Optional[bool] = field(
212
+ default=True,
213
+ metadata={"help": "Whether to stream the dataset."},
214
+ )
215
+ use_auth_token: Optional[bool] = field(
216
+ default=False,
217
+ metadata={
218
+ "help": "Whether to use the authentication token for private datasets."
219
+ },
220
+ )
221
+ shard_by_host: Optional[bool] = field(
222
+ default=False,
223
+ metadata={
224
+ "help": "Whether to shard data files by host in multi-host environments."
225
+ },
226
+ )
227
+ blank_caption_prob: Optional[float] = field(
228
+ default=0.0,
229
+ metadata={
230
+ "help": "Probability of removing some captions for classifier-free guidance."
231
+ },
232
+ )
233
+ clip_score_column: Optional[str] = field(
234
+ default="clip_score",
235
+ metadata={"help": "Column that containts clip score for filtering."},
236
+ )
237
+ min_clip_score: Optional[float] = field(
238
+ default=None,
239
+ metadata={"help": "Minimum clip score required."},
240
+ )
241
+ max_clip_score: Optional[float] = field(
242
+ default=None,
243
+ metadata={"help": "Maximum clip score required."},
244
+ )
245
+ filter_column: Optional[str] = field(
246
+ default=None,
247
+ metadata={"help": "Column that containts classes to be filtered."},
248
+ )
249
+ filter_value: Optional[str] = field(
250
+ default=None,
251
+ metadata={"help": "Class value to be kept during filtering."},
252
+ )
253
+ multi_eval_ds: Optional[bool] = field(
254
+ default=False,
255
+ metadata={
256
+ "help": "Whether to look for multiple validation datasets (local support only)."
257
+ },
258
+ )
259
+ max_train_samples: Optional[int] = field(
260
+ default=None,
261
+ metadata={
262
+ "help": "For debugging purposes or quicker training, truncate the number of training examples."
263
+ },
264
+ )
265
+ max_eval_samples: Optional[int] = field(
266
+ default=None,
267
+ metadata={
268
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
269
+ },
270
+ )
271
+ preprocessing_num_workers: Optional[int] = field(
272
+ default=None,
273
+ metadata={
274
+ "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
275
+ },
276
+ )
277
+ overwrite_cache: bool = field(
278
+ default=False,
279
+ metadata={
280
+ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
281
+ },
282
+ )
283
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
284
+ seed_dataset: int = field(
285
+ default=None,
286
+ metadata={
287
+ "help": "Random seed for the dataset that will be set at the beginning of training."
288
+ },
289
+ )
290
+
291
+ def __post_init__(self):
292
+ if self.dataset_repo_or_path is None:
293
+ raise ValueError("Need a dataset repository or path.")
294
+
295
+
296
+ @dataclass
297
+ class TrainingArguments:
298
+ """
299
+ Arguments pertaining to training parameters.
300
+ """
301
+
302
+ output_dir: str = field(
303
+ metadata={
304
+ "help": "The output directory where the model predictions and checkpoints will be written."
305
+ },
306
+ )
307
+ overwrite_output_dir: bool = field(
308
+ default=False,
309
+ metadata={
310
+ "help": (
311
+ "Overwrite the content of the output directory. "
312
+ "Use this to continue training if output_dir points to a checkpoint directory."
313
+ )
314
+ },
315
+ )
316
+
317
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
318
+ do_eval: bool = field(
319
+ default=False, metadata={"help": "Whether to run eval on the validation set."}
320
+ )
321
+
322
+ per_device_train_batch_size: int = field(
323
+ default=8,
324
+ metadata={"help": "Batch size per data parallel device for training."},
325
+ )
326
+ per_device_eval_batch_size: Optional[int] = field(
327
+ default=None,
328
+ metadata={
329
+ "help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
330
+ },
331
+ )
332
+
333
+ gradient_accumulation_steps: int = field(
334
+ default=1,
335
+ metadata={
336
+ "help": "Number of updates steps to accumulate before performing an update pass."
337
+ },
338
+ )
339
+ gradient_checkpointing: bool = field(
340
+ default=False, metadata={"help": "Use gradient checkpointing."}
341
+ )
342
+
343
+ learning_rate: float = field(
344
+ default=5e-5, metadata={"help": "The initial learning rate."}
345
+ )
346
+ optim: str = field(
347
+ default="distributed_shampoo",
348
+ metadata={
349
+ "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
350
+ },
351
+ )
352
+ weight_decay: float = field(
353
+ default=0.0, metadata={"help": "Weight decay applied to parameters."}
354
+ )
355
+ beta1: float = field(
356
+ default=0.9,
357
+ metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
358
+ )
359
+ beta2: float = field(
360
+ default=0.999,
361
+ metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
362
+ )
363
+ adam_epsilon: float = field(
364
+ default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
365
+ )
366
+ max_grad_norm: float = field(
367
+ default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
368
+ )
369
+ block_size: int = field(
370
+ default=1024,
371
+ metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
372
+ )
373
+ preconditioning_compute_steps: int = field(
374
+ default=10, metadata={"help": "Number of steps to update preconditioner."}
375
+ )
376
+ skip_preconditioning_dim_size_gt: int = field(
377
+ default=4096,
378
+ metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
379
+ )
380
+ graft_type: str = field(
381
+ default="rmsprop_normalized",
382
+ metadata={
383
+ "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
384
+ },
385
+ )
386
+ nesterov: bool = field(
387
+ default=False,
388
+ metadata={"help": "Use Nesterov momentum for Distributed Shampoo."},
389
+ )
390
+ optim_quantized: bool = field(
391
+ default=False,
392
+ metadata={
393
+ "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
394
+ },
395
+ )
396
+ shard_shampoo_across: str = field(
397
+ default="dp",
398
+ metadata={
399
+ "help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
400
+ },
401
+ )
402
+
403
+ num_train_epochs: int = field(
404
+ default=3, metadata={"help": "Total number of training epochs to perform."}
405
+ )
406
+
407
+ warmup_steps: int = field(
408
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
409
+ )
410
+ lr_decay: str = field(
411
+ default=None,
412
+ metadata={
413
+ "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
414
+ },
415
+ )
416
+ lr_transition_steps: int = field(
417
+ default=None,
418
+ metadata={
419
+ "help": "Number of transition steps associated with learning rate decay when using exponential decay."
420
+ },
421
+ )
422
+ lr_decay_rate: float = field(
423
+ default=None,
424
+ metadata={
425
+ "help": "Decay rate associated with learning rate when using exponential decay."
426
+ },
427
+ )
428
+ lr_staircase: bool = field(
429
+ default=False,
430
+ metadata={
431
+ "help": "Whether to use staircase or continuous learning rate when using exponential decay."
432
+ },
433
+ )
434
+ lr_offset: int = field(
435
+ default=0,
436
+ metadata={"help": "Number of steps to offset learning rate and keep it at 0."},
437
+ )
438
+ logging_steps: int = field(
439
+ default=40, metadata={"help": "Log every X updates steps."}
440
+ )
441
+ eval_steps: int = field(
442
+ default=400, metadata={"help": "Run an evaluation every X steps."}
443
+ )
444
+ save_steps: int = field(
445
+ default=4000, metadata={"help": "Save checkpoint every X updates steps."}
446
+ )
447
+ log_model: bool = field(
448
+ default=False,
449
+ metadata={"help": "Log model to wandb at `save_steps` frequency."},
450
+ )
451
+ log_norm_steps: int = field(
452
+ default=True,
453
+ metadata={"help": "Log parameters and gradients norm at this frequency."},
454
+ )
455
+ log_histogram_steps: int = field(
456
+ default=False,
457
+ metadata={
458
+ "help": "Log parameters and gradients histograms at this frequency. Slows down training."
459
+ },
460
+ )
461
+
462
+ seed_model: int = field(
463
+ default=42,
464
+ metadata={
465
+ "help": "Random seed for the model that will be set at the beginning of training."
466
+ },
467
+ )
468
+
469
+ wandb_entity: Optional[str] = field(
470
+ default=None,
471
+ metadata={"help": "The wandb entity to use (for teams)."},
472
+ )
473
+ wandb_project: str = field(
474
+ default="dalle-mini",
475
+ metadata={"help": "The name of the wandb project."},
476
+ )
477
+ wandb_job_type: str = field(
478
+ default="Seq2Seq",
479
+ metadata={"help": "The name of the wandb job type."},
480
+ )
481
+
482
+ assert_TPU_available: bool = field(
483
+ default=False,
484
+ metadata={"help": "Verify that TPU is not in use."},
485
+ )
486
+
487
+ use_vmap_trick: bool = field(
488
+ default=True,
489
+ metadata={"help": "Verify that TPU is not in use."},
490
+ )
491
+
492
+ mp_devices: Optional[int] = field(
493
+ default=1,
494
+ metadata={
495
+ "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
496
+ },
497
+ )
498
+
499
+ dp_devices: int = field(init=False)
500
+
501
+ def __post_init__(self):
502
+ if self.assert_TPU_available:
503
+ assert (
504
+ jax.local_device_count() == 8
505
+ ), "TPUs in use, please check running processes"
506
+ if self.output_dir.startswith("gs://"):
507
+ assert (
508
+ storage is not None
509
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
510
+ assert self.optim in [
511
+ "distributed_shampoo",
512
+ "adam",
513
+ "adafactor",
514
+ ], f"Selected optimizer not supported: {self.optim}"
515
+ if self.optim == "adafactor" and self.weight_decay == 0:
516
+ self.weight_decay = None
517
+ assert self.graft_type in [
518
+ "rmsprop_normalized",
519
+ "rmsprop",
520
+ "adagrad",
521
+ "adagrad_normalized",
522
+ "sgd",
523
+ "sqrt_n",
524
+ ], f"Selected graft type not supported: {self.graft_type}"
525
+ assert self.lr_decay in [
526
+ None,
527
+ "linear",
528
+ "exponential",
529
+ ], f"Selected learning rate decay not supported: {self.lr_decay}"
530
+ if self.per_device_eval_batch_size is None:
531
+ self.per_device_eval_batch_size = self.per_device_train_batch_size
532
+ if self.log_norm_steps is True:
533
+ self.log_norm_steps = self.logging_steps
534
+ if not self.do_train:
535
+ self.num_train_epochs = 1
536
+ if (
537
+ os.path.exists(self.output_dir)
538
+ and os.listdir(self.output_dir)
539
+ and self.do_train
540
+ and not self.overwrite_output_dir
541
+ ):
542
+ raise ValueError(
543
+ f"Output directory ({self.output_dir}) already exists and is not empty."
544
+ "Use --overwrite_output_dir to overcome."
545
+ )
546
+ assert self.shard_shampoo_across in [
547
+ "dp",
548
+ "mp",
549
+ "2d",
550
+ ], f"Shard shampoo across {self.shard_shampoo_across} not supported."
551
+ assert (
552
+ self.mp_devices > 0
553
+ ), f"Number of devices for model parallelism must be > 0"
554
+ assert (
555
+ jax.device_count() % self.mp_devices == 0
556
+ ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
557
+ self.dp_devices = jax.device_count() // self.mp_devices
558
+
559
+
560
+ def split_params(data):
561
+ """Split params between scanned and non-scanned"""
562
+ flat = traverse_util.flatten_dict(unfreeze(data))
563
+ split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}}
564
+ for k, v in flat.items():
565
+ if "FlaxBartEncoderLayers" in k:
566
+ split["scanned_encoder"][k] = v
567
+ elif "FlaxBartDecoderLayers" in k:
568
+ split["scanned_decoder"][k] = v
569
+ else:
570
+ split["standard"][k] = v
571
+ # remove empty keys
572
+ split = {k: v for k, v in split.items() if v}
573
+ for k, v in split.items():
574
+ split[k] = freeze(traverse_util.unflatten_dict(v))
575
+ return split
576
+
577
+
578
+ def unsplit_params(data):
579
+ flat = {}
580
+ for k in ["standard", "scanned_encoder", "scanned_decoder"]:
581
+ if k in data:
582
+ flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
583
+ return freeze(traverse_util.unflatten_dict(flat))
584
+
585
+
586
+ class TrainState(struct.PyTreeNode):
587
+ step: int
588
+ params: core.FrozenDict[str, Any]
589
+ opt_state: optax.OptState
590
+ apply_fn: Callable = struct.field(pytree_node=False)
591
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
592
+ dropout_rng: jnp.ndarray = None
593
+ epoch: int = 0
594
+ train_time: float = 0.0 # total time the model trained
595
+ train_samples: int = 0 # number of samples seen
596
+
597
+ def apply_gradients(self, *, grads, **kwargs):
598
+ grads = split_params(grads)
599
+ params = split_params(self.params)
600
+ opt_state = {}
601
+ # we loop over keys: "standard", "scanned_encoder", "scanned_decoder"
602
+ for k, param in params.items():
603
+ update_fn = self.tx[k].update
604
+ if "scanned" in k:
605
+ update_fn = jax.vmap(update_fn, in_axes=(0, 0, 0), out_axes=(0, 0))
606
+ updates, new_opt_state = update_fn(grads[k], self.opt_state[k], param)
607
+ params[k] = optax.apply_updates(param, updates)
608
+ opt_state[k] = new_opt_state
609
+ params = unsplit_params(params)
610
+
611
+ return self.replace(
612
+ step=self.step + 1,
613
+ params=params,
614
+ opt_state=freeze(opt_state),
615
+ **kwargs,
616
+ )
617
+
618
+ @classmethod
619
+ def create(cls, *, apply_fn, params, tx, **kwargs):
620
+ opt_state = {}
621
+ for k, p in split_params(params).items():
622
+ init_fn = tx[k].init
623
+ if "scanned" in k:
624
+ init_fn = jax.vmap(init_fn)
625
+ opt_state[k] = init_fn(p)
626
+ return cls(
627
+ step=0,
628
+ apply_fn=apply_fn,
629
+ params=params,
630
+ tx=tx,
631
+ opt_state=freeze(opt_state),
632
+ **kwargs,
633
+ )
634
+
635
+
636
+ def main():
637
+ # See all possible arguments by passing the --help flag to this script.
638
+ parser = HfArgumentParser(
639
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
640
+ )
641
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
642
+ # If we pass only one argument to the script and it's the path to a json file,
643
+ # let's parse it to get our arguments.
644
+ model_args, data_args, training_args = parser.parse_json_file(
645
+ json_file=os.path.abspath(sys.argv[1])
646
+ )
647
+ else:
648
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
649
+
650
+ # check arguments
651
+ if training_args.mp_devices > jax.local_device_count():
652
+ assert (
653
+ data_args.seed_dataset is not None
654
+ ), "Seed dataset must be provided when model is split over multiple hosts"
655
+
656
+ # Make one log on every process with the configuration for debugging.
657
+ logging.basicConfig(
658
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
659
+ datefmt="%m/%d/%Y %H:%M:%S",
660
+ level=logging.INFO,
661
+ )
662
+ # Setup logging, we only want one process per machine to log things on the screen.
663
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
664
+ if jax.process_index() == 0:
665
+ datasets.utils.logging.set_verbosity_warning()
666
+ transformers.utils.logging.set_verbosity_info()
667
+ else:
668
+ datasets.utils.logging.set_verbosity_error()
669
+ transformers.utils.logging.set_verbosity_error()
670
+
671
+ # Set the verbosity to info of the Transformers logger (on main process only):
672
+ logger.info(f"Training/evaluation parameters {training_args}")
673
+
674
+ # Load dataset
675
+ dataset = Dataset(
676
+ **asdict(data_args),
677
+ do_train=training_args.do_train,
678
+ do_eval=training_args.do_eval,
679
+ )
680
+
681
+ logger.info(f"Local TPUs: {jax.local_device_count()}")
682
+ logger.info(f"Global TPUs: {jax.device_count()}")
683
+
684
+ # Set up wandb run
685
+ if jax.process_index() == 0:
686
+ wandb.init(
687
+ entity=training_args.wandb_entity,
688
+ project=training_args.wandb_project,
689
+ job_type=training_args.wandb_job_type,
690
+ config=parser.parse_args(),
691
+ )
692
+
693
+ # Set up our new model config
694
+ config_args = {
695
+ k: getattr(model_args, k)
696
+ for k in ["dropout", "activation_dropout", "attention_dropout"]
697
+ if getattr(model_args, k) is not None
698
+ }
699
+ if model_args.config_name:
700
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
701
+ config.gradient_checkpointing = training_args.gradient_checkpointing
702
+ for k, v in config_args.items():
703
+ setattr(config, k, v)
704
+ else:
705
+ config = None
706
+
707
+ # Load or create new model
708
+ if model_args.model_name_or_path:
709
+ model, params = DalleBart.from_pretrained(
710
+ model_args.model_name_or_path,
711
+ config=config,
712
+ seed=training_args.seed_model,
713
+ dtype=getattr(jnp, model_args.dtype),
714
+ _do_init=False, # we overwrite them with loaded checkpoint
715
+ gradient_checkpointing=training_args.gradient_checkpointing,
716
+ **config_args,
717
+ )
718
+ else:
719
+ model = DalleBart(
720
+ config,
721
+ seed=training_args.seed_model,
722
+ dtype=getattr(jnp, model_args.dtype),
723
+ _do_init=False,
724
+ )
725
+ params = None
726
+ params_shape = model.params_shape_tree
727
+
728
+ # get model metadata
729
+ model_metadata = model_args.get_metadata()
730
+
731
+ # get PartitionSpec for model params (required to be a dict)
732
+ param_spec = set_partitions(params_shape, model.config.use_scan)
733
+ params_shape = freeze(params_shape)
734
+ if params is not None:
735
+ params = freeze(params)
736
+
737
+ # Load tokenizer
738
+ tokenizer = DalleBartTokenizer.from_pretrained(
739
+ model_args.tokenizer_name, use_fast=True
740
+ )
741
+
742
+ # Preprocessing the datasets.
743
+ # We need to normalize and tokenize inputs and targets.
744
+ dataset.preprocess(tokenizer=tokenizer, config=model.config)
745
+
746
+ # Initialize our training
747
+ dropout_rng = jax.random.PRNGKey(training_args.seed_model)
748
+
749
+ # Store some constant
750
+ num_epochs = training_args.num_train_epochs
751
+ # batch size
752
+ batch_size_per_node_per_grad_step = (
753
+ training_args.per_device_train_batch_size
754
+ * jax.local_device_count()
755
+ // training_args.mp_devices
756
+ )
757
+ batch_size_per_node = (
758
+ batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
759
+ )
760
+ batch_size_per_step = batch_size_per_node * jax.process_count()
761
+ eval_batch_size_per_node = (
762
+ training_args.per_device_eval_batch_size
763
+ * jax.local_device_count()
764
+ // training_args.mp_devices
765
+ )
766
+ eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
767
+ len_train_dataset, len_eval_dataset = dataset.length
768
+ steps_per_epoch = (
769
+ len_train_dataset // batch_size_per_node
770
+ if len_train_dataset is not None
771
+ else None
772
+ )
773
+ num_train_steps = (
774
+ steps_per_epoch * num_epochs if steps_per_epoch is not None else None
775
+ )
776
+ num_params = model.num_params(params_shape)
777
+
778
+ logger.info("***** Running training *****")
779
+ logger.info(f" Num examples = {len_train_dataset}")
780
+ logger.info(f" Num Epochs = {num_epochs}")
781
+ logger.info(
782
+ f" Batch size per dp device = {training_args.per_device_train_batch_size}"
783
+ )
784
+ logger.info(f" Number of devices = {jax.device_count()}")
785
+ logger.info(
786
+ f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
787
+ )
788
+ logger.info(f" Batch size per update = {batch_size_per_step}")
789
+ logger.info(f" Model parameters = {num_params:,}")
790
+
791
+ # set up wandb run
792
+ if jax.process_index() == 0:
793
+ # set default x-axis as 'train/step'
794
+ wandb.define_metric("*", step_metric="train/step")
795
+
796
+ # add interesting config parameters
797
+ wandb.config.update(
798
+ {
799
+ "len_train_dataset": len_train_dataset,
800
+ "len_eval_dataset": len_eval_dataset,
801
+ "batch_size_per_step": batch_size_per_step,
802
+ "num_params": num_params,
803
+ "model_config": model.config.to_dict(),
804
+ "num_devices": jax.device_count(),
805
+ "versions": {
806
+ "jax": jax.__version__,
807
+ "jaxlib": jaxlib.__version__,
808
+ "flax": flax.__version__,
809
+ "transformers": transformers.__version__,
810
+ "datasets": datasets.__version__,
811
+ "wandb": wandb.__version__,
812
+ "dalle_mini": dalle_mini.__version__,
813
+ },
814
+ }
815
+ )
816
+
817
+ # Create learning rate schedule
818
+ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
819
+ """Create the learning rate function."""
820
+ warmup_fn = optax.linear_schedule(
821
+ init_value=0.0,
822
+ end_value=training_args.learning_rate,
823
+ transition_steps=training_args.warmup_steps + 1, # ensure not 0
824
+ )
825
+ last_boundary = training_args.warmup_steps
826
+ # offset step when resuming
827
+ if training_args.lr_offset:
828
+ warmup_fn = optax.join_schedules(
829
+ schedules=[optax.constant_schedule(0.0), warmup_fn],
830
+ boundaries=[training_args.lr_offset],
831
+ )
832
+ last_boundary += training_args.lr_offset
833
+ if training_args.lr_decay is None:
834
+ return warmup_fn
835
+ elif training_args.lr_decay == "linear":
836
+ assert (
837
+ num_train_steps is not None
838
+ ), "linear decay requires knowing the dataset length"
839
+ decay_fn = optax.linear_schedule(
840
+ init_value=training_args.learning_rate,
841
+ end_value=0,
842
+ transition_steps=num_train_steps - training_args.warmup_steps,
843
+ )
844
+ elif training_args.lr_decay == "exponential":
845
+ decay_fn = optax.exponential_decay(
846
+ init_value=training_args.learning_rate,
847
+ transition_steps=training_args.lr_transition_steps,
848
+ decay_rate=training_args.lr_decay_rate,
849
+ staircase=training_args.lr_staircase,
850
+ )
851
+ schedule_fn = optax.join_schedules(
852
+ schedules=[warmup_fn, decay_fn],
853
+ boundaries=[last_boundary],
854
+ )
855
+ return schedule_fn
856
+
857
+ learning_rate_fn = create_learning_rate_fn()
858
+
859
+ # create adam optimizer
860
+ if training_args.optim == "distributed_shampoo":
861
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
862
+ graft_type = {
863
+ "sgd": GraftingType.SGD,
864
+ "adagrad": GraftingType.ADAGRAD,
865
+ "rmsprop": GraftingType.RMSPROP,
866
+ "rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
867
+ "sqrt_n": GraftingType.SQRT_N,
868
+ "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
869
+ }[training_args.graft_type]
870
+ statistics_partition_spec = (
871
+ PartitionSpec(None, training_args.shard_shampoo_across, None)
872
+ if training_args.shard_shampoo_across != "2d"
873
+ else PartitionSpec(None, "dp", "mp")
874
+ )
875
+ opt = distributed_shampoo(
876
+ learning_rate_fn,
877
+ block_size=training_args.block_size,
878
+ beta1=training_args.beta1,
879
+ beta2=training_args.beta2,
880
+ diagonal_epsilon=1e-10,
881
+ matrix_epsilon=1e-6,
882
+ weight_decay=training_args.weight_decay,
883
+ start_preconditioning_step=max(
884
+ training_args.preconditioning_compute_steps + 1, 101
885
+ ),
886
+ preconditioning_compute_steps=training_args.preconditioning_compute_steps,
887
+ statistics_compute_steps=1,
888
+ best_effort_shape_interpretation=True,
889
+ graft_type=graft_type,
890
+ nesterov=training_args.nesterov,
891
+ exponent_override=0,
892
+ statistics_partition_spec=statistics_partition_spec,
893
+ preconditioner_partition_spec=PartitionSpec(
894
+ training_args.shard_shampoo_across, None, None
895
+ )
896
+ if training_args.shard_shampoo_across != "2d"
897
+ else PartitionSpec(
898
+ "mp" if training_args.mp_devices > training_args.dp_devices else "dp",
899
+ None,
900
+ None,
901
+ ),
902
+ num_devices_for_pjit=training_args.dp_devices,
903
+ shard_optimizer_states=True,
904
+ inverse_failure_threshold=0.1,
905
+ moving_average_for_momentum=True,
906
+ skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
907
+ clip_by_scaled_gradient_norm=None,
908
+ precision=jax.lax.Precision.HIGHEST,
909
+ best_effort_memory_usage_reduction=training_args.optim_quantized,
910
+ )
911
+ # get the real optimizer and helper functions
912
+ update_fn = opt.update
913
+
914
+ optimizer = {}
915
+ opt_fn = {}
916
+ for k, p in split_params(params_shape).items():
917
+ if "scanned" in k:
918
+ p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
919
+ optimizer[k] = opt.init(p)
920
+ opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
921
+ optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn
922
+ )
923
+ optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn)
924
+
925
+ elif training_args.optim == "adam":
926
+ optimizer = optax.adamw(
927
+ learning_rate=learning_rate_fn,
928
+ b1=training_args.beta1,
929
+ b2=training_args.beta2,
930
+ eps=training_args.adam_epsilon,
931
+ weight_decay=training_args.weight_decay,
932
+ )
933
+ optimizer = {k: optimizer for k in split_params(params_shape)}
934
+
935
+ elif training_args.optim == "adafactor":
936
+ # We use the default parameters here to initialize adafactor,
937
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
938
+ optimizer = optax.adafactor(
939
+ learning_rate=learning_rate_fn,
940
+ clipping_threshold=training_args.max_grad_norm,
941
+ weight_decay_rate=training_args.weight_decay,
942
+ )
943
+ optimizer = {k: optimizer for k in split_params(params_shape)}
944
+
945
+ # get PartitionSpec for optimizer state
946
+ def get_opt_state_spec_and_shape():
947
+ # get opt_state shape without actual init
948
+ opt_state_shape = {}
949
+ for k, p in split_params(params_shape).items():
950
+ if "scanned" not in k:
951
+ opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
952
+ else:
953
+ opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
954
+
955
+ if training_args.optim == "adafactor":
956
+ # factorized state must be replicated (rank different than params)
957
+ opt_state_spec = {k: None for k in split_params(params_shape)}
958
+
959
+ elif training_args.optim in ["adam", "distributed_shampoo"]:
960
+
961
+ def _opt_state_spec_per_leaf(x, spec):
962
+ if isinstance(x, FrozenDict):
963
+ # variables with same structure as params
964
+ return spec
965
+ else:
966
+ # other variables such as count
967
+ return None
968
+
969
+ split_spec = split_params(set_partitions(params_shape, False))
970
+ opt_state_spec = {}
971
+ for k, p in split_params(params_shape).items():
972
+ if "scanned" in k:
973
+ p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
974
+ if training_args.optim == "adam":
975
+ opt_state_spec[k] = jax.tree_map(
976
+ _opt_state_spec_per_leaf,
977
+ opt_state_shape[k],
978
+ split_spec[k],
979
+ # return None spec for empty elements
980
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
981
+ )
982
+ elif training_args.optim == "distributed_shampoo":
983
+ opt_state_spec[k] = opt_fn[k].pspec_fn(
984
+ p,
985
+ split_spec[k],
986
+ statistics_partition_spec,
987
+ )
988
+ # add dimension for scanned params
989
+ if "scanned" in k:
990
+ opt_state_spec[k] = jax.tree_map(
991
+ lambda x: PartitionSpec(*(None,) + x)
992
+ if x is not None
993
+ else None,
994
+ opt_state_spec[k],
995
+ is_leaf=lambda x: isinstance(x, PartitionSpec),
996
+ )
997
+
998
+ else:
999
+ raise NotImplementedError
1000
+ return freeze(opt_state_spec), freeze(opt_state_shape)
1001
+
1002
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape()
1003
+
1004
+ # create a mesh
1005
+ mesh_shape = (training_args.dp_devices, training_args.mp_devices)
1006
+ devices = np.asarray(jax.devices()).reshape(*mesh_shape)
1007
+ mesh = maps.Mesh(devices, ("dp", "mp"))
1008
+ logger.info(f" Mesh shape: {mesh_shape}")
1009
+
1010
+ # define state spec
1011
+ state_spec = TrainState(
1012
+ params=param_spec,
1013
+ opt_state=opt_state_spec,
1014
+ dropout_rng=None,
1015
+ step=None,
1016
+ epoch=None,
1017
+ train_time=None,
1018
+ train_samples=None,
1019
+ apply_fn=model.__call__,
1020
+ tx=optimizer,
1021
+ )
1022
+
1023
+ # init params if not available yet
1024
+ def maybe_init_params(params):
1025
+ if params is not None:
1026
+ # model params are correctly loaded
1027
+ return params
1028
+ else:
1029
+ # params have not been initialized yet
1030
+ return model.init_weights(model.key, model.input_shape)
1031
+
1032
+ with mesh:
1033
+ logger.info(" Creating state")
1034
+
1035
+ # restore metadata
1036
+ attr_state = {}
1037
+ keys = ["train_time", "train_samples"]
1038
+ if model_args.restore_state:
1039
+ keys += ["step", "epoch"]
1040
+ attr_state = {k: v for k, v in model_metadata.items() if k in keys}
1041
+
1042
+ if not model_args.restore_state:
1043
+
1044
+ def init_state(params):
1045
+ return TrainState.create(
1046
+ apply_fn=model.__call__,
1047
+ tx=optimizer,
1048
+ params=maybe_init_params(params),
1049
+ dropout_rng=dropout_rng,
1050
+ **attr_state,
1051
+ )
1052
+
1053
+ state = pjit(
1054
+ init_state,
1055
+ in_axis_resources=(param_spec,)
1056
+ if model_args.model_name_or_path
1057
+ else None,
1058
+ out_axis_resources=state_spec,
1059
+ donate_argnums=(0,),
1060
+ )(params)
1061
+
1062
+ else:
1063
+ # load opt_state
1064
+ opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
1065
+
1066
+ def restore_state(params, opt_state):
1067
+ return TrainState(
1068
+ apply_fn=model.__call__,
1069
+ tx=optimizer,
1070
+ params=params,
1071
+ opt_state=opt_state,
1072
+ dropout_rng=dropout_rng,
1073
+ **attr_state,
1074
+ )
1075
+
1076
+ state = pjit(
1077
+ restore_state,
1078
+ in_axis_resources=(
1079
+ param_spec,
1080
+ opt_state_spec,
1081
+ ),
1082
+ out_axis_resources=state_spec,
1083
+ donate_argnums=(0, 1),
1084
+ )(params, opt_state)
1085
+
1086
+ # remove opt_state from CPU
1087
+ del opt_state
1088
+
1089
+ # free CPU memory
1090
+ del params, opt_state_spec, opt_state_shape
1091
+
1092
+ # define batch specs
1093
+ batch_spec = PartitionSpec("dp")
1094
+ grad_batch_spec = PartitionSpec(None, "dp")
1095
+
1096
+ # define loss
1097
+ def loss_fn(logits, labels):
1098
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
1099
+ loss = loss.mean()
1100
+ return loss
1101
+
1102
+ # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
1103
+ # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
1104
+ use_vmap_trick = training_args.use_vmap_trick
1105
+
1106
+ # make grad_param_spec for vmap
1107
+ if use_vmap_trick:
1108
+ grad_param_spec = jax.tree_map(
1109
+ lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))),
1110
+ param_spec,
1111
+ )
1112
+
1113
+ # Define gradient update step fn
1114
+ def train_step(state, batch, train_time):
1115
+
1116
+ # get a minibatch (one gradient accumulation slice)
1117
+ def get_minibatch(batch, grad_idx):
1118
+ return jax.tree_map(
1119
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
1120
+ batch,
1121
+ )
1122
+
1123
+ def compute_loss(params, minibatch, dropout_rng):
1124
+ # minibatch has dim (batch_size, ...)
1125
+ minibatch, labels = minibatch.pop("labels")
1126
+ logits = state.apply_fn(
1127
+ **minibatch, params=params, dropout_rng=dropout_rng, train=True
1128
+ )[0]
1129
+ return loss_fn(logits, labels)
1130
+
1131
+ grad_fn = jax.value_and_grad(compute_loss)
1132
+
1133
+ def loss_and_grad(grad_idx, dropout_rng):
1134
+ # minibatch at grad_idx for gradient accumulation (None otherwise)
1135
+ minibatch = (
1136
+ get_minibatch(batch, grad_idx) if grad_idx is not None else batch
1137
+ )
1138
+ # ensure it is sharded properly
1139
+ minibatch = with_sharding_constraint(minibatch, batch_spec)
1140
+ # only 1 single rng per grad step, let us handle larger batch size (not sure why)
1141
+ dropout_rng, _ = jax.random.split(dropout_rng)
1142
+
1143
+ if use_vmap_trick:
1144
+ # "vmap trick", calculate loss and grads independently per dp_device
1145
+ loss, grads = jax.vmap(
1146
+ grad_fn, in_axes=(None, 0, None), out_axes=(0, 0)
1147
+ )(state.params, minibatch, dropout_rng)
1148
+ # ensure they are sharded correctly
1149
+ loss = with_sharding_constraint(loss, batch_spec)
1150
+ grads = with_sharding_constraint(grads, grad_param_spec)
1151
+ # average across all devices
1152
+ # Note: we could average per device only after gradient accumulation, right before params update
1153
+ loss, grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), (loss, grads))
1154
+ else:
1155
+ # "vmap trick" does not work in multi-hosts and requires too much hbm
1156
+ loss, grads = grad_fn(state.params, minibatch, dropout_rng)
1157
+ # ensure grads are sharded
1158
+ grads = with_sharding_constraint(grads, param_spec)
1159
+ # return loss and grads
1160
+ return loss, grads, dropout_rng
1161
+
1162
+ if training_args.gradient_accumulation_steps == 1:
1163
+ loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng)
1164
+ else:
1165
+ # create initial state for cumul_minibatch_step loop
1166
+ init_minibatch_step = (
1167
+ 0.0,
1168
+ with_sharding_constraint(
1169
+ jax.tree_map(jnp.zeros_like, state.params), param_spec
1170
+ ),
1171
+ state.dropout_rng,
1172
+ )
1173
+
1174
+ # accumulate gradients
1175
+ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
1176
+ cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout
1177
+ loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
1178
+ cumul_loss, cumul_grads = jax.tree_map(
1179
+ jnp.add, (cumul_loss, cumul_grads), (loss, grads)
1180
+ )
1181
+ cumul_grads = with_sharding_constraint(cumul_grads, param_spec)
1182
+ return cumul_loss, cumul_grads, dropout_rng
1183
+
1184
+ # loop over gradients
1185
+ loss, grads, dropout_rng = jax.lax.fori_loop(
1186
+ 0,
1187
+ training_args.gradient_accumulation_steps,
1188
+ cumul_minibatch_step,
1189
+ init_minibatch_step,
1190
+ )
1191
+ grads = with_sharding_constraint(grads, param_spec)
1192
+ # sum -> mean
1193
+ loss, grads = jax.tree_map(
1194
+ lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
1195
+ )
1196
+
1197
+ grads = with_sharding_constraint(grads, param_spec)
1198
+
1199
+ # update state
1200
+ state = state.apply_gradients(
1201
+ grads=grads,
1202
+ dropout_rng=dropout_rng,
1203
+ train_time=train_time,
1204
+ train_samples=state.train_samples + batch_size_per_step,
1205
+ )
1206
+
1207
+ metrics = {
1208
+ "loss": loss,
1209
+ "learning_rate": learning_rate_fn(state.step),
1210
+ }
1211
+
1212
+ def maybe_fn(fn, val, zeros, freq):
1213
+ """Call fn only if it is a logging step"""
1214
+ return jax.lax.cond(
1215
+ state.step % freq == 0,
1216
+ fn,
1217
+ lambda _: zeros,
1218
+ val,
1219
+ )
1220
+
1221
+ if training_args.log_norm_steps:
1222
+ zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
1223
+
1224
+ def norm(val):
1225
+ return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
1226
+
1227
+ gradients_norm = maybe_fn(
1228
+ norm, grads, zeros_norm, training_args.log_norm_steps
1229
+ )
1230
+ params_norm = maybe_fn(
1231
+ norm, state.params, zeros_norm, training_args.log_norm_steps
1232
+ )
1233
+
1234
+ metrics.update(
1235
+ {
1236
+ "gradients_norm": gradients_norm,
1237
+ "params_norm": params_norm,
1238
+ }
1239
+ )
1240
+
1241
+ if training_args.log_histogram_steps:
1242
+ zeros_hist = jax.tree_map(
1243
+ lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
1244
+ )
1245
+
1246
+ def histogram(val):
1247
+ return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
1248
+
1249
+ gradients_hist = maybe_fn(
1250
+ histogram, grads, zeros_hist, training_args.log_histogram_steps
1251
+ )
1252
+ params_hist = maybe_fn(
1253
+ histogram, state.params, zeros_hist, training_args.log_histogram_steps
1254
+ )
1255
+
1256
+ metrics.update(
1257
+ {
1258
+ "params_hist": params_hist,
1259
+ "gradients_hist": gradients_hist,
1260
+ }
1261
+ )
1262
+
1263
+ return state, metrics
1264
+
1265
+ # Define eval fn
1266
+ eval_model = (
1267
+ model
1268
+ if model_args.dtype == "float32"
1269
+ else DalleBart(
1270
+ model.config,
1271
+ seed=training_args.seed_model,
1272
+ dtype=jnp.float32,
1273
+ _do_init=False,
1274
+ )
1275
+ )
1276
+
1277
+ def eval_step(state, batch):
1278
+ def compute_eval_loss(batch):
1279
+ batch, labels = batch.pop("labels")
1280
+ logits = eval_model(**batch, params=state.params, train=False)[0]
1281
+ return loss_fn(logits, labels)
1282
+
1283
+ if use_vmap_trick:
1284
+ loss = jax.vmap(compute_eval_loss)(batch)
1285
+ # ensure they are sharded correctly
1286
+ loss = with_sharding_constraint(loss, batch_spec)
1287
+ # average across all devices
1288
+ loss = jnp.mean(loss)
1289
+ else:
1290
+ loss = compute_eval_loss(batch)
1291
+
1292
+ return loss
1293
+
1294
+ # Create parallel version of the train and eval step
1295
+ p_train_step = pjit(
1296
+ train_step,
1297
+ in_axis_resources=(
1298
+ state_spec,
1299
+ grad_batch_spec
1300
+ if training_args.gradient_accumulation_steps > 1
1301
+ else batch_spec,
1302
+ None,
1303
+ ),
1304
+ out_axis_resources=(state_spec, None),
1305
+ donate_argnums=(0,),
1306
+ )
1307
+ p_eval_step = pjit(
1308
+ eval_step,
1309
+ in_axis_resources=(state_spec, batch_spec),
1310
+ out_axis_resources=None,
1311
+ )
1312
+
1313
+ # define metrics logger
1314
+ class MetricsLogger:
1315
+ def __init__(self, step):
1316
+ # keep state
1317
+ self.state_dict = {}
1318
+ # estimate speed
1319
+ self.step = step
1320
+ self.time = time.perf_counter()
1321
+ self.offset_time = 0.0
1322
+
1323
+ def update_state_metrics(self, state):
1324
+ """Update internal state metrics (logged at each call to be used as x-axis)"""
1325
+ self.state_dict = {
1326
+ f'train/{k.split("_")[-1]}': state[k]
1327
+ for k in ["step", "epoch", "train_time", "train_samples"]
1328
+ }
1329
+ # timing metrics
1330
+ new_step = int(state["step"])
1331
+ new_time = time.perf_counter()
1332
+ if new_step > self.step:
1333
+ # remove time for eval & save
1334
+ delta_time = new_time - self.time - self.offset_time
1335
+ self.offset_time = 0
1336
+ time_per_step = delta_time / (new_step - self.step)
1337
+ self.step = new_step
1338
+ self.time = new_time
1339
+ self.log_time("train_per_step", time_per_step, offset=False)
1340
+ self.log_time("train_per_log", delta_time, offset=False)
1341
+
1342
+ def log_time(self, key, duration, offset=True):
1343
+ if jax.process_index() == 0:
1344
+ wandb.log({f"time/{key}": duration, **self.state_dict})
1345
+ if offset:
1346
+ self.offset_time += duration
1347
+
1348
+ def log(self, metrics, prefix=None):
1349
+ if jax.process_index() == 0:
1350
+ log_metrics = {}
1351
+ for k, v in metrics.items():
1352
+ if "_norm" in k:
1353
+ if self.step % training_args.log_norm_steps == 0:
1354
+ log_metrics[f"{k}/"] = unfreeze(v)
1355
+ elif "_hist" in k:
1356
+ if self.step % training_args.log_histogram_steps == 0:
1357
+ v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
1358
+ v = jax.tree_map(
1359
+ lambda x: wandb.Histogram(np_histogram=x),
1360
+ v,
1361
+ is_leaf=lambda x: isinstance(x, tuple),
1362
+ )
1363
+ log_metrics[f"{k}/"] = v
1364
+ else:
1365
+ if prefix is not None:
1366
+ k = f"{prefix}/{k}"
1367
+ log_metrics[k] = v
1368
+ wandb.log({**log_metrics, **self.state_dict})
1369
+
1370
+ # keep local copy of state
1371
+ local_state = {
1372
+ k: jax.device_get(getattr(state, k)).item()
1373
+ for k in ["step", "epoch", "train_time", "train_samples"]
1374
+ }
1375
+ # init variables
1376
+ start_time = time.perf_counter() - local_state["train_time"]
1377
+ train_metrics = None
1378
+ evaluation_ran = False
1379
+ save_model_ran = False
1380
+ metrics_logger = MetricsLogger(local_state["step"])
1381
+ epochs = tqdm(
1382
+ range(local_state["epoch"], num_epochs),
1383
+ desc=f"Epoch ... (1/{num_epochs})",
1384
+ position=0,
1385
+ disable=jax.process_index() > 0,
1386
+ )
1387
+
1388
+ def run_evaluation():
1389
+ # ======================== Evaluating ==============================
1390
+ if training_args.do_eval:
1391
+ start_eval_time = time.perf_counter()
1392
+ # get validation datasets
1393
+ val_datasets = list(
1394
+ dataset.other_eval_datasets.keys()
1395
+ if hasattr(dataset, "other_eval_datasets")
1396
+ else []
1397
+ )
1398
+ val_datasets += ["eval"]
1399
+ for val_dataset in val_datasets:
1400
+ eval_loader = dataset.dataloader(
1401
+ val_dataset,
1402
+ eval_batch_size_per_step
1403
+ * max(1, training_args.mp_devices // jax.local_device_count()),
1404
+ )
1405
+ eval_steps = (
1406
+ len_eval_dataset // eval_batch_size_per_step
1407
+ if len_eval_dataset is not None
1408
+ else None
1409
+ )
1410
+ eval_loss = []
1411
+ for batch in tqdm(
1412
+ eval_loader,
1413
+ desc="Evaluating...",
1414
+ position=2,
1415
+ leave=False,
1416
+ total=eval_steps,
1417
+ disable=jax.process_index() > 0,
1418
+ ):
1419
+ # need to keep only eval_batch_size_per_node items relevant to the node
1420
+ batch = jax.tree_map(
1421
+ lambda x: x.reshape(
1422
+ (jax.process_count(), eval_batch_size_per_node)
1423
+ + x.shape[1:]
1424
+ ),
1425
+ batch,
1426
+ )
1427
+ batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
1428
+
1429
+ # add dp dimension when using "vmap trick"
1430
+ if use_vmap_trick:
1431
+ bs_shape = (
1432
+ jax.local_device_count() // training_args.mp_devices,
1433
+ training_args.per_device_eval_batch_size,
1434
+ )
1435
+ batch = jax.tree_map(
1436
+ lambda x: x.reshape(bs_shape + x.shape[1:]), batch
1437
+ )
1438
+
1439
+ # freeze batch to pass safely to jax transforms
1440
+ batch = freeze(batch)
1441
+ # accumulate losses async
1442
+ eval_loss.append(p_eval_step(state, batch))
1443
+
1444
+ # get the mean of the loss
1445
+ eval_loss = jnp.stack(eval_loss)
1446
+ eval_loss = jnp.mean(eval_loss)
1447
+ eval_metrics = {"loss": eval_loss}
1448
+
1449
+ # log metrics
1450
+ metrics_logger.log(eval_metrics, prefix=val_dataset)
1451
+
1452
+ # Print metrics and update progress bar
1453
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})"
1454
+ epochs.write(desc)
1455
+ epochs.desc = desc
1456
+
1457
+ # log time
1458
+ metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
1459
+
1460
+ return eval_metrics
1461
+
1462
+ def run_save_model(state, eval_metrics=None):
1463
+ if jax.process_index() == 0:
1464
+
1465
+ start_save_time = time.perf_counter()
1466
+ output_dir = training_args.output_dir
1467
+ use_bucket = output_dir.startswith("gs://")
1468
+ if use_bucket:
1469
+ bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
1470
+ bucket, dir_path = str(bucket_path).split("/", 1)
1471
+ tmp_dir = tempfile.TemporaryDirectory()
1472
+ output_dir = tmp_dir.name
1473
+
1474
+ # save model
1475
+ params = jax.device_get(state.params)
1476
+ model.save_pretrained(
1477
+ output_dir,
1478
+ params=params,
1479
+ )
1480
+
1481
+ # save tokenizer
1482
+ tokenizer.save_pretrained(output_dir)
1483
+
1484
+ # copy to bucket
1485
+ if use_bucket:
1486
+ client = storage.Client()
1487
+ bucket = client.bucket(bucket)
1488
+ for filename in Path(output_dir).glob("*"):
1489
+ blob_name = str(Path(dir_path) / "model" / filename.name)
1490
+ blob = bucket.blob(blob_name)
1491
+ blob.upload_from_filename(str(filename))
1492
+ tmp_dir.cleanup()
1493
+
1494
+ # save state
1495
+ opt_state = jax.device_get(state.opt_state)
1496
+ if use_bucket:
1497
+ blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
1498
+ blob = bucket.blob(blob_name)
1499
+ blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1500
+ else:
1501
+ with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
1502
+ f.write(to_bytes(opt_state))
1503
+
1504
+ # save to W&B
1505
+ if training_args.log_model:
1506
+ # save some space
1507
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1508
+ c.cleanup(wandb.util.from_human_size("20GB"))
1509
+
1510
+ metadata = {
1511
+ k: jax.device_get(getattr(state, k)).item()
1512
+ for k in ["step", "epoch", "train_time", "train_samples"]
1513
+ }
1514
+ metadata["num_params"] = num_params
1515
+ if eval_metrics is not None:
1516
+ metadata["eval"] = eval_metrics
1517
+
1518
+ # create model artifact
1519
+ if use_bucket:
1520
+ metadata["bucket_path"] = f"gs://{bucket_path}/model"
1521
+ artifact = wandb.Artifact(
1522
+ name=f"model-{wandb.run.id}",
1523
+ type="DalleBart_model",
1524
+ metadata=metadata,
1525
+ )
1526
+ if use_bucket:
1527
+ artifact.add_reference(metadata["bucket_path"])
1528
+ else:
1529
+ for filename in [
1530
+ "config.json",
1531
+ "flax_model.msgpack",
1532
+ "merges.txt",
1533
+ "special_tokens_map.json",
1534
+ "tokenizer.json",
1535
+ "tokenizer_config.json",
1536
+ "vocab.json",
1537
+ ]:
1538
+ artifact.add_file(
1539
+ f"{Path(training_args.output_dir) / filename}"
1540
+ )
1541
+ wandb.run.log_artifact(artifact)
1542
+
1543
+ # create state artifact
1544
+ if use_bucket:
1545
+ metadata["bucket_path"] = f"gs://{bucket_path}/state"
1546
+ artifact_state = wandb.Artifact(
1547
+ name=f"state-{wandb.run.id}",
1548
+ type="DalleBart_state",
1549
+ metadata=metadata,
1550
+ )
1551
+ if use_bucket:
1552
+ artifact_state.add_reference(metadata["bucket_path"])
1553
+ else:
1554
+ artifact_state.add_file(
1555
+ f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1556
+ )
1557
+ wandb.run.log_artifact(artifact_state)
1558
+ metrics_logger.log_time("save_model", time.perf_counter() - start_save_time)
1559
+
1560
+ logger.info(" Ready to start training")
1561
+ with mesh:
1562
+ for epoch in epochs:
1563
+ state = state.replace(epoch=epoch)
1564
+ local_state["epoch"] = epoch
1565
+ # ======================== Training ================================
1566
+ metrics_logger.update_state_metrics(local_state)
1567
+ metrics_logger.log({})
1568
+
1569
+ if training_args.do_train:
1570
+ # load data - may be replicated on multiple nodes
1571
+ node_groups = max(
1572
+ 1, training_args.mp_devices // jax.local_device_count()
1573
+ )
1574
+ loader_bs = batch_size_per_node * node_groups
1575
+ train_loader = dataset.dataloader(
1576
+ "train",
1577
+ loader_bs,
1578
+ epoch,
1579
+ )
1580
+ # train
1581
+ for batch in tqdm(
1582
+ train_loader,
1583
+ desc="Training...",
1584
+ position=1,
1585
+ leave=False,
1586
+ total=steps_per_epoch,
1587
+ disable=jax.process_index() > 0,
1588
+ ):
1589
+ # calculate delta time (we have a lag of one step but it's ok)
1590
+ train_time = time.perf_counter() - start_time
1591
+
1592
+ # reset control variables
1593
+ evaluation_ran = False
1594
+ save_model_ran = False
1595
+
1596
+ # set correct shape to batch
1597
+ # - add grad_step dim if gradient_accumulation_steps > 1
1598
+ bs_shape = (
1599
+ (batch_size_per_node_per_grad_step * node_groups,)
1600
+ if not use_vmap_trick
1601
+ else (
1602
+ jax.local_device_count()
1603
+ * node_groups
1604
+ // training_args.mp_devices, # local dp devices
1605
+ training_args.per_device_train_batch_size,
1606
+ )
1607
+ )
1608
+ if training_args.gradient_accumulation_steps > 1:
1609
+ # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1610
+ # to avoid any data redistribution when sharding
1611
+ bs_shape = (
1612
+ training_args.gradient_accumulation_steps,
1613
+ ) + bs_shape
1614
+
1615
+ # reshape batch
1616
+ batch = jax.tree_map(
1617
+ lambda x: x.reshape(bs_shape + x.shape[1:]),
1618
+ batch,
1619
+ )
1620
+ # freeze batch to pass safely to jax transforms
1621
+ batch = freeze(batch)
1622
+
1623
+ # train step
1624
+ state, train_metrics = p_train_step(state, batch, train_time)
1625
+ local_state["step"] += 1
1626
+ local_state["train_time"] = train_time
1627
+ local_state["train_samples"] += batch_size_per_step
1628
+
1629
+ if (
1630
+ local_state["step"] % training_args.logging_steps == 0
1631
+ and jax.process_index() == 0
1632
+ ):
1633
+ metrics_logger.update_state_metrics(local_state)
1634
+ metrics_logger.log(train_metrics, prefix="train")
1635
+
1636
+ eval_metrics = None
1637
+ if local_state["step"] % training_args.eval_steps == 0:
1638
+ eval_metrics = run_evaluation()
1639
+ evaluation_ran = True
1640
+
1641
+ if local_state["step"] % training_args.save_steps == 0:
1642
+ run_save_model(state, eval_metrics)
1643
+ save_model_ran = True
1644
+
1645
+ # log final train metrics
1646
+ if train_metrics is not None:
1647
+ metrics_logger.update_state_metrics(local_state)
1648
+ metrics_logger.log(train_metrics, prefix="train")
1649
+
1650
+ epochs.write(
1651
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1652
+ )
1653
+
1654
+ # Final evaluation at the end of each epoch
1655
+ if not evaluation_ran:
1656
+ eval_metrics = run_evaluation()
1657
+
1658
+ # save checkpoint after each epoch
1659
+ if not save_model_ran:
1660
+ run_save_model(state, eval_metrics)
1661
+
1662
+
1663
+ if __name__ == "__main__":
1664
+ main()
utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import wandb
6
+
7
+
8
+ class PretrainedFromWandbMixin:
9
+ @classmethod
10
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11
+ """
12
+ Initializes from a wandb artifact or delegates loading to the superclass.
13
+ """
14
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
15
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
16
+ pretrained_model_name_or_path
17
+ ):
18
+ # wandb artifact
19
+ if wandb.run is not None:
20
+ artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
21
+ else:
22
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
23
+ pretrained_model_name_or_path = artifact.download(tmp_dir)
24
+
25
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
26
+ pretrained_model_name_or_path, *model_args, **kwargs
27
+ )