facehuggingjay commited on
Commit
e972aca
·
verified ·
1 Parent(s): 78d1206
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -99,13 +99,14 @@ transform = transforms.Compose([
99
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
  ])
101
  weight_dtype = torch.float16
 
102
 
103
  # line model
104
  line_model_path = model_global_path + '/LE/erika.pth'
105
  line_model = res_skip()
106
  line_model.load_state_dict(torch.load(line_model_path))
107
  line_model.eval()
108
- line_model.cuda()
109
 
110
  # screen model
111
  global opt
@@ -116,7 +117,7 @@ ScreenModel.setup(opt)
116
  ScreenModel.eval()
117
 
118
  image_processor = CLIPImageProcessor()
119
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to('cuda')
120
 
121
 
122
  examples = [
@@ -218,8 +219,8 @@ def load_ckpt(input_style):
218
  ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
219
  transformer.load_state_dict(ckpt_key_t, strict=False)
220
 
221
- transformer.to('cuda', dtype=weight_dtype)
222
- ColorGuider.to('cuda', dtype=weight_dtype)
223
 
224
  pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
225
  pretrained_model_name_or_path,
@@ -230,12 +231,12 @@ def load_ckpt(input_style):
230
  variant=None,
231
  torch_dtype=weight_dtype,
232
  )
233
- pipeline = pipeline.to("cuda")
234
  block_out_channels = [128, 128, 256, 512, 512]
235
 
236
  MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
237
  MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
238
- MultiResNetModel.to('cuda', dtype=weight_dtype)
239
 
240
  elif input_style == "GrayImage(ScreenStyle)":
241
  ckpt_path = model_global_path + '/GraySD/'
@@ -245,8 +246,8 @@ def load_ckpt(input_style):
245
  pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
246
  )
247
  ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
248
- ColorGuider.to('cuda', dtype=weight_dtype)
249
- unet.to('cuda', dtype=weight_dtype)
250
 
251
  pipeline = ColorFlowSDPipeline.from_pretrained(
252
  pretrained_model_name_or_path,
@@ -266,12 +267,12 @@ def load_ckpt(input_style):
266
  )
267
  pipeline.unet.add_adapter(unet_lora_config)
268
  pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
269
- pipeline = pipeline.to("cuda")
270
  block_out_channels = [128, 128, 256, 512, 512]
271
 
272
  MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
273
  MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
274
- MultiResNetModel.to('cuda', dtype=weight_dtype)
275
 
276
 
277
 
@@ -312,7 +313,7 @@ def extract_lines(image):
312
  patch = np.ones((1, 1, rows, cols), dtype="float32")
313
  patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
314
 
315
- tensor = torch.from_numpy(patch).cuda()
316
 
317
  with torch.no_grad():
318
  y = line_model(tensor)
@@ -440,7 +441,7 @@ def colorize_image(VAE_input, input_context, reference_images, resolution, seed,
440
  idx_x = idx_list[k][1]
441
  combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
442
  gr.Info("Model inference in progress...")
443
- generator = torch.Generator(device='cuda').manual_seed(seed)
444
  image = pipeline(
445
  "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
446
  ).images[0]
@@ -455,8 +456,8 @@ def colorize_image(VAE_input, input_context, reference_images, resolution, seed,
455
  bottom = top + new_height
456
  center_crop = image.crop((left, top, right, bottom))
457
  up_img = center_crop.resize(query_image_vae.size)
458
- test_low_color = transform(up_img).unsqueeze(0).to('cuda', dtype=weight_dtype)
459
- query_image_vae = transform(query_image_vae).unsqueeze(0).to('cuda', dtype=weight_dtype)
460
 
461
  h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
462
  h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
 
99
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
  ])
101
  weight_dtype = torch.float16
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
 
104
  # line model
105
  line_model_path = model_global_path + '/LE/erika.pth'
106
  line_model = res_skip()
107
  line_model.load_state_dict(torch.load(line_model_path))
108
  line_model.eval()
109
+ line_model.to(device)
110
 
111
  # screen model
112
  global opt
 
117
  ScreenModel.eval()
118
 
119
  image_processor = CLIPImageProcessor()
120
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to(device)
121
 
122
 
123
  examples = [
 
219
  ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
220
  transformer.load_state_dict(ckpt_key_t, strict=False)
221
 
222
+ transformer.to(device, dtype=weight_dtype)
223
+ ColorGuider.to(device, dtype=weight_dtype)
224
 
225
  pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
226
  pretrained_model_name_or_path,
 
231
  variant=None,
232
  torch_dtype=weight_dtype,
233
  )
234
+ pipeline = pipeline.to(device)
235
  block_out_channels = [128, 128, 256, 512, 512]
236
 
237
  MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
238
  MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
239
+ MultiResNetModel.to(device, dtype=weight_dtype)
240
 
241
  elif input_style == "GrayImage(ScreenStyle)":
242
  ckpt_path = model_global_path + '/GraySD/'
 
246
  pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
247
  )
248
  ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
249
+ ColorGuider.to(device, dtype=weight_dtype)
250
+ unet.to(device, dtype=weight_dtype)
251
 
252
  pipeline = ColorFlowSDPipeline.from_pretrained(
253
  pretrained_model_name_or_path,
 
267
  )
268
  pipeline.unet.add_adapter(unet_lora_config)
269
  pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
270
+ pipeline = pipeline.to(device)
271
  block_out_channels = [128, 128, 256, 512, 512]
272
 
273
  MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
274
  MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
275
+ MultiResNetModel.to(device, dtype=weight_dtype)
276
 
277
 
278
 
 
313
  patch = np.ones((1, 1, rows, cols), dtype="float32")
314
  patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
315
 
316
+ tensor = torch.from_numpy(patch).to(device)
317
 
318
  with torch.no_grad():
319
  y = line_model(tensor)
 
441
  idx_x = idx_list[k][1]
442
  combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
443
  gr.Info("Model inference in progress...")
444
+ generator = torch.Generator(device=device).manual_seed(seed)
445
  image = pipeline(
446
  "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
447
  ).images[0]
 
456
  bottom = top + new_height
457
  center_crop = image.crop((left, top, right, bottom))
458
  up_img = center_crop.resize(query_image_vae.size)
459
+ test_low_color = transform(up_img).unsqueeze(0).to(device, dtype=weight_dtype)
460
+ query_image_vae = transform(query_image_vae).unsqueeze(0).to(device, dtype=weight_dtype)
461
 
462
  h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
463
  h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)