osmr commited on
Commit
91a06e4
·
1 Parent(s): b06f841

Add controlnet support

Browse files
Files changed (2) hide show
  1. .gitignore +186 -0
  2. app.py +127 -41
.gitignore ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ yolov8n.pt
2
+ lora_ghibli/
3
+
4
+ # PyCharm ###
5
+ .idea/
6
+
7
+ # Visual Studio ###
8
+ Release/
9
+ Debug/
10
+ .vs/
11
+ *.VC.db
12
+ *.sdf
13
+ *.suo
14
+ *.opendb
15
+ *.psess
16
+ *.vsp
17
+ *.vspx
18
+ *.sln
19
+ *.pyproj
20
+ x64
21
+
22
+ # R ###
23
+ .Rhistory
24
+
25
+ # Byte-compiled / optimized / DLL files
26
+ __pycache__/
27
+ *.py[cod]
28
+ *$py.class
29
+
30
+ # C extensions
31
+ *.so
32
+
33
+ # Distribution / packaging
34
+ .Python
35
+ build/
36
+ develop-eggs/
37
+ dist/
38
+ downloads/
39
+ eggs/
40
+ .eggs/
41
+ lib/
42
+ lib64/
43
+ parts/
44
+ sdist/
45
+ var/
46
+ wheels/
47
+ share/python-wheels/
48
+ *.egg-info/
49
+ .installed.cfg
50
+ *.egg
51
+ MANIFEST
52
+
53
+ # PyInstaller
54
+ # Usually these files are written by a python script from a template
55
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
56
+ *.manifest
57
+ *.spec
58
+
59
+ # Installer logs
60
+ pip-log.txt
61
+ pip-delete-this-directory.txt
62
+
63
+ # Unit test / coverage reports
64
+ htmlcov/
65
+ .tox/
66
+ .nox/
67
+ .coverage
68
+ .coverage.*
69
+ .cache
70
+ nosetests.xml
71
+ coverage.xml
72
+ *.cover
73
+ *.py,cover
74
+ .hypothesis/
75
+ .pytest_cache/
76
+ cover/
77
+
78
+ # Translations
79
+ *.mo
80
+ *.pot
81
+
82
+ # Django stuff:
83
+ *.log
84
+ local_settings.py
85
+ db.sqlite3
86
+ db.sqlite3-journal
87
+
88
+ # Flask stuff:
89
+ instance/
90
+ .webassets-cache
91
+
92
+ # Scrapy stuff:
93
+ .scrapy
94
+
95
+ # Sphinx documentation
96
+ docs/_build/
97
+
98
+ # PyBuilder
99
+ .pybuilder/
100
+ target/
101
+
102
+ # Jupyter Notebook
103
+ .ipynb_checkpoints
104
+
105
+ # IPython
106
+ profile_default/
107
+ ipython_config.py
108
+
109
+ # pyenv
110
+ # For a library or package, you might want to ignore these files since the code is
111
+ # intended to run in multiple environments; otherwise, check them in:
112
+ # .python-version
113
+
114
+ # pipenv
115
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
116
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
117
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
118
+ # install all needed dependencies.
119
+ #Pipfile.lock
120
+
121
+ # poetry
122
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
123
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
124
+ # commonly ignored for libraries.
125
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
126
+ #poetry.lock
127
+
128
+ # pdm
129
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
130
+ #pdm.lock
131
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
132
+ # in version control.
133
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
134
+ .pdm.toml
135
+ .pdm-python
136
+ .pdm-build/
137
+
138
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
139
+ __pypackages__/
140
+
141
+ # Celery stuff
142
+ celerybeat-schedule
143
+ celerybeat.pid
144
+
145
+ # SageMath parsed files
146
+ *.sage.py
147
+
148
+ # Environments
149
+ .env
150
+ .venv
151
+ env/
152
+ venv/
153
+ ENV/
154
+ env.bak/
155
+ venv.bak/
156
+
157
+ # Spyder project settings
158
+ .spyderproject
159
+ .spyproject
160
+
161
+ # Rope project settings
162
+ .ropeproject
163
+
164
+ # mkdocs documentation
165
+ /site
166
+
167
+ # mypy
168
+ .mypy_cache/
169
+ .dmypy.json
170
+ dmypy.json
171
+
172
+ # Pyre type checker
173
+ .pyre/
174
+
175
+ # pytype static type analyzer
176
+ .pytype/
177
+
178
+ # Cython debug symbols
179
+ cython_debug/
180
+
181
+ # PyCharm
182
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
183
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
184
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
185
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
186
+ #.idea/
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  from typing import Optional
5
 
