Spaces:
Sleeping
Sleeping
| import torch # the library | |
| import torch.nn as nn # so we can just write nn.X insted of torch.nn.X | |
| import gradio as gr | |
| from PIL import ImageColor | |
| def hex_to_rgb(value): | |
| value = value.lstrip('#') | |
| lv = len(value) | |
| return tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)) | |
| def parse_color_str(clrstr): | |
| if clrstr[0]=='#': return hex_to_rgb(clrstr) #looks good already | |
| if not clrstr.startswith('rgb'): # neither rgb(r,g,b) nor rgba(r,g,b,a) | |
| gr.Warning(f"Unsupported color format: {clrstr}") | |
| return "#000000" | |
| # extract substring inside ( ), then split by comma | |
| values=clrstr.split('(')[1].split(')')[0].split(',') | |
| r = min(max(round(float(values[0].strip())), 0), 255) | |
| g = min(max(round(float(values[1].strip())), 0), 255) | |
| b = min(max(round(float(values[2].strip())), 0), 255) | |
| return (r,g,b) | |
| class ColorGen(nn.Module): | |
| def __init__(self,dim,vocab_size): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size,dim) | |
| self.color_embedding = nn.Sequential( | |
| nn.Linear(3,64), | |
| nn.ReLU(), | |
| nn.Linear(64,64), | |
| nn.ReLU(), | |
| nn.Linear(64,dim) | |
| ) | |
| self.lstm = nn.LSTM(dim,dim,1,batch_first=True) | |
| self.debed = nn.Linear(dim,vocab_size) | |
| def forward(self,col,text): | |
| col = self.color_embedding(col).unsqueeze(1) | |
| if text is not None: | |
| text = torch.cat([col,self.embedding(text)],1) | |
| else: | |
| text = col | |
| pred_seq = self.lstm(text)[0] | |
| return self.debed(pred_seq) | |
| def name_color(self,col,n_steps,temp:float=1.0): | |
| text = [] | |
| with torch.no_grad(): | |
| for step in range(n_steps): | |
| pred = torch.softmax(self.forward(col,(torch.tensor(text,dtype=torch.long,device=self.embedding.weight.device).unsqueeze(0) if len(text) > 0 else None)),-1) | |
| pred = pred[0,-1] | |
| if temp > 0.02: | |
| pred = pred/temp | |
| pred =torch.distributions.Categorical(pred) | |
| text.append(pred.sample().item()) | |
| else: | |
| text.append(pred.argmax().item()) | |
| return text | |
| vocab = [' ', '!', '$', '%', '&', '(', ')', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '²', 'ß', 'à', 'á', 'â', 'ä', 'å', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ñ', 'ò', 'ó', 'ö', 'ù', 'ú', 'û', 'ü', 'ā', 'ē', 'ě', | |
| 'ğ', 'ī', 'ı', 'ł', 'ń', 'ō', 'ő', 'œ', 'š', 'ū', 'ż', 'ǎ', 'ǐ', 'ǔ', 'ǜ', 'я', '’', '₂', '№', 'ⅱ'] | |
| color_pad_len = 32 | |
| model = ColorGen(256,len(vocab)) | |
| model.load_state_dict(torch.load("color_gen.pt",map_location=torch.device("cpu"))) | |
| def get_name(color,temp:float): | |
| if color is not None: | |
| color = parse_color_str(color) | |
| else: | |
| color = (0,0,0) | |
| name = "".join([vocab[_] for _ in model.name_color(torch.tensor([color])/255,color_pad_len,temp)]).strip() | |
| print(f"Generated: {color} - {name.title()}") | |
| return name.title() | |
| demo = gr.Interface( | |
| fn=get_name, | |
| inputs=[gr.ColorPicker(),gr.Slider(0,2,1,label="Temperature")], | |
| outputs=[gr.TextArea(label="Color name")], | |
| api_name="predict" | |
| ) | |
| demo.launch() |