cp524 commited on
Commit
012fd50
·
1 Parent(s): 9a9c18b

Add run examples script

Browse files
Files changed (1) hide show
  1. run_examples.py +104 -0
run_examples.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files runs and saves the outputs for all example prompts.
3
+ """
4
+
5
+ import os
6
+ import hashlib
7
+ import pickle, json
8
+ from dataclasses import asdict
9
+ from src.smc.inference import (
10
+ infer_pretrained,
11
+ infer_smc_grad,
12
+ infer_ft,
13
+ PretrainedInferenceConfig,
14
+ SMCGradInferenceConfig,
15
+ FTInferenceConfig,
16
+ InferenceOutput,
17
+ )
18
+
19
+
20
+ examples = [
21
+ "A photo of a yellow bird and a black motorcycle",
22
+ "A green stop sign in a red field",
23
+ ]
24
+
25
+ EXAMPLES_DIR = "examples"
26
+
27
+ def short_hash(s):
28
+ return hashlib.md5(s.encode()).hexdigest()[:8]
29
+
30
+ def dataclass_to_json(obj, pretty=False):
31
+ """Convert a dataclass instance to a JSON string."""
32
+ if not hasattr(obj, "__dataclass_fields__"):
33
+ raise TypeError("Object must be a dataclass instance")
34
+
35
+ # Convert to dict and sort keys to ensure stable serialization
36
+ data = asdict(obj)
37
+ if pretty:
38
+ return json.dumps(data, indent=4, sort_keys=True)
39
+ else:
40
+ return json.dumps(data, separators=(",", ":"), sort_keys=True)
41
+
42
+ def hash_dataclass(obj, algo="blake2s", digest_size=8):
43
+ """Compute a deterministic hash for a dataclass instance."""
44
+ s = dataclass_to_json(obj)
45
+ h = hashlib.new(algo)
46
+ h.update(s.encode())
47
+ return h.hexdigest()[:digest_size * 2] # 2 hex chars per byte
48
+
49
+ def does_out_exist(out_dir):
50
+ return os.path.exists(os.path.join(out_dir, "out.pickle"))
51
+
52
+ def save_out(out_dir, out: InferenceOutput):
53
+ pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb"))
54
+ for i, img in enumerate(out.images):
55
+ img.save(os.path.join(out_dir, f"{i}.png"))
56
+
57
+ def get_out_if_exists(method, config):
58
+ out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config))
59
+ if does_out_exist(out_dir):
60
+ return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb"))
61
+ else:
62
+ return None
63
+
64
+ def main():
65
+ for prompt in examples:
66
+ prompt_hash = short_hash(prompt)
67
+ prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash)
68
+ os.makedirs(prompt_dir, exist_ok=True)
69
+
70
+ print(f"Running prompt: {prompt}")
71
+
72
+ # Save prompt in file
73
+ with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f:
74
+ f.write(prompt)
75
+
76
+ config = PretrainedInferenceConfig(prompt=prompt)
77
+ out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config))
78
+ if not does_out_exist(out_dir):
79
+ os.makedirs(out_dir, exist_ok=True)
80
+ with open(os.path.join(out_dir, "config.json"), "w") as f:
81
+ f.write(dataclass_to_json(config, pretty=True))
82
+ out = infer_pretrained(config, device="cuda")
83
+ save_out(out_dir, out)
84
+
85
+ config = SMCGradInferenceConfig(prompt=prompt)
86
+ out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config))
87
+ if not does_out_exist(out_dir):
88
+ os.makedirs(out_dir, exist_ok=True)
89
+ with open(os.path.join(out_dir, "config.json"), "w") as f:
90
+ f.write(dataclass_to_json(config, pretty=True))
91
+ out = infer_smc_grad(config, device="cuda")
92
+ save_out(out_dir, out)
93
+
94
+ config = FTInferenceConfig(prompt=prompt)
95
+ out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config))
96
+ if not does_out_exist(out_dir):
97
+ os.makedirs(out_dir, exist_ok=True)
98
+ with open(os.path.join(out_dir, "config.json"), "w") as f:
99
+ f.write(dataclass_to_json(config))
100
+ out = infer_ft(config, device="cuda")
101
+ save_out(out_dir, out)
102
+
103
+ if __name__ == "__main__":
104
+ main()