File size: 24,838 Bytes
c838758 fc84ba8 c838758 fc84ba8 c838758 fc84ba8 c838758 fc84ba8 c838758 fc84ba8 c838758 fc84ba8 c838758 fc84ba8 c838758 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 | # zeus_mm.py
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModelForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
# Optional backends (lazy import pattern):
try:
from transformers import CLIPVisionModel
except Exception:
CLIPVisionModel = None
try:
from transformers import Wav2Vec2Model
except Exception:
Wav2Vec2Model = None
try:
from transformers import AutoModel as HFBackbone
except Exception:
HFBackbone = None
# ========================== CONFIG ==========================
class ZeusMMConfig(PretrainedConfig):
"""
Zeus: Multimodal conversational LM
- Decoder-only with RoPE + KV cache
- Cross-attn + FiLM fusion with router
- Modality-aware MoE-MLP
- Temporal audio tokens (optional)
- Retrieval slotting
- Role-aware RoPE scaling
- Easy backends: CLIP vision, Wav2Vec2 audio, any HF encoder for retrieval
"""
model_type = "zeusmm"
def __init__(
self,
# LM core
vocab_size=50000,
d_model=768,
n_heads=12,
n_layers=12,
d_ff=3072,
dropout=0.1,
rope_theta=10000.0,
rope_role_scales=(0.95, 1.00, 1.05), # (system, user, assistant)
# Multimodal adapters
vision_model_name: Optional[str] = "openai/clip-vit-base-patch32",
audio_model_name: Optional[str] = "facebook/wav2vec2-base-960h",
retrieval_model_name: Optional[str] = None, # e.g., "intfloat/e5-small-v2"
image_latents=32,
audio_latents=32,
retr_latents=64,
# Backend widths (projection into d_model)
d_vision=768, # CLIP ViT-B/32 hidden size
d_audio=768, # Wav2Vec2-Base hidden size
d_retrieval=768,
# FiLM & Router
film_hidden=1024,
router_hidden=256,
# MoE-MLP
num_experts=4,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.d_ff = d_ff
self.dropout = dropout
self.rope_theta = rope_theta
self.rope_role_scales = rope_role_scales
self.vision_model_name = vision_model_name
self.audio_model_name = audio_model_name
self.retrieval_model_name = retrieval_model_name
self.image_latents = image_latents
self.audio_latents = audio_latents
self.retr_latents = retr_latents
self.d_vision = d_vision
self.d_audio = d_audio
self.d_retrieval = d_retrieval
self.film_hidden = film_hidden
self.router_hidden = router_hidden
self.num_experts = num_experts
self.initializer_range = initializer_range
self.is_decoder = True
self.is_encoder_decoder = False
self.tie_word_embeddings = True
# ========================== RoPE (role-aware) ==========================
def _rotate_half(x):
x1, x2 = x[..., : x.size(-1) // 2], x[..., x.size(-1) // 2 :]
return torch.cat([-x2, x1], dim=-1)
def _apply_rotary(q, k, cos, sin):
q_ = (q * cos) + (_rotate_half(q) * sin)
k_ = (k * cos) + (_rotate_half(k) * sin)
return q_, k_
def _build_role_scaled_rope_cache(attn_len, head_dim, theta, device, role_ids=None, role_scales=(0.95,1.0,1.05)):
"""
Returns cos,sin of shape [B,1,attn_len,head_dim], with per-token role scaling.
role_ids: [B,attn_len] with {0:system,1:user,2:assistant}
"""
if role_ids is None:
pos = torch.arange(attn_len, device=device).float()[None, :] # [1,T]
else:
b = role_ids.size(0)
base = torch.arange(attn_len, device=device).float()[None, :].expand(b, -1) # [B,T]
scales = torch.ones_like(base)
for rid, s in enumerate(role_scales):
scales = torch.where(role_ids == rid, torch.full_like(scales, s), scales)
pos = base * scales
idx = torch.arange(head_dim, device=device).float()
freqs = 1.0 / (theta ** (idx / head_dim))
angles = pos[..., None] * freqs[None, None, :]
cos = torch.cos(angles)[:, None, :, :]
sin = torch.sin(angles)[:, None, :, :]
return cos, sin
# ========================== Attention blocks ==========================
class CausalSelfAttention(nn.Module):
def __init__(self, config: ZeusMMConfig):
super().__init__()
d, h = config.d_model, config.n_heads
assert d % h == 0
self.h = h
self.dk = d // h
self.qkv = nn.Linear(d, 3 * d)
self.o = nn.Linear(d, d)
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
def forward(
self,
x,
cos,
sin,
attention_mask=None, # [B, T_total], 1=keep, 0=pad
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
):
B, T, D = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(D, dim=-1)
q = q.view(B, T, self.h, -1).transpose(1, 2)
k = k.view(B, T, self.h, -1).transpose(1, 2)
v = v.view(B, T, self.h, -1).transpose(1, 2)
# role-aware RoPE
q, k = _apply_rotary(q, k, cos[..., :T, :], sin[..., :T, :])
if past_kv is not None:
pk, pv = past_kv
k = torch.cat([pk, k], dim=2)
v = torch.cat([pv, v], dim=2)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.dk) # [B,H,T,Tot]
# causal mask for cache-aware shapes
t_new = q.size(-2)
t_tot = k.size(-2)
causal = torch.full((t_new, t_tot), float("-inf"), device=x.device)
causal = torch.triu(causal, diagonal=1 + (t_tot - t_new))
scores += causal
if attention_mask is not None:
mask = (1 - attention_mask) * -1e4
scores = scores + mask[:, None, None, :]
attn = F.softmax(scores, dim=-1)
attn = self.attn_drop(attn)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, D)
out = self.resid_drop(self.o(out))
present = (k, v) if use_cache else None
return out, present
class CrossAttention(nn.Module):
"""Text queries attend to memory (image/audio/retrieval latents)."""
def __init__(self, config: ZeusMMConfig):
super().__init__()
d, h = config.d_model, config.n_heads
assert d % h == 0
self.h = h
self.dk = d // h
self.q = nn.Linear(d, d)
self.k = nn.Linear(d, d)
self.v = nn.Linear(d, d)
self.o = nn.Linear(d, d)
self.drop = nn.Dropout(config.dropout)
def forward(self, x, memory, memory_mask=None):
if memory is None:
return x
B, T, D = x.shape
M = memory.size(1)
q = self.q(x).view(B, T, self.h, -1).transpose(1, 2)
k = self.k(memory).view(B, M, self.h, -1).transpose(1, 2)
v = self.v(memory).view(B, M, self.h, -1).transpose(1, 2)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.dk)
if memory_mask is not None:
scores = scores + (1 - memory_mask)[:, None, None, :] * -1e4
attn = F.softmax(scores, dim=-1)
attn = self.drop(attn)
y = attn @ v
y = y.transpose(1, 2).contiguous().view(B, T, D)
return self.o(y)
# ========================== Unique fusion: FiLM + Router + MoE ==========================
class FiLMConditioner(nn.Module):
"""Produces FiLM (gamma,beta) from media summary."""
def __init__(self, config: ZeusMMConfig):
super().__init__()
d = config.d_model
h = config.film_hidden
self.net = nn.Sequential(nn.Linear(d, h), nn.SiLU(), nn.Linear(h, 2 * d))
def forward(self, media_summary):
gb = self.net(media_summary) # [B,2D]
g, b = gb.chunk(2, dim=-1)
return g, b
class Router(nn.Module):
"""Mix cross-attn vs FiLM per token."""
def __init__(self, config: ZeusMMConfig):
super().__init__()
d = config.d_model
h = config.router_hidden
self.net = nn.Sequential(nn.Linear(2 * d, h), nn.SiLU(), nn.Linear(h, 1))
def forward(self, hidden, film_context):
B, T, D = hidden.shape
ctx = film_context.unsqueeze(1).expand(B, T, D)
gate = torch.sigmoid(self.net(torch.cat([hidden, ctx], dim=-1))) # [B,T,1]
return gate
class MoE_MLP(nn.Module):
"""Modality-aware MoE: gate depends on token + media summary."""
def __init__(self, config: ZeusMMConfig):
super().__init__()
d = config.d_model
ff = config.d_ff
e = config.num_experts
self.experts = nn.ModuleList([nn.Sequential(
nn.Linear(d, ff), nn.GELU(), nn.Linear(ff, d)
) for _ in range(e)])
self.gate = nn.Linear(d * 2, e)
def forward(self, x, media_summary):
B, T, D = x.shape
m = media_summary.unsqueeze(1).expand(B, T, D)
logits = self.gate(torch.cat([x, m], dim=-1)) # [B,T,E]
probs = F.softmax(logits, dim=-1)
expert_outs = torch.stack([exp(x) for exp in self.experts], dim=-2) # [B,T,E,D]
out = (probs.unsqueeze(-1) * expert_outs).sum(dim=-2)
return out
# ========================== Decoder Block ==========================
class ZeusBlock(nn.Module):
def __init__(self, config: ZeusMMConfig):
super().__init__()
d = config.d_model
self.ln1 = nn.LayerNorm(d)
self.self_attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(d)
self.cross_attn = CrossAttention(config)
self.film = FiLMConditioner(config)
self.router = Router(config)
self.ln3 = nn.LayerNorm(d)
self.moe = MoE_MLP(config)
self.drop = nn.Dropout(config.dropout)
def forward(
self,
x,
cos,
sin,
attention_mask=None,
past_kv=None,
use_cache=False,
memory=None,
memory_mask=None,
media_summary=None, # [B,D]
):
sa, present = self.self_attn(self.ln1(x), cos, sin, attention_mask, past_kv, use_cache)
x = x + sa
ca = self.cross_attn(self.ln2(x), memory, memory_mask)
if media_summary is not None:
gamma, beta = self.film(media_summary)
film = (x * gamma.unsqueeze(1)) + beta.unsqueeze(1)
else:
film = x
gate = self.router(x, media_summary if media_summary is not None else torch.zeros_like(x[:, 0, :]))
x = x + gate * ca + (1 - gate) * film
mlp_out = self.moe(self.ln3(x), media_summary if media_summary is not None else torch.zeros_like(x[:, 0, :]))
x = x + self.drop(mlp_out)
return x, present
# ========================== Adapters with Easy Backends ==========================
class VisionAdapter(nn.Module):
"""
If CLIPVisionModel is available and vision_model_name is set, accepts:
pixel_values: [B, 3, H, W] normalized per CLIP feature extractor.
Otherwise, accepts precomputed image features via image_memory directly.
"""
def __init__(self, config: ZeusMMConfig):
super().__init__()
self.enabled = CLIPVisionModel is not None and bool(config.vision_model_name)
self.latents = config.image_latents
self.proj = nn.Linear(config.d_vision, config.d_model)
self.pool = nn.Linear(config.d_model, config.d_model)
if self.enabled:
self.encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.encoder.eval()
@torch.no_grad()
def encode_pixels(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
out = self.encoder(pixel_values=pixel_values, output_hidden_states=True)
feats = out.last_hidden_state # [B, N, d_vision]
return feats
def forward(self, pixel_values: Optional[torch.FloatTensor] = None, precomputed: Optional[torch.FloatTensor] = None):
if precomputed is None and (not self.enabled or pixel_values is None):
return None, None, None
feats = precomputed if precomputed is not None else self.encode_pixels(pixel_values)
x = self.proj(feats) # [B,N,D]
L = min(self.latents, x.size(1))
mem = x[:, :L, :]
mask = torch.ones(mem.size(0), mem.size(1), device=mem.device, dtype=torch.long)
summary = self.pool(mem.mean(dim=1))
return mem, mask, summary
class AudioAdapter(nn.Module):
"""
If Wav2Vec2Model is available and audio_model_name is set, accepts:
input_values: [B, T_audio] 16kHz PCM float in [-1,1]
Or pass precomputed audio memory via audio_memory.
Optional temporal tokens: [tempo, beat_phase] in [0,1], shape [B, Na, 2]
"""
def __init__(self, config: ZeusMMConfig):
super().__init__()
self.enabled = Wav2Vec2Model is not None and bool(config.audio_model_name)
self.latents = config.audio_latents
self.proj = nn.Linear(config.d_audio, config.d_model)
self.temp_proj = nn.Linear(2, config.d_model)
self.pool = nn.Linear(config.d_model, config.d_model)
if self.enabled:
self.encoder = Wav2Vec2Model.from_pretrained(config.audio_model_name)
self.encoder.eval()
@torch.no_grad()
def encode_wave(self, input_values: torch.FloatTensor) -> torch.FloatTensor:
out = self.encoder(input_values=input_values)
feats = out.last_hidden_state # [B, Na, d_audio]
return feats
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
temporal: Optional[torch.FloatTensor] = None,
precomputed: Optional[torch.FloatTensor] = None,
):
if precomputed is None and (not self.enabled or input_values is None):
return None, None, None
feats = precomputed if precomputed is not None else self.encode_wave(input_values)
x = self.proj(feats)
if temporal is not None:
x = x + self.temp_proj(temporal)
L = min(self.latents, x.size(1))
mem = x[:, :L, :]
mask = torch.ones(mem.size(0), mem.size(1), device=mem.device, dtype=torch.long)
summary = self.pool(mem.mean(dim=1))
return mem, mask, summary
class RetrievalAdapter(nn.Module):
"""
Retrieval adapter:
- If retrieval_model_name is set and available, accepts tokenized text (input_ids, attention_mask),
runs an HF encoder to get embeddings, pools them, and produces memory tokens.
- Otherwise, accepts precomputed retrieval features via retr_memory.
"""
def __init__(self, config: ZeusMMConfig):
super().__init__()
self.enabled = HFBackbone is not None and bool(config.retrieval_model_name)
self.latents = config.retr_latents
self.proj = nn.Linear(config.d_retrieval, config.d_model)
self.pool = nn.Linear(config.d_model, config.d_model)
if self.enabled:
self.encoder = HFBackbone.from_pretrained(config.retrieval_model_name)
self.encoder.eval()
@torch.no_grad()
def encode_tokens(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
hidden = out.last_hidden_state # [B, N, d_retrieval]
return hidden
def forward(
self,
retr_input_ids: Optional[torch.LongTensor] = None,
retr_attention_mask: Optional[torch.LongTensor] = None,
precomputed: Optional[torch.FloatTensor] = None,
):
if precomputed is None and (not self.enabled or retr_input_ids is None):
return None, None, None
feats = precomputed if precomputed is not None else self.encode_tokens(retr_input_ids, retr_attention_mask)
x = self.proj(feats)
L = min(self.latents, x.size(1))
mem = x[:, :L, :]
mask = torch.ones(mem.size(0), mem.size(1), device=mem.device, dtype=torch.long)
summary = self.pool(mem.mean(dim=1))
return mem, mask, summary
# ========================== Main Model ==========================
class ZeusForCausalLM(PreTrainedModel):
config_class = ZeusMMConfig
def __init__(self, config: ZeusMMConfig):
super().__init__(config)
d = config.d_model
self.embed_tokens = nn.Embedding(config.vocab_size, d)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([ZeusBlock(config) for _ in range(config.n_layers)])
self.ln_f = nn.LayerNorm(d)
self.lm_head = nn.Linear(d, config.vocab_size, bias=False)
# adapters
self.vision = VisionAdapter(config)
self.audio = AudioAdapter(config)
self.retr = RetrievalAdapter(config)
self.post_init()
# HF accessors
def get_input_embeddings(self): return self.embed_tokens
def set_input_embeddings(self, new_emb):
self.embed_tokens = new_emb
self.lm_head.weight = new_emb.weight
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_head):
self.lm_head = new_head
# ---- Generation plumbing ----
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
role_ids=None,
# prebuilt memories to carry across steps
image_memory=None,
audio_memory=None,
retr_memory=None,
memory_mask=None,
media_summary=None,
**kwargs
):
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"role_ids": role_ids,
"image_memory": image_memory,
"audio_memory": audio_memory,
"retr_memory": retr_memory,
"memory_mask": memory_mask,
"media_summary": media_summary,
"use_cache": kwargs.get("use_cache", True),
}
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
role_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
# HF Generation adds these — accept & ignore
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
# ---- Raw inputs for backends OR precomputed memories ----
# Vision
pixel_values: Optional[torch.FloatTensor] = None, # [B,3,H,W]
image_memory: Optional[torch.FloatTensor] = None, # [B,Li,D]
# Audio
input_values: Optional[torch.FloatTensor] = None, # [B,T_audio]
audio_temporal: Optional[torch.FloatTensor] = None, # [B,Na,2]
audio_memory: Optional[torch.FloatTensor] = None, # [B,La,D]
# Retrieval
retr_input_ids: Optional[torch.LongTensor] = None, # [B,Nr]
retr_attention_mask: Optional[torch.LongTensor] = None, # [B,Nr]
retr_memory: Optional[torch.FloatTensor] = None, # [B,Lr,D]
# Pre-assembled
memory_mask: Optional[torch.LongTensor] = None, # [B,Lm]
media_summary: Optional[torch.FloatTensor] = None, # [B,D]
# future-proof
**unused,
):
B, T = input_ids.shape
x = self.embed_tokens(input_ids)
x = self.drop(x)
# ---- Extend masks for cache ----
past_len = 0 if past_key_values is None else past_key_values[0][0].size(2)
tot_len = past_len + T
if attention_mask is None:
attention_mask = torch.ones(B, T, device=x.device, dtype=torch.long)
if past_len > 0 and attention_mask.size(1) == T:
pad = torch.ones(B, past_len, device=x.device, dtype=attention_mask.dtype)
attention_mask = torch.cat([pad, attention_mask], dim=1) # [B,tot_len]
# Role ids for role-aware RoPE
if role_ids is None:
role_ids = torch.full((B, tot_len), 1, device=x.device, dtype=torch.long) # default user
elif past_len > 0 and role_ids.size(1) == T:
pad_roles = torch.full((B, past_len), 1, device=x.device, dtype=role_ids.dtype)
role_ids = torch.cat([pad_roles, role_ids], dim=1)
Dh = self.config.d_model // self.config.n_heads
cos, sin = _build_role_scaled_rope_cache(
tot_len, Dh, self.config.rope_theta, x.device, role_ids=role_ids,
role_scales=self.config.rope_role_scales
)
# ---- Build memories from backends if not provided ----
mems, masks, summaries = [], [], []
# Vision
if image_memory is None and pixel_values is not None:
image_memory, image_mask, img_sum = self.vision(pixel_values=pixel_values)
if image_memory is not None:
mems.append(image_memory)
masks.append(torch.ones_like(image_memory[..., 0], dtype=torch.long) if 'image_mask' not in locals() else image_mask)
summaries.append(img_sum if 'img_sum' in locals() else image_memory.mean(dim=1))
# Audio
if audio_memory is None and input_values is not None:
audio_memory, audio_mask, aud_sum = self.audio(input_values=input_values, temporal=audio_temporal)
if audio_memory is not None:
mems.append(audio_memory)
masks.append(torch.ones_like(audio_memory[..., 0], dtype=torch.long) if 'audio_mask' not in locals() else audio_mask)
summaries.append(aud_sum if 'aud_sum' in locals() else audio_memory.mean(dim=1))
# Retrieval
if retr_memory is None and retr_input_ids is not None:
retr_memory, retr_mask, ret_sum = self.retr(retr_input_ids=retr_input_ids, retr_attention_mask=retr_attention_mask)
if retr_memory is not None:
mems.append(retr_memory)
masks.append(torch.ones_like(retr_memory[..., 0], dtype=torch.long) if 'retr_mask' not in locals() else retr_mask)
summaries.append(ret_sum if 'ret_sum' in locals() else retr_memory.mean(dim=1))
memory = torch.cat(mems, dim=1) if mems else None
if memory_mask is None:
memory_mask = torch.cat(masks, dim=1) if masks else None
if media_summary is None:
media_summary = torch.stack(summaries, dim=0).mean(dim=0) if summaries else torch.zeros(B, self.config.d_model, device=x.device)
# ---- Decoder stack ----
presents = []
h = x
for i, blk in enumerate(self.blocks):
past = None if past_key_values is None else past_key_values[i]
h, present = blk(
h, cos, sin,
attention_mask=attention_mask,
past_kv=past,
use_cache=use_cache,
memory=memory,
memory_mask=memory_mask,
media_summary=media_summary,
)
if use_cache:
presents.append(present)
logits = self.lm_head(self.ln_f(h))
# Shifted loss
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
logits[:, :-1, :].contiguous().view(-1, logits.size(-1)),
labels[:, 1:].contiguous().view(-1)
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=tuple(presents) if use_cache else None,
)
# ========================== Registration ==========================
AutoConfig.register("zeusmm", ZeusMMConfig)
AutoModelForCausalLM.register(ZeusMMConfig, ZeusForCausalLM)
|