AbstractPhil commited on
Commit
9a47041
Β·
verified Β·
1 Parent(s): 1e7e470

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +254 -11
trainer.py CHANGED
@@ -6,10 +6,13 @@
6
  # - Activations at bottom
7
  # =====================================================================================
8
  from __future__ import annotations
9
- import os, json, math, random
10
  from dataclasses import dataclass, asdict
11
  from pathlib import Path
12
  from typing import Dict, List, Tuple, Optional
 
 
 
13
 
14
  import torch
15
  import torch.nn as nn
@@ -27,7 +30,7 @@ from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
27
  from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
28
 
29
  # HF / safetensors
30
- from huggingface_hub import snapshot_download
31
  from safetensors.torch import load_file
32
 
33
 
@@ -60,7 +63,7 @@ class BaseConfig:
60
  amp: bool = True
61
 
62
  global_flow_weight: float = 1.0
63
- block_penalty_weight: float = 0.125 # ← NEW: Start very low!
64
  use_local_flow_heads: bool = False
65
  local_flow_weight: float = 1.0
66
 
@@ -89,6 +92,11 @@ class BaseConfig:
89
  # Inference
90
  sample_steps: int = 30
91
  guidance_scale: float = 7.5
 
 
 
 
 
92
 
93
  def __post_init__(self):
94
  Path(self.out_dir).mkdir(parents=True, exist_ok=True)
@@ -229,6 +237,7 @@ class StudentUNet(nn.Module):
229
  self._ensure_heads(feats)
230
  return v_hat, feats
231
 
 
232
  # =====================================================================================
233
  # 6) DAVID LOADER (HF) + ASSESSOR + FUSION
234
  # =====================================================================================
@@ -360,6 +369,8 @@ class FlowMatchDavidTrainer:
360
  def __init__(self, cfg: BaseConfig, device: str = "cuda"):
361
  self.cfg = cfg
362
  self.device = device
 
 
363
 
364
  # Data
365
  self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed)
@@ -382,9 +393,111 @@ class FlowMatchDavidTrainer:
382
  self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
383
  self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
384
 
 
 
 
 
385
  # Logs
386
  self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  # math helpers
389
  def _v_star(self, x_t, t, eps_hat):
390
  alpha, sigma = self.teacher.alpha_sigma(t)
@@ -401,10 +514,12 @@ class FlowMatchDavidTrainer:
401
  # training
402
  def train(self):
403
  cfg = self.cfg
404
- gstep = 0
405
- for ep in range(cfg.epochs):
 
406
  self.student.train()
407
- pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}")
 
408
  acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
409
 
410
  for it, batch in enumerate(pbar):
@@ -465,6 +580,7 @@ class FlowMatchDavidTrainer:
465
  acc["Lf"] += float(L_flow.item())
466
  acc["Lb"] += float(L_blocks.item())
467
 
 
468
  if it % 50 == 0:
469
  self.writer.add_scalar("train/total", float(L_total.item()), gstep)
470
  self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
@@ -473,9 +589,18 @@ class FlowMatchDavidTrainer:
473
  for k in list(lam.keys())[:4]:
474
  self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
475
 
476
- pbar.set_postfix({"L": f"{float(L_total.item()):.4f}", "Lf": f"{float(L_flow.item()):.4f}", "Lb": f"{float(L_blocks.item()):.4f}"})
 
 
 
 
 
 
 
477
  del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
478
 
 
 
479
  n = len(self.loader)
480
  print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
481
  self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
@@ -488,16 +613,134 @@ class FlowMatchDavidTrainer:
488
  self._save("final", gstep)
489
  self.writer.close()
490
 
 
491
  def _save(self, tag, gstep):
