#!/usr/bin/env python3 """Multitask fine-tuning: DeepIVUS lumen segmentation + bifurcation classification. Uses: - all bifurcation-labeled samples for classification - only lumen-polygon samples for segmentation loss (masked per-sample) Data are loaded on-the-fly from frame bank JSONL + source DICOMs. """ from __future__ import annotations import argparse import datetime import json import math import re from pathlib import Path import numpy as np import tensorflow as tf from PIL import Image from deepivus.config import resolve_lumen_model_dir from scripts.finetune.shared.common import ( load_bifurcation_annotations, load_lumen_annotations, load_preprocessed_stack, load_split_ids, polygon_to_mask, ) try: from tqdm.auto import tqdm except Exception: tqdm = None IMG_MEAN = tf.constant([60.3486], dtype=tf.float32) def _extract_logits(model_output: object) -> tf.Tensor: if isinstance(model_output, dict): if not model_output: raise RuntimeError("Model returned an empty dict output.") return next(iter(model_output.values())) if isinstance(model_output, (tuple, list)): if not model_output: raise RuntimeError("Model returned an empty sequence output.") return model_output[0] if tf.is_tensor(model_output): return model_output raise RuntimeError(f"Unsupported model output type: {type(model_output)!r}") def _prepare_batch(images: tf.Tensor) -> tf.Tensor: x = tf.cast(images, tf.float32) - IMG_MEAN if len(x.shape) < 4: x = tf.expand_dims(x, axis=3) if x.shape[-1] != 3: x = tf.tile(x, [1, 1, 1, 3]) return x def _select_trainable_vars( all_trainable: list[tf.Variable], include_regex: str, exclude_regex: str, max_trainable: int, ) -> list[tf.Variable]: include = re.compile(include_regex) if include_regex else None exclude = re.compile(exclude_regex) if exclude_regex else None selected: list[tf.Variable] = [] for var in all_trainable: name = var.name if include is not None and not include.search(name): continue if exclude is not None and exclude.search(name): continue selected.append(var) if max_trainable > 0 and len(selected) >= max_trainable: break return selected def _binary_logit_from_multiclass(logits: tf.Tensor, lumen_class: int) -> tf.Tensor: num_classes = int(logits.shape[-1]) if lumen_class < 0 or lumen_class >= num_classes: raise ValueError(f"Invalid lumen_class={lumen_class} for num_classes={num_classes}") positive = logits[..., lumen_class] negatives = tf.concat([logits[..., :lumen_class], logits[..., lumen_class + 1 :]], axis=-1) negative_logsumexp = tf.reduce_logsumexp(negatives, axis=-1) return positive - negative_logsumexp def _augment_multitask_batch( images: tf.Tensor, masks: tf.Tensor, bif_labels: tf.Tensor, has_mask: tf.Tensor, seed: int, ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32) flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images) masks = tf.cond(flip_lr, lambda: tf.image.flip_left_right(masks[..., tf.newaxis]), lambda: masks[..., tf.newaxis]) stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32) flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images) masks = tf.cond(flip_ud, lambda: tf.image.flip_up_down(masks), lambda: masks) stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32) k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32) images = tf.image.rot90(images, k=k) masks = tf.image.rot90(masks, k=k) shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08) shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08) p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4) p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4) base_t = tf.stack([ tf.constant(1.0, tf.float32), shx, tf.constant(0.0, tf.float32), shy, tf.constant(1.0, tf.float32), tf.constant(0.0, tf.float32), p0, p1, ]) transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1]) out_shape = tf.shape(images)[1:3] images = tf.raw_ops.ImageProjectiveTransformV3( images=images, transforms=transforms, output_shape=out_shape, interpolation="BILINEAR", fill_mode="REFLECT", fill_value=0.0, ) masks = tf.raw_ops.ImageProjectiveTransformV3( images=masks, transforms=transforms, output_shape=out_shape, interpolation="NEAREST", fill_mode="REFLECT", fill_value=0.0, ) stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32) images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed) stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32) images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed) images = tf.clip_by_value(images, -255.0, 255.0) return images, tf.squeeze(masks, axis=-1), bif_labels, has_mask def _build_dataset( images: np.ndarray, masks: np.ndarray, bif_labels: np.ndarray, has_mask: np.ndarray, batch_size: int, shuffle: bool, seed: int, augment: bool, ) -> tf.data.Dataset: ds = tf.data.Dataset.from_tensor_slices((images, masks, bif_labels, has_mask)) if shuffle: ds = ds.shuffle(buffer_size=len(images), seed=seed, reshuffle_each_iteration=True) ds = ds.batch(batch_size) def _map_fn(batch_images, batch_masks, batch_bif, batch_has): x = _prepare_batch(batch_images) y_mask = tf.cast(batch_masks, tf.float32) y_bif = tf.cast(batch_bif, tf.float32) y_has = tf.cast(batch_has, tf.float32) return x, y_mask, y_bif, y_has ds = ds.map(_map_fn, num_parallel_calls=tf.data.AUTOTUNE) if augment: ds = ds.enumerate().map( lambda i, data: _augment_multitask_batch(data[0], data[1], data[2], data[3], seed=seed + tf.cast(i, tf.int32)), num_parallel_calls=tf.data.AUTOTUNE, ) return ds.prefetch(tf.data.AUTOTUNE) def _tile_to_uint8(tile: np.ndarray) -> np.ndarray: x = np.asarray(tile, dtype=np.float32) if x.ndim == 3 and x.shape[-1] > 1: x = x[..., 0] if not np.isfinite(x).any(): return np.zeros_like(x, dtype=np.uint8) mn = float(np.nanmin(x)) mx = float(np.nanmax(x)) if mx <= mn + 1e-6: return np.clip(x, 0, 255).astype(np.uint8) x = (x - mn) / (mx - mn) return (x * 255.0).clip(0, 255).astype(np.uint8) def _save_augment_grid_from_dataset(ds: tf.data.Dataset, out_path: Path) -> bool: batch = None for x, _, _, _ in ds.take(1): batch = x.numpy() if batch is None or batch.shape[0] == 0: return False n = min(16, int(batch.shape[0])) h = int(batch.shape[1]) w = int(batch.shape[2]) grid = Image.new("L", (4 * w, 4 * h), color=0) for i in range(n): r = i // 4 c = i % 4 grid.paste(Image.fromarray(_tile_to_uint8(batch[i]), mode="L"), (c * w, r * h)) out_path.parent.mkdir(parents=True, exist_ok=True) grid.save(out_path) return True def _build_arrays(annotations_bif, lumen_map: dict[str, object], diameter: int): grouped: dict[Path, list] = {} for ann in annotations_bif: grouped.setdefault(ann.dicom_path, []).append(ann) images = [] masks = [] bif_labels = [] has_mask = [] sample_ids = [] for dicom_path, ann_list in grouped.items(): stack = load_preprocessed_stack(dicom_path, diameter=diameter) h, w = int(stack.shape[1]), int(stack.shape[2]) for ann in ann_list: fidx = ann.frame_idx if fidx < 0 or fidx >= stack.shape[0]: continue sid = ann.sample_id images.append(stack[fidx]) bif_labels.append(1.0 if ann.bifurcation else 0.0) sample_ids.append(sid) lann = lumen_map.get(sid) if lann is None: masks.append(np.zeros((h, w), dtype=np.float32)) has_mask.append(0.0) else: m = polygon_to_mask(lann.lumen_x, lann.lumen_y, (h, w)).astype(np.float32) masks.append(m) has_mask.append(1.0 if np.any(m > 0.5) else 0.0) if not images: raise RuntimeError("No valid samples produced from frame bank.") return ( np.stack(images, axis=0), np.stack(masks, axis=0), np.asarray(bif_labels, dtype=np.float32), np.asarray(has_mask, dtype=np.float32), sample_ids, ) def _make_cls_head(num_seg_classes: int) -> tf.keras.Model: inp = tf.keras.Input(shape=(None, None, num_seg_classes), name="seg_logits") x = tf.keras.layers.GlobalAveragePooling2D()(inp) x = tf.keras.layers.Dense(96, activation="relu")(x) x = tf.keras.layers.Dropout(0.3)(x) out = tf.keras.layers.Dense(1, activation="sigmoid", name="bifurcation_prob")(x) return tf.keras.Model(inp, out, name="bifurcation_head") def _compute_metrics( base_model, cls_head, ds: tf.data.Dataset, lumen_class: int, pos_weight: float, seg_weight: float, cls_weight: float, ) -> dict[str, float]: seg_losses = [] cls_losses = [] total_losses = [] seg_inter = 0.0 seg_union = 0.0 seg_pred_sum = 0.0 seg_gt_sum = 0.0 seg_count = 0.0 y_true = [] y_prob = [] for x, y_mask, y_bif, y_has in ds: logits = _extract_logits(base_model(x, training=False)) resized_logits = tf.image.resize(logits, (tf.shape(y_mask)[1], tf.shape(y_mask)[2])) bin_logit = _binary_logit_from_multiclass(resized_logits, lumen_class=lumen_class) bce_map = tf.nn.weighted_cross_entropy_with_logits(labels=y_mask, logits=bin_logit, pos_weight=pos_weight) bce_per_sample = tf.reduce_mean(bce_map, axis=[1, 2]) probs_mask = tf.math.sigmoid(bin_logit) inter = tf.reduce_sum(probs_mask * y_mask, axis=[1, 2]) denom = tf.reduce_sum(probs_mask + y_mask, axis=[1, 2]) dice_loss_per_sample = 1.0 - (2.0 * inter + 1e-6) / (denom + 1e-6) seg_per_sample = bce_per_sample + 0.3 * dice_loss_per_sample valid = y_has > 0.5 valid_count = tf.reduce_sum(tf.cast(valid, tf.float32)) seg_loss = tf.where( valid_count > 0, tf.reduce_sum(tf.where(valid, seg_per_sample, tf.zeros_like(seg_per_sample))) / valid_count, tf.constant(0.0, dtype=tf.float32), ) bif_prob = cls_head(resized_logits, training=False) bif_prob = tf.reshape(bif_prob, [-1]) cls_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_bif, bif_prob)) total = seg_weight * seg_loss + cls_weight * cls_loss seg_losses.append(float(seg_loss.numpy())) cls_losses.append(float(cls_loss.numpy())) total_losses.append(float(total.numpy())) y_true.append(y_bif.numpy().reshape(-1)) y_prob.append(bif_prob.numpy().reshape(-1)) pred_mask = probs_mask >= 0.5 gt_mask = y_mask >= 0.5 valid_np = valid.numpy().astype(bool) pred_np = pred_mask.numpy() gt_np = gt_mask.numpy() for i in range(pred_np.shape[0]): if not valid_np[i]: continue pi = pred_np[i] gi = gt_np[i] seg_inter += float(np.logical_and(pi, gi).sum()) seg_union += float(np.logical_or(pi, gi).sum()) seg_pred_sum += float(pi.sum()) seg_gt_sum += float(gi.sum()) seg_count += 1.0 y_true_arr = np.concatenate(y_true) if y_true else np.zeros((0,), dtype=np.float32) y_prob_arr = np.concatenate(y_prob) if y_prob else np.zeros((0,), dtype=np.float32) y_pred_arr = (y_prob_arr >= 0.5).astype(np.int32) y_true_int = y_true_arr.astype(np.int32) tp = int(np.sum((y_pred_arr == 1) & (y_true_int == 1))) fp = int(np.sum((y_pred_arr == 1) & (y_true_int == 0))) fn = int(np.sum((y_pred_arr == 0) & (y_true_int == 1))) tn = int(np.sum((y_pred_arr == 0) & (y_true_int == 0))) cls_acc = float((tp + tn) / max(tp + tn + fp + fn, 1)) cls_prec = float(tp / max(tp + fp, 1)) cls_rec = float(tp / max(tp + fn, 1)) cls_f1 = float((2.0 * cls_prec * cls_rec) / max(cls_prec + cls_rec, 1e-12)) if y_true_arr.size > 1 and len(np.unique(y_true_int)) > 1: auc_metric = tf.keras.metrics.AUC(curve="ROC") auc_metric.update_state(y_true_arr, y_prob_arr) cls_auc = float(auc_metric.result().numpy()) else: cls_auc = float("nan") seg_iou = seg_inter / max(seg_union, 1.0) seg_dice = (2.0 * seg_inter) / max(seg_pred_sum + seg_gt_sum, 1.0) return { "total_loss": float(np.mean(total_losses)) if total_losses else float("nan"), "seg_loss": float(np.mean(seg_losses)) if seg_losses else float("nan"), "cls_loss": float(np.mean(cls_losses)) if cls_losses else float("nan"), "seg_iou": float(seg_iou), "seg_dice": float(seg_dice), "seg_count": float(seg_count), "cls_accuracy": cls_acc, "cls_precision": cls_prec, "cls_recall": cls_rec, "cls_f1": cls_f1, "cls_auc": cls_auc, } def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model-dir", type=Path, default=resolve_lumen_model_dir()) parser.add_argument("--frame-bank-root", type=Path, default=Path("evals/frame_bank_merged")) parser.add_argument("--split-json", type=Path, default=Path("evals/splits/ivus_split_merged_600.json")) parser.add_argument("--output-dir", type=Path, default=Path("models/multitask")) parser.add_argument("--lumen-class", type=int, default=1) parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--learning-rate", type=float, default=1e-5) parser.add_argument("--seg-loss-weight", type=float, default=1.0) parser.add_argument("--cls-loss-weight", type=float, default=1.0) parser.add_argument("--augment", action="store_true") parser.add_argument("--seed", type=int, default=7) parser.add_argument("--diameter", type=int, default=67) parser.add_argument("--clip-norm", type=float, default=1.0) parser.add_argument("--include-var-regex", type=str, default=".*") parser.add_argument("--exclude-var-regex", type=str, default="") parser.add_argument("--max-trainable", type=int, default=0) parser.add_argument("--early-stop-patience", type=int, default=10) parser.add_argument("--early-stop-min-delta", type=float, default=1e-4) parser.add_argument("--tb-logdir", type=Path, default=Path("output/tensorboard")) parser.add_argument("--tb-run-name", type=str, default=None) args = parser.parse_args() tf.random.set_seed(args.seed) np.random.seed(args.seed) bif_anns = load_bifurcation_annotations(args.frame_bank_root) lumen_anns = load_lumen_annotations(args.frame_bank_root) if not bif_anns: raise RuntimeError(f"No bifurcation annotations under: {args.frame_bank_root}") split_ids = load_split_ids(args.split_json) train_ids = split_ids["train"] val_ids = split_ids["val"] test_ids = split_ids["test"] bif_anns = [a for a in bif_anns if a.sample_id in (train_ids | val_ids | test_ids)] lumen_map = {a.sample_id: a for a in lumen_anns} images, masks, bif_labels, has_mask, sample_ids = _build_arrays( bif_anns, lumen_map=lumen_map, diameter=args.diameter, ) idx_train = np.asarray([i for i, sid in enumerate(sample_ids) if sid in train_ids], dtype=np.int64) idx_val = np.asarray([i for i, sid in enumerate(sample_ids) if sid in val_ids], dtype=np.int64) idx_test = np.asarray([i for i, sid in enumerate(sample_ids) if sid in test_ids], dtype=np.int64) if len(idx_train) == 0: raise RuntimeError("Train split is empty.") tr_x, tr_m, tr_b, tr_h = images[idx_train], masks[idx_train], bif_labels[idx_train], has_mask[idx_train] va_x, va_m, va_b, va_h = images[idx_val], masks[idx_val], bif_labels[idx_val], has_mask[idx_val] te_x, te_m, te_b, te_h = images[idx_test], masks[idx_test], bif_labels[idx_test], has_mask[idx_test] print( f"Samples train/val/test: {len(tr_x)}/{len(va_x)}/{len(te_x)} | " f"lumen-supervised train/val/test: {int(np.sum(tr_h))}/{int(np.sum(va_h))}/{int(np.sum(te_h))}" ) train_ds = _build_dataset(tr_x, tr_m, tr_b, tr_h, args.batch_size, shuffle=True, seed=args.seed, augment=args.augment) val_ds = _build_dataset(va_x, va_m, va_b, va_h, args.batch_size, shuffle=False, seed=args.seed, augment=False) if len(va_x) else None test_ds = _build_dataset(te_x, te_m, te_b, te_h, args.batch_size, shuffle=False, seed=args.seed, augment=False) if len(te_x) else None base_model = tf.saved_model.load(str(args.model_dir)) all_trainable = list(getattr(base_model, "trainable_variables", [])) if not all_trainable: raise RuntimeError("Loaded base model has no trainable variables.") base_vars = _select_trainable_vars( all_trainable, include_regex=args.include_var_regex, exclude_regex=args.exclude_var_regex, max_trainable=args.max_trainable, ) if not base_vars: raise RuntimeError("No base-model trainable variables selected.") # Build class head on top of segmentation logits. probe = _extract_logits(base_model(_prepare_batch(tf.convert_to_tensor(tr_x[:1], dtype=tf.float32)), training=False)) num_seg_classes = int(probe.shape[-1]) cls_head = _make_cls_head(num_seg_classes=num_seg_classes) pos_pixels = float(np.sum(tr_m > 0.5)) total_pixels = float(np.prod(tr_m.shape)) neg_pixels = max(total_pixels - pos_pixels, 1.0) pos_weight = float(np.clip(neg_pixels / max(pos_pixels, 1.0), 1.0, 40.0)) optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) args.output_dir.mkdir(parents=True, exist_ok=True) run_name = args.tb_run_name or datetime.datetime.now().strftime("multitask_%Y%m%d_%H%M%S") tb_run_dir = args.tb_logdir / run_name tb_run_dir.mkdir(parents=True, exist_ok=True) tb_writer = tf.summary.create_file_writer(str(tb_run_dir)) print(f"TensorBoard logdir: {tb_run_dir}") if args.augment: aug_preview_path = tb_run_dir / "augment_preview_4x4.png" if _save_augment_grid_from_dataset(train_ds, aug_preview_path): print(f"Saved augmentation preview: {aug_preview_path}") ckpt_dir = args.output_dir / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) ckpt = tf.train.Checkpoint(base=base_model, cls_head=cls_head, optimizer=optimizer) manager = tf.train.CheckpointManager(ckpt, directory=str(ckpt_dir), max_to_keep=1) steps_per_epoch = int(math.ceil(len(tr_x) / float(args.batch_size))) best_val = None wait = 0 history = [] for epoch in range(1, args.epochs + 1): if tqdm is not None: pbar = tqdm(total=steps_per_epoch, desc=f"Epoch {epoch}/{args.epochs}", leave=False) else: pbar = None epoch_total = [] epoch_seg = [] epoch_cls = [] for step, (x, y_mask, y_bif, y_has) in enumerate(train_ds, start=1): with tf.GradientTape() as tape: logits = _extract_logits(base_model(x, training=True)) resized_logits = tf.image.resize(logits, (tf.shape(y_mask)[1], tf.shape(y_mask)[2])) bin_logit = _binary_logit_from_multiclass(resized_logits, lumen_class=args.lumen_class) bce_map = tf.nn.weighted_cross_entropy_with_logits(labels=y_mask, logits=bin_logit, pos_weight=pos_weight) bce_per_sample = tf.reduce_mean(bce_map, axis=[1, 2]) probs_mask = tf.math.sigmoid(bin_logit) inter = tf.reduce_sum(probs_mask * y_mask, axis=[1, 2]) denom = tf.reduce_sum(probs_mask + y_mask, axis=[1, 2]) dice_loss_per_sample = 1.0 - (2.0 * inter + 1e-6) / (denom + 1e-6) seg_per_sample = bce_per_sample + 0.3 * dice_loss_per_sample valid = y_has > 0.5 valid_count = tf.reduce_sum(tf.cast(valid, tf.float32)) seg_loss = tf.where( valid_count > 0, tf.reduce_sum(tf.where(valid, seg_per_sample, tf.zeros_like(seg_per_sample))) / valid_count, tf.constant(0.0, dtype=tf.float32), ) bif_prob = cls_head(resized_logits, training=True) bif_prob = tf.reshape(bif_prob, [-1]) cls_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_bif, bif_prob)) total_loss = args.seg_loss_weight * seg_loss + args.cls_loss_weight * cls_loss train_vars = base_vars + cls_head.trainable_variables grads = tape.gradient(total_loss, train_vars) grad_var_pairs = [(g, v) for g, v in zip(grads, train_vars) if g is not None] if args.clip_norm > 0 and grad_var_pairs: grad_tensors = [g for g, _ in grad_var_pairs] clipped, _ = tf.clip_by_global_norm(grad_tensors, clip_norm=args.clip_norm) grad_var_pairs = list(zip(clipped, [v for _, v in grad_var_pairs])) optimizer.apply_gradients(grad_var_pairs) epoch_total.append(float(total_loss.numpy())) epoch_seg.append(float(seg_loss.numpy())) epoch_cls.append(float(cls_loss.numpy())) if pbar is not None: pbar.update(1) pbar.set_postfix(loss=f"{epoch_total[-1]:.4f}", seg=f"{epoch_seg[-1]:.4f}", cls=f"{epoch_cls[-1]:.4f}") elif step % 10 == 0: print(f"Step {step}/{steps_per_epoch} loss={epoch_total[-1]:.4f} seg={epoch_seg[-1]:.4f} cls={epoch_cls[-1]:.4f}") if pbar is not None: pbar.close() train_metrics = { "total_loss": float(np.mean(epoch_total)) if epoch_total else float("nan"), "seg_loss": float(np.mean(epoch_seg)) if epoch_seg else float("nan"), "cls_loss": float(np.mean(epoch_cls)) if epoch_cls else float("nan"), } val_metrics = _compute_metrics( base_model, cls_head, val_ds, lumen_class=args.lumen_class, pos_weight=pos_weight, seg_weight=args.seg_loss_weight, cls_weight=args.cls_loss_weight, ) if val_ds is not None else {"total_loss": float("nan")} print( f"Epoch {epoch}/{args.epochs} train_total={train_metrics['total_loss']:.4f} " f"val_total={val_metrics.get('total_loss', float('nan')):.4f} " f"val_seg_dice={val_metrics.get('seg_dice', float('nan')):.4f} " f"val_cls_auc={val_metrics.get('cls_auc', float('nan')):.4f}" ) history_row = {"epoch": epoch, **train_metrics} history_row.update({f"val_{k}": v for k, v in val_metrics.items()}) history.append(history_row) with tb_writer.as_default(): for k, v in train_metrics.items(): tf.summary.scalar(f"train/{k}", v, step=epoch) for k, v in val_metrics.items(): tf.summary.scalar(f"val/{k}", v, step=epoch) monitor = val_metrics.get("total_loss", float("nan")) improved = np.isfinite(monitor) and (best_val is None or monitor < (best_val - args.early_stop_min_delta)) if improved: best_val = float(monitor) wait = 0 manager.save(checkpoint_number=epoch) print(f"Saved best checkpoint at epoch {epoch} (val_total_loss={best_val:.6f})") else: wait += 1 print(f"No val improvement: wait {wait}/{args.early_stop_patience}") if wait >= args.early_stop_patience: print("Early stopping triggered.") break tb_writer.flush() if manager.latest_checkpoint: ckpt.restore(manager.latest_checkpoint).expect_partial() print(f"Restored best checkpoint: {manager.latest_checkpoint}") base_out = args.output_dir / "lumen_multitask_base" cls_out = args.output_dir / "bifurcation_head.keras" tf.saved_model.save(base_model, str(base_out)) cls_head.save(str(cls_out)) test_metrics = _compute_metrics( base_model, cls_head, test_ds, lumen_class=args.lumen_class, pos_weight=pos_weight, seg_weight=args.seg_loss_weight, cls_weight=args.cls_loss_weight, ) if test_ds is not None else {} with tb_writer.as_default(): for k, v in test_metrics.items(): tf.summary.scalar(f"test/{k}", v, step=len(history)) tb_writer.flush() tb_writer.close() summary = { "model_dir": str(args.model_dir), "output_dir": str(args.output_dir), "split_json": str(args.split_json), "num_train": int(len(tr_x)), "num_val": int(len(va_x)), "num_test": int(len(te_x)), "num_train_with_lumen": int(np.sum(tr_h > 0.5)), "num_val_with_lumen": int(np.sum(va_h > 0.5)), "num_test_with_lumen": int(np.sum(te_h > 0.5)), "selected_trainable_variables": [v.name for v in base_vars], "lumen_class": int(args.lumen_class), "pos_weight": float(pos_weight), "seg_loss_weight": float(args.seg_loss_weight), "cls_loss_weight": float(args.cls_loss_weight), "learning_rate": float(args.learning_rate), "epochs": int(args.epochs), "early_stop_patience": int(args.early_stop_patience), "early_stop_min_delta": float(args.early_stop_min_delta), "best_val_total_loss": None if best_val is None else float(best_val), "tensorboard_run_dir": str(tb_run_dir), "saved_base_model": str(base_out), "saved_cls_head": str(cls_out), "history": history, "test_metrics": test_metrics, } summary_path = args.output_dir / "multitask_summary.json" with summary_path.open("w", encoding="utf-8") as fp: json.dump(summary, fp, indent=2) print(f"Saved multitask base model: {base_out}") print(f"Saved multitask classifier head: {cls_out}") print(f"Saved summary: {summary_path}") if __name__ == "__main__": main()