Ker-VLJEPA-3B / model.py
codybum's picture
Initial release: Ker-VLJEPA-3B inference package
a974113 verified
"""
Ker-VLJEPA-3B — Inference-only model.
Loads the Llama 3.2 3B backbone, applies LoRA adapters, visual encoder,
and cross-attention bridge components for CT report generation.
Requires:
- A local copy of meta-llama/Llama-3.2-3B (user-provided)
- Weight files in weights/ (shipped with this package)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Optional, List, Tuple, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from safetensors.torch import load_file
# ---------------------------------------------------------------------------
# Visual encoder components (inference-only, no training heads)
# ---------------------------------------------------------------------------
def sinusoidal_positional_encoding(n: int, d: int) -> torch.Tensor:
pe = torch.zeros(n, d)
pos = torch.arange(n, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe.unsqueeze(0)
class ZonedCrossAttention(nn.Module):
"""Z-Zoned cross-attention: each zone's queries attend only to their spatial slice range."""
def __init__(self, slice_dim=1024, hidden_dim=1024, num_zones=32,
tokens_per_zone=1, num_heads=16, dropout=0.1):
super().__init__()
self.num_zones = num_zones
self.tokens_per_zone = tokens_per_zone
self.num_regions = num_zones * tokens_per_zone
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.zone_queries = nn.Parameter(torch.zeros(num_zones, tokens_per_zone, hidden_dim))
self.zone_embed = nn.Embedding(num_zones, hidden_dim)
self.slice_proj = nn.Linear(slice_dim, hidden_dim)
self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=dropout)
self.output_proj = nn.Linear(hidden_dim, slice_dim)
self.fallback_pos_embed = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
# Physical Z-positional encoding buffers
div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim))
self.register_buffer("div_term", div_term)
def _add_z_pos(self, x_proj, mask):
"""Sinusoidal positional encoding based on slice index."""
B, S, D = x_proj.shape
device, dtype = x_proj.device, x_proj.dtype
z_positions = torch.zeros(B, S, device=device)
for b in range(B):
n = int(mask[b].sum().item()) if mask is not None else S
pos = torch.arange(n, device=device, dtype=torch.float) * 2.5 / 600.0 * 100.0
z_positions[b, :n] = pos
pos = z_positions.unsqueeze(-1)
pe = torch.zeros(B, S, D, device=device, dtype=dtype)
pe[:, :, 0::2] = torch.sin(pos * self.div_term)
pe[:, :, 1::2] = torch.cos(pos * self.div_term)
if mask is not None:
pe = pe * mask.unsqueeze(-1)
return x_proj + pe
def forward(self, x, mask=None, metadata=None):
B, S, _ = x.shape
x_proj = self.slice_proj(x)
if mask is not None:
x_proj = self._add_z_pos(x_proj, mask)
else:
x_proj = x_proj + self.fallback_pos_embed
queries = self.zone_queries.reshape(self.num_regions, self.hidden_dim)
zone_ids = torch.arange(self.num_zones, device=x.device)
zone_emb = self.zone_embed(zone_ids).unsqueeze(1).expand(-1, self.tokens_per_zone, -1)
queries = (queries + zone_emb.reshape(self.num_regions, self.hidden_dim)).unsqueeze(0).expand(B, -1, -1)
# Zone mask
zone_mask = torch.ones(B, self.num_regions, S, device=x.device, dtype=torch.bool)
for b in range(B):
n = int(mask[b].sum().item()) if mask is not None else S
zone_size = n / self.num_zones
for z in range(self.num_zones):
s, e = int(round(z * zone_size)), min(max(int(round((z + 1) * zone_size)), int(round(z * zone_size)) + 1), n)
q_s, q_e = z * self.tokens_per_zone, (z + 1) * self.tokens_per_zone
zone_mask[b, q_s:q_e, s:e] = False
zone_mask = zone_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1).reshape(B * self.num_heads, self.num_regions, S)
kpm = (mask == 0) if mask is not None else None
region_hidden, attn_w = self.attention(queries, x_proj, x_proj, attn_mask=zone_mask, key_padding_mask=kpm, need_weights=True)
return self.output_proj(region_hidden), attn_w
class VisualEncoder(nn.Module):
"""Compresses variable-length slice embeddings into fixed visual tokens in LLM space."""
def __init__(self, slice_dim=1024, hidden_dim=1024, llm_dim=3072,
num_regions=32, num_zones=32, tokens_per_zone=1,
num_heads=16, dropout=0.1):
super().__init__()
self.region_query = ZonedCrossAttention(slice_dim, hidden_dim, num_zones, tokens_per_zone, num_heads, dropout)
self.global_self_attn = nn.TransformerEncoderLayer(
d_model=slice_dim, nhead=num_heads, batch_first=True,
dropout=dropout, dim_feedforward=slice_dim * 4, activation="gelu",
)
self.jepa_predictor = nn.Sequential(
nn.Dropout(0.0), # disabled at inference
nn.Linear(slice_dim, llm_dim),
nn.LayerNorm(llm_dim),
nn.Dropout(0.0),
)
self.norm_calibrator_scale = 1.0 # set from checkpoint buffer
self.register_buffer("_norm_scale", torch.tensor(1.0))
def forward(self, slices, mask=None):
"""
Args:
slices: (B, num_slices, 1024) pre-computed LeJEPA embeddings
mask: (B, num_slices) binary mask, 1=valid 0=pad
Returns:
(B, 32, 3072) visual tokens in LLM hidden space
"""
regions, _ = self.region_query(slices, mask)
regions = self.global_self_attn(regions)
predicted = self.jepa_predictor(regions)
return predicted * self._norm_scale
class GatedCrossAttentionLayer(nn.Module):
"""Flamingo-style cross-attention: text hidden states attend to visual tokens."""
def __init__(self, hidden_dim=3072, num_heads=16):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.q_norm = nn.LayerNorm(hidden_dim)
self.gate = nn.Parameter(torch.tensor(0.0), requires_grad=False)
def forward(self, text_hidden, visual_tokens):
B, S, _ = text_hidden.shape
dt = self.q_proj.weight.dtype
text_hidden = text_hidden.to(dt)
visual_tokens = visual_tokens.to(dt)
q = self.q_proj(self.q_norm(text_hidden)).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(visual_tokens).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(visual_tokens).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
return self.o_proj(out.transpose(1, 2).contiguous().view(B, S, self.hidden_dim))
# ---------------------------------------------------------------------------
# Main inference model
# ---------------------------------------------------------------------------
class KerVLJEPA(nn.Module):
"""
Ker-VLJEPA-3B inference model.
Takes pre-computed LeJEPA slice embeddings (1024-d per slice) and generates
free-text radiology reports using Llama 3.2 3B + LoRA + cross-attention bridge.
"""
INJECTION_LAYERS = [7, 14, 21]
NUM_REGIONS = 32
VISUAL_TOKEN = "<|visual_region|>"
def __init__(self, llm_path: str, weights_dir: str = "weights", device: str = "cuda"):
super().__init__()
weights_dir = Path(weights_dir)
self.device = torch.device(device)
# --- 1. Load tokenizer ---
tok_dir = weights_dir / "tokenizer"
self.tokenizer = AutoTokenizer.from_pretrained(
llm_path, trust_remote_code=True, padding_side="left",
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.add_tokens([self.VISUAL_TOKEN], special_tokens=True)
self.vis_token_id = self.tokenizer.convert_tokens_to_ids(self.VISUAL_TOKEN)
# Load custom chat template
tmpl_path = tok_dir / "chat_template.jinja"
if tmpl_path.exists():
self.tokenizer.chat_template = tmpl_path.read_text()
# --- 2. Load Llama 3.2 3B + LoRA ---
self.llm = AutoModelForCausalLM.from_pretrained(
llm_path, torch_dtype=torch.bfloat16,
device_map=device, trust_remote_code=True,
attn_implementation="flash_attention_2",
)
self.llm.resize_token_embeddings(len(self.tokenizer))
lora_dir = weights_dir / "lora_adapters"
self.llm = PeftModel.from_pretrained(self.llm, str(lora_dir), is_trainable=False)
self.llm.eval()
llm_dim = self.llm.config.hidden_size # 3072
# --- 3. Visual encoder ---
self.visual_encoder = VisualEncoder(
slice_dim=1024, hidden_dim=1024, llm_dim=llm_dim,
num_regions=self.NUM_REGIONS, num_zones=32, tokens_per_zone=1,
num_heads=16, dropout=0.0,
)
ve_state = load_file(str(weights_dir / "visual_encoder.safetensors"))
# Map checkpoint keys (which include output_ln, norm_calibrator, classifiers)
# into our simplified VisualEncoder
mapped = {}
for k, v in ve_state.items():
if k == "norm_calibrator.scale":
self.visual_encoder._norm_scale = v.clone()
continue
if k.startswith("norm_calibrator.") or k.startswith("region_classifier.") or \
k.startswith("slice_organ_classifier.") or k == "_last_attention_weights" or \
k.startswith("output_ln."):
continue
mapped[k] = v
self.visual_encoder.load_state_dict(mapped, strict=False)
self.visual_encoder = self.visual_encoder.to(self.device).to(torch.bfloat16)
self.visual_encoder.eval()
# --- 4. Bridge components ---
bridge = load_file(str(weights_dir / "bridge_components.safetensors"))
# Text embedding norm (for grafting normalization)
self.register_buffer("text_embed_norm", bridge["text_embed_norm"].to(self.device))
# LayerNorm for grafting
self.layernorm = nn.LayerNorm(llm_dim).to(self.device).to(torch.bfloat16)
self.layernorm.weight.data.copy_(bridge["layernorm.weight"])
self.layernorm.bias.data.copy_(bridge["layernorm.bias"])
# Cross-attention adapters + layer projectors
self.cross_attn_adapters = nn.ModuleDict()
self.layer_projectors = nn.ModuleDict()
for layer_idx in self.INJECTION_LAYERS:
li = str(layer_idx)
adapter = GatedCrossAttentionLayer(llm_dim, 16)
adapter.q_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.q_proj.weight"])
adapter.k_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.k_proj.weight"])
adapter.v_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.v_proj.weight"])
adapter.o_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.o_proj.weight"])
adapter.q_norm.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.q_norm.weight"])
adapter.q_norm.bias.data.copy_(bridge[f"cross_attn_adapters.{li}.q_norm.bias"])
self.cross_attn_adapters[li] = adapter.to(self.device).to(torch.bfloat16)
proj = nn.Linear(llm_dim, llm_dim, bias=False)
proj.weight.data.copy_(bridge[f"layer_projectors.{li}.weight"])
self.layer_projectors[li] = proj.to(self.device).to(torch.bfloat16)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_llm_layers(self):
model = self.llm
for _ in range(10):
if hasattr(model, "base_model"):
model = model.base_model
elif hasattr(model, "model"):
model = model.model
else:
break
return model.layers
def _get_embed_layer(self):
model = self.llm
for _ in range(10):
if hasattr(model, "get_input_embeddings"):
emb = model.get_input_embeddings()
if emb is not None:
return emb
if hasattr(model, "base_model"):
model = model.base_model
elif hasattr(model, "model"):
model = model.model
else:
break
raise RuntimeError("Could not find embedding layer")
def _normalize_visual(self, visual_embeds):
norms = visual_embeds.norm(dim=-1, keepdim=True).clamp(min=1e-8)
return visual_embeds / norms * self.text_embed_norm
@staticmethod
def _clean_text(text: str) -> str:
text = text.strip()
if not text:
return text
low = text.lower()
first = low.find("findings:")
if first >= 0:
second = low.find("findings:", first + 10)
if second > 0:
text = text[:second].strip()
for marker in ["\n\n\n", "User", "user", "assistant", "system"]:
idx = text.find(marker)
if idx > 20:
text = text[:idx].strip()
return text
# ------------------------------------------------------------------
# Generation
# ------------------------------------------------------------------
@torch.no_grad()
def generate(
self,
slice_embeddings: torch.Tensor,
mask: Optional[torch.Tensor] = None,
max_new_tokens: int = 384,
temperature: float = 0.6,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
no_repeat_ngram_size: int = 4,
) -> str:
"""
Generate a radiology report from pre-computed LeJEPA slice embeddings.
Args:
slice_embeddings: (1, num_slices, 1024) — stacked 1024-d embeddings,
one per CT slice. Padding is allowed.
mask: (1, num_slices) — binary mask where 1=real slice, 0=padding.
If None, all slices are treated as valid.
max_new_tokens: maximum tokens to generate (default 384).
temperature: sampling temperature (default 0.6).
top_p: nucleus sampling threshold (default 0.9).
repetition_penalty: penalize repeated tokens (default 1.1).
no_repeat_ngram_size: prevent repeating n-grams (default 4).
Returns:
Generated report text (str).
"""
assert slice_embeddings.ndim == 3 and slice_embeddings.shape[0] == 1
# 1. Visual forward
slices = slice_embeddings.to(self.device, dtype=torch.bfloat16)
if mask is not None:
mask = mask.to(self.device, dtype=torch.bfloat16)
visual_tokens = self.visual_encoder(slices, mask)
visual_tokens = self._normalize_visual(visual_tokens.to(torch.bfloat16))
# 2. Build prompt
placeholders = self.VISUAL_TOKEN * self.NUM_REGIONS
messages = [
{"role": "system", "content": "You are a radiology reporting assistant. Describe thoracic findings based on the provided CT scan visual features. Report only what you observe."},
{"role": "user", "content": f"Based on the visual features from this CT scan, describe the thoracic findings. {placeholders}"},
]
tokenized = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
if hasattr(tokenized, "input_ids"):
input_ids = tokenized["input_ids"].to(self.device)
attention_mask = tokenized["attention_mask"].to(self.device)
else:
input_ids = tokenized.to(self.device)
attention_mask = torch.ones_like(input_ids)
# 3. Graft visual tokens into embedding sequence
embed_layer = self._get_embed_layer()
inputs_embeds = embed_layer(input_ids).clone()
vis_mask = (input_ids == self.vis_token_id)
vis_positions = vis_mask[0].nonzero(as_tuple=True)[0]
assert len(vis_positions) == self.NUM_REGIONS, \
f"Expected {self.NUM_REGIONS} visual tokens, found {len(vis_positions)}"
for idx, pos in enumerate(vis_positions):
inputs_embeds[0, pos] = visual_tokens[0, idx].to(inputs_embeds.dtype)
# 4. Register cross-attention hooks
hooks = []
llm_layers = self._get_llm_layers()
for layer_idx in self.INJECTION_LAYERS:
li = str(layer_idx)
proj = self.layer_projectors[li]
adapter = self.cross_attn_adapters[li]
def make_hook(p, a, v_tokens, v_mask, seq_len):
def hook_fn(module, args, output):
hidden = output[0] if isinstance(output, tuple) else output
# Additive injection at visual positions (prefill only)
if hidden.shape[1] == seq_len:
projected = p(v_tokens.to(hidden.dtype))
vis_pos = v_mask[0].nonzero(as_tuple=True)[0]
hidden[0, vis_pos] = hidden[0, vis_pos] + projected[0, :len(vis_pos)]
# Cross-attention (every token)
xattn_out = a(hidden, v_tokens.to(hidden.dtype))
modified = hidden + xattn_out.to(hidden.dtype)
return (modified,) + output[1:] if isinstance(output, tuple) else modified
return hook_fn
h = llm_layers[layer_idx].register_forward_hook(
make_hook(proj, adapter, visual_tokens, vis_mask, input_ids.shape[1])
)
hooks.append(h)
# 5. Generate
eot_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
start_header = self.tokenizer.convert_tokens_to_ids("<|start_header_id|>")
stop_ids = [self.tokenizer.eos_token_id]
if eot_id is not None and eot_id != self.tokenizer.eos_token_id:
stop_ids.append(eot_id)
if start_header is not None and start_header not in stop_ids:
stop_ids.append(start_header)
try:
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=stop_ids,
)
finally:
for h in hooks:
h.remove()
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return self._clean_text(text)
@torch.no_grad()
def classify(
self,
slice_embeddings: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
"""
Run the auxiliary 18-class abnormality classifier on slice embeddings.
Returns a dict mapping each CT-RATE condition name to its sigmoid probability.
"""
CLASS_NAMES = [
"Medical material", "Arterial wall calcification",
"Cardiomegaly", "Pericardial effusion",
"Coronary artery wall calcification", "Hiatal hernia",
"Lymphadenopathy", "Emphysema",
"Atelectasis", "Lung nodule",
"Lung opacity", "Pulmonary fibrotic sequela",
"Pleural effusion", "Mosaic attenuation pattern",
"Peribronchial thickening", "Consolidation",
"Bronchiectasis", "Interlobular septal thickening",
]
slices = slice_embeddings.to(self.device, dtype=torch.bfloat16)
if mask is not None:
mask = mask.to(self.device, dtype=torch.bfloat16)
visual_tokens = self.visual_encoder(slices, mask) # (1, 32, 3072)
pooled = visual_tokens.mean(dim=1) # (1, 3072)
# case_classifier is loaded from bridge but we need to add it
logits = self._case_classifier(pooled) # (1, 18)
probs = torch.sigmoid(logits).squeeze(0).cpu().tolist()
return {name: round(p, 4) for name, p in zip(CLASS_NAMES, probs)}
def load_model(llm_path: str, weights_dir: str = "weights", device: str = "cuda") -> KerVLJEPA:
"""
Load the Ker-VLJEPA-3B model for inference.
Args:
llm_path: Path to local Llama 3.2 3B model directory.
weights_dir: Path to the weights/ directory from this package.
device: CUDA device string (default "cuda").
Returns:
Ready-to-use KerVLJEPA model instance.
"""
model = KerVLJEPA(llm_path=llm_path, weights_dir=weights_dir, device=device)
model.eval()
return model