LightDiffusion-Next / tests /e2e /test_core_functionalities.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
import argparse
import os
import random
import time
import sys
import requests
import pytest
from pathlib import Path
# Add the project root to the Python path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))
pytestmark = pytest.mark.slow
from src.user.pipeline import pipeline
def get_absolute_path(relative_path):
return os.path.join(project_root, relative_path)
def run_test(test_function, *args, **kwargs):
"""Decorator to time and print test information."""
print(f"\n--- Running test: {test_function.__name__} ---")
start_time = time.perf_counter()
try:
test_function(*args, **kwargs)
end_time = time.perf_counter()
print(f"--- Test {test_function.__name__} finished in {end_time - start_time:.2f} seconds ---")
except Exception as e:
end_time = time.perf_counter()
print(f"--- Test {test_function.__name__} failed after {end_time - start_time:.2f} seconds ---")
print(f"Error: {e}")
@pytest.mark.slow
def test_normal_pipeline():
"""Tests the default text-to-image pipeline."""
print("Testing normal pipeline with default settings...")
pipeline(
prompt="a beautiful landscape",
w=512,
h=512,
number=1,
batch=1,
)
@pytest.mark.slow
def test_samplers():
"""Tests all available samplers."""
samplers = ["euler", "euler_ancestral", "euler_cfgpp", "euler_ancestral_cfgpp", "dpmpp_2m_cfgpp", "dpmpp_sde_cfgpp"]
for sampler in samplers:
print(f"Testing sampler: {sampler}...")
pipeline(
prompt=f"a beautiful landscape using {sampler} sampler",
w=512,
h=512,
number=1,
batch=1,
sampler=sampler,
)
@pytest.mark.slow
def test_schedulers():
"""Tests all available schedulers."""
schedulers = ["normal", "karras", "simple", "beta", "ays", "ays_sd15", "ays_sdxl"]
for scheduler in schedulers:
print(f"Testing scheduler: {scheduler}...")
pipeline(
prompt=f"a beautiful landscape using {scheduler} scheduler",
w=512,
h=512,
number=1,
batch=1,
scheduler=scheduler,
)
@pytest.mark.slow
def test_optimizations():
"""Tests various optimizations."""
import time
def time_pipeline(description, **kwargs):
print(f"Testing {description}...")
start_time = time.perf_counter()
pipeline(
prompt="a beautiful landscape",
w=512,
h=512,
number=1,
batch=1,
**kwargs
)
end_time = time.perf_counter()
duration = end_time - start_time
print(f"{duration:.2f}")
return duration
baseline_time = time_pipeline("baseline (no optimizations)")
stable_fast_time = time_pipeline("Stable-Fast optimization", stable_fast=True)
multiscale_time = time_pipeline("multiscale diffusion", enable_multiscale=True, multiscale_preset="performance")
deepcache_time = time_pipeline("DeepCache", deepcache_enabled=True)
print("\n--- Speed Comparison Results ---")
print(f"Baseline time: {baseline_time:.2f}s")
@pytest.mark.slow
def test_img2img():
"""Tests the img2img pipeline."""
print("Testing img2img pipeline...")
dummy_image_path = get_absolute_path("tests/dummy_image.png")
if not os.path.exists(dummy_image_path):
from PIL import Image
img = Image.new('RGB', (256, 256), color = 'red')
img.save(dummy_image_path)
pipeline(
prompt="a beautiful landscape",
w=512,
h=512,
number=1,
batch=1,
img2img=True,
img2img_image=dummy_image_path,
)
@pytest.mark.asyncio
@pytest.mark.slow
async def test_api_endpoints(monkeypatch, async_server_client):
"""Tests the API endpoints via the in-process ASGI transport."""
print("Testing /health endpoint via ASGI transport...")
async def fake_enqueue(_pending):
return {"image": "data:image/png;base64,xyz"}
import server
monkeypatch.setattr(server._generation_buffer, "enqueue", fake_enqueue)
# Health endpoint
response = await async_server_client.get("/health")
assert response.status_code == 200
# Test generate endpoint with a tiny steps value
print("Testing /api/generate endpoint via ASGI transport...")
payload = {
"prompt": "a beautiful landscape",
"width": 512,
"height": 512,
"steps": 1,
}
response = await async_server_client.post("/api/generate", json=payload)
assert response.status_code == 200
@pytest.mark.slow
def test_hires_fix():
pipeline(
prompt="a beautiful landscape with hires_fix",
w=512,
h=512,
number=1,
batch=1,
hires_fix=True,
)
@pytest.mark.slow
def test_adetailer():
pipeline(
prompt="a beautiful landscape with adetailer",
w=512,
h=512,
number=1,
batch=1,
adetailer=True,
)
@pytest.mark.slow
def test_enhance_prompt():
pipeline(
prompt="a beautiful landscape with enhance_prompt",
w=512,
h=512,
number=1,
batch=1,
enhance_prompt=True,
)
@pytest.mark.slow
def test_reuse_seed():
pipeline(
prompt="a beautiful landscape with reuse_seed",
w=512,
h=512,
number=1,
batch=1,
reuse_seed=True,
)
@pytest.mark.slow
def test_realistic_model():
pipeline(
prompt="a beautiful landscape with realistic_model",
w=512,
h=512,
number=1,
batch=1,
realistic_model=True,
)
@pytest.mark.slow
def test_all_features():
pipeline(
prompt="a beautiful landscape with all features",
w=512,
h=512,
number=1,
batch=1,
hires_fix=True,
adetailer=True,
enhance_prompt=True,
reuse_seed=True,
realistic_model=True,
)
@pytest.mark.slow
def benchmark_optimizations():
import time
import statistics
def benchmark_pipeline(description, runs=3, **kwargs):
times = []
for i in range(runs):
print(f"Benchmarking {description} - run {i+1}/{runs}...")
start_time = time.perf_counter()
pipeline(
prompt="a beautiful landscape",
w=512,
h=512,
number=1,
batch=1,
**kwargs
)
end_time = time.perf_counter()
duration = end_time - start_time
times.append(duration)
print(f"{duration:.2f}")
avg_time = statistics.mean(times)
std_dev = statistics.stdev(times) if len(times) > 1 else 0
print(f"{avg_time:.2f}")
return avg_time, std_dev
print("=== Optimization Benchmark (3 runs each) ===")
baseline_avg, baseline_std = benchmark_pipeline("Baseline (no optimizations)")
optimizations = [
("Stable-Fast", {"stable_fast": True}),
("Multiscale Performance", {"enable_multiscale": True, "multiscale_preset": "performance"}),
("DeepCache", {"deepcache_enabled": True}),
]
results = []
for name, kwargs in optimizations:
avg, std = benchmark_pipeline(name, **kwargs)
results.append((name, avg, std))
print("\n=== Benchmark Results Summary ===")
print(f"{'Optimization':<25} {'Avg Time (s)':<15} {'Std Dev':<10} {'Speedup':<10}")
print("-" * 60)
print(f"{'Baseline':<25} {baseline_avg:<15.2f} {baseline_std:<10.2f} {'1.00x':<10}")
for name, avg, std in results:
speedup = baseline_avg / avg
print(f"{name:<25} {avg:<15.2f} {std:<10.2f} {speedup:<10.2f}x")
print("\nNote: Lower time = better performance. Speedup > 1 means faster than baseline.")
def main():
parser = argparse.ArgumentParser(description="Run LightDiffusion-Next core functionality tests.")
parser.add_argument("--all", action="store_true", help="Run all tests.")
parser.add_argument("--normal", action="store_true", help="Run normal pipeline test.")
parser.add_argument("--samplers", action="store_true", help="Run samplers test.")
parser.add_argument("--schedulers", action="store_true", help="Run schedulers test.")
parser.add_argument("--optimizations", action="store_true", help="Run optimizations test.")
parser.add_argument("--img2img", action="store_true", help="Run img2img test.")
parser.add_argument("--api", action="store_true", help="Run API endpoints test.")
parser.add_argument("--hires_fix", action="store_true", help="Run hires_fix test.")
parser.add_argument("--adetailer", action="store_true", help="Run adetailer test.")
parser.add_argument("--enhance_prompt", action="store_true", help="Run enhance_prompt test.")
parser.add_argument("--reuse_seed", action="store_true", help="Run reuse_seed test.")
parser.add_argument("--realistic_model", action="store_true", help="Run realistic_model test.")
parser.add_argument("--all_features", action="store_true", help="Run all features test.")
parser.add_argument("--benchmark", action="store_true", help="Run optimization benchmark.")
args = parser.parse_args()
if args.all or args.normal:
run_test(test_normal_pipeline)
if args.all or args.samplers:
run_test(test_samplers)
if args.all or args.schedulers:
run_test(test_schedulers)
if args.all or args.optimizations:
run_test(test_optimizations)
if args.all or args.img2img:
run_test(test_img2img)
if args.all or args.api:
run_test(test_api_endpoints)
if args.all or args.hires_fix:
run_test(test_hires_fix)
if args.all or args.adetailer:
run_test(test_adetailer)
if args.all or args.enhance_prompt:
run_test(test_enhance_prompt)
if args.all or args.reuse_seed:
run_test(test_reuse_seed)
if args.all or args.realistic_model:
run_test(test_realistic_model)
if args.all or args.all_features:
run_test(test_all_features)
if args.benchmark:
run_test(benchmark_optimizations)
if not any(vars(args).values()):
print("No tests selected. Use --all to run all tests or select specific tests.")
if __name__ == "__main__":
main()