Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import contextlib
|
| 2 |
import gc
|
| 3 |
import json
|
|
@@ -179,6 +180,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model
|
|
| 179 |
global pipeline
|
| 180 |
global MultiResNetModel
|
| 181 |
|
|
|
|
| 182 |
def load_ckpt():
|
| 183 |
global pipeline
|
| 184 |
global MultiResNetModel
|
|
@@ -293,6 +295,7 @@ def load_ckpt():
|
|
| 293 |
|
| 294 |
global cur_style
|
| 295 |
cur_style = 'line + shadow'
|
|
|
|
| 296 |
def change_ckpt(style):
|
| 297 |
global pipeline
|
| 298 |
global MultiResNetModel
|
|
@@ -334,6 +337,7 @@ def change_ckpt(style):
|
|
| 334 |
|
| 335 |
load_ckpt()
|
| 336 |
|
|
|
|
| 337 |
def fix_random_seeds(seed):
|
| 338 |
random.seed(seed)
|
| 339 |
np.random.seed(seed)
|
|
@@ -349,6 +353,7 @@ def process_multi_images(files):
|
|
| 349 |
imgs.append(img)
|
| 350 |
return imgs
|
| 351 |
|
|
|
|
| 352 |
def extract_lines(image):
|
| 353 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
| 354 |
|
|
@@ -373,16 +378,17 @@ def extract_lines(image):
|
|
| 373 |
torch.cuda.empty_cache()
|
| 374 |
return outimg
|
| 375 |
|
|
|
|
| 376 |
def extract_line_image(query_image_, resolution):
|
| 377 |
tar_width, tar_height = resolution
|
| 378 |
query_image = query_image_.resize((tar_width, tar_height))
|
| 379 |
-
# query_image.save('/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3/input.png')
|
| 380 |
query_image = query_image.convert('L').convert('RGB')
|
| 381 |
extracted_line = extract_lines(query_image)
|
| 382 |
extracted_line = extracted_line.convert('L').convert('RGB')
|
| 383 |
torch.cuda.empty_cache()
|
| 384 |
return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
|
| 385 |
|
|
|
|
| 386 |
def extract_sketch_line_image(query_image_, input_style):
|
| 387 |
global cur_style
|
| 388 |
if input_style != cur_style:
|
|
@@ -418,6 +424,7 @@ def extract_sketch_line_image(query_image_, input_style):
|
|
| 418 |
|
| 419 |
return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
|
| 420 |
|
|
|
|
| 421 |
def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
|
| 422 |
if extracted_line is None:
|
| 423 |
gr.Info("Please preprocess the image first")
|
|
@@ -440,11 +447,6 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
|
|
| 440 |
reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
|
| 441 |
query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
|
| 442 |
reference_patches_pil = []
|
| 443 |
-
# Save reference_images
|
| 444 |
-
# save_path = '/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3'
|
| 445 |
-
# os.makedirs(save_path, exist_ok=True)
|
| 446 |
-
# for idx, ref_image in enumerate(reference_images):
|
| 447 |
-
# ref_image.save(os.path.join(save_path, f'reference_image_{idx}.png'))
|
| 448 |
|
| 449 |
for reference_image in reference_images:
|
| 450 |
reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
|
|
@@ -695,4 +697,4 @@ with gr.Blocks() as demo:
|
|
| 695 |
)
|
| 696 |
|
| 697 |
|
| 698 |
-
demo.launch(
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import contextlib
|
| 3 |
import gc
|
| 4 |
import json
|
|
|
|
| 180 |
global pipeline
|
| 181 |
global MultiResNetModel
|
| 182 |
|
| 183 |
+
@spaces.GPU
|
| 184 |
def load_ckpt():
|
| 185 |
global pipeline
|
| 186 |
global MultiResNetModel
|
|
|
|
| 295 |
|
| 296 |
global cur_style
|
| 297 |
cur_style = 'line + shadow'
|
| 298 |
+
@spaces.GPU
|
| 299 |
def change_ckpt(style):
|
| 300 |
global pipeline
|
| 301 |
global MultiResNetModel
|
|
|
|
| 337 |
|
| 338 |
load_ckpt()
|
| 339 |
|
| 340 |
+
@spaces.GPU
|
| 341 |
def fix_random_seeds(seed):
|
| 342 |
random.seed(seed)
|
| 343 |
np.random.seed(seed)
|
|
|
|
| 353 |
imgs.append(img)
|
| 354 |
return imgs
|
| 355 |
|
| 356 |
+
@spaces.GPU
|
| 357 |
def extract_lines(image):
|
| 358 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
| 359 |
|
|
|
|
| 378 |
torch.cuda.empty_cache()
|
| 379 |
return outimg
|
| 380 |
|
| 381 |
+
@spaces.GPU
|
| 382 |
def extract_line_image(query_image_, resolution):
|
| 383 |
tar_width, tar_height = resolution
|
| 384 |
query_image = query_image_.resize((tar_width, tar_height))
|
|
|
|
| 385 |
query_image = query_image.convert('L').convert('RGB')
|
| 386 |
extracted_line = extract_lines(query_image)
|
| 387 |
extracted_line = extracted_line.convert('L').convert('RGB')
|
| 388 |
torch.cuda.empty_cache()
|
| 389 |
return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
|
| 390 |
|
| 391 |
+
@spaces.GPU
|
| 392 |
def extract_sketch_line_image(query_image_, input_style):
|
| 393 |
global cur_style
|
| 394 |
if input_style != cur_style:
|
|
|
|
| 424 |
|
| 425 |
return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
|
| 426 |
|
| 427 |
+
@spaces.GPU(duration=120)
|
| 428 |
def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
|
| 429 |
if extracted_line is None:
|
| 430 |
gr.Info("Please preprocess the image first")
|
|
|
|
| 447 |
reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
|
| 448 |
query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
|
| 449 |
reference_patches_pil = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
for reference_image in reference_images:
|
| 452 |
reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
|
|
|
|
| 697 |
)
|
| 698 |
|
| 699 |
|
| 700 |
+
demo.launch()
|