File size: 24,218 Bytes
a95e79a 9af2926 a95e79a 9af2926 a95e79a 9af2926 a95e79a 9af2926 a95e79a 9af2926 a95e79a 9af2926 a95e79a 9af2926 a95e79a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 | """EC-SimToken training script.
Adds existence head + synthetic null augmentation to SimToken.
Key differences from train.py:
- Uses ECSimtoken_ForCausalLM (adds existence_head: Linear(256,1))
- Audio-swap null augmentation: p_null fraction of batch items have
their audio replaced with another sample's audio β synthetic null
- is_null tensor passed to model_forward to gate mask loss
- test_n evaluation uses existence head (p_exist threshold) for Null S
Usage (training):
python train_ec_simtoken.py \
--data_dir data \
--mllm Chat-UniVi/Chat-UniVi-7B-v1.5 \
--vision_pretrained path/to/sam_vit_h_4b8939.pth \
--name ec_simtoken_v1 \
--epochs 10 \
--batch_size 12 \
--null_aug_prob 0.25 \
--exist_loss_weight 1.0
Usage (eval only):
python train_ec_simtoken.py --run eval \
--saved_model checkpoints/ec_simtoken_v1.pth \
--eval_splits test_s,test_u,test_n
"""
import argparse
import os
import random
import warnings
from functools import partial
import numpy as np
import torch
import torch.multiprocessing as mp
import transformers
from peft import LoraConfig, get_peft_model
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, get_cosine_schedule_with_warmup, logging
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
import re
# ββ Token constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
AUDIO_TOKEN_INDEX = -300
# ββ Args: base (from configs) + EC-SimToken extensions βββββββββββββββββββββββ
from configs import args as base_args # parses base SimToken args
_ec_parser = argparse.ArgumentParser(add_help=False)
_ec_parser.add_argument("--null_aug_prob", type=float, default=0.25,
help="Fraction of batch items with swapped audio (null aug)")
_ec_parser.add_argument("--exist_loss_weight", type=float, default=1.0,
help="Weight for BCE existence loss")
_ec_parser.add_argument("--exist_threshold", type=float, default=0.5,
help="p_exist sigmoid threshold for null classification")
ec_args, _ = _ec_parser.parse_known_args()
# Merge EC-SimToken args into the base args namespace
args = base_args
args.null_aug_prob = ec_args.null_aug_prob
args.exist_loss_weight = ec_args.exist_loss_weight
args.exist_threshold = ec_args.exist_threshold
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
from datasets import REFAVS
from models.ec_simtoken_model import ECSimtoken_ForCausalLM
from utils import utility
# ββ Utilities βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
# benchmark=True lets cuDNN find fastest conv algorithms (important for SAM)
torch.backends.cudnn.benchmark = True
def seed_worker(worker_id):
seed = torch.initial_seed() % 2 ** 32
np.random.seed(seed)
random.seed(seed)
def dict_to_cuda(d: dict) -> dict:
for k, v in d.items():
if isinstance(v, torch.Tensor):
d[k] = v.cuda(non_blocking=True)
elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
d[k] = [x.cuda(non_blocking=True) for x in v]
return d
# ββ Null augmentation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def apply_null_augmentation(
audio_feats: list, p_null: float = 0.25
) -> tuple[list, torch.BoolTensor]:
"""Randomly replace some audio features with mismatched ones.
Returns the (possibly mutated) audio_feats list and a bool tensor
`is_null` where True means the sample's audio was swapped.
"""
B = len(audio_feats)
is_null = torch.zeros(B, dtype=torch.bool)
if B < 2 or p_null <= 0.0:
return audio_feats, is_null
for i in range(B):
if random.random() < p_null:
candidates = [j for j in range(B) if j != i]
j = random.choice(candidates)
audio_feats[i] = audio_feats[j].clone()
is_null[i] = True
return audio_feats, is_null
# ββ Collate (identical to train.py) ββββββββββββββββββββββββββββββββββββββββββ
def tokenizer_image_audio_token(
prompt, tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
audio_token_index=AUDIO_TOKEN_INDEX,
num_frames=10,
return_tensors=None,
):
prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt)
prompt_chunks = [c for c in prompt_chunks if c]
text_chunks, token_types = [], []
for chunk in prompt_chunks:
if chunk == "<image>":
token_types.append("image")
elif chunk == "<audio>":
token_types.append("audio")
elif chunk == "<video>":
token_types.append("video")
else:
text_chunks.append(chunk)
tokenized_chunks = [tokenizer(c).input_ids for c in text_chunks]
input_ids = []
offset = 0
if tokenized_chunks and tokenized_chunks[0] and tokenized_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(tokenized_chunks[0][0])
min_len = min(len(text_chunks), len(token_types))
for i in range(min_len):
input_ids.extend(tokenized_chunks[i][offset:])
if token_types[i] == "image":
input_ids.append(image_token_index)
elif token_types[i] == "audio":
input_ids.append(audio_token_index)
elif token_types[i] == "video":
input_ids.extend([image_token_index] * num_frames)
if len(text_chunks) > min_len:
input_ids.extend(tokenized_chunks[min_len][offset:])
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
return input_ids
def collate_fn(batch, tokenizer=None):
vids, images, image_clips, masks, conversations = [], [], [], [], []
audio_feats, image_feats, resizes, orgsizes = [], [], [], []
refs, refs_num, fids = [], [], []
for data in batch:
vids.append(data["vid"])
images.append(data["image"])
image_clips.append(data["img_clip"])
masks.append(data["mask"])
conversations.append(data["conversation"])
audio_feats.append(data["feat_aud"])
resizes.append(data["resize"])
orgsizes.append(data["orgsize"])
image_feats.append(data["feat_sam"])
refs_num.append(len(data["ref"]))
fids.append(data["fids"])
refs.append(data["ref"][0])
input_ids = [
tokenizer_image_audio_token(c, tokenizer, return_tensors="pt")
for c in conversations
]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
ref_ids = [
tokenizer_image_audio_token(r, tokenizer, return_tensors="pt") for r in refs
]
labels = input_ids.clone()
sep = "Sure, it is [SEG]"
for conversation, target in zip(conversations, labels):
parts = conversation.split(sep)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1
for i in range(len(parts) - 1):
part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2
target[cur_len : cur_len + part_len] = IGNORE_INDEX
cur_len += part_len + sep_len
target[cur_len:] = IGNORE_INDEX
return {
"vids": vids,
"images": images,
"images_clip": image_clips,
"masks": masks,
"convs": conversations,
"input_ids": input_ids,
"attention_masks": attention_masks,
"labels": labels,
"audio_feats": audio_feats,
"resizes": resizes,
"orgsizes": orgsizes,
"image_feats": image_feats,
"ref_ids": ref_ids,
"refs_num": refs_num,
"fids": fids,
}
# ββ Build model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_model(args, tokenizer, seg_token_idx) -> ECSimtoken_ForCausalLM:
model_args = {
"train_mask_decoder": True,
"out_dim": 256,
"ce_loss_weight": 1.0,
"dice_loss_weight": 0.5,
"bce_loss_weight": 2.0,
"seg_token_idx": seg_token_idx,
"vision_pretrained": args.vision_pretrained,
"vision_tower": args.vision_tower,
"use_im_start_end": False,
"compress": args.compress,
"start": args.start,
"exist_loss_weight": args.exist_loss_weight,
}
model = ECSimtoken_ForCausalLM.from_pretrained(
args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
# Gradient checkpointing trades compute for memory (recomputes activations
# during backward instead of storing them). Measured memory at batch=16:
# 78 GB / 82 GB β too close to OOM, so this must stay enabled.
# model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch.bfloat16, device="cuda")
cfg_pt = AutoConfig.from_pretrained(args.mllm)
cfg_pt.use_cluster = True
cfg_pt.freeze = False
cfg_pt.mm_tune = True
cfg_pt.spatial_cluster_rate0 = 64
cfg_pt.spatial_cluster_rate1 = 32
cfg_pt.spatial_cluster_rate2 = 16
cfg_pt.temporal_cluster_rate = 0.0625
cfg_pt.vision_tune = False
model.get_model().initialize_cluster_modules(cfg_pt)
model.get_model().initialize_lisa_modules(model.get_model().config)
for p in vision_tower.parameters():
p.requires_grad = False
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
# LoRA
lora_r = 8
def find_linear_layers(m, targets):
names = set()
skip = {"visual_model", "vision_tower", "mm_projector",
"text_hidden_fcs", "audio_feature_layer", "existence_head"}
for name, mod in m.named_modules():
if (
isinstance(mod, torch.nn.Linear)
and not any(s in name for s in skip)
and any(t in name for t in targets)
):
names.add(name)
return sorted(names)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=16,
target_modules=find_linear_layers(model, ["q_proj", "v_proj"]),
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model = model.to("cuda")
# Cast everything to bfloat16 β from_pretrained only converts checkpoint tensors;
# modules added post-init (existence_head, audio_feature_layer) default to fp32.
model = model.to(torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
# Ensure key modules are trainable
for n, p in model.named_parameters():
if any(x in n for x in ["lm_head", "embed_tokens", "mask_decoder",
"text_hidden_fcs", "audio_feature_layer",
"existence_head"]):
p.requires_grad = True
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable:,}")
return model
# ββ Evaluation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@torch.no_grad()
def evaluate(model, dataloader, split_name: str, exist_threshold: float = 0.5):
model.eval()
total_iou = total_fscore = count = 0.0
# For test_n: existence-gated null metric (lower Null_S is better)
total_null_metric = null_count = 0.0
null_tp = 0 # correctly predicted null (p_exist < threshold)
null_fn = 0 # missed null detection (p_exist >= threshold)
for batch in tqdm(dataloader, desc=f"Eval {split_name}", leave=False):
batch = dict_to_cuda(batch)
with torch.autocast("cuda", dtype=torch.bfloat16):
out = model.forward(
images=batch["images"],
images_clip=batch["images_clip"],
audio_features=batch["audio_feats"],
image_features=batch["image_feats"],
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_masks=batch["attention_masks"],
masks_list=batch["masks"],
resize_list=batch["resizes"],
orgsize_list=batch["orgsizes"],
conversation_list=batch["convs"],
refs_num=batch["refs_num"],
fids=batch["fids"],
vids=batch["vids"],
ref_ids=batch["ref_ids"],
inference=True,
)
pred_masks = out["pred_masks"]
gt_masks = out["gt_masks"]
# exist_logit shape [seg_num, 1]; refs_num==1 per sample so seg_num==B
p_exist = torch.sigmoid(out["exist_logit"]).squeeze(-1).cpu() # [B]
for i in range(len(pred_masks)):
pred_i = pred_masks[i]
gt_i = gt_masks[i]
pe = p_exist[i].item()
if split_name == "test_n":
# p_exist < threshold β correctly detect null β output empty mask
if pe < exist_threshold:
null_score = utility.metric_s_for_null(torch.zeros_like(pred_i))
null_tp += 1
else:
null_score = utility.metric_s_for_null(pred_i)
null_fn += 1
total_null_metric += null_score * pred_i.shape[0] * pred_i.shape[1]
null_count += pred_i.shape[0] * pred_i.shape[1]
else:
iou = utility.mask_iou(pred_i, gt_i)
fscore = utility.Eval_Fmeasure(pred_i, gt_i, None)
n = pred_i.shape[0] * pred_i.shape[1]
total_iou += iou * n
total_fscore += fscore * n
count += n
if split_name == "test_n":
null_s = total_null_metric / (null_count + 1e-8)
total_n = null_tp + null_fn
print(f"\n [{split_name}] Null_S={null_s:.4f} "
f"null_tp={null_tp}/{total_n} null_fn={null_fn}/{total_n}")
return {"null_s": null_s, "null_tp": null_tp, "null_fn": null_fn}
else:
miou = total_iou / (count + 1e-8)
fscore = total_fscore / (count + 1e-8)
print(f"\n [{split_name}] mIoU={miou:.4f} F={fscore:.4f}")
return {"miou": miou, "fscore": fscore}
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
mp.set_start_method("spawn")
set_seed(42)
os.makedirs(args.log_root, exist_ok=True)
os.makedirs(args.checkpoint_root, exist_ok=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.mllm,
cache_dir=None,
model_max_length=2048,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_tokens("[SEG]")
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
print(f"seg_token_idx: {seg_token_idx}")
# ββ Datasets ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
train_dataset = REFAVS("train", args, tokenizer, input_type="refer")
val_dataset_s = REFAVS("test_s", args, tokenizer, input_type="refer")
val_dataset_u = REFAVS("test_u", args, tokenizer, input_type="refer")
val_dataset_n = REFAVS("test_n", args, tokenizer, input_type="refer")
g = torch.Generator()
g.manual_seed(42)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
worker_init_fn=seed_worker,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
generator=g,
pin_memory=True,
persistent_workers=False, # True caused worker-restart deadlock after shm error
prefetch_factor=2,
)
val_loader_s = DataLoader(
val_dataset_s, batch_size=4, shuffle=False, num_workers=4,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
pin_memory=True, persistent_workers=False,
)
val_loader_u = DataLoader(
val_dataset_u, batch_size=4, shuffle=False, num_workers=4,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
pin_memory=True, persistent_workers=False,
)
val_loader_n = DataLoader(
val_dataset_n, batch_size=4, shuffle=False, num_workers=4,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
pin_memory=True, persistent_workers=False,
)
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model = build_model(args, tokenizer, seg_token_idx)
if args.saved_model and os.path.exists(args.saved_model):
ckpt = torch.load(args.saved_model, map_location="cuda")
# Support both raw state dict and {"model": ...} dicts
state = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state, strict=False)
print(f"Loaded {args.saved_model} missing={len(missing)} unexpected={len(unexpected)}")
if args.run == "eval":
for split, loader in [("test_s", val_loader_s),
("test_u", val_loader_u),
("test_n", val_loader_n)]:
if split in args.eval_splits:
evaluate(model, loader, split, args.exist_threshold)
exit(0)
# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model.train()
optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
gradient_accumulation_steps = max(1, int(16 // args.batch_size))
steps_per_epoch = len(train_loader) // gradient_accumulation_steps
total_steps = args.epochs * steps_per_epoch
warmup_steps = max(1, int(total_steps * 0.1))
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
log_path = os.path.join(args.log_root, f"{args.name}.txt")
for epoch in range(args.epochs):
model.train()
optimizer.zero_grad()
running = {"loss": 0.0, "ce": 0.0, "mask": 0.0, "exist": 0.0}
n_steps = 0
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for step, batch in enumerate(loop):
# ββ Null augmentation ββββββββββββββββββββββββββββββββββββββ
batch["audio_feats"], is_null = apply_null_augmentation(
batch["audio_feats"], p_null=args.null_aug_prob
)
batch = dict_to_cuda(batch)
is_null = is_null.cuda()
with torch.autocast("cuda", dtype=torch.bfloat16):
out = model.forward(
images=batch["images"],
images_clip=batch["images_clip"],
audio_features=batch["audio_feats"],
image_features=batch["image_feats"],
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_masks=batch["attention_masks"],
masks_list=batch["masks"],
resize_list=batch["resizes"],
orgsize_list=batch["orgsizes"],
conversation_list=batch["convs"],
refs_num=batch["refs_num"],
fids=batch["fids"],
vids=batch["vids"],
ref_ids=batch["ref_ids"],
epoch=epoch,
inference=False,
contrast=args.ct_weight,
is_null=is_null,
)
loss = out["loss"] / gradient_accumulation_steps
loss.backward()
for k, key in [("loss", "loss"), ("ce", "ce_loss"),
("mask", "mask_loss"), ("exist", "exist_loss")]:
v = out.get(key, torch.tensor(0.0))
running[k] += v.item() if isinstance(v, torch.Tensor) else v
if (step + 1) % gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
n_steps += 1
lr = scheduler.get_last_lr()[0]
avg = {k: running[k] / n_steps for k in running}
loop.set_postfix(
lr=f"{lr:.2e}",
loss=f"{avg['loss']:.4f}",
exist=f"{avg['exist']:.4f}",
)
# ββ End of epoch eval ββββββββββββββββββββββββββββββββββββββββ
denom = max(n_steps, 1)
epoch_loss = running["loss"] / denom
print(
f"Epoch {epoch+1} loss={epoch_loss:.4f} "
f"ce={running['ce']/denom:.4f} "
f"mask={running['mask']/denom:.4f} "
f"exist={running['exist']/denom:.4f} "
f"lr={scheduler.get_last_lr()[0]:.2e}"
)
with open(log_path, "a") as f:
f.write(
f"epoch={epoch+1} loss={epoch_loss:.4f} "
f"ce={running['ce']/denom:.4f} "
f"mask={running['mask']/denom:.4f} "
f"exist={running['exist']/denom:.4f}\n"
)
# Per-epoch checkpoint β keep last 2 to save disk space.
ckpt_ep = os.path.join(args.checkpoint_root, f"{args.name}_ep{epoch+1}.pth")
torch.save(model.state_dict(), ckpt_ep)
print(f"Saved: {ckpt_ep}")
prev_ckpt = os.path.join(args.checkpoint_root, f"{args.name}_ep{epoch-1}.pth")
if epoch >= 2 and os.path.exists(prev_ckpt):
os.remove(prev_ckpt)
evaluate(model, val_loader_s, "test_s", args.exist_threshold)
evaluate(model, val_loader_u, "test_u", args.exist_threshold)
evaluate(model, val_loader_n, "test_n", args.exist_threshold)
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ckpt_path = os.path.join(args.checkpoint_root, f"{args.name}.pth")
torch.save(model.state_dict(), ckpt_path)
print(f"Saved: {ckpt_path}")
|