File size: 4,490 Bytes
befcde3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, Union, overload)
import torch
from torch.func import functional_call
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
...
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if isinstance(x, torch.Tensor):
return x.flatten(0, 1)
if concat:
return torch.cat(x)
return [x_n for x_b in x for x_n in x_b]
def _flatten_embeddings(embeddings: torch.Tensor) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: torch.Tensor) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
Tensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
# [total_patches, text_config.hidden_size]
flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[is_multimodal] = flattened
return inputs_embeds
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: torch.Tensor,
placeholder_token_id: Union[int, List[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
)
|