Spaces:
Sleeping
Sleeping
| from skimage import color | |
| import numpy as np | |
| import json | |
| tokenizer_input_length = 77 | |
| import torch | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def rgb_to_hex(rgb_array): | |
| return "{:02x}{:02x}{:02x}".format(*rgb_array) | |
| def normalized_lab_to_rgb(lab_array): | |
| lab_array = np.array(lab_array, dtype=np.float32) | |
| lab_array = lab_array.copy() | |
| lab_array[0] *= 100.0 | |
| lab_array[1] *= 127.0 | |
| lab_array[2] *= 127.0 | |
| if lab_array.ndim == 1: | |
| lab_array = lab_array.reshape(1, 3) | |
| rgb_array = color.lab2rgb(lab_array) | |
| rgb_array = (rgb_array * 255).astype(np.uint8) | |
| return tuple(rgb_array.squeeze()) | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download(repo_id="lasercatz/text2palette", filename="epoch_19.pth") | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, d_model): | |
| super().__init__() | |
| self.attn = nn.Linear(d_model, 1) | |
| def forward(self, x, mask=None): | |
| scores = self.attn(x).squeeze(-1) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| weights = F.softmax(scores, dim=-1).unsqueeze(-1) | |
| return torch.sum(x * weights, dim=1) | |
| class SequencePriorNet(nn.Module): | |
| def __init__(self, d_model, d_z, n_heads=4): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) | |
| self.pool = AttentionPooling(d_model) | |
| self.fc = nn.Linear(d_model, d_z * 2) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, text_feats, attention_mask): | |
| attn_output, _ = self.attn( | |
| text_feats, text_feats, text_feats, key_padding_mask=~attention_mask.bool()) | |
| x = self.norm(attn_output + text_feats) | |
| x = self.dropout(x) | |
| x = self.pool(x, attention_mask) | |
| x = self.fc(x) | |
| return x | |
| class Text2PaletteModel(nn.Module): | |
| def __init__(self, d_model=768, d_z=256, max_seq_len=64, | |
| n_layers=8, n_heads=8, dim_ff=3072): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.max_seq_len = max_seq_len | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| 'openai/clip-vit-base-patch32') | |
| self.clip_text = CLIPTextModel.from_pretrained( | |
| 'openai/clip-vit-base-patch32') | |
| self.tokenizer_input_length = tokenizer_input_length | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(512, d_model*2), | |
| nn.GELU(), | |
| nn.LayerNorm(d_model*2), | |
| nn.Dropout(0.3), | |
| nn.Linear(d_model*2, d_model) | |
| ) | |
| self.color_embed = nn.Sequential( | |
| nn.Linear(3, d_model), | |
| nn.LayerNorm(d_model), | |
| nn.GELU(), | |
| nn.Dropout(0.3) | |
| ) | |
| self.cross_attn = nn.MultiheadAttention(d_model, 8, batch_first=True) | |
| self.position_embed = nn.Embedding(max_seq_len, d_model) | |
| self.start_embed = nn.Parameter(torch.randn(1, d_model)) | |
| self.palette_encoder = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model, n_heads, dim_ff, batch_first=True), | |
| n_layers | |
| ) | |
| self.z_proj = nn.Sequential( | |
| nn.Linear(d_model*2, d_z), | |
| nn.LayerNorm(d_z), | |
| nn.GELU() | |
| ) | |
| self.z_expand = nn.Linear(d_z, d_model) | |
| self.z_mu = nn.Linear(d_z, d_z) | |
| self.z_logvar = nn.Linear(d_z, d_z) | |
| self.decoder = nn.TransformerDecoder( | |
| nn.TransformerDecoderLayer( | |
| d_model, n_heads, dim_ff, batch_first=True), | |
| n_layers | |
| ) | |
| self.out_mu_L = nn.Sequential( | |
| nn.Linear(d_model, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.out_mu_ab = nn.Sequential( | |
| nn.Linear(d_model, 2), | |
| nn.Tanh() | |
| ) | |
| self.out_logvar = nn.Linear(d_model, 3) | |
| self.prior_net = SequencePriorNet(d_model, d_z, n_heads=4) | |
| self.text_pool = AttentionPooling(d_model) | |
| self.palette_pool = AttentionPooling(d_model) | |
| def reparameterize(self, mu, logvar): | |
| if self.training: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| else: | |
| return mu | |
| def generate(self, text, palette_size, temp=1.0): | |
| self.eval() | |
| tokenized = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, | |
| max_length=self.tokenizer_input_length).to(next(self.parameters()).device) | |
| text_feats = self.clip_text(**tokenized).last_hidden_state | |
| text_feats = self.text_proj(text_feats) | |
| # Sample from prior | |
| prior_params = self.prior_net(text_feats, tokenized['attention_mask']) | |
| prior_mu, prior_logvar = prior_params.chunk(2, -1) | |
| z = prior_mu + torch.exp(0.5 * prior_logvar) * \ | |
| torch.randn_like(prior_mu) * temp | |
| z_expanded = self.z_expand(z).unsqueeze(1) | |
| memory = torch.cat([z_expanded, text_feats], | |
| dim=1) # [1, T+1, d_model] | |
| memory_key_padding_mask = torch.cat([ | |
| torch.zeros((1, 1), dtype=torch.bool, device=device), | |
| ~tokenized['attention_mask'].bool() | |
| ], dim=1) # [1, T+1] | |
| colors = [] | |
| batch_size = 1 | |
| current_emb = self.start_embed.unsqueeze(0).expand( | |
| batch_size, -1, -1) # [1, 1, d_model] | |
| for i in range(min(palette_size, self.max_seq_len)): | |
| pos = self.position_embed(torch.arange(0, current_emb.size( | |
| 1), device=device)).unsqueeze(0) # [1, i+1, d_model] | |
| decoder_in = current_emb + pos # [1, i+1, d_model] | |
| output = self.decoder( | |
| decoder_in, | |
| memory, | |
| tgt_mask=nn.Transformer.generate_square_subsequent_mask( | |
| decoder_in.size(1), device=device), | |
| memory_key_padding_mask=memory_key_padding_mask | |
| ) # [1, i+1, d_model] | |
| mu = torch.cat([self.out_mu_L(output[:, -1]), | |
| self.out_mu_ab(output[:, -1])], dim=-1) # [1, 3] | |
| logvar = self.out_logvar(output[:, -1]) # [1, 3] | |
| color = mu + torch.exp(0.5 * logvar) * \ | |
| torch.randn_like(mu) * temp # [1, 3] | |
| color[:, 0].clamp_(0, 1) | |
| color[:, 1:].clamp_(-1, 1) | |
| colors.append(color) | |
| color_emb = self.color_embed(color.unsqueeze(1)) # [1, 1, d_model] | |
| current_emb = torch.cat( | |
| [current_emb, color_emb], dim=1) # [1, i+2, d_model] | |
| return torch.cat(colors, dim=0).unsqueeze(0) | |
| model = Text2PaletteModel().to(device) | |
| state_dict = torch.load(model_path, map_location=torch.device(device)) | |
| model.load_state_dict(state_dict['model']) | |
| model.to(device) | |
| model.eval() | |
| import gradio as gr | |
| def generate(text, palette_size=5, temp=0.5): | |
| html="" | |
| all_hex_palettes = [] | |
| with torch.no_grad(): | |
| generated_palette = model.generate( | |
| text, | |
| palette_size=int(palette_size), | |
| temp=temp | |
| ) | |
| lab = generated_palette[0].cpu().numpy() | |
| hex_palette = [rgb_to_hex(normalized_lab_to_rgb(lab_color)) for lab_color in lab] | |
| all_hex_palettes.append(hex_palette) | |
| html += "<div style='display: flex; flex-direction: row;align-items: center; width:100%;'>" | |
| hex_codes = [] | |
| for i,hex_color in enumerate(hex_palette): | |
| hex_color = "#"+hex_color.upper() | |
| hex_codes.append(hex_color) | |
| html += f'<div style=\'margin:0;flex: 1; text-align: center;\'><div style=\'background-color: {hex_color}; width: 100%; height: 100px;border-radius:{"1em 0 0 1em" if i==0 else "0 1em 1em 0" if i==len(hex_palette)-1 else "0"}\'></div><p style=\'font-size: 14px; margin-top: 5px;\'>{hex_color}</p></div>' | |
| html += "</div>" | |
| json_output = json.dumps({"palettes": all_hex_palettes}, indent=2) | |
| html+=json_output | |
| return html | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1>Palette Generator</h1>") | |
| input = gr.Textbox(label="Input text", placeholder="Describe the palette in your mind") | |
| with gr.Row(): | |
| palette_size = gr.Slider(2, 10, value=5, step=1, label="Colors") | |
| temp = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Temperature") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[["fries in ketchup"], ["blueberry milkshake"], ["Oreo McFlurry"]], | |
| inputs=[input], | |
| label="Food & Drinks" | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[["bonfire"], ["sheep on grass"], ["North Arctic"]], | |
| inputs=[input], | |
| label="Objects & Places" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[["rock climbing"], ["scuba-diving"], ["Halloween pumpkin party"]], | |
| inputs=[input], | |
| label="Activities" | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[["sweetheart"], ["sorrow"], ["murder"]], | |
| inputs=[input], | |
| label="Abstract" | |
| ) | |
| generate_button = gr.Button("🎨 Generate") | |
| output = gr.HTML("<div style=\"height: 100px\"></div>") | |
| generate_button.click( | |
| generate, | |
| inputs=[input, palette_size, temp], | |
| outputs=output | |
| ) | |
| demo.launch() |