| """ |
| Ker-VLJEPA-3B — Inference-only model. |
| |
| Loads the Llama 3.2 3B backbone, applies LoRA adapters, visual encoder, |
| and cross-attention bridge components for CT report generation. |
| |
| Requires: |
| - A local copy of meta-llama/Llama-3.2-3B (user-provided) |
| - Weight files in weights/ (shipped with this package) |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from pathlib import Path |
| from typing import Optional, List, Tuple, Dict |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from safetensors.torch import load_file |
|
|
|
|
| |
| |
| |
|
|
| def sinusoidal_positional_encoding(n: int, d: int) -> torch.Tensor: |
| pe = torch.zeros(n, d) |
| pos = torch.arange(n, dtype=torch.float).unsqueeze(1) |
| div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| return pe.unsqueeze(0) |
|
|
|
|
| class ZonedCrossAttention(nn.Module): |
| """Z-Zoned cross-attention: each zone's queries attend only to their spatial slice range.""" |
|
|
| def __init__(self, slice_dim=1024, hidden_dim=1024, num_zones=32, |
| tokens_per_zone=1, num_heads=16, dropout=0.1): |
| super().__init__() |
| self.num_zones = num_zones |
| self.tokens_per_zone = tokens_per_zone |
| self.num_regions = num_zones * tokens_per_zone |
| self.num_heads = num_heads |
| self.hidden_dim = hidden_dim |
|
|
| self.zone_queries = nn.Parameter(torch.zeros(num_zones, tokens_per_zone, hidden_dim)) |
| self.zone_embed = nn.Embedding(num_zones, hidden_dim) |
| self.slice_proj = nn.Linear(slice_dim, hidden_dim) |
| self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True, dropout=dropout) |
| self.output_proj = nn.Linear(hidden_dim, slice_dim) |
| self.fallback_pos_embed = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02) |
|
|
| |
| div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)) |
| self.register_buffer("div_term", div_term) |
|
|
| def _add_z_pos(self, x_proj, mask): |
| """Sinusoidal positional encoding based on slice index.""" |
| B, S, D = x_proj.shape |
| device, dtype = x_proj.device, x_proj.dtype |
| z_positions = torch.zeros(B, S, device=device) |
| for b in range(B): |
| n = int(mask[b].sum().item()) if mask is not None else S |
| pos = torch.arange(n, device=device, dtype=torch.float) * 2.5 / 600.0 * 100.0 |
| z_positions[b, :n] = pos |
| pos = z_positions.unsqueeze(-1) |
| pe = torch.zeros(B, S, D, device=device, dtype=dtype) |
| pe[:, :, 0::2] = torch.sin(pos * self.div_term) |
| pe[:, :, 1::2] = torch.cos(pos * self.div_term) |
| if mask is not None: |
| pe = pe * mask.unsqueeze(-1) |
| return x_proj + pe |
|
|
| def forward(self, x, mask=None, metadata=None): |
| B, S, _ = x.shape |
| x_proj = self.slice_proj(x) |
| if mask is not None: |
| x_proj = self._add_z_pos(x_proj, mask) |
| else: |
| x_proj = x_proj + self.fallback_pos_embed |
|
|
| queries = self.zone_queries.reshape(self.num_regions, self.hidden_dim) |
| zone_ids = torch.arange(self.num_zones, device=x.device) |
| zone_emb = self.zone_embed(zone_ids).unsqueeze(1).expand(-1, self.tokens_per_zone, -1) |
| queries = (queries + zone_emb.reshape(self.num_regions, self.hidden_dim)).unsqueeze(0).expand(B, -1, -1) |
|
|
| |
| zone_mask = torch.ones(B, self.num_regions, S, device=x.device, dtype=torch.bool) |
| for b in range(B): |
| n = int(mask[b].sum().item()) if mask is not None else S |
| zone_size = n / self.num_zones |
| for z in range(self.num_zones): |
| s, e = int(round(z * zone_size)), min(max(int(round((z + 1) * zone_size)), int(round(z * zone_size)) + 1), n) |
| q_s, q_e = z * self.tokens_per_zone, (z + 1) * self.tokens_per_zone |
| zone_mask[b, q_s:q_e, s:e] = False |
| zone_mask = zone_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1).reshape(B * self.num_heads, self.num_regions, S) |
|
|
| kpm = (mask == 0) if mask is not None else None |
| region_hidden, attn_w = self.attention(queries, x_proj, x_proj, attn_mask=zone_mask, key_padding_mask=kpm, need_weights=True) |
| return self.output_proj(region_hidden), attn_w |
|
|
|
|
| class VisualEncoder(nn.Module): |
| """Compresses variable-length slice embeddings into fixed visual tokens in LLM space.""" |
|
|
| def __init__(self, slice_dim=1024, hidden_dim=1024, llm_dim=3072, |
| num_regions=32, num_zones=32, tokens_per_zone=1, |
| num_heads=16, dropout=0.1): |
| super().__init__() |
| self.region_query = ZonedCrossAttention(slice_dim, hidden_dim, num_zones, tokens_per_zone, num_heads, dropout) |
| self.global_self_attn = nn.TransformerEncoderLayer( |
| d_model=slice_dim, nhead=num_heads, batch_first=True, |
| dropout=dropout, dim_feedforward=slice_dim * 4, activation="gelu", |
| ) |
| self.jepa_predictor = nn.Sequential( |
| nn.Dropout(0.0), |
| nn.Linear(slice_dim, llm_dim), |
| nn.LayerNorm(llm_dim), |
| nn.Dropout(0.0), |
| ) |
| self.norm_calibrator_scale = 1.0 |
| self.register_buffer("_norm_scale", torch.tensor(1.0)) |
|
|
| def forward(self, slices, mask=None): |
| """ |
| Args: |
| slices: (B, num_slices, 1024) pre-computed LeJEPA embeddings |
| mask: (B, num_slices) binary mask, 1=valid 0=pad |
| Returns: |
| (B, 32, 3072) visual tokens in LLM hidden space |
| """ |
| regions, _ = self.region_query(slices, mask) |
| regions = self.global_self_attn(regions) |
| predicted = self.jepa_predictor(regions) |
| return predicted * self._norm_scale |
|
|
|
|
| class GatedCrossAttentionLayer(nn.Module): |
| """Flamingo-style cross-attention: text hidden states attend to visual tokens.""" |
|
|
| def __init__(self, hidden_dim=3072, num_heads=16): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.num_heads = num_heads |
| self.head_dim = hidden_dim // num_heads |
| self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
| self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
| self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
| self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
| self.q_norm = nn.LayerNorm(hidden_dim) |
| self.gate = nn.Parameter(torch.tensor(0.0), requires_grad=False) |
|
|
| def forward(self, text_hidden, visual_tokens): |
| B, S, _ = text_hidden.shape |
| dt = self.q_proj.weight.dtype |
| text_hidden = text_hidden.to(dt) |
| visual_tokens = visual_tokens.to(dt) |
| q = self.q_proj(self.q_norm(text_hidden)).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(visual_tokens).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(visual_tokens).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| out = F.scaled_dot_product_attention(q, k, v, is_causal=False) |
| return self.o_proj(out.transpose(1, 2).contiguous().view(B, S, self.hidden_dim)) |
|
|
|
|
| |
| |
| |
|
|
| class KerVLJEPA(nn.Module): |
| """ |
| Ker-VLJEPA-3B inference model. |
| |
| Takes pre-computed LeJEPA slice embeddings (1024-d per slice) and generates |
| free-text radiology reports using Llama 3.2 3B + LoRA + cross-attention bridge. |
| """ |
|
|
| INJECTION_LAYERS = [7, 14, 21] |
| NUM_REGIONS = 32 |
| VISUAL_TOKEN = "<|visual_region|>" |
|
|
| def __init__(self, llm_path: str, weights_dir: str = "weights", device: str = "cuda"): |
| super().__init__() |
| weights_dir = Path(weights_dir) |
| self.device = torch.device(device) |
|
|
| |
| tok_dir = weights_dir / "tokenizer" |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| llm_path, trust_remote_code=True, padding_side="left", |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.tokenizer.add_tokens([self.VISUAL_TOKEN], special_tokens=True) |
| self.vis_token_id = self.tokenizer.convert_tokens_to_ids(self.VISUAL_TOKEN) |
|
|
| |
| tmpl_path = tok_dir / "chat_template.jinja" |
| if tmpl_path.exists(): |
| self.tokenizer.chat_template = tmpl_path.read_text() |
|
|
| |
| self.llm = AutoModelForCausalLM.from_pretrained( |
| llm_path, torch_dtype=torch.bfloat16, |
| device_map=device, trust_remote_code=True, |
| attn_implementation="flash_attention_2", |
| ) |
| self.llm.resize_token_embeddings(len(self.tokenizer)) |
|
|
| lora_dir = weights_dir / "lora_adapters" |
| self.llm = PeftModel.from_pretrained(self.llm, str(lora_dir), is_trainable=False) |
| self.llm.eval() |
|
|
| llm_dim = self.llm.config.hidden_size |
|
|
| |
| self.visual_encoder = VisualEncoder( |
| slice_dim=1024, hidden_dim=1024, llm_dim=llm_dim, |
| num_regions=self.NUM_REGIONS, num_zones=32, tokens_per_zone=1, |
| num_heads=16, dropout=0.0, |
| ) |
| ve_state = load_file(str(weights_dir / "visual_encoder.safetensors")) |
| |
| |
| mapped = {} |
| for k, v in ve_state.items(): |
| if k == "norm_calibrator.scale": |
| self.visual_encoder._norm_scale = v.clone() |
| continue |
| if k.startswith("norm_calibrator.") or k.startswith("region_classifier.") or \ |
| k.startswith("slice_organ_classifier.") or k == "_last_attention_weights" or \ |
| k.startswith("output_ln."): |
| continue |
| mapped[k] = v |
| self.visual_encoder.load_state_dict(mapped, strict=False) |
| self.visual_encoder = self.visual_encoder.to(self.device).to(torch.bfloat16) |
| self.visual_encoder.eval() |
|
|
| |
| bridge = load_file(str(weights_dir / "bridge_components.safetensors")) |
|
|
| |
| self.register_buffer("text_embed_norm", bridge["text_embed_norm"].to(self.device)) |
|
|
| |
| self.layernorm = nn.LayerNorm(llm_dim).to(self.device).to(torch.bfloat16) |
| self.layernorm.weight.data.copy_(bridge["layernorm.weight"]) |
| self.layernorm.bias.data.copy_(bridge["layernorm.bias"]) |
|
|
| |
| self.cross_attn_adapters = nn.ModuleDict() |
| self.layer_projectors = nn.ModuleDict() |
| for layer_idx in self.INJECTION_LAYERS: |
| li = str(layer_idx) |
| adapter = GatedCrossAttentionLayer(llm_dim, 16) |
| adapter.q_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.q_proj.weight"]) |
| adapter.k_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.k_proj.weight"]) |
| adapter.v_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.v_proj.weight"]) |
| adapter.o_proj.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.o_proj.weight"]) |
| adapter.q_norm.weight.data.copy_(bridge[f"cross_attn_adapters.{li}.q_norm.weight"]) |
| adapter.q_norm.bias.data.copy_(bridge[f"cross_attn_adapters.{li}.q_norm.bias"]) |
| self.cross_attn_adapters[li] = adapter.to(self.device).to(torch.bfloat16) |
|
|
| proj = nn.Linear(llm_dim, llm_dim, bias=False) |
| proj.weight.data.copy_(bridge[f"layer_projectors.{li}.weight"]) |
| self.layer_projectors[li] = proj.to(self.device).to(torch.bfloat16) |
|
|
| |
| |
| |
|
|
| def _get_llm_layers(self): |
| model = self.llm |
| for _ in range(10): |
| if hasattr(model, "base_model"): |
| model = model.base_model |
| elif hasattr(model, "model"): |
| model = model.model |
| else: |
| break |
| return model.layers |
|
|
| def _get_embed_layer(self): |
| model = self.llm |
| for _ in range(10): |
| if hasattr(model, "get_input_embeddings"): |
| emb = model.get_input_embeddings() |
| if emb is not None: |
| return emb |
| if hasattr(model, "base_model"): |
| model = model.base_model |
| elif hasattr(model, "model"): |
| model = model.model |
| else: |
| break |
| raise RuntimeError("Could not find embedding layer") |
|
|
| def _normalize_visual(self, visual_embeds): |
| norms = visual_embeds.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
| return visual_embeds / norms * self.text_embed_norm |
|
|
| @staticmethod |
| def _clean_text(text: str) -> str: |
| text = text.strip() |
| if not text: |
| return text |
| low = text.lower() |
| first = low.find("findings:") |
| if first >= 0: |
| second = low.find("findings:", first + 10) |
| if second > 0: |
| text = text[:second].strip() |
| for marker in ["\n\n\n", "User", "user", "assistant", "system"]: |
| idx = text.find(marker) |
| if idx > 20: |
| text = text[:idx].strip() |
| return text |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| slice_embeddings: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| max_new_tokens: int = 384, |
| temperature: float = 0.6, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.1, |
| no_repeat_ngram_size: int = 4, |
| ) -> str: |
| """ |
| Generate a radiology report from pre-computed LeJEPA slice embeddings. |
| |
| Args: |
| slice_embeddings: (1, num_slices, 1024) — stacked 1024-d embeddings, |
| one per CT slice. Padding is allowed. |
| mask: (1, num_slices) — binary mask where 1=real slice, 0=padding. |
| If None, all slices are treated as valid. |
| max_new_tokens: maximum tokens to generate (default 384). |
| temperature: sampling temperature (default 0.6). |
| top_p: nucleus sampling threshold (default 0.9). |
| repetition_penalty: penalize repeated tokens (default 1.1). |
| no_repeat_ngram_size: prevent repeating n-grams (default 4). |
| |
| Returns: |
| Generated report text (str). |
| """ |
| assert slice_embeddings.ndim == 3 and slice_embeddings.shape[0] == 1 |
|
|
| |
| slices = slice_embeddings.to(self.device, dtype=torch.bfloat16) |
| if mask is not None: |
| mask = mask.to(self.device, dtype=torch.bfloat16) |
| visual_tokens = self.visual_encoder(slices, mask) |
| visual_tokens = self._normalize_visual(visual_tokens.to(torch.bfloat16)) |
|
|
| |
| placeholders = self.VISUAL_TOKEN * self.NUM_REGIONS |
| messages = [ |
| {"role": "system", "content": "You are a radiology reporting assistant. Describe thoracic findings based on the provided CT scan visual features. Report only what you observe."}, |
| {"role": "user", "content": f"Based on the visual features from this CT scan, describe the thoracic findings. {placeholders}"}, |
| ] |
| tokenized = self.tokenizer.apply_chat_template( |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" |
| ) |
| if hasattr(tokenized, "input_ids"): |
| input_ids = tokenized["input_ids"].to(self.device) |
| attention_mask = tokenized["attention_mask"].to(self.device) |
| else: |
| input_ids = tokenized.to(self.device) |
| attention_mask = torch.ones_like(input_ids) |
|
|
| |
| embed_layer = self._get_embed_layer() |
| inputs_embeds = embed_layer(input_ids).clone() |
| vis_mask = (input_ids == self.vis_token_id) |
| vis_positions = vis_mask[0].nonzero(as_tuple=True)[0] |
| assert len(vis_positions) == self.NUM_REGIONS, \ |
| f"Expected {self.NUM_REGIONS} visual tokens, found {len(vis_positions)}" |
| for idx, pos in enumerate(vis_positions): |
| inputs_embeds[0, pos] = visual_tokens[0, idx].to(inputs_embeds.dtype) |
|
|
| |
| hooks = [] |
| llm_layers = self._get_llm_layers() |
| for layer_idx in self.INJECTION_LAYERS: |
| li = str(layer_idx) |
| proj = self.layer_projectors[li] |
| adapter = self.cross_attn_adapters[li] |
|
|
| def make_hook(p, a, v_tokens, v_mask, seq_len): |
| def hook_fn(module, args, output): |
| hidden = output[0] if isinstance(output, tuple) else output |
| |
| if hidden.shape[1] == seq_len: |
| projected = p(v_tokens.to(hidden.dtype)) |
| vis_pos = v_mask[0].nonzero(as_tuple=True)[0] |
| hidden[0, vis_pos] = hidden[0, vis_pos] + projected[0, :len(vis_pos)] |
| |
| xattn_out = a(hidden, v_tokens.to(hidden.dtype)) |
| modified = hidden + xattn_out.to(hidden.dtype) |
| return (modified,) + output[1:] if isinstance(output, tuple) else modified |
| return hook_fn |
|
|
| h = llm_layers[layer_idx].register_forward_hook( |
| make_hook(proj, adapter, visual_tokens, vis_mask, input_ids.shape[1]) |
| ) |
| hooks.append(h) |
|
|
| |
| eot_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| start_header = self.tokenizer.convert_tokens_to_ids("<|start_header_id|>") |
| stop_ids = [self.tokenizer.eos_token_id] |
| if eot_id is not None and eot_id != self.tokenizer.eos_token_id: |
| stop_ids.append(eot_id) |
| if start_header is not None and start_header not in stop_ids: |
| stop_ids.append(start_header) |
|
|
| try: |
| generated_ids = self.llm.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| no_repeat_ngram_size=no_repeat_ngram_size, |
| do_sample=True, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=stop_ids, |
| ) |
| finally: |
| for h in hooks: |
| h.remove() |
|
|
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| return self._clean_text(text) |
|
|
| @torch.no_grad() |
| def classify( |
| self, |
| slice_embeddings: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| ) -> Dict[str, float]: |
| """ |
| Run the auxiliary 18-class abnormality classifier on slice embeddings. |
| |
| Returns a dict mapping each CT-RATE condition name to its sigmoid probability. |
| """ |
| CLASS_NAMES = [ |
| "Medical material", "Arterial wall calcification", |
| "Cardiomegaly", "Pericardial effusion", |
| "Coronary artery wall calcification", "Hiatal hernia", |
| "Lymphadenopathy", "Emphysema", |
| "Atelectasis", "Lung nodule", |
| "Lung opacity", "Pulmonary fibrotic sequela", |
| "Pleural effusion", "Mosaic attenuation pattern", |
| "Peribronchial thickening", "Consolidation", |
| "Bronchiectasis", "Interlobular septal thickening", |
| ] |
| slices = slice_embeddings.to(self.device, dtype=torch.bfloat16) |
| if mask is not None: |
| mask = mask.to(self.device, dtype=torch.bfloat16) |
| visual_tokens = self.visual_encoder(slices, mask) |
| pooled = visual_tokens.mean(dim=1) |
| |
| logits = self._case_classifier(pooled) |
| probs = torch.sigmoid(logits).squeeze(0).cpu().tolist() |
| return {name: round(p, 4) for name, p in zip(CLASS_NAMES, probs)} |
|
|
|
|
| def load_model(llm_path: str, weights_dir: str = "weights", device: str = "cuda") -> KerVLJEPA: |
| """ |
| Load the Ker-VLJEPA-3B model for inference. |
| |
| Args: |
| llm_path: Path to local Llama 3.2 3B model directory. |
| weights_dir: Path to the weights/ directory from this package. |
| device: CUDA device string (default "cuda"). |
| |
| Returns: |
| Ready-to-use KerVLJEPA model instance. |
| """ |
| model = KerVLJEPA(llm_path=llm_path, weights_dir=weights_dir, device=device) |
| model.eval() |
| return model |
|
|