harrydawitch commited on
Commit
ea23874
·
verified ·
1 Parent(s): 11ed273

Update backend/inference.py

Browse files
Files changed (1) hide show
  1. backend/inference.py +201 -201
backend/inference.py CHANGED
@@ -1,202 +1,202 @@
1
- import os
2
- from dotenv import load_dotenv
3
-
4
- import torch
5
- import numpy as np
6
- import cv2
7
- from PIL import Image
8
-
9
- from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
10
-
11
- class ArchIntelligent:
12
- def __init__(self):
13
- # Get private variables from enviroment
14
- load_dotenv()
15
-
16
-
17
-
18
- self.hf_token = os.getenv("HF_TOKEN")
19
- self.style_models = os.getenv("STYLE_MODELS")
20
- self.functional_models= os.getenv("FUNCTION_MODELS")
21
- self.enhancement= os.getenv("REALISM_ENHANCE")
22
- self.controlnet_model= os.getenv("CONTROLNET")
23
- self.base_model = os.getenv("BASEMODEL")
24
-
25
- self.model_config = {}
26
-
27
- # Configure ControlNet model
28
- controlnet = ControlNetModel.from_pretrained(
29
- self.controlnet_model,
30
- torch_dtype= torch.float16,
31
- cache_dir= r"huggingface_cache",
32
- token= self.hf_token,
33
- variant= 'fp16',
34
- )
35
-
36
- self.pipeline= StableDiffusionXLControlNetPipeline.from_pretrained(
37
- self.base_model,
38
- controlnet= controlnet,
39
- torch_dtype= torch.float16,
40
- cache_dir= r"huggingface_cache",
41
- token= self.hf_token,
42
- variant= 'fp16',
43
- )
44
-
45
-
46
- # Enable memory-efficient optimizations
47
- try:
48
- self.pipeline.enable_xformers_memory_efficient_attention()
49
- self.pipeline.enable_vae_slicing()
50
- self.pipeline.enable_sequential_cpu_offload()
51
-
52
- print(f"xFormers enabled\nVAE Slicing mode enabled\nSequential CPU Offload enabled!")
53
- except Exception as e:
54
- print(f"Warning: Some optimizations failed: {e}")
55
-
56
-
57
-
58
-
59
-
60
- def img2canny(self, input_img):
61
- """
62
- Processing user's condition image into edge map
63
-
64
- Parameters
65
- input_img : PIL image
66
-
67
- Returns
68
- PIL image
69
- """
70
-
71
- np_image = np.array(input_img)
72
-
73
- # Convert the image into a grayscale image then extract edge map
74
- canny = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY)
75
- canny = cv2.resize(canny, (1024, 1024))
76
- canny = cv2.Canny(canny, 100, 200)
77
-
78
- canny = Image.fromarray(canny)
79
-
80
- return canny
81
-
82
- def process_config(self, config: dict):
83
-
84
- style_dict = {"Modern": "Modernism", "Minimalism": "Minimalism", "Art Deco": "ArtDeco",
85
- "Art Nouveau": "ArtNouveau", "Baroque": "Baroque", "Brutalist": "Brutalist",
86
- "Classical": "Classical", "Neo-Classical": "Neo-Classical", "Cyberpunk": "Cyberpunk",
87
- "Deconstructivism": "Deconstructivism", "Futurism": "Futurism", "Gothic": "Gothic",
88
- "Neo-Futurism": "Neo-Futurism", "Sustainable": "Sustainable", "Victorian": "Victorian"}
89
-
90
- functional_dict = {"Residential": "Modern", "Villa": "Modern", "Office": "Office", "Skyscraper": "SkyScraper",
91
- "Hotel": "Hotel", "School Campus": "SchoolCampus", "Farmhouse": "Farmhouse", "Playground": "PlayGround",
92
- "Park": "Park", "Apartment": "Apartment", "Hospital": "Hospital", "Kindergarten": "KinderGarten",
93
- "Church": "Church", "Container": "Container", "Bridge": "Bridge", "Resort": "Resort", "Airport": "Airport",
94
- "Factory": "Factory", "Stadium": "Stadium", "Temple": "Temple", "Tree House": "TreeHouse"}
95
-
96
-
97
- styles= config['style_names']
98
- functional= config['functional_names']
99
-
100
- season = config['season']
101
- landscape= config['landscape']
102
- weather= config['weather']
103
- day= config['time_of_day']
104
-
105
- config['posprompt_2'] = f"(((realistic))), (({styles})), (({functional})), ({landscape}), ({season}), ({weather}), ({day}), (high quality),\
106
- (high resolution), 4k render, detail, beautiful, cinematic lighting, hyper-realistic"
107
-
108
- config['negprompt_2'] = "((blurry)), details are low, overlapping, (grainy), multiple angles, deformed structures, unnatural, unrealistic, cartoon, \
109
- anime, (painting), drawing, sketch, gibberish text, logo, noise, jpeg artifacts, mutation, (((worst quality))), ((low quality)), (((low resolution))),\
110
- messy, watermark, signature, cut off, low contrast, underexposed, overexposed, draft, disfigured, ugly, tiling, out of frame"
111
-
112
- config["LoRA_style"] = style_dict[styles]
113
- config["LoRA_functional"] = functional_dict[functional]
114
- config['adapter_weights'] = [1.0, 1.0, 0.8]
115
-
116
-
117
- self.model_config = config
118
-
119
-
120
- def generate(self):
121
- """
122
- Generate building image using user's input arguments
123
- """
124
-
125
- # Get user's prompts from dictionary
126
- first_prompt = self.model_config["posprompt_1"]
127
- second_prompt = self.model_config["posprompt_2"]
128
- first_negprompt = self.model_config["negprompt_1"]
129
- second_negprompt = self.model_config["negprompt_2"]
130
-
131
- # Get user's image
132
- input_image = self.model_config['image']
133
-
134
- # Get ControlNet conditioning scale value
135
- controlnet_condition = self.model_config["condition_scale"]
136
-
137
- # Get guidance scale value
138
- guidance_scale = self.model_config["guidance"]
139
-
140
- # Get render speed
141
- render_speed = self.model_config["render_speed"]
142
-
143
- # Get LoRA weight's name and their corresponding adapter weights
144
- LoRA_style_names = self.model_config['LoRA_style']
145
- LoRA_functional_names = self.model_config['LoRA_functional']
146
- LoRA_enhancement_names = 'Realism'
147
- adapter_weights = self.model_config['adapter_weights']
148
-
149
- LoRA_names = [LoRA_style_names, LoRA_functional_names, LoRA_enhancement_names]
150
-
151
- self.pipeline.unload_lora_weights()
152
- print(f"\n\nUNLOADED LORA WEIGHTS\n\n")
153
-
154
- os.environ['HF_HOME'] = r"huggingface_cache"
155
- self.pipeline.load_lora_weights(
156
- self.style_models,
157
- weight_name= f"{LoRA_style_names}.safetensors",
158
- adapter_name= LoRA_style_names
159
- )
160
-
161
-
162
- self.pipeline.load_lora_weights(
163
- self.functional_models,
164
- weight_name= f"{LoRA_functional_names}.safetensors",
165
- adapter_name= LoRA_functional_names
166
- )
167
-
168
- self.pipeline.load_lora_weights(
169
- self.enhancement,
170
- weight_name= f"realistic.safetensors",
171
- adapter_name= LoRA_enhancement_names
172
- )
173
-
174
- print(f"Finished loadded 3 LoRA weights {LoRA_style_names}, {LoRA_functional_names} and {LoRA_enhancement_names}")
175
-
176
-
177
- self.pipeline.set_adapters(adapter_names= LoRA_names, adapter_weights= adapter_weights)
178
- print(f"Adapted 3 lora weights")
179
-
180
- # Transform the image into a depth map that is compatible with ControlNet
181
- conditional_image = self.img2canny(input_image)
182
-
183
- # Setup the pipeline then generate image
184
- image = self.pipeline(
185
- prompt= first_prompt,
186
- prompt_2= second_prompt,
187
- negative_prompt= first_negprompt,
188
- negative_prompt_2= second_negprompt,
189
-
190
- image= conditional_image,
191
- controlnet_conditioning_scale= controlnet_condition,
192
- num_inference_steps= render_speed,
193
- guidance_scale= guidance_scale
194
- ).images[0]
195
-
196
- return image
197
-
198
-
199
- if __name__ == '__main__':
200
- print("Loading")
201
- pipe = ArchIntelligent()
202
  print("Finished")
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+
9
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
10
+
11
+ class ArchIntelligent:
12
+ def __init__(self):
13
+ # Get private variables from enviroment
14
+ load_dotenv()
15
+
16
+
17
+
18
+ self.hf_token = os.getenv("HF_TOKEN")
19
+ self.style_models = os.getenv("STYLE_MODELS")
20
+ self.functional_models= os.getenv("FUNCTION_MODELS")
21
+ self.enhancement= os.getenv("REALISM_ENHANCE")
22
+ self.controlnet_model= os.getenv("CONTROLNET")
23
+ self.base_model = os.getenv("BASEMODEL")
24
+
25
+ self.model_config = {}
26
+
27
+ # Configure ControlNet model
28
+ controlnet = ControlNetModel.from_pretrained(
29
+ self.controlnet_model,
30
+ torch_dtype= torch.float16,
31
+ cache_dir= r"huggingface_cache",
32
+ token= self.hf_token,
33
+ variant= 'fp16',
34
+ )
35
+
36
+ self.pipeline= StableDiffusionXLControlNetPipeline.from_pretrained(
37
+ self.base_model,
38
+ controlnet= controlnet,
39
+ torch_dtype= torch.float16,
40
+ cache_dir= r"huggingface_cache",
41
+ token= self.hf_token,
42
+ variant= 'fp16',
43
+ )
44
+
45
+
46
+ # Enable memory-efficient optimizations
47
+ try:
48
+ self.pipeline.enable_xformers_memory_efficient_attention()
49
+ self.pipeline.enable_vae_slicing()
50
+ self.pipeline.enable_sequential_cpu_offload()
51
+
52
+ print(f"xFormers enabled\nVAE Slicing mode enabled\nSequential CPU Offload enabled!")
53
+ except Exception as e:
54
+ print(f"Warning: Some optimizations failed: {e}")
55
+
56
+
57
+
58
+
59
+
60
+ def img2canny(self, input_img):
61
+ """
62
+ Processing user's condition image into edge map
63
+
64
+ Parameters
65
+ input_img : PIL image
66
+
67
+ Returns
68
+ PIL image
69
+ """
70
+
71
+ np_image = np.array(input_img)
72
+
73
+ # Convert the image into a grayscale image then extract edge map
74
+ canny = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY)
75
+ canny = cv2.resize(canny, (1024, 1024))
76
+ canny = cv2.Canny(canny, 100, 200)
77
+
78
+ canny = Image.fromarray(canny)
79
+
80
+ return canny
81
+
82
+ def process_config(self, config: dict):
83
+
84
+ style_dict = {"Modern": "Modernism", "Minimalism": "Minimalism", "Art Deco": "ArtDeco",
85
+ "Art Nouveau": "ArtNouveau", "Baroque": "Baroque", "Brutalist": "Brutalist",
86
+ "Classical": "Classical", "Neo-Classical": "Neo-Classical", "Cyberpunk": "Cyberpunk",
87
+ "Deconstructivism": "Deconstructivism", "Futurism": "Futurism", "Gothic": "Gothic",
88
+ "Neo-Futurism": "Neo-Futurism", "Sustainable": "Sustainable", "Victorian": "Victorian"}
89
+
90
+ functional_dict = {"Residential": "Modern", "Villa": "Modern", "Office": "Office", "Skyscraper": "SkyScraper",
91
+ "Hotel": "Hotel", "School Campus": "SchoolCampus", "Farmhouse": "Farmhouse", "Playground": "PlayGround",
92
+ "Park": "Park", "Apartment": "Apartment", "Hospital": "Hospital", "Kindergarten": "KinderGarten",
93
+ "Church": "Church", "Container": "Container", "Bridge": "Bridge", "Resort": "Resort", "Airport": "Airport",
94
+ "Factory": "Factory", "Stadium": "Stadium", "Temple": "Temple", "Tree House": "TreeHouse"}
95
+
96
+
97
+ styles= config['style_names']
98
+ functional= config['functional_names']
99
+
100
+ season = config['season']
101
+ landscape= config['landscape']
102
+ weather= config['weather']
103
+ day= config['time_of_day']
104
+
105
+ config['posprompt_2'] = f"(((realistic))), (({styles})), (({functional})), ({landscape}), ({season}), ({weather}), ({day}), (high quality),\
106
+ (high resolution), 4k render, detail, beautiful, cinematic lighting, hyper-realistic"
107
+
108
+ config['negprompt_2'] = "((blurry)), details are low, overlapping, (grainy), multiple angles, deformed structures, unnatural, unrealistic, cartoon, \
109
+ anime, (painting), drawing, sketch, gibberish text, logo, noise, jpeg artifacts, mutation, (((worst quality))), ((low quality)), (((low resolution))),\
110
+ messy, watermark, signature, cut off, low contrast, underexposed, overexposed, draft, disfigured, ugly, tiling, out of frame"
111
+
112
+ config["LoRA_style"] = style_dict[styles]
113
+ config["LoRA_functional"] = functional_dict[functional]
114
+ config['adapter_weights'] = [1.0, 1.0, 0.8]
115
+
116
+
117
+ self.model_config = config
118
+
119
+
120
+ def generate(self):
121
+ """
122
+ Generate building image using user's input arguments
123
+ """
124
+
125
+ # Get user's prompts from dictionary
126
+ first_prompt = self.model_config["posprompt_1"]
127
+ second_prompt = self.model_config["posprompt_2"]
128
+ first_negprompt = self.model_config["negprompt_1"]
129
+ second_negprompt = self.model_config["negprompt_2"]
130
+
131
+ # Get user's image
132
+ input_image = self.model_config['image']
133
+
134
+ # Get ControlNet conditioning scale value
135
+ controlnet_condition = self.model_config["condition_scale"]
136
+
137
+ # Get guidance scale value
138
+ guidance_scale = self.model_config["guidance"]
139
+
140
+ # Get render speed
141
+ render_speed = self.model_config["render_speed"]
142
+
143
+ # Get LoRA weight's name and their corresponding adapter weights
144
+ LoRA_style_names = self.model_config['LoRA_style']
145
+ LoRA_functional_names = self.model_config['LoRA_functional']
146
+ LoRA_enhancement_names = 'Realism'
147
+ adapter_weights = self.model_config['adapter_weights']
148
+
149
+ LoRA_names = [LoRA_style_names, LoRA_functional_names, LoRA_enhancement_names]
150
+
151
+ self.pipeline.unload_lora_weights()
152
+ print(f"\n\nUNLOADED LORA WEIGHTS\n\n")
153
+
154
+ os.environ['HF_HOME'] = r"huggingface_cache"
155
+ self.pipeline.load_lora_weights(
156
+ self.style_models,
157
+ weight_name= f"{LoRA_style_names}.safetensors",
158
+ adapter_name= LoRA_style_names
159
+ )
160
+
161
+
162
+ self.pipeline.load_lora_weights(
163
+ self.functional_models,
164
+ weight_name= f"{LoRA_functional_names}.safetensors",
165
+ adapter_name= LoRA_functional_names
166
+ )
167
+
168
+ self.pipeline.load_lora_weights(
169
+ self.enhancement,
170
+ weight_name= f"realarchvis_xlV10.safetensors",
171
+ adapter_name= LoRA_enhancement_names
172
+ )
173
+
174
+ print(f"Finished loadded 3 LoRA weights {LoRA_style_names}, {LoRA_functional_names} and {LoRA_enhancement_names}")
175
+
176
+
177
+ self.pipeline.set_adapters(adapter_names= LoRA_names, adapter_weights= adapter_weights)
178
+ print(f"Adapted 3 lora weights")
179
+
180
+ # Transform the image into a depth map that is compatible with ControlNet
181
+ conditional_image = self.img2canny(input_image)
182
+
183
+ # Setup the pipeline then generate image
184
+ image = self.pipeline(
185
+ prompt= first_prompt,
186
+ prompt_2= second_prompt,
187
+ negative_prompt= first_negprompt,
188
+ negative_prompt_2= second_negprompt,
189
+
190
+ image= conditional_image,
191
+ controlnet_conditioning_scale= controlnet_condition,
192
+ num_inference_steps= render_speed,
193
+ guidance_scale= guidance_scale
194
+ ).images[0]
195
+
196
+ return image
197
+
198
+
199
+ if __name__ == '__main__':
200
+ print("Loading")
201
+ pipe = ArchIntelligent()
202
  print("Finished")