Vbai-2.6TS
Description
Vbai-2.6 is a 3D brain MRI segmentation model developed as the latest generation member of the Vbai model family. Unlike previous versions, Vbai-2.6TS now works exclusively with NIfTI files for professional research purposes. The Vbai-3D versions have been merged with the standard Vbai versions.
The model generates voxel-level segmentation masks instead of image-level labels and provides spatial localization of pathological regions in addition to quantitative tissue volume measurements.
Vbai-2.6TS also serves as the core engine of the HealFuture image processing library and can run each diagnostic task independently or in combination, depending on the clinical use case. This model is trained exclusively for tumors.
Audience / Target
Vbai models are developed exclusively for hospitals, universities, communities, health centres and science centres.
Architecture
| Input | Shared Encoder | Output |
|---|---|---|
| FLAIR + T1c (2ch) | ResNet3D + CBAM + SE + ASPP |
β Tumor Decoder β Binary tumor mask |
| T1-weighted (1ch) | β Tissue Decoder β CSF / GM / WM maps |
- Encoder: Custom 3D ResNet with CBAM, Squeeze-and-Excitation, and ASPP modules
- Decoder heads: Two independent UNet-style decoders with attention gates
- Deep supervision: 3 auxiliary outputs per decoder during training
- Inference: Sliding window (96Β³ patches, 50 % overlap) + optional TTA
Tasks & Classes
Tumor Segmentation
Localises intracranial tumours at the voxel level. Each output voxel is classified as:
| Class | Description |
|---|---|
| Tumor | Voxel belongs to a neoplastic region (glioma / high-grade) |
| No Tumor | Voxel is healthy parenchyma |
General Tests
| Test Size | Params | Accuracy | HD95 | F1 Score | Recall | Precision | IoU |
|---|---|---|---|---|---|---|---|
| 96Β³ | 43.9M | 99.7% | 1.41 | 87.6% | 92% | 85.7% | 77.9% |
*Tested with BRaTS 2021 Dataset but training is excluding BRaTS 2021 Dataset.
Usage
Python Script
"""
HealFuture Vbai MRI Segmentation - Vbai-2.6TS
Professional Test Script
Comprehensive evaluation including:
- Number of parameters & model summary
- Dice Score, IoU, Precision, Recall, F2-Score, Specificity
- HD95, Volume Similarity
- Per-class metrics (CSF / Gray Matter / White Matter)
- Probability map visualization (Axial / Sagittal / Coronal)
- Per-subject & aggregate test-set evaluation
- Professional TXT report + PNG charts
Usage:
python test.py # Test all volumes in the config
python test.py --task tumor # tumor only
python test.py --task tissue # tissue only
python test.py --no-vis # do not generate visuals
python test.py --show # also display visuals on screen
"""
import os, sys, json, time, warnings
warnings.filterwarnings("ignore")
from datetime import datetime
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from scipy.ndimage import zoom
import matplotlib.cm as cm
_SHOW_LIVE = "--show" in sys.argv
import matplotlib
if not _SHOW_LIVE:
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
# ============================================================================
# CONFIGURATION
# ============================================================================
CHECKPOINT_PATH = (
r"vbai-2.6ts/model/file/path"
)
OUTPUT_DIR = (
r"results/dir"
)
_DATA = (
r"data/set/dir"
)
# Tumor volumes to be tested (ground truth mask optional)
# Candidates assigned to the test set with Seed=42 (approximately) 0010, 0012, 0020, 0025...
# Patients in the training set can also be used for visualization.
def _tumor_vol(subj_id):
d = os.path.join(_DATA, f"3d-brain-mri/dataset-3d-brain/UCSF-PDGM-{subj_id}_nifti")
return {
"flair": os.path.join(d, f"UCSF-PDGM-{subj_id}_FLAIR.nii.gz"),
"t1c": os.path.join(d, f"UCSF-PDGM-{subj_id}_T1c.nii.gz"),
"mask": os.path.join(d, f"UCSF-PDGM-{subj_id}_tumor_segmentation.nii.gz"),
}
TUMOR_VOLUMES: List[Dict] = [
_tumor_vol("0017"),
_tumor_vol("0018"),
_tumor_vol("0019"),
_tumor_vol("0013"),
_tumor_vol("0023"),
]
# Tissue volumes to be tested (actual condition masks are optional)
TISSUE_VOLUMES: List[Dict] = [
{
"t1": os.path.join(_DATA, r".nii/file"),
"mask_csf": os.path.join(_DATA, r".nii/file"),
"mask_gm": os.path.join(_DATA, r".nii/file"),
"mask_wm": os.path.join(_DATA, r".nii/file"),
},
# Copy and paste to add more:
# {
# "t1": r"...",
# "mask_csf": r"...",
# "mask_gm": r"...",
# "mask_wm": r"...",
# },
]
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ============================================================================
# MODEL ARCHITECTURE
# ============================================================================
class SEBlock3D(nn.Module):
def __init__(self, ch, r=16):
super().__init__()
mid = max(ch // r, 4)
self.pool = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Sequential(nn.Linear(ch, mid), nn.ReLU(True),
nn.Linear(mid, ch), nn.Sigmoid())
def forward(self, x):
b, c = x.shape[:2]
return x * self.fc(self.pool(x).view(b, c)).view(b, c, 1, 1, 1)
class CBAM3D(nn.Module):
def __init__(self, ch, r=16, ks=7):
super().__init__()
mid = max(ch // r, 4)
self.avg = nn.AdaptiveAvgPool3d(1); self.mx = nn.AdaptiveMaxPool3d(1)
self.ch_fc = nn.Sequential(nn.Linear(ch, mid), nn.ReLU(True), nn.Linear(mid, ch))
self.sp = nn.Sequential(nn.Conv3d(2, 1, ks, padding=ks//2, bias=False), nn.BatchNorm3d(1))
def forward(self, x):
b, c = x.shape[:2]
ch = torch.sigmoid(self.ch_fc(self.avg(x).view(b,c)) +
self.ch_fc(self.mx(x).view(b,c))).view(b,c,1,1,1)
x = x * ch
sp = torch.sigmoid(self.sp(torch.cat([x.mean(1,True), x.max(1,True).values], 1)))
return x * sp
class ResBlock3D(nn.Module):
def __init__(self, ic, oc, stride=1, drop=0.1, se=True, cbam=True):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(ic, oc, 3, stride, 1, bias=False), nn.BatchNorm3d(oc), nn.ReLU(True),
nn.Dropout3d(drop),
nn.Conv3d(oc, oc, 3, 1, 1, bias=False), nn.BatchNorm3d(oc))
self.skip = (nn.Sequential(nn.Conv3d(ic, oc, 1, stride, bias=False), nn.BatchNorm3d(oc))
if ic != oc or stride != 1 else nn.Identity())
self.se = SEBlock3D(oc) if se else nn.Identity()
self.cbam = CBAM3D(oc) if cbam else nn.Identity()
self.act = nn.ReLU(True)
def forward(self, x):
return self.act(self.cbam(self.se(self.conv(x))) + self.skip(x))
class ASPP3D(nn.Module):
def __init__(self, ic, oc, dils=(1,3,6)):
super().__init__()
mid = oc // (len(dils) + 2)
self.branches = nn.ModuleList([
nn.Sequential(nn.Conv3d(ic, mid, 3, padding=d, dilation=d, bias=False),
nn.BatchNorm3d(mid), nn.ReLU(True)) for d in dils])
self.gp = nn.Sequential(nn.AdaptiveAvgPool3d(1), nn.Conv3d(ic, mid, 1, bias=False), nn.ReLU(True))
self.pw = nn.Sequential(nn.Conv3d(ic, mid, 1, bias=False), nn.BatchNorm3d(mid), nn.ReLU(True))
self.proj = nn.Sequential(nn.Conv3d(mid*(len(dils)+2), oc, 1, bias=False),
nn.BatchNorm3d(oc), nn.ReLU(True), nn.Dropout3d(0.1))
def forward(self, x):
sz = x.shape[2:]
fs = [b(x) for b in self.branches]
fs.append(F.interpolate(self.gp(x), sz, mode="trilinear", align_corners=False))
fs.append(self.pw(x))
return self.proj(torch.cat(fs, 1))
class AttGate3D(nn.Module):
def __init__(self, fc, gc):
super().__init__()
ic = fc // 2
self.Wf = nn.Sequential(nn.Conv3d(fc, ic, 1, bias=False), nn.BatchNorm3d(ic))
self.Wg = nn.Sequential(nn.Conv3d(gc, ic, 1, bias=False), nn.BatchNorm3d(ic))
self.ps = nn.Sequential(nn.Conv3d(ic, 1, 1, bias=False), nn.BatchNorm3d(1), nn.Sigmoid())
self.r = nn.ReLU(True)
def forward(self, feat, gate):
if gate.shape[2:] != feat.shape[2:]:
gate = F.interpolate(gate, feat.shape[2:], mode="trilinear", align_corners=False)
return feat * self.ps(self.r(self.Wf(feat) + self.Wg(gate)))
class EncBlock(nn.Module):
def __init__(self, ic, oc, drop=0.1):
super().__init__()
self.blk = nn.Sequential(ResBlock3D(ic, oc, drop=drop), ResBlock3D(oc, oc, drop=drop))
self.down = nn.Sequential(nn.Conv3d(oc, oc, 3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(oc), nn.ReLU(True))
def forward(self, x):
s = self.blk(x); return s, self.down(s)
class DecBlock(nn.Module):
def __init__(self, ic, sc, oc, drop=0.1, ag=True):
super().__init__()
self.ag = AttGate3D(sc, ic) if ag else None
self.blk = nn.Sequential(ResBlock3D(ic+sc, oc, drop=drop), ResBlock3D(oc, oc, drop=drop))
def forward(self, x, skip):
x = F.interpolate(x, skip.shape[2:], mode="trilinear", align_corners=False)
if self.ag: skip = self.ag(skip, x)
return self.blk(torch.cat([x, skip], 1))
class HFSegNetMultiTask(nn.Module):
def __init__(self, bc=32, mults=(1,2,4,8,10), drop=0.1, ds=True):
super().__init__()
ch = [bc * m for m in mults]
self.ds = ds
kw = dict(drop=drop)
self.stem_t = nn.Sequential(nn.Conv3d(2, ch[0], 3, 1, 1, bias=False), nn.BatchNorm3d(ch[0]), nn.ReLU(True))
self.stem_s = nn.Sequential(nn.Conv3d(1, ch[0], 3, 1, 1, bias=False), nn.BatchNorm3d(ch[0]), nn.ReLU(True))
self.e0 = EncBlock(ch[0], ch[0], **kw)
self.e1 = EncBlock(ch[0], ch[1], **kw)
self.e2 = EncBlock(ch[1], ch[2], **kw)
self.e3 = EncBlock(ch[2], ch[3], **kw)
self.bn = nn.Sequential(ResBlock3D(ch[3], ch[4], **kw), ASPP3D(ch[4], ch[4]))
for tag, oc_list in [("t", [ch[3],ch[2],ch[1],ch[0]]), ("s", [ch[3],ch[2],ch[1],ch[0]])]:
for i, (ic, sc, oc) in enumerate(zip([ch[4],ch[3],ch[2],ch[1]], [ch[3],ch[2],ch[1],ch[0]], oc_list)):
setattr(self, f"d{tag}{i}", DecBlock(ic, sc, oc, **kw))
setattr(self, f"head_{tag}", nn.Conv3d(ch[0], 1 if tag=="t" else 3, 1))
if ds:
for i, c in enumerate([ch[3], ch[2], ch[1]]):
setattr(self, f"ds_{tag}{i}", nn.Conv3d(c, 1 if tag=="t" else 3, 1))
self._init()
def _init(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm3d):
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
def _encode(self, x, task):
stem = self.stem_t if task == "tumor" else self.stem_s
s = stem(x)
k0,d0 = self.e0(s); k1,d1 = self.e1(d0); k2,d2 = self.e2(d1); k3,d3 = self.e3(d2)
return self.bn(d3), k3, k2, k1, k0, x.shape[2:]
def _decode(self, bn, k3, k2, k1, k0, inp_sz, tag):
u3 = getattr(self, f"d{tag}0")(bn, k3); u2 = getattr(self, f"d{tag}1")(u3, k2)
u1 = getattr(self, f"d{tag}2")(u2, k1); u0 = getattr(self, f"d{tag}3")(u1, k0)
out = getattr(self, f"head_{tag}")(u0)
if not self.ds: return out, None
aux = [F.interpolate(getattr(self, f"ds_{tag}{i}")(u), inp_sz, mode="trilinear", align_corners=False)
for i, u in enumerate([u3, u2, u1])]
return out, aux
def forward(self, x, task, return_aux=False):
tag = "t" if task == "tumor" else "s"
bn, k3, k2, k1, k0, inp_sz = self._encode(x, task)
out, aux = self._decode(bn, k3, k2, k1, k0, inp_sz, tag)
if return_aux: return out, aux
return out
# ============================================================================
# MODEL ANALYSIS
# ============================================================================
def print_model_summary(model: nn.Module):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
shared = sum(p.numel() for n, p in model.named_parameters()
if not any(x in n for x in ["stem_t", "stem_s", "dt", "ds_", "head_t",
"ds0", "ds1", "ds2", "ds3", "head_s"]))
tumor_p = sum(p.numel() for n, p in model.named_parameters()
if any(x in n for x in ["stem_t", "dt", "head_t", "ds_t"]))
tissue_p = sum(p.numel() for n, p in model.named_parameters()
if any(x in n for x in ["stem_s", "ds", "head_s"]) and "ds_t" not in n)
print("\n" + "=" * 70)
print(" MODEL PARAMS SUMMARY")
print("=" * 70)
print(f"\n {'Component':<35} {'Params':>12} {'Size (MB)':>12}")
print(" " + "-" * 60)
print(f" {'Shared Encoder + Bottleneck':<35} {shared:>12,} {shared*4/1024/1024:>10.2f}")
print(f" {'Tumor Decoder':<35} {tumor_p:>12,} {tumor_p*4/1024/1024:>10.2f}")
print(f" {'Tissue Decoder':<35} {tissue_p:>12,} {tissue_p*4/1024/1024:>10.2f}")
print(" " + "-" * 60)
print(f" {'TOTAL PARAMS':<35} {total:>12,} {total*4/1024/1024:>10.2f}")
print(f" {'Trainable':<35} {trainable:>12,}")
print("=" * 70 + "\n")
# ============================================================================
# PREDICTOR CLASS
# ============================================================================
class HFSegPredictor:
VSZ = (96, 96, 96)
def __init__(self, checkpoint_path: str, device: str = None):
self.device = torch.device(device if device else
("cuda" if torch.cuda.is_available() else "cpu"))
print(f" Model loading: {checkpoint_path}")
self.model, self.ckpt_info = self._load_model(checkpoint_path)
print(f" β Model is loaded β {self.device}")
def _load_model(self, path: str):
ckpt = torch.load(path, map_location="cpu", weights_only=False)
model = HFSegNetMultiTask(ds=False).to(self.device)
state = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing: print(f" β Missing ({len(missing)}): {missing[:2]}")
if unexpected: print(f" β Extra ({len(unexpected)}): {unexpected[:2]}")
model.eval()
info = {
"epoch": ckpt.get("epoch", "?"),
"best_score": ckpt.get("best", "?"),
"total_params": sum(p.numel() for p in model.parameters()),
}
return model, info
@torch.no_grad()
def _infer(self, vol: torch.Tensor, task: str,
patch: int = 96, overlap: float = 0.5,
use_tta: bool = True) -> np.ndarray:
stride = max(1, int(patch * (1 - overlap)))
C, D, H, W = vol.shape
C_out = 1 if task == "tumor" else 3
acc = np.zeros((C_out, D, H, W), np.float32)
cnt = np.zeros_like(acc)
def starts(dim):
s = list(range(0, dim - patch + 1, stride))
if not s or s[-1] + patch < dim: s.append(max(0, dim - patch))
return s
for d0 in starts(D):
for h0 in starts(H):
for w0 in starts(W):
p = vol[:, d0:d0+patch, h0:h0+patch, w0:w0+patch].unsqueeze(0).to(self.device)
pad = [0,max(0,patch-p.shape[4]), 0,max(0,patch-p.shape[3]),
0,max(0,patch-p.shape[2])]
if any(x > 0 for x in pad): p = F.pad(p, pad)
prob = torch.sigmoid(self.model(p, task))[0].cpu().numpy()
if use_tta:
probs = [prob]
for ax in [2, 3, 4]:
fp = torch.sigmoid(self.model(torch.flip(p, [ax]), task))[0].cpu().numpy()
probs.append(np.flip(fp, ax-2).copy())
prob = np.mean(probs, axis=0)
pd=min(patch,D-d0); ph=min(patch,H-h0); pw=min(patch,W-w0)
acc[:, d0:d0+pd, h0:h0+ph, w0:w0+pw] += prob[:, :pd, :ph, :pw]
cnt[:, d0:d0+pd, h0:h0+ph, w0:w0+pw] += 1.
return acc / np.maximum(cnt, 1e-8)
def predict_tumor(self, flair_path: str, t1c_path: str,
mask_path: str = None, use_tta: bool = True) -> dict:
if not os.path.exists(flair_path):
return {"error": f"File not found: {flair_path}"}
try:
flair_raw = _load_nii(flair_path)
t1c_raw = _load_nii(t1c_path)
orig_shape = flair_raw.shape
flair_r = _resamp(_zscore(flair_raw), self.VSZ)
t1c_r = _resamp(_zscore(t1c_raw), self.VSZ)
vol = torch.tensor(np.stack([flair_r, t1c_r]), dtype=torch.float32)
t0 = time.time()
prob = self._infer(vol, "tumor", use_tta=use_tta)
elapsed = time.time() - t0
result = {
"file": os.path.basename(flair_path),
"type": "tumor",
"prob_map": prob[0],
"volume_resized": flair_r,
"elapsed": elapsed,
"orig_shape": orig_shape,
"detection": _tumor_detection_info(prob[0]),
}
if mask_path and os.path.exists(mask_path):
gt_r = _resamp(_load_nii(mask_path), self.VSZ, order=0)
result["gt"] = gt_r
result["metrics"] = _tumor_metrics(prob[0], gt_r)
return result
except Exception as e:
return {"error": str(e), "file": flair_path}
def predict_tissue(self, t1_path: str,
mask_csf: str = None, mask_gm: str = None,
mask_wm: str = None, use_tta: bool = True) -> dict:
if not os.path.exists(t1_path):
return {"error": f"File not found: {t1_path}"}
try:
t1_raw = _load_nii(t1_path)
orig_shape = t1_raw.shape
t1_r = _resamp(_zscore(t1_raw), self.VSZ)
vol = torch.tensor(t1_r[None], dtype=torch.float32)
t0 = time.time()
prob = self._infer(vol, "tissue", use_tta=use_tta)
elapsed = time.time() - t0
result = {
"file": os.path.basename(t1_path),
"type": "tissue",
"prob_map": prob, # (3, D, H, W) β CSF/GM/WM
"volume_resized": t1_r,
"elapsed": elapsed,
"orig_shape": orig_shape,
}
paths = [mask_csf, mask_gm, mask_wm]
if all(p and os.path.exists(p) for p in paths):
gt = np.stack([
np.clip(_resamp(_load_nii(p), self.VSZ, order=1), 0, 1)
for p in paths
])
result["gt"] = gt
result["metrics"] = _tissue_metrics(prob, gt)
return result
except Exception as e:
return {"error": str(e), "file": t1_path}
# ============================================================================
# HELPER: NIfTI
# ============================================================================
def _load_nii(path: str) -> np.ndarray:
try:
d = np.asarray(nib.load(path).dataobj, dtype=np.float32)
return np.nan_to_num(d, nan=0., posinf=0., neginf=0.)
except Exception as e:
raise RuntimeError(f"The NIfTI file could not be read: {path} β {e}")
def _zscore(v: np.ndarray) -> np.ndarray:
mask = v > 0
if mask.any(): return (v - v[mask].mean()) / (v[mask].std() + 1e-8)
lo, hi = np.percentile(v, 1), np.percentile(v, 99)
return np.clip((v - lo) / (hi - lo + 1e-8), 0., 1.)
def _resamp(v: np.ndarray, tgt: tuple, order: int = 1) -> np.ndarray:
return zoom(v, [t/c for t, c in zip(tgt, v.shape)], order=order).astype(np.float32)
# ============================================================================
# METRICS
# ============================================================================
def _tumor_metrics(pred_prob: np.ndarray, gt: np.ndarray, thr: float = 0.5) -> dict:
pred = (pred_prob >= thr).astype(float).flatten()
true = (gt > 0.5).astype(float).flatten()
sm = 1e-7
tp = (pred*true).sum(); fp = (pred*(1-true)).sum()
fn = ((1-pred)*true).sum(); tn = ((1-pred)*(1-true)).sum()
dice = (2*tp+sm) / (2*tp+fp+fn+sm)
iou = (tp+sm) / (tp+fp+fn+sm)
prec = (tp+sm) / (tp+fp+sm)
rec = (tp+sm) / (tp+fn+sm)
f2 = (5*tp+sm) / (5*tp+4*fn+fp+sm)
spec = (tn+sm) / (tn+fp+sm)
vol_sim = 1 - abs(pred.sum()-true.sum()) / (pred.sum()+true.sum()+sm)
hd95 = float("nan")
try:
from scipy.ndimage import binary_erosion
from scipy.spatial import KDTree
pb = (pred_prob >= thr).astype(bool); gb = (gt > 0.5).astype(bool)
if pb.any() and gb.any():
def surf(m): return np.stack(np.where(m & ~binary_erosion(m)), 1).astype(float)
sp, sg = surf(pb), surf(gb)
if len(sp) and len(sg):
hd95 = float(np.percentile(
np.concatenate([KDTree(sg).query(sp)[0], KDTree(sp).query(sg)[0]]), 95))
except Exception:
pass
return {
"Dice": round(float(dice), 4),
"IoU": round(float(iou), 4),
"Precision": round(float(prec), 4),
"Recall": round(float(rec), 4),
"F2-Score": round(float(f2), 4),
"Specificity": round(float(spec), 4),
"HD95 (vx)": round(hd95, 2) if not np.isnan(hd95) else "N/A",
"Vol.Sim": round(float(vol_sim), 4),
}
def _tumor_detection_info(prob_map: np.ndarray, thr: float = 0.5) -> dict:
binary = (prob_map >= thr)
detected = bool(binary.any())
vol_vx = int(binary.sum())
vol_cm3 = round(vol_vx / 1000.0, 2)
max_conf = round(float(prob_map.max()), 4)
mean_conf = round(float(prob_map[binary].mean()), 4) if detected else 0.0
return {
"detected": detected,
"volume_vx": vol_vx,
"volume_cm3": vol_cm3,
"max_confidence": max_conf,
"mean_confidence": mean_conf,
}
def _tissue_metrics(pred_prob: np.ndarray, gt: np.ndarray) -> dict:
names = ["CSF", "GrayMatter", "WhiteMatter"]
result = {}; dices = []; sm = 1e-7
for i, name in enumerate(names):
p = pred_prob[i]; g = gt[i]
pb = (p>=0.5).astype(float); gb = (g>=0.5).astype(float)
tp = (pb*gb).sum(); fp = (pb*(1-gb)).sum(); fn = ((1-pb)*gb).sum()
dice = (2*tp+sm)/(2*tp+fp+fn+sm)
iou = (tp+sm)/(tp+fp+fn+sm)
mse = float(np.mean((p-g)**2))
corr = float(np.corrcoef(p.flatten(), g.flatten())[0,1]) if p.std()>1e-8 else 0.
result[name] = {"Dice": round(float(dice),4), "IoU": round(float(iou),4),
"MSE": round(mse,6), "Corr": round(corr,4)}
dices.append(float(dice))
result["Mean Dice"] = round(float(np.mean(dices)), 4)
result["Mean IoU"] = round(float(np.mean([result[n]["IoU"] for n in names])), 4)
return result
# ============================================================================
# THEME
# ============================================================================
MEDICAL_CMAP = LinearSegmentedColormap.from_list(
"medical", ["#000033","#0000FF","#00FFFF","#00FF00","#FFFF00","#FF0000"], N=256)
TUMOR_CMAP = LinearSegmentedColormap.from_list("tumor", ["#00000000","#FF3333DD"])
CSF_CMAP = LinearSegmentedColormap.from_list("csf", ["#00000000","#3399FFDD"])
GM_CMAP = LinearSegmentedColormap.from_list("gm", ["#00000000","#33FF99DD"])
WM_CMAP = LinearSegmentedColormap.from_list("wm", ["#00000000","#FFAA33DD"])
BG = "#0D0D0D"
def _evenly(dim, n): return [int(dim*(i+1)/(n+1)) for i in range(n)]
# ============================================================================
# IMAGE: TUMOR
# ============================================================================
def visualize_tumor_prediction(result: dict, save_path: str = None) -> Optional[str]:
if "error" in result:
print(f" Image is not generated: {result['error']}"); return None
vol = result["volume_resized"]
prob = result["prob_map"]
prob_n = (prob - prob.min()) / (prob.max() - prob.min() + 1e-8)
binary = (prob >= 0.5).astype(float)
D, H, W = vol.shape
cd, ch, cw = D//2, H//2, W//2
det = result.get("detection", {})
if det.get("detected"):
det_line = (f"β TUMOR DETECTED | "
f"Volume: ~{det['volume_cm3']} cmΒ³ | "
f"Confidence: {det['mean_confidence']:.1%} | "
f"Max: {det['max_confidence']:.1%}")
det_color = "#FF5555"
else:
det_line = "β No Tumor"
det_color = "#55FF55"
dice_str = f" | Dice: {result['metrics']['Dice']:.4f}" if "metrics" in result else ""
note_str = (βNote: Staging (Grade IβIV) requires a separate classification model.β)
fig, axes = plt.subplots(3, 4, figsize=(22, 16), facecolor=BG)
fig.text(0.5, 0.99, f"Vbai-2.6TS Β· Tumor Segmentation β {result['file']}",
ha="center", va="top", color="white", fontsize=12, fontweight="bold")
fig.text(0.5, 0.965, det_line + dice_str,
ha="center", va="top", color=det_color, fontsize=11, fontweight="bold")
fig.text(0.5, 0.945, note_str,
ha="center", va="top", color="#888888", fontsize=8)
col_labels = ["MRI (Referance)", "Tumor mask (>0.5)", "Prob Map (Jet)", "Overlay + GT"]
col_colors = ["#AAAAAA", "#FF8888", "#88AAFF", "#88FF88"]
def axial_slices(z): return vol[z,:,:], binary[z,:,:], prob_n[z,:,:]
def sagittal_slices(x): return vol[:,:,x].T, binary[:,:,x].T, prob_n[:,:,x].T
def coronal_slices(y): return vol[:,y,:].T, binary[:,y,:].T, prob_n[:,y,:].T
views = [
("Axial", cd, axial_slices,
None if "gt" not in result else result["gt"][cd,:,:]),
("Sagittal", cw, sagittal_slices,
None if "gt" not in result else result["gt"][:,:,cw].T),
("Coronal", ch, coronal_slices,
None if "gt" not in result else result["gt"][:,ch,:].T),
]
for row, (view_name, idx, slicer, gt_sl) in enumerate(views):
mri_sl, bin_sl, pn_sl = slicer(idx)
mri_rgb = np.stack([mri_sl]*3, axis=-1) # (H,W,3)
ax = axes[row, 0]
ax.imshow(mri_sl, cmap="gray", vmin=0, vmax=1, origin="lower", aspect="auto")
ax.set_title(f"{view_name} β {col_labels[0]}", color=col_colors[0], fontsize=9)
ax.axis("off"); [sp.set_visible(False) for sp in ax.spines.values()]
ax = axes[row, 1]
ax.imshow(mri_sl, cmap="gray", vmin=0, vmax=1, origin="lower", aspect="auto")
if bin_sl.any():
ax.imshow(bin_sl, cmap=TUMOR_CMAP, vmin=0, vmax=1,
origin="lower", aspect="auto", alpha=0.65)
ax.set_title(f"{view_name} β {col_labels[1]}", color=col_colors[1], fontsize=9)
ax.axis("off"); [sp.set_visible(False) for sp in ax.spines.values()]
ax = axes[row, 2]
ax.imshow(mri_sl, cmap="gray", vmin=0, vmax=1, origin="lower", aspect="auto")
ax.imshow(pn_sl, cmap="jet", vmin=0, vmax=1,
origin="lower", aspect="auto", alpha=0.5)
ax.set_title(f"{view_name} β {col_labels[2]}", color=col_colors[2], fontsize=9)
ax.axis("off"); [sp.set_visible(False) for sp in ax.spines.values()]
heatmap = cm.jet(pn_sl)[:, :, :3]
overlay = np.clip(0.55*mri_rgb + 0.45*heatmap, 0, 1)
ax = axes[row, 3]
ax.imshow(overlay, origin="lower", aspect="auto")
if gt_sl is not None:
gt_bin = (gt_sl > 0.5).astype(float)
ax.contour(gt_bin, levels=[0.5], colors=["#00FF88"],
linewidths=1.8, origin="lower")
title3 = col_labels[3] + (" (green=GT)" if gt_sl is not None else "")
ax.set_title(f"{view_name} β {title3}", color=col_colors[3], fontsize=9)
ax.axis("off"); [sp.set_visible(False) for sp in ax.spines.values()]
plt.tight_layout(rect=[0, 0, 1, 0.93])
return _save_or_show(fig, save_path, f"tumor_{result['file'].replace('.nii.gz','').replace('.nii','')}")
# ============================================================================
# IMAGE: TISSUE
# ============================================================================
def visualize_tissue_prediction(result: dict, save_path: str = None) -> Optional[str]:
if "error" in result:
print(f" Image is not generated: {result['error']}"); return None
vol = result["volume_resized"] # (D,H,W)
prob = result["prob_map"] # (3,D,H,W)
D = vol.shape[0]
slices = _evenly(D, 4)
tissues = [
("CSF", CSF_CMAP, "#3399FF", 0),
("Gray Matter", GM_CMAP, "#33FF99", 1),
("White Matter",WM_CMAP, "#FFAA33", 2),
]
has_gt = "gt" in result
metrics = result.get("metrics", {})
n_rows = len(tissues) * (1 + int(has_gt))
fig = plt.figure(figsize=(4*4.5, n_rows*3.0), facecolor=BG)
fig.suptitle(f"Vbai-2.6TS Β· Tissue Segmentation\n{result['file']}",
color="white", fontsize=13, fontweight="bold", y=1.01)
gs = gridspec.GridSpec(n_rows, 4, figure=fig, hspace=0.04, wspace=0.04)
row = 0
for name, cmap, color, ti in tissues:
pr = prob[ti]
pr_n = (pr - pr.min()) / (pr.max() - pr.min() + 1e-8)
dice_str = f" Dice: {metrics[name]['Dice']:.4f}" if name in metrics else ""
for col, z in enumerate(slices):
mri_sl = vol[z]
pr_sl = pr_n[z]
ax = fig.add_subplot(gs[row, col])
if col == 0:
ax.imshow(mri_sl.T, cmap="gray", origin="lower", aspect="auto")
ax.set_ylabel(f"{name}{dice_str}", color=color,
fontsize=8, rotation=90, va="center", labelpad=4)
elif col == 1:
im = ax.imshow(pr_sl.T, cmap="jet", vmin=0, vmax=1,
origin="lower", aspect="auto")
if row == 0: ax.set_title("Jet Map", color="#AAAAAA", fontsize=9)
elif col == 2:
ax.imshow(pr_sl.T, cmap=MEDICAL_CMAP, vmin=0, vmax=1,
origin="lower", aspect="auto")
if row == 0: ax.set_title("Medical Map", color="#AAAAAA", fontsize=9)
else:
heatmap = cm.jet(pr_sl.T)[:, :, :3]
mri_rgb = np.stack([mri_sl.T]*3, axis=-1)
overlay = np.clip(0.6*mri_rgb + 0.4*heatmap, 0, 1)
ax.imshow(overlay, origin="lower", aspect="auto")
if row == 0: ax.set_title("MRI + Prediction", color="#AAAAAA", fontsize=9)
if "gt" in result:
gt_sl = result["gt"][ti][z]
ax.contour((gt_sl>0.5).astype(float).T, levels=[0.5],
colors=["#00FF88"], linewidths=1.2, origin="lower")
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
[sp.set_visible(False) for sp in ax.spines.values()]
if has_gt and name in metrics:
row_gt = row + 1
for col, z in enumerate(slices):
ax = fig.add_subplot(gs[row_gt, col])
if col == 0:
ax.imshow(vol[z].T, cmap="gray", origin="lower", aspect="auto")
ax.set_ylabel("Real Mask", color="#AAAAAA",
fontsize=8, rotation=90, va="center", labelpad=4)
else:
ax.imshow(vol[z].T, cmap="gray", origin="lower", aspect="auto", alpha=0.6)
ax.imshow(result["gt"][ti][z].T, cmap=cmap, vmin=0, vmax=1,
origin="lower", aspect="auto", alpha=0.7)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
[sp.set_visible(False) for sp in ax.spines.values()]
row += 1 + int(has_gt)
plt.tight_layout(rect=[0.03, 0, 1, 1.0])
return _save_or_show(fig, save_path, f"tissue_{result['file'].replace('.nii','')}")
def _save_or_show(fig, path, default_name: str) -> str:
if path is None:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = os.path.join(OUTPUT_DIR, f"{default_name}_{ts}.png")
fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG)
print(f" Image β {path}")
if _SHOW_LIVE: plt.show()
plt.close(fig)
return path
# ============================================================================
# TEST FUNCTIONS
# ============================================================================
def test_tumor_volumes(predictor: HFSegPredictor,
volumes: List[Dict],
visualize: bool = True) -> List[dict]:
"""HF-v2'nin test_3d_volumes() fonksiyonu ile aynΔ± yaklaΕΔ±m"""
print("\n" + "=" * 70)
print(" TUMOR SEGMENTATION TEST")
print("=" * 70)
valid = [v for v in volumes if os.path.exists(v.get("flair", ""))]
if not valid:
print(" No valid tumor volume path found!")
print(" β Update list TUMOR_VOLUMES.")
return []
print(f" {len(valid)} volume test ediliyor...\n")
results = []
for i, vol in enumerate(valid, 1):
name = os.path.basename(vol["flair"])
print(f" [{i}/{len(valid)}] {name}")
result = predictor.predict_tumor(
vol["flair"], vol["t1c"],
vol.get("mask"), use_tta=True
)
results.append(result)
if "error" in result:
print(f" ERR: {result['error']}")
else:
print(f" SΓΌre : {result['elapsed']:.1f}s")
det = result.get("detection", {})
if det:
if det["detected"]:
print(f" β TUMOR DETECTED")
print(f" Vol : ~{det['volume_cm3']} cmΒ³ ({det['volume_vx']} voxel)")
print(f" Conf : ort {det['mean_confidence']:.1%} | maks {det['max_confidence']:.1%}")
print(f" (Note: A separate classification model is required for stage prediction)")
else:
print(f" β No Tumor Detected (maks prob: {det['max_confidence']:.1%})")
if "metrics" in result:
m = result["metrics"]
print(f" Dice : {m['Dice']:.4f} IoU: {m['IoU']:.4f} "
f"Recall: {m['Recall']:.4f} HD95: {m['HD95 (vx)']}")
_print_prob_bar("Dice", m["Dice"])
_print_prob_bar("IoU", m["IoU"])
_print_prob_bar("Recall", m["Recall"])
_print_prob_bar("F2-Score",m["F2-Score"])
else:
print(" (Metrics is not calculated)")
if visualize:
visualize_tumor_prediction(result)
print()
return results
def test_tissue_volumes(predictor: HFSegPredictor,
volumes: List[Dict],
visualize: bool = True) -> List[dict]:
print("\n" + "=" * 70)
print(" TISSUE SEGMENTATION TEST")
print("=" * 70)
valid = [v for v in volumes if os.path.exists(v.get("t1", ""))]
if not valid:
print(" The valid tissue volume path could not be found!")
print(" β Update list TISSUE_VOLUMES.")
return []
print(f" {len(valid)} volume testing...\n")
results = []
for i, vol in enumerate(valid, 1):
name = os.path.basename(vol["t1"])
print(f" [{i}/{len(valid)}] {name}")
result = predictor.predict_tissue(
vol["t1"],
vol.get("mask_csf"), vol.get("mask_gm"), vol.get("mask_wm"),
use_tta=True
)
results.append(result)
if "error" in result:
print(f" ERR: {result['error']}")
else:
print(f" Duration : {result['elapsed']:.1f}s")
if "metrics" in result:
m = result["metrics"]
print(f" {'Tissue':<14} {'Dice':>7} {'IoU':>7} {'MSE':>9}")
print(" " + "-" * 38)
for t in ["CSF","GrayMatter","WhiteMatter"]:
print(f" {t:<14} {m[t]['Dice']:>7.4f} {m[t]['IoU']:>7.4f} {m[t]['MSE']:>9.6f}")
print(f" {'β'*38}")
print(f" {'Mean Dice':<14} {m['Mean Dice']:>7.4f} "
f"Mean IoU: {m['Mean IoU']:.4f}")
else:
print(" (Metrics is not calculated)")
if visualize:
visualize_tissue_prediction(result)
print()
return results
def _print_prob_bar(label: str, v, width: int = 20):
if not isinstance(v, float) or not (0 <= v <= 1): return
bar = "β" * int(v*width) + "β" * (width - int(v*width))
print(f" {label:<12} {bar} {v:.4f}")
# ============================================================================
# SUMMARY REPORT
# ============================================================================
def print_summary(results_tumor: list, results_tissue: list):
print("\n" + "=" * 80)
print(" TEST REPORT")
print("=" * 80)
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f" Date: {ts}\n")
# ββ Tumor sum βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
valid_t = [r for r in results_tumor if "error" not in r and "metrics" in r]
if valid_t:
print(" TUMOR SEGMENTATION")
print(" " + "-" * 76)
print(f" {'Metric':<20} {'Min':>8} {'Max':>8} {'Average':>10} {'Std':>8}")
print(" " + "-" * 56)
metric_keys = ["Dice","IoU","Precision","Recall","F2-Score","Specificity","Vol.Sim"]
for k in metric_keys:
vals = [r["metrics"][k] for r in valid_t if isinstance(r["metrics"].get(k), float)]
if vals:
print(f" {k:<20} {min(vals):>8.4f} {max(vals):>8.4f} "
f"{np.mean(vals):>10.4f} {np.std(vals):>8.4f}")
hd_vals = [r["metrics"]["HD95 (vx)"] for r in valid_t
if isinstance(r["metrics"].get("HD95 (vx)"), float)]
if hd_vals:
print(f" {'HD95 (vx)':<20} {min(hd_vals):>8.2f} {max(hd_vals):>8.2f} "
f"{np.mean(hd_vals):>10.2f} {np.std(hd_vals):>8.2f}")
print()
# ββ Tissue sum ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
valid_s = [r for r in results_tissue if "error" not in r and "metrics" in r]
if valid_s:
print(" TISSUE SEGMENTATION")
print(" " + "-" * 76)
print(f" {'Tissue':<16} {'Dice Avg':>10} {'Dice Std':>10} {'IoU Avg':>10} {'MSE Avg':>10}")
print(" " + "-" * 58)
for tissue in ["CSF","GrayMatter","WhiteMatter"]:
dices = [r["metrics"][tissue]["Dice"] for r in valid_s]
ious = [r["metrics"][tissue]["IoU"] for r in valid_s]
mses = [r["metrics"][tissue]["MSE"] for r in valid_s]
print(f" {tissue:<16} {np.mean(dices):>10.4f} {np.std(dices):>10.4f} "
f"{np.mean(ious):>10.4f} {np.mean(mses):>10.6f}")
mean_dices = [r["metrics"]["Mean Dice"] for r in valid_s]
print(f" {'β'*58}")
print(f" {'AVG':<16} {np.mean(mean_dices):>10.4f}\n")
print("=" * 80)
# ============================================================================
# METRIC BAR CHART
# ============================================================================
def plot_metrics_chart(results: list, task: str, save_path: str = None) -> Optional[str]:
valid = [r for r in results if "error" not in r and "metrics" in r]
if not valid: return None
if task == "tumor":
keys = ["Dice","IoU","Precision","Recall","F2-Score","Specificity","Vol.Sim"]
colors = ["#4488FF","#44BBFF","#FF8844","#44FF88","#FFAA44","#AA44FF","#FF4488"]
title = "TΓΌmΓΆr Segmentasyonu β Metrik KarΕΔ±laΕtΔ±rmasΔ±"
n_bars = len(keys)
names = [os.path.basename(r["file"])[:20] for r in valid]
data = {k: [r["metrics"].get(k, 0) for r in valid
if isinstance(r["metrics"].get(k), float)] for k in keys}
fig, axes = plt.subplots(1, n_bars, figsize=(n_bars*2.8, max(4, len(valid)*0.6+2)),
facecolor=BG)
fig.suptitle(title, color="white", fontsize=13, fontweight="bold")
for ax, key, color in zip(axes, keys, colors):
vals = [r["metrics"].get(key, 0) for r in valid]
vals = [v if isinstance(v, float) else 0 for v in vals]
y = np.arange(len(names))
bars = ax.barh(y, vals, color=color, alpha=0.85)
ax.set_xlim(0, 1.05)
ax.set_yticks(y); ax.set_yticklabels(names if key==keys[0] else [], color="#AAAAAA", fontsize=7)
ax.set_title(key, color="white", fontsize=9)
ax.set_facecolor(BG)
for bar, v in zip(bars, vals):
ax.text(min(v+0.02, 1.0), bar.get_y()+bar.get_height()/2,
f"{v:.3f}", va="center", color="white", fontsize=7)
ax.tick_params(colors="#AAAAAA"); ax.spines[:].set_color("#333333")
else: # tissue
tissues = ["CSF","GrayMatter","WhiteMatter"]
t_colors = ["#3399FF","#33FF99","#FFAA33"]
metric_keys = ["Dice","IoU"]
fig, axes = plt.subplots(len(metric_keys), len(tissues),
figsize=(len(tissues)*4, len(metric_keys)*3.5), facecolor=BG)
fig.suptitle("Tissue Segmentation β Per-Subject Metrics",
color="white", fontsize=13, fontweight="bold")
names = [r["file"][:20] for r in valid]
for ri, metric in enumerate(metric_keys):
for ci, (tissue, color) in enumerate(zip(tissues, t_colors)):
ax = axes[ri, ci]
vals = [r["metrics"][tissue][metric] for r in valid]
y = np.arange(len(names))
bars = ax.barh(y, vals, color=color, alpha=0.85)
ax.set_xlim(0, 1.05)
ax.set_yticks(y)
ax.set_yticklabels(names if ci==0 else [], color="#AAAAAA", fontsize=7)
ax.set_title(f"{tissue} β {metric}", color="white", fontsize=9)
ax.set_facecolor(BG)
for bar, v in zip(bars, vals):
ax.text(min(v+0.02,1.0), bar.get_y()+bar.get_height()/2,
f"{v:.3f}", va="center", color="white", fontsize=7)
ax.tick_params(colors="#AAAAAA"); ax.spines[:].set_color("#333333")
plt.tight_layout()
if save_path is None:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(OUTPUT_DIR, f"metrics_{task}_{ts}.png")
fig.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=BG)
print(f" Metrics graph β {save_path}")
if _SHOW_LIVE: plt.show()
plt.close(fig)
return save_path
# ============================================================================
# PROFESSIONAL EVAULATE
# ============================================================================
def run_professional_evaluation(predictor: HFSegPredictor,
tumor_vols: List[Dict] = None,
tissue_vols: List[Dict] = None,
visualize: bool = True):
print("\n" + "=" * 80)
print(" Vbai-2.6TS Β· Professional Evaulate")
print(" " + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print("=" * 80)
print_model_summary(predictor.model)
results_tumor = []
results_tissue = []
if tumor_vols:
results_tumor = test_tumor_volumes(predictor, tumor_vols, visualize=visualize)
if any("metrics" in r for r in results_tumor):
plot_metrics_chart(results_tumor, "tumor")
if tissue_vols:
results_tissue = test_tissue_volumes(predictor, tissue_vols, visualize=visualize)
if any("metrics" in r for r in results_tissue):
plot_metrics_chart(results_tissue, "tissue")
print_summary(results_tumor, results_tissue)
# TXT report save
_save_txt_report(results_tumor, results_tissue)
# JSON save
all_results = {"tumor": [], "tissue": []}
for r in results_tumor:
if "error" in r: continue
entry = {"file": r["file"]}
if "detection" in r: entry["detection"] = r["detection"]
if "metrics" in r: entry["metrics"] = r["metrics"]
all_results["tumor"].append(entry)
for r in results_tissue:
if "metrics" in r:
all_results["tissue"].append({"file": r["file"], "metrics": r["metrics"]})
json_path = os.path.join(OUTPUT_DIR, f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
with open(json_path, "w") as f:
json.dump(all_results, f, indent=2)
print(f"\n JSON Report β {json_path}")
return results_tumor, results_tissue
def _save_txt_report(results_tumor, results_tissue):
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = os.path.join(OUTPUT_DIR, f"report_{ts}.txt")
lines = [
"=" * 80,
f" Vbai-2.6TS Segmentation Test Report β {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"=" * 80, "",
]
valid_t = [r for r in results_tumor if "metrics" in r]
valid_s = [r for r in results_tissue if "metrics" in r]
tumor_all = [r for r in results_tumor if "error" not in r]
if tumor_all:
lines += ["TUMOR SEGMENTATION", "-" * 40]
for r in tumor_all:
lines.append(f" {r['file']}")
det = r.get("detection", {})
if det:
status = "DETECTED" if det["detected"] else "NOT DETECTED"
lines.append(f" Tumor State : {status}")
if det["detected"]:
lines.append(f" Volume : ~{det['volume_cm3']} cmΒ³ ({det['volume_vx']} voxel)")
lines.append(f" Conf (avg) : {det['mean_confidence']:.1%}")
lines.append(f" Conf (max) : {det['max_confidence']:.1%}")
lines.append(f" (Note: Stage prediction requires a separate classification model)")
if "metrics" in r:
for k, v in r["metrics"].items():
lines.append(f" {k:<20}: {v}")
lines.append("")
if valid_s:
lines += ["TISSUE SEGMENTATION", "-" * 40]
for r in valid_s:
lines.append(f" {r['file']}")
for tissue in ["CSF","GrayMatter","WhiteMatter"]:
lines.append(f" {tissue}: {r['metrics'][tissue]}")
lines.append(f" Mean Dice: {r['metrics']['Mean Dice']}")
lines.append("")
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
print(f" TXT raporu β {path}")
# ============================================================================
# CLI
# ============================================================================
def parse_args():
import argparse
p = argparse.ArgumentParser(description="Vbai-2.6TS Professional Test Script",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument("--task", choices=["tumor","tissue","both"], default="both")
p.add_argument("--checkpoint", default=CHECKPOINT_PATH)
p.add_argument("--no-vis", action="store_true", help="Dont generate image")
p.add_argument("--show", action="store_true", help="Open image on the screen")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
predictor = HFSegPredictor(args.checkpoint)
run_professional_evaluation(
predictor,
tumor_vols = TUMOR_VOLUMES if args.task in ("tumor", "both") else None,
tissue_vols = TISSUE_VOLUMES if args.task in ("tissue", "both") else None,
visualize = not args.no_vis,
)
Requirements
- Python β₯ 3.9
- PyTorch β₯ 2.0
- CUDA-capable GPU, β₯ 8 GB VRAM recommended (Tested with at least an NVIDIA RTX 5060 with 8 GB of VRAM)
- See
requirements.txtfor full dependency list
License
CC-BY-NC-SA 4.0 - see LICENSE file for details.
Support
- Website: Neurazum - HealFuture
- Email: contact@neurazum.com