| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Embedding utils.""" |
| | from typing import Any, Callable, Dict, Optional |
| |
|
| | import flax.linen as nn |
| | import jax |
| | from jax.nn import initializers |
| | import jax.numpy as jnp |
| | from scenic.projects.layout_denoise.layers import common |
| |
|
| |
|
| | class TokenEmbedding(nn.Module): |
| | """Creates learned embeddings for text. |
| | |
| | Attributes: |
| | hidden_dim: Hidden dimension for the pos embeddings. |
| | vocab_size: Number of unique tokens. |
| | token_emb_init: Positional embeddings initializer. |
| | dtype: Jax dtype; The dtype of the computation (default: float32). |
| | """ |
| | hidden_dim: int |
| | vocab_size: int |
| | token_emb_init: Callable[..., Any] = initializers.normal(stddev=1.0) |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | @nn.compact |
| | def __call__(self, tokens) -> jnp.ndarray: |
| | """Creates the token embeddings. |
| | |
| | Args: |
| | tokens: the tokens to be embeded. |
| | |
| | Returns: |
| | Embedding for tokens with rank=token_rank + 1. |
| | """ |
| | embs = self.param('token_emb', self.token_emb_init, |
| | (self.vocab_size, self.hidden_dim)) |
| | embds = jnp.take(embs, tokens, axis=0) |
| | return jnp.asarray(embds, self.dtype) |
| |
|
| |
|
| | class InputPosEmbeddingSine(nn.Module): |
| | """Creates sinusoidal positional embeddings for inputs.""" |
| |
|
| | hidden_dim: int |
| | dtype: jnp.dtype = jnp.float32 |
| | scale: Optional[float] = None |
| | temperature: float = 10000 |
| | normalize: bool = True |
| |
|
| | @nn.compact |
| | def __call__(self, padding_mask: jnp.ndarray) -> jnp.ndarray: |
| | """Creates the positional embeddings for transformer inputs. |
| | |
| | Args: |
| | padding_mask: Binary matrix with 0 at padded image regions. Shape is |
| | [batch, height, width] |
| | |
| | Returns: |
| | Positional embedding for inputs. |
| | |
| | Raises: |
| | ValueError if `hidden_dim` is not an even number. |
| | """ |
| | if self.hidden_dim % 2: |
| | raise ValueError('`hidden_dim` must be an even number.') |
| |
|
| | mask = padding_mask.astype(jnp.float32) |
| | y_embed = jnp.cumsum(mask, axis=1) |
| | x_embed = jnp.cumsum(mask, axis=2) |
| |
|
| | if self.normalize: |
| | eps = 1e-6 |
| | scale = self.scale if self.scale is not None else 2 * jnp.pi |
| | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale |
| | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale |
| |
|
| | num_pos_feats = self.hidden_dim // 2 |
| | dim_t = jnp.arange(num_pos_feats, dtype=jnp.float32) |
| | dim_t = self.temperature**(2 * (dim_t // 2) / num_pos_feats) |
| |
|
| | pos_x = x_embed[:, :, :, jnp.newaxis] / dim_t |
| | pos_y = y_embed[:, :, :, jnp.newaxis] / dim_t |
| | pos_x = jnp.stack([ |
| | jnp.sin(pos_x[:, :, :, 0::2]), |
| | jnp.cos(pos_x[:, :, :, 1::2]), |
| | ], |
| | axis=4).reshape(padding_mask.shape + (-1,)) |
| | pos_y = jnp.stack([ |
| | jnp.sin(pos_y[:, :, :, 0::2]), |
| | jnp.cos(pos_y[:, :, :, 1::2]), |
| | ], |
| | axis=4).reshape(padding_mask.shape + (-1,)) |
| |
|
| | pos = jnp.concatenate([pos_y, pos_x], axis=3) |
| | b, h, w = padding_mask.shape |
| | pos = jnp.reshape(pos, [b, h * w, self.hidden_dim]) |
| | return jnp.asarray(pos, self.dtype) |
| |
|
| |
|
| | class ImageEmbedding(nn.Module): |
| | """Creates learned embeddings for images. |
| | |
| | Attributes: |
| | hidden_dim: Hidden dimension for the pos embeddings. |
| | backbone_num_filters: Num filters in the ResNet backbone. |
| | backbone_num_layers: Num layers in the ResNet backbone. |
| | dtype: Jax dtype; The dtype of the computation (default: float32). |
| | """ |
| | hidden_dim: int |
| | backbone_num_filters: int |
| | backbone_num_layers: int |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | @nn.compact |
| | def __call__(self, |
| | cnn, |
| | images: jnp.ndarray, |
| | train: bool, |
| | *, |
| | padding_mask: Optional[jnp.ndarray] = None, |
| | update_batch_stats: bool = False) -> Dict[str, Any]: |
| | """Creates the image embeddings. |
| | |
| | Args: |
| | cnn: Conv Net for processing the image. |
| | images: The images to be embedded. |
| | train: Whether it is training. |
| | padding_mask: Binary matrix with 0 at padded image regions. |
| | update_batch_stats: Whether update the batch statistics for the BatchNorms |
| | in the backbone. if None, the value of `train` flag will be used, i.e. |
| | we update the batch stat if we are in the train mode. |
| | |
| | Returns: |
| | Output: dict; that has 'content_emb' and 'pos_emb'. |
| | """ |
| | if update_batch_stats is None: |
| | update_batch_stats = train |
| |
|
| | backbone_features = cnn(images, train=update_batch_stats) |
| | x = backbone_features['stage_4'] |
| |
|
| | bs, h, w, _ = x.shape |
| |
|
| | if padding_mask is None: |
| | padding_mask_downsampled = jnp.ones((bs, h, w), dtype=jnp.bool_) |
| | else: |
| | padding_mask_downsampled = jax.image.resize( |
| | padding_mask.astype(jnp.float32), shape=[bs, h, w], |
| | method='nearest').astype(jnp.bool_) |
| | pos_emb = InputPosEmbeddingSine(hidden_dim=self.hidden_dim)( |
| | padding_mask_downsampled) |
| |
|
| | |
| | x = nn.Conv(features=self.hidden_dim, kernel_size=(1, 1), strides=(1, 1))(x) |
| | x = x.reshape(bs, h * w, self.hidden_dim) |
| | mask = jnp.reshape(padding_mask_downsampled, [bs, h * w]) |
| | output = {} |
| | output['content_emb'] = x |
| | output['pos_emb'] = pos_emb |
| | output['mask'] = mask |
| | output['backbone_features'] = backbone_features |
| | output['shapes'] = (bs, h, w) |
| | return output |
| |
|
| |
|
| | class QueryPosEmbedding(nn.Module): |
| | """Creates learned positional embeddings for object queries. |
| | |
| | Attributes: |
| | hidden_dim: Hidden dimension for the pos embeddings. |
| | num_queries: Number of object queries. |
| | posemb_init: Positional embeddings initializer. |
| | dtype: Jax dtype; The dtype of the computation (default: float32). |
| | """ |
| | hidden_dim: int |
| | num_queries: int |
| | posemb_init: Callable[..., Any] = initializers.normal(stddev=1.0) |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | @nn.compact |
| | def __call__(self) -> jnp.ndarray: |
| | """Creates the positional embeddings for queries. |
| | |
| | Returns: |
| | Positional embedding for object queries. |
| | """ |
| | query_pos = self.param('query_emb', self.posemb_init, |
| | (self.num_queries, self.hidden_dim)) |
| | query_pos = jnp.expand_dims(query_pos, 0) |
| | return jnp.asarray(query_pos, self.dtype) |
| |
|
| |
|
| | class StructureEmbedding(nn.Module): |
| | """Creates learned embeddings for structures. |
| | |
| | Attributes: |
| | hidden_dim: Hidden dimension for the pos embeddings. |
| | num_queries: The number of queries. |
| | dtype: Jax dtype; The dtype of the computation (default: float32). |
| | """ |
| | hidden_dim: int |
| | num_queries: int |
| | txt_pool_method: str = 'max' |
| | num_types: int = 30 |
| | coordinate_emb_depth: int = 256 |
| | dtype: jnp.dtype = jnp.float32 |
| | aggregation: str = 'concat' |
| | dropout_rate: float = 0.2 |
| |
|
| | @nn.compact |
| | def __call__( |
| | self, |
| | obj_mask: jnp.ndarray, |
| | desc_id: jnp.ndarray, |
| | resource_id: jnp.ndarray, |
| | name_id: jnp.ndarray, |
| | boxes: jnp.ndarray, |
| | task: str, |
| | token_embder, |
| | pos_pattern, |
| | train) -> Dict[str, Any]: |
| | """Creates the structure embeddings.""" |
| | |
| | |
| | |
| | |
| | |
| | |
| | bcx, bcy, bw, bh = jnp.split(boxes, 4, axis=2) |
| | |
| | boxes = jnp.concatenate( |
| | [bcx - bw / 2, bcy - bh / 2, bcx + bw / 2, bcy + bh / 2], axis=2) |
| |
|
| | pos_embs = self.embed_pos( |
| | obj_mask=obj_mask, obj_boxes=boxes, pos_pattern=pos_pattern) |
| |
|
| | obj_embds = self.embed_layout( |
| | obj_mask=obj_mask, |
| | obj_desc_id=desc_id, |
| | obj_resource_id=resource_id, |
| | obj_name_id=name_id, |
| | token_embder=token_embder) |
| |
|
| | obj_embds = nn.Dropout(rate=self.dropout_rate)( |
| | obj_embds, deterministic=not train) |
| | pos_embs = nn.Dropout(rate=self.dropout_rate)( |
| | pos_embs, deterministic=not train) |
| |
|
| | output = {} |
| | output['content_emb'] = obj_embds |
| | output['mask'] = jnp.asarray(jnp.minimum(obj_mask, 1), self.dtype) |
| | output['pos_emb'] = pos_embs |
| | return output |
| |
|
| | def embed_layout(self, obj_mask, obj_desc_id, obj_resource_id, obj_name_id, |
| | token_embder): |
| | """Prepares the input for the screen encoder.""" |
| | |
| | |
| | |
| | |
| | |
| |
|
| | obj_desc_embs = pool_txt_embs( |
| | obj_desc_id, |
| | token_embder(obj_desc_id), |
| | method=self.txt_pool_method, |
| | valid_token_start=4, |
| | dtype=self.dtype) |
| | obj_resource_id_embs = pool_txt_embs( |
| | obj_resource_id, |
| | token_embder(obj_resource_id), |
| | method=self.txt_pool_method, |
| | valid_token_start=4, |
| | dtype=self.dtype) |
| | obj_name_embs = pool_txt_embs( |
| | obj_name_id, |
| | token_embder(obj_name_id), |
| | method=self.txt_pool_method, |
| | valid_token_start=4, |
| | dtype=self.dtype) |
| |
|
| | if self.aggregation == 'concat': |
| | obj_embds = jnp.concatenate( |
| | [obj_desc_embs, obj_resource_id_embs, obj_name_embs], axis=-1) |
| | obj_embds = common.dense(obj_embds, self.hidden_dim, self.dtype) |
| | elif self.aggregation == 'sum': |
| | obj_embds = (obj_desc_embs + obj_resource_id_embs + obj_name_embs) |
| | else: |
| | raise ValueError('Unrecognized aggregation method: %s' % self.aggregation) |
| | obj_non_paddings = jnp.asarray(jnp.minimum(obj_mask, 1), self.dtype) |
| | obj_embds *= jnp.expand_dims(obj_non_paddings, 2) |
| | return obj_embds |
| |
|
| | def embed_pos(self, obj_mask, obj_boxes, pos_pattern='1/4'): |
| | """Prepares the input for the screen encoder.""" |
| | |
| | |
| | if self.aggregation == 'sum': |
| | coordinate_emb_depth = self.hidden_dim |
| | else: |
| | coordinate_emb_depth = self.coordinate_emb_depth |
| |
|
| | pos_embds = encode_coordinate( |
| | obj_boxes, coordinate_emb_depth, self.dtype, pattern=pos_pattern) |
| | obj_non_paddings = jnp.asarray(jnp.minimum(obj_mask, 1), self.dtype) |
| | pos_embds *= jnp.expand_dims(obj_non_paddings, 2) |
| | return pos_embds |
| |
|
| |
|
| | def encode_coordinate(obj_boxes, depth, dtype, freq_depth=64, pattern='1/4'): |
| | """Encodes positions using random features-based encoder.""" |
| | |
| | if pattern == '4/1': |
| | obj_boxes = jnp.expand_dims(obj_boxes, 3) |
| | num_groups = 4 |
| | elif pattern == '1/4': |
| | obj_boxes = jnp.expand_dims(obj_boxes, 2) |
| | num_groups = 1 |
| | elif pattern == '2/2': |
| | obj_boxes = jnp.reshape(obj_boxes, obj_boxes.shape[:2] + (2, 2)) |
| | num_groups = 2 |
| | else: |
| | raise ValueError('Unrecognized coord encoding pattern: %s' % pattern) |
| | kernel_init = nn.initializers.normal(stddev=1e-6) |
| | |
| | freqs = common.dense(obj_boxes, freq_depth, dtype, kernel_init=kernel_init) |
| | |
| | features = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1) |
| | coord_embds = common.dense(features, depth // num_groups, dtype) |
| | coord_embds = nn.relu(coord_embds) |
| | coord_embds = common.dense(coord_embds, depth // num_groups, dtype) |
| | coord_embds = jnp.reshape(coord_embds, features.shape[:2] + (-1,)) |
| | return coord_embds |
| |
|
| |
|
| | def pool_txt_embs(token_ids, |
| | text_embeddings, |
| | method, |
| | valid_token_start=4, |
| | dtype=jnp.float32): |
| | """Aggregate text embedding for a UI element.""" |
| | |
| | non_tokens = jnp.asarray(jnp.less(token_ids, valid_token_start), dtype) |
| | if method == 'max': |
| | assert len(token_ids.shape) == 3 |
| | embed_bias = non_tokens * -1e7 |
| | |
| | text_embeddings = jnp.max( |
| | text_embeddings + jnp.expand_dims(embed_bias, 3), axis=-2) |
| | |
| | non_paddings = jnp.asarray(jnp.greater(text_embeddings, -1e6), dtype) |
| | |
| | embeddings = text_embeddings * non_paddings |
| | elif method == 'sum': |
| | embeddings = jnp.sum( |
| | text_embeddings * jnp.expand_dims(1 - non_tokens, 4), axis=-2) |
| | elif method == 'mean': |
| | sum_embeddings = jnp.sum( |
| | text_embeddings * jnp.expand_dims(1 - non_tokens, 4), axis=-2) |
| | token_counts = jnp.maximum(jnp.sum(1 - non_tokens, axis=-1), 1) |
| | embeddings = sum_embeddings / token_counts |
| | else: |
| | raise ValueError('Unrecognized token aggregation %s' % method) |
| | return embeddings |
| |
|