6
  # import spaces #[uncomment to use ZeroGPU]
7
- from diffusers import StableDiffusionPipeline
8
  import torch
9
 
10
 
@@ -23,6 +23,8 @@ DEFAULT_HEIGHT = 512
23
  DEFAULT_GS = 7.5
24
  DEFAULT_LS = 1.0
25
  DEFAULT_NUM_INF_STEPS = 50
 
 
26
 
27
 
28
  # @spaces.GPU #[uncomment to use ZeroGPU]
@@ -36,58 +38,100 @@ def infer(lora_model_id: Optional[str] = "osmr/stable-diffusion-v1-4-lora-iv-ghi
36
  guidance_scale: Optional[float] = DEFAULT_GS,
37
  lora_scale: Optional[float] = DEFAULT_LS,
38
  num_inference_steps: Optional[int] = DEFAULT_NUM_INF_STEPS,
 
 
 
 
 
 
 
39
  progress = gr.Progress(track_tqdm=True)):
40
- if lora_model_id == "osmr/stable-diffusion-v1-4-lora-iv-ghibli":
41
- model_id = "CompVis/stable-diffusion-v1-4"
42
- elif lora_model_id == "osmr/stable-diffusion-v1-4-lora-db-ghibli":
43
- model_id = "CompVis/stable-diffusion-v1-4"
44
- elif lora_model_id == "osmr/stable-diffusion-v1-5-lora-iv-ghibli":
45
- model_id = "runwayml/stable-diffusion-v1-5"
46
- elif lora_model_id == "osmr/stable-diffusion-v1-5-lora-db-ghibli":
47
- model_id = "runwayml/stable-diffusion-v1-5"
48
- elif lora_model_id == "CompVis/stable-diffusion-v1-4":
49
- model_id = "CompVis/stable-diffusion-v1-4"
50
- lora_model_id = None
51
- elif lora_model_id == "runwayml/stable-diffusion-v1-5":
52
- model_id = "runwayml/stable-diffusion-v1-5"
53
  lora_model_id = None
54
  else:
55
- model_id = "CompVis/stable-diffusion-v1-4"
56
- lora_model_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
 
61
  generator = torch.Generator().manual_seed(seed)
62
 
63
- pipe = StableDiffusionPipeline.from_pretrained(
64
- pretrained_model_name_or_path=model_id,
65
- torch_dtype=torch_dtype)
66
- if lora_model_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  pipe.load_lora_weights(lora_model_id)
 
 
 
 
68
  pipe = pipe.to(device)
69
 
70
- if lora_model_id:
71
- image = pipe(
72
- prompt=prompt,
73
- negative_prompt=negative_prompt,
74
- guidance_scale=guidance_scale,
75
- num_inference_steps=num_inference_steps,
76
- width=width,
77
- height=height,
78
- generator=generator,
79
- cross_attention_kwargs={"scale": lora_scale}
80
- ).images[0]
81
- else:
82
- image = pipe(
83
- prompt=prompt,
84
- negative_prompt=negative_prompt,
85
- guidance_scale=guidance_scale,
86
- num_inference_steps=num_inference_steps,
87
- width=width,
88
- height=height,
89
- generator=generator,
90
- ).images[0]
91
 
92
  return image, seed
93
 
@@ -96,6 +140,7 @@ examples = [
96
  "GBL, a man and a woman sitting at a table with glasses of wine in front of them",
97
  "a man and a woman sitting at a table with glasses of wine in front of them",
98
  "GBL, a man sitting at a desk in a library with a book open in front of him",
 
99
  ]
