smc_meissonic / run_examples.py
cp524's picture
Add more examples
9c41927
"""
This files runs and saves the outputs for all example prompts.
"""
import os
import hashlib
import pickle, json
from dataclasses import asdict
from src.smc.inference import (
infer_pretrained,
infer_smc_grad,
infer_ft,
PretrainedInferenceConfig,
SMCGradInferenceConfig,
FTInferenceConfig,
InferenceOutput,
)
examples = [
"A photo of a yellow bird and a black motorcycle",
"A green stop sign in a red field",
"A pink bicycle leaning against a fence near a river",
"A cat in the style of Van Gogh’s Starry Night",
"A stylish dog wearing sunglasses",
"A photo of a blue clock and a white cup",
"A dog on the moon",
]
EXAMPLES_DIR = "examples"
def short_hash(s):
return hashlib.md5(s.encode()).hexdigest()[:8]
def dataclass_to_json(obj, pretty=False):
"""Convert a dataclass instance to a JSON string."""
if not hasattr(obj, "__dataclass_fields__"):
raise TypeError("Object must be a dataclass instance")
# Convert to dict and sort keys to ensure stable serialization
data = asdict(obj)
if pretty:
return json.dumps(data, indent=4, sort_keys=True)
else:
return json.dumps(data, separators=(",", ":"), sort_keys=True)
def hash_dataclass(obj, algo="blake2s", digest_size=8):
"""Compute a deterministic hash for a dataclass instance."""
s = dataclass_to_json(obj)
h = hashlib.new(algo)
h.update(s.encode())
return h.hexdigest()[:digest_size * 2] # 2 hex chars per byte
def does_out_exist(out_dir):
return os.path.exists(os.path.join(out_dir, "out.pickle"))
def save_out(out_dir, out: InferenceOutput):
pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb"))
for i, img in enumerate(out.images):
img.save(os.path.join(out_dir, f"{i}.png"))
def get_out_if_exists(method, config):
out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config))
if does_out_exist(out_dir):
return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb"))
else:
return None
def main():
for prompt in examples:
prompt_hash = short_hash(prompt)
prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash)
os.makedirs(prompt_dir, exist_ok=True)
print(f"Running prompt: {prompt}")
# Save prompt in file
with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f:
f.write(prompt)
config = PretrainedInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config, pretty=True))
out = infer_pretrained(config, device="cuda")
save_out(out_dir, out)
config = SMCGradInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config, pretty=True))
out = infer_smc_grad(config, device="cuda")
save_out(out_dir, out)
config = FTInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config))
out = infer_ft(config, device="cuda")
save_out(out_dir, out)
if __name__ == "__main__":
main()