Lasercatz commited on
Commit
e99de36
·
verified ·
1 Parent(s): cfe8cc6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -102
app.py CHANGED
@@ -1,30 +1,15 @@
1
- max_description_length=100
2
-
3
-
4
-
5
-
6
-
7
-
8
  from skimage import color
9
  import numpy as np
10
 
11
- def hex_to_rgb(hex_color):
12
- hex_color = hex_color.lstrip('#')
13
- return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
 
14
 
15
  def rgb_to_hex(rgb_array):
16
- return "#{:02x}{:02x}{:02x}".format(*rgb_array)
17
 
18
- def rgb_to_normalized_lab(rgb_array):
19
- rgb_array = np.array(rgb_array, dtype=np.float32) / 255.0
20
-
21
- if rgb_array.ndim == 1:
22
- rgb_array = rgb_array.reshape(1, 3)
23
 
24
- lab_array = color.rgb2lab(rgb_array)
25
- lab_array[:,0] /= 100.0
26
- lab_array[:,1:] /= 127.0
27
- return tuple(lab_array.squeeze())
28
 
29
  def normalized_lab_to_rgb(lab_array):
30
  lab_array = np.array(lab_array, dtype=np.float32)
@@ -42,19 +27,19 @@ def normalized_lab_to_rgb(lab_array):
42
  return tuple(rgb_array.squeeze())
43
 
44
 
45
- import torch
46
- import torch.nn as nn
47
- import torch.nn.functional as F
48
- from transformers import RobertaModel, RobertaTokenizer
49
  from huggingface_hub import hf_hub_download
50
 
51
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
52
- print(f'Using {device}')
53
 
54
- model_path = hf_hub_download(repo_id="lasercatz/text2palette", filename="epoch_13.pth")
55
 
56
 
57
 
 
 
 
 
58
 
59
  class AttentionPooling(nn.Module):
60
  def __init__(self, d_model):
@@ -68,6 +53,7 @@ class AttentionPooling(nn.Module):
68
  weights = F.softmax(scores, dim=-1).unsqueeze(-1)
69
  return torch.sum(x * weights, dim=1)
70
 
 
71
  class SequencePriorNet(nn.Module):
72
  def __init__(self, d_model, d_z, n_heads=4):
73
  super().__init__()
@@ -78,53 +64,56 @@ class SequencePriorNet(nn.Module):
78
  self.dropout = nn.Dropout(0.3)
79
 
80
  def forward(self, text_feats, attention_mask):
81
- attn_output, _ = self.attn(text_feats, text_feats, text_feats, key_padding_mask=~attention_mask.bool())
 
82
  x = self.norm(attn_output + text_feats)
83
  x = self.dropout(x)
84
  x = self.pool(x, attention_mask)
85
  x = self.fc(x)
86
  return x
87
 
88
- class TextToPaletteModel(nn.Module):
89
- def __init__(self, d_model=768, d_z=256, max_text_len=max_description_length, max_seq_len=10,
 
90
  n_layers=8, n_heads=8, dim_ff=3072):
91
  super().__init__()
92
  self.d_model = d_model
93
- self.max_text_len= max_text_len
94
  self.max_seq_len = max_seq_len
95
 
96
- # Text encoder
97
- self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
98
- self.roberta = RobertaModel.from_pretrained('roberta-base')
 
 
 
 
 
99
  self.text_proj = nn.Sequential(
100
- nn.Linear(768, d_model*2),
101
  nn.GELU(),
102
  nn.LayerNorm(d_model*2),
103
  nn.Dropout(0.3),
104
  nn.Linear(d_model*2, d_model)
105
  )
106
 
107
- # Color processing
108
  self.color_embed = nn.Sequential(
109
  nn.Linear(3, d_model),
110
  nn.LayerNorm(d_model),
111
  nn.GELU(),
112
  nn.Dropout(0.3)
113
  )
114
-
115
  self.cross_attn = nn.MultiheadAttention(d_model, 8, batch_first=True)
116
-
117
- # Positional embeddings
118
  self.position_embed = nn.Embedding(max_seq_len, d_model)
119
  self.start_embed = nn.Parameter(torch.randn(1, d_model))
120
 
