Commit ·
5066443
1
Parent(s): 99e90b8
feat: can generate now but not precise
Browse files- chat_template.json +1 -1
- modeling_vila.py +78 -12
- processing_vila.py +5 -5
chat_template.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
{
|
| 2 |
-
"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant
|
| 3 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<image>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
| 3 |
}
|
modeling_vila.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
-
from typing import Optional, Tuple, Type, override
|
| 2 |
|
| 3 |
-
|
| 4 |
-
from torch import
|
| 5 |
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
from transformers.generation.utils import GenerationMixin
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 8 |
from transformers.modeling_utils import PreTrainedModel
|
| 9 |
|
| 10 |
from .auto_model import VILAForCausalLM
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
@@ -16,6 +17,7 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
| 16 |
is_parallelizable: bool = True
|
| 17 |
main_input_name: str = "input_ids"
|
| 18 |
|
|
|
|
| 19 |
model: VILAForCausalLM
|
| 20 |
|
| 21 |
def __init__(
|
|
@@ -27,18 +29,24 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
| 27 |
|
| 28 |
def forward(
|
| 29 |
self,
|
| 30 |
-
input_ids: LongTensor,
|
| 31 |
-
attention_mask: Tensor,
|
| 32 |
-
pixel_values: Tensor,
|
| 33 |
*,
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
**kwargs,
|
| 36 |
-
) ->
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
attention_mask=attention_mask,
|
| 41 |
-
return_dict=return_dict,
|
| 42 |
**kwargs,
|
| 43 |
)
|
| 44 |
|
|
@@ -58,4 +66,62 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
| 58 |
@override
|
| 59 |
def save_pretrained(self, *args, **kwargs) -> None:
|
| 60 |
self.model.save_pretrained(*args, **kwargs)
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Type, cast, override
|
| 2 |
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
from transformers.generation.utils import GenerationMixin
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 8 |
from transformers.modeling_utils import PreTrainedModel
|
| 9 |
|
| 10 |
from .auto_model import VILAForCausalLM
|
| 11 |
+
from .configuration_vila import VILAConfig
|
| 12 |
|
| 13 |
|
| 14 |
class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
|
|
| 17 |
is_parallelizable: bool = True
|
| 18 |
main_input_name: str = "input_ids"
|
| 19 |
|
| 20 |
+
config: PretrainedConfig
|
| 21 |
model: VILAForCausalLM
|
| 22 |
|
| 23 |
def __init__(
|
|
|
|
| 29 |
|
| 30 |
def forward(
|
| 31 |
self,
|
|
|
|
|
|
|
|
|
|
| 32 |
*,
|
| 33 |
+
attention_mask: Optional[Tensor] = None,
|
| 34 |
+
input_ids: Optional[Tensor] = None,
|
| 35 |
+
inputs_embeds: Optional[Tensor] = None,
|
| 36 |
+
pixel_values: Optional[Tensor] = None,
|
| 37 |
**kwargs,
|
| 38 |
+
) -> CausalLMOutputWithPast:
|
| 39 |
+
if inputs_embeds is None:
|
| 40 |
+
assert input_ids is not None
|
| 41 |
|
| 42 |
+
inputs_embeds, _ = self._embed(input_ids, pixel_values, attention_mask)
|
| 43 |
+
else:
|
| 44 |
+
assert input_ids is None
|
| 45 |
+
assert pixel_values is None
|
| 46 |
+
|
| 47 |
+
outputs = self.model.llm.forward(
|
| 48 |
+
inputs_embeds=inputs_embeds,
|
| 49 |
attention_mask=attention_mask,
|
|
|
|
| 50 |
**kwargs,
|
| 51 |
)
|
| 52 |
|
|
|
|
| 66 |
@override
|
| 67 |
def save_pretrained(self, *args, **kwargs) -> None:
|
| 68 |
self.model.save_pretrained(*args, **kwargs)
|
| 69 |
+
|
| 70 |
+
def _embed(
|
| 71 |
+
self,
|
| 72 |
+
input_ids: Tensor,
|
| 73 |
+
pixel_values: Optional[Tensor],
|
| 74 |
+
attention_mask: Optional[Tensor],
|
| 75 |
+
) -> Tuple[Tensor, Tensor]:
|
| 76 |
+
"""Gets the embedding of the input ids and pixel values.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
input_ids: The input ids.
|
| 80 |
+
pixel_values: The pixel values.
|
| 81 |
+
attention_mask: The attention mask.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
A tuple of the embedding of the input ids and attention mask.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
image_token_ids_map = cast(Dict[str, int], self.model.tokenizer.media_token_ids)
|
| 88 |
+
image_token_ids = list(image_token_ids_map.values())
|
| 89 |
+
image_token_idx = torch.isin(
|
| 90 |
+
input_ids,
|
| 91 |
+
torch.tensor(image_token_ids).to(input_ids.device),
|
| 92 |
+
)
|
| 93 |
+
image_token_count = image_token_idx.sum()
|
| 94 |
+
|
| 95 |
+
images = list(pixel_values) if pixel_values is not None else []
|
| 96 |
+
|
| 97 |
+
if image_token_count < len(images):
|
| 98 |
+
images = images[:image_token_count]
|
| 99 |
+
|
| 100 |
+
media = (
|
| 101 |
+
{
|
| 102 |
+
"image": images,
|
| 103 |
+
}
|
| 104 |
+
if image_token_count > 0
|
| 105 |
+
else {}
|
| 106 |
+
)
|
| 107 |
+
media_config = (
|
| 108 |
+
{
|
| 109 |
+
"image": {},
|
| 110 |
+
}
|
| 111 |
+
if image_token_count > 0
|
| 112 |
+
else {}
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
outputs = self.model._embed(
|
| 116 |
+
input_ids,
|
| 117 |
+
media,
|
| 118 |
+
media_config,
|
| 119 |
+
labels=None,
|
| 120 |
+
attention_mask=(
|
| 121 |
+
attention_mask.to(dtype=torch.bool)
|
| 122 |
+
if attention_mask is not None
|
| 123 |
+
else None
|
| 124 |
+
),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return outputs[0], outputs[2]
|
processing_vila.py
CHANGED
|
@@ -15,8 +15,8 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
|
| 15 |
from transformers.tokenization_utils_base import TextInput
|
| 16 |
|
| 17 |
from . import mm_utils
|
|
|
|
| 18 |
|
| 19 |
-
_IMAGE_TOKEN = "<image>"
|
| 20 |
_PLACEHOLDER_TOKEN = "<|placeholder|>"
|
| 21 |
|
| 22 |
|
|
@@ -124,20 +124,20 @@ class VILAProcessor(ProcessorMixin):
|
|
| 124 |
idx_image_splice = 0
|
| 125 |
|
| 126 |
for i in range(len(text)):
|
| 127 |
-
while
|
| 128 |
if idx_image_splice >= len(num_image_splices):
|
| 129 |
raise ValueError(
|
| 130 |
-
f"Too many {
|
| 131 |
f"Expected {len(num_image_splices)} tokens, "
|
| 132 |
f"but found {idx_image_splice} tokens."
|
| 133 |
)
|
| 134 |
|
| 135 |
text[i] = text[i].replace(
|
| 136 |
-
|
| 137 |
_PLACEHOLDER_TOKEN * num_image_splices[idx_image_splice],
|
| 138 |
)
|
| 139 |
idx_image_splice += 1
|
| 140 |
-
text[i] = text[i].replace(_PLACEHOLDER_TOKEN,
|
| 141 |
|
| 142 |
return text
|
| 143 |
|
|
|
|
| 15 |
from transformers.tokenization_utils_base import TextInput
|
| 16 |
|
| 17 |
from . import mm_utils
|
| 18 |
+
from .constants import DEFAULT_IMAGE_TOKEN
|
| 19 |
|
|
|
|
| 20 |
_PLACEHOLDER_TOKEN = "<|placeholder|>"
|
| 21 |
|
| 22 |
|
|
|
|
| 124 |
idx_image_splice = 0
|
| 125 |
|
| 126 |
for i in range(len(text)):
|
| 127 |
+
while DEFAULT_IMAGE_TOKEN in text[i]:
|
| 128 |
if idx_image_splice >= len(num_image_splices):
|
| 129 |
raise ValueError(
|
| 130 |
+
f"Too many {DEFAULT_IMAGE_TOKEN} tokens in text. "
|
| 131 |
f"Expected {len(num_image_splices)} tokens, "
|
| 132 |
f"but found {idx_image_splice} tokens."
|
| 133 |
)
|
| 134 |
|
| 135 |
text[i] = text[i].replace(
|
| 136 |
+
DEFAULT_IMAGE_TOKEN,
|
| 137 |
_PLACEHOLDER_TOKEN * num_image_splices[idx_image_splice],
|
| 138 |
)
|
| 139 |
idx_image_splice += 1
|
| 140 |
+
text[i] = text[i].replace(_PLACEHOLDER_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n")
|
| 141 |
|
| 142 |
return text
|
| 143 |
|