ideogram-4-sdnq-uint4 / benchmark /followup_runner.py
WaveCut's picture
Add RTX 4090 SDNQ vs NF4 follow-up benchmark
98ad5d3 verified
from __future__ import annotations
import argparse
import csv
import gc
import json
import os
import shutil
import subprocess
import sys
import threading
import time
from pathlib import Path
from typing import Any, Callable
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image, ImageDraw, ImageFont
from ideogram4 import Ideogram4Pipeline, Ideogram4PipelineConfig, PRESETS
SDNQ_REPO = "WaveCut/ideogram-4-sdnq-uint4"
NF4_REPO = "ideogram-ai/ideogram-4-nf4"
DTYPE = torch.bfloat16
def read_json(path: Path) -> Any:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def write_json(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
f.write("\n")
def prompt_to_string(prompt_case: dict[str, Any]) -> str:
return json.dumps(prompt_case["caption"], ensure_ascii=False, separators=(",", ":"))
def current_gpu_mb() -> int | None:
try:
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"],
text=True,
timeout=5,
)
return max(int(line.strip()) for line in output.splitlines() if line.strip())
except Exception:
return None
class GpuPeakMonitor:
def __init__(self, interval: float = 0.05) -> None:
self.interval = interval
self.samples: list[int] = []
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
self.samples = []
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self) -> int | None:
self._stop.set()
if self._thread is not None:
self._thread.join(timeout=2)
return max(self.samples) if self.samples else None
def _run(self) -> None:
while not self._stop.is_set():
value = current_gpu_mb()
if value is not None:
self.samples.append(value)
time.sleep(self.interval)
def cuda_cleanup() -> None:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
def measure(name: str, fn: Callable[[], Any], extra: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
cuda_cleanup()
before = current_gpu_mb()
monitor = GpuPeakMonitor()
monitor.start()
start = time.perf_counter()
result = fn()
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
nvidia_peak = monitor.stop()
after = current_gpu_mb()
row = {
"name": name,
"elapsed_seconds": elapsed,
"gpu_before_mb": before,
"gpu_after_mb": after,
"gpu_peak_mb": nvidia_peak,
"torch_peak_allocated_mb": (
torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else None
),
"torch_peak_reserved_mb": (
torch.cuda.max_memory_reserved() / 1024 / 1024 if torch.cuda.is_available() else None
),
}
if extra:
row.update(extra)
return result, row
def append_jsonl(path: Path, row: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
def write_csv(path: Path, rows: list[dict[str, Any]]) -> None:
if not rows:
return
path.parent.mkdir(parents=True, exist_ok=True)
keys: list[str] = []
for row in rows:
for key in row:
if key not in keys:
keys.append(key)
with path.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=keys)
writer.writeheader()
writer.writerows(rows)
def load_prompts(path: Path) -> list[dict[str, Any]]:
if path.exists():
return read_json(path)
downloaded = Path(hf_hub_download(SDNQ_REPO, filename="prompts.json"))
return read_json(downloaded)
def ensure_sdnq_helper() -> None:
helper = Path(hf_hub_download(SDNQ_REPO, filename="ideogram4_sdnq_pipeline.py"))
sys.path.insert(0, str(helper.parent))
def load_pipeline(variant: str, device: str):
if variant == "sdnq":
ensure_sdnq_helper()
from ideogram4_sdnq_pipeline import Ideogram4SDNQPipeline
return Ideogram4SDNQPipeline.from_pretrained(
SDNQ_REPO,
device=device,
dtype=DTYPE,
use_quantized_matmul=False,
dequantize_fp32=False,
)
if variant == "nf4":
return Ideogram4Pipeline.from_pretrained(
config=Ideogram4PipelineConfig(weights_repo=NF4_REPO),
device=device,
dtype=DTYPE,
)
raise ValueError(f"unknown variant: {variant}")
def command_generate(args: argparse.Namespace) -> None:
output_dir = Path(args.output_dir)
image_dir = output_dir / "images"
image_dir.mkdir(parents=True, exist_ok=True)
metrics_path = output_dir / f"{args.variant}_metrics.jsonl"
if metrics_path.exists():
metrics_path.unlink()
prompts = load_prompts(Path(args.prompts))
preset = PRESETS[args.preset]
pipe, load_row = measure(
f"{args.variant}_load",
lambda: load_pipeline(args.variant, args.device),
{"variant": args.variant, "hardware": args.hardware, "preset": args.preset},
)
append_jsonl(metrics_path, load_row)
rows = [load_row]
for idx, case in enumerate(prompts):
prompt = prompt_to_string(case)
seed = int(case.get("seed", idx))
height = int(case.get("height", args.height))
width = int(case.get("width", args.width))
def run_case() -> Image.Image:
return pipe(
prompt,
height=height,
width=width,
num_steps=preset.num_steps,
guidance_schedule=preset.guidance_schedule,
mu=preset.mu,
std=preset.std,
seed=seed,
raise_on_caption_issues=False,
)[0]
image, row = measure(
f"{args.variant}_generate",
run_case,
{
"variant": args.variant,
"hardware": args.hardware,
"case_id": case["id"],
"case_index": idx,
"seed": seed,
"height": height,
"width": width,
"preset": args.preset,
"request_temperature": "cold" if idx == 0 else "hot",
},
)
out_path = image_dir / f"{idx + 1:02d}_{case['id']}_{args.variant}.png"
image.save(out_path)
row["image"] = str(out_path)
append_jsonl(metrics_path, row)
rows.append(row)
print(json.dumps(row, ensure_ascii=False, default=str), flush=True)
write_csv(output_dir / f"{args.variant}_metrics.csv", rows)
def read_jsonl(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
def summarize_variant(rows: list[dict[str, Any]], variant: str) -> dict[str, Any]:
load = next((r for r in rows if r.get("name") == f"{variant}_load"), {})
gens = [r for r in rows if r.get("name") == f"{variant}_generate"]
cold = next((r for r in gens if r.get("request_temperature") == "cold"), {})
hot = [r for r in gens if r.get("request_temperature") == "hot"]
def mean(key: str, items: list[dict[str, Any]]) -> float | None:
vals = [float(x[key]) for x in items if x.get(key) not in (None, "")]
return sum(vals) / len(vals) if vals else None
def maxv(key: str, items: list[dict[str, Any]]) -> float | None:
vals = [float(x[key]) for x in items if x.get(key) not in (None, "")]
return max(vals) if vals else None
return {
"variant": variant,
"load_seconds": load.get("elapsed_seconds"),
"load_peak_reserved_mb": load.get("torch_peak_reserved_mb"),
"load_peak_nvidia_mb": load.get("gpu_peak_mb"),
"cold_request_seconds": cold.get("elapsed_seconds"),
"cold_request_peak_reserved_mb": cold.get("torch_peak_reserved_mb"),
"cold_request_peak_nvidia_mb": cold.get("gpu_peak_mb"),
"hot_request_mean_seconds": mean("elapsed_seconds", hot),
"hot_request_max_seconds": maxv("elapsed_seconds", hot),
"generation_peak_reserved_mb": maxv("torch_peak_reserved_mb", gens),
"generation_peak_nvidia_mb": maxv("gpu_peak_mb", gens),
"cases": len(gens),
}
def fmt(value: Any) -> str:
if value is None or value == "":
return ""
if isinstance(value, str):
return value
return f"{float(value):.2f}"
def markdown_table(rows: list[dict[str, Any]], keys: list[tuple[str, str]]) -> str:
header = "| " + " | ".join(label for label, _ in keys) + " |"
sep = "| " + " | ".join("---" for _ in keys) + " |"
body = ["| " + " | ".join(fmt(row.get(key)) for _, key in keys) + " |" for row in rows]
return "\n".join([header, sep, *body])
def load_font(size: int) -> ImageFont.ImageFont:
for path in [
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/truetype/liberation2/LiberationSans-Regular.ttf",
]:
try:
return ImageFont.truetype(path, size)
except Exception:
pass
return ImageFont.load_default()
def draw_centered(draw: ImageDraw.ImageDraw, xy: tuple[int, int, int, int], text: str, font: ImageFont.ImageFont, fill: tuple[int, int, int]) -> None:
left, top, right, bottom = xy
bbox = draw.textbbox((0, 0), text, font=font)
x = left + (right - left - (bbox[2] - bbox[0])) // 2
y = top + (bottom - top - (bbox[3] - bbox[1])) // 2
draw.text((x, y), text, font=font, fill=fill)
def make_side_by_side_matrix(
left_images: list[Path],
right_images: list[Path],
left_label: str,
right_label: str,
output_path: Path,
) -> None:
if len(left_images) != len(right_images):
raise ValueError("left and right image counts differ")
count = len(left_images)
canvas_size = 8192
header_h = 160
row_h = (canvas_size - header_h) // count
col_w = canvas_size // 2
tile = min(col_w, row_h) - 18
bg = (18, 18, 18)
line = (58, 58, 58)
canvas = Image.new("RGB", (canvas_size, canvas_size), bg)
draw = ImageDraw.Draw(canvas)
header_font = load_font(82)
label_font = load_font(36)
draw.rectangle((0, 0, canvas_size, header_h), fill=(28, 28, 28))
draw_centered(draw, (0, 0, col_w, header_h), left_label, header_font, (245, 245, 245))
draw_centered(draw, (col_w, 0, canvas_size, header_h), right_label, header_font, (245, 245, 245))
draw.line((col_w, 0, col_w, canvas_size), fill=line, width=3)
for idx, (left_path, right_path) in enumerate(zip(left_images, right_images)):
y = header_h + idx * row_h
draw.line((0, y, canvas_size, y), fill=line, width=1)
for col, path in enumerate([left_path, right_path]):
with Image.open(path) as img:
img = img.convert("RGB")
img.thumbnail((tile, tile), Image.Resampling.LANCZOS)
x0 = col * col_w
px = x0 + (col_w - img.width) // 2
py = y + (row_h - img.height) // 2
canvas.paste(img, (px, py))
label = path.stem.split("_", 1)[-1].rsplit("_", 1)[0]
draw.text((col * col_w + 28, y + 16), f"{idx + 1:02d} {label}", font=label_font, fill=(230, 230, 230))
output_path.parent.mkdir(parents=True, exist_ok=True)
canvas.save(output_path, "WEBP", quality=95, method=6)
def command_collect(args: argparse.Namespace) -> None:
results_dir = Path(args.results_dir)
publish_dir = Path(args.publish_dir)
publish_dir.mkdir(parents=True, exist_ok=True)
sdnq_rows = read_jsonl(results_dir / "sdnq" / "sdnq_metrics.jsonl")
nf4_rows = read_jsonl(results_dir / "nf4" / "nf4_metrics.jsonl")
summaries = [summarize_variant(sdnq_rows, "sdnq"), summarize_variant(nf4_rows, "nf4")]
write_json(publish_dir / "summary_4090_sdnq_vs_nf4.json", summaries)
sdnq_images = sorted((results_dir / "sdnq" / "images").glob("*_sdnq.png"))
nf4_images = sorted((results_dir / "nf4" / "images").glob("*_nf4.png"))
matrix_path = publish_dir / "sdnq_vs_nf4_4090_side_by_side.webp"
make_side_by_side_matrix(sdnq_images, nf4_images, "SDNQ UInt4", "Official NF4", matrix_path)
for rel in [
"sdnq/sdnq_metrics.jsonl",
"sdnq/sdnq_metrics.csv",
"nf4/nf4_metrics.jsonl",
"nf4/nf4_metrics.csv",
]:
src = results_dir / rel
if src.exists():
shutil.copy2(src, publish_dir / src.name.replace("_metrics", "_4090_metrics"))
table = markdown_table(
summaries,
[
("Variant", "variant"),
("Cases", "cases"),
("Load s", "load_seconds"),
("Load peak reserved MB", "load_peak_reserved_mb"),
("Load peak nvidia MB", "load_peak_nvidia_mb"),
("Cold request s", "cold_request_seconds"),
("Hot mean s", "hot_request_mean_seconds"),
("Hot max s", "hot_request_max_seconds"),
("Gen peak reserved MB", "generation_peak_reserved_mb"),
("Gen peak nvidia MB", "generation_peak_nvidia_mb"),
],
)
(publish_dir / "README_APPEND.md").write_text(
f"""## RTX 4090 Follow-up: SDNQ UInt4 vs Official NF4
Hardware: RunPod NVIDIA GeForce RTX 4090, 24 GB VRAM, single process, concurrency 1. Both variants used the same 10 structured captions from `prompts.json`, 1024x1024, `V4_DEFAULT_20`, and no magic-prompt expansion. `nf4` uses the official `ideogram-ai/ideogram-4-nf4` checkpoint through the upstream `ideogram4` loader.
{table}
![SDNQ vs official NF4 on RTX 4090](assets/sdnq_vs_nf4_4090_side_by_side.webp)
""",
encoding="utf-8",
)
print(table)
print(matrix_path)
def main() -> None:
parser = argparse.ArgumentParser()
sub = parser.add_subparsers(dest="command", required=True)
gen = sub.add_parser("generate")
gen.add_argument("--variant", choices=["sdnq", "nf4"], required=True)
gen.add_argument("--prompts", default="/workspace/ideogram4_followup/prompts.json")
gen.add_argument("--output-dir", required=True)
gen.add_argument("--device", default="cuda")
gen.add_argument("--height", type=int, default=1024)
gen.add_argument("--width", type=int, default=1024)
gen.add_argument("--preset", default="V4_DEFAULT_20", choices=sorted(PRESETS))
gen.add_argument("--hardware", default="NVIDIA GeForce RTX 4090")
gen.set_defaults(func=command_generate)
collect = sub.add_parser("collect")
collect.add_argument("--results-dir", default="/workspace/ideogram4_followup/results")
collect.add_argument("--publish-dir", default="/workspace/ideogram4_followup/publish")
collect.set_defaults(func=command_collect)
args = parser.parse_args()
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
args.func(args)
if __name__ == "__main__":
main()