|
|
import argparse |
|
|
import functools |
|
|
import os |
|
|
from torch.profiler import record_function |
|
|
|
|
|
|
|
|
def create_parser(): |
|
|
"""Creates CLI args parser.""" |
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
|
|
|
|
|
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell", |
|
|
help="Model checkpoint path") |
|
|
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn", |
|
|
help="Text prompt") |
|
|
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"), |
|
|
help="Cache directory for storing exported models") |
|
|
parser.add_argument("--use-cached-model", action="store_true", |
|
|
help="Attempt to use cached model only (don't re-export)") |
|
|
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda", |
|
|
help="Device to use") |
|
|
parser.add_argument("--num_inference_steps", type=int, default=4, |
|
|
help="Number of denoising steps") |
|
|
parser.add_argument("--output-file", type=str, default="output.png", |
|
|
help="Output image file path") |
|
|
parser.add_argument("--seed", type=int, default=42, |
|
|
help="Random seed to use") |
|
|
|
|
|
parser.add_argument("--trace-file", type=str, default=None, |
|
|
help="Output PyTorch Profiler trace file path") |
|
|
|
|
|
|
|
|
parser.add_argument("--disable_bf16", action="store_true", |
|
|
help="Disables usage of torch.bfloat16") |
|
|
|
|
|
parser.add_argument("--compile_export_mode", type=str, default="export_aoti", |
|
|
choices=["compile", "export_aoti", "disabled"], |
|
|
help="Configures how torch.compile or torch.export + AOTI are used") |
|
|
|
|
|
parser.add_argument("--disable_fused_projections", action="store_true", |
|
|
help="Disables fused q,k,v projections") |
|
|
|
|
|
parser.add_argument("--disable_channels_last", action="store_true", |
|
|
help="Disables usage of torch.channels_last memory format") |
|
|
|
|
|
parser.add_argument("--disable_fa3", action="store_true", |
|
|
help="Disables use of Flash Attention V3") |
|
|
|
|
|
parser.add_argument("--disable_quant", action="store_true", |
|
|
help="Disables usage of dynamic float8 quantization") |
|
|
|
|
|
parser.add_argument("--disable-nsfw-filter", action="store_true", |
|
|
help="Disables NSFW content filter (allows all prompts)") |
|
|
|
|
|
parser.add_argument("--disable_inductor_tuning_flags", action="store_true", |
|
|
help="Disables use of inductor tuning flags") |
|
|
return parser |
|
|
|
|
|
|
|
|
|
|
|
def annotate(f, title): |
|
|
@functools.wraps(f) |
|
|
def _f(*args, **kwargs): |
|
|
with record_function(title): |
|
|
return f(*args, **kwargs) |
|
|
return _f |
|
|
|