Spaces:
Sleeping
Sleeping
| """ | |
| This files runs and saves the outputs for all example prompts. | |
| """ | |
| import os | |
| import hashlib | |
| import pickle, json | |
| from dataclasses import asdict | |
| from src.smc.inference import ( | |
| infer_pretrained, | |
| infer_smc_grad, | |
| infer_ft, | |
| PretrainedInferenceConfig, | |
| SMCGradInferenceConfig, | |
| FTInferenceConfig, | |
| InferenceOutput, | |
| ) | |
| examples = [ | |
| "A photo of a yellow bird and a black motorcycle", | |
| "A green stop sign in a red field", | |
| "A pink bicycle leaning against a fence near a river", | |
| "A cat in the style of Van Gogh’s Starry Night", | |
| "A stylish dog wearing sunglasses", | |
| "A photo of a blue clock and a white cup", | |
| "A dog on the moon", | |
| ] | |
| EXAMPLES_DIR = "examples" | |
| def short_hash(s): | |
| return hashlib.md5(s.encode()).hexdigest()[:8] | |
| def dataclass_to_json(obj, pretty=False): | |
| """Convert a dataclass instance to a JSON string.""" | |
| if not hasattr(obj, "__dataclass_fields__"): | |
| raise TypeError("Object must be a dataclass instance") | |
| # Convert to dict and sort keys to ensure stable serialization | |
| data = asdict(obj) | |
| if pretty: | |
| return json.dumps(data, indent=4, sort_keys=True) | |
| else: | |
| return json.dumps(data, separators=(",", ":"), sort_keys=True) | |
| def hash_dataclass(obj, algo="blake2s", digest_size=8): | |
| """Compute a deterministic hash for a dataclass instance.""" | |
| s = dataclass_to_json(obj) | |
| h = hashlib.new(algo) | |
| h.update(s.encode()) | |
| return h.hexdigest()[:digest_size * 2] # 2 hex chars per byte | |
| def does_out_exist(out_dir): | |
| return os.path.exists(os.path.join(out_dir, "out.pickle")) | |
| def save_out(out_dir, out: InferenceOutput): | |
| pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb")) | |
| for i, img in enumerate(out.images): | |
| img.save(os.path.join(out_dir, f"{i}.png")) | |
| def get_out_if_exists(method, config): | |
| out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config)) | |
| if does_out_exist(out_dir): | |
| return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb")) | |
| else: | |
| return None | |
| def main(): | |
| for prompt in examples: | |
| prompt_hash = short_hash(prompt) | |
| prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash) | |
| os.makedirs(prompt_dir, exist_ok=True) | |
| print(f"Running prompt: {prompt}") | |
| # Save prompt in file | |
| with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f: | |
| f.write(prompt) | |
| config = PretrainedInferenceConfig(prompt=prompt) | |
| out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config)) | |
| if not does_out_exist(out_dir): | |
| os.makedirs(out_dir, exist_ok=True) | |
| with open(os.path.join(out_dir, "config.json"), "w") as f: | |
| f.write(dataclass_to_json(config, pretty=True)) | |
| out = infer_pretrained(config, device="cuda") | |
| save_out(out_dir, out) | |
| config = SMCGradInferenceConfig(prompt=prompt) | |
| out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config)) | |
| if not does_out_exist(out_dir): | |
| os.makedirs(out_dir, exist_ok=True) | |
| with open(os.path.join(out_dir, "config.json"), "w") as f: | |
| f.write(dataclass_to_json(config, pretty=True)) | |
| out = infer_smc_grad(config, device="cuda") | |
| save_out(out_dir, out) | |
| config = FTInferenceConfig(prompt=prompt) | |
| out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config)) | |
| if not does_out_exist(out_dir): | |
| os.makedirs(out_dir, exist_ok=True) | |
| with open(os.path.join(out_dir, "config.json"), "w") as f: | |
| f.write(dataclass_to_json(config)) | |
| out = infer_ft(config, device="cuda") | |
| save_out(out_dir, out) | |
| if __name__ == "__main__": | |
| main() | |