492
- path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_{tag}.pt"
 
 
493
  torch.save({
494
  "cfg": asdict(self.cfg),
495
  "student": self.student.state_dict(),
496
  "opt": self.opt.state_dict(),
497
  "sched": self.sched.state_dict(),
498
  "gstep": gstep
499
- }, path)
500
- print(f"βœ“ Saved: {path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  # ---------- Inference (v-pred sampling; use teacher VAE for decode) ----------
503
  @torch.no_grad()
@@ -544,4 +787,4 @@ def main():
544
  print("βœ“ Inference sanity done.")
545
 
546
  if __name__ == "__main__":
547
- main()
 
6
  # - Activations at bottom
7
  # =====================================================================================
8
  from __future__ import annotations
9
+ import os, json, math, random, re
10
  from dataclasses import dataclass, asdict
11
  from pathlib import Path
12
  from typing import Dict, List, Tuple, Optional
13
+ import urllib.request
14
+ import subprocess
15
+ import shutil
16
 
17
  import torch
18
  import torch.nn as nn
 
30
  from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
31
 
32
  # HF / safetensors
33
+ from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download
34
  from safetensors.torch import load_file
35
 
36
 
 
63
  amp: bool = True
64
 
65
  global_flow_weight: float = 1.0
66
+ block_penalty_weight: float = 0.2 # ← NEW: Start very low!
67
  use_local_flow_heads: bool = False
68
  local_flow_weight: float = 1.0
69
 
 
92
  # Inference
93
  sample_steps: int = 30
94
  guidance_scale: float = 7.5
95
+
96
+ # HuggingFace upload & resume
97
+ hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching"
98
+ upload_every_epoch: bool = True
99
+ continue_training: bool = True # Download latest checkpoint and resume
100
 
101
  def __post_init__(self):
102
  Path(self.out_dir).mkdir(parents=True, exist_ok=True)
 
237
  self._ensure_heads(feats)
238
  return v_hat, feats
239
 
240
+
241
  # =====================================================================================
242
  # 6) DAVID LOADER (HF) + ASSESSOR + FUSION
243
  # =====================================================================================
 
369
  def __init__(self, cfg: BaseConfig, device: str = "cuda"):
370
  self.cfg = cfg
371
  self.device = device
372
+ self.start_epoch = 0
373
+ self.start_gstep = 0
374
 
375
  # Data
376
  self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed)
 
393
  self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader))
394
  self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
395
 
396
+ # Try to resume from HF if enabled
397
+ if cfg.continue_training:
398
+ self._load_latest_from_hf()
399
+
400
  # Logs
401
  self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name))
402
 
403
+ def _load_latest_from_hf(self):
404
+ """Download and load the latest checkpoint from HuggingFace."""
405
+ if not self.cfg.hf_repo_id:
406
+ print("⚠️ continue_training=True but no hf_repo_id specified")
407
+ return
408
+
409
+ try:
410
+ api = HfApi()
411
+ print(f"\nπŸ” Searching for latest checkpoint in {self.cfg.hf_repo_id}...")
412
+
413
+ # Check if repo exists
414
+ try:
415
+ repo_info = api.repo_info(repo_id=self.cfg.hf_repo_id, repo_type="model")
416
+ except Exception as e:
417
+ print(f"⚠️ Could not access repo: {e}")
418
+ print(" Starting training from scratch")
419
+ return
420
+
421
+ # List all files in repo
422
+ files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model")
423
+
424
+ if not files:
425
+ print("ℹ️ Repo is empty, starting from scratch")
426
+ return
427
+
428
+ print(f"πŸ“‚ Found {len(files)} files in repo:")
429
+ for f in files:
430
+ print(f" - {f}")
431
+
432
+ # Find all .safetensors files with epoch numbers
433
+ # Try multiple patterns
434
+ epochs = []
435
+
436
+ for f in files:
437
+ if not f.endswith('.safetensors'):
438
+ continue
439
+
440
+ # Look for _e<number> pattern anywhere in filename
441
+ match = re.search(r'_e(\d+)\.safetensors$', f)
442
+ if match:
443
+ epoch_num = int(match.group(1))
444
+ epochs.append((epoch_num, f))
445
+ print(f"βœ“ Found checkpoint: {f} (epoch {epoch_num})")
446
+
447
+ if not epochs:
448
+ print("ℹ️ No checkpoint files found (looking for *_e<num>.safetensors)")
449
+ return
450
+
451
+ # Get latest epoch
452
+ latest_epoch, latest_file = max(epochs, key=lambda x: x[0])
453
+ print(f"\nπŸ“₯ Downloading latest checkpoint: {latest_file} (epoch {latest_epoch})")
454
+
455
+ # Download the safetensors file
456
+ local_path = hf_hub_download(
457
+ repo_id=self.cfg.hf_repo_id,
458
+ filename=latest_file,
459
+ repo_type="model",
460
+ cache_dir=self.cfg.ckpt_dir
461
+ )
462
+ print(f"οΏ½οΏ½οΏ½ Downloaded to: {local_path}")
463
+
464
+ # Load the checkpoint using from_single_file
465
+ print("πŸ“¦ Loading checkpoint into pipeline...")
466
+ pipe = StableDiffusionPipeline.from_single_file(
467
+ local_path,
468
+ torch_dtype=torch.float16,
469
+ safety_checker=None,
470
+ load_safety_checker=False
471
+ )
472
+
473
+ # Extract UNet state dict
474
+ unet_state = pipe.unet.state_dict()
475
+
476
+ # Load into student
477
+ missing, unexpected = self.student.unet.load_state_dict(unet_state, strict=False)
478
+ print(f"βœ“ Loaded student UNet from epoch {latest_epoch}")
479
+ if missing:
480
+ print(f" Missing keys: {len(missing)}")
481
+ if unexpected:
482
+ print(f" Unexpected keys: {len(unexpected)}")
483
+
484
+ # Set starting epoch (resume from next epoch)
485
+ self.start_epoch = latest_epoch
486
+ self.start_gstep = latest_epoch * len(self.loader)
487
+
488
+ print(f"🎯 Resuming training from epoch {self.start_epoch + 1}")
489
+
490
+ # Clean up
491
+ del pipe
492
+ torch.cuda.empty_cache()
493
+
494
+ except Exception as e:
495
+ print(f"⚠️ Failed to load checkpoint from HF: {e}")
496
+ print(" Starting training from scratch")
497
+ import traceback
498
+ traceback.print_exc()
499
+
500
+
501
  # math helpers
