coralLight commited on
Commit
508e6a7
·
1 Parent(s): ddc7f94

add inference

Browse files
Files changed (1) hide show
  1. app.py +110 -117
app.py CHANGED
@@ -34,58 +34,112 @@ from customed_unipc_scheduler import CustomedUniPCMultistepScheduler
34
 
35
  precision_scope = autocast
36
 
37
- def chunk(it, size):
38
- it = iter(it)
39
- return iter(lambda: tuple(islice(it, size)), ())
40
-
41
-
42
- def numpy_to_pil(images):
43
- """
44
- Convert a numpy image or a batch of images to a PIL image.
45
- """
46
- if images.ndim == 3:
47
- images = images[None, ...]
48
- images = (images * 255).round().astype("uint8")
49
- pil_images = [Image.fromarray(image) for image in images]
50
-
51
- return pil_images
52
-
53
-
54
- def load_replacement(x):
55
- try:
56
- hwc = x.shape
57
- y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
58
- y = (np.array(y) / 255.0).astype(x.dtype)
59
- assert y.shape == x.shape
60
- return y
61
- except Exception:
62
- return x
63
-
64
-
65
- # Adapted from pipelines.StableDiffusionPipeline.encode_prompt
66
- def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
67
- captions = []
68
- for caption in prompt_batch:
69
- if random.random() < proportion_empty_prompts:
70
- captions.append("")
71
- elif isinstance(caption, str):
72
- captions.append(caption)
73
- elif isinstance(caption, (list, np.ndarray)):
74
- # take a random caption if there are multiple
75
- captions.append(random.choice(caption) if is_train else caption[0])
76
-
77
- with torch.no_grad():
78
- text_inputs = tokenizer(
79
- captions,
80
- padding="max_length",
81
- max_length=tokenizer.model_max_length,
82
- truncation=True,
83
- return_tensors="pt",
84
  )
85
- text_input_ids = text_inputs.input_ids
86
- prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- return prompt_embeds
89
 
90
  def chunk(it, size):
91
  it = iter(it)
@@ -95,67 +149,6 @@ def convert_caption_json_to_str(json):
95
  caption = json["caption"]
96
  return caption
97
 
98
- def prepare_sdxl_pipeline_step_parameter(pipe, prompts, need_cfg, device, negative_prompts, W = 1024, H = 1024):
99
- (
100
- prompt_embeds,
101
- negative_prompt_embeds,
102
- pooled_prompt_embeds,
103
- negative_pooled_prompt_embeds,
104
- ) = pipe.encode_prompt(
105
- prompt=prompts,
106
- negative_prompt=negative_prompts,
107
- device=device,
108
- do_classifier_free_guidance=need_cfg,
109
- )
110
- # timesteps = pipe.scheduler.timesteps
111
-
112
- prompt_embeds = prompt_embeds.to(device)
113
- add_text_embeds = pooled_prompt_embeds.to(device)
114
- original_size = (W, H)
115
- crops_coords_top_left = (0, 0)
116
- target_size = (W, H)
117
- text_encoder_projection_dim = None
118
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
119
- if pipe.text_encoder_2 is None:
120
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
121
- else:
122
- text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
123
- passed_add_embed_dim = (
124
- pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
125
- )
126
- expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
127
- if expected_add_embed_dim != passed_add_embed_dim:
128
- raise ValueError(
129
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
130
- )
131
- add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
132
- add_time_ids = add_time_ids.to(device)
133
- negative_add_time_ids = add_time_ids
134
-
135
- if need_cfg:
136
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
137
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
138
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
139
- ret_dict = {
140
- "text_embeds": add_text_embeds,
141
- "time_ids": add_time_ids
142
- }
143
- return prompt_embeds, ret_dict
144
-
145
-
146
- def model_closure(pipe):
147
- def model_fn(x, t, c):
148
- prompt = c[0]
149
- cond_kwargs = c[1] if len(c) > 1 else None
150
- # prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe=pipe,prompts = prompt, need_cfg=True, device=pipe.device,negative_prompts=negative_prompt)
151
- # prompt_embeds, cond_kwargs = c
152
- return pipe.unet(x
153
- , t
154
- , encoder_hidden_states=prompt.to(device=x.device, dtype=x.dtype)
155
- , added_cond_kwargs=cond_kwargs).sample
156
-
157
- return model_fn
158
-
159
 
160
  torch_dtype = torch.float16
161
  repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
@@ -210,12 +203,12 @@ def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guid
210
  negative_prompts = 1 * [negative_prompts]
211
 
212
  prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
