test_mlx / test_mlx.py
yujuanqin's picture
save result when KeyboardInterrupt
d907122
import json
import re
import subprocess
import csv
from subprocess import CompletedProcess
from test_configs import *
def cmd(command: str, check=True, capture_output=False) -> CompletedProcess:
print(command)
if capture_output:
ret = subprocess.run(command, shell=True, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
universal_newlines=True)
else:
ret = subprocess.run(command, shell=True, check=check)
print(ret.stdout)
return ret
def parse_log(output):
"""output example:
"""
model_name = re.search(r"model: (.+)", output).group(1)
steps = re.search(r"steps: (.+)", output).group(1)
cfg_weight = re.search(r"cfg_weight: (.+)", output).group(1)
img_size = re.search(r"img-size: (.+)", output).group(1)
img_number = re.search(r"image number: (.+)", output).group(1)
load_model_time = re.search(r"load model time: (.+)", output).group(1)
preload_model = re.search(r"preload model time: (.+)", output).group(1)
update_lora_time = re.search(r"update lora time: (.+)", output).group(1)
quantize_time = re.search(r"quantize time: (.+)", output).group(1)
generate_time = re.search(r"generate image time: (.+)", output).group(1)
total_time = re.search(r"total time: (.+)", output).group(1)
out_image = re.search(r"save image to: (.+)", output).group(1)
out_image = '/'.join(out_image.split("/")[-2:])
out_image_size = re.search(r"output image size: \((.+)\)", output).group(1)
out_image_size = out_image_size.replace(', ', '*')
return (model_name, steps, cfg_weight, img_size, img_number, load_model_time,
preload_model, update_lora_time, quantize_time, generate_time, total_time, out_image, out_image_size)
def _get_cmd(prompt, **kwargs):
base_cmd = f'python mlx_app/txt2image_lora.py "{prompt}"'
for k, v in kwargs.items():
if v is True:
base_cmd += f" --{k}"
else:
base_cmd += f" --{k} {v}"
return base_cmd
def test_lora(result):
commands = {
"no_lora": [],
"no_trigger": [],
"with_trigger": []
}
for model, config in base_models.items():
loras = config.pop("loras")
for l in loras:
trigger_words = l.get("trigger_words")
lora_name = l.get("lora").split('/')[-1].split('.')[0]
for i, p in prompts.items():
# 1. run with no lora
# paras = {"model": model, "output": str(output / model / f"{lora_name}-{i}-a_no_lora.png")}
# paras.update(config)
# commands["no_lora"].append(_get_cmd(p, **paras))
#
# # 2. run with lora, but no trigger words
# paras = {"model": model, "output": str(output / model / f"{lora_name}-{i}-b_no_trigger.png")}
# paras.update(config)
# paras["lora"] = l.get("lora")
# if l.get("lora-scale"):
# paras["lora-scale"] = l.get("lora-scale")
# commands["no_trigger"].append(_get_cmd(p, **paras))
# 3. run with lora, with trigger words
paras = {"model": model, "output": str(output / f"{model}-{lora_name}-{i}-c_with_trigger.png"),
"n_images": 4}
paras.update(config)
paras["lora"] = l.get("lora")
if l.get("lora-scale"):
paras["lora-scale"] = l.get("lora-scale")
p = f"{p}, {trigger_words}"
commands["with_trigger"].append(_get_cmd(p, **paras))
for _, cmds in commands.items():
for c in cmds:
try:
ret = cmd(c, capture_output=True)
result.append(parse_log(ret.stdout))
except Exception as e:
print("Exception: ", e)
return result
def test_base_model(result: list):
for model, config in base_models.items():
for i, p in prompts.items():
paras = {"model": model, "output": str(output / f"{model}_{i}.png"), "n_images": 4, "decoding_batch_size": 4}
paras.update(config)
command = _get_cmd(p, **paras)
try:
ret = cmd(command, capture_output=True)
result.append(parse_log(ret.stdout))
except KeyboardInterrupt:
return result
except Exception as e:
print("Exception: ", e)
return result
def main():
result = [
['model name', 'steps', 'cfg_weight', 'img size', 'img number', 'load model', 'preload model', 'update lora',
'quantize', 'generate image', 'total time', 'output image', 'output image size']
]
result = test_base_model(result)
# result = test_lora(result)
with open("result_mlx.csv", 'w', newline='') as f:
writer = csv.writer(f)
writer.writerows(result)
if __name__ == '__main__':
main()