import numpy as np import tensorflow as tf import tensorflow_addons as tfa import tensorflow_mri as tfmri import tqdm import os import pydicom as dicom import glob from utils.unet3plusnew import * from utils.custom_unet_code import * from pydicom.tag import Tag import imageio from PIL import Image import streamlit as st import zipfile from collections import defaultdict import re from pydicom.tag import Tag from collections import defaultdict, Counter import io, shutil, zipfile, time # Resizes image def resize(t1, x, y): # Adding new axis for the channels t1 = tf.expand_dims(t1, -1) im1 = tf.image.resize_with_crop_or_pad(t1, x, y) return im1 # Function that normalises image def norm(t1): im1 = t1 im1 = (im1 - np.min(im1)) / np.max(im1) return im1 # Applies debanding model to any number of slices def apply_debanding_model(input_im, frames=32): debanding_model = "./models_final/Deband_model" debanding = tf.keras.models.load_model(os.path.join(MODELS_BASE, "Deband_model"), compile=False) weights = debanding.get_weights() inputs = tf.keras.Input(shape=[None, None, None, 1]) unet = tfmri.models.UNet3D([32, 64, 128], kernel_size=3, out_channels=1, use_global_residual=False) DB = unet(inputs) de_banding_model = tf.keras.Model(inputs=inputs, outputs=DB) de_banding_model.set_weights(weights) de_banded = [] for i in range(frames): temp = de_banding_model.predict(tf.expand_dims(tf.expand_dims(input_im[i], 0), -1), verbose=0) de_banded.append(temp) return de_banded # Function that applies deformations to 28 slice data def deformation_28(x): sagittal_deformed = [] for i in range(28): input_img = tf.expand_dims(x[0][0, i, :, :], -1) dy = tf.expand_dims(tf.expand_dims(x[1][0, i, :, :], -1), 0) dx = tf.expand_dims(tf.expand_dims(x[2][0, i, :, :], -1), 0) displacement = tf.concat((dy[0, ...], dx[0, ...]), axis=-1) img = tf.image.convert_image_dtype(tf.expand_dims(input_img, 0), tf.dtypes.float32) displacement = tf.image.convert_image_dtype(displacement, tf.dtypes.float32) dense_img_warp = tfa.image.dense_image_warp(img, displacement) im_deformed = tf.squeeze(dense_img_warp, 0) sagittal_deformed.append(im_deformed) sagittal_deformed = tf.image.convert_image_dtype(sagittal_deformed, tf.dtypes.float32) sagittal_deformed = tf.expand_dims(sagittal_deformed, axis=0) return sagittal_deformed # Applies respiratory correction model def apply_resp_model_28(input_im, frames=32): inputs = tf.keras.Input(shape=[None, 256, 128, 1]) unet = build_3d_unet_resp([None, 256, 128, 1], 2) # Acts as a deformation field generator deformation_fields = unet(inputs) # Outputs the deformation fields lambda_deformation = tf.keras.layers.Lambda(deformation_28) out_2 = lambda_deformation([inputs[:, :, :, :, 0], deformation_fields[:, :, :, :, 0], deformation_fields[:, :, :, :, 1]]) # Outputs the deformed volume outputs = [deformation_fields, out_2] complete_model = tf.keras.Model(inputs=inputs, outputs=outputs) complete_model.load_weights(os.path.join(MODELS_BASE, "Resp_Correction_model", "variables", "variables")) resp_corrected = [] deformations = [] for i in range(frames): def_fields, resp_cor = complete_model.predict(input_im[i][:, :, :, :, :], verbose=0) resp_corrected.append(resp_cor) deformations.append(def_fields) return deformations, resp_corrected # Applies super resolution model def apply_SR_model(input_im, frames=32): E2E_model = os.path.join(MODELS_BASE, "E2E_SR_model") E2E = tf.keras.models.load_model(E2E_model, compile=False) weights = E2E.get_weights() sr_weights = weights[22:] inputs = tf.keras.Input(shape=[None, None, None, 1]) SR_model = build_3d_unet(input_shape=(None, None, None, 1), num_classes=1) SR = SR_model(inputs) SR_model_done = tf.keras.Model(inputs=inputs, outputs=SR) SR_model_done.set_weights(sr_weights) super_resed = [] for i in range(frames): super_resed.append(SR_model_done.predict(input_im[i], verbose=0)) return super_resed t = Tag(0x0019, 0x10D7) # Reads in example RT sagittal stack def load_data_samples(path_to_data): sag_volumes = [] filename = f"{path_to_data}/*" if not os.path.exists(path_to_data): raise Exception("Error with file path.") else: clean_ims_1 = [] locations_1 = [] clean_ims_final = [] locations_final = [] test = sorted(glob.glob(filename)) for file in test: ds = dicom.dcmread(file) locations_1.append(ds.SliceLocation) clean_ims_1.append(ds.pixel_array) if ds[t].value == 30: clean_ims_final.append(np.array(clean_ims_1)) locations_final.append(locations_1) clean_ims_1 = [] locations_1 = [] final = np.array(clean_ims_final) final = np.transpose(final, (1, 0, 2, 3)) return final def load_data_samples_from_folder(base_dir, number_of_scans=32): """ Recursively find all DICOM files under the first valid subfolder of base_dir, group them by InstanceNumber (time), sort by SliceLocation (z), and return a NumPy array of shape (time, z, H, W). """ # 1. Find the real nested folder (skip macOS junk) candidates = [ d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and not d.startswith("._") and "__MACOSX" not in d ] if not candidates: st.error("No valid data folder found in ZIP.") return np.array([]) nested_base = os.path.join(base_dir, candidates[0]) # 2. Recursively collect every file; we'll filter for DICOMs next all_paths = glob.glob(os.path.join(nested_base, "**", "*"), recursive=True) all_paths = [p for p in all_paths if os.path.isfile(p) and not os.path.basename(p).startswith("._")] # 3. Filter valid DICOMs dicom_files = [] for p in all_paths: try: ds = dicom.dcmread(p, force=True, stop_before_pixels=True) if hasattr(ds, "InstanceNumber"): dicom_files.append(p) except: continue st.write(f"🧾 Found {len(dicom_files)} DICOM files.") if not dicom_files: st.error("No valid DICOMs found.") return np.array([]) # 4. Group by InstanceNumber (temporal frames) grouped = defaultdict(list) for p in dicom_files: try: ds = dicom.dcmread(p, force=True) inst = ds.InstanceNumber loc = getattr(ds, "SliceLocation", 0.0) grouped[inst].append((loc, ds.pixel_array)) except: continue # 5. Build volume up to number_of_scans frames vols = [] for inst in sorted(grouped.keys())[:number_of_scans]: slices = grouped[inst] # sort along z slices.sort(key=lambda x: x[0]) vols.append([img for _, img in slices]) volume = np.array(vols) # shape (T, Z, H, W) st.write(f"✅ Found data shape: {volume.shape}") return volume def load_cine_any( base_dir: str, number_of_scans: int = None, # if None, use all detected phases private_phase_tag: Tag = Tag(0x0019, 0x10D7), # your private phase tag (if present) verbose: bool = True ): """ Universal DICOM cine loader (flat or nested folders). Scans recursively from `base_dir`, detects cardiac phases, sorts slices, and returns (T, Z, H, W) along with the total number of phases detected. Returns: volume: np.ndarray with shape (T, Z, H, W) num_phases_detected: int (total phases found in the dataset) """ def log(msg): if verbose: try: st.write(msg) except Exception: print(msg) if not os.path.isdir(base_dir): raise FileNotFoundError(f"No such directory: {base_dir}") # --- Collect candidate files (recursive), skip junk/zips candidates = glob.glob(os.path.join(base_dir, "**", "*"), recursive=True) candidates = [ p for p in candidates if os.path.isfile(p) and "__MACOSX" not in p and not os.path.basename(p).startswith("._") and not p.lower().endswith(".zip") ] if not candidates: log("No files found under the provided directory.") return np.array([]), 0 # --- Keep only files that parse as DICOM headers dicom_files = [] for p in candidates: try: _ = dicom.dcmread(p, force=True, stop_before_pixels=True) dicom_files.append(p) except Exception: pass log(f"🧾 Candidate DICOM files: {len(dicom_files)}") if not dicom_files: log("No valid DICOM files found.") return np.array([]), 0 # --- NEW: detect flat folder layout (all files in the same directory) dicom_dirs = {os.path.dirname(p) for p in dicom_files} is_flat = (len(dicom_dirs) == 1) # --- Probe to choose the best phase key def _try_get(ds, tag): try: return ds[tag].value except Exception: return None uniq_priv, uniq_tpi, uniq_inst = set(), set(), set() for p in dicom_files[:min(len(dicom_files), 200)]: try: ds = dicom.dcmread(p, force=True, stop_before_pixels=True) v_priv = _try_get(ds, private_phase_tag) if v_priv is not None: try: uniq_priv.add(int(v_priv)) except Exception: pass if hasattr(ds, "TemporalPositionIdentifier"): try: uniq_tpi.add(int(ds.TemporalPositionIdentifier)) except Exception: pass if hasattr(ds, "InstanceNumber"): try: uniq_inst.add(int(ds.InstanceNumber)) except Exception: pass except Exception: continue if len(uniq_priv) > 1: phase_key = ("private", private_phase_tag) elif len(uniq_tpi) > 1: phase_key = ("tpi", None) elif len(uniq_inst) > 1: phase_key = ("instance", None) else: log("Could not determine a phase key (no variation in private/TPI/InstanceNumber).") return np.array([]), 0 def _get_phase(ds): if phase_key[0] == "private": v = _try_get(ds, phase_key[1]) return int(v) if v is not None else None if phase_key[0] == "tpi": return int(getattr(ds, "TemporalPositionIdentifier", None)) \ if hasattr(ds, "TemporalPositionIdentifier") else None if phase_key[0] == "instance": return int(getattr(ds, "InstanceNumber", None)) \ if hasattr(ds, "InstanceNumber") else None return None def _get_z(ds): z = getattr(ds, "SliceLocation", None) if z is None: ipp = getattr(ds, "ImagePositionPatient", None) if ipp is not None and len(ipp) >= 3: try: z = float(ipp[2]) except Exception: z = 0.0 else: z = 0.0 return float(z) # --- Group by phase; sort by z grouped = defaultdict(list) for p in dicom_files: try: ds = dicom.dcmread(p, force=True) ph = _get_phase(ds) if ph is None: continue grouped[int(ph)].append((_get_z(ds), ds.pixel_array)) except Exception: continue if not grouped: log("No groups formed (no phase could be read).") return np.array([]), 0 all_phase_ids = sorted(grouped.keys()) num_phases_detected = len(all_phase_ids) phases_to_use = all_phase_ids if number_of_scans is None else all_phase_ids[:number_of_scans] stacks_T, slice_counts = [], [] for ph in phases_to_use: pairs = grouped[ph] if not pairs: continue pairs.sort(key=lambda x: x[0]) # sort by z stack = [img for _, img in pairs] # Z × H × W stacks_T.append(stack) slice_counts.append(len(stack)) if not stacks_T: log("Groups existed but none had readable slices.") return np.array([]), num_phases_detected # Harmonize Z across phases (trim to the most common Z) if len(set(slice_counts)) > 1: common_Z = Counter(slice_counts).most_common(1)[0][0] stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z] if not stacks_T: log("All phases had inconsistent slice counts.") return np.array([]), num_phases_detected volume = np.array(stacks_T) # (T, Z, H, W) # --- NEW: flip slice order (Z) if data came from a flat single folder if is_flat: volume = volume[:, ::-1, :, :] log(f"✅ Final volume shape: {volume[0, ...].shape} , Phases detected = {num_phases_detected}") return volume, num_phases_detected def load_data_samples_from_flat_folder( base_dir: str, number_of_scans: int = 32, frame_tag: Tag = Tag(0x0019, 0x10D7) # private phase tag (adjust if needed) ) -> np.ndarray: """ Robust loader when all DICOMs are under one folder (possibly nested). - Steps into the single subfolder if present (ignores upload.zip, macOS junk). - Recursively finds DICOMs (even without .dcm extension). - Groups by phase from `frame_tag` or fallback (0020,0100). - Sorts by SliceLocation/IPPs and returns (Z, T, H, W). """ if not os.path.isdir(base_dir): raise FileNotFoundError(f"No such directory: {base_dir}") # --- Step 1: if there’s exactly one subfolder (plus upload.zip), dive into it entries = [e for e in os.listdir(base_dir) if not e.startswith("._")] subdirs = [os.path.join(base_dir, e) for e in entries if os.path.isdir(os.path.join(base_dir, e)) and "__MACOSX" not in e] # If precisely one subdir, prefer that as root; otherwise use base_dir as-is root = subdirs[0] if len(subdirs) == 1 else base_dir # --- Step 2: recursively collect candidate files (skip zips and junk) candidates = glob.glob(os.path.join(root, "**", "*"), recursive=True) candidates = [ p for p in candidates if os.path.isfile(p) and not os.path.basename(p).startswith("._") and "__MACOSX" not in p and not p.lower().endswith(".zip") ] if not candidates: st.error("No files found under the provided directory.") return np.array([]) # --- Step 3: keep only files that parse as DICOM headers dicom_files = [] for p in candidates: try: ds = dicom.dcmread(p, force=True, stop_before_pixels=True) dicom_files.append(p) except Exception: pass st.write(f"🧾 Candidate DICOM files: {len(dicom_files)}") if not dicom_files: st.error("No valid DICOM files found.") return np.array([]) # --- Helper: determine phase index def _get_phase(ds): # Preferred: private tag (your dataset) if frame_tag in ds: try: return int(ds[frame_tag].value) except Exception: pass # Fallback: standard TemporalPositionIdentifier (0020,0100) if hasattr(ds, "TemporalPositionIdentifier"): try: return int(ds.TemporalPositionIdentifier) except Exception: pass # Last resort: AcquisitionNumber (not always phase, but useful fallback) if hasattr(ds, "AcquisitionNumber"): try: return int(ds.AcquisitionNumber) except Exception: pass return None # --- Step 4: group by phase; sort by z grouped = defaultdict(list) phase_missing = 0 for p in dicom_files: try: ds = dicom.dcmread(p, force=True) phase = _get_phase(ds) if phase is None: phase_missing += 1 continue # z-order: SliceLocation if present else IPP[2] else 0 z = getattr(ds, "SliceLocation", None) if z is None: ipp = getattr(ds, "ImagePositionPatient", None) if ipp is not None and len(ipp) >= 3: z = float(ipp[2]) else: z = 0.0 img = ds.pixel_array grouped[int(phase)].append((z, img)) except Exception: continue if not grouped: st.error( "Could not determine a phase tag for any files. " "Check for (0019,10D7) or (0020,0100) in your dataset." ) st.write(f"Files missing phase: {phase_missing} / {len(dicom_files)}") # Optional: show attributes of one file to discover tags try: ds0 = dicom.dcmread(dicom_files[0], force=True, stop_before_pixels=True) st.write("Sample DICOM attributes:", ds0.dir()) except Exception: pass return np.array([]) # keep up to number_of_scans phases phases = sorted(grouped.keys())[:number_of_scans] stacks_T = [] slice_counts = [] for ph in phases: pairs = grouped[ph] if not pairs: continue pairs.sort(key=lambda x: x[0]) # sort by z stack = [img for _, img in pairs] # Z × H × W stacks_T.append(stack) slice_counts.append(len(stack)) if not stacks_T: st.error("No phases contained readable slices.") return np.array([]) # Harmonize Z across phases (trim to the most common slice count) if len(set(slice_counts)) > 1: common_Z = Counter(slice_counts).most_common(1)[0][0] stacks_T = [s[:common_Z] for s in stacks_T if len(s) >= common_Z] if not stacks_T: st.error("All phases had inconsistent slice counts.") return np.array([]) vol = np.array(stacks_T) # (T, Z, H, W) st.write(f"✅ Final volume shape: {vol.shape} (T, S, H, W)") return vol def extract_zip(zip_path, extract_to): with zipfile.ZipFile(zip_path, 'r') as zip_ref: # Filter out __MACOSX and dotfiles valid_files = [ f for f in zip_ref.namelist() if "__MACOSX" not in f and not os.path.basename(f).startswith("._") ] zip_ref.extractall(extract_to, members=valid_files) def make_gif(path, timepoints, axis=-1, slice=60, frame_rate=30): # 1. Locate all .npy files all_files = glob.glob(os.path.join(path, "*.npy")) print("Found NPY files:", all_files) # 2. Group them by prefix scan_keys = ['raw', 'debanded', 'resp_cor', '3D_cine'] groups = {k: [] for k in scan_keys} pattern = re.compile(r'(?Praw|debanded|resp_cor|3D_cine)_(?P\d+)\.npy') for p in all_files: fn = os.path.basename(p) match = pattern.match(fn) if match: prefix = match.group("prefix") t_idx = int(match.group("index")) groups[prefix].append((t_idx, p)) # 3. Sanity check: do all groups exist & have equal lengths? Ts = [len(v) for v in groups.values()] print("Group counts:", Ts) if not all(T == timepoints for T in Ts): raise ValueError(f"Mismatch in timepoints across groups. Expected {timepoints}, got {Ts}") for k in groups: groups[k].sort(key=lambda x: x[0]) # sort by t_idx # 4. Determine normalization range per group stats = {} for k in scan_keys: mins, maxs = [], [] for _, p in groups[k]: vol = np.load(p) if axis == -1: slice_ = vol[:, :, slice] else: slice_ = vol[:, slice, :] if axis == 1 else vol[slice, :, :] mins.append(slice_.min()) maxs.append(slice_.max()) stats[k] = (min(mins), max(maxs)) # 5. Create frames frames = [] for t in range(timepoints): imgs_t = [] for k in scan_keys: _, p = groups[k][t] vol = np.load(p).astype(np.float32) if axis == 2: img = vol[::-1, :, slice] elif axis == 1: img = vol[:, slice, :] elif axis == 0: img = vol[slice, :, :] img = np.transpose(img[:, ::-1]) mn, mx = stats[k] img = np.clip(img, mn, mx) img8 = ((img - mn) / (mx - mn) * 255).astype(np.uint8) img8 = img8.T[:, ::-1] # flip + transpose imgs_t.append(img8) # Stitch side-by-side composite = np.concatenate(imgs_t, axis=1) resized = Image.fromarray(composite).resize((composite.shape[1] * 3, composite.shape[0] * 3), Image.NEAREST) frames.append(np.array(resized)) # 6. Save and return out_path = os.path.join(path, f"temp.gif") imageio.mimsave(out_path, frames, duration=1000 / frame_rate, loop=0) return out_path def to_dicom(cardiac_frames, patient_number): filename = "./utils/dicom_headerfile.dcm" for file in glob.glob(filename): ds = dicom.read_file(file) for i in range(cardiac_frames): volume = np.load(f'./out_dir/3D_cine_{i}.npy') print(f"Volume: {i}") for j in range(volume.shape[0]): PixelData = volume[j, :, :] PixelData = (PixelData * 255).astype(np.uint16) Dicoms = ds.copy() Dicoms.InstanceNumber = j Dicoms.PatientID = 'Mark' Dicoms.PatientName = 'Mark' Dicoms.StudyDescription = '3D Cine' Dicoms.SeriesDescription = 'HR' Dicoms.StudyInstanceUID = '1.3.12.2.1107.5.2.41.169828.3001002301121546102500000000' + str(patient_number) Dicoms.SliceThickness = 1.5 Dicoms.Rows = 256 Dicoms.Columns = 128 Dicoms.AcquisitionMatrix = [0, 256, 128, 0] Dicoms.ImageOrientationPatient = [1.0, 0.0, 0.0, 0.0, 0.0, -1.0] Dicoms.SliceLocation = -100.0 + ((j - 1) * Dicoms.SliceThickness) Dicoms.SamplesPerPixel = 1 Dicoms.BitsAllocated = 16 Dicoms.BitsStored = 12 Dicoms.HighBit = 11 Dicoms.PixelRepresentation = 0 Dicoms.AcquisitionNumber = i Dicoms.SeriesNumber = i Dicoms.PixelSpacing = [1.5, 1.5] Dicoms.SmallestImagePixelValue = 0 Dicoms.LargestImagePixelValue = 255 Dicoms.PixelData = PixelData.tobytes() Dicoms.SeriesInstanceUID = '1.3.12.2.1107.5.2.41.169828.300100230112154610250000001' + str(i) Dicoms.SOPInstanceUID = dicom.uid.generate_uid() Dicoms.AcquisitionTime = str(i) Dicoms.SeriesTime = str(i) dicom.filewriter.dcmwrite(filename=f'./out_dicoms/MARK_PATIENT_{patient_number}_VOL_{i}_SLICE_{j}.dcm', dataset=Dicoms) return 42 def zip_dir_to_memory(dir_path: str) -> io.BytesIO: buf = io.BytesIO() with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf: for root, _, files in os.walk(dir_path): for f in files: full = os.path.join(root, f) arc = os.path.relpath(full, dir_path) # keep relative paths in zip zf.write(full, arc) buf.seek(0) return buf