502
  def _v_star(self, x_t, t, eps_hat):
503
  alpha, sigma = self.teacher.alpha_sigma(t)
 
514
  # training
515
  def train(self):
516
  cfg = self.cfg
517
+ gstep = self.start_gstep
518
+
519
+ for ep in range(self.start_epoch, cfg.epochs):
520
  self.student.train()
521
+ pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}",
522
+ dynamic_ncols=True, leave=True, position=0) # Add these params
523
  acc = {"L":0.0, "Lf":0.0, "Lb":0.0}
524
 
525
  for it, batch in enumerate(pbar):
 
580
  acc["Lf"] += float(L_flow.item())
581
  acc["Lb"] += float(L_blocks.item())
582
 
583
+ # Only log to tensorboard every 50 iterations
584
  if it % 50 == 0:
585
  self.writer.add_scalar("train/total", float(L_total.item()), gstep)
586
  self.writer.add_scalar("train/flow", float(L_flow.item()), gstep)
 
589
  for k in list(lam.keys())[:4]:
590
  self.writer.add_scalar(f"lambda/{k}", lam[k], gstep)
591
 
592
+ # Update progress bar less frequently to avoid double display
593
+ if it % 10 == 0 or it == len(self.loader) - 1: # Update every 10 iterations
594
+ pbar.set_postfix({
595
+ "L": f"{float(L_total.item()):.4f}",
596
+ "Lf": f"{float(L_flow.item()):.4f}",
597
+ "Lb": f"{float(L_blocks.item()):.4f}"
598
+ }, refresh=False) # Add refresh=False
599
+
600
  del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial
601
 
602
+ pbar.close() # Explicitly close the progress bar
603
+
604
  n = len(self.loader)
605
  print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}")
606
  self.writer.add_scalar("epoch/total", acc['L']/n, ep+1)
 
613
  self._save("final", gstep)
614
  self.writer.close()
615
 
616
+
617
  def _save(self, tag, gstep):
