Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import os | |
| import sys | |
| import warnings | |
| from PIL import Image | |
| # 1. Setup paths | |
| sys.path.append(os.path.abspath("stylegan2-ada-pytorch")) | |
| from gap_engine import StyleGAN2AgingEngine | |
| # Suppress warnings | |
| os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # 2. Config | |
| WEIGHTS = "weights/stylegan2_ffhq.pkl" | |
| AGE_V = "weights/age_boundary.npy" | |
| GENDER_V = "weights/gender_boundary.npy" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def check_model(): | |
| print(f"--- Diagnostic Check: StyleGAN2 Model ---") | |
| if not os.path.exists(WEIGHTS): | |
| print("FAIL: Model weights missing.") | |
| return | |
| try: | |
| engine = StyleGAN2AgingEngine(WEIGHTS, AGE_V, GENDER_V, device=DEVICE) | |
| engine.load_model() | |
| print("SUCCESS: Model loaded into memory.") | |
| # Test 1: Generate 3 random faces (Identity check) | |
| print("Testing Generator (3 seeds)...") | |
| results = [] | |
| for seed in [100, 200, 300]: | |
| z = torch.from_numpy(np.random.RandomState(seed).randn(1, engine.G.z_dim)).to(DEVICE) | |
| with torch.no_grad(): | |
| w = engine.G.mapping(z, None) | |
| img = engine.generate_at_age(w, 0, 0) # Neutral age/gender | |
| results.append(img) | |
| print(f" - Seed {seed} generated successfully.") | |
| # Save a grid to verify visually | |
| final_img = Image.fromarray(np.concatenate(results, axis=1)) | |
| final_img.save("test_generator_output.png") | |
| print("SUCCESS: Diagnostic image saved to 'test_generator_output.png'.") | |
| # Test 2: Verify Latent Math (Age Vector check) | |
| print("Testing Latent Math (Age shift)...") | |
| z = torch.from_numpy(np.random.RandomState(100).randn(1, engine.G.z_dim)).to(DEVICE) | |
| with torch.no_grad(): | |
| w = engine.G.mapping(z, None) | |
| # Generate young and old | |
| img_young = engine.generate_at_age(w, -5.0) | |
| img_old = engine.generate_at_age(w, 5.0) | |
| # Check if they are different | |
| if np.array_equal(img_young, img_old): | |
| print("FAIL: Latent age vector has NO effect.") | |
| else: | |
| print("SUCCESS: Latent age vector shifts the image.") | |
| Image.fromarray(np.concatenate([img_young, results[0], img_old], axis=1)).save("test_age_shift.png") | |
| except Exception as e: | |
| print(f"CRITICAL FAIL: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| check_model() | |