Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
508e6a7
1
Parent(s):
ddc7f94
add inference
Browse files
app.py
CHANGED
|
@@ -34,58 +34,112 @@ from customed_unipc_scheduler import CustomedUniPCMultistepScheduler
|
|
| 34 |
|
| 35 |
precision_scope = autocast
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
|
| 66 |
-
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
|
| 67 |
-
captions = []
|
| 68 |
-
for caption in prompt_batch:
|
| 69 |
-
if random.random() < proportion_empty_prompts:
|
| 70 |
-
captions.append("")
|
| 71 |
-
elif isinstance(caption, str):
|
| 72 |
-
captions.append(caption)
|
| 73 |
-
elif isinstance(caption, (list, np.ndarray)):
|
| 74 |
-
# take a random caption if there are multiple
|
| 75 |
-
captions.append(random.choice(caption) if is_train else caption[0])
|
| 76 |
-
|
| 77 |
-
with torch.no_grad():
|
| 78 |
-
text_inputs = tokenizer(
|
| 79 |
-
captions,
|
| 80 |
-
padding="max_length",
|
| 81 |
-
max_length=tokenizer.model_max_length,
|
| 82 |
-
truncation=True,
|
| 83 |
-
return_tensors="pt",
|
| 84 |
)
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
return prompt_embeds
|
| 89 |
|
| 90 |
def chunk(it, size):
|
| 91 |
it = iter(it)
|
|
@@ -95,67 +149,6 @@ def convert_caption_json_to_str(json):
|
|
| 95 |
caption = json["caption"]
|
| 96 |
return caption
|
| 97 |
|
| 98 |
-
def prepare_sdxl_pipeline_step_parameter(pipe, prompts, need_cfg, device, negative_prompts, W = 1024, H = 1024):
|
| 99 |
-
(
|
| 100 |
-
prompt_embeds,
|
| 101 |
-
negative_prompt_embeds,
|
| 102 |
-
pooled_prompt_embeds,
|
| 103 |
-
negative_pooled_prompt_embeds,
|
| 104 |
-
) = pipe.encode_prompt(
|
| 105 |
-
prompt=prompts,
|
| 106 |
-
negative_prompt=negative_prompts,
|
| 107 |
-
device=device,
|
| 108 |
-
do_classifier_free_guidance=need_cfg,
|
| 109 |
-
)
|
| 110 |
-
# timesteps = pipe.scheduler.timesteps
|
| 111 |
-
|
| 112 |
-
prompt_embeds = prompt_embeds.to(device)
|
| 113 |
-
add_text_embeds = pooled_prompt_embeds.to(device)
|
| 114 |
-
original_size = (W, H)
|
| 115 |
-
crops_coords_top_left = (0, 0)
|
| 116 |
-
target_size = (W, H)
|
| 117 |
-
text_encoder_projection_dim = None
|
| 118 |
-
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 119 |
-
if pipe.text_encoder_2 is None:
|
| 120 |
-
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 121 |
-
else:
|
| 122 |
-
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
| 123 |
-
passed_add_embed_dim = (
|
| 124 |
-
pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 125 |
-
)
|
| 126 |
-
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
| 127 |
-
if expected_add_embed_dim != passed_add_embed_dim:
|
| 128 |
-
raise ValueError(
|
| 129 |
-
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 130 |
-
)
|
| 131 |
-
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
| 132 |
-
add_time_ids = add_time_ids.to(device)
|
| 133 |
-
negative_add_time_ids = add_time_ids
|
| 134 |
-
|
| 135 |
-
if need_cfg:
|
| 136 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 137 |
-
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 138 |
-
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 139 |
-
ret_dict = {
|
| 140 |
-
"text_embeds": add_text_embeds,
|
| 141 |
-
"time_ids": add_time_ids
|
| 142 |
-
}
|
| 143 |
-
return prompt_embeds, ret_dict
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def model_closure(pipe):
|
| 147 |
-
def model_fn(x, t, c):
|
| 148 |
-
prompt = c[0]
|
| 149 |
-
cond_kwargs = c[1] if len(c) > 1 else None
|
| 150 |
-
# prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe=pipe,prompts = prompt, need_cfg=True, device=pipe.device,negative_prompts=negative_prompt)
|
| 151 |
-
# prompt_embeds, cond_kwargs = c
|
| 152 |
-
return pipe.unet(x
|
| 153 |
-
, t
|
| 154 |
-
, encoder_hidden_states=prompt.to(device=x.device, dtype=x.dtype)
|
| 155 |
-
, added_cond_kwargs=cond_kwargs).sample
|
| 156 |
-
|
| 157 |
-
return model_fn
|
| 158 |
-
|
| 159 |
|
| 160 |
torch_dtype = torch.float16
|
| 161 |
repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
|
|
@@ -210,12 +203,12 @@ def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guid
|
|
| 210 |
negative_prompts = 1 * [negative_prompts]
|
| 211 |
|
| 212 |
prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
noise_pred = pipe.unet(latent_model_input
|
| 220 |
, t
|
| 221 |
, encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 34 |
|
| 35 |
precision_scope = autocast
|
| 36 |
|
| 37 |
+
def extract_into_tensor(a, t, x_shape):
|
| 38 |
+
b, *_ = t.shape
|
| 39 |
+
out = a.gather(-1, t)
|
| 40 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def append_zero(x):
|
| 44 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 45 |
+
|
| 46 |
+
def prepare_sdxl_pipeline_step_parameter( pipe: StableDiffusionXLPipeline
|
| 47 |
+
, prompts
|
| 48 |
+
, need_cfg
|
| 49 |
+
, device
|
| 50 |
+
, negative_prompt = None
|
| 51 |
+
, W = 1024
|
| 52 |
+
, H = 1024): # need to correct the format
|
| 53 |
+
(
|
| 54 |
+
prompt_embeds,
|
| 55 |
+
negative_prompt_embeds,
|
| 56 |
+
pooled_prompt_embeds,
|
| 57 |
+
negative_pooled_prompt_embeds,
|
| 58 |
+
) = pipe.encode_prompt(
|
| 59 |
+
prompt=prompts,
|
| 60 |
+
negative_prompt=negative_prompt,
|
| 61 |
+
device=device,
|
| 62 |
+
do_classifier_free_guidance=need_cfg,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
+
# timesteps = pipe.scheduler.timesteps
|
| 65 |
+
|
| 66 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 67 |
+
add_text_embeds = pooled_prompt_embeds.to(device)
|
| 68 |
+
original_size = (W, H)
|
| 69 |
+
crops_coords_top_left = (0, 0)
|
| 70 |
+
target_size = (W, H)
|
| 71 |
+
text_encoder_projection_dim = None
|
| 72 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 73 |
+
if pipe.text_encoder_2 is None:
|
| 74 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 75 |
+
else:
|
| 76 |
+
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
| 77 |
+
passed_add_embed_dim = (
|
| 78 |
+
pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 79 |
+
)
|
| 80 |
+
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
| 81 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 84 |
+
)
|
| 85 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
| 86 |
+
add_time_ids = add_time_ids.to(device)
|
| 87 |
+
negative_add_time_ids = add_time_ids
|
| 88 |
+
|
| 89 |
+
if need_cfg:
|
| 90 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 91 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 92 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 93 |
+
ret_dict = {
|
| 94 |
+
"text_embeds": add_text_embeds,
|
| 95 |
+
"time_ids": add_time_ids
|
| 96 |
+
}
|
| 97 |
+
return prompt_embeds, ret_dict
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# New helper to load a list-of-dicts preference JSON
|
| 101 |
+
# JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
|
| 102 |
+
def load_preference_json(json_path: str) -> list[dict]:
|
| 103 |
+
"""Load records from a JSON file formatted as a list of preference dicts."""
|
| 104 |
+
with open(json_path, 'r') as f:
|
| 105 |
+
data = json.load(f)
|
| 106 |
+
return data
|
| 107 |
+
|
| 108 |
+
# New helper to extract just the prompts from the preference JSON
|
| 109 |
+
# Returns a flat list of all 'prompt' values
|
| 110 |
+
|
| 111 |
+
def extract_prompts_from_pref_json(json_path: str) -> list[str]:
|
| 112 |
+
"""Load a JSON of preference records and return only the prompts."""
|
| 113 |
+
records = load_preference_json(json_path)
|
| 114 |
+
return [rec['prompt'] for rec in records]
|
| 115 |
+
|
| 116 |
+
# Example usage:
|
| 117 |
+
# prompts = extract_prompts_from_pref_json("path/to/preference.json")
|
| 118 |
+
# print(prompts)
|
| 119 |
+
|
| 120 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
|
| 121 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 122 |
+
ramp = torch.linspace(0, 1, n)
|
| 123 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 124 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 125 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 126 |
+
return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 127 |
+
|
| 128 |
+
def extract_into_tensor(a, t, x_shape):
|
| 129 |
+
b, *_ = t.shape
|
| 130 |
+
out = a.gather(-1, t)
|
| 131 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 132 |
+
|
| 133 |
+
def append_zero(x):
|
| 134 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 135 |
+
|
| 136 |
+
def append_dims(x, target_dims):
|
| 137 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 138 |
+
dims_to_append = target_dims - x.ndim
|
| 139 |
+
if dims_to_append < 0:
|
| 140 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
| 141 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 142 |
|
|
|
|
| 143 |
|
| 144 |
def chunk(it, size):
|
| 145 |
it = iter(it)
|
|
|
|
| 149 |
caption = json["caption"]
|
| 150 |
return caption
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
torch_dtype = torch.float16
|
| 154 |
repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
|
|
|
|
| 203 |
negative_prompts = 1 * [negative_prompts]
|
| 204 |
|
| 205 |
prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
|
| 206 |
+
, prompts
|
| 207 |
+
, need_cfg=True
|
| 208 |
+
, device=pipe.device
|
| 209 |
+
, negative_prompt=negative_prompts
|
| 210 |
+
, W=width
|
| 211 |
+
, H=height)
|
| 212 |
noise_pred = pipe.unet(latent_model_input
|
| 213 |
, t
|
| 214 |
, encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)
|