Vbai-2.6AD
Description
Vbai-2.6 is a 3D brain MRI segmentation model developed as the latest generation member of the Vbai model family. Unlike previous versions, Vbai-2.6AD is now a multimodal model designed for professional research purposes, capable of working with NIfTI files and blood values. Vbai-3D versions have been integrated with the standard Vbai versions.
The model generates voxel-level segmentation masks instead of image-level labels and provides the spatial localization of pathological regions in addition to quantitative tissue volume measurements. It aims to accurately establish a diagnosis by establishing a context with blood values and patient information.
Vbai-2.6AD also serves as the core engine of the HealFuture image processing library and can run each diagnostic task independently or in combination, depending on the clinical use case. This model has been trained exclusively for the early diagnosis of atrophy and Alzheimer’s disease.
Audience / Target
Vbai models are developed exclusively for hospitals, universities, communities, health centres and science centres.
Architecture
This model is a multimodal deep learning architecture that combines 3D MRI data and clinical biomarkers for Alzheimer's classification and progression prediction.
| Input Modality | Encoder | Fusion Layer | Prediction Heads (Output) |
|---|---|---|---|
| 3D MRI Volume (96x96x96, 1ch) |
3D ResNet + CBAM + SE Blocks + ASPP |
Bidirectional Cross-Attention (MRI ↔ Tabular) + Gated Residuals |
→ Diagnosis: CN / MCI / AD Classification |
| Clinical Data (13 Features + Masks) |
MLP (LayerNorm + GELU) Linear Projection |
→ Progression: MCI to AD Conversion & Timeline Estimation | |
| Auxiliary Supervision (Training Mode) | Contrastive Learning Alignment + Modality-specific Logits | ||
- MRI Encoder: Custom 3D ResNet architecture featuring ASPP (Atrous Spatial Pyramid Pooling) for capturing multi-scale context and a dual-attention mechanism (CBAM/SE) for spatial focus.
- Tabular Encoder: MLP with a 256-dimensional projection layer, processing 13 clinical biomarkers and missing-value masks.
- Fusion Module: Multi-Head Bidirectional Cross-Attention and a 512-dimensional shared latent space, enabling each modality to cross-weight the other.
- Deep Supervision: Independent auxiliary classifiers for both MRI and Tabular streams to enhance robustness against noisy data during training.
- Training Strategy: Includes modality dropout, random feature masking, and contrastive loss for cross-modal alignment.
General Tests
| Test Size | Params | Accuracy | ROC-AUC | F1 Score | Recall | Precision |
|---|---|---|---|---|---|---|
| 96³ | 16.85M | 80.6% | 95.4% | 78.6% | 80.6% | 84.1% |
*Tested with ADNI T1 and ADNIMERGE values datasets. But training is excluding ADNI T1 and ADNIMERGE values datasets.
*It was trained in just 14 epochs.
*No transfer learning or pre-trained weights were used.
Usage
Python Script
"""
Vbai-2.6AD Standalone Inference Script
===============================================
A self-contained script containing the full model architecture and inference logic.
Designed for open-source distribution.
Usage:
python vbai-2.6ad_test.py --weights path/to/model.pt --mri path/to/scan.nii --clinical "Age:75.2, Sex:1, MMSE:25"
"""
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
try:
import nibabel as nib
from scipy.ndimage import zoom
HAS_NIBABEL = True
except ImportError:
HAS_NIBABEL = False
# ============================================================
# Configuration Constants
# ============================================================
FEATURE_NAMES = [
"Age", "Sex", "MMSE", "CDRSB", "APOE4_count",
"CSF_ABETA42", "CSF_TAU", "CSF_PTAU", "CSF_AB42_AB40",
"PLASMA_PTAU", "PLASMA_NFL", "PLASMA_AB42_AB40", "PLASMA_GFAP"
]
CLASS_NAMES = ["CN", "MCI", "AD"]
class ModelConfig:
def __init__(self):
self.mri_input_shape = (1, 96, 96, 96)
self.mri_encoder_channels = [32, 64, 128, 256]
self.mri_bottleneck_channels = 512
self.mri_feature_dim = 512
self.mri_dropout = 0.4
self.use_cbam = True
self.use_se_block = True
self.num_tabular_inputs = len(FEATURE_NAMES) * 2
self.tabular_hidden_dims = [128, 256]
self.tabular_feature_dim = 256
self.tabular_dropout = 0.3
self.fusion_dim = 512
self.fusion_num_heads = 8
self.fusion_dropout = 0.3
self.num_classes = 3
self.progression_hidden_dim = 256
self.max_progression_months = 120
self.num_time_bins = 24
# ============================================================
# Attention Modules
# ============================================================
class ChannelAttention3D(nn.Module):
def __init__(self, ch, r=16):
super().__init__()
m = max(ch // r, 8)
self.mlp = nn.Sequential(nn.Linear(ch, m), nn.ReLU(inplace=True), nn.Linear(m, ch))
def forward(self, x):
a = x.mean(dim=[2, 3, 4]); b = x.amax(dim=[2, 3, 4])
attn = torch.sigmoid(self.mlp(a) + self.mlp(b))
return x * attn[..., None, None, None]
class SpatialAttention3D(nn.Module):
def __init__(self, k=7):
super().__init__()
self.conv = nn.Conv3d(2, 1, k, padding=k // 2, bias=False)
def forward(self, x):
avg = x.mean(dim=1, keepdim=True); mx = x.amax(dim=1, keepdim=True)
attn = torch.sigmoid(self.conv(torch.cat([avg, mx], dim=1)))
return x * attn
class CBAM3D(nn.Module):
def __init__(self, ch, r=16):
super().__init__()
self.c = ChannelAttention3D(ch, r); self.s = SpatialAttention3D()
def forward(self, x): return self.s(self.c(x))
class SEBlock3D(nn.Module):
def __init__(self, ch, r=16):
super().__init__()
m = max(ch // r, 8)
self.fc = nn.Sequential(nn.Linear(ch, m), nn.ReLU(True), nn.Linear(m, ch), nn.Sigmoid())
def forward(self, x):
s = x.mean(dim=[2, 3, 4]); s = self.fc(s)[..., None, None, None]
return x * s
# ============================================================
# Encoders
# ============================================================
class ResBlock3D(nn.Module):
def __init__(self, in_ch, out_ch, stride=1, use_cbam=True, use_se=True, drop_path=0.0):
super().__init__()
self.conv1 = nn.Conv3d(in_ch, out_ch, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm3d(out_ch)
self.conv2 = nn.Conv3d(out_ch, out_ch, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm3d(out_ch)
self.act = nn.GELU()
self.cbam = CBAM3D(out_ch) if use_cbam else nn.Identity()
self.se = SEBlock3D(out_ch) if use_se else nn.Identity()
self.skip = nn.Identity() if (in_ch == out_ch and stride == 1) else nn.Sequential(
nn.Conv3d(in_ch, out_ch, 1, stride, bias=False), nn.BatchNorm3d(out_ch))
def forward(self, x):
identity = self.skip(x)
out = self.act(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.cbam(out); out = self.se(out)
return self.act(out + identity)
class ASPP3D(nn.Module):
def __init__(self, in_ch, out_ch, dilations=(1, 6, 12, 18)):
super().__init__()
per = out_ch // len(dilations)
self.branches = nn.ModuleList([
nn.Sequential(nn.Conv3d(in_ch, per, 3, padding=d, dilation=d, bias=False),
nn.BatchNorm3d(per), nn.GELU())
for d in dilations
])
self.gp = nn.Sequential(
nn.AdaptiveAvgPool3d(1),
nn.Conv3d(in_ch, per, 1, bias=False),
nn.BatchNorm3d(per), nn.GELU())
self.fuse = nn.Sequential(nn.Conv3d(per * (len(dilations) + 1), out_ch, 1, bias=False),
nn.BatchNorm3d(out_ch), nn.GELU())
def forward(self, x):
feats = [b(x) for b in self.branches]
g = self.gp(x)
g = F.interpolate(g, size=x.shape[2:], mode="trilinear", align_corners=False)
feats.append(g)
return self.fuse(torch.cat(feats, dim=1))
class MRIEncoder3D(nn.Module):
def __init__(self, mcfg: ModelConfig):
super().__init__()
ch = mcfg.mri_encoder_channels
self.stem = nn.Sequential(
nn.Conv3d(1, ch[0], 7, 2, 3, bias=False), nn.BatchNorm3d(ch[0]), nn.GELU(),
nn.MaxPool3d(3, 2, 1))
self.stage1 = self._make(ch[0], ch[0], 2, 1, mcfg)
self.stage2 = self._make(ch[0], ch[1], 2, 2, mcfg)
self.stage3 = self._make(ch[1], ch[2], 2, 2, mcfg)
self.stage4 = self._make(ch[2], ch[3], 2, 2, mcfg)
self.aspp = ASPP3D(ch[3], mcfg.mri_bottleneck_channels)
self.pool = nn.AdaptiveAvgPool3d(1)
self.proj = nn.Sequential(
nn.Linear(mcfg.mri_bottleneck_channels, mcfg.mri_feature_dim),
nn.GELU(), nn.Dropout(mcfg.mri_dropout))
def _make(self, in_ch, out_ch, n, stride, mcfg):
layers = [ResBlock3D(in_ch, out_ch, stride, mcfg.use_cbam, mcfg.use_se_block)]
for _ in range(1, n):
layers.append(ResBlock3D(out_ch, out_ch, 1, mcfg.use_cbam, mcfg.use_se_block))
return nn.Sequential(*layers)
def forward(self, x):
x = self.stem(x)
x = self.stage1(x); x = self.stage2(x); x = self.stage3(x); x = self.stage4(x)
x = self.aspp(x); x = self.pool(x).flatten(1)
return self.proj(x)
class TabularEncoder(nn.Module):
def __init__(self, mcfg: ModelConfig):
super().__init__()
prev = mcfg.num_tabular_inputs
layers = []
for h in mcfg.tabular_hidden_dims:
layers += [nn.Linear(prev, h), nn.LayerNorm(h), nn.GELU(), nn.Dropout(mcfg.tabular_dropout)]
prev = h
layers += [nn.Linear(prev, mcfg.tabular_feature_dim)]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# ============================================================
# Fusion & Heads
# ============================================================
class CrossModalFusion(nn.Module):
def __init__(self, mri_dim, tab_dim, fdim, heads=8, dropout=0.1):
super().__init__()
self.pm = nn.Linear(mri_dim, fdim); self.pt = nn.Linear(tab_dim, fdim)
self.a_mt = nn.MultiheadAttention(fdim, heads, dropout=dropout, batch_first=True)
self.a_tm = nn.MultiheadAttention(fdim, heads, dropout=dropout, batch_first=True)
self.lnm = nn.LayerNorm(fdim); self.lnt = nn.LayerNorm(fdim)
self.gate = nn.Sequential(nn.Linear(fdim * 2, fdim), nn.Sigmoid())
self.out = nn.Sequential(nn.Linear(fdim * 2, fdim), nn.GELU(), nn.Dropout(dropout))
def forward(self, m, t):
m1 = self.pm(m).unsqueeze(1); t1 = self.pt(t).unsqueeze(1)
ma, _ = self.a_mt(m1, t1, t1); ta, _ = self.a_tm(t1, m1, m1)
m2 = self.lnm(m1 + ma).squeeze(1); t2 = self.lnt(t1 + ta).squeeze(1)
cat = torch.cat([m2, t2], dim=-1)
g = self.gate(cat); o = self.out(cat)
return g * m2 + (1 - g) * t2 + o
class ClsHead(nn.Module):
def __init__(self, in_dim, num_classes, dropout=0.3):
super().__init__()
self.h = nn.Sequential(
nn.Linear(in_dim, 256), nn.GELU(), nn.Dropout(dropout),
nn.Linear(256, 128), nn.GELU(), nn.Dropout(dropout),
nn.Linear(128, num_classes))
def forward(self, x): return self.h(x)
class ProgressionHead(nn.Module):
def __init__(self, in_dim, hidden=256, max_months=120, n_bins=24):
super().__init__()
self.max_months = float(max_months); self.n_bins = n_bins
self.shared = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(0.3))
self.binary = nn.Linear(hidden, 1)
self.time = nn.Sequential(nn.Linear(hidden, 64), nn.GELU(), nn.Linear(64, 1))
def forward(self, x):
h = self.shared(x)
logits = self.binary(h).squeeze(-1)
return {
"will_progress": torch.sigmoid(logits),
"time_to_conversion": torch.clamp(F.softplus(self.time(h)).squeeze(-1), min=0.0, max=self.max_months),
}
# ============================================================
# Main Model Class
# ============================================================
class HFv3AModel(nn.Module):
def __init__(self, mcfg: ModelConfig = None):
super().__init__()
self.cfg = mcfg or ModelConfig()
self.mri_encoder = MRIEncoder3D(self.cfg)
self.tab_encoder = TabularEncoder(self.cfg)
self.mri_classifier = ClsHead(self.cfg.mri_feature_dim, self.cfg.num_classes, self.cfg.mri_dropout)
self.tab_classifier = ClsHead(self.cfg.tabular_feature_dim, self.cfg.num_classes, self.cfg.tabular_dropout)
self.fusion = CrossModalFusion(self.cfg.mri_feature_dim, self.cfg.tabular_feature_dim, self.cfg.fusion_dim, self.cfg.fusion_num_heads, self.cfg.fusion_dropout)
self.fused_classifier = ClsHead(self.cfg.fusion_dim, self.cfg.num_classes, self.cfg.fusion_dropout)
self.progression_head = ProgressionHead(self.cfg.fusion_dim, self.cfg.progression_hidden_dim, self.cfg.max_progression_months, self.cfg.num_time_bins)
def forward(self, mri=None, tab=None):
out = {}
m_feat = t_feat = None
if mri is not None:
m_feat = self.mri_encoder(mri)
out["mri_logits"] = self.mri_classifier(m_feat)
if tab is not None:
t_feat = self.tab_encoder(tab)
out["tab_logits"] = self.tab_classifier(t_feat)
if m_feat is not None and t_feat is not None:
f = self.fusion(m_feat, t_feat)
out["fused_logits"] = self.fused_classifier(f)
out["progression"] = self.progression_head(f)
elif m_feat is not None:
out["fused_logits"] = out["mri_logits"]
elif t_feat is not None:
out["fused_logits"] = out["tab_logits"]
return out
@torch.no_grad()
def predict(self, mri=None, tab=None):
self.eval()
out = self.forward(mri=mri, tab=tab)
probs = F.softmax(out["fused_logits"], dim=-1)
pred = probs.argmax(dim=-1)
result = {
"pred_class": pred,
"class_probs": probs,
"class_name": CLASS_NAMES[pred.item()]
}
if "progression" in out:
p = out["progression"]
result["will_progress"] = p["will_progress"].item()
result["time_to_conversion_months"] = p["time_to_conversion"].item()
return result
# ============================================================
# Preprocessing Helpers
# ============================================================
def load_mri_tensor(path: str, target_shape=(96, 96, 96)):
if not HAS_NIBABEL:
raise ImportError("Please install nibabel and scipy to process NIfTI MRI images: pip install nibabel scipy")
img = nib.load(path)
data = img.get_fdata().astype(np.float32)
if data.ndim == 4:
data = data[..., 0]
mask = data > 0
if mask.sum() > 0:
vals = data[mask]
lo, hi = np.percentile(vals, [1.0, 99.0])
data = np.clip(data, lo, hi)
m, s = vals.mean(), vals.std()
if s > 0:
data = (data - m) / s
data[~mask] = 0
if data.shape != target_shape:
f = [t / s for t, s in zip(target_shape, data.shape)]
data = zoom(data, f, order=1)
tensor = torch.from_numpy(np.ascontiguousarray(data)).unsqueeze(0).unsqueeze(0).float()
return tensor
def parse_clinical_data(clinical_str: str):
"""
Parses a string like "Age:75.2, Sex:1, MMSE:25" into a tabular tensor.
If normalizer values are not provided, it passes raw values.
"""
pairs = [p.strip().split(':') for p in clinical_str.split(',') if ':' in p]
val_dict = {k.strip(): float(v.strip()) for k, v in pairs}
vals = []
masks = []
for feat in FEATURE_NAMES:
if feat in val_dict:
vals.append(val_dict[feat])
masks.append(1.0)
else:
vals.append(0.0)
masks.append(0.0)
# Note: Without the original training normalizer state, scaling might be inaccurate.
# We pass the raw values here. For accurate production use, you should apply your normalizer parameters.
tab_tensor = torch.tensor(vals + masks, dtype=torch.float32).unsqueeze(0)
return tab_tensor
# ============================================================
# Inference Pipeline
# ============================================================
def run_inference(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[*] Running on device: {device}")
print("[*] Initializing Model...")
model = HFv3AModel().to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"[*] Model Parameters: {total_params:,} ({(total_params/1e6):.2f}M)")
if args.weights:
if os.path.exists(args.weights):
print(f"[*] Loading weights from {args.weights}")
try:
ckpt = torch.load(args.weights, map_location=device, weights_only=False)
# Flexible loading depending on how weights were saved
state_dict = ckpt["model"] if "model" in ckpt else (ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
model.load_state_dict(state_dict, strict=False)
except Exception as e:
print(f"[!] Error loading weights: {e}")
else:
print(f"[!] Warning: Weights file not found at {args.weights}. Using untrained model.")
model.eval()
mri_tensor = None
tab_tensor = None
if args.mri:
if os.path.exists(args.mri):
print(f"[*] Processing MRI: {args.mri}")
mri_tensor = load_mri_tensor(args.mri).to(device)
else:
print(f"[!] Error: MRI file not found at {args.mri}")
return
if args.clinical:
print(f"[*] Processing Clinical Data: {args.clinical}")
tab_tensor = parse_clinical_data(args.clinical).to(device)
if mri_tensor is None and tab_tensor is None:
print("[!] Error: You must provide either --mri or --clinical inputs.")
return
print("[*] Running Prediction...")
result = model.predict(mri=mri_tensor, tab=tab_tensor)
print("\n" + "="*40)
print(" PREDICTION RESULTS ")
print("="*40)
print(f"Diagnosis : {result['class_name']} (Class {result['pred_class'].item()})")
probs = result['class_probs'].squeeze().tolist()
print(f"Confidence (CN) : {probs[0]:.4f}")
print(f"Confidence (MCI) : {probs[1]:.4f}")
print(f"Confidence (AD) : {probs[2]:.4f}")
if "will_progress" in result:
print("-" * 40)
print(f"Progression Risk : {result['will_progress']:.2%}")
print(f"Est. Time to Convert: {result['time_to_conversion_months']:.1f} months")
print("="*40)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Vbai-2.6AD Inference Script")
parser.add_argument("--weights", type=str, help="Path to the model .pt weights file")
parser.add_argument("--mri", type=str, help="Path to the input NIfTI (.nii / .nii.gz) MRI scan")
parser.add_argument("--clinical", type=str, help='Clinical data string, e.g., "Age:75.2, Sex:1, MMSE:25"')
args = parser.parse_args()
if not any([args.mri, args.clinical]):
parser.print_help()
else:
run_inference(args)
Requirements
- Python ≥ 3.9
- PyTorch ≥ 2.0
- CUDA-capable GPU, ≥ 8 GB VRAM recommended (Tested with at least an NVIDIA RTX 5060 with 8 GB of VRAM) (Trained with NVIDIA L4 with of 24 GB of VRAM)
- See
requirements.txtfor full dependency list
License
CC-BY-NC-SA 4.0 - see LICENSE file for details.
Support
- Website: Neurazum - HealFuture
- Email: contact@neurazum.com