Spaces:
Sleeping
Sleeping
File size: 3,854 Bytes
012fd50 f35ee86 9c41927 012fd50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""
This files runs and saves the outputs for all example prompts.
"""
import os
import hashlib
import pickle, json
from dataclasses import asdict
from src.smc.inference import (
infer_pretrained,
infer_smc_grad,
infer_ft,
PretrainedInferenceConfig,
SMCGradInferenceConfig,
FTInferenceConfig,
InferenceOutput,
)
examples = [
"A photo of a yellow bird and a black motorcycle",
"A green stop sign in a red field",
"A pink bicycle leaning against a fence near a river",
"A cat in the style of Van Gogh’s Starry Night",
"A stylish dog wearing sunglasses",
"A photo of a blue clock and a white cup",
"A dog on the moon",
]
EXAMPLES_DIR = "examples"
def short_hash(s):
return hashlib.md5(s.encode()).hexdigest()[:8]
def dataclass_to_json(obj, pretty=False):
"""Convert a dataclass instance to a JSON string."""
if not hasattr(obj, "__dataclass_fields__"):
raise TypeError("Object must be a dataclass instance")
# Convert to dict and sort keys to ensure stable serialization
data = asdict(obj)
if pretty:
return json.dumps(data, indent=4, sort_keys=True)
else:
return json.dumps(data, separators=(",", ":"), sort_keys=True)
def hash_dataclass(obj, algo="blake2s", digest_size=8):
"""Compute a deterministic hash for a dataclass instance."""
s = dataclass_to_json(obj)
h = hashlib.new(algo)
h.update(s.encode())
return h.hexdigest()[:digest_size * 2] # 2 hex chars per byte
def does_out_exist(out_dir):
return os.path.exists(os.path.join(out_dir, "out.pickle"))
def save_out(out_dir, out: InferenceOutput):
pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb"))
for i, img in enumerate(out.images):
img.save(os.path.join(out_dir, f"{i}.png"))
def get_out_if_exists(method, config):
out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config))
if does_out_exist(out_dir):
return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb"))
else:
return None
def main():
for prompt in examples:
prompt_hash = short_hash(prompt)
prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash)
os.makedirs(prompt_dir, exist_ok=True)
print(f"Running prompt: {prompt}")
# Save prompt in file
with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f:
f.write(prompt)
config = PretrainedInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config, pretty=True))
out = infer_pretrained(config, device="cuda")
save_out(out_dir, out)
config = SMCGradInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config, pretty=True))
out = infer_smc_grad(config, device="cuda")
save_out(out_dir, out)
config = FTInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config))
out = infer_ft(config, device="cuda")
save_out(out_dir, out)
if __name__ == "__main__":
main()
|