File size: 4,472 Bytes
daa13d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import requests
import os
from PIL import Image, ImageOps
import cv2
import numpy as np
import socket
import torchvision.transforms.functional as TF
from modules.shared import opts
from .general_utils import clean_gradio_path_strings

DEBUG_MODE = opts.data.get("deforum_debug_mode_enabled", False)

def load_img(path : str, shape=None, use_alpha_as_mask=False):
    # use_alpha_as_mask: Read the alpha channel of the image as the mask image
    image = load_image(path)
    if use_alpha_as_mask:
        image = image.convert('RGBA')
    else:
        image = image.convert('RGB')

    if shape is not None:
        image = image.resize(shape, resample=Image.LANCZOS)

    mask_image = None
    if use_alpha_as_mask:
        # Split alpha channel into a mask_image
        red, green, blue, alpha = Image.Image.split(image)
        mask_image = alpha.convert('L')
        image = image.convert('RGB')
        
        # check using init image alpha as mask if mask is not blank
        extrema = mask_image.getextrema()
        if (extrema == (0,0)) or extrema == (255,255):
            print("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
            print("ignoring alpha as mask.")
            mask_image = None

    return image, mask_image

def load_image(image_path :str):
    image_path = clean_gradio_path_strings(image_path)
    image = None
    if image_path.startswith('http://') or image_path.startswith('https://'):
        try:
            host = socket.gethostbyname("www.google.com")
            s = socket.create_connection((host, 80), 2)
            s.close()
        except:
            raise ConnectionError("There is no active internet connection available - please use local masks and init files only.")
        
        try:
            response = requests.get(image_path, stream=True)
        except requests.exceptions.RequestException as e:
            raise ConnectionError("Failed to download image due to no internet connection. Error: {}".format(e))
        if response.status_code == 404 or response.status_code != 200:
            raise ConnectionError("Init image url or mask image url is not valid")
        image = Image.open(response.raw).convert('RGB')
    else:
        if not os.path.exists(image_path):
            raise RuntimeError("Init image path or mask image path is not valid")
        image = Image.open(image_path).convert('RGB')
        
    return image

def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):
    """
    prepares mask for use in webui
    """
    if isinstance(mask_input, Image.Image):
        mask = mask_input
    else :
        mask = load_image(mask_input)
    mask = mask.resize(mask_shape, resample=Image.LANCZOS)
    if mask_brightness_adjust != 1:
        mask = TF.adjust_brightness(mask, mask_brightness_adjust)
    if mask_contrast_adjust != 1:
        mask = TF.adjust_contrast(mask, mask_contrast_adjust)
    mask = mask.convert('L')
    return mask

# "check_mask_for_errors" may have prevented errors in composable masks,
# but it CAUSES errors on any frame where it's all black.
# Bypassing the check below until we can fix it even better.
# This may break composable masks, but it makes ACTUAL masks usable.
def check_mask_for_errors(mask_input, invert_mask=False):
    extrema = mask_input.getextrema()
    if (invert_mask):
        if extrema == (255,255): 
            print("after inverting mask will be blank. ignoring mask")  
            return None
    elif extrema == (0,0): 
        print("mask is blank. ignoring mask")  
        return None
    else:
        return mask_input    
 
def get_mask(args):
    # return check_mask_for_errors(
    #     prepare_mask(args.mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)
    # )
    return prepare_mask(args.mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)

def get_mask_from_file(mask_file, args):
    # return check_mask_for_errors(
    #     prepare_mask(mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)
    # )
    return prepare_mask(mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)

def blank_if_none(mask, w, h, mode):
    return Image.new(mode, (w, h), (0)) if mask is None else mask

def none_if_blank(mask):
    return None if mask.getextrema() == (0,0) else mask