AaSiKu commited on
Commit
844e1a2
·
verified ·
1 Parent(s): 481822c

Upload 4 files

Browse files
HelperFunctions.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import BertTokenizer, BertModel
7
+ import io
8
+
9
+
10
+ class ConditionalAugmentation(nn.Module):
11
+ def __init__(self, text_dim, projected_dim):
12
+ super(ConditionalAugmentation, self).__init__()
13
+ self.proj = nn.Linear(text_dim, projected_dim * 2)
14
+
15
+ def forward(self, text_embedding):
16
+ mu_logvar = self.proj(text_embedding)
17
+ mu, logvar = mu_logvar.chunk(2, dim=1)
18
+ std = torch.exp(0.5 * logvar)
19
+ eps = torch.randn_like(std)
20
+ return mu + eps * std
21
+
22
+
23
+ class Stage1Generator(nn.Module):
24
+ def __init__(self, text_embedding_dim, noise_dim, img_size):
25
+ super(Stage1Generator, self).__init__()
26
+ self.fc1 = nn.Linear(768 + noise_dim, 128 * 8 * 8)
27
+ self.reduced_embeddings = nn.Linear(text_embedding_dim, 128)
28
+ self.bn1 = nn.BatchNorm1d(128 * 8 * 8)
29
+ self.relu = nn.ReLU(inplace=True)
30
+ self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
31
+ self.bn2 = nn.BatchNorm2d(64)
32
+ self.upsample2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
33
+ self.bn3 = nn.BatchNorm2d(32)
34
+ self.upsample3 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1)
35
+ self.tanh = nn.Tanh()
36
+ self.augment = ConditionalAugmentation(768,768)
37
+ self.img_size = img_size
38
+
39
+
40
+ def forward(self, text_embedding, noise):
41
+
42
+ x = self.augment(text_embedding)
43
+ x = torch.cat((x, noise), dim=1)
44
+ x = self.relu(self.bn1(self.fc1(x)))
45
+ x = x.view(-1, 128, 8, 8)
46
+ x = self.relu(self.bn2(self.upsample1(x)))
47
+ x = self.relu(self.bn3(self.upsample2(x)))
48
+ x = self.tanh(self.upsample3(x))
49
+ return x
50
+
51
+
52
+ stage1_generator = Stage1Generator(text_embedding_dim=768, noise_dim=100, img_size=64)
53
+
54
+
55
+
56
+ class Stage2Generator(nn.Module):
57
+ def __init__(self, text_embedding_dim, img_size):
58
+ super(Stage2Generator, self).__init__()
59
+ self.fc1 = nn.Linear(text_embedding_dim + 3 * img_size * img_size, 128 * 16 * 16)
60
+ self.bn1 = nn.BatchNorm1d(128 * 16 * 16)
61
+ self.relu = nn.ReLU(inplace=True)
62
+ self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
63
+ self.bn2 = nn.BatchNorm2d(64)
64
+ self.upsample2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
65
+ self.bn3 = nn.BatchNorm2d(32)
66
+ self.upsample3 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1)
67
+ self.tanh = nn.Tanh()
68
+ self.augment = ConditionalAugmentation(768,768)
69
+ self.img_size = img_size
70
+
71
+ def forward(self, text_embedding, stage1_img):
72
+ stage1_img_flat = stage1_img.view(stage1_img.size(0), -1)
73
+ text_embedding = self.augment(text_embedding)
74
+ x = torch.cat((text_embedding, stage1_img_flat), dim=1)
75
+ x = self.relu(self.bn1(self.fc1(x)))
76
+ x = x.view(-1, 128, 16, 16)
77
+ x = self.relu(self.bn2(self.upsample1(x)))
78
+ x = self.relu(self.bn3(self.upsample2(x)))
79
+ x = self.tanh(self.upsample3(x))
80
+ return x
81
+
82
+
83
+ stage2_generator = Stage2Generator(text_embedding_dim=768, img_size=64)
84
+ # Set the model to evaluation mode
85
+ stage1_generator.eval()
86
+ stage2_generator.eval()
87
+ device = 'cpu'
88
+ stage1_generator.load_state_dict(torch.load('Weights/stage1Generator_weights.pth',map_location=device))
89
+ stage2_generator.load_state_dict(torch.load('Weights/stage2Generator_weights_UPDATED.pth',map_location=device))
90
+ print("Models loaded successfully")
91
+
92
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
93
+ bert_model = BertModel.from_pretrained('bert-base-uncased').eval()
94
+
95
+ print("bert loaded")
96
+
97
+
98
+ def Tokenize(sentence):
99
+ encoded_input = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=64)
100
+ with torch.no_grad():
101
+ model_output = bert_model(**encoded_input)
102
+ text_embedding = model_output.last_hidden_state.mean(dim=1).squeeze()
103
+
104
+ return text_embedding.unsqueeze(0)
105
+
106
+
107
+
108
+ def generate_images(text_embeddings):
109
+ noise = torch.randn(1, 100)
110
+ with torch.no_grad():
111
+ Image_stage1 = stage1_generator(text_embeddings,noise)
112
+ Image_stage2 = stage2_generator(text_embeddings,Image_stage1)
113
+ print(Image_stage2.shape)
114
+ return Image_stage2.squeeze()
115
+
116
+ # def display_images(image, title="Generated Images"):
117
+ # # Display a grid of images using matplotlib
118
+ # # fig, axes = plt.subplots(4, 4, figsize=(10, 10))
119
+ # # for i, ax in enumerate(axes.flatten()):
120
+ # # ax.axis('off')
121
+ # image = image.permute(1, 2, 0).to('cpu').detach().numpy()
122
+ # plt.imshow(image)
123
+ # # ax.imshow(image)
124
+ # plt.show()
125
+ # # ax.axis('off')
126
+ # # plt.imshow
127
+
128
+
129
+ # st.title("Pokémon Image Generator")
130
+
131
+ # st.markdown('<div class="custom-label">Enter a sentence:</div>', unsafe_allow_html=True)
132
+ # input_text = st.text_input("", key='input', help='Type your sentence here', label_visibility='collapsed')
133
+ # # sentence = "A cheerful Bulbasaur ready for its next Pokémon adventure."
134
+ # # generate_images(Tokenize(sentence))
135
+ # # display_images(generate_images(Tokenize(input_text)))
136
+ # # print(Tokenize(sentence).shape)
137
+
138
+
139
+ # # Generate images
140
+ # st.write("Generating images...")
141
+ # # # # Replace with actual text embeddings input
142
+ # # # text_embeddings = torch.randn(16, 1024) # Placeholder, use actual text embeddings
143
+
144
+ # if st.button("Generate Image"):
145
+ # if input_text:
146
+ # # generated_image = generate_image(input_text)
147
+ # generated_image = generate_images(Tokenize(input_text))
148
+ # img_bytes = io.BytesIO()
149
+ # generated_image.save(img_bytes, format='PNG')
150
+ # img_bytes.seek(0)
151
+
152
+ # st.image(img_bytes, caption="Generated Image", use_column_width=True)
153
+
154
+ # else:
155
+ # st.error("Please enter a sentence.")
156
+
157
+ # image = generate_images(Tokenize(input_text))
158
+
159
+ # # # # Display images
160
+ # st.write("Displaying images...")
161
+ # display_images(image)
162
+
163
+ # # # if __name__ == '__main__':
164
+ # # # st.write("Streamlit app for image generation.")
165
+ # # print("hello")
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from HelperFunctions import *
2
+ import streamlit as st
3
+ from PIL import Image
4
+ import io
5
+
6
+ def Generate_image(text):
7
+ # Dummy implementation for illustration
8
+ # Replace this with your actual image generation model
9
+ embeddings = Tokenize(text)
10
+ tensor = generate_images(embeddings)
11
+ tensor = tensor.squeeze().permute(1, 2, 0)
12
+ # img = Image.new('RGB', (200, 100), color = (73, 109, 137))
13
+ tensor = (tensor * 255).clamp(0, 255).byte() # Scale to [0, 255] and convert to byte
14
+ array = tensor.cpu().numpy() # Convert to NumPy array
15
+ return Image.fromarray(array)
16
+
17
+ return img
18
+
19
+ st.markdown(
20
+ """
21
+ <style>
22
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
23
+
24
+ .custom-text-input > div > input {
25
+ font-size: 25px !important;
26
+ height: 50px !important;
27
+ font-family: 'Roboto', sans-serif;
28
+ }
29
+ .custom-label {
30
+ font-size: 25px !important;
31
+ font-weight: bold;
32
+ font-family: 'Roboto', sans-serif;
33
+ }
34
+ </style>
35
+ """,
36
+ unsafe_allow_html=True,
37
+ )
38
+
39
+ st.title("Pokémon Image Generator")
40
+
41
+ st.markdown('<div class="custom-label">Enter a sentence:</div>', unsafe_allow_html=True)
42
+ input_text = st.text_input("", key='input', help='Type your sentence here', label_visibility='collapsed')
43
+
44
+
45
+ st.markdown(
46
+ """
47
+ <style>
48
+ .custom-text-input {
49
+ display: flex;
50
+ flex-direction: column;
51
+ }
52
+ .custom-text-input label {
53
+ font-size: 20px;
54
+ }
55
+ .custom-text-input > div > input {
56
+ font-size: 20px !important;
57
+ height: 50px !important;
58
+ }
59
+ </style>
60
+ """,
61
+ unsafe_allow_html=True,
62
+ )
63
+
64
+ if st.button("Generate Image"):
65
+ if input_text:
66
+ generated_image = Generate_image(input_text)
67
+
68
+ img_bytes = io.BytesIO()
69
+ generated_image.save(img_bytes, format='PNG')
70
+ img_bytes.seek(0)
71
+
72
+ st.image(img_bytes, caption="Generated Image", use_column_width=True)
73
+
74
+ else:
75
+ st.error("Please enter a sentence.")
76
+
stage1Generator_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb9697568b187e7ff0719e6fb1cc1422b0ae66693a08aefbcf8de1d24fa1a2c7
3
+ size 34397621
stage2Generator_weights_UPDATED.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c60ff6e5b578e563f291f95b581db8e79cde6abd88cf6cf7ad829b5ea5963118
3
+ size 1717328387