ivus-segmentation / scripts /finetune /multitask /train_multitask_deepivus.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""Multitask fine-tuning: DeepIVUS lumen segmentation + bifurcation classification.
Uses:
- all bifurcation-labeled samples for classification
- only lumen-polygon samples for segmentation loss (masked per-sample)
Data are loaded on-the-fly from frame bank JSONL + source DICOMs.
"""
from __future__ import annotations
import argparse
import datetime
import json
import math
import re
from pathlib import Path
import numpy as np
import tensorflow as tf
from PIL import Image
from deepivus.config import resolve_lumen_model_dir
from scripts.finetune.shared.common import (
load_bifurcation_annotations,
load_lumen_annotations,
load_preprocessed_stack,
load_split_ids,
polygon_to_mask,
)
try:
from tqdm.auto import tqdm
except Exception:
tqdm = None
IMG_MEAN = tf.constant([60.3486], dtype=tf.float32)
def _extract_logits(model_output: object) -> tf.Tensor:
if isinstance(model_output, dict):
if not model_output:
raise RuntimeError("Model returned an empty dict output.")
return next(iter(model_output.values()))
if isinstance(model_output, (tuple, list)):
if not model_output:
raise RuntimeError("Model returned an empty sequence output.")
return model_output[0]
if tf.is_tensor(model_output):
return model_output
raise RuntimeError(f"Unsupported model output type: {type(model_output)!r}")
def _prepare_batch(images: tf.Tensor) -> tf.Tensor:
x = tf.cast(images, tf.float32) - IMG_MEAN
if len(x.shape) < 4:
x = tf.expand_dims(x, axis=3)
if x.shape[-1] != 3:
x = tf.tile(x, [1, 1, 1, 3])
return x
def _select_trainable_vars(
all_trainable: list[tf.Variable],
include_regex: str,
exclude_regex: str,
max_trainable: int,
) -> list[tf.Variable]:
include = re.compile(include_regex) if include_regex else None
exclude = re.compile(exclude_regex) if exclude_regex else None
selected: list[tf.Variable] = []
for var in all_trainable:
name = var.name
if include is not None and not include.search(name):
continue
if exclude is not None and exclude.search(name):
continue
selected.append(var)
if max_trainable > 0 and len(selected) >= max_trainable:
break
return selected
def _binary_logit_from_multiclass(logits: tf.Tensor, lumen_class: int) -> tf.Tensor:
num_classes = int(logits.shape[-1])
if lumen_class < 0 or lumen_class >= num_classes:
raise ValueError(f"Invalid lumen_class={lumen_class} for num_classes={num_classes}")
positive = logits[..., lumen_class]
negatives = tf.concat([logits[..., :lumen_class], logits[..., lumen_class + 1 :]], axis=-1)
negative_logsumexp = tf.reduce_logsumexp(negatives, axis=-1)
return positive - negative_logsumexp
def _augment_multitask_batch(
images: tf.Tensor,
masks: tf.Tensor,
bif_labels: tf.Tensor,
has_mask: tf.Tensor,
seed: int,
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32)
flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5
images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images)
masks = tf.cond(flip_lr, lambda: tf.image.flip_left_right(masks[..., tf.newaxis]), lambda: masks[..., tf.newaxis])
stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32)
flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5
images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images)
masks = tf.cond(flip_ud, lambda: tf.image.flip_up_down(masks), lambda: masks)
stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32)
k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32)
images = tf.image.rot90(images, k=k)
masks = tf.image.rot90(masks, k=k)
shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08)
shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08)
p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4)
p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4)
base_t = tf.stack([
tf.constant(1.0, tf.float32), shx, tf.constant(0.0, tf.float32),
shy, tf.constant(1.0, tf.float32), tf.constant(0.0, tf.float32),
p0, p1,
])
transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1])
out_shape = tf.shape(images)[1:3]
images = tf.raw_ops.ImageProjectiveTransformV3(
images=images,
transforms=transforms,
output_shape=out_shape,
interpolation="BILINEAR",
fill_mode="REFLECT",
fill_value=0.0,
)
masks = tf.raw_ops.ImageProjectiveTransformV3(
images=masks,
transforms=transforms,
output_shape=out_shape,
interpolation="NEAREST",
fill_mode="REFLECT",
fill_value=0.0,
)
stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32)
images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed)
stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32)
images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed)
images = tf.clip_by_value(images, -255.0, 255.0)
return images, tf.squeeze(masks, axis=-1), bif_labels, has_mask
def _build_dataset(
images: np.ndarray,
masks: np.ndarray,
bif_labels: np.ndarray,
has_mask: np.ndarray,
batch_size: int,
shuffle: bool,
seed: int,
augment: bool,
) -> tf.data.Dataset:
ds = tf.data.Dataset.from_tensor_slices((images, masks, bif_labels, has_mask))
if shuffle:
ds = ds.shuffle(buffer_size=len(images), seed=seed, reshuffle_each_iteration=True)
ds = ds.batch(batch_size)
def _map_fn(batch_images, batch_masks, batch_bif, batch_has):
x = _prepare_batch(batch_images)
y_mask = tf.cast(batch_masks, tf.float32)
y_bif = tf.cast(batch_bif, tf.float32)
y_has = tf.cast(batch_has, tf.float32)
return x, y_mask, y_bif, y_has
ds = ds.map(_map_fn, num_parallel_calls=tf.data.AUTOTUNE)
if augment:
ds = ds.enumerate().map(
lambda i, data: _augment_multitask_batch(data[0], data[1], data[2], data[3], seed=seed + tf.cast(i, tf.int32)),
num_parallel_calls=tf.data.AUTOTUNE,
)
return ds.prefetch(tf.data.AUTOTUNE)
def _tile_to_uint8(tile: np.ndarray) -> np.ndarray:
x = np.asarray(tile, dtype=np.float32)
if x.ndim == 3 and x.shape[-1] > 1:
x = x[..., 0]
if not np.isfinite(x).any():
return np.zeros_like(x, dtype=np.uint8)
mn = float(np.nanmin(x))
mx = float(np.nanmax(x))
if mx <= mn + 1e-6:
return np.clip(x, 0, 255).astype(np.uint8)
x = (x - mn) / (mx - mn)
return (x * 255.0).clip(0, 255).astype(np.uint8)
def _save_augment_grid_from_dataset(ds: tf.data.Dataset, out_path: Path) -> bool:
batch = None
for x, _, _, _ in ds.take(1):
batch = x.numpy()
if batch is None or batch.shape[0] == 0:
return False
n = min(16, int(batch.shape[0]))
h = int(batch.shape[1])
w = int(batch.shape[2])
grid = Image.new("L", (4 * w, 4 * h), color=0)
for i in range(n):
r = i // 4
c = i % 4
grid.paste(Image.fromarray(_tile_to_uint8(batch[i]), mode="L"), (c * w, r * h))
out_path.parent.mkdir(parents=True, exist_ok=True)
grid.save(out_path)
return True
def _build_arrays(annotations_bif, lumen_map: dict[str, object], diameter: int):
grouped: dict[Path, list] = {}
for ann in annotations_bif:
grouped.setdefault(ann.dicom_path, []).append(ann)
images = []
masks = []
bif_labels = []
has_mask = []
sample_ids = []
for dicom_path, ann_list in grouped.items():
stack = load_preprocessed_stack(dicom_path, diameter=diameter)
h, w = int(stack.shape[1]), int(stack.shape[2])
for ann in ann_list:
fidx = ann.frame_idx
if fidx < 0 or fidx >= stack.shape[0]:
continue
sid = ann.sample_id
images.append(stack[fidx])
bif_labels.append(1.0 if ann.bifurcation else 0.0)
sample_ids.append(sid)
lann = lumen_map.get(sid)
if lann is None:
masks.append(np.zeros((h, w), dtype=np.float32))
has_mask.append(0.0)
else:
m = polygon_to_mask(lann.lumen_x, lann.lumen_y, (h, w)).astype(np.float32)
masks.append(m)
has_mask.append(1.0 if np.any(m > 0.5) else 0.0)
if not images:
raise RuntimeError("No valid samples produced from frame bank.")
return (
np.stack(images, axis=0),
np.stack(masks, axis=0),
np.asarray(bif_labels, dtype=np.float32),
np.asarray(has_mask, dtype=np.float32),
sample_ids,
)
def _make_cls_head(num_seg_classes: int) -> tf.keras.Model:
inp = tf.keras.Input(shape=(None, None, num_seg_classes), name="seg_logits")
x = tf.keras.layers.GlobalAveragePooling2D()(inp)
x = tf.keras.layers.Dense(96, activation="relu")(x)
x = tf.keras.layers.Dropout(0.3)(x)
out = tf.keras.layers.Dense(1, activation="sigmoid", name="bifurcation_prob")(x)
return tf.keras.Model(inp, out, name="bifurcation_head")
def _compute_metrics(
base_model,
cls_head,
ds: tf.data.Dataset,
lumen_class: int,
pos_weight: float,
seg_weight: float,
cls_weight: float,
) -> dict[str, float]:
seg_losses = []
cls_losses = []
total_losses = []
seg_inter = 0.0
seg_union = 0.0
seg_pred_sum = 0.0
seg_gt_sum = 0.0
seg_count = 0.0
y_true = []
y_prob = []
for x, y_mask, y_bif, y_has in ds:
logits = _extract_logits(base_model(x, training=False))
resized_logits = tf.image.resize(logits, (tf.shape(y_mask)[1], tf.shape(y_mask)[2]))
bin_logit = _binary_logit_from_multiclass(resized_logits, lumen_class=lumen_class)
bce_map = tf.nn.weighted_cross_entropy_with_logits(labels=y_mask, logits=bin_logit, pos_weight=pos_weight)
bce_per_sample = tf.reduce_mean(bce_map, axis=[1, 2])
probs_mask = tf.math.sigmoid(bin_logit)
inter = tf.reduce_sum(probs_mask * y_mask, axis=[1, 2])
denom = tf.reduce_sum(probs_mask + y_mask, axis=[1, 2])
dice_loss_per_sample = 1.0 - (2.0 * inter + 1e-6) / (denom + 1e-6)
seg_per_sample = bce_per_sample + 0.3 * dice_loss_per_sample
valid = y_has > 0.5
valid_count = tf.reduce_sum(tf.cast(valid, tf.float32))
seg_loss = tf.where(
valid_count > 0,
tf.reduce_sum(tf.where(valid, seg_per_sample, tf.zeros_like(seg_per_sample))) / valid_count,
tf.constant(0.0, dtype=tf.float32),
)
bif_prob = cls_head(resized_logits, training=False)
bif_prob = tf.reshape(bif_prob, [-1])
cls_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_bif, bif_prob))
total = seg_weight * seg_loss + cls_weight * cls_loss
seg_losses.append(float(seg_loss.numpy()))
cls_losses.append(float(cls_loss.numpy()))
total_losses.append(float(total.numpy()))
y_true.append(y_bif.numpy().reshape(-1))
y_prob.append(bif_prob.numpy().reshape(-1))
pred_mask = probs_mask >= 0.5
gt_mask = y_mask >= 0.5
valid_np = valid.numpy().astype(bool)
pred_np = pred_mask.numpy()
gt_np = gt_mask.numpy()
for i in range(pred_np.shape[0]):
if not valid_np[i]:
continue
pi = pred_np[i]
gi = gt_np[i]
seg_inter += float(np.logical_and(pi, gi).sum())
seg_union += float(np.logical_or(pi, gi).sum())
seg_pred_sum += float(pi.sum())
seg_gt_sum += float(gi.sum())
seg_count += 1.0
y_true_arr = np.concatenate(y_true) if y_true else np.zeros((0,), dtype=np.float32)
y_prob_arr = np.concatenate(y_prob) if y_prob else np.zeros((0,), dtype=np.float32)
y_pred_arr = (y_prob_arr >= 0.5).astype(np.int32)
y_true_int = y_true_arr.astype(np.int32)
tp = int(np.sum((y_pred_arr == 1) & (y_true_int == 1)))
fp = int(np.sum((y_pred_arr == 1) & (y_true_int == 0)))
fn = int(np.sum((y_pred_arr == 0) & (y_true_int == 1)))
tn = int(np.sum((y_pred_arr == 0) & (y_true_int == 0)))
cls_acc = float((tp + tn) / max(tp + tn + fp + fn, 1))
cls_prec = float(tp / max(tp + fp, 1))
cls_rec = float(tp / max(tp + fn, 1))
cls_f1 = float((2.0 * cls_prec * cls_rec) / max(cls_prec + cls_rec, 1e-12))
if y_true_arr.size > 1 and len(np.unique(y_true_int)) > 1:
auc_metric = tf.keras.metrics.AUC(curve="ROC")
auc_metric.update_state(y_true_arr, y_prob_arr)
cls_auc = float(auc_metric.result().numpy())
else:
cls_auc = float("nan")
seg_iou = seg_inter / max(seg_union, 1.0)
seg_dice = (2.0 * seg_inter) / max(seg_pred_sum + seg_gt_sum, 1.0)
return {
"total_loss": float(np.mean(total_losses)) if total_losses else float("nan"),
"seg_loss": float(np.mean(seg_losses)) if seg_losses else float("nan"),
"cls_loss": float(np.mean(cls_losses)) if cls_losses else float("nan"),
"seg_iou": float(seg_iou),
"seg_dice": float(seg_dice),
"seg_count": float(seg_count),
"cls_accuracy": cls_acc,
"cls_precision": cls_prec,
"cls_recall": cls_rec,
"cls_f1": cls_f1,
"cls_auc": cls_auc,
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model-dir", type=Path, default=resolve_lumen_model_dir())
parser.add_argument("--frame-bank-root", type=Path, default=Path("evals/frame_bank_merged"))
parser.add_argument("--split-json", type=Path, default=Path("evals/splits/ivus_split_merged_600.json"))
parser.add_argument("--output-dir", type=Path, default=Path("models/multitask"))
parser.add_argument("--lumen-class", type=int, default=1)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--learning-rate", type=float, default=1e-5)
parser.add_argument("--seg-loss-weight", type=float, default=1.0)
parser.add_argument("--cls-loss-weight", type=float, default=1.0)
parser.add_argument("--augment", action="store_true")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--diameter", type=int, default=67)
parser.add_argument("--clip-norm", type=float, default=1.0)
parser.add_argument("--include-var-regex", type=str, default=".*")
parser.add_argument("--exclude-var-regex", type=str, default="")
parser.add_argument("--max-trainable", type=int, default=0)
parser.add_argument("--early-stop-patience", type=int, default=10)
parser.add_argument("--early-stop-min-delta", type=float, default=1e-4)
parser.add_argument("--tb-logdir", type=Path, default=Path("output/tensorboard"))
parser.add_argument("--tb-run-name", type=str, default=None)
args = parser.parse_args()
tf.random.set_seed(args.seed)
np.random.seed(args.seed)
bif_anns = load_bifurcation_annotations(args.frame_bank_root)
lumen_anns = load_lumen_annotations(args.frame_bank_root)
if not bif_anns:
raise RuntimeError(f"No bifurcation annotations under: {args.frame_bank_root}")
split_ids = load_split_ids(args.split_json)
train_ids = split_ids["train"]
val_ids = split_ids["val"]
test_ids = split_ids["test"]
bif_anns = [a for a in bif_anns if a.sample_id in (train_ids | val_ids | test_ids)]
lumen_map = {a.sample_id: a for a in lumen_anns}
images, masks, bif_labels, has_mask, sample_ids = _build_arrays(
bif_anns,
lumen_map=lumen_map,
diameter=args.diameter,
)
idx_train = np.asarray([i for i, sid in enumerate(sample_ids) if sid in train_ids], dtype=np.int64)
idx_val = np.asarray([i for i, sid in enumerate(sample_ids) if sid in val_ids], dtype=np.int64)
idx_test = np.asarray([i for i, sid in enumerate(sample_ids) if sid in test_ids], dtype=np.int64)
if len(idx_train) == 0:
raise RuntimeError("Train split is empty.")
tr_x, tr_m, tr_b, tr_h = images[idx_train], masks[idx_train], bif_labels[idx_train], has_mask[idx_train]
va_x, va_m, va_b, va_h = images[idx_val], masks[idx_val], bif_labels[idx_val], has_mask[idx_val]
te_x, te_m, te_b, te_h = images[idx_test], masks[idx_test], bif_labels[idx_test], has_mask[idx_test]
print(
f"Samples train/val/test: {len(tr_x)}/{len(va_x)}/{len(te_x)} | "
f"lumen-supervised train/val/test: {int(np.sum(tr_h))}/{int(np.sum(va_h))}/{int(np.sum(te_h))}"
)
train_ds = _build_dataset(tr_x, tr_m, tr_b, tr_h, args.batch_size, shuffle=True, seed=args.seed, augment=args.augment)
val_ds = _build_dataset(va_x, va_m, va_b, va_h, args.batch_size, shuffle=False, seed=args.seed, augment=False) if len(va_x) else None
test_ds = _build_dataset(te_x, te_m, te_b, te_h, args.batch_size, shuffle=False, seed=args.seed, augment=False) if len(te_x) else None
base_model = tf.saved_model.load(str(args.model_dir))
all_trainable = list(getattr(base_model, "trainable_variables", []))
if not all_trainable:
raise RuntimeError("Loaded base model has no trainable variables.")
base_vars = _select_trainable_vars(
all_trainable,
include_regex=args.include_var_regex,
exclude_regex=args.exclude_var_regex,
max_trainable=args.max_trainable,
)
if not base_vars:
raise RuntimeError("No base-model trainable variables selected.")
# Build class head on top of segmentation logits.
probe = _extract_logits(base_model(_prepare_batch(tf.convert_to_tensor(tr_x[:1], dtype=tf.float32)), training=False))
num_seg_classes = int(probe.shape[-1])
cls_head = _make_cls_head(num_seg_classes=num_seg_classes)
pos_pixels = float(np.sum(tr_m > 0.5))
total_pixels = float(np.prod(tr_m.shape))
neg_pixels = max(total_pixels - pos_pixels, 1.0)
pos_weight = float(np.clip(neg_pixels / max(pos_pixels, 1.0), 1.0, 40.0))
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
args.output_dir.mkdir(parents=True, exist_ok=True)
run_name = args.tb_run_name or datetime.datetime.now().strftime("multitask_%Y%m%d_%H%M%S")
tb_run_dir = args.tb_logdir / run_name
tb_run_dir.mkdir(parents=True, exist_ok=True)
tb_writer = tf.summary.create_file_writer(str(tb_run_dir))
print(f"TensorBoard logdir: {tb_run_dir}")
if args.augment:
aug_preview_path = tb_run_dir / "augment_preview_4x4.png"
if _save_augment_grid_from_dataset(train_ds, aug_preview_path):
print(f"Saved augmentation preview: {aug_preview_path}")
ckpt_dir = args.output_dir / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt = tf.train.Checkpoint(base=base_model, cls_head=cls_head, optimizer=optimizer)
manager = tf.train.CheckpointManager(ckpt, directory=str(ckpt_dir), max_to_keep=1)
steps_per_epoch = int(math.ceil(len(tr_x) / float(args.batch_size)))
best_val = None
wait = 0
history = []
for epoch in range(1, args.epochs + 1):
if tqdm is not None:
pbar = tqdm(total=steps_per_epoch, desc=f"Epoch {epoch}/{args.epochs}", leave=False)
else:
pbar = None
epoch_total = []
epoch_seg = []
epoch_cls = []
for step, (x, y_mask, y_bif, y_has) in enumerate(train_ds, start=1):
with tf.GradientTape() as tape:
logits = _extract_logits(base_model(x, training=True))
resized_logits = tf.image.resize(logits, (tf.shape(y_mask)[1], tf.shape(y_mask)[2]))
bin_logit = _binary_logit_from_multiclass(resized_logits, lumen_class=args.lumen_class)
bce_map = tf.nn.weighted_cross_entropy_with_logits(labels=y_mask, logits=bin_logit, pos_weight=pos_weight)
bce_per_sample = tf.reduce_mean(bce_map, axis=[1, 2])
probs_mask = tf.math.sigmoid(bin_logit)
inter = tf.reduce_sum(probs_mask * y_mask, axis=[1, 2])
denom = tf.reduce_sum(probs_mask + y_mask, axis=[1, 2])
dice_loss_per_sample = 1.0 - (2.0 * inter + 1e-6) / (denom + 1e-6)
seg_per_sample = bce_per_sample + 0.3 * dice_loss_per_sample
valid = y_has > 0.5
valid_count = tf.reduce_sum(tf.cast(valid, tf.float32))
seg_loss = tf.where(
valid_count > 0,
tf.reduce_sum(tf.where(valid, seg_per_sample, tf.zeros_like(seg_per_sample))) / valid_count,
tf.constant(0.0, dtype=tf.float32),
)
bif_prob = cls_head(resized_logits, training=True)
bif_prob = tf.reshape(bif_prob, [-1])
cls_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_bif, bif_prob))
total_loss = args.seg_loss_weight * seg_loss + args.cls_loss_weight * cls_loss
train_vars = base_vars + cls_head.trainable_variables
grads = tape.gradient(total_loss, train_vars)
grad_var_pairs = [(g, v) for g, v in zip(grads, train_vars) if g is not None]
if args.clip_norm > 0 and grad_var_pairs:
grad_tensors = [g for g, _ in grad_var_pairs]
clipped, _ = tf.clip_by_global_norm(grad_tensors, clip_norm=args.clip_norm)
grad_var_pairs = list(zip(clipped, [v for _, v in grad_var_pairs]))
optimizer.apply_gradients(grad_var_pairs)
epoch_total.append(float(total_loss.numpy()))
epoch_seg.append(float(seg_loss.numpy()))
epoch_cls.append(float(cls_loss.numpy()))
if pbar is not None:
pbar.update(1)
pbar.set_postfix(loss=f"{epoch_total[-1]:.4f}", seg=f"{epoch_seg[-1]:.4f}", cls=f"{epoch_cls[-1]:.4f}")
elif step % 10 == 0:
print(f"Step {step}/{steps_per_epoch} loss={epoch_total[-1]:.4f} seg={epoch_seg[-1]:.4f} cls={epoch_cls[-1]:.4f}")
if pbar is not None:
pbar.close()
train_metrics = {
"total_loss": float(np.mean(epoch_total)) if epoch_total else float("nan"),
"seg_loss": float(np.mean(epoch_seg)) if epoch_seg else float("nan"),
"cls_loss": float(np.mean(epoch_cls)) if epoch_cls else float("nan"),
}
val_metrics = _compute_metrics(
base_model,
cls_head,
val_ds,
lumen_class=args.lumen_class,
pos_weight=pos_weight,
seg_weight=args.seg_loss_weight,
cls_weight=args.cls_loss_weight,
) if val_ds is not None else {"total_loss": float("nan")}
print(
f"Epoch {epoch}/{args.epochs} train_total={train_metrics['total_loss']:.4f} "
f"val_total={val_metrics.get('total_loss', float('nan')):.4f} "
f"val_seg_dice={val_metrics.get('seg_dice', float('nan')):.4f} "
f"val_cls_auc={val_metrics.get('cls_auc', float('nan')):.4f}"
)
history_row = {"epoch": epoch, **train_metrics}
history_row.update({f"val_{k}": v for k, v in val_metrics.items()})
history.append(history_row)
with tb_writer.as_default():
for k, v in train_metrics.items():
tf.summary.scalar(f"train/{k}", v, step=epoch)
for k, v in val_metrics.items():
tf.summary.scalar(f"val/{k}", v, step=epoch)
monitor = val_metrics.get("total_loss", float("nan"))
improved = np.isfinite(monitor) and (best_val is None or monitor < (best_val - args.early_stop_min_delta))
if improved:
best_val = float(monitor)
wait = 0
manager.save(checkpoint_number=epoch)
print(f"Saved best checkpoint at epoch {epoch} (val_total_loss={best_val:.6f})")
else:
wait += 1
print(f"No val improvement: wait {wait}/{args.early_stop_patience}")
if wait >= args.early_stop_patience:
print("Early stopping triggered.")
break
tb_writer.flush()
if manager.latest_checkpoint:
ckpt.restore(manager.latest_checkpoint).expect_partial()
print(f"Restored best checkpoint: {manager.latest_checkpoint}")
base_out = args.output_dir / "lumen_multitask_base"
cls_out = args.output_dir / "bifurcation_head.keras"
tf.saved_model.save(base_model, str(base_out))
cls_head.save(str(cls_out))
test_metrics = _compute_metrics(
base_model,
cls_head,
test_ds,
lumen_class=args.lumen_class,
pos_weight=pos_weight,
seg_weight=args.seg_loss_weight,
cls_weight=args.cls_loss_weight,
) if test_ds is not None else {}
with tb_writer.as_default():
for k, v in test_metrics.items():
tf.summary.scalar(f"test/{k}", v, step=len(history))
tb_writer.flush()
tb_writer.close()
summary = {
"model_dir": str(args.model_dir),
"output_dir": str(args.output_dir),
"split_json": str(args.split_json),
"num_train": int(len(tr_x)),
"num_val": int(len(va_x)),
"num_test": int(len(te_x)),
"num_train_with_lumen": int(np.sum(tr_h > 0.5)),
"num_val_with_lumen": int(np.sum(va_h > 0.5)),
"num_test_with_lumen": int(np.sum(te_h > 0.5)),
"selected_trainable_variables": [v.name for v in base_vars],
"lumen_class": int(args.lumen_class),
"pos_weight": float(pos_weight),
"seg_loss_weight": float(args.seg_loss_weight),
"cls_loss_weight": float(args.cls_loss_weight),
"learning_rate": float(args.learning_rate),
"epochs": int(args.epochs),
"early_stop_patience": int(args.early_stop_patience),
"early_stop_min_delta": float(args.early_stop_min_delta),
"best_val_total_loss": None if best_val is None else float(best_val),
"tensorboard_run_dir": str(tb_run_dir),
"saved_base_model": str(base_out),
"saved_cls_head": str(cls_out),
"history": history,
"test_metrics": test_metrics,
}
summary_path = args.output_dir / "multitask_summary.json"
with summary_path.open("w", encoding="utf-8") as fp:
json.dump(summary, fp, indent=2)
print(f"Saved multitask base model: {base_out}")
print(f"Saved multitask classifier head: {cls_out}")
print(f"Saved summary: {summary_path}")
if __name__ == "__main__":
main()