Update trainer.py
Browse files- trainer.py +254 -11
trainer.py
CHANGED
|
@@ -6,10 +6,13 @@
|
|
| 6 |
# - Activations at bottom
|
| 7 |
# =====================================================================================
|
| 8 |
from __future__ import annotations
|
| 9 |
-
import os, json, math, random
|
| 10 |
from dataclasses import dataclass, asdict
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Dict, List, Tuple, Optional
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
|
@@ -27,7 +30,7 @@ from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
|
|
| 27 |
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
|
| 28 |
|
| 29 |
# HF / safetensors
|
| 30 |
-
from huggingface_hub import snapshot_download
|
| 31 |
from safetensors.torch import load_file
|
| 32 |
|
| 33 |
|
|
@@ -60,7 +63,7 @@ class BaseConfig:
|
|
| 60 |
amp: bool = True
|
| 61 |
|
| 62 |
global_flow_weight: float = 1.0
|
| 63 |
-
block_penalty_weight: float = 0.
|
| 64 |
use_local_flow_heads: bool = False
|
| 65 |
local_flow_weight: float = 1.0
|
| 66 |
|
|
@@ -89,6 +92,11 @@ class BaseConfig:
|
|
| 89 |
# Inference
|
| 90 |
sample_steps: int = 30
|
| 91 |
guidance_scale: float = 7.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def __post_init__(self):
|
| 94 |
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
|
|
@@ -229,6 +237,7 @@ class StudentUNet(nn.Module):
|
|
| 229 |
self._ensure_heads(feats)
|
| 230 |
return v_hat, feats
|
| 231 |
|
|
|
|
| 232 |
# =====================================================================================
|
| 233 |
# 6) DAVID LOADER (HF) + ASSESSOR + FUSION
|
| 234 |
# =====================================================================================
|
|
@@ -360,6 +369,8 @@ class FlowMatchDavidTrainer:
|
|
| 360 |
def __init__(self, cfg: BaseConfig, device: str = "cuda"):
|
| 361 |
self.cfg = cfg
|
| 362 |
self.device = device
|
|
|
|
|
|
|
| 363 |
|
| 364 |
# Data
|
| 365 |
self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed)
|
|
@@ -382,9 +393,111 @@ class FlowMatchDavidTrainer:
|
|
| 382 |
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
|
| 383 |
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
# Logs
|
| 386 |
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
# math helpers
|
| 389 |
def _v_star(self, x_t, t, eps_hat):
|
| 390 |
alpha, sigma = self.teacher.alpha_sigma(t)
|
|
@@ -401,10 +514,12 @@ class FlowMatchDavidTrainer:
|
|
| 401 |
# training
|
| 402 |
def train(self):
|
| 403 |
cfg = self.cfg
|
| 404 |
-
gstep =
|
| 405 |
-
|
|
|
|
| 406 |
self.student.train()
|
| 407 |
-
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}"
|
|
|
|
| 408 |
acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
|
| 409 |
|
| 410 |
for it, batch in enumerate(pbar):
|
|
@@ -465,6 +580,7 @@ class FlowMatchDavidTrainer:
|
|
| 465 |
acc["Lf"] += float(L_flow.item())
|
| 466 |
acc["Lb"] += float(L_blocks.item())
|
| 467 |
|
|
|
|
| 468 |
if it % 50 == 0:
|
| 469 |
self.writer.add_scalar("train/total", float(L_total.item()), gstep)
|
| 470 |
self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
|
|
@@ -473,9 +589,18 @@ class FlowMatchDavidTrainer:
|
|
| 473 |
for k in list(lam.keys())[:4]:
|
| 474 |
self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
|
| 475 |
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
|
| 478 |
|
|
|
|
|
|
|
| 479 |
n = len(self.loader)
|
| 480 |
print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
|
| 481 |
self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
|
|
@@ -488,16 +613,134 @@ class FlowMatchDavidTrainer:
|
|
| 488 |
self._save("final", gstep)
|
| 489 |
self.writer.close()
|
| 490 |
|
|
|
|
| 491 |
def _save(self, tag, gstep):
|
| 492 |
-
|
|
|
|
|
|
|
| 493 |
torch.save({
|
| 494 |
"cfg": asdict(self.cfg),
|
| 495 |
"student": self.student.state_dict(),
|
| 496 |
"opt": self.opt.state_dict(),
|
| 497 |
"sched": self.sched.state_dict(),
|
| 498 |
"gstep": gstep
|
| 499 |
-
},
|
| 500 |
-
print(f"β Saved: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
# ---------- Inference (v-pred sampling; use teacher VAE for decode) ----------
|
| 503 |
@torch.no_grad()
|
|
@@ -544,4 +787,4 @@ def main():
|
|
| 544 |
print("β Inference sanity done.")
|
| 545 |
|
| 546 |
if __name__ == "__main__":
|
| 547 |
-
main()
|
|
|
|
| 6 |
# - Activations at bottom
|
| 7 |
# =====================================================================================
|
| 8 |
from __future__ import annotations
|
| 9 |
+
import os, json, math, random, re
|
| 10 |
from dataclasses import dataclass, asdict
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Dict, List, Tuple, Optional
|
| 13 |
+
import urllib.request
|
| 14 |
+
import subprocess
|
| 15 |
+
import shutil
|
| 16 |
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
|
|
|
| 30 |
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
|
| 31 |
|
| 32 |
# HF / safetensors
|
| 33 |
+
from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download
|
| 34 |
from safetensors.torch import load_file
|
| 35 |
|
| 36 |
|
|
|
|
| 63 |
amp: bool = True
|
| 64 |
|
| 65 |
global_flow_weight: float = 1.0
|
| 66 |
+
block_penalty_weight: float = 0.2 # β NEW: Start very low!
|
| 67 |
use_local_flow_heads: bool = False
|
| 68 |
local_flow_weight: float = 1.0
|
| 69 |
|
|
|
|
| 92 |
# Inference
|
| 93 |
sample_steps: int = 30
|
| 94 |
guidance_scale: float = 7.5
|
| 95 |
+
|
| 96 |
+
# HuggingFace upload & resume
|
| 97 |
+
hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching"
|
| 98 |
+
upload_every_epoch: bool = True
|
| 99 |
+
continue_training: bool = True # Download latest checkpoint and resume
|
| 100 |
|
| 101 |
def __post_init__(self):
|
| 102 |
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
| 237 |
self._ensure_heads(feats)
|
| 238 |
return v_hat, feats
|
| 239 |
|
| 240 |
+
|
| 241 |
# =====================================================================================
|
| 242 |
# 6) DAVID LOADER (HF) + ASSESSOR + FUSION
|
| 243 |
# =====================================================================================
|
|
|
|
| 369 |
def __init__(self, cfg: BaseConfig, device: str = "cuda"):
|
| 370 |
self.cfg = cfg
|
| 371 |
self.device = device
|
| 372 |
+
self.start_epoch = 0
|
| 373 |
+
self.start_gstep = 0
|
| 374 |
|
| 375 |
# Data
|
| 376 |
self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed)
|
|
|
|
| 393 |
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
|
| 394 |
self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
|
| 395 |
|
| 396 |
+
# Try to resume from HF if enabled
|
| 397 |
+
if cfg.continue_training:
|
| 398 |
+
self._load_latest_from_hf()
|
| 399 |
+
|
| 400 |
# Logs
|
| 401 |
self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
|
| 402 |
|
| 403 |
+
def _load_latest_from_hf(self):
|
| 404 |
+
"""Download and load the latest checkpoint from HuggingFace."""
|
| 405 |
+
if not self.cfg.hf_repo_id:
|
| 406 |
+
print("β οΈ continue_training=True but no hf_repo_id specified")
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
try:
|
| 410 |
+
api = HfApi()
|
| 411 |
+
print(f"\nπ Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
|
| 412 |
+
|
| 413 |
+
# Check if repo exists
|
| 414 |
+
try:
|
| 415 |
+
repo_info = api.repo_info(repo_id=self.cfg.hf_repo_id, repo_type="model")
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f"β οΈ Could not access repo: {e}")
|
| 418 |
+
print(" Starting training from scratch")
|
| 419 |
+
return
|
| 420 |
+
|
| 421 |
+
# List all files in repo
|
| 422 |
+
files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
|
| 423 |
+
|
| 424 |
+
if not files:
|
| 425 |
+
print("βΉοΈ Repo is empty, starting from scratch")
|
| 426 |
+
return
|
| 427 |
+
|
| 428 |
+
print(f"π Found {len(files)} files in repo:")
|
| 429 |
+
for f in files:
|
| 430 |
+
print(f" - {f}")
|
| 431 |
+
|
| 432 |
+
# Find all .safetensors files with epoch numbers
|
| 433 |
+
# Try multiple patterns
|
| 434 |
+
epochs = []
|
| 435 |
+
|
| 436 |
+
for f in files:
|
| 437 |
+
if not f.endswith('.safetensors'):
|
| 438 |
+
continue
|
| 439 |
+
|
| 440 |
+
# Look for _e<number> pattern anywhere in filename
|
| 441 |
+
match = re.search(r'_e(\d+)\.safetensors$', f)
|
| 442 |
+
if match:
|
| 443 |
+
epoch_num = int(match.group(1))
|
| 444 |
+
epochs.append((epoch_num, f))
|
| 445 |
+
print(f"β Found checkpoint: {f} (epoch {epoch_num})")
|
| 446 |
+
|
| 447 |
+
if not epochs:
|
| 448 |
+
print("βΉοΈ No checkpoint files found (looking for *_e<num>.safetensors)")
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
# Get latest epoch
|
| 452 |
+
latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
|
| 453 |
+
print(f"\nπ₯ Downloading latest checkpoint: {latest_file} (epoch {latest_epoch})")
|
| 454 |
+
|
| 455 |
+
# Download the safetensors file
|
| 456 |
+
local_path = hf_hub_download(
|
| 457 |
+
repo_id=self.cfg.hf_repo_id,
|
| 458 |
+
filename=latest_file,
|
| 459 |
+
repo_type="model",
|
| 460 |
+
cache_dir=self.cfg.ckpt_dir
|
| 461 |
+
)
|
| 462 |
+
print(f"οΏ½οΏ½οΏ½ Downloaded to: {local_path}")
|
| 463 |
+
|
| 464 |
+
# Load the checkpoint using from_single_file
|
| 465 |
+
print("π¦ Loading checkpoint into pipeline...")
|
| 466 |
+
pipe = StableDiffusionPipeline.from_single_file(
|
| 467 |
+
local_path,
|
| 468 |
+
torch_dtype=torch.float16,
|
| 469 |
+
safety_checker=None,
|
| 470 |
+
load_safety_checker=False
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Extract UNet state dict
|
| 474 |
+
unet_state = pipe.unet.state_dict()
|
| 475 |
+
|
| 476 |
+
# Load into student
|
| 477 |
+
missing, unexpected = self.student.unet.load_state_dict(unet_state, strict=False)
|
| 478 |
+
print(f"β Loaded student UNet from epoch {latest_epoch}")
|
| 479 |
+
if missing:
|
| 480 |
+
print(f" Missing keys: {len(missing)}")
|
| 481 |
+
if unexpected:
|
| 482 |
+
print(f" Unexpected keys: {len(unexpected)}")
|
| 483 |
+
|
| 484 |
+
# Set starting epoch (resume from next epoch)
|
| 485 |
+
self.start_epoch = latest_epoch
|
| 486 |
+
self.start_gstep = latest_epoch * len(self.loader)
|
| 487 |
+
|
| 488 |
+
print(f"π― Resuming training from epoch {self.start_epoch + 1}")
|
| 489 |
+
|
| 490 |
+
# Clean up
|
| 491 |
+
del pipe
|
| 492 |
+
torch.cuda.empty_cache()
|
| 493 |
+
|
| 494 |
+
except Exception as e:
|
| 495 |
+
print(f"β οΈ Failed to load checkpoint from HF: {e}")
|
| 496 |
+
print(" Starting training from scratch")
|
| 497 |
+
import traceback
|
| 498 |
+
traceback.print_exc()
|
| 499 |
+
|
| 500 |
+
|
| 501 |
# math helpers
|
| 502 |
def _v_star(self, x_t, t, eps_hat):
|
| 503 |
alpha, sigma = self.teacher.alpha_sigma(t)
|
|
|
|
| 514 |
# training
|
| 515 |
def train(self):
|
| 516 |
cfg = self.cfg
|
| 517 |
+
gstep = self.start_gstep
|
| 518 |
+
|
| 519 |
+
for ep in range(self.start_epoch, cfg.epochs):
|
| 520 |
self.student.train()
|
| 521 |
+
pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
|
| 522 |
+
dynamic_ncols=True, leave=True, position=0) # Add these params
|
| 523 |
acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
|
| 524 |
|
| 525 |
for it, batch in enumerate(pbar):
|
|
|
|
| 580 |
acc["Lf"] += float(L_flow.item())
|
| 581 |
acc["Lb"] += float(L_blocks.item())
|
| 582 |
|
| 583 |
+
# Only log to tensorboard every 50 iterations
|
| 584 |
if it % 50 == 0:
|
| 585 |
self.writer.add_scalar("train/total", float(L_total.item()), gstep)
|
| 586 |
self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
|
|
|
|
| 589 |
for k in list(lam.keys())[:4]:
|
| 590 |
self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
|
| 591 |
|
| 592 |
+
# Update progress bar less frequently to avoid double display
|
| 593 |
+
if it % 10 == 0 or it == len(self.loader) - 1: # Update every 10 iterations
|
| 594 |
+
pbar.set_postfix({
|
| 595 |
+
"L": f"{float(L_total.item()):.4f}",
|
| 596 |
+
"Lf": f"{float(L_flow.item()):.4f}",
|
| 597 |
+
"Lb": f"{float(L_blocks.item()):.4f}"
|
| 598 |
+
}, refresh=False) # Add refresh=False
|
| 599 |
+
|
| 600 |
del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
|
| 601 |
|
| 602 |
+
pbar.close() # Explicitly close the progress bar
|
| 603 |
+
|
| 604 |
n = len(self.loader)
|
| 605 |
print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
|
| 606 |
self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
|
|
|
|
| 613 |
self._save("final", gstep)
|
| 614 |
self.writer.close()
|
| 615 |
|
| 616 |
+
|
| 617 |
def _save(self, tag, gstep):
|
| 618 |
+
"""Save and convert to ComfyUI format, then upload."""
|
| 619 |
+
# 1. Save .pt first (for resuming training if needed)
|
| 620 |
+
pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt"
|
| 621 |
torch.save({
|
| 622 |
"cfg": asdict(self.cfg),
|
| 623 |
"student": self.student.state_dict(),
|
| 624 |
"opt": self.opt.state_dict(),
|
| 625 |
"sched": self.sched.state_dict(),
|
| 626 |
"gstep": gstep
|
| 627 |
+
}, pt_path)
|
| 628 |
+
print(f"β Saved temp .pt: {pt_path}")
|
| 629 |
+
|
| 630 |
+
# 2. Convert to ComfyUI safetensors
|
| 631 |
+
safetensors_path = self._convert_to_comfyui(pt_path, tag)
|
| 632 |
+
|
| 633 |
+
# 3. Upload to HF
|
| 634 |
+
if self.cfg.upload_every_epoch and self.cfg.hf_repo_id and safetensors_path:
|
| 635 |
+
self._upload_to_hf(safetensors_path, tag)
|
| 636 |
+
|
| 637 |
+
# 4. Clean up large .pt file
|
| 638 |
+
pt_path.unlink()
|
| 639 |
+
print(f"β Cleaned up temp .pt file")
|
| 640 |
+
|
| 641 |
+
def _convert_to_comfyui(self, pt_path: Path, tag) -> Optional[Path]:
|
| 642 |
+
"""Convert .pt to ComfyUI-compatible safetensors."""
|
| 643 |
+
try:
|
| 644 |
+
temp_pipeline = Path(self.cfg.ckpt_dir) / f"temp_pipeline_e{tag}"
|
| 645 |
+
output_safetensors = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.safetensors"
|
| 646 |
+
|
| 647 |
+
# Download converter if needed
|
| 648 |
+
converter_path = Path(self.cfg.ckpt_dir) / "convert_diffusers_to_original_stable_diffusion.py"
|
| 649 |
+
if not converter_path.exists():
|
| 650 |
+
print("π₯ Downloading official converter...")
|
| 651 |
+
url = "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_diffusers_to_original_stable_diffusion.py"
|
| 652 |
+
urllib.request.urlretrieve(url, str(converter_path))
|
| 653 |
+
print("β Converter downloaded")
|
| 654 |
+
|
| 655 |
+
# Load checkpoint
|
| 656 |
+
print(f"π¦ Creating diffusers pipeline from checkpoint...")
|
| 657 |
+
checkpoint = torch.load(pt_path, map_location='cpu')
|
| 658 |
+
student_state = checkpoint.get('student', checkpoint)
|
| 659 |
+
|
| 660 |
+
# Load base UNet and replace with student weights
|
| 661 |
+
print("π₯ Loading base UNet...")
|
| 662 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 663 |
+
"runwayml/stable-diffusion-v1-5",
|
| 664 |
+
subfolder="unet",
|
| 665 |
+
torch_dtype=torch.float16
|
| 666 |
+
)
|
| 667 |
+
unet.load_state_dict(student_state, strict=False)
|
| 668 |
+
print("β Loaded student weights into UNet")
|
| 669 |
+
|
| 670 |
+
# Load full pipeline and replace UNet
|
| 671 |
+
print("π₯ Loading base SD1.5 pipeline...")
|
| 672 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 673 |
+
"runwayml/stable-diffusion-v1-5",
|
| 674 |
+
torch_dtype=torch.float16,
|
| 675 |
+
safety_checker=None
|
| 676 |
+
)
|
| 677 |
+
pipe.unet = unet
|
| 678 |
+
print("β Replaced UNet with student")
|
| 679 |
+
|
| 680 |
+
# Save as pipeline
|
| 681 |
+
print(f"πΎ Saving diffusers pipeline...")
|
| 682 |
+
pipe.save_pretrained(str(temp_pipeline), safe_serialization=True)
|
| 683 |
+
print(f"β Pipeline saved to {temp_pipeline}")
|
| 684 |
+
|
| 685 |
+
# Convert to checkpoint
|
| 686 |
+
print(f"π Converting to ComfyUI format...")
|
| 687 |
+
cmd = [
|
| 688 |
+
"python", str(converter_path),
|
| 689 |
+
"--model_path", str(temp_pipeline),
|
| 690 |
+
"--checkpoint_path", str(output_safetensors),
|
| 691 |
+
"--half"
|
| 692 |
+
]
|
| 693 |
+
|
| 694 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 695 |
+
if result.returncode != 0:
|
| 696 |
+
print(f"β Conversion failed: {result.stderr}")
|
| 697 |
+
return None
|
| 698 |
+
|
| 699 |
+
# Verify output
|
| 700 |
+
if output_safetensors.exists():
|
| 701 |
+
size_mb = output_safetensors.stat().st_size / 1e6
|
| 702 |
+
print(f"β Converted: {output_safetensors.name} ({size_mb:.1f}MB)")
|
| 703 |
+
|
| 704 |
+
# Clean up temp pipeline
|
| 705 |
+
shutil.rmtree(temp_pipeline)
|
| 706 |
+
print("β Cleaned up temp pipeline")
|
| 707 |
+
|
| 708 |
+
return output_safetensors
|
| 709 |
+
else:
|
| 710 |
+
print(f"β Output file not created")
|
| 711 |
+
return None
|
| 712 |
+
|
| 713 |
+
except Exception as e:
|
| 714 |
+
print(f"β Conversion failed: {e}")
|
| 715 |
+
import traceback
|
| 716 |
+
traceback.print_exc()
|
| 717 |
+
return None
|
| 718 |
+
|
| 719 |
+
def _upload_to_hf(self, path: Path, tag):
|
| 720 |
+
"""Upload safetensors to HuggingFace."""
|
| 721 |
+
try:
|
| 722 |
+
api = HfApi()
|
| 723 |
+
|
| 724 |
+
# Create repo if doesn't exist
|
| 725 |
+
try:
|
| 726 |
+
create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model")
|
| 727 |
+
print(f"β Repo ready: {self.cfg.hf_repo_id}")
|
| 728 |
+
except Exception:
|
| 729 |
+
pass
|
| 730 |
+
|
| 731 |
+
# Upload
|
| 732 |
+
print(f"π€ Uploading to {self.cfg.hf_repo_id}...")
|
| 733 |
+
api.upload_file(
|
| 734 |
+
path_or_fileobj=str(path),
|
| 735 |
+
path_in_repo=path.name,
|
| 736 |
+
repo_id=self.cfg.hf_repo_id,
|
| 737 |
+
repo_type="model",
|
| 738 |
+
commit_message=f"Epoch {tag}"
|
| 739 |
+
)
|
| 740 |
+
print(f"β
Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}")
|
| 741 |
+
|
| 742 |
+
except Exception as e:
|
| 743 |
+
print(f"β οΈ Upload failed: {e}")
|
| 744 |
|
| 745 |
# ---------- Inference (v-pred sampling; use teacher VAE for decode) ----------
|
| 746 |
@torch.no_grad()
|
|
|
|
| 787 |
print("β Inference sanity done.")
|
| 788 |
|
| 789 |
if __name__ == "__main__":
|
| 790 |
+
main()
|