121
- # Palette encoder
122
  self.palette_encoder = nn.TransformerEncoder(
123
- nn.TransformerEncoderLayer(d_model, n_heads, dim_ff, batch_first=True),
 
124
  n_layers
125
  )
126
 
127
- # Latent projection
128
  self.z_proj = nn.Sequential(
129
  nn.Linear(d_model*2, d_z),
130
  nn.LayerNorm(d_z),
@@ -134,13 +123,12 @@ class TextToPaletteModel(nn.Module):
134
  self.z_mu = nn.Linear(d_z, d_z)
135
  self.z_logvar = nn.Linear(d_z, d_z)
136
 
137
- # Decoder
138
  self.decoder = nn.TransformerDecoder(
139
- nn.TransformerDecoderLayer(d_model, n_heads, dim_ff, batch_first=True),
 
140
  n_layers
141
  )
142
 
143
- # Output layers
144
  self.out_mu_L = nn.Sequential(
145
  nn.Linear(d_model, 1),
146
  nn.Sigmoid()
@@ -150,79 +138,82 @@ class TextToPaletteModel(nn.Module):
150
  nn.Tanh()
151
  )
152
  self.out_logvar = nn.Linear(d_model, 3)
153
-
154
  self.prior_net = SequencePriorNet(d_model, d_z, n_heads=4)
155
 
156
- # Pooling
157
  self.text_pool = AttentionPooling(d_model)
158
  self.palette_pool = AttentionPooling(d_model)
159
 
 
 
 
 
 
 
 
 
 
160
 
161
  @torch.no_grad()
162
- def generate(self, text, palette_size,temp=1.0):
163
  self.eval()
164
- device = next(self.parameters()).device
165
- tokenized = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True,
166
- max_length=self.max_text_len).to(device)
167
-
168
- # Text features
169
- text_feats = self.roberta(**tokenized).last_hidden_state
170
  text_feats = self.text_proj(text_feats)
171
- text_pooled = self.text_pool(text_feats, tokenized['attention_mask'])
172
-
173
  # Sample from prior
174
  prior_params = self.prior_net(text_feats, tokenized['attention_mask'])
175
  prior_mu, prior_logvar = prior_params.chunk(2, -1)
176
- z = prior_mu + torch.exp(0.5 * prior_logvar) * torch.randn_like(prior_mu) * temp
 
177
  z_expanded = self.z_expand(z).unsqueeze(1)
178
 
179
- memory = torch.cat([z_expanded, text_feats], dim=1) # [1, 1 + seq_len, d_model]
 
180
  memory_key_padding_mask = torch.cat([
181
  torch.zeros((1, 1), dtype=torch.bool, device=device),
182
  ~tokenized['attention_mask'].bool()
183
- ], dim=1) # [1, 1 + seq_len]
184
 
185
- generate_size = min(palette_size, self.max_seq_len)
186
-
187
-
188
- # Generation loop
189
  colors = []
190
- current_emb = self.start_embed.unsqueeze(0)
191
-
192
- for i in range(generate_size):
193
- # Positional update
194
- pos = self.position_embed(torch.arange(0, current_emb.size(1), device=device)).unsqueeze(0) # [1, i+1, d_model]
 
 
 
195
  decoder_in = current_emb + pos # [1, i+1, d_model]
196
-
197
- # Decode
198
  output = self.decoder(
199
  decoder_in,
200
  memory,
201
- tgt_mask=nn.Transformer.generate_square_subsequent_mask(decoder_in.size(1), device=device),
 
202
  memory_key_padding_mask=memory_key_padding_mask
203
  ) # [1, i+1, d_model]
204
-
205
- # Predict color
206
- mu = torch.cat([self.out_mu_L(output[:, -1]), self.out_mu_ab(output[:, -1])], dim=-1) # [1, 3]
207
- logvar = self.out_logvar(output[:, -1]) # [1, 3]
208
-
209
 
210
- color = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu) * temp # [1, 3]
211
-
 
 
 
212
  color[:, 0].clamp_(0, 1)
213
  color[:, 1:].clamp_(-1, 1)
214
-
215
  colors.append(color)
216
-
217
- # Embed the color and update the sequence
218
  color_emb = self.color_embed(color.unsqueeze(1)) # [1, 1, d_model]
 
 
219
 