100
 
101
  css = """
@@ -199,6 +244,42 @@ with gr.Blocks(css=css) as demo:
199
  value=DEFAULT_LS,
200
  )
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  gr.Examples(examples=examples, inputs=[prompt])
203
  gr.on(
204
  triggers=[run_button.click, prompt.submit],
@@ -214,6 +295,11 @@ with gr.Blocks(css=css) as demo:
214
  guidance_scale,
215
  lora_scale,
216
  num_inference_steps,
 
 
 
 
 
217
  ],
218
  outputs=[result, seed],
219
  )
 
4
  from typing import Optional
5
 
6
  # import spaces #[uncomment to use ZeroGPU]
7
+ from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
8
  import torch
9
 
10
 
 
23
  DEFAULT_GS = 7.5
24
  DEFAULT_LS = 1.0
25
  DEFAULT_NUM_INF_STEPS = 50
26
+ DEFAULT_CN_COND_SCALE = 1.0
27
+ DEFAULT_IPA_SCALE = 0.5
28
 
29
 
30
  # @spaces.GPU #[uncomment to use ZeroGPU]
 
38
  guidance_scale: Optional[float] = DEFAULT_GS,
39
  lora_scale: Optional[float] = DEFAULT_LS,
40
  num_inference_steps: Optional[int] = DEFAULT_NUM_INF_STEPS,
41
+
42
+ controlnet_type: str = "Edge-Detection",
43
+ controlnet_cond_scale: float = DEFAULT_CN_COND_SCALE,
44
+ controlnet_image: object = None,
45
+ ipadapter_scale: float = DEFAULT_IPA_SCALE,
46
+ ipadapter_image: object = None,
47
+
48
  progress = gr.Progress(track_tqdm=True)):
49
+
50
+ use_lora = (lora_model_id in [
51
+ "osmr/stable-diffusion-v1-4-lora-iv-ghibli",
52
+ "osmr/stable-diffusion-v1-4-lora-db-ghibli",
53
+ "osmr/stable-diffusion-v1-5-lora-iv-ghibli",
54
+ "osmr/stable-diffusion-v1-5-lora-db-ghibli",
55
+ ])
56
+ if not use_lora:
57
+ model_id = lora_model_id
 
 
 
 
58
  lora_model_id = None
59
  else:
60
+ if lora_model_id == "osmr/stable-diffusion-v1-4-lora-iv-ghibli":
61
+ model_id = "CompVis/stable-diffusion-v1-4"
62
+ elif lora_model_id == "osmr/stable-diffusion-v1-4-lora-db-ghibli":
63
+ model_id = "CompVis/stable-diffusion-v1-4"
64
+ elif lora_model_id == "osmr/stable-diffusion-v1-5-lora-iv-ghibli":
65
+ model_id = "runwayml/stable-diffusion-v1-5"
66
+ elif lora_model_id == "osmr/stable-diffusion-v1-5-lora-db-ghibli":
67
+ model_id = "runwayml/stable-diffusion-v1-5"
68
+ else:
69
+ model_id = lora_model_id
70
+ lora_model_id = None
71
+
72
+ sd_version = "1.5" if (model_id == "runwayml/stable-diffusion-v1-5") else "1.4"
73
+
74
+ use_controlnet = (controlnet_image is not None)
75
+ use_ipadapter = (ipadapter_image is not None)
76
 
77
  if randomize_seed:
78
  seed = random.randint(0, MAX_SEED)
79
 
80
  generator = torch.Generator().manual_seed(seed)
81
 
82
+ if use_controlnet:
83
+ if sd_version == "1.4":
84
+ if controlnet_type == "Edge-Detection":
85
+ controlnet_id = "lllyasviel/sd-controlnet-canny"
86
+ else:
87
+ controlnet_id = "lllyasviel/sd-controlnet-openpose"
88
+ else:
89
+ if controlnet_type == "Edge-Detection":
90
+ controlnet_id = "lllyasviel/control_v11p_sd15_canny"
91
+ else:
92
+ controlnet_id = "lllyasviel/control_v11p_sd15_openpose"
93
+
94
+ controlnet = ControlNetModel.from_pretrained(
95
+ pretrained_model_name_or_path=controlnet_id,
96
+ torch_dtype=torch_dtype)
97
+
98
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
99
+ pretrained_model_name_or_path=model_id,
100
+ controlnet=controlnet,
101
+ torch_dtype=torch_dtype)
102
+ else:
103
+ pipe = StableDiffusionPipeline.from_pretrained(
104
+ pretrained_model_name_or_path=model_id,
105
+ torch_dtype=torch_dtype)
106
+
107
+ if use_ipadapter:
108
+ pipe.load_ip_adapter(
109
+ "h94/IP-Adapter",
110
+ subfolder="models",
111
+ weight_name="ip-adapter_sd15.bin")
112
+ pipe.set_ip_adapter_scale(ipadapter_scale)
113
+
114
+ if use_lora:
115
  pipe.load_lora_weights(lora_model_id)
116
+ cross_attention_kwargs = {"scale": lora_scale}
117
+ else:
118
+ cross_attention_kwargs = None
119
+
120
  pipe = pipe.to(device)
121
 
122
+ image = pipe(
123
+ prompt=prompt,
124
+ negative_prompt=negative_prompt,
125
+ guidance_scale=guidance_scale,
126
+ num_inference_steps=num_inference_steps,
127
+ width=width,
128
+ height=height,
129
+ generator=generator,
130
+ cross_attention_kwargs=cross_attention_kwargs,
131
+ image=controlnet_image,
132
+ controlnet_conditioning_scale=(float(controlnet_cond_scale) if use_controlnet else None),
133
+ ip_adapter_image=ipadapter_image
134
+ ).images[0]
 
 
 
 
 
 
 
 
135
 
136
  return image, seed
137
 
 
140
  "GBL, a man and a woman sitting at a table with glasses of wine in front of them",
141
  "a man and a woman sitting at a table with glasses of wine in front of them",
142
  "GBL, a man sitting at a desk in a library with a book open in front of him",
143
+ "GBL, a cartoon woman is standing in front of a wall",
144
  ]
145
 
146
  css = """
 
