File size: 6,062 Bytes
844e1a2 8dee918 844e1a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import streamlit as st
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import io
class ConditionalAugmentation(nn.Module):
def __init__(self, text_dim, projected_dim):
super(ConditionalAugmentation, self).__init__()
self.proj = nn.Linear(text_dim, projected_dim * 2)
def forward(self, text_embedding):
mu_logvar = self.proj(text_embedding)
mu, logvar = mu_logvar.chunk(2, dim=1)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
class Stage1Generator(nn.Module):
def __init__(self, text_embedding_dim, noise_dim, img_size):
super(Stage1Generator, self).__init__()
self.fc1 = nn.Linear(768 + noise_dim, 128 * 8 * 8)
self.reduced_embeddings = nn.Linear(text_embedding_dim, 128)
self.bn1 = nn.BatchNorm1d(128 * 8 * 8)
self.relu = nn.ReLU(inplace=True)
self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.upsample2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(32)
self.upsample3 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1)
self.tanh = nn.Tanh()
self.augment = ConditionalAugmentation(768,768)
self.img_size = img_size
def forward(self, text_embedding, noise):
x = self.augment(text_embedding)
x = torch.cat((x, noise), dim=1)
x = self.relu(self.bn1(self.fc1(x)))
x = x.view(-1, 128, 8, 8)
x = self.relu(self.bn2(self.upsample1(x)))
x = self.relu(self.bn3(self.upsample2(x)))
x = self.tanh(self.upsample3(x))
return x
stage1_generator = Stage1Generator(text_embedding_dim=768, noise_dim=100, img_size=64)
class Stage2Generator(nn.Module):
def __init__(self, text_embedding_dim, img_size):
super(Stage2Generator, self).__init__()
self.fc1 = nn.Linear(text_embedding_dim + 3 * img_size * img_size, 128 * 16 * 16)
self.bn1 = nn.BatchNorm1d(128 * 16 * 16)
self.relu = nn.ReLU(inplace=True)
self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.upsample2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(32)
self.upsample3 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1)
self.tanh = nn.Tanh()
self.augment = ConditionalAugmentation(768,768)
self.img_size = img_size
def forward(self, text_embedding, stage1_img):
stage1_img_flat = stage1_img.view(stage1_img.size(0), -1)
text_embedding = self.augment(text_embedding)
x = torch.cat((text_embedding, stage1_img_flat), dim=1)
x = self.relu(self.bn1(self.fc1(x)))
x = x.view(-1, 128, 16, 16)
x = self.relu(self.bn2(self.upsample1(x)))
x = self.relu(self.bn3(self.upsample2(x)))
x = self.tanh(self.upsample3(x))
return x
stage2_generator = Stage2Generator(text_embedding_dim=768, img_size=64)
# Set the model to evaluation mode
stage1_generator.eval()
stage2_generator.eval()
device = 'cpu'
stage1_generator.load_state_dict(torch.load('stage1Generator_weights.pth',map_location=device))
stage2_generator.load_state_dict(torch.load('stage2Generator_weights_UPDATED.pth',map_location=device))
print("Models loaded successfully")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').eval()
print("bert loaded")
def Tokenize(sentence):
encoded_input = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=64)
with torch.no_grad():
model_output = bert_model(**encoded_input)
text_embedding = model_output.last_hidden_state.mean(dim=1).squeeze()
return text_embedding.unsqueeze(0)
def generate_images(text_embeddings):
noise = torch.randn(1, 100)
with torch.no_grad():
Image_stage1 = stage1_generator(text_embeddings,noise)
Image_stage2 = stage2_generator(text_embeddings,Image_stage1)
print(Image_stage2.shape)
return Image_stage2.squeeze()
# def display_images(image, title="Generated Images"):
# # Display a grid of images using matplotlib
# # fig, axes = plt.subplots(4, 4, figsize=(10, 10))
# # for i, ax in enumerate(axes.flatten()):
# # ax.axis('off')
# image = image.permute(1, 2, 0).to('cpu').detach().numpy()
# plt.imshow(image)
# # ax.imshow(image)
# plt.show()
# # ax.axis('off')
# # plt.imshow
# st.title("Pokémon Image Generator")
# st.markdown('<div class="custom-label">Enter a sentence:</div>', unsafe_allow_html=True)
# input_text = st.text_input("", key='input', help='Type your sentence here', label_visibility='collapsed')
# # sentence = "A cheerful Bulbasaur ready for its next Pokémon adventure."
# # generate_images(Tokenize(sentence))
# # display_images(generate_images(Tokenize(input_text)))
# # print(Tokenize(sentence).shape)
# # Generate images
# st.write("Generating images...")
# # # # Replace with actual text embeddings input
# # # text_embeddings = torch.randn(16, 1024) # Placeholder, use actual text embeddings
# if st.button("Generate Image"):
# if input_text:
# # generated_image = generate_image(input_text)
# generated_image = generate_images(Tokenize(input_text))
# img_bytes = io.BytesIO()
# generated_image.save(img_bytes, format='PNG')
# img_bytes.seek(0)
# st.image(img_bytes, caption="Generated Image", use_column_width=True)
# else:
# st.error("Please enter a sentence.")
# image = generate_images(Tokenize(input_text))
# # # # Display images
# st.write("Displaying images...")
# display_images(image)
# # # if __name__ == '__main__':
# # # st.write("Streamlit app for image generation.")
# # print("hello") |