import os import re import time import numpy as np import pandas as pd import pydicom import tensorflow as tf import matplotlib.pyplot as plt import matplotlib.animation as animation from utils.layer_util import * from skimage.measure import label, regionprops import streamlit as st import io, shutil, zipfile, time os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set the GPU to use, if available if "artifacts" not in st.session_state: st.session_state.artifacts = {} # {"gif_bytes": ..., "csv_bytes": ...} if "processed" not in st.session_state: st.session_state.processed = False import os output_root = os.path.join("/tmp", "DICOM_OUTPUTS") # import neptune.new as neptune # from neptune.new.integrations.tensorflow_keras import NeptuneCallback # from neptune.types import File # Root folder containing all uploaded DICOMs (across patients) #dicom_path = '/workspaces/PhD/DICOMS' # Where to write outputs users can download #output_root = './DICOM_OUTPUTS' gifs_dir = os.path.join(output_root, 'GIFs') csv_dir = os.path.join(output_root, 'CSV') os.makedirs(gifs_dir, exist_ok=True) os.makedirs(csv_dir, exist_ok=True) # Path to trained model model_path = './models_final/SEG459.h5' # Network input size and class map from training SIZE_X = 256 SIZE_Y = 256 N_CLASSES = 3 # 0=background, 1=myocardium, 2=blood pool # For myocardium mass (same constants) myocardium_density = 1.05 # mg/mm^3 (≈1.05 g/ml) # --- Island removal toggle & parameters --- ENABLE_ISLAND_REMOVAL = True ISLAND_MIN_SLICES = 2 # spans at least this many slices ISLAND_MIN_AREA = 10 # per-slice minimum area ISLAND_DISTANCE_THRESH = 40 # max centroid distance (voxels) from dominant component # ========================================================= # GT-FREE MASK PREPROCESS # ========================================================= def extract_details_for_sorting(filename): """ From your original code: parse filenames like ..._slXX_..._frYY... Returns: (patient_id, slice_number, frame_number) """ base = filename.split('.')[0] parts = re.split(r'[_-]', base) slice_idx = next(i for i, p in enumerate(parts) if p.startswith('sl')) frame_idx = next(i for i, p in enumerate(parts) if p.startswith('fr')) patient_id = '_'.join(parts[:slice_idx]) slice_number = int(parts[slice_idx].replace('sl', '')) frame_number = int(parts[frame_idx].replace('fr', '')) return patient_id, slice_number, frame_number def normalize_images(images): """ Per-image min-max normalization. Accepts: array (N, H, W, 1) or similar; we use it per-slice. """ images = tf.cast(images, tf.float64) min_val = tf.reduce_min(images, axis=[1, 2], keepdims=True) max_val = tf.reduce_max(images, axis=[1, 2], keepdims=True) normalized_images = tf.where( max_val > min_val, (images - min_val) / (max_val - min_val), images ) return normalized_images def _tf_resize_pad(img, sx=SIZE_X, sy=SIZE_Y): return tf.image.resize_with_crop_or_pad(img[:, :, np.newaxis], sx, sy).numpy()[:, :, 0] def load_process_data_no_gt(dicom_path, SIZE_X, SIZE_Y): """ Build per-patient 4D stacks (H, W, S, F) from DICOMs only. Returns: - all_frames_test_images: { patient_id: (H, W, S, F) } - per_patient_used_files: { patient_id: [filepaths in stack order] } """ file_names_DICOM = [f for f in os.listdir(dicom_path) if f.lower().endswith('.dcm')] patients = {} for f in file_names_DICOM: try: pid, sl, fr = extract_details_for_sorting(f) except Exception: # If a file doesn't match the pattern, skip it continue patients.setdefault(pid, {}).setdefault(fr, []).append((sl, f)) all_frames_test_images = {} per_patient_used_files = {} for pid, frames in patients.items(): per_frame_stacks = [] used_files = [] for fr in sorted(frames.keys()): slices = sorted(frames[fr], key=lambda x: x[0]) imgs = [] for sl, fname in slices: fp = os.path.join(dicom_path, fname) ds = pydicom.dcmread(fp) img = ds.pixel_array.astype(np.float32) if img.shape[0] != SIZE_X or img.shape[1] != SIZE_Y: img = _tf_resize_pad(img, SIZE_X, SIZE_Y) imgs.append(img) used_files.append(fp) if imgs: per_frame_stacks.append(np.stack(imgs, axis=-1)) # (H, W, S) if per_frame_stacks: all_frames_test_images[pid] = np.stack(per_frame_stacks, axis=-1) # (H, W, S, F) per_patient_used_files[pid] = used_files return all_frames_test_images, per_patient_used_files # ========================================================= # VOLUME & MASS HELPERS # ========================================================= def calculate_volume_from_mask(mask, row_mm, col_mm, slice_thickness): pixel_area = row_mm * col_mm blood_pool_area = np.sum(mask == 2) return blood_pool_area * pixel_area * slice_thickness def calculate_myocardium_mass(mask, row_mm, col_mm, slice_thickness, density): pixel_area = row_mm * col_mm myocardium_area = np.sum(mask == 1) myocardium_volume = myocardium_area * pixel_area * slice_thickness return myocardium_volume * density # ========================================================= # DICOM SPACING & THICKNESS # ========================================================= def read_spacing_thickness_from_files(filepaths): """ Return row_mm, col_mm, slice_thickness_mm from DICOM headers. Fallback for thickness uses ImagePositionPatient z-steps. """ row_mm = None col_mm = None th = None z_positions = [] for fp in filepaths: try: ds = pydicom.dcmread(fp, stop_before_pixels=True) except Exception: continue if getattr(ds, 'PixelSpacing', None) is not None and row_mm is None and col_mm is None: row_mm, col_mm = map(float, ds.PixelSpacing) # mm if getattr(ds, 'SliceThickness', None) is not None and th is None: th = float(ds.SliceThickness) ipp = getattr(ds, 'ImagePositionPatient', None) if ipp is not None and len(ipp) == 3: z_positions.append(float(ipp[2])) # Fallback thickness from z-distance if th is None and len(z_positions) > 1: z_positions = sorted(z_positions) diffs = np.diff(z_positions) diffs = [d for d in diffs if abs(d) > 1e-6] if diffs: th = float(np.median(np.abs(diffs))) # Sensible defaults if missing if row_mm is None: row_mm = 1.0 if col_mm is None: col_mm = 1.0 if th is None: th = 1.0 return row_mm, col_mm, th # ========================================================= # MODEL LOAD (ISLAND REMOVAL & LOSS) # ========================================================= def remove_inconsistent_slices( segmentation_slices, min_slices=2, min_area=10, distance_threshold=40 ): """ Keeps components that span >= min_slices and lie within distance_threshold of the dominant component’s centroid. Class-wise cleaning for labels {1,2}. """ H, W, S = segmentation_slices.shape cleaned_segmentation = np.zeros_like(segmentation_slices, dtype=np.uint8) for class_label, label_name in [(1, "myocardium"), (2, "blood_pool")]: binary_mask = (segmentation_slices == class_label).astype(np.uint8) labels_3d = label(binary_mask, connectivity=1) regions = [r for r in regionprops(labels_3d) if r.area >= min_area] if not regions: continue dominant_region = max(regions, key=lambda r: len(set(c[2] for c in r.coords))) dominant_centroid = np.array(dominant_region.centroid) for region in regions: slices_present = set(c[2] for c in region.coords) centroid = np.array(region.centroid) distance = np.linalg.norm(centroid - dominant_centroid) if len(slices_present) >= min_slices and distance <= distance_threshold: for c in region.coords: cleaned_segmentation[c[0], c[1], c[2]] = class_label return cleaned_segmentation def clean_predictions_per_frame(preds_4d, min_slices=2, min_area=10, distance_threshold=40): """ Apply remove_inconsistent_slices() to each time frame (H, W, S) of preds_4d (H, W, S, F). Returns a new array with the same dtype/shape as preds_4d. """ H, W, S, F = preds_4d.shape cleaned = np.empty_like(preds_4d) for f in range(F): cleaned[..., f] = remove_inconsistent_slices( preds_4d[..., f], min_slices=min_slices, min_area=min_area, distance_threshold=distance_threshold ) return cleaned def dice(y_true, y_pred, smooth=1e-6): y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1]) y_pred_f = tf.reshape(tf.clip_by_value(y_pred, 0.0, 1.0), [-1]) intersection = tf.reduce_sum(y_true_f * y_pred_f) union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) return (2. * intersection + smooth) / (union + smooth) def dice_coef_class(class_index, name=None, smooth=1e-6): def wrapped_dice(y_true, y_pred): y_true_c = tf.cast(y_true[..., class_index], tf.float32) y_pred_c = tf.clip_by_value(y_pred[..., class_index], 0.0, 1.0) y_true_f = tf.reshape(y_true_c, [-1]) y_pred_f = tf.reshape(y_pred_c, [-1]) intersection = tf.reduce_sum(y_true_f * y_pred_f) union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) return (2. * intersection + smooth) / (union + smooth) return tf.keras.metrics.MeanMetricWrapper(wrapped_dice, name=name or f'dice_class_{class_index}') def dice_coef_no_bkg(y_true, y_pred, smooth=1e-6): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) y_true_fg = y_true[..., 1:] y_pred_fg = y_pred[..., 1:] y_true_f = tf.reshape(y_true_fg, [-1, tf.shape(y_true_fg)[-1]]) y_pred_f = tf.reshape(y_pred_fg, [-1, tf.shape(y_pred_fg)[-1]]) intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0) denominator = tf.reduce_sum(y_true_f + y_pred_f, axis=0) dice_vals = (2. * intersection + smooth) / (denominator + smooth) return tf.reduce_mean(dice_vals) def focal_tversky_loss(y_true, y_pred, alpha=0.5, beta=0.5, gamma=1.0, smooth=1e-6): y_true = tf.cast(y_true, tf.float32) y_pred = tf.clip_by_value(y_pred, smooth, 1.0 - smooth) num_classes = 3 loss = 0.0 for c in range(num_classes): y_true_c = y_true[..., c] y_pred_c = y_pred[..., c] true_pos = tf.reduce_sum(y_true_c * y_pred_c) false_neg = tf.reduce_sum(y_true_c * (1 - y_pred_c)) false_pos = tf.reduce_sum((1 - y_true_c) * y_pred_c) tversky_index = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) loss += tf.pow((1 - tversky_index), gamma) return loss / tf.cast(num_classes, tf.float32) CUSTOM_OBJECTS = { 'focal_tversky_loss': focal_tversky_loss, 'dice_coef_no_bkg': dice_coef_no_bkg, 'ResizeAndConcatenate': ResizeAndConcatenate, 'dice_myo': dice_coef_class(1, name='dice_myo'), 'dice_blood': dice_coef_class(2, name='dice_blood'), 'dice': dice } # ========================================================= # PREDICTION # ========================================================= def predict_patient_images(model, images_4d): """ Mirrors your original prediction shape logic: - For each frame, feed (S,H,W,1) - Use last deep-supervision head if list - Argmax to labels {0,1,2}, then reshape back to (H,W,S,F) """ H, W, S, F = images_4d.shape preds = np.zeros((H, W, S, F), dtype=np.uint8) for f in range(F): frame = images_4d[..., f] # (H,W,S) batch = np.moveaxis(frame, -1, 0)[..., np.newaxis].astype(np.float32) # (S,H,W,1) batch = normalize_images(batch).numpy() out = model.predict(batch, verbose=0) if isinstance(out, list): out = out[-1] lab = np.argmax(out, axis=-1).astype(np.uint8) # (S,H,W) lab = np.moveaxis(lab, 0, -1) # (H,W,S) preds[..., f] = lab return preds # ========================================================= # ED/ES PICKER FROM PRED # ========================================================= def pick_ed_es_from_predictions( preds_4d, row_mm, col_mm, slice_thickness_mm, prefer_frame0: bool = True, rel_tolerance: float = 0.05, # keep frame 0 as ED if within 5% of argmax abs_tolerance_uL: float = 0.0, # or within this absolute tolerance (µL); 0 = ignore min_temporal_separation: int = 0 # ensure ES is at least this many frames away from ED ): """ ED/ES selection: - Compute LV blood-pool volume per frame from predicted labels (class==2). - If prefer_frame0=True, default ED=0 but validate vs argmax(volume): * If volume(frame0) within tolerance of max(volume), keep ED=0. * Else, ED=argmax(volume). - ES = argmin(volume), with safeguards so ES != ED and (optionally) not too close in time to ED. Returns ------- ed_idx : int es_idx : int frame_vols : list[float] (per-frame volumes in µL) """ import numpy as np H, W, S, F = preds_4d.shape frame_vols = [] for f in range(F): v = calculate_volume_from_mask(preds_4d[..., f], row_mm, col_mm, slice_thickness_mm) frame_vols.append(float(v)) # Physiologic candidates ed_argmax = int(np.argmax(frame_vols)) es_argmin = int(np.argmin(frame_vols)) # --- ED selection--- if prefer_frame0 and F > 0: v0 = frame_vols[0] vmax = frame_vols[ed_argmax] # If max is non-positive, just use frame 0 if vmax <= 0: ed_idx = 0 ed_source = "frame0_fallback_vmax<=0" else: # Keep frame 0 if it's close enough to the argmax close_by_rel = abs(vmax - v0) <= (rel_tolerance * vmax) close_by_abs = (abs_tolerance_uL > 0.0) and (abs(vmax - v0) <= abs_tolerance_uL) if close_by_rel or close_by_abs: ed_idx = 0 ed_source = "frame0_within_tolerance" else: ed_idx = ed_argmax ed_source = "argmax" else: ed_idx = ed_argmax ed_source = "argmax" # --- ES selection --- # Start from pure argmin es_idx = es_argmin # Ensure ES != ED and optionally far enough from ED in time # --- ES selection (with safeguards) --- # Start from pure argmin es_idx = es_argmin if es_idx == ed_idx or (min_temporal_separation > 0 and abs(es_idx - ed_idx) < min_temporal_separation): order = np.argsort(frame_vols) chosen = None for idx in order: if idx == ed_idx: continue if min_temporal_separation > 0 and abs(int(idx) - ed_idx) < min_temporal_separation: continue chosen = int(idx) break es_idx = chosen if chosen is not None else int(order[0] if int(order[0]) != ed_idx else (order[1] if len(order) > 1 else ed_idx)) return int(ed_idx), int(es_idx), frame_vols # ========================================================= # GIF MAKER (NO GT) # ========================================================= def gif_animation_for_patient_pred_only(images_4d, preds_4d, patient_id, ed_idx, es_idx, output_dir): import os import matplotlib.pyplot as plt from matplotlib import animation os.makedirs(output_dir, exist_ok=True) def overlay(ax, img, pred): ax.imshow(img, cmap='gray') ax.imshow((pred == 1), alpha=(pred == 1) * 0.5, cmap='Blues') # myocardium ax.imshow((pred == 2), alpha=(pred == 2) * 0.5, cmap='jet') # blood pool ax.axis('off') H, W, S, F = images_4d.shape fig, axarr = plt.subplots(1, 2, figsize=(8, 4)) plt.tight_layout(rect=[0, 0, 1, 0.92]) # leave space at top for patient ID def update(slice_idx): axarr[0].clear() axarr[1].clear() overlay(axarr[0], images_4d[:, :, slice_idx, ed_idx], preds_4d[:, :, slice_idx, ed_idx]) axarr[0].set_title(f'ED (frame {ed_idx}) | Slice {slice_idx}') # extra spaces before/after '|' overlay(axarr[1], images_4d[:, :, slice_idx, es_idx], preds_4d[:, :, slice_idx, es_idx]) axarr[1].set_title(f'ES (frame {es_idx}) | Slice {slice_idx}') # Large, centered patient ID at the top fig.suptitle(f'Patient ID: {patient_id}', fontsize=14, y=0.98) anim = animation.FuncAnimation(fig, update, frames=S, interval=700) out_path = os.path.join(output_dir, f"{patient_id}_pred.gif") anim.save(out_path, writer='pillow') plt.close(fig) return out_path def _is_dicom_file(path): if not os.path.isfile(path): return False # Fast + robust: try pydicom first; fall back to DICM preamble try: import pydicom try: pydicom.dcmread(path, stop_before_pixels=True, force=True) return True except Exception: return False except ImportError: # fallback check: 128-byte preamble + 'DICM' marker (not guaranteed for all files) try: with open(path, "rb") as f: f.seek(128) return f.read(4) == b"DICM" except Exception: return False def _count_dicoms_here(dir_path, max_check=500): """Count DICOM files directly inside dir_path (non-recursive), up to max_check files.""" count = 0 for name in os.listdir(dir_path): fp = os.path.join(dir_path, name) if os.path.isfile(fp) and _is_dicom_file(fp): count += 1 if count >= max_check: break return count def find_dicom_series_dirs(root, min_files=3): """ Walk the tree and collect directories that contain >= min_files DICOMs directly inside them. Prunes descent once a directory is identified as a series dir. """ series = [] for curr, dirnames, filenames in os.walk(root, topdown=True): # Ignore hidden/OS cruft dirnames[:] = [d for d in dirnames if not d.startswith('.') and d != '__MACOSX'] dicom_count = _count_dicoms_here(curr) if dicom_count >= min_files: series.append((curr, dicom_count)) dirnames[:] = [] # don't descend further under a series dir # sort by most DICOM files first series.sort(key=lambda t: t[1], reverse=True) return series def clear_dir(path): """Remove all contents of a directory, but keep the directory itself.""" if os.path.exists(path): for fname in os.listdir(path): fpath = os.path.join(path, fname) if os.path.isfile(fpath) or os.path.islink(fpath): os.remove(fpath) elif os.path.isdir(fpath): shutil.rmtree(fpath) else: os.makedirs(path, exist_ok=True) def process_zip_and_make_artifacts(uploaded_zip): # ... your existing extraction + processing ... # e.g., write GIF to a BytesIO and CSV to bytes gif_buf = BytesIO() # anim.save(gif_buf, writer="pillow", format="gif"); gif_buf.seek(0) # For demo, pretend we have bytes: # gif_buf.write(b"..."); gif_buf.seek(0) csv_bytes = b"col1,col2\n1,2\n3,4\n" return gif_buf.getvalue(), csv_bytes # ========================================================= # MAIN # ========================================================= def main(): import os # os.makedirs(output_root, exist_ok=True) # os.makedirs(gifs_dir, exist_ok=True) # os.makedirs(csv_dir, exist_ok=True) #os.makedirs("./DICOM_OUTPUTS", exist_ok=True) # put this near the top of your app.py clear_dir("/tmp/out_dicoms") clear_dir("/tmp/DICOM_OUTPUTS/CSV") clear_dir("/tmp/DICOM_OUTPUTS/GIFs") st.header("Data Upload") uploaded_zip = st.file_uploader("Upload ZIP file of MRI folders", type="zip") def extract_zip(zip_path, extract_to): import zipfile, os os.makedirs(extract_to, exist_ok=True) with zipfile.ZipFile(zip_path, 'r') as zip_ref: valid_files = [ f for f in zip_ref.namelist() if "__MACOSX" not in f and not os.path.basename(f).startswith("._") ] zip_ref.extractall(extract_to, members=valid_files) if uploaded_zip is not None and st.button("Process Data"): with st.spinner("Processing ZIP..."): extract_root = '/tmp/out_dicoms' os.makedirs(extract_root, exist_ok=True) zip_path = os.path.join(extract_root, "upload.zip") with open(zip_path, "wb") as f: f.write(uploaded_zip.read()) extract_zip(zip_path, extract_root) # 1) Automatically find the best DICOM series directory series = find_dicom_series_dirs(extract_root, min_files=3) if not series: st.error("No DICOM series found in the uploaded ZIP.") st.stop() # Choose the directory with the most DICOM files dicom_dir, dicom_count = series[0] st.success(f"Detected DICOM folder: {dicom_dir} (≈{dicom_count} files)") all_frames_test_images, per_patient_used_files = load_process_data_no_gt( dicom_dir, SIZE_X, SIZE_Y ) print(f"Total patients found: {len(all_frames_test_images)}") # 2) Load model model = tf.keras.models.load_model(model_path, custom_objects=CUSTOM_OBJECTS) print('Loaded model successfully') print(f"Model input shape: {model.input_shape}") print(f"Model output shape: {model.output_shape}") # 3) For each patient: predict -> (optional clean) -> ED/ES -> GIF -> CSV row rows = [] for pid, images_4d in sorted(all_frames_test_images.items()): print(f"\nProcessing patient: {pid} | 4D shape: {images_4d.shape}") # Spacing & thickness from headers of files used for this stack used_files = per_patient_used_files.get(pid, []) row_mm, col_mm, slice_thickness_mm = read_spacing_thickness_from_files(used_files) print(f"Spacing/thickness: row={row_mm:.4f} mm, col={col_mm:.4f} mm, th={slice_thickness_mm:.4f} mm") # Predict labels preds_4d = predict_patient_images(model, images_4d) # ✅ Optionally clean per-frame predictions to remove islands if ENABLE_ISLAND_REMOVAL: preds_4d = clean_predictions_per_frame( preds_4d, min_slices=ISLAND_MIN_SLICES, min_area=ISLAND_MIN_AREA, distance_threshold=ISLAND_DISTANCE_THRESH ) # Choose ED/ES from predicted blood-pool volumes ed_idx, es_idx, frame_vols = pick_ed_es_from_predictions(preds_4d, row_mm, col_mm, slice_thickness_mm) print(f"Selected ED frame: {ed_idx}, ES frame: {es_idx}") # Cardiac metrics from predictions only (mm^3 == µL) EDV_uL = calculate_volume_from_mask(preds_4d[..., ed_idx], row_mm, col_mm, slice_thickness_mm) ESV_uL = calculate_volume_from_mask(preds_4d[..., es_idx], row_mm, col_mm, slice_thickness_mm) SV_uL = EDV_uL - ESV_uL EF_pct = (SV_uL / EDV_uL * 100.0) if EDV_uL > 0 else 0.0 myo_mass_ED_mg = calculate_myocardium_mass(preds_4d[..., ed_idx], row_mm, col_mm, slice_thickness_mm, myocardium_density) myo_mass_ES_mg = calculate_myocardium_mass(preds_4d[..., es_idx], row_mm, col_mm, slice_thickness_mm, myocardium_density) # GIF (prediction-only overlays) gif_path = gif_animation_for_patient_pred_only(images_4d, preds_4d, pid, ed_idx, es_idx, gifs_dir) st.image(gif_path, caption="Generated GIF", use_container_width=True) print(f"GIF saved: {gif_path}") # CSV row (predictions only) rows.append({ 'Patient_ID': pid, 'EDV_uL': EDV_uL, 'ESV_uL': ESV_uL, 'SV_uL' : SV_uL, 'EF_%' : EF_pct, 'MyocardiumMass_ED_mg': myo_mass_ED_mg, 'MyocardiumMass_ES_mg': myo_mass_ES_mg, 'ED_frame_index': ed_idx, 'ES_frame_index': es_idx, 'PixelSpacing_row_mm': row_mm, 'PixelSpacing_col_mm': col_mm, 'SliceThickness_mm': slice_thickness_mm }) # 4) Save CSV/Excel df = pd.DataFrame(rows) os.makedirs(csv_dir, exist_ok=True) csv_path = os.path.join(csv_dir, 'Cardiac_Volumes_And_Mass_fromPredictions.csv') xlsx_path = os.path.join(csv_dir, 'Cardiac_Volumes_And_Mass_fromPredictions.xlsx') df.to_csv(csv_path, index=False) df.to_excel(xlsx_path, index=False) print(f"\nCSV written: {csv_path}") print(f"Excel written: {xlsx_path}") st.session_state.processed = True if st.session_state.processed == True: with open(gif_path, "rb") as f: st.download_button( label="📥 Download GIF", data=f, file_name="mask.gif", mime="image/gif" ) with open(csv_path, "rb") as f: st.download_button( label="Download CSV", data=f, file_name="/Cardiac_Volumes_And_Mass_fromPredictions.csv", mime="text/csv" ) if __name__ == '__main__': main()