Spaces:
Runtime error
Runtime error
claude
Browse files
app.py
CHANGED
|
@@ -99,13 +99,14 @@ transform = transforms.Compose([
|
|
| 99 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 100 |
])
|
| 101 |
weight_dtype = torch.float16
|
|
|
|
| 102 |
|
| 103 |
# line model
|
| 104 |
line_model_path = model_global_path + '/LE/erika.pth'
|
| 105 |
line_model = res_skip()
|
| 106 |
line_model.load_state_dict(torch.load(line_model_path))
|
| 107 |
line_model.eval()
|
| 108 |
-
line_model.
|
| 109 |
|
| 110 |
# screen model
|
| 111 |
global opt
|
|
@@ -116,7 +117,7 @@ ScreenModel.setup(opt)
|
|
| 116 |
ScreenModel.eval()
|
| 117 |
|
| 118 |
image_processor = CLIPImageProcessor()
|
| 119 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to(
|
| 120 |
|
| 121 |
|
| 122 |
examples = [
|
|
@@ -218,8 +219,8 @@ def load_ckpt(input_style):
|
|
| 218 |
ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
|
| 219 |
transformer.load_state_dict(ckpt_key_t, strict=False)
|
| 220 |
|
| 221 |
-
transformer.to(
|
| 222 |
-
ColorGuider.to(
|
| 223 |
|
| 224 |
pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
|
| 225 |
pretrained_model_name_or_path,
|
|
@@ -230,12 +231,12 @@ def load_ckpt(input_style):
|
|
| 230 |
variant=None,
|
| 231 |
torch_dtype=weight_dtype,
|
| 232 |
)
|
| 233 |
-
pipeline = pipeline.to(
|
| 234 |
block_out_channels = [128, 128, 256, 512, 512]
|
| 235 |
|
| 236 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
| 237 |
MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
|
| 238 |
-
MultiResNetModel.to(
|
| 239 |
|
| 240 |
elif input_style == "GrayImage(ScreenStyle)":
|
| 241 |
ckpt_path = model_global_path + '/GraySD/'
|
|
@@ -245,8 +246,8 @@ def load_ckpt(input_style):
|
|
| 245 |
pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
|
| 246 |
)
|
| 247 |
ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
|
| 248 |
-
ColorGuider.to(
|
| 249 |
-
unet.to(
|
| 250 |
|
| 251 |
pipeline = ColorFlowSDPipeline.from_pretrained(
|
| 252 |
pretrained_model_name_or_path,
|
|
@@ -266,12 +267,12 @@ def load_ckpt(input_style):
|
|
| 266 |
)
|
| 267 |
pipeline.unet.add_adapter(unet_lora_config)
|
| 268 |
pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
|
| 269 |
-
pipeline = pipeline.to(
|
| 270 |
block_out_channels = [128, 128, 256, 512, 512]
|
| 271 |
|
| 272 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
| 273 |
MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
|
| 274 |
-
MultiResNetModel.to(
|
| 275 |
|
| 276 |
|
| 277 |
|
|
@@ -312,7 +313,7 @@ def extract_lines(image):
|
|
| 312 |
patch = np.ones((1, 1, rows, cols), dtype="float32")
|
| 313 |
patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
|
| 314 |
|
| 315 |
-
tensor = torch.from_numpy(patch).
|
| 316 |
|
| 317 |
with torch.no_grad():
|
| 318 |
y = line_model(tensor)
|
|
@@ -440,7 +441,7 @@ def colorize_image(VAE_input, input_context, reference_images, resolution, seed,
|
|
| 440 |
idx_x = idx_list[k][1]
|
| 441 |
combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
|
| 442 |
gr.Info("Model inference in progress...")
|
| 443 |
-
generator = torch.Generator(device=
|
| 444 |
image = pipeline(
|
| 445 |
"manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
|
| 446 |
).images[0]
|
|
@@ -455,8 +456,8 @@ def colorize_image(VAE_input, input_context, reference_images, resolution, seed,
|
|
| 455 |
bottom = top + new_height
|
| 456 |
center_crop = image.crop((left, top, right, bottom))
|
| 457 |
up_img = center_crop.resize(query_image_vae.size)
|
| 458 |
-
test_low_color = transform(up_img).unsqueeze(0).to(
|
| 459 |
-
query_image_vae = transform(query_image_vae).unsqueeze(0).to(
|
| 460 |
|
| 461 |
h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
|
| 462 |
h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
|
|
|
|
| 99 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 100 |
])
|
| 101 |
weight_dtype = torch.float16
|
| 102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
|
| 104 |
# line model
|
| 105 |
line_model_path = model_global_path + '/LE/erika.pth'
|
| 106 |
line_model = res_skip()
|
| 107 |
line_model.load_state_dict(torch.load(line_model_path))
|
| 108 |
line_model.eval()
|
| 109 |
+
line_model.to(device)
|
| 110 |
|
| 111 |
# screen model
|
| 112 |
global opt
|
|
|
|
| 117 |
ScreenModel.eval()
|
| 118 |
|
| 119 |
image_processor = CLIPImageProcessor()
|
| 120 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to(device)
|
| 121 |
|
| 122 |
|
| 123 |
examples = [
|
|
|
|
| 219 |
ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
|
| 220 |
transformer.load_state_dict(ckpt_key_t, strict=False)
|
| 221 |
|
| 222 |
+
transformer.to(device, dtype=weight_dtype)
|
| 223 |
+
ColorGuider.to(device, dtype=weight_dtype)
|
| 224 |
|
| 225 |
pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
|
| 226 |
pretrained_model_name_or_path,
|
|
|
|
| 231 |
variant=None,
|
| 232 |
torch_dtype=weight_dtype,
|
| 233 |
)
|
| 234 |
+
pipeline = pipeline.to(device)
|
| 235 |
block_out_channels = [128, 128, 256, 512, 512]
|
| 236 |
|
| 237 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
| 238 |
MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
|
| 239 |
+
MultiResNetModel.to(device, dtype=weight_dtype)
|
| 240 |
|
| 241 |
elif input_style == "GrayImage(ScreenStyle)":
|
| 242 |
ckpt_path = model_global_path + '/GraySD/'
|
|
|
|
| 246 |
pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
|
| 247 |
)
|
| 248 |
ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
|
| 249 |
+
ColorGuider.to(device, dtype=weight_dtype)
|
| 250 |
+
unet.to(device, dtype=weight_dtype)
|
| 251 |
|
| 252 |
pipeline = ColorFlowSDPipeline.from_pretrained(
|
| 253 |
pretrained_model_name_or_path,
|
|
|
|
| 267 |
)
|
| 268 |
pipeline.unet.add_adapter(unet_lora_config)
|
| 269 |
pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
|
| 270 |
+
pipeline = pipeline.to(device)
|
| 271 |
block_out_channels = [128, 128, 256, 512, 512]
|
| 272 |
|
| 273 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
| 274 |
MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
|
| 275 |
+
MultiResNetModel.to(device, dtype=weight_dtype)
|
| 276 |
|
| 277 |
|
| 278 |
|
|
|
|
| 313 |
patch = np.ones((1, 1, rows, cols), dtype="float32")
|
| 314 |
patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
|
| 315 |
|
| 316 |
+
tensor = torch.from_numpy(patch).to(device)
|
| 317 |
|
| 318 |
with torch.no_grad():
|
| 319 |
y = line_model(tensor)
|
|
|
|
| 441 |
idx_x = idx_list[k][1]
|
| 442 |
combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
|
| 443 |
gr.Info("Model inference in progress...")
|
| 444 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 445 |
image = pipeline(
|
| 446 |
"manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
|
| 447 |
).images[0]
|
|
|
|
| 456 |
bottom = top + new_height
|
| 457 |
center_crop = image.crop((left, top, right, bottom))
|
| 458 |
up_img = center_crop.resize(query_image_vae.size)
|
| 459 |
+
test_low_color = transform(up_img).unsqueeze(0).to(device, dtype=weight_dtype)
|
| 460 |
+
query_image_vae = transform(query_image_vae).unsqueeze(0).to(device, dtype=weight_dtype)
|
| 461 |
|
| 462 |
h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
|
| 463 |
h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
|