SLIP / modeling_slip.py
LeoChen085's picture
Pass dtype through to Gemma init for proper torch_dtype support
b81a726 verified
"""
SLIP: Sensor Language Integrated Pre-training
Self-contained model file for HuggingFace Hub (trust_remote_code=True).
Usage:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("LeoChen085/SLIP", trust_remote_code=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("LeoChen085/SLIP", trust_remote_code=True)
# Task-specific checkpoint (download manually):
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
state_dict = load_file(hf_hub_download("LeoChen085/SLIP", "har.safetensors"))
model.load_state_dict(state_dict, strict=False)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
from einops import rearrange, repeat, reduce
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.activations import ACT2FN
from configuration_slip import SLIPConfig
# ═══════════════════════════════════════════════════════════════
# Positional Embeddings
# ═══════════════════════════════════════════════════════════════
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
def apply_rotary_pos_emb_2d(q, k, cos_h, sin_h, cos_w, sin_w, pos_h, pos_w, unsqueeze_dim=1):
Dh = q.shape[-1]
q_h, q_w = q.split(Dh // 2, dim=-1)
k_h, k_w = k.split(Dh // 2, dim=-1)
q_h, k_h = apply_rotary_pos_emb(q_h, k_h, cos_h, sin_h, pos_h.long(), unsqueeze_dim=unsqueeze_dim)
q_w, k_w = apply_rotary_pos_emb(q_w, k_w, cos_w, sin_w, pos_w.long(), unsqueeze_dim=unsqueeze_dim)
return torch.cat([q_h, q_w], dim=-1), torch.cat([k_h, k_w], dim=-1)
def build_2d_position_ids(attention_mask, flatten=True):
B, V, P = attention_mask.shape
mask = attention_mask.to(dtype=torch.long)
pos_patch = (mask.cumsum(dim=-1) - 1) * mask
var_valid = mask.any(dim=-1).to(dtype=torch.long)
pos_var_base = (var_valid.cumsum(dim=1) - 1) * var_valid
pos_var = pos_var_base.unsqueeze(-1).expand(B, V, P) * mask
if flatten:
return pos_var.reshape(B, V * P).long(), pos_patch.reshape(B, V * P).long()
return pos_var.long(), pos_patch.long()
# ═══════════════════════════════════════════════════════════════
# Sensor Encoder Components
# ═══════════════════════════════════════════════════════════════
def flatten_list(input_list):
return [item for sublist in input_list for item in sublist]
class MLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class TsRoPEAttention(nn.Module):
def __init__(self, layer_idx, **cfg):
super().__init__()
self.hidden_size = cfg.get("embed_dim", 768)
self.num_heads = cfg.get("num_heads", 12)
self.head_dim = self.hidden_size // self.num_heads
self.attention_dropout = cfg.get("dropout_rate", 0.1)
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim // 2, max_position_embeddings=cfg.get("max_position_embeddings"))
def forward(self, hidden_states, attention_mask=None, **kwargs):
bsz, q_len, _ = hidden_states.size()
tmp_attn_mask = rearrange(attention_mask, 'b nvar p -> b (nvar p)')
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
tmp_attn_mask = tmp_attn_mask.unsqueeze(1).unsqueeze(2).expand(-1, 1, q_len, q_len).bool()
pos_var, pos_patch = build_2d_position_ids(attention_mask, flatten=True)
cos_h, sin_h = self.rotary_emb(query_states, seq_len=int(pos_var.max().item()) + 1)
cos_w, sin_w = self.rotary_emb(query_states, seq_len=int(pos_patch.max().item()) + 1)
query_states, key_states = apply_rotary_pos_emb_2d(
query_states, key_states, cos_h, sin_h, cos_w, sin_w, pos_var, pos_patch)
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, tmp_attn_mask, dropout_p=self.attention_dropout)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
return self.o_proj(attn_output)
class MultiSizePatchEmbed(nn.Module):
def __init__(self, base_patch=32, **cfg):
super().__init__()
self.base_patch = base_patch
hidden_size = cfg['embed_dim']
intermediate_size = cfg['mlp_ratio'] * hidden_size
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
self.shared_linear = nn.Linear(base_patch * 3, intermediate_size)
self.shared_residual = nn.Linear(base_patch * 3, hidden_size)
self.dropout = nn.Dropout(cfg['dropout_rate'])
self.act = ACT2FN['silu']
self.output_layer = nn.Linear(intermediate_size, hidden_size)
def resize_weight(self, patch_size):
base_w, base_b = self.shared_linear.weight, self.shared_linear.bias
res_w, res_b = self.shared_residual.weight, self.shared_residual.bias
new_w = F.interpolate(base_w.unsqueeze(1), size=patch_size, mode="linear", align_corners=False).squeeze(1).to(base_w.dtype)
new_res_w = F.interpolate(res_w.unsqueeze(1), size=patch_size, mode="linear", align_corners=False).squeeze(1).to(res_w.dtype)
return new_w, base_b, new_res_w, res_b
def forward(self, x_list, attention_mask, time_idx):
device = self.shared_linear.weight.device
dtype = self.shared_linear.weight.dtype
sizes = torch.tensor([x.shape[-1] for x in x_list])
unique_sizes = sizes.unique(sorted=True)
N = x_list[0].shape[0]
outputs = torch.empty(len(x_list), N, self.intermediate_size, device=device, dtype=dtype)
res_outputs = torch.empty(len(x_list), N, self.hidden_size, device=device, dtype=dtype)
for psize in unique_sizes.tolist():
idxs = (sizes == psize).nonzero(as_tuple=True)[0]
xs = torch.stack([x_list[i] for i in idxs]).to(device=device, non_blocking=True)
mask = torch.stack([attention_mask[i] for i in idxs]).to(device=device, non_blocking=True)
ti = torch.stack([time_idx[i] for i in idxs]).to(device=device, non_blocking=True)
xs = torch.cat([xs, mask, ti], dim=-1)
w, b, r_w, r_b = self.resize_weight(psize * 3)
res_outputs[idxs] = F.linear(xs, r_w, r_b)
outputs[idxs] = F.linear(xs, w, b)
return self.dropout(self.output_layer(self.act(outputs))) + res_outputs
class PatchEmbedding(nn.Module):
def __init__(self, **cfg):
super().__init__()
patch_size = cfg['patch_size']
self.patch_size = patch_size
self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
hidden_size = cfg['embed_dim']
self.hidden_layer = nn.Linear(patch_size * 3, hidden_size)
self.act = ACT2FN['silu']
self.output_layer = nn.Linear(hidden_size, hidden_size)
self.residual_layer = nn.Linear(patch_size * 3, hidden_size)
def forward(self, x, mask, time_idx):
x = rearrange(x, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size)
mask = rearrange(mask, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size)
time_idx = rearrange(time_idx, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size)
x = torch.cat([x, mask, time_idx], dim=-1)
return self.dropout(self.output_layer(self.act(self.hidden_layer(x)))) + self.residual_layer(x)
class Attention(nn.Module):
def __init__(self, layer_idx, is_rope=True, **cfg):
super().__init__()
self.is_rope = is_rope
self.hidden_size = cfg.get("embed_dim", 768)
self.num_heads = cfg.get("num_heads", 12)
self.head_dim = self.hidden_size // self.num_heads
self.attention_dropout = cfg.get("dropout_rate", 0.1)
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
if self.is_rope:
self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=cfg.get("sensor_max_len", 2880))
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if self.is_rope:
cos, sin = self.rotary_emb(value_states, seq_len=key_states.shape[-2])
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, dropout_p=self.attention_dropout)
return self.o_proj(attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size))
class CrossAttention(nn.Module):
def __init__(self, dim=768, *, context_dim=384, num_heads=12, dropout_rate=0.1):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.attn_dropout = dropout_rate
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(context_dim)
self.q_proj = nn.Linear(dim, dim, bias=True)
self.k_proj = nn.Linear(context_dim, dim, bias=True)
self.v_proj = nn.Linear(context_dim, dim, bias=True)
self.o_proj = nn.Linear(dim, dim, bias=False)
def forward(self, query, context, attention_mask=None, **kwargs):
bsz, q_len, _ = query.size()
assert context.size(0) == bsz, (
f"Context batch size ({context.size(0)}) must match query batch size ({bsz}). "
f"Ensure sensor and text inputs have the same batch size."
)
k_len = context.size(1)
query = self.norm(query)
context = self.context_norm(context)
q = self.q_proj(query).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(context).view(bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(context).view(bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=self.attn_dropout)
return self.o_proj(attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.dim))
class AllAttention(nn.Module):
def __init__(self, layer_idx, **cfg):
super().__init__()
self.self_attention = TsRoPEAttention(layer_idx=layer_idx, **cfg)
self.layer_norm = nn.LayerNorm(cfg.get('embed_dim'))
self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
def forward(self, hidden_states, attention_mask):
return hidden_states + self.dropout(self.self_attention(self.layer_norm(hidden_states), attention_mask))
class TimeSelfAttention(nn.Module):
def __init__(self, layer_idx, **cfg):
super().__init__()
self.self_attention = Attention(layer_idx=layer_idx, is_rope=True, **cfg)
self.layer_norm = nn.LayerNorm(cfg.get('embed_dim', 768))
self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
def forward(self, hidden_states, attention_mask, position_ids):
q_len = hidden_states.size(1)
am = rearrange(attention_mask, 'b nvar p -> (b nvar) p')
am = am.unsqueeze(1).unsqueeze(2).expand(-1, 1, q_len, q_len).bool()
return hidden_states + self.dropout(self.self_attention(self.layer_norm(hidden_states), am, position_ids))
class GroupSelfAttention(nn.Module):
def __init__(self, layer_idx, **cfg):
super().__init__()
self.self_attention = Attention(layer_idx, is_rope=False, **cfg)
self.layer_norm = nn.LayerNorm(cfg.get('embed_dim', 768))
self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1))
def forward(self, hidden_states, attention_mask, group_ids):
BS, nvar, _ = attention_mask.shape
hidden_states = rearrange(hidden_states, '(bs nvar) l d -> (bs l) nvar d', bs=BS, nvar=nvar)
am = rearrange(attention_mask, 'bs nvar l -> (bs l) nvar')
group_attn_mask = am.unsqueeze(1).unsqueeze(2).expand(-1, 1, nvar, nvar).bool()
hidden_states = hidden_states + self.dropout(self.self_attention(self.layer_norm(hidden_states), group_attn_mask))
return rearrange(hidden_states, '(bs l) nvar d -> (bs nvar) l d', bs=BS, nvar=nvar)
class AttentionPooling(nn.Module):
def __init__(self, dim=768, mlp_ratio=4, context_dim=384, num_heads=12, dropout_rate=0.1):
super().__init__()
self.cross_attn = CrossAttention(dim=dim, context_dim=context_dim, num_heads=num_heads, dropout_rate=dropout_rate)
self.ffn_norm = nn.LayerNorm(dim)
self.ffn_layer = MLP(hidden_size=dim, intermediate_size=dim * mlp_ratio, hidden_act='silu')
self.post_norm = nn.LayerNorm(dim)
def forward(self, x, context, attn_mask=None):
b, n, _ = x.shape
kv_len = context.shape[1]
attn_mask = rearrange(attn_mask, 'b nvar p -> b (nvar p)')
attn_mask = attn_mask.view(b, 1, 1, kv_len).expand(b, 1, n, kv_len).bool()
x = self.cross_attn(x, context, attn_mask)
x = x + self.ffn_layer(self.ffn_norm(x))
return self.post_norm(x)
class SensorEncoderLayer(nn.Module):
def __init__(self, layer_idx, **cfg):
super().__init__()
hidden_size = cfg['embed_dim']
self.channel_attn_type = cfg.get('channel_attn_type', 'group_attn')
if self.channel_attn_type == 'group_attn':
self.ts_attn = TimeSelfAttention(layer_idx=layer_idx, **cfg)
self.group_attn = GroupSelfAttention(layer_idx=layer_idx, **cfg)
elif self.channel_attn_type == 'univariate':
self.ts_attn = TimeSelfAttention(layer_idx=layer_idx, **cfg)
else:
self.ts_attn = AllAttention(layer_idx=layer_idx, **cfg)
self.norm = nn.LayerNorm(hidden_size)
self.ffn_layer = MLP(hidden_size=hidden_size, intermediate_size=cfg['mlp_ratio'] * hidden_size, hidden_act='silu')
def forward(self, hidden_states, attention_mask=None, group_ids=None, position_ids=None):
if self.channel_attn_type == 'group_attn':
hidden_states = self.ts_attn(hidden_states, attention_mask, position_ids)
hidden_states = self.group_attn(hidden_states, attention_mask, group_ids)
elif self.channel_attn_type == 'univariate':
hidden_states = self.ts_attn(hidden_states, attention_mask, position_ids)
else:
hidden_states = self.ts_attn(hidden_states, attention_mask)
residual = hidden_states
return residual + self.ffn_layer(self.norm(hidden_states))
class SensorTransformerModel(nn.Module):
def __init__(self, **cfg):
super().__init__()
patch_size = cfg.get('patch_size', None)
self.patch_size = patch_size
self.patch_embed = PatchEmbedding(**cfg) if patch_size else MultiSizePatchEmbed(**cfg)
self.blocks = nn.ModuleList([SensorEncoderLayer(i, **cfg) for i in range(cfg['depth'])])
self.norm = nn.LayerNorm(cfg['embed_dim'])
self.embed_dim = cfg['embed_dim']
self.channel_attn_type = cfg.get('channel_attn_type', 'group_attn')
def forward(self, input_ids, attention_mask, time_index):
if self.patch_size is None:
BS = len(input_ids)
hidden_states = self.patch_embed(flatten_list(input_ids), flatten_list(attention_mask), flatten_list(time_index))
attention_mask = self._get_self_attn_mask(attention_mask).to(hidden_states.device)
position_ids = rearrange(self._build_rope_position_ids(attention_mask), 'b nvar p -> (b nvar) p')
else:
BS = input_ids.shape[0]
hidden_states = self.patch_embed(input_ids, attention_mask, time_index)
attention_mask = reduce(attention_mask, 'b v (p ps) -> b v p', 'max', ps=self.patch_size)
position_ids = rearrange(self._build_rope_position_ids(attention_mask), 'b nvar p -> (b nvar) p')
if self.channel_attn_type == 'all_attn':
hidden_states = rearrange(hidden_states, '(b nvar) l d -> b (nvar l) d', b=BS)
for blk in self.blocks:
hidden_states = blk(hidden_states, attention_mask=attention_mask, group_ids=None, position_ids=position_ids)
if self.channel_attn_type == 'group_attn':
hidden_states = rearrange(hidden_states, '(b nvar) l d -> b (nvar l) d', b=BS)
return self.norm(hidden_states), attention_mask
def _build_rope_position_ids(self, attention_mask):
mask = attention_mask.to(torch.long)
return (mask.cumsum(dim=-1) - 1) * mask
def _get_self_attn_mask(self, attn_mask_list):
collapsed = []
for sample_masks in attn_mask_list:
collapsed.append(torch.stack([(m.sum(dim=-1) > 0).to(m.dtype) for m in sample_masks], dim=0))
return torch.stack(collapsed, dim=0)
# ═══════════════════════════════════════════════════════════════
# Multimodal Gemma
# ═══════════════════════════════════════════════════════════════
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
class Gemma3MultimodalLayer(nn.Module):
def __init__(self, original_layer, cross_attn_block):
super().__init__()
self.original_layer = original_layer
self.cross_attn_block = cross_attn_block
self.vis_x = None
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.original_layer, name)
def forward(self, hidden_states, **kwargs):
assert self.vis_x is not None, "vis_x must be set before forward pass."
outputs = self.original_layer(hidden_states, **kwargs)
hidden_states = self.cross_attn_block(outputs[0], context=self.vis_x)
return (hidden_states,) + outputs[1:]
class Gemma3MultimodalModel(nn.Module):
def __init__(self, model_id="google/gemma-3-270m", init_from_pretrained=False, split_layer=12, dtype=None):
super().__init__()
if init_from_pretrained:
self.model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True)
else:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
config.torch_dtype = dtype or torch.float32
self.model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True)
self.split_layer = split_layer
hidden_size = self.model.config.hidden_size
num_heads = self.model.config.num_attention_heads
self.hidden_size = hidden_size
for i in range(split_layer, len(self.model.model.layers)):
cross_attn = CrossAttention(
dim=hidden_size, context_dim=hidden_size, num_heads=num_heads, dropout_rate=0.1)
self.model.model.layers[i] = Gemma3MultimodalLayer(
self.model.model.layers[i], Residual(cross_attn))
def condition_image(self, image_embeds):
self.image_embeds = image_embeds
for layer in self.model.model.layers:
if isinstance(layer, Gemma3MultimodalLayer):
layer.condition_vis_x(self.image_embeds)
def forward(self, input_ids, attention_mask=None, return_embeddings=False, **kwargs):
outputs = self.model(
input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs)
text_sentence_embedding = outputs.hidden_states[self.split_layer][:, -1, :]
if return_embeddings:
return outputs
return text_sentence_embedding, outputs.logits
# ═══════════════════════════════════════════════════════════════
# SLIP Model (PreTrainedModel for HuggingFace Auto* classes)
# ═══════════════════════════════════════════════════════════════
def masked_mean(t, mask, dim=1, eps=1e-6):
t = t.masked_fill(~mask, 0.)
numer = t.sum(dim=dim)
denom = mask.sum(dim=dim).clamp(min=eps)
return numer / denom
class EmbedToLatents(nn.Module):
def __init__(self, dim, dim_latents):
super().__init__()
self.to_latents = nn.Linear(dim, dim_latents, bias=False)
def forward(self, x):
return F.normalize(self.to_latents(x), dim=-1)
class SLIPPreTrainedModel(PreTrainedModel):
config_class = SLIPConfig
base_model_prefix = "slip"
supports_gradient_checkpointing = False
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
class SLIPModel(SLIPPreTrainedModel):
"""
SLIP: Sensor Language Integrated Pre-training.
Usage:
model = AutoModel.from_pretrained("LeoChen085/SLIP", trust_remote_code=True)
"""
def __init__(self, config: SLIPConfig):
super().__init__(config)
# Sensor encoder
sensor_cfg = config.sensor_encoder
self.sensor_encoder = SensorTransformerModel(**sensor_cfg)
dim = self.sensor_encoder.embed_dim
# Multimodal LLM (init from scratch β€” weights come from safetensors)
self.multimodalModel = Gemma3MultimodalModel(
config.llm_model_name,
init_from_pretrained=False,
split_layer=config.split_layer,
dtype=getattr(config, "torch_dtype", None),
)
lm_dim = self.multimodalModel.hidden_size
common_dim = config.common_dim
# Attention pooling
num_img_queries = config.num_img_queries
if num_img_queries > 0:
self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, common_dim))
self.img_attn_pool = AttentionPooling(
dim=common_dim, context_dim=dim, num_heads=config.num_heads)
dim = common_dim
# Bridge projections
self.img_to_latents = EmbedToLatents(dim, common_dim)
self.text_to_latents = EmbedToLatents(common_dim, common_dim)
# Temperature
self.temperature = nn.Parameter(torch.tensor(math.log(1 / 0.07)))
self.temperature_max = math.log(1 / 0.07)
def embed_sensor(self, sensors, sensor_attn_mask=None, time_index=None):
sensor_tokens, attn_mask = self.sensor_encoder(sensors, sensor_attn_mask, time_index=time_index)
if hasattr(self, "img_attn_pool"):
img_queries = repeat(self.img_queries, "n d -> b n d", b=sensor_tokens.shape[0])
sensor_tokens = self.img_attn_pool(img_queries, sensor_tokens, attn_mask)
return sensor_tokens, attn_mask.bool()
def forward(self, text=None, sensors=None, **kwargs):
"""
Forward pass for contrastive + captioning training.
For inference, use get_embedding(), get_sensor_embedding(), or generate().
"""
sensor_hidden, sensor_mask = self.embed_sensor(
sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"],
time_index=sensors["time_index"])
self.multimodalModel.condition_image(sensor_hidden)
text_hidden, logits = self.multimodalModel(
input_ids=text["input_ids"][:, :-1], attention_mask=text["attention_mask"][:, :-1])
text_hidden = self.text_to_latents(text_hidden)
sensor_hidden = self.img_to_latents(sensor_hidden)
return {"text_hidden": text_hidden, "sensor_hidden": sensor_hidden, "logits": logits}
@torch.no_grad()
def get_embedding(self, text, sensors):
sensor_hidden, sensor_mask = self.embed_sensor(
sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"],
time_index=sensors["time_index"])
self.multimodalModel.condition_image(sensor_hidden)
text_hidden, _ = self.multimodalModel(
input_ids=text["input_ids"][:, :-1], attention_mask=text["attention_mask"][:, :-1])
text_hidden = self.text_to_latents(text_hidden)
sensor_hidden = self.img_to_latents(sensor_hidden)
if hasattr(self, "img_attn_pool"):
sensor_hidden = sensor_hidden[:, 0, :]
else:
sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, "b n p -> b (n p) 1"), dim=1)
return text_hidden, sensor_hidden
@torch.no_grad()
def get_sensor_embedding(self, input_ids, mask, time_index):
sensor_hidden, sensor_mask = self.embed_sensor(sensors=input_ids, sensor_attn_mask=mask, time_index=time_index)
sensor_hidden = self.img_to_latents(sensor_hidden)
if hasattr(self, "img_attn_pool"):
sensor_hidden = sensor_hidden[:, 0, :]
else:
sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, "b n p -> b (n p) 1"), dim=1)
return sensor_hidden
@torch.no_grad()
def generate(self, text, sensors, **generate_kwargs):
sensor_hidden, _ = self.embed_sensor(
sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"],
time_index=sensors["time_index"])
self.multimodalModel.condition_image(sensor_hidden)
return self.multimodalModel.model.generate(
input_ids=text["input_ids"], attention_mask=text["attention_mask"],
max_new_tokens=generate_kwargs.get("max_new_tokens", 300),
do_sample=generate_kwargs.get("do_sample", False),
num_beams=generate_kwargs.get("num_beams", 1))
def sft_training(self, text, sensors, return_output=False):
sensor_hidden, _ = self.embed_sensor(
sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"],
time_index=sensors["time_index"])
self.multimodalModel.condition_image(sensor_hidden)
outputs = self.multimodalModel.model(
input_ids=text["input_ids"], attention_mask=text["attention_mask"], return_dict=True)
if return_output:
return outputs
logits = outputs.logits
labels = text["labels"]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
ce = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1),
reduction="none", ignore_index=-100)
if "loss_weights" in text:
loss_weights = text["loss_weights"][:, 1:].contiguous().view(-1)
loss = (ce * loss_weights).sum() / loss_weights.sum()
else:
loss = ce.mean()
return {"loss": loss}