Atharva Mete commited on
Commit ·
57b4d23
1
Parent(s): 303e3cf
vla added but giving nans in loss
Browse files- added_tokens.json +2 -0
- config.json +2 -0
- config_molmo.py +4 -0
- modeling_molmo.py +53 -2
- preprocessing_molmo.py +20 -4
- special_tokens_map.json +3 -1
- tokenizer_config.json +19 -1
added_tokens.json
CHANGED
|
@@ -7,6 +7,8 @@
|
|
| 7 |
"<|im_end|>": 151645,
|
| 8 |
"<|im_start|>": 151644,
|
| 9 |
"<|image|>": 152068,
|
|
|
|
|
|
|
| 10 |
"|<EXTRA_TOKENS_0>|": 151646,
|
| 11 |
"|<EXTRA_TOKENS_100>|": 151746,
|
| 12 |
"|<EXTRA_TOKENS_101>|": 151747,
|
|
|
|
| 7 |
"<|im_end|>": 151645,
|
| 8 |
"<|im_start|>": 151644,
|
| 9 |
"<|image|>": 152068,
|
| 10 |
+
"<|proprio|>": 152069,
|
| 11 |
+
"<|skill|>": 152070,
|
| 12 |
"|<EXTRA_TOKENS_0>|": 151646,
|
| 13 |
"|<EXTRA_TOKENS_100>|": 151746,
|
| 14 |
"|<EXTRA_TOKENS_101>|": 151747,
|
config.json
CHANGED
|
@@ -28,5 +28,7 @@
|
|
| 28 |
"use_cache": true,
|
| 29 |
"use_position_ids": true,
|
| 30 |
"vocab_size": 152064,
|
|
|
|
|
|
|
| 31 |
"weight_tying": false
|
| 32 |
}
|
|
|
|
| 28 |
"use_cache": true,
|
| 29 |
"use_position_ids": true,
|
| 30 |
"vocab_size": 152064,
|
| 31 |
+
"skill_vocab_size": 1000,
|
| 32 |
+
"additional_vocab_size": 128,
|
| 33 |
"weight_tying": false
|
| 34 |
}
|
config_molmo.py
CHANGED
|
@@ -9,6 +9,8 @@ class MolmoConfig(PretrainedConfig):
|
|
| 9 |
|
| 10 |
def __init__(
|
| 11 |
self,
|
|
|
|
|
|
|
| 12 |
vocab_size=50304,
|
| 13 |
embedding_size=50304,
|
| 14 |
hidden_size=4096,
|
|
@@ -31,6 +33,8 @@ class MolmoConfig(PretrainedConfig):
|
|
| 31 |
layer_norm_type: str="rms",
|
| 32 |
**kwargs,
|
| 33 |
):
|
|
|
|
|
|
|
| 34 |
self.vocab_size = vocab_size
|
| 35 |
self.embedding_size = embedding_size
|
| 36 |
self.max_position_embeddings = max_position_embeddings
|
|
|
|
| 9 |
|
| 10 |
def __init__(
|
| 11 |
self,
|
| 12 |
+
skill_vocab_size=1000,
|
| 13 |
+
additional_vocab_size=128,
|
| 14 |
vocab_size=50304,
|
| 15 |
embedding_size=50304,
|
| 16 |
hidden_size=4096,
|
|
|
|
| 33 |
layer_norm_type: str="rms",
|
| 34 |
**kwargs,
|
| 35 |
):
|
| 36 |
+
self.skill_vocab_size = skill_vocab_size
|
| 37 |
+
self.additional_vocab_size = additional_vocab_size
|
| 38 |
self.vocab_size = vocab_size
|
| 39 |
self.embedding_size = embedding_size
|
| 40 |
self.max_position_embeddings = max_position_embeddings
|
modeling_molmo.py
CHANGED
|
@@ -541,6 +541,7 @@ class Embedding(nn.Module):
|
|
| 541 |
self,
|
| 542 |
num_embeddings: int,
|
| 543 |
num_new_embeddings: int,
|
|
|
|
| 544 |
features: int,
|
| 545 |
device: Union[str, torch.device],
|
| 546 |
initializer_range: float = 0.02,
|
|
@@ -555,13 +556,17 @@ class Embedding(nn.Module):
|
|
| 555 |
self.new_embedding = nn.Parameter(
|
| 556 |
torch.zeros(num_new_embeddings, features, device=device),
|
| 557 |
)
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
def reset_parameters(self):
|
| 560 |
nn.init.normal_(self.embedding, std=self.initializer_range)
|
| 561 |
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
|
|
|
| 562 |
|
| 563 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 564 |
-
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
|
| 565 |
|
| 566 |
|
| 567 |
class Dropout(nn.Dropout):
|
|
@@ -681,6 +686,7 @@ class FullMolmoConfig:
|
|
| 681 |
initializer_range: float = 0.02
|
| 682 |
normalize_input_embeds: bool = False
|
| 683 |
use_position_ids: bool = True
|
|
|
|
| 684 |
|
| 685 |
@property
|
| 686 |
def effective_n_kv_heads(self) -> int:
|
|
@@ -1695,6 +1701,7 @@ class Molmo(nn.Module):
|
|
| 1695 |
wte = Embedding(
|
| 1696 |
config.embedding_size or config.vocab_size,
|
| 1697 |
config.additional_vocab_size,
|
|
|
|
| 1698 |
config.d_model,
|
| 1699 |
device=config.init_device,
|
| 1700 |
initializer_range=config.initializer_range,
|
|
@@ -1734,6 +1741,16 @@ class Molmo(nn.Module):
|
|
| 1734 |
)
|
| 1735 |
}
|
| 1736 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1737 |
|
| 1738 |
self.vision_backbone: Optional[OLMoVisionBackbone] = None
|
| 1739 |
if config.vision_backbone is not None:
|
|
@@ -1741,6 +1758,11 @@ class Molmo(nn.Module):
|
|
| 1741 |
|
| 1742 |
self.__num_fwd_flops: Optional[int] = None
|
| 1743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1744 |
def reset_parameters(self):
|
| 1745 |
if self.vision_backbone is not None:
|
| 1746 |
self.vision_backbone.reset_parameters()
|
|
@@ -1778,12 +1800,15 @@ class Molmo(nn.Module):
|
|
| 1778 |
image_masks: Optional[torch.Tensor] = None,
|
| 1779 |
image_input_idx: Optional[torch.Tensor] = None,
|
| 1780 |
subsegment_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
| 1781 |
position_ids: Optional[torch.Tensor] = None,
|
| 1782 |
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1783 |
use_cache: bool = False,
|
| 1784 |
last_logits_only: bool = False,
|
| 1785 |
output_hidden_states: Optional[bool] = None,
|
| 1786 |
append_last_valid_logits: Optional[torch.Tensor] = None,
|
|
|
|
| 1787 |
) -> ModelOutput:
|
| 1788 |
"""
|
| 1789 |
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
|
@@ -1880,6 +1905,9 @@ class Molmo(nn.Module):
|
|
| 1880 |
image_features = image_features.to(x.device)
|
| 1881 |
|
| 1882 |
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
|
|
|
|
|
|
|
|
|
| 1883 |
|
| 1884 |
if not self.config.rope:
|
| 1885 |
# Get positional embeddings.
|
|
@@ -1997,7 +2025,14 @@ class Molmo(nn.Module):
|
|
| 1997 |
if self.config.weight_tying:
|
| 1998 |
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 1999 |
else:
|
| 2000 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2001 |
if self.config.scale_logits:
|
| 2002 |
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 2003 |
|
|
@@ -2039,6 +2074,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2039 |
mlp_hidden_size=config.intermediate_size,
|
| 2040 |
n_layers=config.num_hidden_layers,
|
| 2041 |
additional_vocab_size=128,
|
|
|
|
| 2042 |
n_heads=config.num_attention_heads,
|
| 2043 |
n_kv_heads=config.num_key_value_heads,
|
| 2044 |
rope_theta=config.rope_theta,
|
|
@@ -2080,6 +2116,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2080 |
image_masks: Optional[torch.Tensor] = None,
|
| 2081 |
image_input_idx: Optional[torch.Tensor] = None,
|
| 2082 |
subsegment_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
| 2083 |
position_ids: Optional[torch.Tensor] = None,
|
| 2084 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 2085 |
labels: Optional[torch.LongTensor] = None,
|
|
@@ -2113,6 +2151,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2113 |
image_masks=image_masks,
|
| 2114 |
image_input_idx=image_input_idx,
|
| 2115 |
subsegment_ids=subsegment_ids,
|
|
|
|
|
|
|
| 2116 |
position_ids=position_ids,
|
| 2117 |
past_key_values=past_key_values,
|
| 2118 |
use_cache=use_cache,
|
|
@@ -2185,6 +2225,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2185 |
images = batch.get("images")
|
| 2186 |
image_masks = batch.get("image_masks")
|
| 2187 |
image_input_idx = batch.get("image_input_idx")
|
|
|
|
|
|
|
| 2188 |
|
| 2189 |
# Validate inputs.
|
| 2190 |
input_ids = batch["input_ids"]
|
|
@@ -2217,6 +2259,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2217 |
image_masks=image_masks,
|
| 2218 |
image_input_idx=image_input_idx,
|
| 2219 |
position_ids=position_ids,
|
|
|
|
|
|
|
| 2220 |
append_last_valid_logits=append_last_valid_logits,
|
| 2221 |
**kwargs,
|
| 2222 |
)
|
|
@@ -2235,6 +2279,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2235 |
images = kwargs.get("images")
|
| 2236 |
image_masks = kwargs.get("image_masks")
|
| 2237 |
image_input_idx = kwargs.get("image_input_idx")
|
|
|
|
|
|
|
| 2238 |
position_ids = kwargs.get("position_ids")
|
| 2239 |
append_last_valid_logits = kwargs.get("append_last_valid_logits")
|
| 2240 |
model_inputs = {
|
|
@@ -2250,6 +2296,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2250 |
model_inputs["image_masks"] = image_masks
|
| 2251 |
model_inputs["image_input_idx"] = image_input_idx
|
| 2252 |
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
|
|
|
|
|
|
| 2253 |
else:
|
| 2254 |
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 2255 |
|
|
@@ -2272,6 +2320,9 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
| 2272 |
del model_kwargs["images"]
|
| 2273 |
del model_kwargs["image_masks"]
|
| 2274 |
del model_kwargs["image_input_idx"]
|
|
|
|
|
|
|
|
|
|
| 2275 |
cache_name, cache = super()._extract_past_from_model_output(outputs)
|
| 2276 |
model_kwargs[cache_name] = cache
|
| 2277 |
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
|
|
|
| 541 |
self,
|
| 542 |
num_embeddings: int,
|
| 543 |
num_new_embeddings: int,
|
| 544 |
+
num_skill_embeddings: int,
|
| 545 |
features: int,
|
| 546 |
device: Union[str, torch.device],
|
| 547 |
initializer_range: float = 0.02,
|
|
|
|
| 556 |
self.new_embedding = nn.Parameter(
|
| 557 |
torch.zeros(num_new_embeddings, features, device=device),
|
| 558 |
)
|
| 559 |
+
self.skill_embedding = nn.Parameter(
|
| 560 |
+
torch.zeros(num_skill_embeddings, features, device=device),
|
| 561 |
+
)
|
| 562 |
|
| 563 |
def reset_parameters(self):
|
| 564 |
nn.init.normal_(self.embedding, std=self.initializer_range)
|
| 565 |
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
| 566 |
+
nn.init.normal_(self.skill_embedding, std=self.new_embed_initializer_range)
|
| 567 |
|
| 568 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 569 |
+
return F.embedding(x, torch.cat([self.embedding, self.new_embedding, self.skill_embedding], dim=0))
|
| 570 |
|
| 571 |
|
| 572 |
class Dropout(nn.Dropout):
|
|
|
|
| 686 |
initializer_range: float = 0.02
|
| 687 |
normalize_input_embeds: bool = False
|
| 688 |
use_position_ids: bool = True
|
| 689 |
+
skill_vocab_size: int = 1000
|
| 690 |
|
| 691 |
@property
|
| 692 |
def effective_n_kv_heads(self) -> int:
|
|
|
|
| 1701 |
wte = Embedding(
|
| 1702 |
config.embedding_size or config.vocab_size,
|
| 1703 |
config.additional_vocab_size,
|
| 1704 |
+
config.skill_vocab_size,
|
| 1705 |
config.d_model,
|
| 1706 |
device=config.init_device,
|
| 1707 |
initializer_range=config.initializer_range,
|
|
|
|
| 1741 |
)
|
| 1742 |
}
|
| 1743 |
)
|
| 1744 |
+
self.transformer.update(
|
| 1745 |
+
{
|
| 1746 |
+
"skill_ff_out": nn.Linear(
|
| 1747 |
+
config.d_model,
|
| 1748 |
+
config.skill_vocab_size,
|
| 1749 |
+
bias=config.include_bias,
|
| 1750 |
+
device=config.init_device,
|
| 1751 |
+
)
|
| 1752 |
+
}
|
| 1753 |
+
)
|
| 1754 |
|
| 1755 |
self.vision_backbone: Optional[OLMoVisionBackbone] = None
|
| 1756 |
if config.vision_backbone is not None:
|
|
|
|
| 1758 |
|
| 1759 |
self.__num_fwd_flops: Optional[int] = None
|
| 1760 |
|
| 1761 |
+
self.total_vocab_size = config.vocab_size + config.additional_vocab_size + config.skill_vocab_size
|
| 1762 |
+
torch.nn.init.xavier_uniform_(self.transformer.skill_ff_out.weight)
|
| 1763 |
+
if self.transformer.skill_ff_out.bias is not None:
|
| 1764 |
+
torch.nn.init.zeros_(self.transformer.skill_ff_out.bias)
|
| 1765 |
+
|
| 1766 |
def reset_parameters(self):
|
| 1767 |
if self.vision_backbone is not None:
|
| 1768 |
self.vision_backbone.reset_parameters()
|
|
|
|
| 1800 |
image_masks: Optional[torch.Tensor] = None,
|
| 1801 |
image_input_idx: Optional[torch.Tensor] = None,
|
| 1802 |
subsegment_ids: Optional[torch.Tensor] = None,
|
| 1803 |
+
proprio_embeds: Optional[torch.Tensor] = None,
|
| 1804 |
+
proprio_idx: Optional[torch.Tensor] = None,
|
| 1805 |
position_ids: Optional[torch.Tensor] = None,
|
| 1806 |
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1807 |
use_cache: bool = False,
|
| 1808 |
last_logits_only: bool = False,
|
| 1809 |
output_hidden_states: Optional[bool] = None,
|
| 1810 |
append_last_valid_logits: Optional[torch.Tensor] = None,
|
| 1811 |
+
mode: Optional[str] = "vla",
|
| 1812 |
) -> ModelOutput:
|
| 1813 |
"""
|
| 1814 |
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
|
|
|
| 1905 |
image_features = image_features.to(x.device)
|
| 1906 |
|
| 1907 |
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
| 1908 |
+
|
| 1909 |
+
if proprio_embeds is not None:
|
| 1910 |
+
x[batch_idx, proprio_idx] += proprio_embeds
|
| 1911 |
|
| 1912 |
if not self.config.rope:
|
| 1913 |
# Get positional embeddings.
|
|
|
|
| 2025 |
if self.config.weight_tying:
|
| 2026 |
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 2027 |
else:
|
| 2028 |
+
if mode == "vla":
|
| 2029 |
+
logits = self.transformer.skill_ff_out(x)
|
| 2030 |
+
# this little trick allows use to use HF generate() while decoding
|
| 2031 |
+
if use_cache:
|
| 2032 |
+
filler_logits = torch.full((x.shape[0], x.shape[1], self.total_vocab_size-self.config.skill_vocab_size), -math.inf, device=logits.device)
|
| 2033 |
+
logits = torch.cat([filler_logits, logits], dim=-1) # type: ignore
|
| 2034 |
+
else:
|
| 2035 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
| 2036 |
if self.config.scale_logits:
|
| 2037 |
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 2038 |
|
|
|
|
| 2074 |
mlp_hidden_size=config.intermediate_size,
|
| 2075 |
n_layers=config.num_hidden_layers,
|
| 2076 |
additional_vocab_size=128,
|
| 2077 |
+
skill_vocab_size=config.skill_vocab_size,
|
| 2078 |
n_heads=config.num_attention_heads,
|
| 2079 |
n_kv_heads=config.num_key_value_heads,
|
| 2080 |
rope_theta=config.rope_theta,
|
|
|
|
| 2116 |
image_masks: Optional[torch.Tensor] = None,
|
| 2117 |
image_input_idx: Optional[torch.Tensor] = None,
|
| 2118 |
subsegment_ids: Optional[torch.Tensor] = None,
|
| 2119 |
+
proprio_embeds: Optional[torch.Tensor] = None,
|
| 2120 |
+
proprio_idx: Optional[torch.Tensor] = None,
|
| 2121 |
position_ids: Optional[torch.Tensor] = None,
|
| 2122 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 2123 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
| 2151 |
image_masks=image_masks,
|
| 2152 |
image_input_idx=image_input_idx,
|
| 2153 |
subsegment_ids=subsegment_ids,
|
| 2154 |
+
proprio_embeds=proprio_embeds,
|
| 2155 |
+
proprio_idx=proprio_idx,
|
| 2156 |
position_ids=position_ids,
|
| 2157 |
past_key_values=past_key_values,
|
| 2158 |
use_cache=use_cache,
|
|
|
|
| 2225 |
images = batch.get("images")
|
| 2226 |
image_masks = batch.get("image_masks")
|
| 2227 |
image_input_idx = batch.get("image_input_idx")
|
| 2228 |
+
proprio_embeds = batch.get("proprio_embeds")
|
| 2229 |
+
proprio_idx = batch.get("proprio_idx")
|
| 2230 |
|
| 2231 |
# Validate inputs.
|
| 2232 |
input_ids = batch["input_ids"]
|
|
|
|
| 2259 |
image_masks=image_masks,
|
| 2260 |
image_input_idx=image_input_idx,
|
| 2261 |
position_ids=position_ids,
|
| 2262 |
+
proprio_embeds=proprio_embeds,
|
| 2263 |
+
proprio_idx=proprio_idx,
|
| 2264 |
append_last_valid_logits=append_last_valid_logits,
|
| 2265 |
**kwargs,
|
| 2266 |
)
|
|
|
|
| 2279 |
images = kwargs.get("images")
|
| 2280 |
image_masks = kwargs.get("image_masks")
|
| 2281 |
image_input_idx = kwargs.get("image_input_idx")
|
| 2282 |
+
proprio_embeds = kwargs.get("proprio_embeds")
|
| 2283 |
+
proprio_idx = kwargs.get("proprio_idx")
|
| 2284 |
position_ids = kwargs.get("position_ids")
|
| 2285 |
append_last_valid_logits = kwargs.get("append_last_valid_logits")
|
| 2286 |
model_inputs = {
|
|
|
|
| 2296 |
model_inputs["image_masks"] = image_masks
|
| 2297 |
model_inputs["image_input_idx"] = image_input_idx
|
| 2298 |
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
| 2299 |
+
model_inputs["proprio_embeds"] = proprio_embeds
|
| 2300 |
+
model_inputs["proprio_idx"] = proprio_idx
|
| 2301 |
else:
|
| 2302 |
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 2303 |
|
|
|
|
| 2320 |
del model_kwargs["images"]
|
| 2321 |
del model_kwargs["image_masks"]
|
| 2322 |
del model_kwargs["image_input_idx"]
|
| 2323 |
+
if "proprio_embeds" in model_kwargs:
|
| 2324 |
+
del model_kwargs["proprio_embeds"]
|
| 2325 |
+
del model_kwargs["proprio_idx"]
|
| 2326 |
cache_name, cache = super()._extract_past_from_model_output(outputs)
|
| 2327 |
model_kwargs[cache_name] = cache
|
| 2328 |
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
preprocessing_molmo.py
CHANGED
|
@@ -28,7 +28,7 @@ from transformers.utils import logging
|
|
| 28 |
|
| 29 |
from transformers import AutoTokenizer
|
| 30 |
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
| 31 |
-
|
| 32 |
|
| 33 |
logger = logging.get_logger(__name__)
|
| 34 |
|
|
@@ -38,9 +38,14 @@ DEFAULT_IM_START_TOKEN = f"<im_start>"
|
|
| 38 |
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
| 39 |
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
| 40 |
IMAGE_PROMPT = "<|image|>"
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def get_special_token_ids(tokenizer):
|
| 46 |
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
|
@@ -72,7 +77,7 @@ class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
|
| 72 |
"text_kwargs": {
|
| 73 |
"style": "long_caption",
|
| 74 |
"system_prompt": "none",
|
| 75 |
-
"message_format": "
|
| 76 |
"always_start_with_space": True,
|
| 77 |
"sequence_length": 1536,
|
| 78 |
"padding": False,
|
|
@@ -97,11 +102,14 @@ class MolmoProcessor(ProcessorMixin):
|
|
| 97 |
self._special_tokens = get_special_token_ids(self.tokenizer)
|
| 98 |
return self._special_tokens
|
| 99 |
|
| 100 |
-
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
| 101 |
if message_format == "none" or message_format is None:
|
| 102 |
pass
|
| 103 |
elif message_format == "role":
|
| 104 |
prompt = "User: " + prompt + " Assistant:"
|
|
|
|
|
|
|
|
|
|
| 105 |
else:
|
| 106 |
raise NotImplementedError(f"Message format {message_format} not implemented")
|
| 107 |
|
|
@@ -116,6 +124,7 @@ class MolmoProcessor(ProcessorMixin):
|
|
| 116 |
self,
|
| 117 |
text: TextInput = None,
|
| 118 |
images: ImageInput = None,
|
|
|
|
| 119 |
*,
|
| 120 |
tokens: Optional[PreTokenizedInput] = None,
|
| 121 |
**kwargs: Unpack[MolmoProcessorKwargs],
|
|
@@ -126,14 +135,18 @@ class MolmoProcessor(ProcessorMixin):
|
|
| 126 |
**kwargs,
|
| 127 |
)
|
| 128 |
|
|
|
|
|
|
|
| 129 |
if tokens is None:
|
| 130 |
tokens = self.get_tokens_input(
|
| 131 |
text,
|
| 132 |
output_kwargs["text_kwargs"]["message_format"],
|
| 133 |
output_kwargs["text_kwargs"]["always_start_with_space"],
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
|
|
|
| 137 |
|
| 138 |
if images is not None:
|
| 139 |
if not isinstance(images, (list, tuple)):
|
|
@@ -182,6 +195,9 @@ class MolmoProcessor(ProcessorMixin):
|
|
| 182 |
# Shift patch mapping up by one since we added BOS
|
| 183 |
image_input_idx = out["image_input_idx"]
|
| 184 |
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
for k, v in out.items():
|
| 187 |
out[k] = torch.from_numpy(v)
|
|
|
|
| 28 |
|
| 29 |
from transformers import AutoTokenizer
|
| 30 |
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
| 31 |
+
from typing import List, Union
|
| 32 |
|
| 33 |
logger = logging.get_logger(__name__)
|
| 34 |
|
|
|
|
| 38 |
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
| 39 |
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
| 40 |
IMAGE_PROMPT = "<|image|>"
|
| 41 |
+
PROPRIO_PROMPT = "<|proprio|>"
|
| 42 |
+
SKILL_PROMPT = "<|skill|>"
|
| 43 |
|
| 44 |
+
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT, PROPRIO_PROMPT, SKILL_PROMPT)
|
| 45 |
|
| 46 |
+
ProprioInput = Union[
|
| 47 |
+
np.ndarray, "torch.Tensor", List[np.ndarray], List["torch.Tensor"]
|
| 48 |
+
]
|
| 49 |
|
| 50 |
def get_special_token_ids(tokenizer):
|
| 51 |
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
|
|
|
| 77 |
"text_kwargs": {
|
| 78 |
"style": "long_caption",
|
| 79 |
"system_prompt": "none",
|
| 80 |
+
"message_format": "robot",
|
| 81 |
"always_start_with_space": True,
|
| 82 |
"sequence_length": 1536,
|
| 83 |
"padding": False,
|
|
|
|
| 102 |
self._special_tokens = get_special_token_ids(self.tokenizer)
|
| 103 |
return self._special_tokens
|
| 104 |
|
| 105 |
+
def get_tokens_input(self, prompt, message_format, always_start_with_space, num_proprio):
|
| 106 |
if message_format == "none" or message_format is None:
|
| 107 |
pass
|
| 108 |
elif message_format == "role":
|
| 109 |
prompt = "User: " + prompt + " Assistant:"
|
| 110 |
+
elif message_format == "robot":
|
| 111 |
+
# this adds proprio observations after the prompt
|
| 112 |
+
prompt = "User: " + prompt + PROPRIO_PROMPT*num_proprio + " Assistant:"
|
| 113 |
else:
|
| 114 |
raise NotImplementedError(f"Message format {message_format} not implemented")
|
| 115 |
|
|
|
|
| 124 |
self,
|
| 125 |
text: TextInput = None,
|
| 126 |
images: ImageInput = None,
|
| 127 |
+
proprio: ProprioInput = None,
|
| 128 |
*,
|
| 129 |
tokens: Optional[PreTokenizedInput] = None,
|
| 130 |
**kwargs: Unpack[MolmoProcessorKwargs],
|
|
|
|
| 135 |
**kwargs,
|
| 136 |
)
|
| 137 |
|
| 138 |
+
num_proprio = len(proprio) if proprio is not None else 0
|
| 139 |
+
|
| 140 |
if tokens is None:
|
| 141 |
tokens = self.get_tokens_input(
|
| 142 |
text,
|
| 143 |
output_kwargs["text_kwargs"]["message_format"],
|
| 144 |
output_kwargs["text_kwargs"]["always_start_with_space"],
|
| 145 |
+
num_proprio
|
| 146 |
)
|
| 147 |
|
| 148 |
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
| 149 |
+
proprio_token_id = self.special_token_ids[PROPRIO_PROMPT]
|
| 150 |
|
| 151 |
if images is not None:
|
| 152 |
if not isinstance(images, (list, tuple)):
|
|
|
|
| 195 |
# Shift patch mapping up by one since we added BOS
|
| 196 |
image_input_idx = out["image_input_idx"]
|
| 197 |
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
| 198 |
+
|
| 199 |
+
proprio_idx = np.where(out["input_ids"] == proprio_token_id)[0]
|
| 200 |
+
out["proprio_idx"] = proprio_idx
|
| 201 |
|
| 202 |
for k, v in out.items():
|
| 203 |
out[k] = torch.from_numpy(v)
|
special_tokens_map.json
CHANGED
|
@@ -422,7 +422,9 @@
|
|
| 422 |
"<im_end>",
|
| 423 |
"<im_patch>",
|
| 424 |
"<im_col>",
|
| 425 |
-
"<|image|>"
|
|
|
|
|
|
|
| 426 |
],
|
| 427 |
"eos_token": {
|
| 428 |
"content": "<|endoftext|>",
|
|
|
|
| 422 |
"<im_end>",
|
| 423 |
"<im_patch>",
|
| 424 |
"<im_col>",
|
| 425 |
+
"<|image|>",
|
| 426 |
+
"<|proprio|>",
|
| 427 |
+
"<|skill|>"
|
| 428 |
],
|
| 429 |
"eos_token": {
|
| 430 |
"content": "<|endoftext|>",
|
tokenizer_config.json
CHANGED
|
@@ -3408,6 +3408,22 @@
|
|
| 3408 |
"rstrip": false,
|
| 3409 |
"single_word": false,
|
| 3410 |
"special": true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3411 |
}
|
| 3412 |
},
|
| 3413 |
"additional_special_tokens": [
|
|
@@ -3833,7 +3849,9 @@
|
|
| 3833 |
"<im_end>",
|
| 3834 |
"<im_patch>",
|
| 3835 |
"<im_col>",
|
| 3836 |
-
"<|image|>"
|
|
|
|
|
|
|
| 3837 |
],
|
| 3838 |
"auto_map": {
|
| 3839 |
"AutoProcessor": "preprocessing_molmo.MolmoProcessor"
|
|
|
|
| 3408 |
"rstrip": false,
|
| 3409 |
"single_word": false,
|
| 3410 |
"special": true
|
| 3411 |
+
},
|
| 3412 |
+
"152069": {
|
| 3413 |
+
"content": "<|proprio|>",
|
| 3414 |
+
"lstrip": false,
|
| 3415 |
+
"normalized": false,
|
| 3416 |
+
"rstrip": false,
|
| 3417 |
+
"single_word": false,
|
| 3418 |
+
"special": true
|
| 3419 |
+
},
|
| 3420 |
+
"152070": {
|
| 3421 |
+
"content": "<|skill|>",
|
| 3422 |
+
"lstrip": false,
|
| 3423 |
+
"normalized": false,
|
| 3424 |
+
"rstrip": false,
|
| 3425 |
+
"single_word": false,
|
| 3426 |
+
"special": true
|
| 3427 |
}
|
| 3428 |
},
|
| 3429 |
"additional_special_tokens": [
|
|
|
|
| 3849 |
"<im_end>",
|
| 3850 |
"<im_patch>",
|
| 3851 |
"<im_col>",
|
| 3852 |
+
"<|image|>",
|
| 3853 |
+
"<|proprio|>",
|
| 3854 |
+
"<|skill|>"
|
| 3855 |
],
|
| 3856 |
"auto_map": {
|
| 3857 |
"AutoProcessor": "preprocessing_molmo.MolmoProcessor"
|