File size: 3,854 Bytes
012fd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f35ee86
9c41927
 
 
 
012fd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
This files runs and saves the outputs for all example prompts.
"""

import os
import hashlib
import pickle, json
from dataclasses import asdict
from src.smc.inference import (
    infer_pretrained,
    infer_smc_grad,
    infer_ft,
    PretrainedInferenceConfig,
    SMCGradInferenceConfig,
    FTInferenceConfig,
    InferenceOutput,
)


examples = [
    "A photo of a yellow bird and a black motorcycle",
    "A green stop sign in a red field",
    "A pink bicycle leaning against a fence near a river",
    "A cat in the style of Van Gogh’s Starry Night",
    "A stylish dog wearing sunglasses",
    "A photo of a blue clock and a white cup",
    "A dog on the moon",
]

EXAMPLES_DIR = "examples"

def short_hash(s):
    return hashlib.md5(s.encode()).hexdigest()[:8]

def dataclass_to_json(obj, pretty=False):
    """Convert a dataclass instance to a JSON string."""
    if not hasattr(obj, "__dataclass_fields__"):
        raise TypeError("Object must be a dataclass instance")

    # Convert to dict and sort keys to ensure stable serialization
    data = asdict(obj)
    if pretty:
        return json.dumps(data, indent=4, sort_keys=True)
    else:
        return json.dumps(data, separators=(",", ":"), sort_keys=True)

def hash_dataclass(obj, algo="blake2s", digest_size=8):
    """Compute a deterministic hash for a dataclass instance."""
    s = dataclass_to_json(obj)
    h = hashlib.new(algo)
    h.update(s.encode())
    return h.hexdigest()[:digest_size * 2]  # 2 hex chars per byte

def does_out_exist(out_dir):
    return os.path.exists(os.path.join(out_dir, "out.pickle"))

def save_out(out_dir, out: InferenceOutput):
    pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb"))
    for i, img in enumerate(out.images):
        img.save(os.path.join(out_dir, f"{i}.png"))
        
def get_out_if_exists(method, config):
    out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config))
    if does_out_exist(out_dir):
        return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb"))
    else:
        return None

def main():
    for prompt in examples:
        prompt_hash = short_hash(prompt)
        prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash)
        os.makedirs(prompt_dir, exist_ok=True)

        print(f"Running prompt: {prompt}")
        
        # Save prompt in file
        with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f:
            f.write(prompt)
        
        config = PretrainedInferenceConfig(prompt=prompt)
        out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config))
        if not does_out_exist(out_dir):
            os.makedirs(out_dir, exist_ok=True)
            with open(os.path.join(out_dir, "config.json"), "w") as f:
                f.write(dataclass_to_json(config, pretty=True))
            out = infer_pretrained(config, device="cuda")
            save_out(out_dir, out)
    
        config = SMCGradInferenceConfig(prompt=prompt)
        out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config))
        if not does_out_exist(out_dir):
            os.makedirs(out_dir, exist_ok=True)
            with open(os.path.join(out_dir, "config.json"), "w") as f:
                f.write(dataclass_to_json(config, pretty=True))
            out = infer_smc_grad(config, device="cuda")
            save_out(out_dir, out)
            
        config = FTInferenceConfig(prompt=prompt)
        out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config))
        if not does_out_exist(out_dir):
            os.makedirs(out_dir, exist_ok=True)
            with open(os.path.join(out_dir, "config.json"), "w") as f:
                f.write(dataclass_to_json(config))
            out = infer_ft(config, device="cuda")
            save_out(out_dir, out)
    
if __name__ == "__main__":
    main()