Instructions to use Aditya2162/ivus-segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use Aditya2162/ivus-segmentation with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://Aditya2162/ivus-segmentation") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """Multitask fine-tuning: DeepIVUS lumen segmentation + bifurcation classification. | |
| Uses: | |
| - all bifurcation-labeled samples for classification | |
| - only lumen-polygon samples for segmentation loss (masked per-sample) | |
| Data are loaded on-the-fly from frame bank JSONL + source DICOMs. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import datetime | |
| import json | |
| import math | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from deepivus.config import resolve_lumen_model_dir | |
| from scripts.finetune.shared.common import ( | |
| load_bifurcation_annotations, | |
| load_lumen_annotations, | |
| load_preprocessed_stack, | |
| load_split_ids, | |
| polygon_to_mask, | |
| ) | |
| try: | |
| from tqdm.auto import tqdm | |
| except Exception: | |
| tqdm = None | |
| IMG_MEAN = tf.constant([60.3486], dtype=tf.float32) | |
| def _extract_logits(model_output: object) -> tf.Tensor: | |
| if isinstance(model_output, dict): | |
| if not model_output: | |
| raise RuntimeError("Model returned an empty dict output.") | |
| return next(iter(model_output.values())) | |
| if isinstance(model_output, (tuple, list)): | |
| if not model_output: | |
| raise RuntimeError("Model returned an empty sequence output.") | |
| return model_output[0] | |
| if tf.is_tensor(model_output): | |
| return model_output | |
| raise RuntimeError(f"Unsupported model output type: {type(model_output)!r}") | |
| def _prepare_batch(images: tf.Tensor) -> tf.Tensor: | |
| x = tf.cast(images, tf.float32) - IMG_MEAN | |
| if len(x.shape) < 4: | |
| x = tf.expand_dims(x, axis=3) | |
| if x.shape[-1] != 3: | |
| x = tf.tile(x, [1, 1, 1, 3]) | |
| return x | |
| def _select_trainable_vars( | |
| all_trainable: list[tf.Variable], | |
| include_regex: str, | |
| exclude_regex: str, | |
| max_trainable: int, | |
| ) -> list[tf.Variable]: | |
| include = re.compile(include_regex) if include_regex else None | |
| exclude = re.compile(exclude_regex) if exclude_regex else None | |
| selected: list[tf.Variable] = [] | |
| for var in all_trainable: | |
| name = var.name | |
| if include is not None and not include.search(name): | |
| continue | |
| if exclude is not None and exclude.search(name): | |
| continue | |
| selected.append(var) | |
| if max_trainable > 0 and len(selected) >= max_trainable: | |
| break | |
| return selected | |
| def _binary_logit_from_multiclass(logits: tf.Tensor, lumen_class: int) -> tf.Tensor: | |
| num_classes = int(logits.shape[-1]) | |
| if lumen_class < 0 or lumen_class >= num_classes: | |
| raise ValueError(f"Invalid lumen_class={lumen_class} for num_classes={num_classes}") | |
| positive = logits[..., lumen_class] | |
| negatives = tf.concat([logits[..., :lumen_class], logits[..., lumen_class + 1 :]], axis=-1) | |
| negative_logsumexp = tf.reduce_logsumexp(negatives, axis=-1) | |
| return positive - negative_logsumexp | |
| def _augment_multitask_batch( | |
| images: tf.Tensor, | |
| masks: tf.Tensor, | |
| bif_labels: tf.Tensor, | |
| has_mask: tf.Tensor, | |
| seed: int, | |
| ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: | |
| stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32) | |
| flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 | |
| images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images) | |
| masks = tf.cond(flip_lr, lambda: tf.image.flip_left_right(masks[..., tf.newaxis]), lambda: masks[..., tf.newaxis]) | |
| stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32) | |
| flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 | |
| images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images) | |
| masks = tf.cond(flip_ud, lambda: tf.image.flip_up_down(masks), lambda: masks) | |
| stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32) | |
| k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32) | |
| images = tf.image.rot90(images, k=k) | |
| masks = tf.image.rot90(masks, k=k) | |
| shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08) | |
| shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08) | |
| p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4) | |
| p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4) | |
| base_t = tf.stack([ | |
| tf.constant(1.0, tf.float32), shx, tf.constant(0.0, tf.float32), | |
| shy, tf.constant(1.0, tf.float32), tf.constant(0.0, tf.float32), | |
| p0, p1, | |
| ]) | |
| transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1]) | |
| out_shape = tf.shape(images)[1:3] | |
| images = tf.raw_ops.ImageProjectiveTransformV3( | |
| images=images, | |
| transforms=transforms, | |
| output_shape=out_shape, | |
| interpolation="BILINEAR", | |
| fill_mode="REFLECT", | |
| fill_value=0.0, | |
| ) | |
| masks = tf.raw_ops.ImageProjectiveTransformV3( | |
| images=masks, | |
| transforms=transforms, | |
| output_shape=out_shape, | |
| interpolation="NEAREST", | |
| fill_mode="REFLECT", | |
| fill_value=0.0, | |
| ) | |
| stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32) | |
| images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed) | |
| stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32) | |
| images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed) | |
| images = tf.clip_by_value(images, -255.0, 255.0) | |
| return images, tf.squeeze(masks, axis=-1), bif_labels, has_mask | |
| def _build_dataset( | |
| images: np.ndarray, | |
| masks: np.ndarray, | |
| bif_labels: np.ndarray, | |
| has_mask: np.ndarray, | |
| batch_size: int, | |
| shuffle: bool, | |
| seed: int, | |
| augment: bool, | |
| ) -> tf.data.Dataset: | |
| ds = tf.data.Dataset.from_tensor_slices((images, masks, bif_labels, has_mask)) | |
| if shuffle: | |
| ds = ds.shuffle(buffer_size=len(images), seed=seed, reshuffle_each_iteration=True) | |
| ds = ds.batch(batch_size) | |
| def _map_fn(batch_images, batch_masks, batch_bif, batch_has): | |
| x = _prepare_batch(batch_images) | |
| y_mask = tf.cast(batch_masks, tf.float32) | |
| y_bif = tf.cast(batch_bif, tf.float32) | |
| y_has = tf.cast(batch_has, tf.float32) | |
| return x, y_mask, y_bif, y_has | |
| ds = ds.map(_map_fn, num_parallel_calls=tf.data.AUTOTUNE) | |
| if augment: | |
| ds = ds.enumerate().map( | |
| lambda i, data: _augment_multitask_batch(data[0], data[1], data[2], data[3], seed=seed + tf.cast(i, tf.int32)), | |
| num_parallel_calls=tf.data.AUTOTUNE, | |
| ) | |
| return ds.prefetch(tf.data.AUTOTUNE) | |
| def _tile_to_uint8(tile: np.ndarray) -> np.ndarray: | |
| x = np.asarray(tile, dtype=np.float32) | |
| if x.ndim == 3 and x.shape[-1] > 1: | |
| x = x[..., 0] | |
| if not np.isfinite(x).any(): | |
| return np.zeros_like(x, dtype=np.uint8) | |
| mn = float(np.nanmin(x)) | |
| mx = float(np.nanmax(x)) | |
| if mx <= mn + 1e-6: | |
| return np.clip(x, 0, 255).astype(np.uint8) | |
| x = (x - mn) / (mx - mn) | |
| return (x * 255.0).clip(0, 255).astype(np.uint8) | |
| def _save_augment_grid_from_dataset(ds: tf.data.Dataset, out_path: Path) -> bool: | |
| batch = None | |
| for x, _, _, _ in ds.take(1): | |
| batch = x.numpy() | |
| if batch is None or batch.shape[0] == 0: | |
| return False | |
| n = min(16, int(batch.shape[0])) | |
| h = int(batch.shape[1]) | |
| w = int(batch.shape[2]) | |
| grid = Image.new("L", (4 * w, 4 * h), color=0) | |
| for i in range(n): | |
| r = i // 4 | |
| c = i % 4 | |
| grid.paste(Image.fromarray(_tile_to_uint8(batch[i]), mode="L"), (c * w, r * h)) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| grid.save(out_path) | |
| return True | |
| def _build_arrays(annotations_bif, lumen_map: dict[str, object], diameter: int): | |
| grouped: dict[Path, list] = {} | |
| for ann in annotations_bif: | |
| grouped.setdefault(ann.dicom_path, []).append(ann) | |
| images = [] | |
| masks = [] | |
| bif_labels = [] | |
| has_mask = [] | |
| sample_ids = [] | |
| for dicom_path, ann_list in grouped.items(): | |
| stack = load_preprocessed_stack(dicom_path, diameter=diameter) | |
| h, w = int(stack.shape[1]), int(stack.shape[2]) | |
| for ann in ann_list: | |
| fidx = ann.frame_idx | |
| if fidx < 0 or fidx >= stack.shape[0]: | |
| continue | |
| sid = ann.sample_id | |
| images.append(stack[fidx]) | |
| bif_labels.append(1.0 if ann.bifurcation else 0.0) | |
| sample_ids.append(sid) | |
| lann = lumen_map.get(sid) | |
| if lann is None: | |
| masks.append(np.zeros((h, w), dtype=np.float32)) | |
| has_mask.append(0.0) | |
| else: | |
| m = polygon_to_mask(lann.lumen_x, lann.lumen_y, (h, w)).astype(np.float32) | |
| masks.append(m) | |
| has_mask.append(1.0 if np.any(m > 0.5) else 0.0) | |
| if not images: | |
| raise RuntimeError("No valid samples produced from frame bank.") | |
| return ( | |
| np.stack(images, axis=0), | |
| np.stack(masks, axis=0), | |
| np.asarray(bif_labels, dtype=np.float32), | |
| np.asarray(has_mask, dtype=np.float32), | |
| sample_ids, | |
| ) | |
| def _make_cls_head(num_seg_classes: int) -> tf.keras.Model: | |
| inp = tf.keras.Input(shape=(None, None, num_seg_classes), name="seg_logits") | |
| x = tf.keras.layers.GlobalAveragePooling2D()(inp) | |
| x = tf.keras.layers.Dense(96, activation="relu")(x) | |
| x = tf.keras.layers.Dropout(0.3)(x) | |
| out = tf.keras.layers.Dense(1, activation="sigmoid", name="bifurcation_prob")(x) | |
| return tf.keras.Model(inp, out, name="bifurcation_head") | |
| def _compute_metrics( | |
| base_model, | |
| cls_head, | |
| ds: tf.data.Dataset, | |
| lumen_class: int, | |
| pos_weight: float, | |
| seg_weight: float, | |
| cls_weight: float, | |
| ) -> dict[str, float]: | |
| seg_losses = [] | |
| cls_losses = [] | |
| total_losses = [] | |
| seg_inter = 0.0 | |
| seg_union = 0.0 | |
| seg_pred_sum = 0.0 | |
| seg_gt_sum = 0.0 | |
| seg_count = 0.0 | |
| y_true = [] | |
| y_prob = [] | |
| for x, y_mask, y_bif, y_has in ds: | |
| logits = _extract_logits(base_model(x, training=False)) | |
| resized_logits = tf.image.resize(logits, (tf.shape(y_mask)[1], tf.shape(y_mask)[2])) | |
| bin_logit = _binary_logit_from_multiclass(resized_logits, lumen_class=lumen_class) | |
| bce_map = tf.nn.weighted_cross_entropy_with_logits(labels=y_mask, logits=bin_logit, pos_weight=pos_weight) | |
| bce_per_sample = tf.reduce_mean(bce_map, axis=[1, 2]) | |
| probs_mask = tf.math.sigmoid(bin_logit) | |
| inter = tf.reduce_sum(probs_mask * y_mask, axis=[1, 2]) | |
| denom = tf.reduce_sum(probs_mask + y_mask, axis=[1, 2]) | |
| dice_loss_per_sample = 1.0 - (2.0 * inter + 1e-6) / (denom + 1e-6) | |
| seg_per_sample = bce_per_sample + 0.3 * dice_loss_per_sample | |
| valid = y_has > 0.5 | |
| valid_count = tf.reduce_sum(tf.cast(valid, tf.float32)) | |
| seg_loss = tf.where( | |
| valid_count > 0, | |
| tf.reduce_sum(tf.where(valid, seg_per_sample, tf.zeros_like(seg_per_sample))) / valid_count, | |
| tf.constant(0.0, dtype=tf.float32), | |
| ) | |
| bif_prob = cls_head(resized_logits, training=False) | |
| bif_prob = tf.reshape(bif_prob, [-1]) | |
| cls_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_bif, bif_prob)) | |
| total = seg_weight * seg_loss + cls_weight * cls_loss | |
| seg_losses.append(float(seg_loss.numpy())) | |
| cls_losses.append(float(cls_loss.numpy())) | |
| total_losses.append(float(total.numpy())) | |
| y_true.append(y_bif.numpy().reshape(-1)) | |
| y_prob.append(bif_prob.numpy().reshape(-1)) | |
| pred_mask = probs_mask >= 0.5 | |
| gt_mask = y_mask >= 0.5 | |
| valid_np = valid.numpy().astype(bool) | |
| pred_np = pred_mask.numpy() | |
| gt_np = gt_mask.numpy() | |
| for i in range(pred_np.shape[0]): | |
| if not valid_np[i]: | |
| continue | |
| pi = pred_np[i] | |
| gi = gt_np[i] | |
| seg_inter += float(np.logical_and(pi, gi).sum()) | |
| seg_union += float(np.logical_or(pi, gi).sum()) | |
| seg_pred_sum += float(pi.sum()) | |
| seg_gt_sum += float(gi.sum()) | |
| seg_count += 1.0 | |
| y_true_arr = np.concatenate(y_true) if y_true else np.zeros((0,), dtype=np.float32) | |
| y_prob_arr = np.concatenate(y_prob) if y_prob else np.zeros((0,), dtype=np.float32) | |
| y_pred_arr = (y_prob_arr >= 0.5).astype(np.int32) | |
| y_true_int = y_true_arr.astype(np.int32) | |
| tp = int(np.sum((y_pred_arr == 1) & (y_true_int == 1))) | |
| fp = int(np.sum((y_pred_arr == 1) & (y_true_int == 0))) | |
| fn = int(np.sum((y_pred_arr == 0) & (y_true_int == 1))) | |
| tn = int(np.sum((y_pred_arr == 0) & (y_true_int == 0))) | |
| cls_acc = float((tp + tn) / max(tp + tn + fp + fn, 1)) | |
| cls_prec = float(tp / max(tp + fp, 1)) | |
| cls_rec = float(tp / max(tp + fn, 1)) | |
| cls_f1 = float((2.0 * cls_prec * cls_rec) / max(cls_prec + cls_rec, 1e-12)) | |
| if y_true_arr.size > 1 and len(np.unique(y_true_int)) > 1: | |
| auc_metric = tf.keras.metrics.AUC(curve="ROC") | |
| auc_metric.update_state(y_true_arr, y_prob_arr) | |
| cls_auc = float(auc_metric.result().numpy()) | |
| else: | |
| cls_auc = float("nan") | |
| seg_iou = seg_inter / max(seg_union, 1.0) | |
| seg_dice = (2.0 * seg_inter) / max(seg_pred_sum + seg_gt_sum, 1.0) | |
| return { | |
| "total_loss": float(np.mean(total_losses)) if total_losses else float("nan"), | |
| "seg_loss": float(np.mean(seg_losses)) if seg_losses else float("nan"), | |
| "cls_loss": float(np.mean(cls_losses)) if cls_losses else float("nan"), | |
| "seg_iou": float(seg_iou), | |
| "seg_dice": float(seg_dice), | |
| "seg_count": float(seg_count), | |
| "cls_accuracy": cls_acc, | |
| "cls_precision": cls_prec, | |
| "cls_recall": cls_rec, | |
| "cls_f1": cls_f1, | |
| "cls_auc": cls_auc, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--model-dir", type=Path, default=resolve_lumen_model_dir()) | |
| parser.add_argument("--frame-bank-root", type=Path, default=Path("evals/frame_bank_merged")) | |
| parser.add_argument("--split-json", type=Path, default=Path("evals/splits/ivus_split_merged_600.json")) | |
| parser.add_argument("--output-dir", type=Path, default=Path("models/multitask")) | |
| parser.add_argument("--lumen-class", type=int, default=1) | |
| parser.add_argument("--epochs", type=int, default=50) | |
| parser.add_argument("--batch-size", type=int, default=4) | |
| parser.add_argument("--learning-rate", type=float, default=1e-5) | |
| parser.add_argument("--seg-loss-weight", type=float, default=1.0) | |
| parser.add_argument("--cls-loss-weight", type=float, default=1.0) | |
| parser.add_argument("--augment", action="store_true") | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--diameter", type=int, default=67) | |
| parser.add_argument("--clip-norm", type=float, default=1.0) | |
| parser.add_argument("--include-var-regex", type=str, default=".*") | |
| parser.add_argument("--exclude-var-regex", type=str, default="") | |
| parser.add_argument("--max-trainable", type=int, default=0) | |
| parser.add_argument("--early-stop-patience", type=int, default=10) | |
| parser.add_argument("--early-stop-min-delta", type=float, default=1e-4) | |
| parser.add_argument("--tb-logdir", type=Path, default=Path("output/tensorboard")) | |
| parser.add_argument("--tb-run-name", type=str, default=None) | |
| args = parser.parse_args() | |
| tf.random.set_seed(args.seed) | |
| np.random.seed(args.seed) | |
| bif_anns = load_bifurcation_annotations(args.frame_bank_root) | |
| lumen_anns = load_lumen_annotations(args.frame_bank_root) | |
| if not bif_anns: | |
| raise RuntimeError(f"No bifurcation annotations under: {args.frame_bank_root}") | |
| split_ids = load_split_ids(args.split_json) | |
| train_ids = split_ids["train"] | |
| val_ids = split_ids["val"] | |
| test_ids = split_ids["test"] | |
| bif_anns = [a for a in bif_anns if a.sample_id in (train_ids | val_ids | test_ids)] | |
| lumen_map = {a.sample_id: a for a in lumen_anns} | |
| images, masks, bif_labels, has_mask, sample_ids = _build_arrays( | |
| bif_anns, | |
| lumen_map=lumen_map, | |
| diameter=args.diameter, | |
| ) | |
| idx_train = np.asarray([i for i, sid in enumerate(sample_ids) if sid in train_ids], dtype=np.int64) | |
| idx_val = np.asarray([i for i, sid in enumerate(sample_ids) if sid in val_ids], dtype=np.int64) | |
| idx_test = np.asarray([i for i, sid in enumerate(sample_ids) if sid in test_ids], dtype=np.int64) | |
| if len(idx_train) == 0: | |
| raise RuntimeError("Train split is empty.") | |
| tr_x, tr_m, tr_b, tr_h = images[idx_train], masks[idx_train], bif_labels[idx_train], has_mask[idx_train] | |
| va_x, va_m, va_b, va_h = images[idx_val], masks[idx_val], bif_labels[idx_val], has_mask[idx_val] | |
| te_x, te_m, te_b, te_h = images[idx_test], masks[idx_test], bif_labels[idx_test], has_mask[idx_test] | |
| print( | |
| f"Samples train/val/test: {len(tr_x)}/{len(va_x)}/{len(te_x)} | " | |
| f"lumen-supervised train/val/test: {int(np.sum(tr_h))}/{int(np.sum(va_h))}/{int(np.sum(te_h))}" | |
| ) | |
| train_ds = _build_dataset(tr_x, tr_m, tr_b, tr_h, args.batch_size, shuffle=True, seed=args.seed, augment=args.augment) | |
| val_ds = _build_dataset(va_x, va_m, va_b, va_h, args.batch_size, shuffle=False, seed=args.seed, augment=False) if len(va_x) else None | |
| test_ds = _build_dataset(te_x, te_m, te_b, te_h, args.batch_size, shuffle=False, seed=args.seed, augment=False) if len(te_x) else None | |
| base_model = tf.saved_model.load(str(args.model_dir)) | |
| all_trainable = list(getattr(base_model, "trainable_variables", [])) | |
| if not all_trainable: | |
| raise RuntimeError("Loaded base model has no trainable variables.") | |
| base_vars = _select_trainable_vars( | |
| all_trainable, | |
| include_regex=args.include_var_regex, | |
| exclude_regex=args.exclude_var_regex, | |
| max_trainable=args.max_trainable, | |
| ) | |
| if not base_vars: | |
| raise RuntimeError("No base-model trainable variables selected.") | |
| # Build class head on top of segmentation logits. | |
| probe = _extract_logits(base_model(_prepare_batch(tf.convert_to_tensor(tr_x[:1], dtype=tf.float32)), training=False)) | |
| num_seg_classes = int(probe.shape[-1]) | |
| cls_head = _make_cls_head(num_seg_classes=num_seg_classes) | |
| pos_pixels = float(np.sum(tr_m > 0.5)) | |
| total_pixels = float(np.prod(tr_m.shape)) | |
| neg_pixels = max(total_pixels - pos_pixels, 1.0) | |
| pos_weight = float(np.clip(neg_pixels / max(pos_pixels, 1.0), 1.0, 40.0)) | |
| optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| run_name = args.tb_run_name or datetime.datetime.now().strftime("multitask_%Y%m%d_%H%M%S") | |
| tb_run_dir = args.tb_logdir / run_name | |
| tb_run_dir.mkdir(parents=True, exist_ok=True) | |
| tb_writer = tf.summary.create_file_writer(str(tb_run_dir)) | |
| print(f"TensorBoard logdir: {tb_run_dir}") | |
| if args.augment: | |
| aug_preview_path = tb_run_dir / "augment_preview_4x4.png" | |
| if _save_augment_grid_from_dataset(train_ds, aug_preview_path): | |
| print(f"Saved augmentation preview: {aug_preview_path}") | |
| ckpt_dir = args.output_dir / "checkpoints" | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| ckpt = tf.train.Checkpoint(base=base_model, cls_head=cls_head, optimizer=optimizer) | |
| manager = tf.train.CheckpointManager(ckpt, directory=str(ckpt_dir), max_to_keep=1) | |
| steps_per_epoch = int(math.ceil(len(tr_x) / float(args.batch_size))) | |
| best_val = None | |
| wait = 0 | |
| history = [] | |
| for epoch in range(1, args.epochs + 1): | |
| if tqdm is not None: | |
| pbar = tqdm(total=steps_per_epoch, desc=f"Epoch {epoch}/{args.epochs}", leave=False) | |
| else: | |
| pbar = None | |
| epoch_total = [] | |
| epoch_seg = [] | |
| epoch_cls = [] | |
| for step, (x, y_mask, y_bif, y_has) in enumerate(train_ds, start=1): | |
| with tf.GradientTape() as tape: | |
| logits = _extract_logits(base_model(x, training=True)) | |
| resized_logits = tf.image.resize(logits, (tf.shape(y_mask)[1], tf.shape(y_mask)[2])) | |
| bin_logit = _binary_logit_from_multiclass(resized_logits, lumen_class=args.lumen_class) | |
| bce_map = tf.nn.weighted_cross_entropy_with_logits(labels=y_mask, logits=bin_logit, pos_weight=pos_weight) | |
| bce_per_sample = tf.reduce_mean(bce_map, axis=[1, 2]) | |
| probs_mask = tf.math.sigmoid(bin_logit) | |
| inter = tf.reduce_sum(probs_mask * y_mask, axis=[1, 2]) | |
| denom = tf.reduce_sum(probs_mask + y_mask, axis=[1, 2]) | |
| dice_loss_per_sample = 1.0 - (2.0 * inter + 1e-6) / (denom + 1e-6) | |
| seg_per_sample = bce_per_sample + 0.3 * dice_loss_per_sample | |
| valid = y_has > 0.5 | |
| valid_count = tf.reduce_sum(tf.cast(valid, tf.float32)) | |
| seg_loss = tf.where( | |
| valid_count > 0, | |
| tf.reduce_sum(tf.where(valid, seg_per_sample, tf.zeros_like(seg_per_sample))) / valid_count, | |
| tf.constant(0.0, dtype=tf.float32), | |
| ) | |
| bif_prob = cls_head(resized_logits, training=True) | |
| bif_prob = tf.reshape(bif_prob, [-1]) | |
| cls_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_bif, bif_prob)) | |
| total_loss = args.seg_loss_weight * seg_loss + args.cls_loss_weight * cls_loss | |
| train_vars = base_vars + cls_head.trainable_variables | |
| grads = tape.gradient(total_loss, train_vars) | |
| grad_var_pairs = [(g, v) for g, v in zip(grads, train_vars) if g is not None] | |
| if args.clip_norm > 0 and grad_var_pairs: | |
| grad_tensors = [g for g, _ in grad_var_pairs] | |
| clipped, _ = tf.clip_by_global_norm(grad_tensors, clip_norm=args.clip_norm) | |
| grad_var_pairs = list(zip(clipped, [v for _, v in grad_var_pairs])) | |
| optimizer.apply_gradients(grad_var_pairs) | |
| epoch_total.append(float(total_loss.numpy())) | |
| epoch_seg.append(float(seg_loss.numpy())) | |
| epoch_cls.append(float(cls_loss.numpy())) | |
| if pbar is not None: | |
| pbar.update(1) | |
| pbar.set_postfix(loss=f"{epoch_total[-1]:.4f}", seg=f"{epoch_seg[-1]:.4f}", cls=f"{epoch_cls[-1]:.4f}") | |
| elif step % 10 == 0: | |
| print(f"Step {step}/{steps_per_epoch} loss={epoch_total[-1]:.4f} seg={epoch_seg[-1]:.4f} cls={epoch_cls[-1]:.4f}") | |
| if pbar is not None: | |
| pbar.close() | |
| train_metrics = { | |
| "total_loss": float(np.mean(epoch_total)) if epoch_total else float("nan"), | |
| "seg_loss": float(np.mean(epoch_seg)) if epoch_seg else float("nan"), | |
| "cls_loss": float(np.mean(epoch_cls)) if epoch_cls else float("nan"), | |
| } | |
| val_metrics = _compute_metrics( | |
| base_model, | |
| cls_head, | |
| val_ds, | |
| lumen_class=args.lumen_class, | |
| pos_weight=pos_weight, | |
| seg_weight=args.seg_loss_weight, | |
| cls_weight=args.cls_loss_weight, | |
| ) if val_ds is not None else {"total_loss": float("nan")} | |
| print( | |
| f"Epoch {epoch}/{args.epochs} train_total={train_metrics['total_loss']:.4f} " | |
| f"val_total={val_metrics.get('total_loss', float('nan')):.4f} " | |
| f"val_seg_dice={val_metrics.get('seg_dice', float('nan')):.4f} " | |
| f"val_cls_auc={val_metrics.get('cls_auc', float('nan')):.4f}" | |
| ) | |
| history_row = {"epoch": epoch, **train_metrics} | |
| history_row.update({f"val_{k}": v for k, v in val_metrics.items()}) | |
| history.append(history_row) | |
| with tb_writer.as_default(): | |
| for k, v in train_metrics.items(): | |
| tf.summary.scalar(f"train/{k}", v, step=epoch) | |
| for k, v in val_metrics.items(): | |
| tf.summary.scalar(f"val/{k}", v, step=epoch) | |
| monitor = val_metrics.get("total_loss", float("nan")) | |
| improved = np.isfinite(monitor) and (best_val is None or monitor < (best_val - args.early_stop_min_delta)) | |
| if improved: | |
| best_val = float(monitor) | |
| wait = 0 | |
| manager.save(checkpoint_number=epoch) | |
| print(f"Saved best checkpoint at epoch {epoch} (val_total_loss={best_val:.6f})") | |
| else: | |
| wait += 1 | |
| print(f"No val improvement: wait {wait}/{args.early_stop_patience}") | |
| if wait >= args.early_stop_patience: | |
| print("Early stopping triggered.") | |
| break | |
| tb_writer.flush() | |
| if manager.latest_checkpoint: | |
| ckpt.restore(manager.latest_checkpoint).expect_partial() | |
| print(f"Restored best checkpoint: {manager.latest_checkpoint}") | |
| base_out = args.output_dir / "lumen_multitask_base" | |
| cls_out = args.output_dir / "bifurcation_head.keras" | |
| tf.saved_model.save(base_model, str(base_out)) | |
| cls_head.save(str(cls_out)) | |
| test_metrics = _compute_metrics( | |
| base_model, | |
| cls_head, | |
| test_ds, | |
| lumen_class=args.lumen_class, | |
| pos_weight=pos_weight, | |
| seg_weight=args.seg_loss_weight, | |
| cls_weight=args.cls_loss_weight, | |
| ) if test_ds is not None else {} | |
| with tb_writer.as_default(): | |
| for k, v in test_metrics.items(): | |
| tf.summary.scalar(f"test/{k}", v, step=len(history)) | |
| tb_writer.flush() | |
| tb_writer.close() | |
| summary = { | |
| "model_dir": str(args.model_dir), | |
| "output_dir": str(args.output_dir), | |
| "split_json": str(args.split_json), | |
| "num_train": int(len(tr_x)), | |
| "num_val": int(len(va_x)), | |
| "num_test": int(len(te_x)), | |
| "num_train_with_lumen": int(np.sum(tr_h > 0.5)), | |
| "num_val_with_lumen": int(np.sum(va_h > 0.5)), | |
| "num_test_with_lumen": int(np.sum(te_h > 0.5)), | |
| "selected_trainable_variables": [v.name for v in base_vars], | |
| "lumen_class": int(args.lumen_class), | |
| "pos_weight": float(pos_weight), | |
| "seg_loss_weight": float(args.seg_loss_weight), | |
| "cls_loss_weight": float(args.cls_loss_weight), | |
| "learning_rate": float(args.learning_rate), | |
| "epochs": int(args.epochs), | |
| "early_stop_patience": int(args.early_stop_patience), | |
| "early_stop_min_delta": float(args.early_stop_min_delta), | |
| "best_val_total_loss": None if best_val is None else float(best_val), | |
| "tensorboard_run_dir": str(tb_run_dir), | |
| "saved_base_model": str(base_out), | |
| "saved_cls_head": str(cls_out), | |
| "history": history, | |
| "test_metrics": test_metrics, | |
| } | |
| summary_path = args.output_dir / "multitask_summary.json" | |
| with summary_path.open("w", encoding="utf-8") as fp: | |
| json.dump(summary, fp, indent=2) | |
| print(f"Saved multitask base model: {base_out}") | |
| print(f"Saved multitask classifier head: {cls_out}") | |
| print(f"Saved summary: {summary_path}") | |
| if __name__ == "__main__": | |
| main() | |