Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1027,6 +1027,7 @@ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
|
|
| 1027 |
|
| 1028 |
def load_lyra_vae_xl(
|
| 1029 |
repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
|
|
|
|
| 1030 |
device: str = "cuda"
|
| 1031 |
):
|
| 1032 |
"""Load Lyra VAE v2 (SDXL/Illustrious version) from HuggingFace."""
|
|
@@ -1037,7 +1038,9 @@ def load_lyra_vae_xl(
|
|
| 1037 |
print(f"π΅ Loading Lyra VAE v2 from {repo_id}...")
|
| 1038 |
|
| 1039 |
try:
|
| 1040 |
-
|
|
|
|
|
|
|
| 1041 |
print(" π₯ Downloading config.json...")
|
| 1042 |
config_path = hf_hub_download(
|
| 1043 |
repo_id=repo_id,
|
|
@@ -1048,23 +1051,53 @@ def load_lyra_vae_xl(
|
|
| 1048 |
with open(config_path, 'r') as f:
|
| 1049 |
config_dict = json.load(f)
|
| 1050 |
|
| 1051 |
-
print(f" β Config
|
| 1052 |
|
| 1053 |
-
#
|
| 1054 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
checkpoint_path = hf_hub_download(
|
| 1056 |
repo_id=repo_id,
|
| 1057 |
-
filename=
|
| 1058 |
repo_type="model"
|
| 1059 |
)
|
| 1060 |
|
| 1061 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 1062 |
|
| 1063 |
-
# Build config
|
| 1064 |
vae_config = LyraV2Config(
|
| 1065 |
-
modality_dims=config_dict.get('modality_dims', {
|
| 1066 |
-
|
| 1067 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1068 |
latent_dim=config_dict.get('latent_dim', 2048),
|
| 1069 |
seq_len=config_dict.get('seq_len', 77),
|
| 1070 |
encoder_layers=config_dict.get('encoder_layers', 3),
|
|
@@ -1078,28 +1111,48 @@ def load_lyra_vae_xl(
|
|
| 1078 |
cantor_local_window=config_dict.get('cantor_local_window', 3),
|
| 1079 |
alpha_init=config_dict.get('alpha_init', 1.0),
|
| 1080 |
beta_init=config_dict.get('beta_init', 0.3),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1081 |
)
|
| 1082 |
|
|
|
|
| 1083 |
lyra_model = LyraV2(vae_config)
|
| 1084 |
|
| 1085 |
-
# Load weights
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
|
|
|
|
|
|
|
|
|
| 1090 |
|
| 1091 |
-
# Keep Lyra in float32 for stability - inputs will be upcast
|
| 1092 |
lyra_model.to(device)
|
| 1093 |
lyra_model.eval()
|
| 1094 |
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
print(f"
|
| 1098 |
-
print(f"
|
|
|
|
|
|
|
| 1099 |
if 'global_step' in checkpoint:
|
| 1100 |
-
print(f"
|
| 1101 |
if 'best_loss' in checkpoint:
|
| 1102 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
|
| 1104 |
return lyra_model
|
| 1105 |
|
|
@@ -1132,9 +1185,9 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
|
|
| 1132 |
|
| 1133 |
# T5-XL for Lyra
|
| 1134 |
print("Loading T5-XL encoder...")
|
| 1135 |
-
t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-
|
| 1136 |
t5_encoder = T5EncoderModel.from_pretrained(
|
| 1137 |
-
"google/t5-
|
| 1138 |
torch_dtype=torch.float16
|
| 1139 |
).to(device)
|
| 1140 |
t5_encoder.eval()
|
|
|
|
| 1027 |
|
| 1028 |
def load_lyra_vae_xl(
|
| 1029 |
repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
|
| 1030 |
+
checkpoint_filename: str = None, # Auto-detect if None
|
| 1031 |
device: str = "cuda"
|
| 1032 |
):
|
| 1033 |
"""Load Lyra VAE v2 (SDXL/Illustrious version) from HuggingFace."""
|
|
|
|
| 1038 |
print(f"π΅ Loading Lyra VAE v2 from {repo_id}...")
|
| 1039 |
|
| 1040 |
try:
|
| 1041 |
+
from huggingface_hub import list_repo_files
|
| 1042 |
+
|
| 1043 |
+
# Download config.json
|
| 1044 |
print(" π₯ Downloading config.json...")
|
| 1045 |
config_path = hf_hub_download(
|
| 1046 |
repo_id=repo_id,
|
|
|
|
| 1051 |
with open(config_path, 'r') as f:
|
| 1052 |
config_dict = json.load(f)
|
| 1053 |
|
| 1054 |
+
print(f" β Config: {config_dict.get('fusion_strategy', 'unknown')} fusion, latent_dim={config_dict.get('latent_dim')}")
|
| 1055 |
|
| 1056 |
+
# Auto-detect checkpoint if not specified
|
| 1057 |
+
if checkpoint_filename is None:
|
| 1058 |
+
repo_files = list_repo_files(repo_id, repo_type="model")
|
| 1059 |
+
checkpoint_files = [f for f in repo_files if f.endswith('.pt') or f.endswith('.safetensors')]
|
| 1060 |
+
checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower() or 'model' in f.lower()]
|
| 1061 |
+
|
| 1062 |
+
if not checkpoint_files:
|
| 1063 |
+
raise FileNotFoundError(f"No checkpoint found in {repo_id}")
|
| 1064 |
+
|
| 1065 |
+
# Prefer newest checkpoint (highest step number)
|
| 1066 |
+
def extract_step(name):
|
| 1067 |
+
import re
|
| 1068 |
+
match = re.search(r'(\d+)\.pt', name)
|
| 1069 |
+
return int(match.group(1)) if match else 0
|
| 1070 |
+
|
| 1071 |
+
checkpoint_files.sort(key=extract_step, reverse=True)
|
| 1072 |
+
checkpoint_filename = checkpoint_files[0]
|
| 1073 |
+
print(f" β Auto-selected checkpoint: {checkpoint_filename}")
|
| 1074 |
+
|
| 1075 |
+
# Download checkpoint
|
| 1076 |
+
print(f" π₯ Downloading {checkpoint_filename}...")
|
| 1077 |
checkpoint_path = hf_hub_download(
|
| 1078 |
repo_id=repo_id,
|
| 1079 |
+
filename=checkpoint_filename,
|
| 1080 |
repo_type="model"
|
| 1081 |
)
|
| 1082 |
|
| 1083 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 1084 |
|
| 1085 |
+
# Build config with all v2 fields
|
| 1086 |
vae_config = LyraV2Config(
|
| 1087 |
+
modality_dims=config_dict.get('modality_dims', {
|
| 1088 |
+
"clip_l": 768, "clip_g": 1280,
|
| 1089 |
+
"t5_xl_l": 2048, "t5_xl_g": 2048
|
| 1090 |
+
}),
|
| 1091 |
+
modality_seq_lens=config_dict.get('modality_seq_lens', {
|
| 1092 |
+
"clip_l": 77, "clip_g": 77,
|
| 1093 |
+
"t5_xl_l": 512, "t5_xl_g": 512
|
| 1094 |
+
}),
|
| 1095 |
+
binding_config=config_dict.get('binding_config', {
|
| 1096 |
+
"clip_l": {"t5_xl_l": 0.3},
|
| 1097 |
+
"clip_g": {"t5_xl_g": 0.3},
|
| 1098 |
+
"t5_xl_l": {},
|
| 1099 |
+
"t5_xl_g": {}
|
| 1100 |
+
}),
|
| 1101 |
latent_dim=config_dict.get('latent_dim', 2048),
|
| 1102 |
seq_len=config_dict.get('seq_len', 77),
|
| 1103 |
encoder_layers=config_dict.get('encoder_layers', 3),
|
|
|
|
| 1111 |
cantor_local_window=config_dict.get('cantor_local_window', 3),
|
| 1112 |
alpha_init=config_dict.get('alpha_init', 1.0),
|
| 1113 |
beta_init=config_dict.get('beta_init', 0.3),
|
| 1114 |
+
alpha_lr_scale=config_dict.get('alpha_lr_scale', 0.1),
|
| 1115 |
+
beta_lr_scale=config_dict.get('beta_lr_scale', 1.0),
|
| 1116 |
+
beta_kl=config_dict.get('beta_kl', 0.1),
|
| 1117 |
+
beta_reconstruction=config_dict.get('beta_reconstruction', 1.0),
|
| 1118 |
+
beta_cross_modal=config_dict.get('beta_cross_modal', 0.0),
|
| 1119 |
+
beta_alpha_regularization=config_dict.get('beta_alpha_regularization', 0.01),
|
| 1120 |
+
kl_clamp_max=config_dict.get('kl_clamp_max', 1.0),
|
| 1121 |
+
logvar_clamp_min=config_dict.get('logvar_clamp_min', -10.0),
|
| 1122 |
+
logvar_clamp_max=config_dict.get('logvar_clamp_max', 10.0),
|
| 1123 |
)
|
| 1124 |
|
| 1125 |
+
# Initialize model
|
| 1126 |
lyra_model = LyraV2(vae_config)
|
| 1127 |
|
| 1128 |
+
# Load weights
|
| 1129 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| 1130 |
+
missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False)
|
| 1131 |
+
|
| 1132 |
+
if missing:
|
| 1133 |
+
print(f" β οΈ Missing keys: {len(missing)} (using initialized weights)")
|
| 1134 |
+
if unexpected:
|
| 1135 |
+
print(f" β οΈ Unexpected keys: {len(unexpected)} (ignored)")
|
| 1136 |
|
|
|
|
| 1137 |
lyra_model.to(device)
|
| 1138 |
lyra_model.eval()
|
| 1139 |
|
| 1140 |
+
# Print summary
|
| 1141 |
+
total_params = sum(p.numel() for p in lyra_model.parameters())
|
| 1142 |
+
print(f"β
Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)")
|
| 1143 |
+
print(f" Fusion: {vae_config.fusion_strategy}")
|
| 1144 |
+
print(f" Latent: {vae_config.latent_dim}, Hidden: {vae_config.hidden_dim}")
|
| 1145 |
+
|
| 1146 |
if 'global_step' in checkpoint:
|
| 1147 |
+
print(f" Trained steps: {checkpoint['global_step']:,}")
|
| 1148 |
if 'best_loss' in checkpoint:
|
| 1149 |
+
print(f" Best loss: {checkpoint['best_loss']:.4f}")
|
| 1150 |
+
|
| 1151 |
+
# Print binding info
|
| 1152 |
+
fusion_params = lyra_model.get_fusion_params()
|
| 1153 |
+
if fusion_params.get('alphas'):
|
| 1154 |
+
alpha_vals = {k: torch.sigmoid(v).item() for k, v in fusion_params['alphas'].items()}
|
| 1155 |
+
print(f" Alphas: {alpha_vals}")
|
| 1156 |
|
| 1157 |
return lyra_model
|
| 1158 |
|
|
|
|
| 1185 |
|
| 1186 |
# T5-XL for Lyra
|
| 1187 |
print("Loading T5-XL encoder...")
|
| 1188 |
+
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
|
| 1189 |
t5_encoder = T5EncoderModel.from_pretrained(
|
| 1190 |
+
"google/flan-t5-xl",
|
| 1191 |
torch_dtype=torch.float16
|
| 1192 |
).to(device)
|
| 1193 |
t5_encoder.eval()
|