from skimage import color import numpy as np import json tokenizer_input_length = 77 import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' def rgb_to_hex(rgb_array): return "{:02x}{:02x}{:02x}".format(*rgb_array) def normalized_lab_to_rgb(lab_array): lab_array = np.array(lab_array, dtype=np.float32) lab_array = lab_array.copy() lab_array[0] *= 100.0 lab_array[1] *= 127.0 lab_array[2] *= 127.0 if lab_array.ndim == 1: lab_array = lab_array.reshape(1, 3) rgb_array = color.lab2rgb(lab_array) rgb_array = (rgb_array * 255).astype(np.uint8) return tuple(rgb_array.squeeze()) from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="lasercatz/text2palette", filename="epoch_19.pth") import torch.nn as nn import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer class AttentionPooling(nn.Module): def __init__(self, d_model): super().__init__() self.attn = nn.Linear(d_model, 1) def forward(self, x, mask=None): scores = self.attn(x).squeeze(-1) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) weights = F.softmax(scores, dim=-1).unsqueeze(-1) return torch.sum(x * weights, dim=1) class SequencePriorNet(nn.Module): def __init__(self, d_model, d_z, n_heads=4): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) self.pool = AttentionPooling(d_model) self.fc = nn.Linear(d_model, d_z * 2) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.3) def forward(self, text_feats, attention_mask): attn_output, _ = self.attn( text_feats, text_feats, text_feats, key_padding_mask=~attention_mask.bool()) x = self.norm(attn_output + text_feats) x = self.dropout(x) x = self.pool(x, attention_mask) x = self.fc(x) return x class Text2PaletteModel(nn.Module): def __init__(self, d_model=768, d_z=256, max_seq_len=64, n_layers=8, n_heads=8, dim_ff=3072): super().__init__() self.d_model = d_model self.max_seq_len = max_seq_len self.tokenizer = CLIPTokenizer.from_pretrained( 'openai/clip-vit-base-patch32') self.clip_text = CLIPTextModel.from_pretrained( 'openai/clip-vit-base-patch32') self.tokenizer_input_length = tokenizer_input_length self.text_proj = nn.Sequential( nn.Linear(512, d_model*2), nn.GELU(), nn.LayerNorm(d_model*2), nn.Dropout(0.3), nn.Linear(d_model*2, d_model) ) self.color_embed = nn.Sequential( nn.Linear(3, d_model), nn.LayerNorm(d_model), nn.GELU(), nn.Dropout(0.3) ) self.cross_attn = nn.MultiheadAttention(d_model, 8, batch_first=True) self.position_embed = nn.Embedding(max_seq_len, d_model) self.start_embed = nn.Parameter(torch.randn(1, d_model)) self.palette_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model, n_heads, dim_ff, batch_first=True), n_layers ) self.z_proj = nn.Sequential( nn.Linear(d_model*2, d_z), nn.LayerNorm(d_z), nn.GELU() ) self.z_expand = nn.Linear(d_z, d_model) self.z_mu = nn.Linear(d_z, d_z) self.z_logvar = nn.Linear(d_z, d_z) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model, n_heads, dim_ff, batch_first=True), n_layers ) self.out_mu_L = nn.Sequential( nn.Linear(d_model, 1), nn.Sigmoid() ) self.out_mu_ab = nn.Sequential( nn.Linear(d_model, 2), nn.Tanh() ) self.out_logvar = nn.Linear(d_model, 3) self.prior_net = SequencePriorNet(d_model, d_z, n_heads=4) self.text_pool = AttentionPooling(d_model) self.palette_pool = AttentionPooling(d_model) def reparameterize(self, mu, logvar): if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std else: return mu @torch.no_grad() def generate(self, text, palette_size, temp=1.0): self.eval() tokenized = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=self.tokenizer_input_length).to(next(self.parameters()).device) text_feats = self.clip_text(**tokenized).last_hidden_state text_feats = self.text_proj(text_feats) # Sample from prior prior_params = self.prior_net(text_feats, tokenized['attention_mask']) prior_mu, prior_logvar = prior_params.chunk(2, -1) z = prior_mu + torch.exp(0.5 * prior_logvar) * \ torch.randn_like(prior_mu) * temp z_expanded = self.z_expand(z).unsqueeze(1) memory = torch.cat([z_expanded, text_feats], dim=1) # [1, T+1, d_model] memory_key_padding_mask = torch.cat([ torch.zeros((1, 1), dtype=torch.bool, device=device), ~tokenized['attention_mask'].bool() ], dim=1) # [1, T+1] colors = [] batch_size = 1 current_emb = self.start_embed.unsqueeze(0).expand( batch_size, -1, -1) # [1, 1, d_model] for i in range(min(palette_size, self.max_seq_len)): pos = self.position_embed(torch.arange(0, current_emb.size( 1), device=device)).unsqueeze(0) # [1, i+1, d_model] decoder_in = current_emb + pos # [1, i+1, d_model] output = self.decoder( decoder_in, memory, tgt_mask=nn.Transformer.generate_square_subsequent_mask( decoder_in.size(1), device=device), memory_key_padding_mask=memory_key_padding_mask ) # [1, i+1, d_model] mu = torch.cat([self.out_mu_L(output[:, -1]), self.out_mu_ab(output[:, -1])], dim=-1) # [1, 3] logvar = self.out_logvar(output[:, -1]) # [1, 3] color = mu + torch.exp(0.5 * logvar) * \ torch.randn_like(mu) * temp # [1, 3] color[:, 0].clamp_(0, 1) color[:, 1:].clamp_(-1, 1) colors.append(color) color_emb = self.color_embed(color.unsqueeze(1)) # [1, 1, d_model] current_emb = torch.cat( [current_emb, color_emb], dim=1) # [1, i+2, d_model] return torch.cat(colors, dim=0).unsqueeze(0) model = Text2PaletteModel().to(device) state_dict = torch.load(model_path, map_location=torch.device(device)) model.load_state_dict(state_dict['model']) model.to(device) model.eval() import gradio as gr def generate(text, palette_size=5, temp=0.5): html="" all_hex_palettes = [] with torch.no_grad(): generated_palette = model.generate( text, palette_size=int(palette_size), temp=temp ) lab = generated_palette[0].cpu().numpy() hex_palette = [rgb_to_hex(normalized_lab_to_rgb(lab_color)) for lab_color in lab] all_hex_palettes.append(hex_palette) html += "
" hex_codes = [] for i,hex_color in enumerate(hex_palette): hex_color = "#"+hex_color.upper() hex_codes.append(hex_color) html += f'

{hex_color}

' html += "
" json_output = json.dumps({"palettes": all_hex_palettes}, indent=2) html+=json_output return html with gr.Blocks() as demo: gr.Markdown("

Palette Generator

") input = gr.Textbox(label="Input text", placeholder="Describe the palette in your mind") with gr.Row(): palette_size = gr.Slider(2, 10, value=5, step=1, label="Colors") temp = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Temperature") with gr.Row(): with gr.Column(): gr.Examples( examples=[["fries in ketchup"], ["blueberry milkshake"], ["Oreo McFlurry"]], inputs=[input], label="Food & Drinks" ) with gr.Column(): gr.Examples( examples=[["bonfire"], ["sheep on grass"], ["North Arctic"]], inputs=[input], label="Objects & Places" ) with gr.Row(): with gr.Column(): gr.Examples( examples=[["rock climbing"], ["scuba-diving"], ["Halloween pumpkin party"]], inputs=[input], label="Activities" ) with gr.Column(): gr.Examples( examples=[["sweetheart"], ["sorrow"], ["murder"]], inputs=[input], label="Abstract" ) generate_button = gr.Button("🎨 Generate") output = gr.HTML("
") generate_button.click( generate, inputs=[input, palette_size, temp], outputs=output ) demo.launch()