""" Ker-VLJEPA-3B — Inference-only model. Loads the Llama 3.2 3B backbone, applies LoRA adapters, visual encoder, and cross-attention bridge components for CT report generation. Requires: - A local copy of meta-llama/Llama-3.2-3B (user-provided) - Weight files in weights/ (shipped with this package) """ import math import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from typing import Optional, List, Tuple, Dict from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel from safetensors.torch import load_file # --------------------------------------------------------------------------- # Visual encoder components (inference-only, no training heads) # --------------------------------------------------------------------------- def sinusoidal_positional_encoding(n: int, d: int) -> torch.Tensor: pe = torch.zeros(n, d) pos = torch.arange(n, dtype=torch.float).unsqueeze(1) div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) return pe.unsqueeze(0) class ZonedCrossAttention(nn.Module): """Z-Zoned cross-attention: each zone's queries attend only to their spatial slice range.""" def __init__(self, slice_dim=1024, hidden_dim=1024, num_zones=32, tokens_per_zone=1, num_heads=16, dropout=0.1): super().__init__() self.num_zones = num_zones self.tokens_per_zone = tokens_per_zone self.num_regions = num_zones * tokens_per_zone self.num_heads = num_heads self.hidden_dim = hidden_dim self.zone_queries = nn.Parameter(torch.zeros(num_zones, tokens_per_zone, hidden_dim)) self.zone_embed = nn.Embedding(num_zones, hidden_dim) self.slice_proj = nn.Linear(slice_dim, hidden_dim) self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=dropout) self.output_proj = nn.Linear(hidden_dim, slice_dim) self.fallback_pos_embed = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02) # Physical Z-positional encoding buffers div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)) self.register_buffer("div_term", div_term) def _add_z_pos(self, x_proj, mask): """Sinusoidal positional encoding based on slice index.""" B, S, D = x_proj.shape device, dtype = x_proj.device, x_proj.dtype z_positions = torch.zeros(B, S, device=device) for b in range(B): n = int(mask[b].sum().item()) if mask is not None else S pos = torch.arange(n, device=device, dtype=torch.float) * 2.5 / 600.0 * 100.0 z_positions[b, :n] = pos pos = z_positions.unsqueeze(-1) pe = torch.zeros(B, S, D, device=device, dtype=dtype) pe[:, :, 0::2] = torch.sin(pos * self.div_term) pe[:, :, 1::2] = torch.cos(pos * self.div_term) if mask is not None: pe = pe * mask.unsqueeze(-1) return x_proj + pe def forward(self, x, mask=None, metadata=None): B, S, _ = x.shape x_proj = self.slice_proj(x) if mask is not None: x_proj = self._add_z_pos(x_proj, mask) else: x_proj = x_proj + self.fallback_pos_embed queries = self.zone_queries.reshape(self.num_regions, self.hidden_dim) zone_ids = torch.arange(self.num_zones, device=x.device) zone_emb = self.zone_embed(zone_ids).unsqueeze(1).expand(-1, self.tokens_per_zone, -1) queries = (queries + zone_emb.reshape(self.num_regions, self.hidden_dim)).unsqueeze(0).expand(B, -1, -1) # Zone mask zone_mask = torch.ones(B, self.num_regions, S, device=x.device, dtype=torch.bool) for b in range(B): n = int(mask[b].sum().item()) if mask is not None else S zone_size = n / self.num_zones for z in range(self.num_zones): s, e = int(round(z * zone_size)), min(max(int(round((z + 1) * zone_size)), int(round(z * zone_size)) + 1), n) q_s, q_e = z * self.tokens_per_zone, (z + 1) * self.tokens_per_zone zone_mask[b, q_s:q_e, s:e] = False zone_mask = zone_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1).reshape(B * self.num_heads, self.num_regions, S) kpm = (mask == 0) if mask is not None else None region_hidden, attn_w = self.attention(queries, x_proj, x_proj, attn_mask=zone_mask, key_padding_mask=kpm, need_weights=True) return self.output_proj(region_hidden), attn_w class VisualEncoder(nn.Module): """Compresses variable-length slice embeddings into fixed visual tokens in LLM space.""" def __init__(self, slice_dim=1024, hidden_dim=1024, llm_dim=3072, num_regions=32, num_zones=32, tokens_per_zone=1, num_heads=16, dropout=0.1): super().__init__() self.region_query = ZonedCrossAttention(slice_dim, hidden_dim, num_zones, tokens_per_zone, num_heads, dropout) self.global_self_attn = nn.TransformerEncoderLayer( d_model=slice_dim, nhead=num_heads, batch_first=True, dropout=dropout, dim_feedforward=slice_dim * 4, activation="gelu", ) self.jepa_predictor = nn.Sequential( nn.Dropout(0.0), # disabled at inference nn.Linear(slice_dim, llm_dim), nn.LayerNorm(llm_dim), nn.Dropout(0.0), ) self.norm_calibrator_scale = 1.0 # set from checkpoint buffer self.register_buffer("_norm_scale", torch.tensor(1.0)) def forward(self, slices, mask=None): """ Args: slices: (B, num_slices, 1024) pre-computed LeJEPA embeddings mask: (B, num_slices) binary mask, 1=valid 0=pad Returns: (B, 32, 3072) visual tokens in LLM hidden space """ regions, _ = self.region_query(slices, mask) regions = self.global_self_attn(regions) predicted = self.jepa_predictor(regions) return predicted * self._norm_scale class GatedCrossAttentionLayer(nn.Module): """Flamingo-style cross-attention: text hidden states attend to visual tokens.""" def __init__(self, hidden_dim=3072, num_heads=16): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = hidden_dim // num_heads self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.q_norm = nn.LayerNorm(hidden_dim) self.gate = nn.Parameter(torch.tensor(0.0), requires_grad=False) def forward(self, text_hidden, visual_tokens): B, S, _ = text_hidden.shape dt = self.q_proj.weight.dtype text_hidden = text_hidden.to(dt) visual_tokens = visual_tokens.to(dt) q = self.q_proj(self.q_norm(text_hidden)).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(visual_tokens).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(visual_tokens).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) out = F.scaled_dot_product_attention(q, k, v, is_causal=False) return self.o_proj(out.transpose(1, 2).contiguous().view(B, S, self.hidden_dim)) # --------------------------------------------------------------------------- # Main inference model # --------------------------------------------------------------------------- class KerVLJEPA(nn.Module): """ Ker-VLJEPA-3B inference model. Takes pre-computed LeJEPA slice embeddings (1024-d per slice) and generates free-text radiology reports using Llama 3.2 3B + LoRA + cross-attention bridge. """ INJECTION_LAYERS = [7, 14, 21] NUM_REGIONS = 32 VISUAL_TOKEN = "<|visual_region|>" def __init__(self, llm_path: str, weights_dir: str = "weights", device: str = "cuda"): super().__init__() weights_dir = Path(weights_dir) self.device = torch.device(device) # --- 1. Load tokenizer --- tok_dir = weights_dir / "tokenizer" self.tokenizer = AutoTokenizer.from_pretrained( llm_path, trust_remote_code=True, padding_side="left", ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.add_tokens([self.VISUAL_TOKEN], special_tokens=True) self.vis_token_id = self.tokenizer.convert_tokens_to_ids(self.VISUAL_TOKEN) # Load custom chat template tmpl_path = tok_dir / "chat_template.jinja" if tmpl_path.exists(): self.tokenizer.chat_template = tmpl_path.read_text() # --- 2. Load Llama 3.2 3B + LoRA --- self.llm = AutoModelForCausalLM.from_pretrained( llm_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True, attn_implementation="flash_attention_2", ) self.llm.resize_token_embeddings(len(self.tokenizer)) lora_dir = weights_dir / "lora_adapters" self.llm = PeftModel.from_pretrained(self.llm, str(lora_dir), is_trainable=False) self.llm.eval() llm_dim = self.llm.config.hidden_size # 3072 # --- 3. Visual encoder --- self.visual_encoder = VisualEncoder( slice_dim=1024, hidden_dim=1024, llm_dim=llm_dim, num_regions=self.NUM_REGIONS, num_zones=32, tokens_per_zone=1, num_heads=16, dropout=0.0, ) ve_state = load_file(str(weights_dir / "visual_encoder.safetensors")) # Map checkpoint keys (which include output_ln, norm_calibrator, classifiers) # into our simplified VisualEncoder mapped = {} for k, v in ve_state.items(): if k == "norm_calibrator.scale": self.visual_encoder._norm_scale = v.clone() continue if k.startswith("norm_calibrator.") or k.startswith("region_classifier.") or \ k.startswith("slice_organ_classifier.") or k == "_last_attention_weights" or \ k.startswith("output_ln."): continue mapped[k] = v self.visual_encoder.load_state_dict(mapped, strict=False) self.visual_encoder = self.visual_encoder.to(self.device).to(torch.bfloat16) self.visual_encoder.eval() # --- 4. Bridge components --- bridge = load_file(str(weights_dir / "bridge_components.safetensors")) # Text embedding norm (for grafting normalization) self.register_buffer("text_embed_norm", bridge["text_embed_norm"].to(self.device)) # LayerNorm for grafting self.layernorm = nn.LayerNorm(llm_dim).to(self.device).to(torch.bfloat16) self.layernorm.weight.data.copy_(bridge["layernorm.weight"]) self.layernorm.bias.data.copy_(bridge["layernorm.bias"]) # Cross-attention adapters + layer projectors self.cross_attn_adapters = nn.ModuleDict() self.layer_projectors = nn.ModuleDict() for layer_idx in self.INJECTION_LAYERS: li = str(layer_idx) adapter = GatedCrossAttentionLayer(llm_dim, 16) adapter.q_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.q_proj.weight"]) adapter.k_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.k_proj.weight"]) adapter.v_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.v_proj.weight"]) adapter.o_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.o_proj.weight"]) adapter.q_norm.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.q_norm.weight"]) adapter.q_norm.bias.data.copy_(bridge[f"cross_attn_adapters.{li}.q_norm.bias"]) self.cross_attn_adapters[li] = adapter.to(self.device).to(torch.bfloat16) proj = nn.Linear(llm_dim, llm_dim, bias=False) proj.weight.data.copy_(bridge[f"layer_projectors.{li}.weight"]) self.layer_projectors[li] = proj.to(self.device).to(torch.bfloat16) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _get_llm_layers(self): model = self.llm for _ in range(10): if hasattr(model, "base_model"): model = model.base_model elif hasattr(model, "model"): model = model.model else: break return model.layers def _get_embed_layer(self): model = self.llm for _ in range(10): if hasattr(model, "get_input_embeddings"): emb = model.get_input_embeddings() if emb is not None: return emb if hasattr(model, "base_model"): model = model.base_model elif hasattr(model, "model"): model = model.model else: break raise RuntimeError("Could not find embedding layer") def _normalize_visual(self, visual_embeds): norms = visual_embeds.norm(dim=-1, keepdim=True).clamp(min=1e-8) return visual_embeds / norms * self.text_embed_norm @staticmethod def _clean_text(text: str) -> str: text = text.strip() if not text: return text low = text.lower() first = low.find("findings:") if first >= 0: second = low.find("findings:", first + 10) if second > 0: text = text[:second].strip() for marker in ["\n\n\n", "User", "user", "assistant", "system"]: idx = text.find(marker) if idx > 20: text = text[:idx].strip() return text # ------------------------------------------------------------------ # Generation # ------------------------------------------------------------------ @torch.no_grad() def generate( self, slice_embeddings: torch.Tensor, mask: Optional[torch.Tensor] = None, max_new_tokens: int = 384, temperature: float = 0.6, top_p: float = 0.9, repetition_penalty: float = 1.1, no_repeat_ngram_size: int = 4, ) -> str: """ Generate a radiology report from pre-computed LeJEPA slice embeddings. Args: slice_embeddings: (1, num_slices, 1024) — stacked 1024-d embeddings, one per CT slice. Padding is allowed. mask: (1, num_slices) — binary mask where 1=real slice, 0=padding. If None, all slices are treated as valid. max_new_tokens: maximum tokens to generate (default 384). temperature: sampling temperature (default 0.6). top_p: nucleus sampling threshold (default 0.9). repetition_penalty: penalize repeated tokens (default 1.1). no_repeat_ngram_size: prevent repeating n-grams (default 4). Returns: Generated report text (str). """ assert slice_embeddings.ndim == 3 and slice_embeddings.shape[0] == 1 # 1. Visual forward slices = slice_embeddings.to(self.device, dtype=torch.bfloat16) if mask is not None: mask = mask.to(self.device, dtype=torch.bfloat16) visual_tokens = self.visual_encoder(slices, mask) visual_tokens = self._normalize_visual(visual_tokens.to(torch.bfloat16)) # 2. Build prompt placeholders = self.VISUAL_TOKEN * self.NUM_REGIONS messages = [ {"role": "system", "content": "You are a radiology reporting assistant. Describe thoracic findings based on the provided CT scan visual features. Report only what you observe."}, {"role": "user", "content": f"Based on the visual features from this CT scan, describe the thoracic findings. {placeholders}"}, ] tokenized = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) if hasattr(tokenized, "input_ids"): input_ids = tokenized["input_ids"].to(self.device) attention_mask = tokenized["attention_mask"].to(self.device) else: input_ids = tokenized.to(self.device) attention_mask = torch.ones_like(input_ids) # 3. Graft visual tokens into embedding sequence embed_layer = self._get_embed_layer() inputs_embeds = embed_layer(input_ids).clone() vis_mask = (input_ids == self.vis_token_id) vis_positions = vis_mask[0].nonzero(as_tuple=True)[0] assert len(vis_positions) == self.NUM_REGIONS, \ f"Expected {self.NUM_REGIONS} visual tokens, found {len(vis_positions)}" for idx, pos in enumerate(vis_positions): inputs_embeds[0, pos] = visual_tokens[0, idx].to(inputs_embeds.dtype) # 4. Register cross-attention hooks hooks = [] llm_layers = self._get_llm_layers() for layer_idx in self.INJECTION_LAYERS: li = str(layer_idx) proj = self.layer_projectors[li] adapter = self.cross_attn_adapters[li] def make_hook(p, a, v_tokens, v_mask, seq_len): def hook_fn(module, args, output): hidden = output[0] if isinstance(output, tuple) else output # Additive injection at visual positions (prefill only) if hidden.shape[1] == seq_len: projected = p(v_tokens.to(hidden.dtype)) vis_pos = v_mask[0].nonzero(as_tuple=True)[0] hidden[0, vis_pos] = hidden[0, vis_pos] + projected[0, :len(vis_pos)] # Cross-attention (every token) xattn_out = a(hidden, v_tokens.to(hidden.dtype)) modified = hidden + xattn_out.to(hidden.dtype) return (modified,) + output[1:] if isinstance(output, tuple) else modified return hook_fn h = llm_layers[layer_idx].register_forward_hook( make_hook(proj, adapter, visual_tokens, vis_mask, input_ids.shape[1]) ) hooks.append(h) # 5. Generate eot_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") start_header = self.tokenizer.convert_tokens_to_ids("<|start_header_id|>") stop_ids = [self.tokenizer.eos_token_id] if eot_id is not None and eot_id != self.tokenizer.eos_token_id: stop_ids.append(eot_id) if start_header is not None and start_header not in stop_ids: stop_ids.append(start_header) try: generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=stop_ids, ) finally: for h in hooks: h.remove() text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) return self._clean_text(text) @torch.no_grad() def classify( self, slice_embeddings: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> Dict[str, float]: """ Run the auxiliary 18-class abnormality classifier on slice embeddings. Returns a dict mapping each CT-RATE condition name to its sigmoid probability. """ CLASS_NAMES = [ "Medical material", "Arterial wall calcification", "Cardiomegaly", "Pericardial effusion", "Coronary artery wall calcification", "Hiatal hernia", "Lymphadenopathy", "Emphysema", "Atelectasis", "Lung nodule", "Lung opacity", "Pulmonary fibrotic sequela", "Pleural effusion", "Mosaic attenuation pattern", "Peribronchial thickening", "Consolidation", "Bronchiectasis", "Interlobular septal thickening", ] slices = slice_embeddings.to(self.device, dtype=torch.bfloat16) if mask is not None: mask = mask.to(self.device, dtype=torch.bfloat16) visual_tokens = self.visual_encoder(slices, mask) # (1, 32, 3072) pooled = visual_tokens.mean(dim=1) # (1, 3072) # case_classifier is loaded from bridge but we need to add it logits = self._case_classifier(pooled) # (1, 18) probs = torch.sigmoid(logits).squeeze(0).cpu().tolist() return {name: round(p, 4) for name, p in zip(CLASS_NAMES, probs)} def load_model(llm_path: str, weights_dir: str = "weights", device: str = "cuda") -> KerVLJEPA: """ Load the Ker-VLJEPA-3B model for inference. Args: llm_path: Path to local Llama 3.2 3B model directory. weights_dir: Path to the weights/ directory from this package. device: CUDA device string (default "cuda"). Returns: Ready-to-use KerVLJEPA model instance. """ model = KerVLJEPA(llm_path=llm_path, weights_dir=weights_dir, device=device) model.eval() return model