|
|
|
|
|
import time |
|
|
import argparse |
|
|
import torch |
|
|
from diffusers import FluxPipeline |
|
|
|
|
|
def benchmark_load_lora( |
|
|
base_model: str, |
|
|
lora_source: str, |
|
|
weight_name: str = None, |
|
|
adapter_name: str = None, |
|
|
dtype = torch.bfloat16, |
|
|
runs: int = 3, |
|
|
): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Benchmarking on device {device}, torch.cuda.device_count()={torch.cuda.device_count()}.") |
|
|
|
|
|
print(f"1/4. Loading base Flux.1-dev model …") |
|
|
t0 = time.time() |
|
|
pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=dtype, use_safetensors=True) |
|
|
base_load_s = time.time() - t0 |
|
|
print(f" Base model loaded in {base_load_s:.3f} s") |
|
|
|
|
|
print("2/4. Moving pipeline to GPU …") |
|
|
t1 = time.time() |
|
|
pipe = pipe.to(device) |
|
|
torch.cuda.synchronize(device) |
|
|
move_s = time.time() - t1 |
|
|
print(f" to('cuda') took {move_s:.3f} s") |
|
|
|
|
|
|
|
|
for i in range(runs): |
|
|
print(f"3.{i+1}/4. Running load_lora_weights (run {i+1}/{runs}) …") |
|
|
start = time.time() |
|
|
adapter_name = "lora" |
|
|
pipe.load_lora_weights(lora_source, adapter_name=adapter_name) |
|
|
torch.cuda.synchronize(device) |
|
|
duration = time.time() - start |
|
|
print(f" → run {i+1}: load_lora_weights took {duration:.3f} s") |
|
|
|
|
|
if i < runs - 1: |
|
|
print(" Unloading LoRA …") |
|
|
pipe.unload_lora_weights(reset_to_overwritten_params=True) |
|
|
torch.cuda.synchronize(device) |
|
|
|
|
|
print("All runs complete.") |
|
|
avg = duration |
|
|
print(f"☆ Final run time: {avg:.3f} s") |
|
|
print(f"― average over {runs} runs ≈ {avg:.3f} s") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Benchmark Flux.1‑dev load_lora_weights timing" |
|
|
) |
|
|
parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev") |
|
|
parser.add_argument("--lora", required=True, help="LoRA adapter repo ID or local folder / file path") |
|
|
parser.add_argument("--runs", type=int, default=3) |
|
|
args = parser.parse_args() |
|
|
|
|
|
benchmark_load_lora( |
|
|
base_model=args.model, |
|
|
lora_source=args.lora, |
|
|
runs=args.runs |
|
|
) |
|
|
|