Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import os | |
| import random | |
| import time | |
| import sys | |
| import requests | |
| import pytest | |
| from pathlib import Path | |
| # Add the project root to the Python path | |
| project_root = Path(__file__).resolve().parent.parent.parent | |
| sys.path.append(str(project_root)) | |
| pytestmark = pytest.mark.slow | |
| from src.user.pipeline import pipeline | |
| def get_absolute_path(relative_path): | |
| return os.path.join(project_root, relative_path) | |
| def run_test(test_function, *args, **kwargs): | |
| """Decorator to time and print test information.""" | |
| print(f"\n--- Running test: {test_function.__name__} ---") | |
| start_time = time.perf_counter() | |
| try: | |
| test_function(*args, **kwargs) | |
| end_time = time.perf_counter() | |
| print(f"--- Test {test_function.__name__} finished in {end_time - start_time:.2f} seconds ---") | |
| except Exception as e: | |
| end_time = time.perf_counter() | |
| print(f"--- Test {test_function.__name__} failed after {end_time - start_time:.2f} seconds ---") | |
| print(f"Error: {e}") | |
| def test_normal_pipeline(): | |
| """Tests the default text-to-image pipeline.""" | |
| print("Testing normal pipeline with default settings...") | |
| pipeline( | |
| prompt="a beautiful landscape", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| ) | |
| def test_samplers(): | |
| """Tests all available samplers.""" | |
| samplers = ["euler", "euler_ancestral", "euler_cfgpp", "euler_ancestral_cfgpp", "dpmpp_2m_cfgpp", "dpmpp_sde_cfgpp"] | |
| for sampler in samplers: | |
| print(f"Testing sampler: {sampler}...") | |
| pipeline( | |
| prompt=f"a beautiful landscape using {sampler} sampler", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| sampler=sampler, | |
| ) | |
| def test_schedulers(): | |
| """Tests all available schedulers.""" | |
| schedulers = ["normal", "karras", "simple", "beta", "ays", "ays_sd15", "ays_sdxl"] | |
| for scheduler in schedulers: | |
| print(f"Testing scheduler: {scheduler}...") | |
| pipeline( | |
| prompt=f"a beautiful landscape using {scheduler} scheduler", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| scheduler=scheduler, | |
| ) | |
| def test_optimizations(): | |
| """Tests various optimizations.""" | |
| import time | |
| def time_pipeline(description, **kwargs): | |
| print(f"Testing {description}...") | |
| start_time = time.perf_counter() | |
| pipeline( | |
| prompt="a beautiful landscape", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| **kwargs | |
| ) | |
| end_time = time.perf_counter() | |
| duration = end_time - start_time | |
| print(f"{duration:.2f}") | |
| return duration | |
| baseline_time = time_pipeline("baseline (no optimizations)") | |
| stable_fast_time = time_pipeline("Stable-Fast optimization", stable_fast=True) | |
| multiscale_time = time_pipeline("multiscale diffusion", enable_multiscale=True, multiscale_preset="performance") | |
| deepcache_time = time_pipeline("DeepCache", deepcache_enabled=True) | |
| print("\n--- Speed Comparison Results ---") | |
| print(f"Baseline time: {baseline_time:.2f}s") | |
| def test_img2img(): | |
| """Tests the img2img pipeline.""" | |
| print("Testing img2img pipeline...") | |
| dummy_image_path = get_absolute_path("tests/dummy_image.png") | |
| if not os.path.exists(dummy_image_path): | |
| from PIL import Image | |
| img = Image.new('RGB', (256, 256), color = 'red') | |
| img.save(dummy_image_path) | |
| pipeline( | |
| prompt="a beautiful landscape", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| img2img=True, | |
| img2img_image=dummy_image_path, | |
| ) | |
| async def test_api_endpoints(monkeypatch, async_server_client): | |
| """Tests the API endpoints via the in-process ASGI transport.""" | |
| print("Testing /health endpoint via ASGI transport...") | |
| async def fake_enqueue(_pending): | |
| return {"image": "data:image/png;base64,xyz"} | |
| import server | |
| monkeypatch.setattr(server._generation_buffer, "enqueue", fake_enqueue) | |
| # Health endpoint | |
| response = await async_server_client.get("/health") | |
| assert response.status_code == 200 | |
| # Test generate endpoint with a tiny steps value | |
| print("Testing /api/generate endpoint via ASGI transport...") | |
| payload = { | |
| "prompt": "a beautiful landscape", | |
| "width": 512, | |
| "height": 512, | |
| "steps": 1, | |
| } | |
| response = await async_server_client.post("/api/generate", json=payload) | |
| assert response.status_code == 200 | |
| def test_hires_fix(): | |
| pipeline( | |
| prompt="a beautiful landscape with hires_fix", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| hires_fix=True, | |
| ) | |
| def test_adetailer(): | |
| pipeline( | |
| prompt="a beautiful landscape with adetailer", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| adetailer=True, | |
| ) | |
| def test_enhance_prompt(): | |
| pipeline( | |
| prompt="a beautiful landscape with enhance_prompt", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| enhance_prompt=True, | |
| ) | |
| def test_reuse_seed(): | |
| pipeline( | |
| prompt="a beautiful landscape with reuse_seed", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| reuse_seed=True, | |
| ) | |
| def test_realistic_model(): | |
| pipeline( | |
| prompt="a beautiful landscape with realistic_model", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| realistic_model=True, | |
| ) | |
| def test_all_features(): | |
| pipeline( | |
| prompt="a beautiful landscape with all features", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| hires_fix=True, | |
| adetailer=True, | |
| enhance_prompt=True, | |
| reuse_seed=True, | |
| realistic_model=True, | |
| ) | |
| def benchmark_optimizations(): | |
| import time | |
| import statistics | |
| def benchmark_pipeline(description, runs=3, **kwargs): | |
| times = [] | |
| for i in range(runs): | |
| print(f"Benchmarking {description} - run {i+1}/{runs}...") | |
| start_time = time.perf_counter() | |
| pipeline( | |
| prompt="a beautiful landscape", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| **kwargs | |
| ) | |
| end_time = time.perf_counter() | |
| duration = end_time - start_time | |
| times.append(duration) | |
| print(f"{duration:.2f}") | |
| avg_time = statistics.mean(times) | |
| std_dev = statistics.stdev(times) if len(times) > 1 else 0 | |
| print(f"{avg_time:.2f}") | |
| return avg_time, std_dev | |
| print("=== Optimization Benchmark (3 runs each) ===") | |
| baseline_avg, baseline_std = benchmark_pipeline("Baseline (no optimizations)") | |
| optimizations = [ | |
| ("Stable-Fast", {"stable_fast": True}), | |
| ("Multiscale Performance", {"enable_multiscale": True, "multiscale_preset": "performance"}), | |
| ("DeepCache", {"deepcache_enabled": True}), | |
| ] | |
| results = [] | |
| for name, kwargs in optimizations: | |
| avg, std = benchmark_pipeline(name, **kwargs) | |
| results.append((name, avg, std)) | |
| print("\n=== Benchmark Results Summary ===") | |
| print(f"{'Optimization':<25} {'Avg Time (s)':<15} {'Std Dev':<10} {'Speedup':<10}") | |
| print("-" * 60) | |
| print(f"{'Baseline':<25} {baseline_avg:<15.2f} {baseline_std:<10.2f} {'1.00x':<10}") | |
| for name, avg, std in results: | |
| speedup = baseline_avg / avg | |
| print(f"{name:<25} {avg:<15.2f} {std:<10.2f} {speedup:<10.2f}x") | |
| print("\nNote: Lower time = better performance. Speedup > 1 means faster than baseline.") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Run LightDiffusion-Next core functionality tests.") | |
| parser.add_argument("--all", action="store_true", help="Run all tests.") | |
| parser.add_argument("--normal", action="store_true", help="Run normal pipeline test.") | |
| parser.add_argument("--samplers", action="store_true", help="Run samplers test.") | |
| parser.add_argument("--schedulers", action="store_true", help="Run schedulers test.") | |
| parser.add_argument("--optimizations", action="store_true", help="Run optimizations test.") | |
| parser.add_argument("--img2img", action="store_true", help="Run img2img test.") | |
| parser.add_argument("--api", action="store_true", help="Run API endpoints test.") | |
| parser.add_argument("--hires_fix", action="store_true", help="Run hires_fix test.") | |
| parser.add_argument("--adetailer", action="store_true", help="Run adetailer test.") | |
| parser.add_argument("--enhance_prompt", action="store_true", help="Run enhance_prompt test.") | |
| parser.add_argument("--reuse_seed", action="store_true", help="Run reuse_seed test.") | |
| parser.add_argument("--realistic_model", action="store_true", help="Run realistic_model test.") | |
| parser.add_argument("--all_features", action="store_true", help="Run all features test.") | |
| parser.add_argument("--benchmark", action="store_true", help="Run optimization benchmark.") | |
| args = parser.parse_args() | |
| if args.all or args.normal: | |
| run_test(test_normal_pipeline) | |
| if args.all or args.samplers: | |
| run_test(test_samplers) | |
| if args.all or args.schedulers: | |
| run_test(test_schedulers) | |
| if args.all or args.optimizations: | |
| run_test(test_optimizations) | |
| if args.all or args.img2img: | |
| run_test(test_img2img) | |
| if args.all or args.api: | |
| run_test(test_api_endpoints) | |
| if args.all or args.hires_fix: | |
| run_test(test_hires_fix) | |
| if args.all or args.adetailer: | |
| run_test(test_adetailer) | |
| if args.all or args.enhance_prompt: | |
| run_test(test_enhance_prompt) | |
| if args.all or args.reuse_seed: | |
| run_test(test_reuse_seed) | |
| if args.all or args.realistic_model: | |
| run_test(test_realistic_model) | |
| if args.all or args.all_features: | |
| run_test(test_all_features) | |
| if args.benchmark: | |
| run_test(benchmark_optimizations) | |
| if not any(vars(args).values()): | |
| print("No tests selected. Use --all to run all tests or select specific tests.") | |
| if __name__ == "__main__": | |
| main() | |