File size: 4,983 Bytes
dbcf7a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d907122
 
dbcf7a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import re
import subprocess
import csv
from subprocess import CompletedProcess

from test_configs import *


def cmd(command: str, check=True, capture_output=False) -> CompletedProcess:
    print(command)
    if capture_output:
        ret = subprocess.run(command, shell=True, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                             universal_newlines=True)
    else:
        ret = subprocess.run(command, shell=True, check=check)
    print(ret.stdout)
    return ret


def parse_log(output):
    """output example:
    """
    model_name = re.search(r"model: (.+)", output).group(1)
    steps = re.search(r"steps: (.+)", output).group(1)
    cfg_weight = re.search(r"cfg_weight: (.+)", output).group(1)
    img_size = re.search(r"img-size: (.+)", output).group(1)
    img_number = re.search(r"image number: (.+)", output).group(1)
    load_model_time = re.search(r"load model time: (.+)", output).group(1)
    preload_model = re.search(r"preload model time: (.+)", output).group(1)
    update_lora_time = re.search(r"update lora time: (.+)", output).group(1)
    quantize_time = re.search(r"quantize time: (.+)", output).group(1)
    generate_time = re.search(r"generate image time: (.+)", output).group(1)
    total_time = re.search(r"total time: (.+)", output).group(1)
    out_image = re.search(r"save image to: (.+)", output).group(1)
    out_image = '/'.join(out_image.split("/")[-2:])
    out_image_size = re.search(r"output image size: \((.+)\)", output).group(1)
    out_image_size = out_image_size.replace(', ', '*')
    return (model_name, steps, cfg_weight, img_size, img_number, load_model_time,
            preload_model, update_lora_time, quantize_time, generate_time, total_time, out_image, out_image_size)


def _get_cmd(prompt, **kwargs):
    base_cmd = f'python mlx_app/txt2image_lora.py "{prompt}"'
    for k, v in kwargs.items():
        if v is True:
            base_cmd += f" --{k}"
        else:
            base_cmd += f" --{k} {v}"
    return base_cmd


def test_lora(result):
    commands = {
        "no_lora": [],
        "no_trigger": [],
        "with_trigger": []
    }
    for model, config in base_models.items():
        loras = config.pop("loras")
        for l in loras:
            trigger_words = l.get("trigger_words")
            lora_name = l.get("lora").split('/')[-1].split('.')[0]
            for i, p in prompts.items():
                # 1. run with no lora
                # paras = {"model": model, "output": str(output / model / f"{lora_name}-{i}-a_no_lora.png")}
                # paras.update(config)
                # commands["no_lora"].append(_get_cmd(p, **paras))
                #
                # # 2. run with lora, but no trigger words
                # paras = {"model": model, "output": str(output / model / f"{lora_name}-{i}-b_no_trigger.png")}
                # paras.update(config)
                # paras["lora"] = l.get("lora")
                # if l.get("lora-scale"):
                #     paras["lora-scale"] = l.get("lora-scale")
                # commands["no_trigger"].append(_get_cmd(p, **paras))

                # 3. run with lora, with trigger words
                paras = {"model": model, "output": str(output / f"{model}-{lora_name}-{i}-c_with_trigger.png"),
                         "n_images": 4}
                paras.update(config)
                paras["lora"] = l.get("lora")
                if l.get("lora-scale"):
                    paras["lora-scale"] = l.get("lora-scale")
                p = f"{p}, {trigger_words}"
                commands["with_trigger"].append(_get_cmd(p, **paras))

    for _, cmds in commands.items():
        for c in cmds:
            try:
                ret = cmd(c, capture_output=True)
                result.append(parse_log(ret.stdout))
            except Exception as e:
                print("Exception: ", e)
    return result


def test_base_model(result: list):
    for model, config in base_models.items():
        for i, p in prompts.items():
            paras = {"model": model, "output": str(output / f"{model}_{i}.png"), "n_images": 4, "decoding_batch_size": 4}
            paras.update(config)
            command = _get_cmd(p, **paras)
            try:
                ret = cmd(command, capture_output=True)
                result.append(parse_log(ret.stdout))
            except KeyboardInterrupt:
                return result
            except Exception as e:
                print("Exception: ", e)
    return result


def main():
    result = [
        ['model name', 'steps', 'cfg_weight', 'img size', 'img number', 'load model', 'preload model', 'update lora',
         'quantize', 'generate image', 'total time', 'output image', 'output image size']
    ]
    result = test_base_model(result)
    # result = test_lora(result)

    with open("result_mlx.csv", 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(result)


if __name__ == '__main__':
    main()