AbstractPhil commited on
Commit
e5ffd07
Β·
verified Β·
1 Parent(s): 0f1416b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -23
app.py CHANGED
@@ -1027,6 +1027,7 @@ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
1027
 
1028
  def load_lyra_vae_xl(
1029
  repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
 
1030
  device: str = "cuda"
1031
  ):
1032
  """Load Lyra VAE v2 (SDXL/Illustrious version) from HuggingFace."""
@@ -1037,7 +1038,9 @@ def load_lyra_vae_xl(
1037
  print(f"🎡 Loading Lyra VAE v2 from {repo_id}...")
1038
 
1039
  try:
1040
- # Download config.json first to get model architecture
 
 
1041
  print(" πŸ“₯ Downloading config.json...")
1042
  config_path = hf_hub_download(
1043
  repo_id=repo_id,
@@ -1048,23 +1051,53 @@ def load_lyra_vae_xl(
1048
  with open(config_path, 'r') as f:
1049
  config_dict = json.load(f)
1050
 
1051
- print(f" βœ“ Config loaded: {config_dict.get('fusion_strategy', 'unknown')} fusion")
1052
 
1053
- # Download model weights
1054
- print(" πŸ“₯ Downloading model.pt...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1055
  checkpoint_path = hf_hub_download(
1056
  repo_id=repo_id,
1057
- filename="checkpoint_lyra_illustrious_37000.pt",
1058
  repo_type="model"
1059
  )
1060
 
1061
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
1062
 
1063
- # Build config from repo's config.json
1064
  vae_config = LyraV2Config(
1065
- modality_dims=config_dict.get('modality_dims', {"clip_l": 768, "clip_g": 1280, "t5_xl_l": 2048, "t5_xl_g": 2048}),
1066
- modality_seq_lens=config_dict.get('modality_seq_lens', {"clip_l": 77, "clip_g": 77, "t5_xl_l": 512, "t5_xl_g": 512}),
1067
- binding_config=config_dict.get('binding_config'),
 
 
 
 
 
 
 
 
 
 
 
1068
  latent_dim=config_dict.get('latent_dim', 2048),
1069
  seq_len=config_dict.get('seq_len', 77),
1070
  encoder_layers=config_dict.get('encoder_layers', 3),
@@ -1078,28 +1111,48 @@ def load_lyra_vae_xl(
1078
  cantor_local_window=config_dict.get('cantor_local_window', 3),
1079
  alpha_init=config_dict.get('alpha_init', 1.0),
1080
  beta_init=config_dict.get('beta_init', 0.3),
 
 
 
 
 
 
 
 
 
1081
  )
1082
 
 
1083
  lyra_model = LyraV2(vae_config)
1084
 
1085
- # Load weights from checkpoint
1086
- if 'model_state_dict' in checkpoint:
1087
- lyra_model.load_state_dict(checkpoint['model_state_dict'])
1088
- else:
1089
- lyra_model.load_state_dict(checkpoint)
 
 
 
1090
 
1091
- # Keep Lyra in float32 for stability - inputs will be upcast
1092
  lyra_model.to(device)
1093
  lyra_model.eval()
1094
 
1095
- print(f"βœ… Lyra VAE v2 loaded")
1096
- print(f" Fusion: {config_dict.get('fusion_strategy')}")
1097
- print(f" Latent dim: {config_dict.get('latent_dim')}")
1098
- print(f" Hidden dim: {config_dict.get('hidden_dim')}")
 
 
1099
  if 'global_step' in checkpoint:
1100
- print(f" Step: {checkpoint['global_step']:,}")
1101
  if 'best_loss' in checkpoint:
1102
- print(f" Loss: {checkpoint['best_loss']:.4f}")
 
 
 
 
 
 
1103
 
1104
  return lyra_model
1105
 
@@ -1132,9 +1185,9 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1132
 
1133
  # T5-XL for Lyra
1134
  print("Loading T5-XL encoder...")
1135
- t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xl")
1136
  t5_encoder = T5EncoderModel.from_pretrained(
1137
- "google/t5-v1_1-xl",
1138
  torch_dtype=torch.float16
1139
  ).to(device)
1140
  t5_encoder.eval()
 
1027
 
1028
  def load_lyra_vae_xl(
1029
  repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
1030
+ checkpoint_filename: str = None, # Auto-detect if None
1031
  device: str = "cuda"
1032
  ):
1033
  """Load Lyra VAE v2 (SDXL/Illustrious version) from HuggingFace."""
 
1038
  print(f"🎡 Loading Lyra VAE v2 from {repo_id}...")
1039
 
1040
  try:
1041
+ from huggingface_hub import list_repo_files
1042
+
1043
+ # Download config.json
1044
  print(" πŸ“₯ Downloading config.json...")
1045
  config_path = hf_hub_download(
1046
  repo_id=repo_id,
 
1051
  with open(config_path, 'r') as f:
1052
  config_dict = json.load(f)
1053
 
1054
+ print(f" βœ“ Config: {config_dict.get('fusion_strategy', 'unknown')} fusion, latent_dim={config_dict.get('latent_dim')}")
1055
 
1056
+ # Auto-detect checkpoint if not specified
1057
+ if checkpoint_filename is None:
1058
+ repo_files = list_repo_files(repo_id, repo_type="model")
1059
+ checkpoint_files = [f for f in repo_files if f.endswith('.pt') or f.endswith('.safetensors')]
1060
+ checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower() or 'model' in f.lower()]
1061
+
1062
+ if not checkpoint_files:
1063
+ raise FileNotFoundError(f"No checkpoint found in {repo_id}")
1064
+
1065
+ # Prefer newest checkpoint (highest step number)
1066
+ def extract_step(name):
1067
+ import re
1068
+ match = re.search(r'(\d+)\.pt', name)
1069
+ return int(match.group(1)) if match else 0
1070
+
1071
+ checkpoint_files.sort(key=extract_step, reverse=True)
1072
+ checkpoint_filename = checkpoint_files[0]
1073
+ print(f" βœ“ Auto-selected checkpoint: {checkpoint_filename}")
1074
+
1075
+ # Download checkpoint
1076
+ print(f" πŸ“₯ Downloading {checkpoint_filename}...")
1077
  checkpoint_path = hf_hub_download(
1078
  repo_id=repo_id,
1079
+ filename=checkpoint_filename,
1080
  repo_type="model"
1081
  )
1082
 
1083
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
1084
 
1085
+ # Build config with all v2 fields
1086
  vae_config = LyraV2Config(
1087
+ modality_dims=config_dict.get('modality_dims', {
1088
+ "clip_l": 768, "clip_g": 1280,
1089
+ "t5_xl_l": 2048, "t5_xl_g": 2048
1090
+ }),
1091
+ modality_seq_lens=config_dict.get('modality_seq_lens', {
1092
+ "clip_l": 77, "clip_g": 77,
1093
+ "t5_xl_l": 512, "t5_xl_g": 512
1094
+ }),
1095
+ binding_config=config_dict.get('binding_config', {
1096
+ "clip_l": {"t5_xl_l": 0.3},
1097
+ "clip_g": {"t5_xl_g": 0.3},
1098
+ "t5_xl_l": {},
1099
+ "t5_xl_g": {}
1100
+ }),
1101
  latent_dim=config_dict.get('latent_dim', 2048),
1102
  seq_len=config_dict.get('seq_len', 77),
1103
  encoder_layers=config_dict.get('encoder_layers', 3),
 
1111
  cantor_local_window=config_dict.get('cantor_local_window', 3),
1112
  alpha_init=config_dict.get('alpha_init', 1.0),
1113
  beta_init=config_dict.get('beta_init', 0.3),
1114
+ alpha_lr_scale=config_dict.get('alpha_lr_scale', 0.1),
1115
+ beta_lr_scale=config_dict.get('beta_lr_scale', 1.0),
1116
+ beta_kl=config_dict.get('beta_kl', 0.1),
1117
+ beta_reconstruction=config_dict.get('beta_reconstruction', 1.0),
1118
+ beta_cross_modal=config_dict.get('beta_cross_modal', 0.0),
1119
+ beta_alpha_regularization=config_dict.get('beta_alpha_regularization', 0.01),
1120
+ kl_clamp_max=config_dict.get('kl_clamp_max', 1.0),
1121
+ logvar_clamp_min=config_dict.get('logvar_clamp_min', -10.0),
1122
+ logvar_clamp_max=config_dict.get('logvar_clamp_max', 10.0),
1123
  )
1124
 
1125
+ # Initialize model
1126
  lyra_model = LyraV2(vae_config)
1127
 
1128
+ # Load weights
1129
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
1130
+ missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False)
1131
+
1132
+ if missing:
1133
+ print(f" ⚠️ Missing keys: {len(missing)} (using initialized weights)")
1134
+ if unexpected:
1135
+ print(f" ⚠️ Unexpected keys: {len(unexpected)} (ignored)")
1136
 
 
1137
  lyra_model.to(device)
1138
  lyra_model.eval()
1139
 
1140
+ # Print summary
1141
+ total_params = sum(p.numel() for p in lyra_model.parameters())
1142
+ print(f"βœ… Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)")
1143
+ print(f" Fusion: {vae_config.fusion_strategy}")
1144
+ print(f" Latent: {vae_config.latent_dim}, Hidden: {vae_config.hidden_dim}")
1145
+
1146
  if 'global_step' in checkpoint:
1147
+ print(f" Trained steps: {checkpoint['global_step']:,}")
1148
  if 'best_loss' in checkpoint:
1149
+ print(f" Best loss: {checkpoint['best_loss']:.4f}")
1150
+
1151
+ # Print binding info
1152
+ fusion_params = lyra_model.get_fusion_params()
1153
+ if fusion_params.get('alphas'):
1154
+ alpha_vals = {k: torch.sigmoid(v).item() for k, v in fusion_params['alphas'].items()}
1155
+ print(f" Alphas: {alpha_vals}")
1156
 
1157
  return lyra_model
1158
 
 
1185
 
1186
  # T5-XL for Lyra
1187
  print("Loading T5-XL encoder...")
1188
+ t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
1189
  t5_encoder = T5EncoderModel.from_pretrained(
1190
+ "google/flan-t5-xl",
1191
  torch_dtype=torch.float16
1192
  ).to(device)
1193
  t5_encoder.eval()