Spaces:
Build error
Build error
| from pathlib import Path | |
| from rembg import remove | |
| import io | |
| # Apply the transformations needed | |
| from torch import autocast, nn | |
| import torch | |
| import torch.nn as nn | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.utils as utils | |
| import torch.nn as nn | |
| import pyrootutils | |
| from PIL import Image | |
| import numpy as np | |
| from utils.photo_wct import PhotoWCT | |
| from utils.photo_smooth import Propagator | |
| #from utils.smooth_filter import smooth_filter | |
| # Load models | |
| root = Path.cwd() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model | |
| p_wct = PhotoWCT().to(device) | |
| p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth")) | |
| p_pro = Propagator().to(device) | |
| stylization_module=p_wct | |
| smoothing_module=p_pro | |
| #Dependecies - To be installed - | |
| #!pip install replicate | |
| #Token - To be authenticated - | |
| #API TOKEN - 664474670af075461f85420f7b1d23d18484f826 | |
| #To be declared as an environment variable - | |
| #export REPLICATE_API_TOKEN = | |
| import replicate | |
| import os | |
| import requests | |
| def stableDiffusionAPICall(text_prompt): | |
| os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8' | |
| model = replicate.models.get("stability-ai/stable-diffusion") | |
| #text_prompt = 'photorealistic, elf fighting Sauron' | |
| gen_bg_img = model.predict(prompt=text_prompt)[0] | |
| img_data = requests.get(gen_bg_img).content | |
| # r_data = binascii.unhexlify(img_data) | |
| stream = io.BytesIO(img_data) | |
| img = Image.open(stream) | |
| del img_data | |
| return img | |
| def memory_limit_image_resize(cont_img): | |
| # prevent too small or too big images | |
| MINSIZE=400 | |
| MAXSIZE=800 | |
| orig_width = cont_img.width | |
| orig_height = cont_img.height | |
| if max(cont_img.width,cont_img.height) < MINSIZE: | |
| if cont_img.width > cont_img.height: | |
| cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC) | |
| else: | |
| cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC) | |
| if min(cont_img.width,cont_img.height) > MAXSIZE: | |
| if cont_img.width > cont_img.height: | |
| cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC) | |
| else: | |
| cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC) | |
| print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height)) | |
| return cont_img.width, cont_img.height | |
| def superimpose(input_img,back_img): | |
| matte_img = remove(input_img) | |
| back_img.paste(matte_img, (0, 0), matte_img) | |
| return back_img,input_img | |
| def style_transfer(cont_img,styl_img): | |
| with torch.no_grad(): | |
| new_cw, new_ch = memory_limit_image_resize(cont_img) | |
| new_sw, new_sh = memory_limit_image_resize(styl_img) | |
| cont_pilimg = cont_img.copy() | |
| cw = cont_pilimg.width | |
| ch = cont_pilimg.height | |
| cont_img = transforms.ToTensor()(cont_img).unsqueeze(0) | |
| styl_img = transforms.ToTensor()(styl_img).unsqueeze(0) | |
| cont_seg = [] | |
| styl_seg = [] | |
| if device == 'cuda': | |
| cont_img = cont_img.to(device) | |
| styl_img = styl_img.to(device) | |
| stylization_module.to(device) | |
| cont_seg = np.asarray(cont_seg) | |
| styl_seg = np.asarray(styl_seg) | |
| stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg) | |
| if ch != new_ch or cw != new_cw: | |
| stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear') | |
| grid = utils.make_grid(stylized_img.data, nrow=1, padding=0) | |
| ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() | |
| stylized_img = Image.fromarray(ndarr) | |
| #final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1) | |
| return stylized_img | |
| def smoother(stylized_img, over_img): | |
| if device == 'cuda': | |
| smoothing_module.to(device) | |
| final_img = smoothing_module.process(stylized_img, over_img) | |
| #final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1) | |
| return final_img | |
| if __name__ == "__main__": | |
| root = pyrootutils.setup_root(__file__, pythonpath=True) | |
| fg_path = root/"notebooks/profile_new.png" | |
| bg_path = root/"notebooks/back_img.png" | |
| ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt" | |
| #stableDiffusionAPICall("Photorealistic scenery of a concert") | |
| fg_img = Image.open(fg_path).resize((800,800)) | |
| bg_img = Image.open(bg_path).resize((800,800)) | |
| #img = combined_display(fg_img, bg_img,ckpt_path) | |
| img = superimpose(fg_img,bg_img) | |
| img.save(root/"notebooks/overlay.png") | |
| # bg_img.paste(img, (0, 0), img) | |
| # bg_img.save(root/"notebooks/check.png") | |