| | import numpy as np |
| | from src.models.utils import get_image_arr, load_model |
| | from src.data import TAIMGANTokenizer |
| | from torchvision import transforms |
| | from src.config import config_dict |
| | from pathlib import Path |
| | from enum import IntEnum, auto |
| | from PIL import Image |
| | import gradio as gr |
| | import torch |
| | from src.models.modules import ( |
| | VGGEncoder, |
| | InceptionEncoder, |
| | TextEncoder, |
| | Generator |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | IMG_CHANS = 3 |
| | IMG_HW = 256 |
| | HIDDEN_DIM = 128 |
| | C = 2 * HIDDEN_DIM |
| |
|
| | Ng = config_dict["Ng"] |
| | cond_dim = config_dict["condition_dim"] |
| | z_dim = config_dict["noise_dim"] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | models = { |
| | "COCO": { |
| | "dir": "weights/coco" |
| | }, |
| | "Bird": { |
| | "dir": "weights/bird" |
| | }, |
| | "UTKFace": { |
| | "dir": "weights/utkface" |
| | } |
| | } |
| |
|
| | for model_name in models: |
| | |
| | models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle") |
| | vocab_size = len(models[model_name]["tokenizer"].word_to_ix) |
| | |
| | models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval() |
| | models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval() |
| | models[model_name]["vgg"] = VGGEncoder().eval() |
| | models[model_name]["inception"] = InceptionEncoder(D=C).eval() |
| | |
| | load_model( |
| | generator=models[model_name]["generator"], |
| | discriminator=None, |
| | image_encoder=models[model_name]["inception"], |
| | text_encoder=models[model_name]["lstm"], |
| | output_dir=Path(models[model_name]["dir"]), |
| | device=torch.device("cpu") |
| | ) |
| |
|
| |
|
| | def change_image_with_text(image: Image, text: str, model_name: str) -> Image: |
| | """ |
| | Create an image modified by text from the original image |
| | and save it with _modified postfix |
| | |
| | :param gr.Image image: Path to the image |
| | :param str text: Desired caption |
| | """ |
| | global models |
| | tokenizer = models[model_name]["tokenizer"] |
| | G = models[model_name]["generator"] |
| | lstm = models[model_name]["lstm"] |
| | inception = models[model_name]["inception"] |
| | vgg = models[model_name]["vgg"] |
| | |
| | noise = torch.rand(z_dim).unsqueeze(0) |
| | |
| | tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0) |
| | mask = (tokens == tokenizer.pad_token_id) |
| | word_embs, sent_embs = lstm(tokens) |
| | |
| | image = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Resize((IMG_HW, IMG_HW)), |
| | transforms.Normalize( |
| | mean=(0.5, 0.5, 0.5), |
| | std=(0.5, 0.5, 0.5) |
| | ) |
| | ])(image).unsqueeze(0) |
| | |
| | vgg_features = vgg(image) |
| | local_features, global_features = inception(image) |
| | |
| | fake_image, _, _ = G(noise, sent_embs, word_embs, global_features, |
| | local_features, vgg_features, mask) |
| | |
| | fake_image = Image.fromarray(get_image_arr(fake_image)[0]) |
| | |
| | return fake_image |
| |
|
| |
|
| | |
| | |
| | |
| | demo = gr.Interface( |
| | fn=change_image_with_text, |
| | inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))], |
| | outputs=gr.Image(type="pil"), |
| | examples=[ |
| | ["src/data/stubs/bird.jpg", "black bird with blue wings", "Bird"], |
| | ["src/data/stubs/lady.jpg", "lady with blue eyes", "UTKFace"], |
| | ["src/data/stubs/bird.jpg", "white bird with black wings", "Bird"] |
| | ] |
| | ) |
| | demo.launch(debug=True) |
| |
|