| | |
| | import os, re, time |
| |
|
| | os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache") |
| | os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache") |
| |
|
| | import streamlit as st |
| | import PIL |
| | import cv2 |
| | import numpy as np |
| | import uuid |
| | from zipfile import ZipFile, ZIP_DEFLATED |
| | from io import BytesIO |
| | from random import randint |
| | from datetime import datetime |
| |
|
| | from src.deoldify import device |
| | from src.deoldify.device_id import DeviceId |
| | from src.deoldify.visualize import * |
| | from src.app_utils import get_model_bin |
| |
|
| |
|
| | device.set(device=DeviceId.CPU) |
| |
|
| |
|
| | @st.cache(allow_output_mutation=True, show_spinner=False) |
| | def load_model(model_dir, option): |
| | if option.lower() == 'artistic': |
| | model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' |
| | get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) |
| | colorizer = get_image_colorizer(artistic=True) |
| | elif option.lower() == 'stable': |
| | model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" |
| | get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) |
| | colorizer = get_image_colorizer(artistic=False) |
| |
|
| | return colorizer |
| |
|
| |
|
| | def resize_img(input_img, max_size): |
| | img = input_img.copy() |
| | img_height, img_width = img.shape[0],img.shape[1] |
| |
|
| | if max(img_height, img_width) > max_size: |
| | if img_height > img_width: |
| | new_width = img_width*(max_size/img_height) |
| | new_height = max_size |
| | resized_img = cv2.resize(img,(int(new_width), int(new_height))) |
| | return resized_img |
| |
|
| | elif img_height <= img_width: |
| | new_width = img_height*(max_size/img_width) |
| | new_height = max_size |
| | resized_img = cv2.resize(img,(int(new_width), int(new_height))) |
| | return resized_img |
| |
|
| | return img |
| |
|
| |
|
| | def colorize_image(pil_image, img_size=800) -> "PIL.Image": |
| | |
| | pil_img = pil_image.convert("RGB") |
| | img_rgb = np.array(pil_img) |
| | resized_img_rgb = resize_img(img_rgb, img_size) |
| | resized_pil_img = PIL.Image.fromarray(resized_img_rgb) |
| |
|
| | |
| | output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) |
| | |
| | return output_pil_img |
| |
|
| |
|
| | def image_download_button(pil_image, filename: str, fmt: str, label="Download"): |
| | if fmt not in ["jpg", "png"]: |
| | raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)") |
| | |
| | pil_format = "JPEG" if fmt == "jpg" else "PNG" |
| | file_format = "jpg" if fmt == "jpg" else "png" |
| | mime = "image/jpeg" if fmt == "jpg" else "image/png" |
| | |
| | buf = BytesIO() |
| | pil_image.save(buf, format=pil_format) |
| | |
| | return st.download_button( |
| | label=label, |
| | data=buf.getvalue(), |
| | file_name=f'{filename}.{file_format}', |
| | mime=mime, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | st_color_option = "Artistic" |
| |
|
| | |
| | try: |
| | with st.spinner("Loading..."): |
| | print('before loading the model') |
| | colorizer = load_model('models/', st_color_option) |
| | print('after loading the model') |
| |
|
| | except Exception as e: |
| | colorizer = None |
| | print('Error while loading the model. Please refresh the page') |
| | print(e) |
| | st.write("**App loading error. Please try again later.**") |
| |
|
| |
|
| |
|
| | if colorizer is not None: |
| | st.title("AI ์ฌ์ง ์ปฌ๋ฌ ๋ณต์") |
| |
|
| | st.image(open("assets/demo.jpg", "rb").read()) |
| |
|
| | st.markdown( |
| | """ |
| | ํ๋ฐฑ ์ฌ์ง์ ์
๋ก๋ํ๋ฉด, AI๊ฐ ์๋์ผ๋ก ์ด๋ฅผ ์ธ์ ๋ฐ ๋ถ์ํ์ฌ ์ปฌ๋ฌ๋ก ๋ณต์์ ํฉ๋๋ค.** |
| | """ |
| | ) |
| | |
| | uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"]) |
| |
|
| | if uploaded_file is not None: |
| | bytes_data = uploaded_file.getvalue() |
| | img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB") |
| | |
| | with st.expander("Original photo", True): |
| | st.image(img_input) |
| |
|
| | if st.button("์ปฌ๋ฌ ๋ณต์ํ๊ธฐ") and uploaded_file is not None: |
| | |
| | with st.spinner("์ปฌ๋ฌ ๋ณต์์ค์
๋๋ค. ์ ์๋ง ๊ธฐ๋ค๋ฆฌ์ธ์."): |
| | img_output = colorize_image(img_input) |
| | img_output = img_output.resize(img_input.size) |
| | |
| | |
| | |
| | now = datetime.now().strftime("%Y%m%d-%H%M%S-%f") |
| | img_input.convert("RGB").save(f"./output/{now}-input.jpg") |
| | img_output.convert("RGB").save(f"./output/{now}-output.jpg") |
| | |
| | st.write("์ปฌ๋ฌ ๋ณต์์ค์
๋๋ค. ์ ์๋ง ๊ธฐ๋ค๋ฆฌ์ธ์.") |
| | st.image(img_output) |
| | |
| | |
| | uploaded_name = os.path.splitext(uploaded_file.name)[0] |
| | image_download_button( |
| | pil_image=img_output, |
| | filename=uploaded_name, |
| | fmt="jpg", |
| | label="Download Image" |
| | ) |
| | |
| |
|