213
- , prompts
214
- , need_cfg=True
215
- , device=pipe.device
216
- , negative_prompt=negative_prompts
217
- , W=width
218
- , H=height)
219
  noise_pred = pipe.unet(latent_model_input
220
  , t
221
  , encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)
 
34
 
35
  precision_scope = autocast
36
 
37
+ def extract_into_tensor(a, t, x_shape):
38
+ b, *_ = t.shape
39
+ out = a.gather(-1, t)
40
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
41
+
42
+
43
+ def append_zero(x):
44
+ return torch.cat([x, x.new_zeros([1])])
45
+
46
+ def prepare_sdxl_pipeline_step_parameter( pipe: StableDiffusionXLPipeline
47
+ , prompts
48
+ , need_cfg
49
+ , device
50
+ , negative_prompt = None
51
+ , W = 1024
52
+ , H = 1024): # need to correct the format
53
+ (
54
+ prompt_embeds,
55
+ negative_prompt_embeds,
56
+ pooled_prompt_embeds,
57
+ negative_pooled_prompt_embeds,
58
+ ) = pipe.encode_prompt(
59
+ prompt=prompts,
60
+ negative_prompt=negative_prompt,
61
+ device=device,
62
+ do_classifier_free_guidance=need_cfg,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
+ # timesteps = pipe.scheduler.timesteps
65
+
66
+ prompt_embeds = prompt_embeds.to(device)
67
+ add_text_embeds = pooled_prompt_embeds.to(device)
68
+ original_size = (W, H)
69
+ crops_coords_top_left = (0, 0)
70
+ target_size = (W, H)
71
+ text_encoder_projection_dim = None
72
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
73
+ if pipe.text_encoder_2 is None:
74
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
75
+ else:
76
+ text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
77
+ passed_add_embed_dim = (
78
+ pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
79
+ )
80
+ expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
81
+ if expected_add_embed_dim != passed_add_embed_dim:
82
+ raise ValueError(
83
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
84
+ )
85
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
86
+ add_time_ids = add_time_ids.to(device)
87
+ negative_add_time_ids = add_time_ids
88
+
89
+ if need_cfg:
90
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
91
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
92
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
93
+ ret_dict = {
94
+ "text_embeds": add_text_embeds,
95
+ "time_ids": add_time_ids
96
+ }
97
+ return prompt_embeds, ret_dict
98
+
99
+
100
+ # New helper to load a list-of-dicts preference JSON
101
+ # JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
102
+ def load_preference_json(json_path: str) -> list[dict]:
103
+ """Load records from a JSON file formatted as a list of preference dicts."""
104
+ with open(json_path, 'r') as f:
105
+ data = json.load(f)
106
+ return data
107
+
108
+ # New helper to extract just the prompts from the preference JSON
109
+ # Returns a flat list of all 'prompt' values
110
+
111
+ def extract_prompts_from_pref_json(json_path: str) -> list[str]:
112
+ """Load a JSON of preference records and return only the prompts."""
113
+ records = load_preference_json(json_path)
114
+ return [rec['prompt'] for rec in records]
115
+
116
+ # Example usage:
117
+ # prompts = extract_prompts_from_pref_json("path/to/preference.json")
118
+ # print(prompts)
119
+
120
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
121
+ """Constructs the noise schedule of Karras et al. (2022)."""
122
+ ramp = torch.linspace(0, 1, n)
123
+ min_inv_rho = sigma_min ** (1 / rho)
124
+ max_inv_rho = sigma_max ** (1 / rho)
125
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
126
+ return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
127
+
128
+ def extract_into_tensor(a, t, x_shape):
129
+ b, *_ = t.shape
130
+ out = a.gather(-1, t)
131
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
132
+
133
+ def append_zero(x):
134
+ return torch.cat([x, x.new_zeros([1])])
135
+
136
+ def append_dims(x, target_dims):
137
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
138
+ dims_to_append = target_dims - x.ndim
139
+ if dims_to_append < 0:
140
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
141
+ return x[(...,) + (None,) * dims_to_append]
142
 
 
143
 
144
  def chunk(it, size):
145
  it = iter(it)
 
149
  caption = json["caption"]
150
  return caption
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  torch_dtype = torch.float16
154
  repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
 
203
  negative_prompts = 1 * [negative_prompts]
204
 
205
  prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
206
+ , prompts
207
+ , need_cfg=True
208
+ , device=pipe.device
209
+ , negative_prompt=negative_prompts
210
+ , W=width
211
+ , H=height)
212
  noise_pred = pipe.unet(latent_model_input
213
  , t
214
  , encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)