thisiswooyeol commited on
Commit
4201802
·
verified ·
1 Parent(s): 6a42d2d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +86 -0
README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: creativeml-openrail-m
3
+ language:
4
+ - en
5
+ base_model:
6
+ - CompVis/stable-diffusion-v1-4
7
+ - limuloo1999/MIGC
8
+ pipeline_tag: text-to-image
9
+ ---
10
+ # About file
11
+
12
+ <!-- Provide a quick summary of what the model is/does. -->
13
+
14
+ Diffusers version of MIGC adapter state dict. The actual values are identical to the original checkpoint file [MICG_SD14.ckpt](https://huggingface.co/limuloo1999/MIGC)
15
+ Please see the details of MIGC in the [MIGC repositiory](https://github.com/limuloo/MIGC).
16
+
17
+
18
+ # How to use
19
+
20
+ Please use modified pipeline class in `pipeline_stable_diffusion_migc.py` file.
21
+
22
+ ```python
23
+ import random
24
+
25
+ import numpy as np
26
+ import safetensors.torch
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ from pipeline_stable_diffusion_migc import StableDiffusionMIGCPipeline
31
+
32
+
33
+ DEVICE="cuda"
34
+ SEED=42
35
+
36
+ pipe = StableDiffusionMIGCPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(DEVICE)
37
+ adapter_path = hf_hub_download(repo_id="thisiswooyeol/MIGC-diffusers", filename="migc_adapter_weights.safetensors")
38
+
39
+ # Load MIGC adapter to UNet attn2 layers
40
+ state_dict = safetensors.torch.load_file(adapter_path)
41
+ for name, module in pipe.unet.named_modules():
42
+ if hasattr(module, "migc"):
43
+ print(f"Found MIGC in {name}")
44
+
45
+ # Get the state dict with the incorrect keys
46
+ state_dict_to_load = {k: v for k, v in state_dict.items() if k.startswith(name)}
47
+
48
+ # Create a new state dict, removing the "attn2." prefix from each key
49
+ new_state_dict = {k.replace(f"{name}.migc.", "", 1): v for k, v in state_dict_to_load.items()}
50
+
51
+ # Load the corrected state dict
52
+ module.migc.load_state_dict(new_state_dict)
53
+ module.to(device=pipe.unet.device, dtype=pipe.unet.dtype)
54
+
55
+
56
+ # Sample inference !
57
+ prompt = "bestquality, detailed, 8k.a photo of a black potted plant and a yellow refrigerator and a brown surfboard"
58
+ phrases = [
59
+ "a black potted plant",
60
+ "a brown surfboard",
61
+ "a yellow refrigerator",
62
+ ]
63
+ bboxes = [
64
+ [0.5717187499999999, 0.0, 0.8179531250000001, 0.29807511737089204],
65
+ [0.85775, 0.058755868544600943, 0.9991875, 0.646525821596244],
66
+ [0.6041562500000001, 0.284906103286385, 0.799046875, 0.9898591549295774],
67
+ ]
68
+
69
+ def seed_everything(seed):
70
+ random.seed(seed)
71
+ np.random.seed(seed)
72
+ torch.manual_seed(seed)
73
+ torch.cuda.manual_seed_all(seed)
74
+
75
+
76
+ seed_everything(SEED)
77
+
78
+ image = pipe(
79
+ prompt=prompt,
80
+ phrases=phrases,
81
+ bboxes=bboxes,
82
+ negative_prompt="worst quality, low quality, bad anatomy",
83
+ generator=torch.Generator(DEVICE).manual_seed(SEED),
84
+ ).images[0]
85
+ image.save("image.png")
86
+ ```