Lasercatz
Upload app.py
af8382f verified
from skimage import color
import numpy as np
import json
tokenizer_input_length = 77
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def rgb_to_hex(rgb_array):
return "{:02x}{:02x}{:02x}".format(*rgb_array)
def normalized_lab_to_rgb(lab_array):
lab_array = np.array(lab_array, dtype=np.float32)
lab_array = lab_array.copy()
lab_array[0] *= 100.0
lab_array[1] *= 127.0
lab_array[2] *= 127.0
if lab_array.ndim == 1:
lab_array = lab_array.reshape(1, 3)
rgb_array = color.lab2rgb(lab_array)
rgb_array = (rgb_array * 255).astype(np.uint8)
return tuple(rgb_array.squeeze())
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="lasercatz/text2palette", filename="epoch_19.pth")
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer
class AttentionPooling(nn.Module):
def __init__(self, d_model):
super().__init__()
self.attn = nn.Linear(d_model, 1)
def forward(self, x, mask=None):
scores = self.attn(x).squeeze(-1)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1).unsqueeze(-1)
return torch.sum(x * weights, dim=1)
class SequencePriorNet(nn.Module):
def __init__(self, d_model, d_z, n_heads=4):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.pool = AttentionPooling(d_model)
self.fc = nn.Linear(d_model, d_z * 2)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.3)
def forward(self, text_feats, attention_mask):
attn_output, _ = self.attn(
text_feats, text_feats, text_feats, key_padding_mask=~attention_mask.bool())
x = self.norm(attn_output + text_feats)
x = self.dropout(x)
x = self.pool(x, attention_mask)
x = self.fc(x)
return x
class Text2PaletteModel(nn.Module):
def __init__(self, d_model=768, d_z=256, max_seq_len=64,
n_layers=8, n_heads=8, dim_ff=3072):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.tokenizer = CLIPTokenizer.from_pretrained(
'openai/clip-vit-base-patch32')
self.clip_text = CLIPTextModel.from_pretrained(
'openai/clip-vit-base-patch32')
self.tokenizer_input_length = tokenizer_input_length
self.text_proj = nn.Sequential(
nn.Linear(512, d_model*2),
nn.GELU(),
nn.LayerNorm(d_model*2),
nn.Dropout(0.3),
nn.Linear(d_model*2, d_model)
)
self.color_embed = nn.Sequential(
nn.Linear(3, d_model),
nn.LayerNorm(d_model),
nn.GELU(),
nn.Dropout(0.3)
)
self.cross_attn = nn.MultiheadAttention(d_model, 8, batch_first=True)
self.position_embed = nn.Embedding(max_seq_len, d_model)
self.start_embed = nn.Parameter(torch.randn(1, d_model))
self.palette_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model, n_heads, dim_ff, batch_first=True),
n_layers
)
self.z_proj = nn.Sequential(
nn.Linear(d_model*2, d_z),
nn.LayerNorm(d_z),
nn.GELU()
)
self.z_expand = nn.Linear(d_z, d_model)
self.z_mu = nn.Linear(d_z, d_z)
self.z_logvar = nn.Linear(d_z, d_z)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model, n_heads, dim_ff, batch_first=True),
n_layers
)
self.out_mu_L = nn.Sequential(
nn.Linear(d_model, 1),
nn.Sigmoid()
)
self.out_mu_ab = nn.Sequential(
nn.Linear(d_model, 2),
nn.Tanh()
)
self.out_logvar = nn.Linear(d_model, 3)
self.prior_net = SequencePriorNet(d_model, d_z, n_heads=4)
self.text_pool = AttentionPooling(d_model)
self.palette_pool = AttentionPooling(d_model)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
else:
return mu
@torch.no_grad()
def generate(self, text, palette_size, temp=1.0):
self.eval()
tokenized = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True,
max_length=self.tokenizer_input_length).to(next(self.parameters()).device)
text_feats = self.clip_text(**tokenized).last_hidden_state
text_feats = self.text_proj(text_feats)
# Sample from prior
prior_params = self.prior_net(text_feats, tokenized['attention_mask'])
prior_mu, prior_logvar = prior_params.chunk(2, -1)
z = prior_mu + torch.exp(0.5 * prior_logvar) * \
torch.randn_like(prior_mu) * temp
z_expanded = self.z_expand(z).unsqueeze(1)
memory = torch.cat([z_expanded, text_feats],
dim=1) # [1, T+1, d_model]
memory_key_padding_mask = torch.cat([
torch.zeros((1, 1), dtype=torch.bool, device=device),
~tokenized['attention_mask'].bool()
], dim=1) # [1, T+1]
colors = []
batch_size = 1
current_emb = self.start_embed.unsqueeze(0).expand(
batch_size, -1, -1) # [1, 1, d_model]
for i in range(min(palette_size, self.max_seq_len)):
pos = self.position_embed(torch.arange(0, current_emb.size(
1), device=device)).unsqueeze(0) # [1, i+1, d_model]
decoder_in = current_emb + pos # [1, i+1, d_model]
output = self.decoder(
decoder_in,
memory,
tgt_mask=nn.Transformer.generate_square_subsequent_mask(
decoder_in.size(1), device=device),
memory_key_padding_mask=memory_key_padding_mask
) # [1, i+1, d_model]
mu = torch.cat([self.out_mu_L(output[:, -1]),
self.out_mu_ab(output[:, -1])], dim=-1) # [1, 3]
logvar = self.out_logvar(output[:, -1]) # [1, 3]
color = mu + torch.exp(0.5 * logvar) * \
torch.randn_like(mu) * temp # [1, 3]
color[:, 0].clamp_(0, 1)
color[:, 1:].clamp_(-1, 1)
colors.append(color)
color_emb = self.color_embed(color.unsqueeze(1)) # [1, 1, d_model]
current_emb = torch.cat(
[current_emb, color_emb], dim=1) # [1, i+2, d_model]
return torch.cat(colors, dim=0).unsqueeze(0)
model = Text2PaletteModel().to(device)
state_dict = torch.load(model_path, map_location=torch.device(device))
model.load_state_dict(state_dict['model'])
model.to(device)
model.eval()
import gradio as gr
def generate(text, palette_size=5, temp=0.5):
html=""
all_hex_palettes = []
with torch.no_grad():
generated_palette = model.generate(
text,
palette_size=int(palette_size),
temp=temp
)
lab = generated_palette[0].cpu().numpy()
hex_palette = [rgb_to_hex(normalized_lab_to_rgb(lab_color)) for lab_color in lab]
all_hex_palettes.append(hex_palette)
html += "<div style='display: flex; flex-direction: row;align-items: center; width:100%;'>"
hex_codes = []
for i,hex_color in enumerate(hex_palette):
hex_color = "#"+hex_color.upper()
hex_codes.append(hex_color)
html += f'<div style=\'margin:0;flex: 1; text-align: center;\'><div style=\'background-color: {hex_color}; width: 100%; height: 100px;border-radius:{"1em 0 0 1em" if i==0 else "0 1em 1em 0" if i==len(hex_palette)-1 else "0"}\'></div><p style=\'font-size: 14px; margin-top: 5px;\'>{hex_color}</p></div>'
html += "</div>"
json_output = json.dumps({"palettes": all_hex_palettes}, indent=2)
html+=json_output
return html
with gr.Blocks() as demo:
gr.Markdown("<h1>Palette Generator</h1>")
input = gr.Textbox(label="Input text", placeholder="Describe the palette in your mind")
with gr.Row():
palette_size = gr.Slider(2, 10, value=5, step=1, label="Colors")
temp = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Temperature")
with gr.Row():
with gr.Column():
gr.Examples(
examples=[["fries in ketchup"], ["blueberry milkshake"], ["Oreo McFlurry"]],
inputs=[input],
label="Food & Drinks"
)
with gr.Column():
gr.Examples(
examples=[["bonfire"], ["sheep on grass"], ["North Arctic"]],
inputs=[input],
label="Objects & Places"
)
with gr.Row():
with gr.Column():
gr.Examples(
examples=[["rock climbing"], ["scuba-diving"], ["Halloween pumpkin party"]],
inputs=[input],
label="Activities"
)
with gr.Column():
gr.Examples(
examples=[["sweetheart"], ["sorrow"], ["murder"]],
inputs=[input],
label="Abstract"
)
generate_button = gr.Button("🎨 Generate")
output = gr.HTML("<div style=\"height: 100px\"></div>")
generate_button.click(
generate,
inputs=[input, palette_size, temp],
outputs=output
)
demo.launch()