ColorNamer / app.py
DrHouseFan-315's picture
Update app.py
8dfe603 verified
import torch # the library
import torch.nn as nn # so we can just write nn.X insted of torch.nn.X
import gradio as gr
from PIL import ImageColor
def hex_to_rgb(value):
value = value.lstrip('#')
lv = len(value)
return tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3))
def parse_color_str(clrstr):
if clrstr[0]=='#': return hex_to_rgb(clrstr) #looks good already
if not clrstr.startswith('rgb'): # neither rgb(r,g,b) nor rgba(r,g,b,a)
gr.Warning(f"Unsupported color format: {clrstr}")
return "#000000"
# extract substring inside ( ), then split by comma
values=clrstr.split('(')[1].split(')')[0].split(',')
r = min(max(round(float(values[0].strip())), 0), 255)
g = min(max(round(float(values[1].strip())), 0), 255)
b = min(max(round(float(values[2].strip())), 0), 255)
return (r,g,b)
class ColorGen(nn.Module):
def __init__(self,dim,vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size,dim)
self.color_embedding = nn.Sequential(
nn.Linear(3,64),
nn.ReLU(),
nn.Linear(64,64),
nn.ReLU(),
nn.Linear(64,dim)
)
self.lstm = nn.LSTM(dim,dim,1,batch_first=True)
self.debed = nn.Linear(dim,vocab_size)
def forward(self,col,text):
col = self.color_embedding(col).unsqueeze(1)
if text is not None:
text = torch.cat([col,self.embedding(text)],1)
else:
text = col
pred_seq = self.lstm(text)[0]
return self.debed(pred_seq)
def name_color(self,col,n_steps,temp:float=1.0):
text = []
with torch.no_grad():
for step in range(n_steps):
pred = torch.softmax(self.forward(col,(torch.tensor(text,dtype=torch.long,device=self.embedding.weight.device).unsqueeze(0) if len(text) > 0 else None)),-1)
pred = pred[0,-1]
if temp > 0.02:
pred = pred/temp
pred =torch.distributions.Categorical(pred)
text.append(pred.sample().item())
else:
text.append(pred.argmax().item())
return text
vocab = [' ', '!', '$', '%', '&', '(', ')', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '²', 'ß', 'à', 'á', 'â', 'ä', 'å', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ñ', 'ò', 'ó', 'ö', 'ù', 'ú', 'û', 'ü', 'ā', 'ē', 'ě',
'ğ', 'ī', 'ı', 'ł', 'ń', 'ō', 'ő', 'œ', 'š', 'ū', 'ż', 'ǎ', 'ǐ', 'ǔ', 'ǜ', 'я', '’', '₂', '№', 'ⅱ']
color_pad_len = 32
model = ColorGen(256,len(vocab))
model.load_state_dict(torch.load("color_gen.pt",map_location=torch.device("cpu")))
def get_name(color,temp:float):
if color is not None:
color = parse_color_str(color)
else:
color = (0,0,0)
name = "".join([vocab[_] for _ in model.name_color(torch.tensor([color])/255,color_pad_len,temp)]).strip()
print(f"Generated: {color} - {name.title()}")
return name.title()
demo = gr.Interface(
fn=get_name,
inputs=[gr.ColorPicker(),gr.Slider(0,2,1,label="Temperature")],
outputs=[gr.TextArea(label="Color name")],
api_name="predict"
)
demo.launch()