244
  value=DEFAULT_LS,
245
  )
246
 
247
+ with gr.Accordion("ControlNet Settings", open=False):
248
+ controlnet_type = gr.Dropdown(
249
+ choices=[
250
+ "Edge-Detection",
251
+ "Pose-Estimation"],
252
+ interactive=True,
253
+ label="ControlNet Type",
254
+ )
255
+
256
+ controlnet_cond_scale = gr.Slider(
257
+ label="ControlNet Conditioning Scale",
258
+ minimum=0.0,
259
+ maximum=2.0,
260
+ step=0.1,
261
+ value=DEFAULT_CN_COND_SCALE
262
+ )
263
+
264
+ controlnet_image = gr.Image(
265
+ label="Control Image",
266
+ type="pil",
267
+ show_label=True)
268
+
269
+ with gr.Accordion("IP-adapter Settings", open=False):
270
+ ipadapter_scale = gr.Slider(
271
+ label="IP-adapter Scale",
272
+ minimum=0.0,
273
+ maximum=1.0,
274
+ step=0.1,
275
+ value=DEFAULT_IPA_SCALE
276
+ )
277
+
278
+ ipadapter_image = gr.Image(
279
+ label="IP-adapter Image",
280
+ type="pil",
281
+ show_label=True)
282
+
283
  gr.Examples(examples=examples, inputs=[prompt])
284
  gr.on(
285
  triggers=[run_button.click, prompt.submit],
 
295
  guidance_scale,
296
  lora_scale,
297
  num_inference_steps,
298
+ controlnet_type,
299
+ controlnet_cond_scale,
300
+ controlnet_image,
301
+ ipadapter_scale,
302
+ ipadapter_image,
303
  ],
304
  outputs=[result, seed],
305
  )