three-model version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- app.py +117 -0
- requirements.txt +5 -0
- src/__init__.py +2 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/__pycache__/config.cpython-39.pyc +0 -0
- src/config.py +47 -0
- src/data/.gitkeep +0 -0
- src/data/__init__.py +5 -0
- src/data/__pycache__/__init__.cpython-39.pyc +0 -0
- src/data/__pycache__/collate.cpython-39.pyc +0 -0
- src/data/__pycache__/datasets.cpython-39.pyc +0 -0
- src/data/__pycache__/tokenizer.cpython-39.pyc +0 -0
- src/data/collate.py +43 -0
- src/data/datasets.py +387 -0
- src/data/stubs/bird.jpg +0 -0
- src/data/stubs/pigeon.jpg +0 -0
- src/data/stubs/rohit.jpeg +0 -0
- src/data/tokenizer.py +23 -0
- src/features/.gitkeep +0 -0
- src/features/__init__.py +0 -0
- src/features/build_features.py +0 -0
- src/models/.gitkeep +0 -0
- src/models/__init__.py +4 -0
- src/models/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/__pycache__/losses.cpython-39.pyc +0 -0
- src/models/__pycache__/train_model.cpython-39.pyc +0 -0
- src/models/__pycache__/utils.cpython-39.pyc +0 -0
- src/models/losses.py +344 -0
- src/models/modules/__init__.py +12 -0
- src/models/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/acm.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/attention.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/cond_augment.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/conv_utils.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/discriminator.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/downsample.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/generator.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/image_encoder.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/residual.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/text_encoder.cpython-39.pyc +0 -0
- src/models/modules/__pycache__/upsample.cpython-39.pyc +0 -0
- src/models/modules/acm.py +37 -0
- src/models/modules/attention.py +88 -0
- src/models/modules/cond_augment.py +57 -0
- src/models/modules/conv_utils.py +78 -0
- src/models/modules/discriminator.py +144 -0
- src/models/modules/downsample.py +14 -0
- src/models/modules/generator.py +300 -0
- src/models/modules/image_encoder.py +138 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/*
|
| 2 |
+
.idea/*
|
app.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np # this should come first to mitigate mlk-service bug
|
| 2 |
+
from src.models.utils import get_image_arr, load_model
|
| 3 |
+
from src.data import TAIMGANTokenizer
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from src.config import config_dict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from enum import IntEnum, auto
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import torch
|
| 11 |
+
from src.models.modules import (
|
| 12 |
+
VGGEncoder,
|
| 13 |
+
InceptionEncoder,
|
| 14 |
+
TextEncoder,
|
| 15 |
+
Generator
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
##########
|
| 19 |
+
# PARAMS #
|
| 20 |
+
##########
|
| 21 |
+
|
| 22 |
+
IMG_CHANS = 3 # RGB channels for image
|
| 23 |
+
IMG_HW = 256 # height and width of images
|
| 24 |
+
HIDDEN_DIM = 128 # hidden dimensions of lstm cell in one direction
|
| 25 |
+
C = 2 * HIDDEN_DIM # length of embeddings
|
| 26 |
+
|
| 27 |
+
Ng = config_dict["Ng"]
|
| 28 |
+
cond_dim = config_dict["condition_dim"]
|
| 29 |
+
z_dim = config_dict["noise_dim"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
###############
|
| 33 |
+
# LOAD MODELS #
|
| 34 |
+
###############
|
| 35 |
+
|
| 36 |
+
models = {
|
| 37 |
+
"COCO": {
|
| 38 |
+
"dir": "weights/coco"
|
| 39 |
+
},
|
| 40 |
+
"Bird": {
|
| 41 |
+
"dir": "weights/bird"
|
| 42 |
+
},
|
| 43 |
+
"UTKFace": {
|
| 44 |
+
"dir": "weights/utkface"
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
for model_name in models:
|
| 49 |
+
# create tokenizer
|
| 50 |
+
models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle")
|
| 51 |
+
vocab_size = len(models[model_name]["tokenizer"].word_to_ix)
|
| 52 |
+
# instantiate models
|
| 53 |
+
models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval()
|
| 54 |
+
models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval()
|
| 55 |
+
models[model_name]["vgg"] = VGGEncoder().eval()
|
| 56 |
+
models[model_name]["inception"] = InceptionEncoder(D=C).eval()
|
| 57 |
+
# load models
|
| 58 |
+
load_model(
|
| 59 |
+
generator=models[model_name]["generator"],
|
| 60 |
+
discriminator=None,
|
| 61 |
+
image_encoder=models[model_name]["inception"],
|
| 62 |
+
text_encoder=models[model_name]["lstm"],
|
| 63 |
+
output_dir=Path(models[model_name]["dir"]),
|
| 64 |
+
device=torch.device("cpu")
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def change_image_with_text(image: Image, text: str, model_name: str) -> Image:
|
| 69 |
+
"""
|
| 70 |
+
Create an image modified by text from the original image
|
| 71 |
+
and save it with _modified postfix
|
| 72 |
+
|
| 73 |
+
:param gr.Image image: Path to the image
|
| 74 |
+
:param str text: Desired caption
|
| 75 |
+
"""
|
| 76 |
+
global models
|
| 77 |
+
tokenizer = models[model_name]["tokenizer"]
|
| 78 |
+
G = models[model_name]["generator"]
|
| 79 |
+
lstm = models[model_name]["lstm"]
|
| 80 |
+
inception = models[model_name]["inception"]
|
| 81 |
+
vgg = models[model_name]["vgg"]
|
| 82 |
+
# generate some noise
|
| 83 |
+
noise = torch.rand(z_dim).unsqueeze(0)
|
| 84 |
+
# transform input text and get masks with embeddings
|
| 85 |
+
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
|
| 86 |
+
mask = (tokens == tokenizer.pad_token_id)
|
| 87 |
+
word_embs, sent_embs = lstm(tokens)
|
| 88 |
+
# open the image and transform it to the tensor
|
| 89 |
+
image = transforms.Compose([
|
| 90 |
+
transforms.ToTensor(),
|
| 91 |
+
transforms.Resize((IMG_HW, IMG_HW)),
|
| 92 |
+
transforms.Normalize(
|
| 93 |
+
mean=(0.5, 0.5, 0.5),
|
| 94 |
+
std=(0.5, 0.5, 0.5)
|
| 95 |
+
)
|
| 96 |
+
])(image).unsqueeze(0)
|
| 97 |
+
# obtain visual features of the image
|
| 98 |
+
vgg_features = vgg(image)
|
| 99 |
+
local_features, global_features = inception(image)
|
| 100 |
+
# generate new image from the old one
|
| 101 |
+
fake_image, _, _ = G(noise, sent_embs, word_embs, global_features,
|
| 102 |
+
local_features, vgg_features, mask)
|
| 103 |
+
# denormalize the image
|
| 104 |
+
fake_image = Image.fromarray(get_image_arr(fake_image)[0])
|
| 105 |
+
# return image in gradio format
|
| 106 |
+
return fake_image
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
##########
|
| 110 |
+
# GRADIO #
|
| 111 |
+
##########
|
| 112 |
+
demo = gr.Interface(
|
| 113 |
+
fn=change_image_with_text,
|
| 114 |
+
inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))],
|
| 115 |
+
outputs=gr.Image(type="pil")
|
| 116 |
+
)
|
| 117 |
+
demo.launch(debug=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Pillow
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
torchaudio
|
| 5 |
+
nltk
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Config file for the project."""
|
| 2 |
+
from .config import config_dict, update_config
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (260 Bytes). View file
|
|
|
src/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configurations for the project."""
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 8 |
+
|
| 9 |
+
repo_path = Path(__file__).parent.parent.absolute()
|
| 10 |
+
output_path = repo_path / "models"
|
| 11 |
+
|
| 12 |
+
config_dict = {
|
| 13 |
+
"Ng": 32,
|
| 14 |
+
"D": 256,
|
| 15 |
+
"condition_dim": 100,
|
| 16 |
+
"noise_dim": 100,
|
| 17 |
+
"lr_config": {
|
| 18 |
+
"disc_lr": 2e-4,
|
| 19 |
+
"gen_lr": 2e-4,
|
| 20 |
+
"img_encoder_lr": 3e-3,
|
| 21 |
+
"text_encoder_lr": 3e-3,
|
| 22 |
+
},
|
| 23 |
+
"batch_size": 64,
|
| 24 |
+
"device": device,
|
| 25 |
+
"epochs": 200,
|
| 26 |
+
"output_dir": output_path,
|
| 27 |
+
"snapshot": 5,
|
| 28 |
+
"const_dict": {
|
| 29 |
+
"smooth_val_gen": 0.999,
|
| 30 |
+
"lambda1": 1,
|
| 31 |
+
"lambda2": 1,
|
| 32 |
+
"lambda3": 1,
|
| 33 |
+
"lambda4": 1,
|
| 34 |
+
"gamma1": 4,
|
| 35 |
+
"gamma2": 5,
|
| 36 |
+
"gamma3": 10,
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def update_config(cfg_dict: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
|
| 42 |
+
"""
|
| 43 |
+
Function to update the configuration dictionary.
|
| 44 |
+
"""
|
| 45 |
+
for key, value in kwargs.items():
|
| 46 |
+
cfg_dict[key] = value
|
| 47 |
+
return cfg_dict
|
src/data/.gitkeep
ADDED
|
File without changes
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset and custom collate function to load"""
|
| 2 |
+
|
| 3 |
+
from .collate import custom_collate
|
| 4 |
+
from .datasets import TextImageDataset
|
| 5 |
+
from .tokenizer import TAIMGANTokenizer
|
src/data/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (372 Bytes). View file
|
|
|
src/data/__pycache__/collate.cpython-39.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
src/data/__pycache__/datasets.cpython-39.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
src/data/__pycache__/tokenizer.cpython-39.pyc
ADDED
|
Binary file (1.55 kB). View file
|
|
|
src/data/collate.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom collate function for the data loader."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def custom_collate(batch: List[Any], device: Any) -> Any:
|
| 10 |
+
"""
|
| 11 |
+
Custom collate function to be used in the data loader.
|
| 12 |
+
:param batch: list, with length equal to number of batches.
|
| 13 |
+
:return: processed batch of data [add padding to text, stack tensors in batch]
|
| 14 |
+
"""
|
| 15 |
+
img, correct_capt, curr_class, word_labels = zip(*batch)
|
| 16 |
+
batched_img = torch.stack(img, dim=0).to(
|
| 17 |
+
device
|
| 18 |
+
) # shape: (batch_size, 3, height, width)
|
| 19 |
+
correct_capt_len = torch.tensor(
|
| 20 |
+
[len(capt) for capt in correct_capt], dtype=torch.int64
|
| 21 |
+
).unsqueeze(
|
| 22 |
+
1
|
| 23 |
+
) # shape: (batch_size, 1)
|
| 24 |
+
batched_correct_capt = pad_sequence(
|
| 25 |
+
correct_capt, batch_first=True, padding_value=0
|
| 26 |
+
).to(
|
| 27 |
+
device
|
| 28 |
+
) # shape: (batch_size, max_seq_len)
|
| 29 |
+
batched_curr_class = torch.stack(curr_class, dim=0).to(
|
| 30 |
+
device
|
| 31 |
+
) # shape: (batch_size, 1)
|
| 32 |
+
batched_word_labels = pad_sequence(
|
| 33 |
+
word_labels, batch_first=True, padding_value=0
|
| 34 |
+
).to(
|
| 35 |
+
device
|
| 36 |
+
) # shape: (batch_size, max_seq_len)
|
| 37 |
+
return (
|
| 38 |
+
batched_img,
|
| 39 |
+
batched_correct_capt,
|
| 40 |
+
correct_capt_len,
|
| 41 |
+
batched_curr_class,
|
| 42 |
+
batched_word_labels,
|
| 43 |
+
)
|
src/data/datasets.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytorch Dataset classes for the datasets used in the project."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import nltk
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision.transforms.functional as F
|
| 13 |
+
from nltk.tokenize import RegexpTokenizer
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TextImageDataset(Dataset): # type: ignore
|
| 20 |
+
"""Custom PyTorch Dataset class to load Image and Text data."""
|
| 21 |
+
|
| 22 |
+
# pylint: disable=too-many-instance-attributes
|
| 23 |
+
# pylint: disable=too-many-locals
|
| 24 |
+
# pylint: disable=too-many-function-args
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self, data_path: str, split: str, num_captions: int, transform: Any = None
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
:param data_path: Path to the data directory. [i.e. can be './birds/', or './coco/]
|
| 31 |
+
:param split: 'train' or 'test' split
|
| 32 |
+
:param num_captions: number of captions present per image.
|
| 33 |
+
[For birds, this is 10, for coco, this is 5]
|
| 34 |
+
:param transform: PyTorch transform to apply to the images.
|
| 35 |
+
"""
|
| 36 |
+
self.transform = transform
|
| 37 |
+
self.bound_box_map = None
|
| 38 |
+
self.file_names = self.load_filenames(data_path, split)
|
| 39 |
+
self.data_path = data_path
|
| 40 |
+
self.num_captions_per_image = num_captions
|
| 41 |
+
(
|
| 42 |
+
self.captions,
|
| 43 |
+
self.ix_to_word,
|
| 44 |
+
self.word_to_ix,
|
| 45 |
+
self.vocab_len,
|
| 46 |
+
) = self.get_capt_and_vocab(data_path, split)
|
| 47 |
+
self.normalize = transforms.Compose(
|
| 48 |
+
[
|
| 49 |
+
transforms.ToTensor(),
|
| 50 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
| 51 |
+
]
|
| 52 |
+
)
|
| 53 |
+
self.class_ids = self.get_class_id(data_path, split, len(self.file_names))
|
| 54 |
+
if self.data_path.endswith("birds/"):
|
| 55 |
+
self.bound_box_map = self.get_bound_box(data_path)
|
| 56 |
+
|
| 57 |
+
elif self.data_path.endswith("coco/"):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"Invalid data path. Please ensure the data [CUB/COCO] is stored in correct folders."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def __len__(self) -> int:
|
| 66 |
+
"""Return the length of the dataset."""
|
| 67 |
+
return len(self.file_names)
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, idx: int) -> Any:
|
| 70 |
+
"""
|
| 71 |
+
Return the item at index idx.
|
| 72 |
+
:param idx: index of the item to return
|
| 73 |
+
:return img_tensor: image tensor
|
| 74 |
+
:return correct_caption: correct caption for the image [list of word indices]
|
| 75 |
+
:return curr_class_id: class id of the image
|
| 76 |
+
:return word_labels: POS_tagged word labels [1 for noun and adjective, 0 else]
|
| 77 |
+
|
| 78 |
+
"""
|
| 79 |
+
file_name = self.file_names[idx]
|
| 80 |
+
curr_class_id = self.class_ids[idx]
|
| 81 |
+
|
| 82 |
+
if self.bound_box_map is not None:
|
| 83 |
+
bbox = self.bound_box_map[file_name]
|
| 84 |
+
images_dir = os.path.join(self.data_path, "CUB_200_2011/images")
|
| 85 |
+
else:
|
| 86 |
+
bbox = None
|
| 87 |
+
images_dir = os.path.join(self.data_path, "images")
|
| 88 |
+
|
| 89 |
+
img_path = os.path.join(images_dir, file_name + ".jpg")
|
| 90 |
+
img_tensor = self.get_image(img_path, bbox, self.transform)
|
| 91 |
+
|
| 92 |
+
rand_sent_idx = np.random.randint(0, self.num_captions_per_image)
|
| 93 |
+
rand_sent_idx = idx * self.num_captions_per_image + rand_sent_idx
|
| 94 |
+
|
| 95 |
+
correct_caption = torch.tensor(self.captions[rand_sent_idx], dtype=torch.int64)
|
| 96 |
+
num_words = len(correct_caption)
|
| 97 |
+
|
| 98 |
+
capt_token_list = []
|
| 99 |
+
for i in range(num_words):
|
| 100 |
+
capt_token_list.append(self.ix_to_word[correct_caption[i].item()])
|
| 101 |
+
|
| 102 |
+
pos_tag_list = nltk.tag.pos_tag(capt_token_list)
|
| 103 |
+
word_labels = []
|
| 104 |
+
|
| 105 |
+
for pos_tag in pos_tag_list:
|
| 106 |
+
if (
|
| 107 |
+
"NN" in pos_tag[1] or "JJ" in pos_tag[1]
|
| 108 |
+
): # check for Nouns and Adjective only
|
| 109 |
+
word_labels.append(1)
|
| 110 |
+
else:
|
| 111 |
+
word_labels.append(0)
|
| 112 |
+
|
| 113 |
+
word_labels = torch.tensor(word_labels).float() # type: ignore
|
| 114 |
+
|
| 115 |
+
curr_class_id = torch.tensor(curr_class_id, dtype=torch.int64).unsqueeze(0)
|
| 116 |
+
|
| 117 |
+
return (
|
| 118 |
+
img_tensor,
|
| 119 |
+
correct_caption,
|
| 120 |
+
curr_class_id,
|
| 121 |
+
word_labels,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def get_capt_and_vocab(self, data_dir: str, split: str) -> Any:
|
| 125 |
+
"""
|
| 126 |
+
Helper function to get the captions, vocab dict for each image.
|
| 127 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
| 128 |
+
:param split: 'train' or 'test' split
|
| 129 |
+
:return captions: list of all captions for each image
|
| 130 |
+
:return ix_to_word: dictionary mapping index to word
|
| 131 |
+
:return word_to_ix: dictionary mapping word to index
|
| 132 |
+
:return num_words: number of unique words in the vocabulary
|
| 133 |
+
"""
|
| 134 |
+
captions_ckpt_path = os.path.join(data_dir, "stubs/captions.pickle")
|
| 135 |
+
if os.path.exists(
|
| 136 |
+
captions_ckpt_path
|
| 137 |
+
): # check if previously processed captions exist
|
| 138 |
+
with open(captions_ckpt_path, "rb") as ckpt_file:
|
| 139 |
+
captions = pickle.load(ckpt_file)
|
| 140 |
+
train_captions, test_captions = captions[0], captions[1]
|
| 141 |
+
ix_to_word, word_to_ix = captions[2], captions[3]
|
| 142 |
+
num_words = len(ix_to_word)
|
| 143 |
+
del captions
|
| 144 |
+
if split == "train":
|
| 145 |
+
return train_captions, ix_to_word, word_to_ix, num_words
|
| 146 |
+
return test_captions, ix_to_word, word_to_ix, num_words
|
| 147 |
+
|
| 148 |
+
else: # if not, process the captions and save them
|
| 149 |
+
train_files = self.load_filenames(data_dir, "train")
|
| 150 |
+
test_files = self.load_filenames(data_dir, "test")
|
| 151 |
+
|
| 152 |
+
train_captions_tokenized = self.get_tokenized_captions(
|
| 153 |
+
data_dir, train_files
|
| 154 |
+
)
|
| 155 |
+
test_captions_tokenized = self.get_tokenized_captions(
|
| 156 |
+
data_dir, test_files
|
| 157 |
+
) # we need both train and test captions to build the vocab
|
| 158 |
+
|
| 159 |
+
(
|
| 160 |
+
train_captions,
|
| 161 |
+
test_captions,
|
| 162 |
+
ix_to_word,
|
| 163 |
+
word_to_ix,
|
| 164 |
+
num_words,
|
| 165 |
+
) = self.build_vocab( # type: ignore
|
| 166 |
+
train_captions_tokenized, test_captions_tokenized, split
|
| 167 |
+
)
|
| 168 |
+
vocab_list = [train_captions, test_captions, ix_to_word, word_to_ix]
|
| 169 |
+
with open(captions_ckpt_path, "wb") as ckpt_file:
|
| 170 |
+
pickle.dump(vocab_list, ckpt_file)
|
| 171 |
+
|
| 172 |
+
if split == "train":
|
| 173 |
+
return train_captions, ix_to_word, word_to_ix, num_words
|
| 174 |
+
if split == "test":
|
| 175 |
+
return test_captions, ix_to_word, word_to_ix, num_words
|
| 176 |
+
raise ValueError("Invalid split. Please use 'train' or 'test'")
|
| 177 |
+
|
| 178 |
+
def build_vocab(
|
| 179 |
+
self, tokenized_captions_train: list, tokenized_captions_test: list # type: ignore
|
| 180 |
+
) -> Any:
|
| 181 |
+
"""
|
| 182 |
+
Helper function which builds the vocab dicts.
|
| 183 |
+
:param tokenized_captions_train: list containing all the
|
| 184 |
+
train tokenized captions in the dataset. This is list of lists.
|
| 185 |
+
:param tokenized_captions_test: list containing all the
|
| 186 |
+
test tokenized captions in the dataset. This is list of lists.
|
| 187 |
+
:return train_captions_int: list of all captions in training,
|
| 188 |
+
where each word is replaced by its index in the vocab
|
| 189 |
+
:return test_captions_int: list of all captions in test,
|
| 190 |
+
where each word is replaced by its index in the vocab
|
| 191 |
+
:return ix_to_word: dictionary mapping index to word
|
| 192 |
+
:return word_to_ix: dictionary mapping word to index
|
| 193 |
+
:return num_words: number of unique words in the vocabulary
|
| 194 |
+
"""
|
| 195 |
+
vocab = defaultdict(int) # type: ignore
|
| 196 |
+
total_captions = tokenized_captions_train + tokenized_captions_test
|
| 197 |
+
for caption in total_captions:
|
| 198 |
+
for word in caption:
|
| 199 |
+
vocab[word] += 1
|
| 200 |
+
|
| 201 |
+
# sort vocab dict by frequency in descending order
|
| 202 |
+
vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) # type: ignore
|
| 203 |
+
|
| 204 |
+
ix_to_word = {}
|
| 205 |
+
word_to_ix = {}
|
| 206 |
+
ix_to_word[0] = "<end>"
|
| 207 |
+
word_to_ix["<end>"] = 0
|
| 208 |
+
|
| 209 |
+
word_idx = 1
|
| 210 |
+
for word, _ in vocab:
|
| 211 |
+
word_to_ix[word] = word_idx
|
| 212 |
+
ix_to_word[word_idx] = word
|
| 213 |
+
word_idx += 1
|
| 214 |
+
|
| 215 |
+
train_captions_int = [] # we want to convert words to indices in vocab.
|
| 216 |
+
for caption in tokenized_captions_train:
|
| 217 |
+
curr_caption_int = []
|
| 218 |
+
for word in caption:
|
| 219 |
+
curr_caption_int.append(word_to_ix[word])
|
| 220 |
+
|
| 221 |
+
train_captions_int.append(curr_caption_int)
|
| 222 |
+
|
| 223 |
+
test_captions_int = []
|
| 224 |
+
for caption in tokenized_captions_test:
|
| 225 |
+
curr_caption_int = []
|
| 226 |
+
for word in caption:
|
| 227 |
+
curr_caption_int.append(word_to_ix[word])
|
| 228 |
+
|
| 229 |
+
test_captions_int.append(curr_caption_int)
|
| 230 |
+
|
| 231 |
+
return (
|
| 232 |
+
train_captions_int,
|
| 233 |
+
test_captions_int,
|
| 234 |
+
ix_to_word,
|
| 235 |
+
word_to_ix,
|
| 236 |
+
len(ix_to_word),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def get_tokenized_captions(self, data_dir: str, filenames: list) -> Any: # type: ignore
|
| 240 |
+
"""
|
| 241 |
+
Helper function to tokenize and return captions for each image in filenames.
|
| 242 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
| 243 |
+
:param filenames: list of all filenames corresponding to the split
|
| 244 |
+
:return tokenized_captions: list of all tokenized captions for all files in filenames.
|
| 245 |
+
[this returns a list, where each element is again a list of tokens/words]
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
all_captions = []
|
| 249 |
+
for filename in filenames:
|
| 250 |
+
caption_path = os.path.join(data_dir, "text", filename + ".txt")
|
| 251 |
+
with open(caption_path, "r", encoding="utf8") as txt_file:
|
| 252 |
+
captions = txt_file.readlines()
|
| 253 |
+
count = 0
|
| 254 |
+
for caption in captions:
|
| 255 |
+
if len(caption) == 0:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
caption = caption.replace("\ufffd\ufffd", " ")
|
| 259 |
+
tokenizer = RegexpTokenizer(r"\w+")
|
| 260 |
+
tokens = tokenizer.tokenize(
|
| 261 |
+
caption.lower()
|
| 262 |
+
) # splits current caption/line to list of words/tokens
|
| 263 |
+
if len(tokens) == 0:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
tokens = [
|
| 267 |
+
t.encode("ascii", "ignore").decode("ascii") for t in tokens
|
| 268 |
+
]
|
| 269 |
+
tokens = [t for t in tokens if len(t) > 0]
|
| 270 |
+
|
| 271 |
+
all_captions.append(tokens)
|
| 272 |
+
count += 1
|
| 273 |
+
if count == self.num_captions_per_image:
|
| 274 |
+
break
|
| 275 |
+
if count < self.num_captions_per_image:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"Number of captions for {filename} is only {count},\
|
| 278 |
+
which is less than {self.num_captions_per_image}."
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return all_captions
|
| 282 |
+
|
| 283 |
+
def get_image(self, img_path: str, bbox: list, transform: Any) -> Any: # type: ignore
|
| 284 |
+
"""
|
| 285 |
+
Helper function to load and transform an image.
|
| 286 |
+
:param img_path: path to the image
|
| 287 |
+
:param bbox: bounding box coordinates [x, y, width, height]
|
| 288 |
+
:param transform: PyTorch transform to apply to the image
|
| 289 |
+
:return img_tensor: transformed image tensor
|
| 290 |
+
"""
|
| 291 |
+
img = Image.open(img_path).convert("RGB")
|
| 292 |
+
width, height = img.size
|
| 293 |
+
|
| 294 |
+
if bbox is not None:
|
| 295 |
+
r_val = int(np.maximum(bbox[2], bbox[3]) * 0.75)
|
| 296 |
+
|
| 297 |
+
center_x = int((2 * bbox[0] + bbox[2]) / 2)
|
| 298 |
+
center_y = int((2 * bbox[1] + bbox[3]) / 2)
|
| 299 |
+
y1_coord = np.maximum(0, center_y - r_val)
|
| 300 |
+
y2_coord = np.minimum(height, center_y + r_val)
|
| 301 |
+
x1_coord = np.maximum(0, center_x - r_val)
|
| 302 |
+
x2_coord = np.minimum(width, center_x + r_val)
|
| 303 |
+
|
| 304 |
+
img = img.crop(
|
| 305 |
+
[x1_coord, y1_coord, x2_coord, y2_coord]
|
| 306 |
+
) # This preprocessing steps seems to follow from
|
| 307 |
+
# Stackgan: Text to photo-realistic image synthesis
|
| 308 |
+
|
| 309 |
+
if transform is not None:
|
| 310 |
+
img_tensor = transform(img) # this scales to 304x304, i.e. 256 x (76/64).
|
| 311 |
+
x_val = np.random.randint(0, 48) # 304 - 256 = 48
|
| 312 |
+
y_val = np.random.randint(0, 48)
|
| 313 |
+
flip = np.random.rand() > 0.5
|
| 314 |
+
|
| 315 |
+
# crop
|
| 316 |
+
img_tensor = img_tensor.crop(
|
| 317 |
+
[x_val, y_val, x_val + 256, y_val + 256]
|
| 318 |
+
) # this crops to 256x256
|
| 319 |
+
if flip:
|
| 320 |
+
img_tensor = F.hflip(img_tensor)
|
| 321 |
+
|
| 322 |
+
img_tensor = self.normalize(img_tensor)
|
| 323 |
+
|
| 324 |
+
return img_tensor
|
| 325 |
+
|
| 326 |
+
def load_filenames(self, data_dir: str, split: str) -> Any:
|
| 327 |
+
"""
|
| 328 |
+
Helper function to get list of all image filenames.
|
| 329 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
| 330 |
+
:param split: 'train' or 'test' split
|
| 331 |
+
:return filenames: list of all image filenames
|
| 332 |
+
"""
|
| 333 |
+
filepath = f"{data_dir}{split}/filenames.pickle"
|
| 334 |
+
if os.path.isfile(filepath):
|
| 335 |
+
with open(filepath, "rb") as pick_file:
|
| 336 |
+
filenames = pickle.load(pick_file)
|
| 337 |
+
else:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
"Invalid split. Please use 'train' or 'test',\
|
| 340 |
+
or make sure the filenames.pickle file exists."
|
| 341 |
+
)
|
| 342 |
+
return filenames
|
| 343 |
+
|
| 344 |
+
def get_class_id(self, data_dir: str, split: str, total_elems: int) -> Any:
|
| 345 |
+
"""
|
| 346 |
+
Helper function to get list of all image class ids.
|
| 347 |
+
:param data_dir: path to the data directory [i.e. './birds/' or './coco/']
|
| 348 |
+
:param split: 'train' or 'test' split
|
| 349 |
+
:param total_elems: total number of elements in the dataset
|
| 350 |
+
:return class_ids: list of all image class ids
|
| 351 |
+
"""
|
| 352 |
+
filepath = f"{data_dir}{split}/class_info.pickle"
|
| 353 |
+
if os.path.isfile(filepath):
|
| 354 |
+
with open(filepath, "rb") as class_file:
|
| 355 |
+
class_ids = pickle.load(class_file, encoding="latin1")
|
| 356 |
+
else:
|
| 357 |
+
class_ids = np.arange(total_elems)
|
| 358 |
+
return class_ids
|
| 359 |
+
|
| 360 |
+
def get_bound_box(self, data_path: str) -> Any:
|
| 361 |
+
"""
|
| 362 |
+
Helper function to get the bounding box for birds dataset.
|
| 363 |
+
:param data_path: path to birds data directory [i.e. './data/birds/']
|
| 364 |
+
:return imageToBox: dictionary mapping image name to bounding box coordinates
|
| 365 |
+
"""
|
| 366 |
+
bbox_path = os.path.join(data_path, "CUB_200_2011/bounding_boxes.txt")
|
| 367 |
+
df_bounding_boxes = pd.read_csv(
|
| 368 |
+
bbox_path, delim_whitespace=True, header=None
|
| 369 |
+
).astype(int)
|
| 370 |
+
|
| 371 |
+
filepath = os.path.join(data_path, "CUB_200_2011/images.txt")
|
| 372 |
+
df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
|
| 373 |
+
filenames = df_filenames[
|
| 374 |
+
1
|
| 375 |
+
].tolist() # df_filenames[0] just contains the index or ID.
|
| 376 |
+
|
| 377 |
+
img_to_box = { # type: ignore
|
| 378 |
+
img_file[:-4]: [] for img_file in filenames
|
| 379 |
+
} # remove the .jpg extension from the names
|
| 380 |
+
num_imgs = len(filenames)
|
| 381 |
+
|
| 382 |
+
for i in range(0, num_imgs):
|
| 383 |
+
bbox = df_bounding_boxes.iloc[i][1:].tolist()
|
| 384 |
+
key = filenames[i][:-4]
|
| 385 |
+
img_to_box[key] = bbox
|
| 386 |
+
|
| 387 |
+
return img_to_box
|
src/data/stubs/bird.jpg
ADDED
|
src/data/stubs/pigeon.jpg
ADDED
|
src/data/stubs/rohit.jpeg
ADDED
|
src/data/tokenizer.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import re
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TAIMGANTokenizer:
|
| 7 |
+
def __init__(self, captions_path):
|
| 8 |
+
with open(captions_path, "rb") as ckpt_file:
|
| 9 |
+
captions = pickle.load(ckpt_file)
|
| 10 |
+
self.ix_to_word = captions[2]
|
| 11 |
+
self.word_to_ix = captions[3]
|
| 12 |
+
self.token_regex = r'\w+'
|
| 13 |
+
self.pad_token_id = self.word_to_ix["<end>"]
|
| 14 |
+
self.pad_repr = "[PAD]"
|
| 15 |
+
|
| 16 |
+
def encode(self, text: str) -> List[int]:
|
| 17 |
+
return [self.word_to_ix.get(word, self.pad_token_id)
|
| 18 |
+
for word in re.findall(self.token_regex, text.lower())]
|
| 19 |
+
|
| 20 |
+
def decode(self, tokens: List[int]) -> str:
|
| 21 |
+
return ' '.join([self.ix_to_word[token]
|
| 22 |
+
if token != self.pad_token_id else self.pad_repr
|
| 23 |
+
for token in tokens])
|
src/features/.gitkeep
ADDED
|
File without changes
|
src/features/__init__.py
ADDED
|
File without changes
|
src/features/build_features.py
ADDED
|
File without changes
|
src/models/.gitkeep
ADDED
|
File without changes
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helper functions for training loop."""
|
| 2 |
+
from .losses import discriminator_loss, generator_loss, kl_loss
|
| 3 |
+
from .train_model import train
|
| 4 |
+
from .utils import copy_gen_params, define_optimizers, load_params, prepare_labels
|
src/models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (461 Bytes). View file
|
|
|
src/models/__pycache__/losses.cpython-39.pyc
ADDED
|
Binary file (8.36 kB). View file
|
|
|
src/models/__pycache__/train_model.cpython-39.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
src/models/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (8.76 kB). View file
|
|
|
src/models/losses.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing the loss functions for the GANs."""
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
# pylint: disable=too-many-arguments
|
| 8 |
+
# pylint: disable=too-many-locals
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generator_loss(
|
| 12 |
+
logits: Dict[str, Dict[str, torch.Tensor]],
|
| 13 |
+
local_fake_incept_feat: torch.Tensor,
|
| 14 |
+
global_fake_incept_feat: torch.Tensor,
|
| 15 |
+
real_labels: torch.Tensor,
|
| 16 |
+
words_emb: torch.Tensor,
|
| 17 |
+
sent_emb: torch.Tensor,
|
| 18 |
+
match_labels: torch.Tensor,
|
| 19 |
+
cap_lens: torch.Tensor,
|
| 20 |
+
class_ids: torch.Tensor,
|
| 21 |
+
real_vgg_feat: torch.Tensor,
|
| 22 |
+
fake_vgg_feat: torch.Tensor,
|
| 23 |
+
const_dict: Dict[str, float],
|
| 24 |
+
) -> Any:
|
| 25 |
+
"""Calculate the loss for the generator.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
logits: Dictionary with fake/real and word-level/uncond/cond logits
|
| 29 |
+
|
| 30 |
+
local_fake_incept_feat: The local inception features for the fake images.
|
| 31 |
+
|
| 32 |
+
global_fake_incept_feat: The global inception features for the fake images.
|
| 33 |
+
|
| 34 |
+
real_labels: Label for "real" image as predicted by discriminator,
|
| 35 |
+
this is a tensor of ones. [shape: (batch_size, 1)].
|
| 36 |
+
|
| 37 |
+
word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)]
|
| 38 |
+
|
| 39 |
+
words_emb: The embeddings for all the words in the captions.
|
| 40 |
+
shape: (batch_size, embedding_size, max_caption_length)
|
| 41 |
+
|
| 42 |
+
sent_emb: The embeddings for the sentences.
|
| 43 |
+
shape: (batch_size, embedding_size)
|
| 44 |
+
|
| 45 |
+
match_labels: Tensor of shape: (batch_size, 1).
|
| 46 |
+
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
|
| 47 |
+
|
| 48 |
+
cap_lens: The length of the 'actual' captions in the batch [without padding]
|
| 49 |
+
shape: (batch_size, 1)
|
| 50 |
+
|
| 51 |
+
class_ids: The class ids for the instance. shape: (batch_size, 1)
|
| 52 |
+
|
| 53 |
+
real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128)
|
| 54 |
+
fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128)
|
| 55 |
+
|
| 56 |
+
const_dict: The dictionary containing the constants.
|
| 57 |
+
"""
|
| 58 |
+
lambda1 = const_dict["lambda1"]
|
| 59 |
+
total_error_g = 0.0
|
| 60 |
+
|
| 61 |
+
cond_logits = logits["fake"]["cond"]
|
| 62 |
+
cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels)
|
| 63 |
+
|
| 64 |
+
uncond_logits = logits["fake"]["uncond"]
|
| 65 |
+
uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels)
|
| 66 |
+
|
| 67 |
+
# add up the conditional and unconditional losses
|
| 68 |
+
loss_g = cond_err_g + uncond_err_g
|
| 69 |
+
total_error_g += loss_g
|
| 70 |
+
|
| 71 |
+
# DAMSM Loss from attnGAN.
|
| 72 |
+
loss_damsm = damsm_loss(
|
| 73 |
+
local_fake_incept_feat,
|
| 74 |
+
global_fake_incept_feat,
|
| 75 |
+
words_emb,
|
| 76 |
+
sent_emb,
|
| 77 |
+
match_labels,
|
| 78 |
+
cap_lens,
|
| 79 |
+
class_ids,
|
| 80 |
+
const_dict,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
total_error_g += loss_damsm
|
| 84 |
+
|
| 85 |
+
loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) # perceptual loss
|
| 86 |
+
|
| 87 |
+
total_error_g += lambda1 * loss_per
|
| 88 |
+
|
| 89 |
+
return total_error_g
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def damsm_loss(
|
| 93 |
+
local_incept_feat: torch.Tensor,
|
| 94 |
+
global_incept_feat: torch.Tensor,
|
| 95 |
+
words_emb: torch.Tensor,
|
| 96 |
+
sent_emb: torch.Tensor,
|
| 97 |
+
match_labels: torch.Tensor,
|
| 98 |
+
cap_lens: torch.Tensor,
|
| 99 |
+
class_ids: torch.Tensor,
|
| 100 |
+
const_dict: Dict[str, float],
|
| 101 |
+
) -> Any:
|
| 102 |
+
"""Calculate the DAMSM loss from the attnGAN paper.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
|
| 106 |
+
|
| 107 |
+
global_incept_feat: The global inception features. [shape: (batch, D)]
|
| 108 |
+
|
| 109 |
+
words_emb: The embeddings for all the words in the captions.
|
| 110 |
+
|
| 111 |
+
shape: (batch, D, max_caption_length)
|
| 112 |
+
|
| 113 |
+
sent_emb: The embeddings for the sentences. shape: (batch_size, D)
|
| 114 |
+
|
| 115 |
+
match_labels: Tensor of shape: (batch_size, 1).
|
| 116 |
+
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
|
| 117 |
+
|
| 118 |
+
cap_lens: The length of the 'actual' captions in the batch [without padding]
|
| 119 |
+
shape: (batch_size, 1)
|
| 120 |
+
|
| 121 |
+
class_ids: The class ids for the instance. shape: (batch, 1)
|
| 122 |
+
|
| 123 |
+
const_dict: The dictionary containing the constants.
|
| 124 |
+
"""
|
| 125 |
+
batch_size = match_labels.size(0)
|
| 126 |
+
# Mask mis-match samples, that come from the same class as the real sample
|
| 127 |
+
masks = []
|
| 128 |
+
|
| 129 |
+
match_scores = []
|
| 130 |
+
gamma1 = const_dict["gamma1"]
|
| 131 |
+
gamma2 = const_dict["gamma2"]
|
| 132 |
+
gamma3 = const_dict["gamma3"]
|
| 133 |
+
lambda3 = const_dict["lambda3"]
|
| 134 |
+
|
| 135 |
+
for i in range(batch_size):
|
| 136 |
+
mask = (class_ids == class_ids[i]).int()
|
| 137 |
+
# This ensures that "correct class" index is not included in the mask.
|
| 138 |
+
mask[i] = 0
|
| 139 |
+
masks.append(mask.reshape(1, -1)) # shape: (1, batch)
|
| 140 |
+
|
| 141 |
+
numb_words = int(cap_lens[i])
|
| 142 |
+
# shape: (1, D, L), this picks the caption at ith batch index.
|
| 143 |
+
query_words = words_emb[i, :, :numb_words].unsqueeze(0)
|
| 144 |
+
# shape: (batch, D, L), this expands the same caption for all batch indices.
|
| 145 |
+
query_words = query_words.repeat(batch_size, 1, 1)
|
| 146 |
+
|
| 147 |
+
c_i = compute_region_context_vector(
|
| 148 |
+
local_incept_feat, query_words, gamma1
|
| 149 |
+
) # Taken from attnGAN paper. shape: (batch, D, L)
|
| 150 |
+
|
| 151 |
+
query_words = query_words.transpose(1, 2) # shape: (batch, L, D)
|
| 152 |
+
c_i = c_i.transpose(1, 2) # shape: (batch, L, D)
|
| 153 |
+
query_words = query_words.reshape(
|
| 154 |
+
batch_size * numb_words, -1
|
| 155 |
+
) # shape: (batch * L, D)
|
| 156 |
+
c_i = c_i.reshape(batch_size * numb_words, -1) # shape: (batch * L, D)
|
| 157 |
+
|
| 158 |
+
r_i = compute_relevance(
|
| 159 |
+
c_i, query_words
|
| 160 |
+
) # cosine similarity, or R(c_i, e_i) from attnGAN paper. shape: (batch * L, 1)
|
| 161 |
+
r_i = r_i.view(batch_size, numb_words) # shape: (batch, L)
|
| 162 |
+
r_i = torch.exp(r_i * gamma2) # shape: (batch, L)
|
| 163 |
+
r_i = r_i.sum(dim=1, keepdim=True) # shape: (batch, 1)
|
| 164 |
+
r_i = torch.log(
|
| 165 |
+
r_i
|
| 166 |
+
) # This is image-text matching score b/w whole image and caption, shape: (batch, 1)
|
| 167 |
+
match_scores.append(r_i)
|
| 168 |
+
|
| 169 |
+
masks = torch.cat(masks, dim=0).bool() # type: ignore
|
| 170 |
+
match_scores = torch.cat(match_scores, dim=1) # type: ignore
|
| 171 |
+
|
| 172 |
+
# This corresponds to P(D|Q) from attnGAN.
|
| 173 |
+
match_scores = gamma3 * match_scores # type: ignore
|
| 174 |
+
match_scores.data.masked_fill_( # type: ignore
|
| 175 |
+
masks, -float("inf")
|
| 176 |
+
) # mask out the scores for mis-matched samples
|
| 177 |
+
|
| 178 |
+
match_scores_t = match_scores.transpose( # type: ignore
|
| 179 |
+
0, 1
|
| 180 |
+
) # This corresponds to P(Q|D) from attnGAN.
|
| 181 |
+
|
| 182 |
+
# This corresponds to L1_w from attnGAN.
|
| 183 |
+
l1_w = nn.CrossEntropyLoss()(match_scores, match_labels)
|
| 184 |
+
# This corresponds to L2_w from attnGAN.
|
| 185 |
+
l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels)
|
| 186 |
+
|
| 187 |
+
incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1)
|
| 188 |
+
sent_emb_norm = torch.linalg.norm(sent_emb, dim=1)
|
| 189 |
+
|
| 190 |
+
# shape: (batch, batch)
|
| 191 |
+
global_match_score = global_incept_feat @ (sent_emb.T)
|
| 192 |
+
|
| 193 |
+
global_match_score = (
|
| 194 |
+
global_match_score / torch.outer(incept_feat_norm, sent_emb_norm)
|
| 195 |
+
).clamp(min=1e-8)
|
| 196 |
+
global_match_score = gamma3 * global_match_score
|
| 197 |
+
|
| 198 |
+
# mask out the scores for mis-matched samples
|
| 199 |
+
global_match_score.data.masked_fill_(masks, -float("inf")) # type: ignore
|
| 200 |
+
|
| 201 |
+
global_match_t = global_match_score.T # shape: (batch, batch)
|
| 202 |
+
|
| 203 |
+
# This corresponds to L1_s from attnGAN.
|
| 204 |
+
l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels)
|
| 205 |
+
# This corresponds to L2_s from attnGAN.
|
| 206 |
+
l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels)
|
| 207 |
+
|
| 208 |
+
loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s)
|
| 209 |
+
|
| 210 |
+
return loss_damsm
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any:
|
| 214 |
+
"""Computes the cosine similarity between the region context vector and the query words.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
c_i: The region context vector. shape: (batch * L, D)
|
| 218 |
+
query_words: The query words. shape: (batch * L, D)
|
| 219 |
+
"""
|
| 220 |
+
prod = c_i * query_words # shape: (batch * L, D)
|
| 221 |
+
numr = torch.sum(prod, dim=1) # shape: (batch * L, 1)
|
| 222 |
+
norm_c = torch.linalg.norm(c_i, ord=2, dim=1)
|
| 223 |
+
norm_q = torch.linalg.norm(query_words, ord=2, dim=1)
|
| 224 |
+
denr = norm_c * norm_q
|
| 225 |
+
r_i = (numr / denr).clamp(min=1e-8).squeeze() # shape: (batch * L, 1)
|
| 226 |
+
return r_i
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def compute_region_context_vector(
|
| 230 |
+
local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float
|
| 231 |
+
) -> Any:
|
| 232 |
+
"""Compute the region context vector (c_i) from attnGAN paper.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
|
| 236 |
+
query_words: The embeddings for all the words in the captions. shape: (batch, D, L)
|
| 237 |
+
gamma1: The gamma1 value from attnGAN paper.
|
| 238 |
+
"""
|
| 239 |
+
batch, L = query_words.size(0), query_words.size(2) # pylint: disable=invalid-name
|
| 240 |
+
|
| 241 |
+
feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3)
|
| 242 |
+
N = feat_height * feat_width # pylint: disable=invalid-name
|
| 243 |
+
|
| 244 |
+
# Reshape the local inception features to (batch, D, N)
|
| 245 |
+
local_incept_feat = local_incept_feat.view(batch, -1, N)
|
| 246 |
+
# shape: (batch, N, D)
|
| 247 |
+
incept_feat_t = local_incept_feat.transpose(1, 2)
|
| 248 |
+
|
| 249 |
+
sim_matrix = incept_feat_t @ query_words # shape: (batch, N, L)
|
| 250 |
+
sim_matrix = sim_matrix.view(batch * N, L) # shape: (batch * N, L)
|
| 251 |
+
|
| 252 |
+
sim_matrix = nn.Softmax(dim=1)(sim_matrix) # shape: (batch * N, L)
|
| 253 |
+
sim_matrix = sim_matrix.view(batch, N, L) # shape: (batch, N, L)
|
| 254 |
+
|
| 255 |
+
sim_matrix = torch.transpose(sim_matrix, 1, 2) # shape: (batch, L, N)
|
| 256 |
+
sim_matrix = sim_matrix.reshape(batch * L, N) # shape: (batch * L, N)
|
| 257 |
+
|
| 258 |
+
alpha_j = gamma1 * sim_matrix # shape: (batch * L, N)
|
| 259 |
+
alpha_j = nn.Softmax(dim=1)(alpha_j) # shape: (batch * L, N)
|
| 260 |
+
alpha_j = alpha_j.view(batch, L, N) # shape: (batch, L, N)
|
| 261 |
+
alpha_j_t = torch.transpose(alpha_j, 1, 2) # shape: (batch, N, L)
|
| 262 |
+
|
| 263 |
+
c_i = (
|
| 264 |
+
local_incept_feat @ alpha_j_t
|
| 265 |
+
) # shape: (batch, D, L) [summing over N dimension in paper, so we multiply like this]
|
| 266 |
+
return c_i
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def discriminator_loss(
|
| 270 |
+
logits: Dict[str, Dict[str, torch.Tensor]],
|
| 271 |
+
labels: Dict[str, Dict[str, torch.Tensor]],
|
| 272 |
+
) -> Any:
|
| 273 |
+
"""
|
| 274 |
+
Calculate discriminator objective
|
| 275 |
+
|
| 276 |
+
:param dict[str, dict[str, torch.Tensor]] logits:
|
| 277 |
+
Dictionary with fake/real and word-level/uncond/cond logits
|
| 278 |
+
|
| 279 |
+
Example:
|
| 280 |
+
|
| 281 |
+
logits = {
|
| 282 |
+
"fake": {
|
| 283 |
+
"word_level": torch.Tensor (BxL)
|
| 284 |
+
"uncond": torch.Tensor (Bx1)
|
| 285 |
+
"cond": torch.Tensor (Bx1)
|
| 286 |
+
},
|
| 287 |
+
"real": {
|
| 288 |
+
"word_level": torch.Tensor (BxL)
|
| 289 |
+
"uncond": torch.Tensor (Bx1)
|
| 290 |
+
"cond": torch.Tensor (Bx1)
|
| 291 |
+
},
|
| 292 |
+
}
|
| 293 |
+
:param dict[str, dict[str, torch.Tensor]] labels:
|
| 294 |
+
Dictionary with fake/real and word-level/image labels
|
| 295 |
+
|
| 296 |
+
Example:
|
| 297 |
+
|
| 298 |
+
labels = {
|
| 299 |
+
"fake": {
|
| 300 |
+
"word_level": torch.Tensor (BxL)
|
| 301 |
+
"image": torch.Tensor (Bx1)
|
| 302 |
+
},
|
| 303 |
+
"real": {
|
| 304 |
+
"word_level": torch.Tensor (BxL)
|
| 305 |
+
"image": torch.Tensor (Bx1)
|
| 306 |
+
},
|
| 307 |
+
}
|
| 308 |
+
:param float lambda_4: Hyperparameter for word loss in paper
|
| 309 |
+
:return: Discriminator objective loss
|
| 310 |
+
:rtype: Any
|
| 311 |
+
"""
|
| 312 |
+
# define main loss functions for logit losses
|
| 313 |
+
tot_loss = 0.0
|
| 314 |
+
bce_logits = nn.BCEWithLogitsLoss()
|
| 315 |
+
bce = nn.BCELoss()
|
| 316 |
+
# calculate word-level loss
|
| 317 |
+
word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"])
|
| 318 |
+
# calculate unconditional adversarial loss
|
| 319 |
+
uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"])
|
| 320 |
+
|
| 321 |
+
# calculate conditional adversarial loss
|
| 322 |
+
cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"])
|
| 323 |
+
|
| 324 |
+
tot_loss = (uncond_loss + cond_loss) / 2.0
|
| 325 |
+
|
| 326 |
+
fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"])
|
| 327 |
+
fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"])
|
| 328 |
+
|
| 329 |
+
tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0
|
| 330 |
+
tot_loss += word_loss
|
| 331 |
+
|
| 332 |
+
return tot_loss
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any:
|
| 336 |
+
"""
|
| 337 |
+
Calculate KL loss
|
| 338 |
+
|
| 339 |
+
:param torch.Tensor mu_tensor: Mean of latent distribution
|
| 340 |
+
:param torch.Tensor logvar: Log variance of latent distribution
|
| 341 |
+
:return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)]
|
| 342 |
+
:rtype: Any
|
| 343 |
+
"""
|
| 344 |
+
return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar)))
|
src/models/modules/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""All the modules used in creation of Generator and Discriminator"""
|
| 2 |
+
from .acm import ACM
|
| 3 |
+
from .attention import ChannelWiseAttention, SpatialAttention
|
| 4 |
+
from .cond_augment import CondAugmentation
|
| 5 |
+
from .conv_utils import calc_out_conv, conv1d, conv2d
|
| 6 |
+
from .discriminator import Discriminator, WordLevelLogits
|
| 7 |
+
from .downsample import down_sample
|
| 8 |
+
from .generator import Generator
|
| 9 |
+
from .image_encoder import InceptionEncoder, VGGEncoder
|
| 10 |
+
from .residual import ResidualBlock
|
| 11 |
+
from .text_encoder import TextEncoder
|
| 12 |
+
from .upsample import img_up_block, up_sample
|
src/models/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (891 Bytes). View file
|
|
|
src/models/modules/__pycache__/acm.cpython-39.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
src/models/modules/__pycache__/attention.cpython-39.pyc
ADDED
|
Binary file (3.38 kB). View file
|
|
|
src/models/modules/__pycache__/cond_augment.cpython-39.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
src/models/modules/__pycache__/conv_utils.cpython-39.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
src/models/modules/__pycache__/discriminator.cpython-39.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
src/models/modules/__pycache__/downsample.cpython-39.pyc
ADDED
|
Binary file (598 Bytes). View file
|
|
|
src/models/modules/__pycache__/generator.cpython-39.pyc
ADDED
|
Binary file (9.03 kB). View file
|
|
|
src/models/modules/__pycache__/image_encoder.cpython-39.pyc
ADDED
|
Binary file (4.27 kB). View file
|
|
|
src/models/modules/__pycache__/residual.cpython-39.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
src/models/modules/__pycache__/text_encoder.cpython-39.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
src/models/modules/__pycache__/upsample.cpython-39.pyc
ADDED
|
Binary file (983 Bytes). View file
|
|
|
src/models/modules/acm.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ACM and its variations"""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from .conv_utils import conv2d
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ACM(nn.Module):
|
| 12 |
+
"""Affine Combination Module from ManiGAN"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, img_chans: int, text_chans: int, inner_dim: int = 64) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Initialize the convolutional layers
|
| 17 |
+
|
| 18 |
+
:param int img_chans: Channels in visual input
|
| 19 |
+
:param int text_chans: Channels of textual input
|
| 20 |
+
:param int inner_dim: Hyperparameters for inner dimensionality of features
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.conv = conv2d(in_channels=img_chans, out_channels=inner_dim)
|
| 24 |
+
self.weights = conv2d(in_channels=inner_dim, out_channels=text_chans)
|
| 25 |
+
self.biases = conv2d(in_channels=inner_dim, out_channels=text_chans)
|
| 26 |
+
|
| 27 |
+
def forward(self, text: torch.Tensor, img: torch.Tensor) -> Any:
|
| 28 |
+
"""
|
| 29 |
+
Propagate the textual and visual input through the ACM module
|
| 30 |
+
|
| 31 |
+
:param torch.Tensor text: Textual input (can be hidden features)
|
| 32 |
+
:param torch.Tensor img: Image input
|
| 33 |
+
:return: Affine combination of text and image
|
| 34 |
+
:rtype: torch.Tensor
|
| 35 |
+
"""
|
| 36 |
+
img_features = self.conv(img)
|
| 37 |
+
return text * self.weights(img_features) + self.biases(img_features)
|
src/models/modules/attention.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Attention modules"""
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from src.models.modules.conv_utils import conv1d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ChannelWiseAttention(nn.Module):
|
| 11 |
+
"""ChannelWise attention adapted from ControlGAN"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, fm_size: int, text_d: int) -> None:
|
| 14 |
+
"""
|
| 15 |
+
Initialize the Channel-Wise attention module
|
| 16 |
+
|
| 17 |
+
:param int fm_size:
|
| 18 |
+
Height and width of feature map on k-th iteration of forward-pass.
|
| 19 |
+
In paper, it's H_k * W_k
|
| 20 |
+
:param int text_d: Dimensionality of sentence. From paper, it's D
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
# perception layer
|
| 24 |
+
self.text_conv = conv1d(text_d, fm_size)
|
| 25 |
+
# attention across channel dimension
|
| 26 |
+
self.softmax = nn.Softmax(2)
|
| 27 |
+
|
| 28 |
+
def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any:
|
| 29 |
+
"""
|
| 30 |
+
Apply attention to visual features taking into account features of words
|
| 31 |
+
|
| 32 |
+
:param torch.Tensor v_k: Visual context
|
| 33 |
+
:param torch.Tensor w_text: Textual features
|
| 34 |
+
:return: Fused hidden visual features and word features
|
| 35 |
+
:rtype: Any
|
| 36 |
+
"""
|
| 37 |
+
w_hat = self.text_conv(w_text)
|
| 38 |
+
m_k = v_k @ w_hat
|
| 39 |
+
a_k = self.softmax(m_k)
|
| 40 |
+
w_hat = torch.transpose(w_hat, 1, 2)
|
| 41 |
+
return a_k @ w_hat
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SpatialAttention(nn.Module):
|
| 45 |
+
"""Spatial attention module for attending textual context to visual features"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, d: int, d_hat: int) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Set up softmax and conv layers
|
| 50 |
+
|
| 51 |
+
:param int d: Initial embedding size for textual features. D from paper
|
| 52 |
+
:param int d_hat: Height of image feature map. D_hat from paper
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.softmax = nn.Softmax(2)
|
| 56 |
+
self.conv = conv1d(d, d_hat)
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
text_context: torch.Tensor,
|
| 61 |
+
image: torch.Tensor,
|
| 62 |
+
mask: Optional[torch.Tensor] = None,
|
| 63 |
+
) -> Any:
|
| 64 |
+
"""
|
| 65 |
+
Project image features into the latent space
|
| 66 |
+
of textual features and apply attention
|
| 67 |
+
|
| 68 |
+
:param torch.Tensor text_context: D x T tensor of hidden textual features
|
| 69 |
+
:param torch.Tensor image: D_hat x N visual features
|
| 70 |
+
:param Optional[torch.Tensor] mask:
|
| 71 |
+
Boolean tensor for masking the padded words. BxL
|
| 72 |
+
:return: Word features attended by visual features
|
| 73 |
+
:rtype: Any
|
| 74 |
+
"""
|
| 75 |
+
# number of features on image feature map H * W
|
| 76 |
+
feature_num = image.size(2)
|
| 77 |
+
# number of words in caption
|
| 78 |
+
len_caption = text_context.size(2)
|
| 79 |
+
text_context = self.conv(text_context)
|
| 80 |
+
image = torch.transpose(image, 1, 2)
|
| 81 |
+
s_i_j = image @ text_context
|
| 82 |
+
if mask is not None:
|
| 83 |
+
# duplicating mask and aligning dims with s_i_j
|
| 84 |
+
mask = mask.repeat(1, feature_num).view(-1, feature_num, len_caption)
|
| 85 |
+
s_i_j[mask] = -float("inf")
|
| 86 |
+
b_i_j = self.softmax(s_i_j)
|
| 87 |
+
c_i_j = b_i_j @ torch.transpose(text_context, 1, 2)
|
| 88 |
+
return torch.transpose(c_i_j, 1, 2)
|
src/models/modules/cond_augment.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning Augmentation Module"""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CondAugmentation(nn.Module):
|
| 10 |
+
"""Conditioning Augmentation Module"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, D: int, conditioning_dim: int):
|
| 13 |
+
"""
|
| 14 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
| 15 |
+
:param conditioning_dim: Dimension of the conditioning space
|
| 16 |
+
"""
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.cond_dim = conditioning_dim
|
| 19 |
+
self.cond_augment = nn.Linear(D, conditioning_dim * 4, bias=True)
|
| 20 |
+
self.glu = nn.GLU(dim=1)
|
| 21 |
+
|
| 22 |
+
def encode(self, text_embedding: torch.Tensor) -> Any:
|
| 23 |
+
"""
|
| 24 |
+
This function encodes the text embedding into the conditioning space
|
| 25 |
+
:param text_embedding: Text embedding
|
| 26 |
+
:return: Conditioning embedding
|
| 27 |
+
"""
|
| 28 |
+
x_tensor = self.glu(self.cond_augment(text_embedding))
|
| 29 |
+
mu_tensor = x_tensor[:, : self.cond_dim]
|
| 30 |
+
logvar = x_tensor[:, self.cond_dim :]
|
| 31 |
+
return mu_tensor, logvar
|
| 32 |
+
|
| 33 |
+
def sample(self, mu_tensor: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
This function samples from the Gaussian distribution
|
| 36 |
+
:param mu: Mean of the Gaussian distribution
|
| 37 |
+
:param logvar: Log variance of the Gaussian distribution
|
| 38 |
+
:return: Sample from the Gaussian distribution
|
| 39 |
+
"""
|
| 40 |
+
std = torch.exp(0.5 * logvar)
|
| 41 |
+
eps = torch.randn_like(
|
| 42 |
+
std
|
| 43 |
+
) # check if this should add requires_grad = True to this tensor?
|
| 44 |
+
return mu_tensor + eps * std
|
| 45 |
+
|
| 46 |
+
def forward(self, text_embedding: torch.Tensor) -> Any:
|
| 47 |
+
"""
|
| 48 |
+
This function encodes the text embedding into the conditioning space,
|
| 49 |
+
and samples from the Gaussian distribution.
|
| 50 |
+
:param text_embedding: Text embedding
|
| 51 |
+
:return c_hat: Conditioning embedding (C^ from StackGAN++ paper)
|
| 52 |
+
:return mu: Mean of the Gaussian distribution
|
| 53 |
+
:return logvar: Log variance of the Gaussian distribution
|
| 54 |
+
"""
|
| 55 |
+
mu_tensor, logvar = self.encode(text_embedding)
|
| 56 |
+
c_hat = self.sample(mu_tensor, logvar)
|
| 57 |
+
return c_hat, mu_tensor, logvar
|
src/models/modules/conv_utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frequently used convolution modules"""
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def conv2d(
|
| 9 |
+
in_channels: int,
|
| 10 |
+
out_channels: int,
|
| 11 |
+
kernel_size: int = 3,
|
| 12 |
+
stride: int = 1,
|
| 13 |
+
padding: int = 1,
|
| 14 |
+
) -> nn.Conv2d:
|
| 15 |
+
"""
|
| 16 |
+
Template convolution which is typically used throughout the project
|
| 17 |
+
|
| 18 |
+
:param int in_channels: Number of input channels
|
| 19 |
+
:param int out_channels: Number of output channels
|
| 20 |
+
:param int kernel_size: Size of sliding kernel
|
| 21 |
+
:param int stride: How many steps kernel does when sliding
|
| 22 |
+
:param int padding: How many dimensions to pad
|
| 23 |
+
:return: Convolution layer with parameters
|
| 24 |
+
:rtype: nn.Conv2d
|
| 25 |
+
"""
|
| 26 |
+
return nn.Conv2d(
|
| 27 |
+
in_channels=in_channels,
|
| 28 |
+
out_channels=out_channels,
|
| 29 |
+
kernel_size=kernel_size,
|
| 30 |
+
stride=stride,
|
| 31 |
+
padding=padding,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def conv1d(
|
| 36 |
+
in_channels: int,
|
| 37 |
+
out_channels: int,
|
| 38 |
+
kernel_size: int = 1,
|
| 39 |
+
stride: int = 1,
|
| 40 |
+
padding: int = 0,
|
| 41 |
+
) -> nn.Conv1d:
|
| 42 |
+
"""
|
| 43 |
+
Template 1d convolution which is typically used throughout the project
|
| 44 |
+
|
| 45 |
+
:param int in_channels: Number of input channels
|
| 46 |
+
:param int out_channels: Number of output channels
|
| 47 |
+
:param int kernel_size: Size of sliding kernel
|
| 48 |
+
:param int stride: How many steps kernel does when sliding
|
| 49 |
+
:param int padding: How many dimensions to pad
|
| 50 |
+
:return: Convolution layer with parameters
|
| 51 |
+
:rtype: nn.Conv2d
|
| 52 |
+
"""
|
| 53 |
+
return nn.Conv1d(
|
| 54 |
+
in_channels=in_channels,
|
| 55 |
+
out_channels=out_channels,
|
| 56 |
+
kernel_size=kernel_size,
|
| 57 |
+
stride=stride,
|
| 58 |
+
padding=padding,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def calc_out_conv(
|
| 63 |
+
h_in: int, w_in: int, kernel_size: int = 3, stride: int = 1, padding: int = 0
|
| 64 |
+
) -> Tuple[int, int]:
|
| 65 |
+
"""
|
| 66 |
+
Calculate the dimensionalities of images propagated through conv layers
|
| 67 |
+
|
| 68 |
+
:param h_in: Height of the image
|
| 69 |
+
:param w_in: Width of the image
|
| 70 |
+
:param kernel_size: Size of sliding kernel
|
| 71 |
+
:param stride: How many steps kernel does when sliding
|
| 72 |
+
:param padding: How many dimensions to pad
|
| 73 |
+
:return: Height and width of image through convolution
|
| 74 |
+
:rtype: tuple[int, int]
|
| 75 |
+
"""
|
| 76 |
+
h_out = int((h_in + 2 * padding - kernel_size) / stride + 1)
|
| 77 |
+
w_out = int((w_in + 2 * padding - kernel_size) / stride + 1)
|
| 78 |
+
return h_out, w_out
|
src/models/modules/discriminator.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Discriminator providing word-level feedback"""
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from src.models.modules.conv_utils import conv1d, conv2d
|
| 8 |
+
from src.models.modules.image_encoder import InceptionEncoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WordLevelLogits(nn.Module):
|
| 12 |
+
"""API for converting regional feature maps into logits for multi-class classification"""
|
| 13 |
+
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Instantiate the module with softmax on channel dimension
|
| 17 |
+
"""
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.softmax = nn.Softmax(dim=1)
|
| 20 |
+
# layer for flattening the feature maps
|
| 21 |
+
self.flat = nn.Flatten(start_dim=2)
|
| 22 |
+
# change dism of of textual embs to correlate with chans of inception
|
| 23 |
+
self.chan_reduction = conv1d(256, 128)
|
| 24 |
+
|
| 25 |
+
def forward(
|
| 26 |
+
self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor
|
| 27 |
+
) -> Any:
|
| 28 |
+
"""
|
| 29 |
+
Fuse two types of features together to get output for feeding into the classification loss
|
| 30 |
+
:param torch.Tensor visual_features:
|
| 31 |
+
Feature maps of an image after being processed by Inception encoder. Bx128x17x17
|
| 32 |
+
:param torch.Tensor word_embs:
|
| 33 |
+
Word-level embeddings from the text encoder Bx256xL
|
| 34 |
+
:return: Logits for each word in the picture. BxL
|
| 35 |
+
:rtype: Any
|
| 36 |
+
"""
|
| 37 |
+
# make textual and visual features have the same amount of channels
|
| 38 |
+
word_embs = self.chan_reduction(word_embs)
|
| 39 |
+
# flattening the feature maps
|
| 40 |
+
visual_features = self.flat(visual_features)
|
| 41 |
+
word_embs = torch.transpose(word_embs, 1, 2)
|
| 42 |
+
word_region_correlations = word_embs @ visual_features
|
| 43 |
+
# normalize across L dimension
|
| 44 |
+
m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
|
| 45 |
+
# normalize across H*W dimension
|
| 46 |
+
m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
|
| 47 |
+
m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
|
| 48 |
+
weighted_img_feats = visual_features @ m_norm_hw
|
| 49 |
+
weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
|
| 50 |
+
weighted_img_feats[mask] = -float("inf")
|
| 51 |
+
deltas = self.softmax(weighted_img_feats)
|
| 52 |
+
return deltas
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class UnconditionalLogits(nn.Module):
|
| 56 |
+
"""Head for retrieving logits from an image"""
|
| 57 |
+
|
| 58 |
+
def __init__(self) -> None:
|
| 59 |
+
"""Initialize modules that reduce the features down to a set of logits"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.conv = nn.Conv2d(128, 1, kernel_size=17)
|
| 62 |
+
# flattening BxLx1x1 into Bx1
|
| 63 |
+
self.flat = nn.Flatten()
|
| 64 |
+
|
| 65 |
+
def forward(self, visual_features: torch.Tensor) -> Any:
|
| 66 |
+
"""
|
| 67 |
+
Compute logits for unconditioned adversarial loss
|
| 68 |
+
|
| 69 |
+
:param visual_features: Local features from Inception network. Bx128x17x17
|
| 70 |
+
:return: Logits for unconditioned adversarial loss. Bx1
|
| 71 |
+
:rtype: Any
|
| 72 |
+
"""
|
| 73 |
+
# reduce channels and feature maps for visual features
|
| 74 |
+
visual_features = self.conv(visual_features)
|
| 75 |
+
# flatten Bx1x1x1 into Bx1
|
| 76 |
+
logits = self.flat(visual_features)
|
| 77 |
+
return logits
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ConditionalLogits(nn.Module):
|
| 81 |
+
"""Logits extractor for conditioned adversarial loss"""
|
| 82 |
+
|
| 83 |
+
def __init__(self) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
# layer for forming the feature maps out of textual info
|
| 86 |
+
self.text_to_fm = conv1d(256, 17 * 17)
|
| 87 |
+
# fitting the size of text channels to the size of visual channels
|
| 88 |
+
self.chan_aligner = conv2d(1, 128)
|
| 89 |
+
# for reduced textual + visual features down to 1x1 feature map
|
| 90 |
+
self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17)
|
| 91 |
+
# converting Bx1x1x1 into Bx1
|
| 92 |
+
self.flat = nn.Flatten()
|
| 93 |
+
|
| 94 |
+
def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any:
|
| 95 |
+
"""
|
| 96 |
+
Compute logits for conditional adversarial loss
|
| 97 |
+
|
| 98 |
+
:param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17
|
| 99 |
+
:param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256
|
| 100 |
+
:return: Logits for conditional adversarial loss. BxL
|
| 101 |
+
:rtype: Any
|
| 102 |
+
"""
|
| 103 |
+
# make text and visual features have the same sizes of feature maps
|
| 104 |
+
# Bx256 -> Bx256x1 -> Bx289x1
|
| 105 |
+
sent_embs = sent_embs.view(-1, 256, 1)
|
| 106 |
+
sent_embs = self.text_to_fm(sent_embs)
|
| 107 |
+
# transform textual info into shape of visual feature maps
|
| 108 |
+
# Bx289x1 -> Bx1x17x17
|
| 109 |
+
sent_embs = sent_embs.view(-1, 1, 17, 17)
|
| 110 |
+
# propagate text embs through 1d conv to
|
| 111 |
+
# align dims with visual feature maps
|
| 112 |
+
sent_embs = self.chan_aligner(sent_embs)
|
| 113 |
+
# unite textual and visual features across the dim of channels
|
| 114 |
+
cross_features = torch.cat((visual_features, sent_embs), dim=1)
|
| 115 |
+
# reduce dims down to length of caption and form raw logits
|
| 116 |
+
cross_features = self.joint_conv(cross_features)
|
| 117 |
+
# form logits from Bx1x1x1 into Bx1
|
| 118 |
+
logits = self.flat(cross_features)
|
| 119 |
+
return logits
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Discriminator(nn.Module):
|
| 123 |
+
"""Simple CNN-based discriminator"""
|
| 124 |
+
|
| 125 |
+
def __init__(self) -> None:
|
| 126 |
+
"""Use a pretrained InceptionNet to extract features"""
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.encoder = InceptionEncoder(D=128)
|
| 129 |
+
# define different logit extractors for different losses
|
| 130 |
+
self.logits_word_level = WordLevelLogits()
|
| 131 |
+
self.logits_uncond = UnconditionalLogits()
|
| 132 |
+
self.logits_cond = ConditionalLogits()
|
| 133 |
+
|
| 134 |
+
def forward(self, images: torch.Tensor) -> Any:
|
| 135 |
+
"""
|
| 136 |
+
Retrieves image features encoded by the image encoder
|
| 137 |
+
|
| 138 |
+
:param torch.Tensor images: Images to be analyzed. Bx3x256x256
|
| 139 |
+
:return: image features encoded by image encoder. Bx128x17x17
|
| 140 |
+
"""
|
| 141 |
+
# only taking the local features from inception
|
| 142 |
+
# Bx3x256x256 -> Bx128x17x17
|
| 143 |
+
img_features, _ = self.encoder(images)
|
| 144 |
+
return img_features
|
src/models/modules/downsample.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""downsample module."""
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def down_sample(in_planes: int, out_planes: int) -> nn.Module:
|
| 7 |
+
"""UpSample module."""
|
| 8 |
+
return nn.Sequential(
|
| 9 |
+
nn.Conv2d(
|
| 10 |
+
in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False
|
| 11 |
+
),
|
| 12 |
+
nn.BatchNorm2d(out_planes),
|
| 13 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 14 |
+
)
|
src/models/modules/generator.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generator Module"""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from src.models.modules.acm import ACM
|
| 9 |
+
from src.models.modules.attention import ChannelWiseAttention, SpatialAttention
|
| 10 |
+
from src.models.modules.cond_augment import CondAugmentation
|
| 11 |
+
from src.models.modules.downsample import down_sample
|
| 12 |
+
from src.models.modules.residual import ResidualBlock
|
| 13 |
+
from src.models.modules.upsample import img_up_block, up_sample
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class InitStageG(nn.Module):
|
| 17 |
+
"""Initial Stage Generator Module"""
|
| 18 |
+
|
| 19 |
+
# pylint: disable=too-many-instance-attributes
|
| 20 |
+
# pylint: disable=too-many-arguments
|
| 21 |
+
# pylint: disable=invalid-name
|
| 22 |
+
# pylint: disable=too-many-locals
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
:param Ng: Number of channels.
|
| 29 |
+
:param Ng_init: Initial value of Ng, this is output channel of first image upsample.
|
| 30 |
+
:param conditioning_dim: Dimension of the conditioning space
|
| 31 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
| 32 |
+
:param noise_dim: Dimension of the noise space
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.gf_dim = Ng
|
| 36 |
+
self.gf_init = Ng_init
|
| 37 |
+
self.in_dim = noise_dim + conditioning_dim + D
|
| 38 |
+
self.text_dim = D
|
| 39 |
+
|
| 40 |
+
self.define_module()
|
| 41 |
+
|
| 42 |
+
def define_module(self) -> None:
|
| 43 |
+
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
|
| 44 |
+
nz, ng = self.in_dim, self.gf_dim
|
| 45 |
+
self.fully_connect = nn.Sequential(
|
| 46 |
+
nn.Linear(nz, ng * 4 * 4 * 2, bias=False),
|
| 47 |
+
nn.BatchNorm1d(ng * 4 * 4 * 2),
|
| 48 |
+
nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64.
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.upsample1 = up_sample(ng, ng // 2)
|
| 52 |
+
self.upsample2 = up_sample(ng // 2, ng // 4)
|
| 53 |
+
self.upsample3 = up_sample(ng // 4, ng // 8)
|
| 54 |
+
self.upsample4 = up_sample(
|
| 55 |
+
ng // 8 * 3, ng // 16
|
| 56 |
+
) # multiply channel by 3 because concat spatial and channel att
|
| 57 |
+
|
| 58 |
+
self.residual = self._make_layer(ResidualBlock, ng // 8 * 3)
|
| 59 |
+
self.acm_module = ACM(self.gf_init, ng // 8 * 3)
|
| 60 |
+
|
| 61 |
+
self.spatial_att = SpatialAttention(self.text_dim, ng // 8)
|
| 62 |
+
self.channel_att = ChannelWiseAttention(
|
| 63 |
+
32 * 32, self.text_dim
|
| 64 |
+
) # 32 x 32 is the feature map size
|
| 65 |
+
|
| 66 |
+
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
|
| 67 |
+
layers = []
|
| 68 |
+
for _ in range(2): # number of residual blocks hardcoded to 2
|
| 69 |
+
layers.append(block(channel_num))
|
| 70 |
+
return nn.Sequential(*layers)
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
noise: torch.Tensor,
|
| 75 |
+
condition: torch.Tensor,
|
| 76 |
+
global_inception: torch.Tensor,
|
| 77 |
+
local_upsampled_inception: torch.Tensor,
|
| 78 |
+
word_embeddings: torch.Tensor,
|
| 79 |
+
mask: Optional[torch.Tensor] = None,
|
| 80 |
+
) -> Any:
|
| 81 |
+
"""
|
| 82 |
+
:param noise: Noise tensor
|
| 83 |
+
:param condition: Condition tensor (c^ from stackGAN++ paper)
|
| 84 |
+
:param global_inception: Global inception feature
|
| 85 |
+
:param local_upsampled_inception: Local inception feature, upsampled to 32 x 32
|
| 86 |
+
:param word_embeddings: Word embeddings [shape: D x L or D x T]
|
| 87 |
+
:param mask: Mask for padding tokens
|
| 88 |
+
:return: Hidden Image feature map Tensor of 64 x 64 size
|
| 89 |
+
"""
|
| 90 |
+
noise_concat = torch.cat((noise, condition), 1)
|
| 91 |
+
inception_concat = torch.cat((noise_concat, global_inception), 1)
|
| 92 |
+
hidden = self.fully_connect(inception_concat)
|
| 93 |
+
hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map
|
| 94 |
+
hidden = self.upsample1(hidden)
|
| 95 |
+
hidden = self.upsample2(hidden)
|
| 96 |
+
hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32)
|
| 97 |
+
hidden_32_view = hidden_32.view(
|
| 98 |
+
hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3]
|
| 99 |
+
) # this reshaping is done as attention module expects this shape.
|
| 100 |
+
|
| 101 |
+
spatial_att_feat = self.spatial_att(
|
| 102 |
+
word_embeddings, hidden_32_view, mask
|
| 103 |
+
) # spatial att shape: (batch, D^, 32 * 32)
|
| 104 |
+
channel_att_feat = self.channel_att(
|
| 105 |
+
spatial_att_feat, word_embeddings
|
| 106 |
+
) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper
|
| 107 |
+
spatial_att_feat = spatial_att_feat.view(
|
| 108 |
+
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
|
| 109 |
+
) # reshape to (batch, D^, 32, 32)
|
| 110 |
+
channel_att_feat = channel_att_feat.view(
|
| 111 |
+
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
|
| 112 |
+
) # reshape to (batch, D^, 32, 32)
|
| 113 |
+
|
| 114 |
+
spatial_concat = torch.cat(
|
| 115 |
+
(hidden_32, spatial_att_feat), 1
|
| 116 |
+
) # concat spatial attention feature with hidden_32
|
| 117 |
+
attn_concat = torch.cat(
|
| 118 |
+
(spatial_concat, channel_att_feat), 1
|
| 119 |
+
) # concat channel and spatial attention feature
|
| 120 |
+
|
| 121 |
+
hidden_32 = self.acm_module(attn_concat, local_upsampled_inception)
|
| 122 |
+
hidden_32 = self.residual(hidden_32)
|
| 123 |
+
hidden_64 = self.upsample4(hidden_32)
|
| 124 |
+
return hidden_64
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class NextStageG(nn.Module):
|
| 128 |
+
"""Next Stage Generator Module"""
|
| 129 |
+
|
| 130 |
+
# pylint: disable=too-many-instance-attributes
|
| 131 |
+
# pylint: disable=too-many-arguments
|
| 132 |
+
# pylint: disable=invalid-name
|
| 133 |
+
# pylint: disable=too-many-locals
|
| 134 |
+
|
| 135 |
+
def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int):
|
| 136 |
+
"""
|
| 137 |
+
:param Ng: Number of channels.
|
| 138 |
+
:param Ng_init: Initial value of Ng.
|
| 139 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
| 140 |
+
:param image_size: Size of the output image from previous generator stage.
|
| 141 |
+
"""
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.gf_dim = Ng
|
| 144 |
+
self.gf_init = Ng_init
|
| 145 |
+
self.text_dim = D
|
| 146 |
+
self.img_size = image_size
|
| 147 |
+
|
| 148 |
+
self.define_module()
|
| 149 |
+
|
| 150 |
+
def define_module(self) -> None:
|
| 151 |
+
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
|
| 152 |
+
ng = self.gf_dim
|
| 153 |
+
self.spatial_att = SpatialAttention(self.text_dim, ng)
|
| 154 |
+
self.channel_att = ChannelWiseAttention(
|
| 155 |
+
self.img_size * self.img_size, self.text_dim
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.residual = self._make_layer(ResidualBlock, ng * 3)
|
| 159 |
+
self.upsample = up_sample(ng * 3, ng)
|
| 160 |
+
self.acm_module = ACM(self.gf_init, ng * 3)
|
| 161 |
+
self.upsample2 = up_sample(ng, ng)
|
| 162 |
+
|
| 163 |
+
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
|
| 164 |
+
layers = []
|
| 165 |
+
for _ in range(2): # no of residual layers hardcoded to 2
|
| 166 |
+
layers.append(block(channel_num))
|
| 167 |
+
return nn.Sequential(*layers)
|
| 168 |
+
|
| 169 |
+
def forward(
|
| 170 |
+
self,
|
| 171 |
+
hidden_feat: Any,
|
| 172 |
+
word_embeddings: torch.Tensor,
|
| 173 |
+
vgg64_feat: torch.Tensor,
|
| 174 |
+
mask: Optional[torch.Tensor] = None,
|
| 175 |
+
) -> Any:
|
| 176 |
+
"""
|
| 177 |
+
:param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64]
|
| 178 |
+
:param word_embeddings: Word embeddings
|
| 179 |
+
:param vgg64_feat: VGG feature map of size 64 x 64
|
| 180 |
+
:param mask: Mask for the padding tokens
|
| 181 |
+
:return: Image feature map of size 256 x 256
|
| 182 |
+
"""
|
| 183 |
+
hidden_view = hidden_feat.view(
|
| 184 |
+
hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3]
|
| 185 |
+
) # reshape to pass into attention modules.
|
| 186 |
+
spatial_att_feat = self.spatial_att(
|
| 187 |
+
word_embeddings, hidden_view, mask
|
| 188 |
+
) # spatial att shape: (batch, D^, 64 * 64), or D^ x N
|
| 189 |
+
channel_att_feat = self.channel_att(
|
| 190 |
+
spatial_att_feat, word_embeddings
|
| 191 |
+
) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper
|
| 192 |
+
spatial_att_feat = spatial_att_feat.view(
|
| 193 |
+
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
|
| 194 |
+
) # reshape to (batch, D^, 64, 64)
|
| 195 |
+
channel_att_feat = channel_att_feat.view(
|
| 196 |
+
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
|
| 197 |
+
) # reshape to (batch, D^, 64, 64)
|
| 198 |
+
|
| 199 |
+
spatial_concat = torch.cat(
|
| 200 |
+
(hidden_feat, spatial_att_feat), 1
|
| 201 |
+
) # concat spatial attention feature with hidden_64
|
| 202 |
+
attn_concat = torch.cat(
|
| 203 |
+
(spatial_concat, channel_att_feat), 1
|
| 204 |
+
) # concat channel and spatial attention feature
|
| 205 |
+
|
| 206 |
+
hidden_64 = self.acm_module(attn_concat, vgg64_feat)
|
| 207 |
+
hidden_64 = self.residual(hidden_64)
|
| 208 |
+
hidden_128 = self.upsample(hidden_64)
|
| 209 |
+
hidden_256 = self.upsample2(hidden_128)
|
| 210 |
+
return hidden_256
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class GetImageG(nn.Module):
|
| 214 |
+
"""Generates the Final Fake Image from the Image Feature Map"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, Ng: int):
|
| 217 |
+
"""
|
| 218 |
+
:param Ng: Number of channels.
|
| 219 |
+
"""
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.img = nn.Sequential(
|
| 222 |
+
nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh()
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def forward(self, hidden_feat: torch.Tensor) -> Any:
|
| 226 |
+
"""
|
| 227 |
+
:param hidden_feat: Image feature map
|
| 228 |
+
:return: Final fake image
|
| 229 |
+
"""
|
| 230 |
+
return self.img(hidden_feat)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class Generator(nn.Module):
|
| 234 |
+
"""Generator Module"""
|
| 235 |
+
|
| 236 |
+
# pylint: disable=too-many-instance-attributes
|
| 237 |
+
# pylint: disable=too-many-arguments
|
| 238 |
+
# pylint: disable=invalid-name
|
| 239 |
+
# pylint: disable=too-many-locals
|
| 240 |
+
|
| 241 |
+
def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int):
|
| 242 |
+
"""
|
| 243 |
+
:param Ng: Number of channels. [Taken from StackGAN++ paper]
|
| 244 |
+
:param D: Dimension of the text embedding space
|
| 245 |
+
:param conditioning_dim: Dimension of the conditioning space
|
| 246 |
+
:param noise_dim: Dimension of the noise space
|
| 247 |
+
"""
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.cond_augment = CondAugmentation(D, conditioning_dim)
|
| 250 |
+
self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim)
|
| 251 |
+
self.inception_img_upsample = img_up_block(
|
| 252 |
+
D, Ng
|
| 253 |
+
) # as channel size returned by inception encoder is D (Default in paper: 256)
|
| 254 |
+
self.hidden_net2 = NextStageG(Ng, Ng, D, 64)
|
| 255 |
+
self.generate_img = GetImageG(Ng)
|
| 256 |
+
|
| 257 |
+
self.acm_module = ACM(Ng, Ng)
|
| 258 |
+
|
| 259 |
+
self.vgg_downsample = down_sample(D // 2, Ng)
|
| 260 |
+
self.upsample1 = up_sample(Ng, Ng)
|
| 261 |
+
self.upsample2 = up_sample(Ng, Ng)
|
| 262 |
+
|
| 263 |
+
def forward(
|
| 264 |
+
self,
|
| 265 |
+
noise: torch.Tensor,
|
| 266 |
+
sentence_embeddings: torch.Tensor,
|
| 267 |
+
word_embeddings: torch.Tensor,
|
| 268 |
+
global_inception_feat: torch.Tensor,
|
| 269 |
+
local_inception_feat: torch.Tensor,
|
| 270 |
+
vgg_feat: torch.Tensor,
|
| 271 |
+
mask: Optional[torch.Tensor] = None,
|
| 272 |
+
) -> Any:
|
| 273 |
+
"""
|
| 274 |
+
:param noise: Noise vector [shape: (batch, noise_dim)]
|
| 275 |
+
:param sentence_embeddings: Sentence embeddings [shape: (batch, D)]
|
| 276 |
+
:param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence]
|
| 277 |
+
:param global_inception_feat: Global Inception feature map [shape: (batch, D)]
|
| 278 |
+
:param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)]
|
| 279 |
+
:param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)]
|
| 280 |
+
:param mask: Mask for the padding tokens
|
| 281 |
+
:return: Final fake image
|
| 282 |
+
"""
|
| 283 |
+
c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings)
|
| 284 |
+
hidden_32 = self.inception_img_upsample(local_inception_feat)
|
| 285 |
+
|
| 286 |
+
hidden_64 = self.hidden_net1(
|
| 287 |
+
noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
vgg_64 = self.vgg_downsample(vgg_feat)
|
| 291 |
+
|
| 292 |
+
hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask)
|
| 293 |
+
|
| 294 |
+
vgg_128 = self.upsample1(vgg_64)
|
| 295 |
+
vgg_256 = self.upsample2(vgg_128)
|
| 296 |
+
|
| 297 |
+
hidden_256 = self.acm_module(hidden_256, vgg_256)
|
| 298 |
+
fake_img = self.generate_img(hidden_256)
|
| 299 |
+
|
| 300 |
+
return fake_img, mu_tensor, logvar
|
src/models/modules/image_encoder.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image Encoder Module"""
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from src.models.modules.conv_utils import conv2d
|
| 8 |
+
|
| 9 |
+
# build inception v3 image encoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class InceptionEncoder(nn.Module):
|
| 13 |
+
"""Image Encoder Module adapted from AttnGAN"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, D: int):
|
| 16 |
+
"""
|
| 17 |
+
:param D: Dimension of the text embedding space [D from AttnGAN paper]
|
| 18 |
+
"""
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.text_emb_dim = D
|
| 22 |
+
|
| 23 |
+
model = torch.hub.load(
|
| 24 |
+
"pytorch/vision:v0.10.0", "inception_v3", pretrained=True
|
| 25 |
+
)
|
| 26 |
+
for param in model.parameters():
|
| 27 |
+
param.requires_grad = False
|
| 28 |
+
|
| 29 |
+
self.define_module(model)
|
| 30 |
+
self.init_trainable_weights()
|
| 31 |
+
|
| 32 |
+
def define_module(self, model: nn.Module) -> None:
|
| 33 |
+
"""
|
| 34 |
+
This function defines the modules of the image encoder
|
| 35 |
+
:param model: Pretrained Inception V3 model
|
| 36 |
+
"""
|
| 37 |
+
model.cust_upsample = nn.Upsample(size=(299, 299), mode="bilinear")
|
| 38 |
+
model.cust_maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
|
| 39 |
+
model.cust_maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
|
| 40 |
+
model.cust_avgpool = nn.AvgPool2d(kernel_size=8)
|
| 41 |
+
|
| 42 |
+
attribute_list = [
|
| 43 |
+
"cust_upsample",
|
| 44 |
+
"Conv2d_1a_3x3",
|
| 45 |
+
"Conv2d_2a_3x3",
|
| 46 |
+
"Conv2d_2b_3x3",
|
| 47 |
+
"cust_maxpool1",
|
| 48 |
+
"Conv2d_3b_1x1",
|
| 49 |
+
"Conv2d_4a_3x3",
|
| 50 |
+
"cust_maxpool2",
|
| 51 |
+
"Mixed_5b",
|
| 52 |
+
"Mixed_5c",
|
| 53 |
+
"Mixed_5d",
|
| 54 |
+
"Mixed_6a",
|
| 55 |
+
"Mixed_6b",
|
| 56 |
+
"Mixed_6c",
|
| 57 |
+
"Mixed_6d",
|
| 58 |
+
"Mixed_6e",
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
self.feature_extractor = nn.Sequential(
|
| 62 |
+
*[getattr(model, name) for name in attribute_list]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
attribute_list2 = ["Mixed_7a", "Mixed_7b", "Mixed_7c", "cust_avgpool"]
|
| 66 |
+
|
| 67 |
+
self.feature_extractor2 = nn.Sequential(
|
| 68 |
+
*[getattr(model, name) for name in attribute_list2]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.emb_features = conv2d(
|
| 72 |
+
768, self.text_emb_dim, kernel_size=1, stride=1, padding=0
|
| 73 |
+
)
|
| 74 |
+
self.emb_cnn_code = nn.Linear(2048, self.text_emb_dim)
|
| 75 |
+
|
| 76 |
+
def init_trainable_weights(self) -> None:
|
| 77 |
+
"""
|
| 78 |
+
This function initializes the trainable weights of the image encoder
|
| 79 |
+
"""
|
| 80 |
+
initrange = 0.1
|
| 81 |
+
self.emb_features.weight.data.uniform_(-initrange, initrange)
|
| 82 |
+
self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
|
| 83 |
+
|
| 84 |
+
def forward(self, image_tensor: torch.Tensor) -> Any:
|
| 85 |
+
"""
|
| 86 |
+
:param image_tensor: Input image
|
| 87 |
+
:return: features: local feature matrix (v from attnGAN paper) [shape: (batch, D, 17, 17)]
|
| 88 |
+
:return: cnn_code: global image feature (v^ from attnGAN paper) [shape: (batch, D)]
|
| 89 |
+
"""
|
| 90 |
+
# this is the image size
|
| 91 |
+
# x.shape: 10 3 256 256
|
| 92 |
+
|
| 93 |
+
features = self.feature_extractor(image_tensor)
|
| 94 |
+
# 17 x 17 x 768
|
| 95 |
+
|
| 96 |
+
image_tensor = self.feature_extractor2(features)
|
| 97 |
+
|
| 98 |
+
image_tensor = image_tensor.view(image_tensor.size(0), -1)
|
| 99 |
+
# 2048
|
| 100 |
+
|
| 101 |
+
# global image features
|
| 102 |
+
cnn_code = self.emb_cnn_code(image_tensor)
|
| 103 |
+
|
| 104 |
+
if features is not None:
|
| 105 |
+
features = self.emb_features(features)
|
| 106 |
+
|
| 107 |
+
# feature.shape: 10 256 17 17
|
| 108 |
+
# cnn_code.shape: 10 256
|
| 109 |
+
return features, cnn_code
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class VGGEncoder(nn.Module):
|
| 113 |
+
"""Pre Trained VGG Encoder Module"""
|
| 114 |
+
|
| 115 |
+
def __init__(self) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Initialize pre-trained VGG model with frozen parameters
|
| 118 |
+
"""
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.select = "8" ## We want to get the output of the 8th layer in VGG.
|
| 121 |
+
|
| 122 |
+
self.model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True)
|
| 123 |
+
|
| 124 |
+
for param in self.model.parameters():
|
| 125 |
+
param.resquires_grad = False
|
| 126 |
+
|
| 127 |
+
self.vgg_modules = self.model.features._modules
|
| 128 |
+
|
| 129 |
+
def forward(self, image_tensor: torch.Tensor) -> Any:
|
| 130 |
+
"""
|
| 131 |
+
:param x: Input image tensor [shape: (batch, 3, 256, 256)]
|
| 132 |
+
:return: VGG features [shape: (batch, 128, 128, 128)]
|
| 133 |
+
"""
|
| 134 |
+
for name, layer in self.vgg_modules.items():
|
| 135 |
+
image_tensor = layer(image_tensor)
|
| 136 |
+
if name == self.select:
|
| 137 |
+
return image_tensor
|
| 138 |
+
return None
|