220
- current_emb = torch.cat([current_emb, color_emb ], dim=1)
221
-
222
  return torch.cat(colors, dim=0).unsqueeze(0)
223
 
224
 
225
- model = TextToPaletteModel().to(device)
226
  state_dict = torch.load(model_path, map_location=torch.device(device))
227
  model.load_state_dict(state_dict['model'])
228
  model.to(device)
@@ -231,7 +222,78 @@ model.eval()
231
  import gradio as gr
232
 
233
 
234
- def generate_palettes(text, palette_size=5, temp=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  temps=[temp]*4
236
  hex_palettes = []
237
 
@@ -249,21 +311,4 @@ def generate_palettes(text, palette_size=5, temp=1.0):
249
  hex_palettes.append(hex_palette)
250
  return hex_palettes
251
 
252
-
253
-
254
-
255
- with gr.Blocks() as demo:
256
- default_input = gr.Textbox(label="Input text", placeholder="")
257
- palette_size = gr.Slider(1, 10, value=5, step=1, label="Palette size")
258
- temp = gr.Slider(0.0, 0.1, value=0.03, step=0.01, label="Temperature")
259
- default_button = gr.Button("🎨 Generate")
260
- default_output = gr.HTML()
261
-
262
- default_button.click(
263
- generate_palettes,
264
- inputs=[default_input,palette_size , temp],
265
- outputs=default_output
266
- )
267
-
268
-
269
  demo.launch()
 
 
 
 
 
 
 
 
1
  from skimage import color
2
  import numpy as np
3
 
4
+ tokenizer_input_length = 77
5
+
6
+ import torch
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
  def rgb_to_hex(rgb_array):
10
+ return "{:02x}{:02x}{:02x}".format(*rgb_array)
11
 
 
 
 
 
 
12
 
 
 
 
 
13
 
14
  def normalized_lab_to_rgb(lab_array):
15
  lab_array = np.array(lab_array, dtype=np.float32)
 
27
  return tuple(rgb_array.squeeze())
28
 
29
 
30
+
 
 
 
31
  from huggingface_hub import hf_hub_download
32
 
33
+ model_path = hf_hub_download(repo_id="lasercatz/text2palette", filename="epoch_19.pth")
34
+
35
 
 
36
 
37
 
38
 
39
+ import torch.nn as nn
40
+ import torch.nn.functional as F
41
+ from transformers import CLIPTextModel, CLIPTokenizer
42
+
43
 
44
  class AttentionPooling(nn.Module):
45
  def __init__(self, d_model):
 
53
  weights = F.softmax(scores, dim=-1).unsqueeze(-1)
54
  return torch.sum(x * weights, dim=1)
55
 
56
+
57
  class SequencePriorNet(nn.Module):
58
  def __init__(self, d_model, d_z, n_heads=4):
59
  super().__init__()
 
64
  self.dropout = nn.Dropout(0.3)
65
 
66
  def forward(self, text_feats, attention_mask):
67
+ attn_output, _ = self.attn(
68
+ text_feats, text_feats, text_feats, key_padding_mask=~attention_mask.bool())
69
  x = self.norm(attn_output + text_feats)
70
  x = self.dropout(x)
71
  x = self.pool(x, attention_mask)
72
  x = self.fc(x)
73
  return x
74
 
75
+
76
+ class Text2PaletteModel(nn.Module):
77
+ def __init__(self, d_model=768, d_z=256, max_seq_len=64,
78
  n_layers=8, n_heads=8, dim_ff=3072):
79
  super().__init__()
80
  self.d_model = d_model
 
81
  self.max_seq_len = max_seq_len
82
 
83
+ self.tokenizer = CLIPTokenizer.from_pretrained(
84
+ 'openai/clip-vit-base-patch32')
85
+ self.clip_text = CLIPTextModel.from_pretrained(
86
+ 'openai/clip-vit-base-patch32')
87
+
88
+ self.tokenizer_input_length = tokenizer_input_length
89
+
90
+
91
  self.text_proj = nn.Sequential(
92
+ nn.Linear(512, d_model*2),
93
  nn.GELU(),
94
  nn.LayerNorm(d_model*2),
95
  nn.Dropout(0.3),
96
  nn.Linear(d_model*2, d_model)
97
  )
98
 
 
99
  self.color_embed = nn.Sequential(
100
  nn.Linear(3, d_model),
101
  nn.LayerNorm(d_model),
102
  nn.GELU(),
103
  nn.Dropout(0.3)
104
  )
105
+
106
  self.cross_attn = nn.MultiheadAttention(d_model, 8, batch_first=True)
107
+
 
108
  self.position_embed = nn.Embedding(max_seq_len, d_model)
109
  self.start_embed = nn.Parameter(torch.randn(1, d_model))
110
 
 
111
  self.palette_encoder = nn.TransformerEncoder(
112
+ nn.TransformerEncoderLayer(
113
+ d_model, n_heads, dim_ff, batch_first=True),
114
  n_layers
115
  )
116
 
 
117
  self.z_proj = nn.Sequential(
118
  nn.Linear(d_model*2, d_z),
119
  nn.LayerNorm(d_z),
 
123
  self.z_mu = nn.Linear(d_z, d_z)
124
  self.z_logvar = nn.Linear(d_z, d_z)
125
 
 
126
  self.decoder = nn.TransformerDecoder(
127
+ nn.TransformerDecoderLayer(
128
+ d_model, n_heads, dim_ff, batch_first=True),
129
  n_layers
130
  )
131
 
 
132
  self.out_mu_L = nn.Sequential(
133
  nn.Linear(d_model, 1),
134
  nn.Sigmoid()
 
138
  nn.Tanh()
139
  )
140
  self.out_logvar = nn.Linear(d_model, 3)
141
+
142
  self.prior_net = SequencePriorNet(d_model, d_z, n_heads=4)
143
 
 
144
  self.text_pool = AttentionPooling(d_model)
145
  self.palette_pool = AttentionPooling(d_model)
146
 
147
+ def reparameterize(self, mu, logvar):
148
+
149
+ if self.training:
150
+ std = torch.exp(0.5 * logvar)
151
+ eps = torch.randn_like(std)
152
+ return mu + eps * std
153
+ else:
154
+ return mu
155
+
156
 
157
  @torch.no_grad()
158
+ def generate(self, text, palette_size, temp=1.0):
159
  self.eval()
160
+ tokenized = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True,
161
+ max_length=self.tokenizer_input_length).to(next(self.parameters()).device)
162
+
163
+ text_feats = self.clip_text(**tokenized).last_hidden_state
 
 
164
  text_feats = self.text_proj(text_feats)
165
+
 
166
  # Sample from prior
167
  prior_params = self.prior_net(text_feats, tokenized['attention_mask'])
168
  prior_mu, prior_logvar = prior_params.chunk(2, -1)
169
+ z = prior_mu + torch.exp(0.5 * prior_logvar) * \
170
+ torch.randn_like(prior_mu) * temp
171
  z_expanded = self.z_expand(z).unsqueeze(1)
172
 
173
+ memory = torch.cat([z_expanded, text_feats],
174
+ dim=1) # [1, T+1, d_model]
175
  memory_key_padding_mask = torch.cat([
176
  torch.zeros((1, 1), dtype=torch.bool, device=device),
177
  ~tokenized['attention_mask'].bool()
178
+ ], dim=1) # [1, T+1]
179
 
 
 
 
 
180
  colors = []
181
+ batch_size = 1
182
+ current_emb = self.start_embed.unsqueeze(0).expand(
183
+ batch_size, -1, -1) # [1, 1, d_model]
184
+
185
+ for i in range(min(palette_size, self.max_seq_len)):
186
+
187
+ pos = self.position_embed(torch.arange(0, current_emb.size(
188
+ 1), device=device)).unsqueeze(0) # [1, i+1, d_model]
189
  decoder_in = current_emb + pos # [1, i+1, d_model]
190
+
 
191
  output = self.decoder(
192
  decoder_in,
193
  memory,
194
+ tgt_mask=nn.Transformer.generate_square_subsequent_mask(
195
+ decoder_in.size(1), device=device),
196
  memory_key_padding_mask=memory_key_padding_mask
197
  ) # [1, i+1, d_model]
 
 
 
 
 
198
 
199
+ mu = torch.cat([self.out_mu_L(output[:, -1]),
200
+ self.out_mu_ab(output[:, -1])], dim=-1) # [1, 3]
201
+ logvar = self.out_logvar(output[:, -1]) # [1, 3]
202
+ color = mu + torch.exp(0.5 * logvar) * \
203
+ torch.randn_like(mu) * temp # [1, 3]
204
  color[:, 0].clamp_(0, 1)
205
  color[:, 1:].clamp_(-1, 1)
206
+
207
  colors.append(color)
208
+
 
209
  color_emb = self.color_embed(color.unsqueeze(1)) # [1, 1, d_model]
210
+ current_emb = torch.cat(
211
+ [current_emb, color_emb], dim=1) # [1, i+2, d_model]
212
 
 
 
213
  return torch.cat(colors, dim=0).unsqueeze(0)
214
 
215
 
216
+ model = Text2PaletteModel().to(device)
217
  state_dict = torch.load(model_path, map_location=torch.device(device))
218
  model.load_state_dict(state_dict['model'])
219
  model.to(device)
 
222
  import gradio as gr
223
 
224
 
225
+
226
+ def generate_palette(text, palette_size=5, temp=1.0):
227
+
228
+ html=""
229
+
230
+ with torch.no_grad():
231
+ generated_palette = model.generate(
232
+ text,
233
+ palette_size=int(palette_size),
234
+ temp=temp
235
+ )
236
+
237
+ lab = generated_palette[0].cpu().numpy()
238
+ hex_palette = [rgb_to_hex(normalized_lab_to_rgb(lab_color)) for lab_color in lab]
239
+
240
+ html += "<div style='display: flex; flex-direction: row;align-items: center; width:100%;'>"
241
+
242
+ hex_codes = []
243
+ for i,hex_color in enumerate(hex_palette):
244
+ hex_color = "#"+hex_color.upper()
245
+ hex_codes.append(hex_color)
246
+ html += f'<div style=\'margin:0;flex: 1; text-align: center;\'><div style=\'background-color: {hex_color}; width: 100%; height: 100px;border-radius:{"1em 0 0 1em" if i==0 else "0 1em 1em 0" if i==len(hex_palette)-1 else "0"}\'></div><p style=\'font-size: 14px; margin-top: 5px;\'>{hex_color}</p></div>'
247
+ html += "</div>"
248
+
249
+ return html
250
+
251
+
252
+
253
+ with gr.Blocks() as demo:
254
+ gr.Markdown("<h1>Palette Generator</h1>")
255
+
256
+ input = gr.Textbox(label="Input text", placeholder="Describe the palette in your mind")
257
+
258
+ with gr.Row():
259
+ palette_size = gr.Slider(2, 10, value=5, step=1, label="Colors")
260
+ temp = gr.Slider(0.0, 0.1, value=0.03, step=0.01, label="Temperature")
261
+ with gr.Row():
262
+ with gr.Column():
263
+ gr.Examples(
264
+ examples=[["fries in ketchup"], ["blueberry milkshake"], ["Oreo McFlurry"]],
265
+ inputs=[input],
266
+ label="Food & Drinks"
267
+ )
268
+ with gr.Column():
269
+ gr.Examples(
270
+ examples=[["bonfire"], ["sheep on grass"], ["North Arctic"]],
271
+ inputs=[input],
272
+ label="Objects & Places"
273
+ )
274
+ with gr.Row():
275
+ with gr.Column():
276
+ gr.Examples(
277
+ examples=[["rock climbing"], ["scuba-diving"], ["Halloween pumpkin party"]],
278
+ inputs=[input],
279
+ label="Activities"
280
+ )
281
+ with gr.Column():
282
+ gr.Examples(
283
+ examples=[["sweetheart"], ["sorrow"], ["murder"]],
284
+ inputs=[input],
285
+ label="Abstract"
286
+ )
287
+ generate_button = gr.Button("🎨 Generate")
288
+ output = gr.HTML("<div style=\"height: 100px\"></div>")
289
+
290
+ generate_button.click(
291
+ generate_palette,
292
+ inputs=[input, palette_size, temp],
293
+ outputs=output
294
+ )
295
+
296
+ def generate_palettes_api(text, palette_size=5, temp=1.0):
297
  temps=[temp]*4
298
  hex_palettes = []
299
 
 
311
  hex_palettes.append(hex_palette)
312
  return hex_palettes
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  demo.launch()