3DCine / utils /process_utils.py
MarkWrobel's picture
Update utils/process_utils.py
e7117ed verified
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_mri as tfmri
import tqdm
import os
import pydicom as dicom
import glob
from utils.unet3plusnew import *
from utils.custom_unet_code import *
from pydicom.tag import Tag
import imageio
from PIL import Image
import streamlit as st
import zipfile
from collections import defaultdict
import re
from pydicom.tag import Tag
from collections import defaultdict, Counter
import io, shutil, zipfile, time
# Resizes image
def resize(t1, x, y):
# Adding new axis for the channels
t1 = tf.expand_dims(t1, -1)
im1 = tf.image.resize_with_crop_or_pad(t1, x, y)
return im1
# Function that normalises image
def norm(t1):
im1 = t1
im1 = (im1 - np.min(im1)) / np.max(im1)
return im1
# Applies debanding model to any number of slices
def apply_debanding_model(input_im, frames=32):
debanding_model = "./models_final/Deband_model"
debanding = tf.keras.models.load_model(os.path.join(MODELS_BASE, "Deband_model"), compile=False)
weights = debanding.get_weights()
inputs = tf.keras.Input(shape=[None, None, None, 1])
unet = tfmri.models.UNet3D([32, 64, 128], kernel_size=3, out_channels=1, use_global_residual=False)
DB = unet(inputs)
de_banding_model = tf.keras.Model(inputs=inputs, outputs=DB)
de_banding_model.set_weights(weights)
de_banded = []
for i in range(frames):
temp = de_banding_model.predict(tf.expand_dims(tf.expand_dims(input_im[i], 0), -1), verbose=0)
de_banded.append(temp)
return de_banded
# Function that applies deformations to 28 slice data
def deformation_28(x):
sagittal_deformed = []
for i in range(28):
input_img = tf.expand_dims(x[0][0, i, :, :], -1)
dy = tf.expand_dims(tf.expand_dims(x[1][0, i, :, :], -1), 0)
dx = tf.expand_dims(tf.expand_dims(x[2][0, i, :, :], -1), 0)
displacement = tf.concat((dy[0, ...], dx[0, ...]), axis=-1)
img = tf.image.convert_image_dtype(tf.expand_dims(input_img, 0), tf.dtypes.float32)
displacement = tf.image.convert_image_dtype(displacement, tf.dtypes.float32)
dense_img_warp = tfa.image.dense_image_warp(img, displacement)
im_deformed = tf.squeeze(dense_img_warp, 0)
sagittal_deformed.append(im_deformed)
sagittal_deformed = tf.image.convert_image_dtype(sagittal_deformed, tf.dtypes.float32)
sagittal_deformed = tf.expand_dims(sagittal_deformed, axis=0)
return sagittal_deformed
# Applies respiratory correction model
def apply_resp_model_28(input_im, frames=32):
inputs = tf.keras.Input(shape=[None, 256, 128, 1])
unet = build_3d_unet_resp([None, 256, 128, 1], 2) # Acts as a deformation field generator
deformation_fields = unet(inputs) # Outputs the deformation fields
lambda_deformation = tf.keras.layers.Lambda(deformation_28)
out_2 = lambda_deformation([inputs[:, :, :, :, 0], deformation_fields[:, :, :, :, 0], deformation_fields[:, :, :, :, 1]]) # Outputs the deformed volume
outputs = [deformation_fields, out_2]
complete_model = tf.keras.Model(inputs=inputs, outputs=outputs)
complete_model.load_weights(os.path.join(MODELS_BASE, "Resp_Correction_model", "variables", "variables"))
resp_corrected = []
deformations = []
for i in range(frames):
def_fields, resp_cor = complete_model.predict(input_im[i][:, :, :, :, :], verbose=0)
resp_corrected.append(resp_cor)
deformations.append(def_fields)
return deformations, resp_corrected
# Applies super resolution model
def apply_SR_model(input_im, frames=32):
E2E_model = os.path.join(MODELS_BASE, "E2E_SR_model")
E2E = tf.keras.models.load_model(E2E_model, compile=False)
weights = E2E.get_weights()
sr_weights = weights[22:]
inputs = tf.keras.Input(shape=[None, None, None, 1])
SR_model = build_3d_unet(input_shape=(None, None, None, 1), num_classes=1)
SR = SR_model(inputs)
SR_model_done = tf.keras.Model(inputs=inputs, outputs=SR)
SR_model_done.set_weights(sr_weights)
super_resed = []
for i in range(frames):
super_resed.append(SR_model_done.predict(input_im[i], verbose=0))
return super_resed
t = Tag(0x0019, 0x10D7)
# Reads in example RT sagittal stack
def load_data_samples(path_to_data):
sag_volumes = []
filename = f"{path_to_data}/*"
if not os.path.exists(path_to_data):
raise Exception("Error with file path.")
else:
clean_ims_1 = []
locations_1 = []
clean_ims_final = []
locations_final = []
test = sorted(glob.glob(filename))
for file in test:
ds = dicom.dcmread(file)
locations_1.append(ds.SliceLocation)
clean_ims_1.append(ds.pixel_array)
if ds[t].value == 30:
clean_ims_final.append(np.array(clean_ims_1))
locations_final.append(locations_1)
clean_ims_1 = []
locations_1 = []
final = np.array(clean_ims_final)
final = np.transpose(final, (1, 0, 2, 3))
return final
def load_data_samples_from_folder(base_dir, number_of_scans=32):
"""
Recursively find all DICOM files under the first valid subfolder of base_dir,
group them by InstanceNumber (time), sort by SliceLocation (z), and return
a NumPy array of shape (time, z, H, W).
"""
# 1. Find the real nested folder (skip macOS junk)
candidates = [
d for d in os.listdir(base_dir)
if os.path.isdir(os.path.join(base_dir, d))
and not d.startswith("._")
and "__MACOSX" not in d
]
if not candidates:
st.error("No valid data folder found in ZIP.")
return np.array([])
nested_base = os.path.join(base_dir, candidates[0])
# 2. Recursively collect every file; we'll filter for DICOMs next
all_paths = glob.glob(os.path.join(nested_base, "**", "*"), recursive=True)
all_paths = [p for p in all_paths if os.path.isfile(p) and not os.path.basename(p).startswith("._")]
# 3. Filter valid DICOMs
dicom_files = []
for p in all_paths:
try:
ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
if hasattr(ds, "InstanceNumber"):
dicom_files.append(p)
except:
continue
st.write(f"🧾 Found {len(dicom_files)} DICOM files.")
if not dicom_files:
st.error("No valid DICOMs found.")
return np.array([])
# 4. Group by InstanceNumber (temporal frames)
grouped = defaultdict(list)
for p in dicom_files:
try:
ds = dicom.dcmread(p, force=True)
inst = ds.InstanceNumber
loc = getattr(ds, "SliceLocation", 0.0)
grouped[inst].append((loc, ds.pixel_array))
except:
continue
# 5. Build volume up to number_of_scans frames
vols = []
for inst in sorted(grouped.keys())[:number_of_scans]:
slices = grouped[inst]
# sort along z
slices.sort(key=lambda x: x[0])
vols.append([img for _, img in slices])
volume = np.array(vols) # shape (T, Z, H, W)
st.write(f"✅ Found data shape: {volume.shape}")
return volume
def load_cine_any(
base_dir: str,
number_of_scans: int = None, # if None, use all detected phases
private_phase_tag: Tag = Tag(0x0019, 0x10D7), # your private phase tag (if present)
verbose: bool = True
):
"""
Universal DICOM cine loader (flat or nested folders).
Scans recursively from `base_dir`, detects cardiac phases, sorts slices,
and returns (T, Z, H, W) along with the total number of phases detected.
Returns:
volume: np.ndarray with shape (T, Z, H, W)
num_phases_detected: int (total phases found in the dataset)
"""
def log(msg):
if verbose:
try:
st.write(msg)
except Exception:
print(msg)
if not os.path.isdir(base_dir):
raise FileNotFoundError(f"No such directory: {base_dir}")
# --- Collect candidate files (recursive), skip junk/zips
candidates = glob.glob(os.path.join(base_dir, "**", "*"), recursive=True)
candidates = [
p for p in candidates
if os.path.isfile(p)
and "__MACOSX" not in p
and not os.path.basename(p).startswith("._")
and not p.lower().endswith(".zip")
]
if not candidates:
log("No files found under the provided directory.")
return np.array([]), 0
# --- Keep only files that parse as DICOM headers
dicom_files = []
for p in candidates:
try:
_ = dicom.dcmread(p, force=True, stop_before_pixels=True)
dicom_files.append(p)
except Exception:
pass
log(f"🧾 Candidate DICOM files: {len(dicom_files)}")
if not dicom_files:
log("No valid DICOM files found.")
return np.array([]), 0
# --- NEW: detect flat folder layout (all files in the same directory)
dicom_dirs = {os.path.dirname(p) for p in dicom_files}
is_flat = (len(dicom_dirs) == 1)
# --- Probe to choose the best phase key
def _try_get(ds, tag):
try:
return ds[tag].value
except Exception:
return None
uniq_priv, uniq_tpi, uniq_inst = set(), set(), set()
for p in dicom_files[:min(len(dicom_files), 200)]:
try:
ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
v_priv = _try_get(ds, private_phase_tag)
if v_priv is not None:
try:
uniq_priv.add(int(v_priv))
except Exception:
pass
if hasattr(ds, "TemporalPositionIdentifier"):
try:
uniq_tpi.add(int(ds.TemporalPositionIdentifier))
except Exception:
pass
if hasattr(ds, "InstanceNumber"):
try:
uniq_inst.add(int(ds.InstanceNumber))
except Exception:
pass
except Exception:
continue
if len(uniq_priv) > 1:
phase_key = ("private", private_phase_tag)
elif len(uniq_tpi) > 1:
phase_key = ("tpi", None)
elif len(uniq_inst) > 1:
phase_key = ("instance", None)
else:
log("Could not determine a phase key (no variation in private/TPI/InstanceNumber).")
return np.array([]), 0
def _get_phase(ds):
if phase_key[0] == "private":
v = _try_get(ds, phase_key[1])
return int(v) if v is not None else None
if phase_key[0] == "tpi":
return int(getattr(ds, "TemporalPositionIdentifier", None)) \
if hasattr(ds, "TemporalPositionIdentifier") else None
if phase_key[0] == "instance":
return int(getattr(ds, "InstanceNumber", None)) \
if hasattr(ds, "InstanceNumber") else None
return None
def _get_z(ds):
z = getattr(ds, "SliceLocation", None)
if z is None:
ipp = getattr(ds, "ImagePositionPatient", None)
if ipp is not None and len(ipp) >= 3:
try:
z = float(ipp[2])
except Exception:
z = 0.0
else:
z = 0.0
return float(z)
# --- Group by phase; sort by z
grouped = defaultdict(list)
for p in dicom_files:
try:
ds = dicom.dcmread(p, force=True)
ph = _get_phase(ds)
if ph is None:
continue
grouped[int(ph)].append((_get_z(ds), ds.pixel_array))
except Exception:
continue
if not grouped:
log("No groups formed (no phase could be read).")
return np.array([]), 0
all_phase_ids = sorted(grouped.keys())
num_phases_detected = len(all_phase_ids)
phases_to_use = all_phase_ids if number_of_scans is None else all_phase_ids[:number_of_scans]
stacks_T, slice_counts = [], []
for ph in phases_to_use:
pairs = grouped[ph]
if not pairs:
continue
pairs.sort(key=lambda x: x[0]) # sort by z
stack = [img for _, img in pairs] # Z × H × W
stacks_T.append(stack)
slice_counts.append(len(stack))
if not stacks_T:
log("Groups existed but none had readable slices.")
return np.array([]), num_phases_detected
# Harmonize Z across phases (trim to the most common Z)
if len(set(slice_counts)) > 1:
common_Z = Counter(slice_counts).most_common(1)[0][0]
stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z]
if not stacks_T:
log("All phases had inconsistent slice counts.")
return np.array([]), num_phases_detected
volume = np.array(stacks_T) # (T, Z, H, W)
# --- NEW: flip slice order (Z) if data came from a flat single folder
if is_flat:
volume = volume[:, ::-1, :, :]
log(f"✅ Final volume shape: {volume[0, ...].shape} , Phases detected = {num_phases_detected}")
return volume, num_phases_detected
def load_data_samples_from_flat_folder(
base_dir: str,
number_of_scans: int = 32,
frame_tag: Tag = Tag(0x0019, 0x10D7) # private phase tag (adjust if needed)
) -> np.ndarray:
"""
Robust loader when all DICOMs are under one folder (possibly nested).
- Steps into the single subfolder if present (ignores upload.zip, macOS junk).
- Recursively finds DICOMs (even without .dcm extension).
- Groups by phase from `frame_tag` or fallback (0020,0100).
- Sorts by SliceLocation/IPPs and returns (Z, T, H, W).
"""
if not os.path.isdir(base_dir):
raise FileNotFoundError(f"No such directory: {base_dir}")
# --- Step 1: if there’s exactly one subfolder (plus upload.zip), dive into it
entries = [e for e in os.listdir(base_dir) if not e.startswith("._")]
subdirs = [os.path.join(base_dir, e) for e in entries
if os.path.isdir(os.path.join(base_dir, e)) and "__MACOSX" not in e]
# If precisely one subdir, prefer that as root; otherwise use base_dir as-is
root = subdirs[0] if len(subdirs) == 1 else base_dir
# --- Step 2: recursively collect candidate files (skip zips and junk)
candidates = glob.glob(os.path.join(root, "**", "*"), recursive=True)
candidates = [
p for p in candidates
if os.path.isfile(p)
and not os.path.basename(p).startswith("._")
and "__MACOSX" not in p
and not p.lower().endswith(".zip")
]
if not candidates:
st.error("No files found under the provided directory.")
return np.array([])
# --- Step 3: keep only files that parse as DICOM headers
dicom_files = []
for p in candidates:
try:
ds = dicom.dcmread(p, force=True, stop_before_pixels=True)
dicom_files.append(p)
except Exception:
pass
st.write(f"🧾 Candidate DICOM files: {len(dicom_files)}")
if not dicom_files:
st.error("No valid DICOM files found.")
return np.array([])
# --- Helper: determine phase index
def _get_phase(ds):
# Preferred: private tag (your dataset)
if frame_tag in ds:
try:
return int(ds[frame_tag].value)
except Exception:
pass
# Fallback: standard TemporalPositionIdentifier (0020,0100)
if hasattr(ds, "TemporalPositionIdentifier"):
try:
return int(ds.TemporalPositionIdentifier)
except Exception:
pass
# Last resort: AcquisitionNumber (not always phase, but useful fallback)
if hasattr(ds, "AcquisitionNumber"):
try:
return int(ds.AcquisitionNumber)
except Exception:
pass
return None
# --- Step 4: group by phase; sort by z
grouped = defaultdict(list)
phase_missing = 0
for p in dicom_files:
try:
ds = dicom.dcmread(p, force=True)
phase = _get_phase(ds)
if phase is None:
phase_missing += 1
continue
# z-order: SliceLocation if present else IPP[2] else 0
z = getattr(ds, "SliceLocation", None)
if z is None:
ipp = getattr(ds, "ImagePositionPatient", None)
if ipp is not None and len(ipp) >= 3:
z = float(ipp[2])
else:
z = 0.0
img = ds.pixel_array
grouped[int(phase)].append((z, img))
except Exception:
continue
if not grouped:
st.error(
"Could not determine a phase tag for any files. "
"Check for (0019,10D7) or (0020,0100) in your dataset."
)
st.write(f"Files missing phase: {phase_missing} / {len(dicom_files)}")
# Optional: show attributes of one file to discover tags
try:
ds0 = dicom.dcmread(dicom_files[0], force=True, stop_before_pixels=True)
st.write("Sample DICOM attributes:", ds0.dir())
except Exception:
pass
return np.array([])
# keep up to number_of_scans phases
phases = sorted(grouped.keys())[:number_of_scans]
stacks_T = []
slice_counts = []
for ph in phases:
pairs = grouped[ph]
if not pairs:
continue
pairs.sort(key=lambda x: x[0]) # sort by z
stack = [img for _, img in pairs] # Z × H × W
stacks_T.append(stack)
slice_counts.append(len(stack))
if not stacks_T:
st.error("No phases contained readable slices.")
return np.array([])
# Harmonize Z across phases (trim to the most common slice count)
if len(set(slice_counts)) > 1:
common_Z = Counter(slice_counts).most_common(1)[0][0]
stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z]
if not stacks_T:
st.error("All phases had inconsistent slice counts.")
return np.array([])
vol = np.array(stacks_T) # (T, Z, H, W)
st.write(f"✅ Final volume shape: {vol.shape} (T, S, H, W)")
return vol
def extract_zip(zip_path, extract_to):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# Filter out __MACOSX and dotfiles
valid_files = [
f for f in zip_ref.namelist()
if "__MACOSX" not in f and not os.path.basename(f).startswith("._")
]
zip_ref.extractall(extract_to, members=valid_files)
def make_gif(path, timepoints, axis=-1, slice=60, frame_rate=30):
# 1. Locate all .npy files
all_files = glob.glob(os.path.join(path, "*.npy"))
print("Found NPY files:", all_files)
# 2. Group them by prefix
scan_keys = ['raw', 'debanded', 'resp_cor', '3D_cine']
groups = {k: [] for k in scan_keys}
pattern = re.compile(r'(?P<prefix>raw|debanded|resp_cor|3D_cine)_(?P<index>\d+)\.npy')
for p in all_files:
fn = os.path.basename(p)
match = pattern.match(fn)
if match:
prefix = match.group("prefix")
t_idx = int(match.group("index"))
groups[prefix].append((t_idx, p))
# 3. Sanity check: do all groups exist & have equal lengths?
Ts = [len(v) for v in groups.values()]
print("Group counts:", Ts)
if not all(T == timepoints for T in Ts):
raise ValueError(f"Mismatch in timepoints across groups. Expected {timepoints}, got {Ts}")
for k in groups:
groups[k].sort(key=lambda x: x[0]) # sort by t_idx
# 4. Determine normalization range per group
stats = {}
for k in scan_keys:
mins, maxs = [], []
for _, p in groups[k]:
vol = np.load(p)
if axis == -1:
slice_ = vol[:, :, slice]
else:
slice_ = vol[:, slice, :] if axis == 1 else vol[slice, :, :]
mins.append(slice_.min())
maxs.append(slice_.max())
stats[k] = (min(mins), max(maxs))
# 5. Create frames
frames = []
for t in range(timepoints):
imgs_t = []
for k in scan_keys:
_, p = groups[k][t]
vol = np.load(p).astype(np.float32)
if axis == 2:
img = vol[::-1, :, slice]
elif axis == 1:
img = vol[:, slice, :]
elif axis == 0:
img = vol[slice, :, :]
img = np.transpose(img[:, ::-1])
mn, mx = stats[k]
img = np.clip(img, mn, mx)
img8 = ((img - mn) / (mx - mn) * 255).astype(np.uint8)
img8 = img8.T[:, ::-1] # flip + transpose
imgs_t.append(img8)
# Stitch side-by-side
composite = np.concatenate(imgs_t, axis=1)
resized = Image.fromarray(composite).resize((composite.shape[1] * 3, composite.shape[0] * 3), Image.NEAREST)
frames.append(np.array(resized))
# 6. Save and return
out_path = os.path.join(path, f"temp.gif")
imageio.mimsave(out_path, frames, duration=1000 / frame_rate, loop=0)
return out_path
def to_dicom(cardiac_frames, patient_number):
filename = "./utils/dicom_headerfile.dcm"
for file in glob.glob(filename):
ds = dicom.read_file(file)
for i in range(cardiac_frames):
volume = np.load(f'./out_dir/3D_cine_{i}.npy')
print(f"Volume: {i}")
for j in range(volume.shape[0]):
PixelData = volume[j, :, :]
PixelData = (PixelData * 255).astype(np.uint16)
Dicoms = ds.copy()
Dicoms.InstanceNumber = j
Dicoms.PatientID = 'Mark'
Dicoms.PatientName = 'Mark'
Dicoms.StudyDescription = '3D Cine'
Dicoms.SeriesDescription = 'HR'
Dicoms.StudyInstanceUID = '1.3.12.2.1107.5.2.41.169828.3001002301121546102500000000' + str(patient_number)
Dicoms.SliceThickness = 1.5
Dicoms.Rows = 256
Dicoms.Columns = 128
Dicoms.AcquisitionMatrix = [0, 256, 128, 0]
Dicoms.ImageOrientationPatient = [1.0, 0.0, 0.0, 0.0, 0.0, -1.0]
Dicoms.SliceLocation = -100.0 + ((j - 1) * Dicoms.SliceThickness)
Dicoms.SamplesPerPixel = 1
Dicoms.BitsAllocated = 16
Dicoms.BitsStored = 12
Dicoms.HighBit = 11
Dicoms.PixelRepresentation = 0
Dicoms.AcquisitionNumber = i
Dicoms.SeriesNumber = i
Dicoms.PixelSpacing = [1.5, 1.5]
Dicoms.SmallestImagePixelValue = 0
Dicoms.LargestImagePixelValue = 255
Dicoms.PixelData = PixelData.tobytes()
Dicoms.SeriesInstanceUID = '1.3.12.2.1107.5.2.41.169828.300100230112154610250000001' + str(i)
Dicoms.SOPInstanceUID = dicom.uid.generate_uid()
Dicoms.AcquisitionTime = str(i)
Dicoms.SeriesTime = str(i)
dicom.filewriter.dcmwrite(filename=f'./out_dicoms/MARK_PATIENT_{patient_number}_VOL_{i}_SLICE_{j}.dcm', dataset=Dicoms)
return 42
def zip_dir_to_memory(dir_path: str) -> io.BytesIO:
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(dir_path):
for f in files:
full = os.path.join(root, f)
arc = os.path.relpath(full, dir_path) # keep relative paths in zip
zf.write(full, arc)
buf.seek(0)
return buf