AndyZijianZhang commited on
Commit
5066443
·
1 Parent(s): 99e90b8

feat: can generate now but not precise

Browse files
Files changed (3) hide show
  1. chat_template.json +1 -1
  2. modeling_vila.py +78 -12
  3. processing_vila.py +5 -5
chat_template.json CHANGED
@@ -1,3 +1,3 @@
1
  {
2
- "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<image>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
  }
 
1
  {
2
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<image>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
  }
modeling_vila.py CHANGED
@@ -1,13 +1,14 @@
1
- from typing import Optional, Tuple, Type, override
2
 
3
- from configuration_vila import VILAConfig
4
- from torch import LongTensor, Tensor
5
  from transformers.configuration_utils import PretrainedConfig
6
  from transformers.generation.utils import GenerationMixin
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
  from transformers.modeling_utils import PreTrainedModel
9
 
10
  from .auto_model import VILAForCausalLM
 
11
 
12
 
13
  class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
@@ -16,6 +17,7 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
16
  is_parallelizable: bool = True
17
  main_input_name: str = "input_ids"
18
 
 
19
  model: VILAForCausalLM
20
 
21
  def __init__(
@@ -27,18 +29,24 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
27
 
28
  def forward(
29
  self,
30
- input_ids: LongTensor,
31
- attention_mask: Tensor,
32
- pixel_values: Tensor,
33
  *,
34
- return_dict: Optional[bool] = None,
 
 
 
35
  **kwargs,
36
- ) -> Tuple | CausalLMOutputWithPast:
 
 
37
 
38
- outputs = self.model.forward(
39
- input_ids=input_ids,
 
 
 
 
 
40
  attention_mask=attention_mask,
41
- return_dict=return_dict,
42
  **kwargs,
43
  )
44
 
@@ -58,4 +66,62 @@ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
58
  @override
59
  def save_pretrained(self, *args, **kwargs) -> None:
60
  self.model.save_pretrained(*args, **kwargs)
61
- self.model.save_pretrained(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Type, cast, override
2
 
3
+ import torch
4
+ from torch import Tensor
5
  from transformers.configuration_utils import PretrainedConfig
6
  from transformers.generation.utils import GenerationMixin
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
  from transformers.modeling_utils import PreTrainedModel
9
 
10
  from .auto_model import VILAForCausalLM
11
+ from .configuration_vila import VILAConfig
12
 
13
 
14
  class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
 
17
  is_parallelizable: bool = True
18
  main_input_name: str = "input_ids"
19
 
20
+ config: PretrainedConfig
21
  model: VILAForCausalLM
22
 
23
  def __init__(
 
29
 
30
  def forward(
31
  self,
 
 
 
32
  *,
33
+ attention_mask: Optional[Tensor] = None,
34
+ input_ids: Optional[Tensor] = None,
35
+ inputs_embeds: Optional[Tensor] = None,
36
+ pixel_values: Optional[Tensor] = None,
37
  **kwargs,
38
+ ) -> CausalLMOutputWithPast:
39
+ if inputs_embeds is None:
40
+ assert input_ids is not None
41
 
42
+ inputs_embeds, _ = self._embed(input_ids, pixel_values, attention_mask)
43
+ else:
44
+ assert input_ids is None
45
+ assert pixel_values is None
46
+
47
+ outputs = self.model.llm.forward(
48
+ inputs_embeds=inputs_embeds,
49
  attention_mask=attention_mask,
 
50
  **kwargs,
51
  )
52
 
 
66
  @override
67
  def save_pretrained(self, *args, **kwargs) -> None:
68
  self.model.save_pretrained(*args, **kwargs)
69
+
70
+ def _embed(
71
+ self,
72
+ input_ids: Tensor,
73
+ pixel_values: Optional[Tensor],
74
+ attention_mask: Optional[Tensor],
75
+ ) -> Tuple[Tensor, Tensor]:
76
+ """Gets the embedding of the input ids and pixel values.
77
+
78
+ Args:
79
+ input_ids: The input ids.
80
+ pixel_values: The pixel values.
81
+ attention_mask: The attention mask.
82
+
83
+ Returns:
84
+ A tuple of the embedding of the input ids and attention mask.
85
+ """
86
+
87
+ image_token_ids_map = cast(Dict[str, int], self.model.tokenizer.media_token_ids)
88
+ image_token_ids = list(image_token_ids_map.values())
89
+ image_token_idx = torch.isin(
90
+ input_ids,
91
+ torch.tensor(image_token_ids).to(input_ids.device),
92
+ )
93
+ image_token_count = image_token_idx.sum()
94
+
95
+ images = list(pixel_values) if pixel_values is not None else []
96
+
97
+ if image_token_count < len(images):
98
+ images = images[:image_token_count]
99
+
100
+ media = (
101
+ {
102
+ "image": images,
103
+ }
104
+ if image_token_count > 0
105
+ else {}
106
+ )
107
+ media_config = (
108
+ {
109
+ "image": {},
110
+ }
111
+ if image_token_count > 0
112
+ else {}
113
+ )
114
+
115
+ outputs = self.model._embed(
116
+ input_ids,
117
+ media,
118
+ media_config,
119
+ labels=None,
120
+ attention_mask=(
121
+ attention_mask.to(dtype=torch.bool)
122
+ if attention_mask is not None
123
+ else None
124
+ ),
125
+ )
126
+
127
+ return outputs[0], outputs[2]
processing_vila.py CHANGED
@@ -15,8 +15,8 @@ from transformers.tokenization_utils import PreTrainedTokenizer
15
  from transformers.tokenization_utils_base import TextInput
16
 
17
  from . import mm_utils
 
18
 
19
- _IMAGE_TOKEN = "<image>"
20
  _PLACEHOLDER_TOKEN = "<|placeholder|>"
21
 
22
 
@@ -124,20 +124,20 @@ class VILAProcessor(ProcessorMixin):
124
  idx_image_splice = 0
125
 
126
  for i in range(len(text)):
127
- while _IMAGE_TOKEN in text[i]:
128
  if idx_image_splice >= len(num_image_splices):
129
  raise ValueError(
130
- f"Too many {_IMAGE_TOKEN} tokens in text. "
131
  f"Expected {len(num_image_splices)} tokens, "
132
  f"but found {idx_image_splice} tokens."
133
  )
134
 
135
  text[i] = text[i].replace(
136
- _IMAGE_TOKEN,
137
  _PLACEHOLDER_TOKEN * num_image_splices[idx_image_splice],
138
  )
139
  idx_image_splice += 1
140
- text[i] = text[i].replace(_PLACEHOLDER_TOKEN, _IMAGE_TOKEN)
141
 
142
  return text
143
 
 
15
  from transformers.tokenization_utils_base import TextInput
16
 
17
  from . import mm_utils
18
+ from .constants import DEFAULT_IMAGE_TOKEN
19
 
 
20
  _PLACEHOLDER_TOKEN = "<|placeholder|>"
21
 
22
 
 
124
  idx_image_splice = 0
125
 
126
  for i in range(len(text)):
127
+ while DEFAULT_IMAGE_TOKEN in text[i]:
128
  if idx_image_splice >= len(num_image_splices):
129
  raise ValueError(
130
+ f"Too many {DEFAULT_IMAGE_TOKEN} tokens in text. "
131
  f"Expected {len(num_image_splices)} tokens, "
132
  f"but found {idx_image_splice} tokens."
133
  )
134
 
135
  text[i] = text[i].replace(
136
+ DEFAULT_IMAGE_TOKEN,
137
  _PLACEHOLDER_TOKEN * num_image_splices[idx_image_splice],
138
  )
139
  idx_image_splice += 1
140
+ text[i] = text[i].replace(_PLACEHOLDER_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n")
141
 
142
  return text
143