Spaces:
Runtime error
Runtime error
| import base64 | |
| from huggingface_hub import hf_hub_download | |
| import streamlit as st | |
| import io | |
| import gc | |
| import json | |
| ######################################################################################################## | |
| # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM | |
| ######################################################################################################## | |
| MODEL_REPO = 'BlinkDL/clip-guided-binary-autoencoder' | |
| import torch, types | |
| import numpy as np | |
| from PIL import Image | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import torchvision as vision | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms import functional as VF | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| IMG_BITS = 13 | |
| class ResBlock(nn.Module): | |
| def __init__(self, c_x, c_hidden): | |
| super().__init__() | |
| self.B0 = nn.BatchNorm2d(c_x) | |
| self.C0 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1) | |
| self.C1 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1) | |
| self.C2 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1) | |
| self.C3 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| ACT = F.mish | |
| x = x + self.C1(ACT(self.C0(ACT(self.B0(x))))) | |
| x = x + self.C3(ACT(self.C2(x))) | |
| return x | |
| class REncoderSmall(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| dd = 8 | |
| self.Bxx = nn.BatchNorm2d(dd * 64) | |
| self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) | |
| self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) | |
| self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) | |
| self.B00 = nn.BatchNorm2d(dd * 4) | |
| self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
| self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
| self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
| self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
| self.B10 = nn.BatchNorm2d(dd * 16) | |
| self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
| self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
| self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
| self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
| self.B20 = nn.BatchNorm2d(dd * 64) | |
| self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
| self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
| self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
| self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
| self.COUT = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1) | |
| def forward(self, img): | |
| ACT = F.mish | |
| x = self.CIN(img) | |
| xx = self.Bxx(F.pixel_unshuffle(x, 8)) | |
| x = x + self.Cx1(ACT(self.Cx0(x))) | |
| x = F.pixel_unshuffle(x, 2) | |
| x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) | |
| x = x + self.C03(ACT(self.C02(x))) | |
| x = F.pixel_unshuffle(x, 2) | |
| x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) | |
| x = x + self.C13(ACT(self.C12(x))) | |
| x = F.pixel_unshuffle(x, 2) | |
| x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) | |
| x = x + self.C23(ACT(self.C22(x))) | |
| x = self.COUT(x + xx) | |
| return torch.sigmoid(x) | |
| class RDecoderSmall(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| dd = 8 | |
| self.CIN = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1) | |
| self.B00 = nn.BatchNorm2d(dd * 64) | |
| self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
| self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
| self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
| self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
| self.B10 = nn.BatchNorm2d(dd * 16) | |
| self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
| self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
| self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
| self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
| self.B20 = nn.BatchNorm2d(dd * 4) | |
| self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
| self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
| self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
| self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
| self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) | |
| self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) | |
| self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) | |
| def forward(self, code): | |
| ACT = F.mish | |
| x = self.CIN(code) | |
| x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) | |
| x = x + self.C03(ACT(self.C02(x))) | |
| x = F.pixel_shuffle(x, 2) | |
| x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) | |
| x = x + self.C13(ACT(self.C12(x))) | |
| x = F.pixel_shuffle(x, 2) | |
| x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) | |
| x = x + self.C23(ACT(self.C22(x))) | |
| x = F.pixel_shuffle(x, 2) | |
| x = x + self.Cx1(ACT(self.Cx0(x))) | |
| x = self.COUT(x) | |
| return torch.sigmoid(x) | |
| class REncoderLarge(nn.Module): | |
| def __init__(self, dd, ee, ff): | |
| super().__init__() | |
| self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1) | |
| self.BXX = nn.BatchNorm2d(dd) | |
| self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1) | |
| self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1) | |
| self.R0 = ResBlock(dd * 4, ff) | |
| self.R1 = ResBlock(dd * 16, ff) | |
| self.R2 = ResBlock(dd * 64, ff) | |
| self.CZZ = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| ACT = F.mish | |
| x = self.BXX(self.CXX(x)) | |
| x = x + self.CX1(ACT(self.CX0(x))) | |
| x = F.pixel_unshuffle(x, 2) | |
| x = self.R0(x) | |
| x = F.pixel_unshuffle(x, 2) | |
| x = self.R1(x) | |
| x = F.pixel_unshuffle(x, 2) | |
| x = self.R2(x) | |
| x = self.CZZ(x) | |
| return torch.sigmoid(x) | |
| class RDecoderLarge(nn.Module): | |
| def __init__(self, dd, ee, ff): | |
| super().__init__() | |
| self.CZZ = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1) | |
| self.BZZ = nn.BatchNorm2d(dd * 64) | |
| self.R0 = ResBlock(dd * 64, ff) | |
| self.R1 = ResBlock(dd * 16, ff) | |
| self.R2 = ResBlock(dd * 4, ff) | |
| self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1) | |
| self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1) | |
| self.CXX = nn.Conv2d(dd, 3, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| ACT = F.mish | |
| x = self.BZZ(self.CZZ(x)) | |
| x = self.R0(x) | |
| x = F.pixel_shuffle(x, 2) | |
| x = self.R1(x) | |
| x = F.pixel_shuffle(x, 2) | |
| x = self.R2(x) | |
| x = F.pixel_shuffle(x, 2) | |
| x = x + self.CX1(ACT(self.CX0(x))) | |
| x = self.CXX(x) | |
| return torch.sigmoid(x) | |
| def prepare_model(model_prefix): | |
| gc.collect() | |
| if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745': | |
| R_ENCODER, R_DECODER = REncoderSmall(), RDecoderSmall() | |
| else: | |
| if 'd16_512' in model_prefix: | |
| dd, ee, ff = 16, 64, 512 | |
| elif 'd32_1024' in model_prefix: | |
| dd, ee, ff = 32, 128, 1024 | |
| R_ENCODER = REncoderLarge(dd, ee, ff) | |
| R_DECODER = RDecoderLarge(dd, ee, ff) | |
| encoder = R_ENCODER.eval().to(device) | |
| decoder = R_DECODER.eval().to(device) | |
| encoder.load_state_dict( | |
| torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-E.pth'))) | |
| decoder.load_state_dict( | |
| torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-D.pth'))) | |
| return encoder, decoder | |
| def compute_padding(img_shape): | |
| hsize, vsize = (img_shape[1] + 7) // 8 * 8, (img_shape[0] + 7) // 8 * 8 | |
| hpad, vpad = hsize - img_shape[1], vsize - img_shape[0] | |
| left, top = hpad // 2, vpad // 2 | |
| right, bottom = hpad - left, vpad - top | |
| return left, top, right, bottom | |
| def encode(model_prefix, img, keep_shape): | |
| gc.collect() | |
| encoder, _ = prepare_model(model_prefix) | |
| with torch.no_grad(): | |
| img = VF.pil_to_tensor(img.convert("RGB")) | |
| img = VF.convert_image_dtype(img) | |
| img = img.unsqueeze(0).to(device) | |
| img_shape = img.shape[2:] | |
| if keep_shape: | |
| left, top, right, bottom = compute_padding(img_shape) | |
| img = VF.pad(img, [left, top, right, bottom], padding_mode='edge') | |
| else: | |
| img = VF.resize(img, [224, 224]) | |
| z = torch.floor(encoder(img) + 0.5) | |
| with io.BytesIO() as buffer: | |
| np.save(buffer, np.packbits(z.cpu().numpy().astype('bool'))) | |
| z_b64 = base64.b64encode(buffer.getvalue()).decode() | |
| return json.dumps({ | |
| "img_shape": img_shape, | |
| "z_shape": z.shape[2:], | |
| "keep_shape": keep_shape, | |
| "data": z_b64, | |
| }) | |
| def decode(model_prefix, z_str): | |
| gc.collect() | |
| _, decoder = prepare_model(model_prefix) | |
| z_json = json.loads(z_str) | |
| with io.BytesIO() as buffer: | |
| buffer.write(base64.b64decode(z_json["data"])) | |
| buffer.seek(0) | |
| z = np.load(buffer) | |
| img_shape = z_json["img_shape"] | |
| z_shape = z_json["z_shape"] | |
| keep_shape = z_json["keep_shape"] | |
| z = np.unpackbits(z)[:IMG_BITS * z_shape[0] * z_shape[1]].astype('float') | |
| z = z.reshape([1, IMG_BITS] + z_shape) | |
| img = decoder(torch.Tensor(z).to(device)) | |
| if keep_shape: | |
| left, top, right, bottom = compute_padding(img_shape) | |
| img = img[0, :, top:img.shape[2] - bottom, left:img.shape[3] - right] | |
| else: | |
| img = img[0] | |
| return VF.to_pil_image(img) | |
| st.title("Clip Guided Binary Autoencoder") | |
| st.write( | |
| "Model is from [@BlinkDL](https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder)" | |
| ) | |
| model_prefix = st.selectbox('The model to use', | |
| ('out-v7c_d8_256-224-13bit-OB32x0.5-745', | |
| 'out-v7d_d16_512-224-13bit-OB32x0.5-2487', | |
| 'out-v7d_d32_1024-224-13bit-OB32x0.5-5560')) | |
| encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"]) | |
| with encoder_tab: | |
| col_in, col_out = st.columns(2) | |
| keep_shape = col_in.checkbox( | |
| 'Use original size of input image instead of rescaling (Experimental)') | |
| uploaded_file = col_in.file_uploader('Choose an Image') | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| col_in.image(image, 'Input Image') | |
| z_str = encode(model_prefix, image, keep_shape) | |
| col_out.write("Encoded to:") | |
| col_out.code(z_str, language=None) | |
| col_out.image(decode(model_prefix, z_str), 'Output Image preview') | |
| with decoder_tab: | |
| col_in, col_out = st.columns(2) | |
| z_str = col_in.text_area('Paste encoded string here:') | |
| if len(z_str) > 0: | |
| image = decode(model_prefix, z_str) | |
| col_out.image(image, 'Output Image') | |