Age-Synthesis / check_model_integrity.py
Bhaskr's picture
Initial clean deployment
2ab0040
import torch
import numpy as np
import os
import sys
import warnings
from PIL import Image
# 1. Setup paths
sys.path.append(os.path.abspath("stylegan2-ada-pytorch"))
from gap_engine import StyleGAN2AgingEngine
# Suppress warnings
os.environ['KMP_DUPLICATE_LIB_OK']='True'
warnings.filterwarnings("ignore", category=UserWarning)
# 2. Config
WEIGHTS = "weights/stylegan2_ffhq.pkl"
AGE_V = "weights/age_boundary.npy"
GENDER_V = "weights/gender_boundary.npy"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def check_model():
print(f"--- Diagnostic Check: StyleGAN2 Model ---")
if not os.path.exists(WEIGHTS):
print("FAIL: Model weights missing.")
return
try:
engine = StyleGAN2AgingEngine(WEIGHTS, AGE_V, GENDER_V, device=DEVICE)
engine.load_model()
print("SUCCESS: Model loaded into memory.")
# Test 1: Generate 3 random faces (Identity check)
print("Testing Generator (3 seeds)...")
results = []
for seed in [100, 200, 300]:
z = torch.from_numpy(np.random.RandomState(seed).randn(1, engine.G.z_dim)).to(DEVICE)
with torch.no_grad():
w = engine.G.mapping(z, None)
img = engine.generate_at_age(w, 0, 0) # Neutral age/gender
results.append(img)
print(f" - Seed {seed} generated successfully.")
# Save a grid to verify visually
final_img = Image.fromarray(np.concatenate(results, axis=1))
final_img.save("test_generator_output.png")
print("SUCCESS: Diagnostic image saved to 'test_generator_output.png'.")
# Test 2: Verify Latent Math (Age Vector check)
print("Testing Latent Math (Age shift)...")
z = torch.from_numpy(np.random.RandomState(100).randn(1, engine.G.z_dim)).to(DEVICE)
with torch.no_grad():
w = engine.G.mapping(z, None)
# Generate young and old
img_young = engine.generate_at_age(w, -5.0)
img_old = engine.generate_at_age(w, 5.0)
# Check if they are different
if np.array_equal(img_young, img_old):
print("FAIL: Latent age vector has NO effect.")
else:
print("SUCCESS: Latent age vector shifts the image.")
Image.fromarray(np.concatenate([img_young, results[0], img_old], axis=1)).save("test_age_shift.png")
except Exception as e:
print(f"CRITICAL FAIL: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
check_model()