smc_meissonic / run_examples.py
cp524's picture
Add run examples script
012fd50
raw
history blame
3.63 kB
"""
This files runs and saves the outputs for all example prompts.
"""
import os
import hashlib
import pickle, json
from dataclasses import asdict
from src.smc.inference import (
infer_pretrained,
infer_smc_grad,
infer_ft,
PretrainedInferenceConfig,
SMCGradInferenceConfig,
FTInferenceConfig,
InferenceOutput,
)
examples = [
"A photo of a yellow bird and a black motorcycle",
"A green stop sign in a red field",
]
EXAMPLES_DIR = "examples"
def short_hash(s):
return hashlib.md5(s.encode()).hexdigest()[:8]
def dataclass_to_json(obj, pretty=False):
"""Convert a dataclass instance to a JSON string."""
if not hasattr(obj, "__dataclass_fields__"):
raise TypeError("Object must be a dataclass instance")
# Convert to dict and sort keys to ensure stable serialization
data = asdict(obj)
if pretty:
return json.dumps(data, indent=4, sort_keys=True)
else:
return json.dumps(data, separators=(",", ":"), sort_keys=True)
def hash_dataclass(obj, algo="blake2s", digest_size=8):
"""Compute a deterministic hash for a dataclass instance."""
s = dataclass_to_json(obj)
h = hashlib.new(algo)
h.update(s.encode())
return h.hexdigest()[:digest_size * 2] # 2 hex chars per byte
def does_out_exist(out_dir):
return os.path.exists(os.path.join(out_dir, "out.pickle"))
def save_out(out_dir, out: InferenceOutput):
pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb"))
for i, img in enumerate(out.images):
img.save(os.path.join(out_dir, f"{i}.png"))
def get_out_if_exists(method, config):
out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config))
if does_out_exist(out_dir):
return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb"))
else:
return None
def main():
for prompt in examples:
prompt_hash = short_hash(prompt)
prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash)
os.makedirs(prompt_dir, exist_ok=True)
print(f"Running prompt: {prompt}")
# Save prompt in file
with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f:
f.write(prompt)
config = PretrainedInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config, pretty=True))
out = infer_pretrained(config, device="cuda")
save_out(out_dir, out)
config = SMCGradInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config, pretty=True))
out = infer_smc_grad(config, device="cuda")
save_out(out_dir, out)
config = FTInferenceConfig(prompt=prompt)
out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config))
if not does_out_exist(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(dataclass_to_json(config))
out = infer_ft(config, device="cuda")
save_out(out_dir, out)
if __name__ == "__main__":
main()