618
+ """Save and convert to ComfyUI format, then upload."""
619
+ # 1. Save .pt first (for resuming training if needed)
620
+ pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt"
621
  torch.save({
622
  "cfg": asdict(self.cfg),
623
  "student": self.student.state_dict(),
624
  "opt": self.opt.state_dict(),
625
  "sched": self.sched.state_dict(),
626
  "gstep": gstep
627
+ }, pt_path)
628
+ print(f"βœ“ Saved temp .pt: {pt_path}")
629
+
630
+ # 2. Convert to ComfyUI safetensors
631
+ safetensors_path = self._convert_to_comfyui(pt_path, tag)
632
+
633
+ # 3. Upload to HF
634
+ if self.cfg.upload_every_epoch and self.cfg.hf_repo_id and safetensors_path:
635
+ self._upload_to_hf(safetensors_path, tag)
636
+
637
+ # 4. Clean up large .pt file
638
+ pt_path.unlink()
639
+ print(f"βœ“ Cleaned up temp .pt file")
640
+
641
+ def _convert_to_comfyui(self, pt_path: Path, tag) -> Optional[Path]:
642
+ """Convert .pt to ComfyUI-compatible safetensors."""
643
+ try:
644
+ temp_pipeline = Path(self.cfg.ckpt_dir) / f"temp_pipeline_e{tag}"
645
+ output_safetensors = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.safetensors"
646
+
647
+ # Download converter if needed
648
+ converter_path = Path(self.cfg.ckpt_dir) / "convert_diffusers_to_original_stable_diffusion.py"
649
+ if not converter_path.exists():
650
+ print("πŸ“₯ Downloading official converter...")
651
+ url = "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_diffusers_to_original_stable_diffusion.py"
652
+ urllib.request.urlretrieve(url, str(converter_path))
653
+ print("βœ“ Converter downloaded")
654
+
655
+ # Load checkpoint
656
+ print(f"πŸ“¦ Creating diffusers pipeline from checkpoint...")
657
+ checkpoint = torch.load(pt_path, map_location='cpu')
658
+ student_state = checkpoint.get('student', checkpoint)
659
+
660
+ # Load base UNet and replace with student weights
661
+ print("πŸ“₯ Loading base UNet...")
662
+ unet = UNet2DConditionModel.from_pretrained(
663
+ "runwayml/stable-diffusion-v1-5",
664
+ subfolder="unet",
665
+ torch_dtype=torch.float16
666
+ )
667
+ unet.load_state_dict(student_state, strict=False)
668
+ print("βœ“ Loaded student weights into UNet")
669
+
670
+ # Load full pipeline and replace UNet
671
+ print("πŸ“₯ Loading base SD1.5 pipeline...")
672
+ pipe = StableDiffusionPipeline.from_pretrained(
673
+ "runwayml/stable-diffusion-v1-5",
674
+ torch_dtype=torch.float16,
675
+ safety_checker=None
676
+ )
677
+ pipe.unet = unet
678
+ print("βœ“ Replaced UNet with student")
679
+
680
+ # Save as pipeline
681
+ print(f"πŸ’Ύ Saving diffusers pipeline...")
682
+ pipe.save_pretrained(str(temp_pipeline), safe_serialization=True)
683
+ print(f"βœ“ Pipeline saved to {temp_pipeline}")
684
+
685
+ # Convert to checkpoint
686
+ print(f"πŸ”„ Converting to ComfyUI format...")
687
+ cmd = [
688
+ "python", str(converter_path),
689
+ "--model_path", str(temp_pipeline),
690
+ "--checkpoint_path", str(output_safetensors),
691
+ "--half"
692
+ ]
693
+
694
+ result = subprocess.run(cmd, capture_output=True, text=True)
695
+ if result.returncode != 0:
696
+ print(f"❌ Conversion failed: {result.stderr}")
697
+ return None
698
+
699
+ # Verify output
700
+ if output_safetensors.exists():
701
+ size_mb = output_safetensors.stat().st_size / 1e6
702
+ print(f"βœ“ Converted: {output_safetensors.name} ({size_mb:.1f}MB)")
703
+
704
+ # Clean up temp pipeline
705
+ shutil.rmtree(temp_pipeline)
706
+ print("βœ“ Cleaned up temp pipeline")
707
+
708
+ return output_safetensors
709
+ else:
710
+ print(f"❌ Output file not created")
711
+ return None
712
+
713
+ except Exception as e:
714
+ print(f"❌ Conversion failed: {e}")
715
+ import traceback
716
+ traceback.print_exc()
717
+ return None
718
+
719
+ def _upload_to_hf(self, path: Path, tag):
720
+ """Upload safetensors to HuggingFace."""
721
+ try:
722
+ api = HfApi()
723
+
724
+ # Create repo if doesn't exist
725
+ try:
726
+ create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model")
727
+ print(f"βœ“ Repo ready: {self.cfg.hf_repo_id}")
728
+ except Exception:
729
+ pass
730
+
731
+ # Upload
732
+ print(f"πŸ“€ Uploading to {self.cfg.hf_repo_id}...")
733
+ api.upload_file(
734
+ path_or_fileobj=str(path),
735
+ path_in_repo=path.name,
736
+ repo_id=self.cfg.hf_repo_id,
737
+ repo_type="model",
738
+ commit_message=f"Epoch {tag}"
739
+ )
740
+ print(f"βœ… Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}")
741
+
742
+ except Exception as e:
743
+ print(f"⚠️ Upload failed: {e}")
744
 
745
  # ---------- Inference (v-pred sampling; use teacher VAE for decode) ----------
746
  @torch.no_grad()
 
787
  print("βœ“ Inference sanity done.")
788
 
789
  if __name__ == "__main__":
790
+ main()