| | |
| | |
| | """ |
| | streamlit app demo |
| | how to run: |
| | streamlit run app.py --server.port 8501 |
| | |
| | @author: Tu Bui @surrey.ac.uk |
| | """ |
| | import os, sys, torch |
| | import argparse |
| | from pathlib import Path |
| | import numpy as np |
| | import pickle |
| | import pytorch_lightning as pl |
| | from torchvision import transforms |
| | import argparse |
| | from ldm.util import instantiate_from_config |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from tools.augment_imagenetc import RandomImagenetC |
| | from io import BytesIO |
| | from tools.helpers import welcome_message |
| | from tools.ecc import BCH, RSC |
| |
|
| | import streamlit as st |
| | from streamlit.source_util import ( |
| | page_icon_and_name, |
| | calc_md5, |
| | get_pages, |
| | _on_pages_changed |
| | ) |
| |
|
| | model_names = ['UNet'] |
| |
|
| |
|
| | def delete_page(main_script_path_str, page_name): |
| |
|
| | current_pages = get_pages(main_script_path_str) |
| |
|
| | for key, value in current_pages.items(): |
| | print(value['page_name']) |
| | if value['page_name'] == page_name: |
| | del current_pages[key] |
| | break |
| | else: |
| | pass |
| | _on_pages_changed.send() |
| |
|
| |
|
| | def add_page(main_script_path_str, page_name): |
| | |
| | pages = get_pages(main_script_path_str) |
| | main_script_path = Path(main_script_path_str) |
| | pages_dir = main_script_path.parent / "pages" |
| | |
| | script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0] |
| | script_path_str = str(script_path.resolve()) |
| | pi, pn = page_icon_and_name(script_path) |
| | psh = calc_md5(script_path_str) |
| | pages[psh] = { |
| | "page_script_hash": psh, |
| | "page_name": pn, |
| | "icon": pi, |
| | "script_path": script_path_str, |
| | } |
| | _on_pages_changed.send() |
| |
|
| | def unormalize(x): |
| | |
| | x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
| | return x |
| |
|
| | def to_bytes(x, mime): |
| | x = Image.fromarray(x) |
| | buf = BytesIO() |
| | f = "JPEG" if mime == 'image/jpeg' else "PNG" |
| | x.save(buf, format=f) |
| | byte_im = buf.getvalue() |
| | return byte_im |
| |
|
| |
|
| | def load_UNet(args): |
| | print('args: ', args) |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | config_file = args.config |
| | weight_file = args.weight |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | if weight_file.startswith('http'): |
| | weight_dir = Path('./weights') |
| | weight_dir.mkdir(exist_ok=True) |
| | weight_path = weight_dir / weight_file.split('/')[-1] |
| | config_path = weight_dir / config_file.split('/')[-1] |
| | if not weight_path.exists(): |
| | import wget |
| | print(f'Downloading {weight_file}...') |
| | with st.spinner("Downloading model... this may take awhile!"): |
| | wget.download(weight_file, str(weight_path)) |
| | wget.download(config_file, str(config_path)) |
| | weight_file = str(weight_path) |
| | config_file = str(config_path) |
| |
|
| | config = OmegaConf.load(config_file).model |
| | secret_len = config.params.secret_len |
| | print(f'Secret length: {secret_len}') |
| | model = instantiate_from_config(config) |
| | state_dict = torch.load(weight_file, map_location=torch.device('cpu')) |
| | if 'global_step' in state_dict: |
| | print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}') |
| |
|
| | if 'state_dict' in state_dict: |
| | state_dict = state_dict['state_dict'] |
| | misses, ignores = model.load_state_dict(state_dict, strict=False) |
| | print(f'Missed keys: {misses}\nIgnore keys: {ignores}') |
| | model = model.to(device) |
| | model.eval() |
| | return model, secret_len |
| |
|
| | def embed_secret(model_name, model, cover, tform, secret): |
| | if model_name == 'UNet': |
| | w, h = cover.size |
| | with torch.no_grad(): |
| | im = tform(cover).unsqueeze(0).to(model.device) |
| | stego, _ = model(im, secret) |
| | res = (stego.clamp(-1,1) - im) |
| | res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear') |
| | res = res.permute(0,2,3,1).cpu().numpy() |
| | stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5 |
| | stego_uint8 = stego_uint8.astype(np.uint8) |
| | else: |
| | raise NotImplementedError |
| | return stego_uint8 |
| |
|
| | def identity(x): |
| | return x |
| |
|
| | def decode_secret(model_name, model, im, tform): |
| | if model_name in ['RoSteALS', 'UNet']: |
| | with torch.no_grad(): |
| | im = tform(im).unsqueeze(0).to(model.device) |
| | secret_pred = (model.decoder(im) > 0).cpu().numpy() |
| | else: |
| | raise NotImplementedError |
| | return secret_pred |
| |
|
| |
|
| | @st.cache_resource |
| | def load_model(model_name, _args): |
| | if model_name == 'UNet': |
| | tform_emb = transforms.Compose([ |
| | transforms.Resize((256,256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| | ]) |
| | tform_det = transforms.Compose([ |
| | transforms.Resize((224,224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| | ]) |
| | model, secret_len = load_UNet(_args) |
| | else: |
| | raise NotImplementedError |
| | return model, tform_emb, tform_det, secret_len |
| |
|
| |
|
| | @st.cache_resource |
| | def load_ecc(ecc_name, secret_len): |
| | if ecc_name == 'BCH': |
| | if secret_len == 160: |
| | ecc = BCH(285, 10, secret_len, verbose=True) |
| | elif secret_len == 100: |
| | ecc = BCH(137, 5, payload_len= secret_len, verbose=True) |
| | elif ecc_name == 'RSC': |
| | ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True) |
| | return ecc |
| |
|
| |
|
| | class Resize(object): |
| | def __init__(self, size=None) -> None: |
| | self.size = size |
| | def __call__(self, x, size=None): |
| | if isinstance(x, np.ndarray): |
| | x = Image.fromarray(x) |
| | new_size = size if size is not None else self.size |
| | if min(x.size) > min(new_size): |
| | x = x.resize(new_size, Image.LANCZOS) |
| | else: |
| | x = x.resize(new_size, Image.BILINEAR) |
| | x = np.array(x) |
| | return x |
| |
|
| |
|
| | def parse_st_args(): |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt') |
| | parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml') |
| | |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def app(args): |
| | |
| | st.title('Watermarking Demo') |
| | |
| | model_name = st.selectbox("Choose the model", model_names) |
| | model, tform_emb, tform_det, secret_len = load_model(model_name, args) |
| | display_width = 300 |
| | |
| | ecc = load_ecc('BCH', secret_len) |
| |
|
| | |
| | st.subheader("Input") |
| | image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"]) |
| | if image_file is not None: |
| | print('Image: ', image_file.name) |
| | ext = image_file.name.split('.')[-1] |
| | im = Image.open(image_file).convert('RGB') |
| | size0 = im.size |
| | st.image(im, width=display_width) |
| | secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret') |
| | assert len(secret_text) <= ecc.data_len |
| |
|
| | |
| | st.subheader("Embed results") |
| | status = st.empty() |
| | prep = transforms.Compose([ |
| | transforms.Resize((256,256)), |
| | transforms.CenterCrop((224,224)) |
| | ]) |
| | if image_file is not None and secret_text is not None: |
| | secret = ecc.encode_text([secret_text]) |
| | secret = torch.from_numpy(secret).float().to(model.device) |
| | |
| | stego = embed_secret(model_name, model, im, tform_emb, secret) |
| | st.image(stego, width=display_width) |
| |
|
| | |
| | mime='image/jpeg' if ext=='jpg' else f'image/{ext}' |
| | stego_bytes = to_bytes(stego, mime) |
| | st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime) |
| |
|
| | |
| | stego_processed = prep(Image.fromarray(stego)) |
| | secret_pred = decode_secret(model_name, model, stego_processed, tform_det) |
| | bit_acc = (secret_pred == secret.cpu().numpy()).mean() |
| | secret_pred = ecc.decode_text(secret_pred)[0] |
| | status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True) |
| | status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True) |
| |
|
| | if __name__ == '__main__': |
| | args = parse_st_args() |
| | app(args) |
| | |
| |
|
| | |