Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,19 @@ import torch # the library
|
|
| 2 |
import torch.nn as nn # so we can just write nn.X insted of torch.nn.X
|
| 3 |
import gradio as gr
|
| 4 |
from PIL import ImageColor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class ColorGen(nn.Module):
|
| 6 |
def __init__(self,dim,vocab_size):
|
| 7 |
super().__init__()
|
|
@@ -46,7 +59,7 @@ model.load_state_dict(torch.load("color_gen.pt",map_location=torch.device("cpu")
|
|
| 46 |
|
| 47 |
def get_name(color,temp:float):
|
| 48 |
if color is not None:
|
| 49 |
-
color =
|
| 50 |
else:
|
| 51 |
color = (0,0,0)
|
| 52 |
name = "".join([vocab[_] for _ in model.name_color(torch.tensor([color])/255,color_pad_len,temp)]).strip()
|
|
|
|
| 2 |
import torch.nn as nn # so we can just write nn.X insted of torch.nn.X
|
| 3 |
import gradio as gr
|
| 4 |
from PIL import ImageColor
|
| 5 |
+
def parse_color_str(clrstr):
|
| 6 |
+
if clrstr[0]=='#': return clrstr #looks good already
|
| 7 |
+
if not clrstr.startswith('rgb'): # neither rgb(r,g,b) nor rgba(r,g,b,a)
|
| 8 |
+
gr.Warning(f"Unsupported color format: {clrstr}")
|
| 9 |
+
return "#000000"
|
| 10 |
+
|
| 11 |
+
# extract substring inside ( ), then split by comma
|
| 12 |
+
values=clrstr.split('(')[1].split(')')[0].split(',')
|
| 13 |
+
|
| 14 |
+
r = min(max(round(float(values[0].strip())), 0), 255)
|
| 15 |
+
g = min(max(round(float(values[1].strip())), 0), 255)
|
| 16 |
+
b = min(max(round(float(values[2].strip())), 0), 255)
|
| 17 |
+
return (r,g,b)
|
| 18 |
class ColorGen(nn.Module):
|
| 19 |
def __init__(self,dim,vocab_size):
|
| 20 |
super().__init__()
|
|
|
|
| 59 |
|
| 60 |
def get_name(color,temp:float):
|
| 61 |
if color is not None:
|
| 62 |
+
color = parse_color_str(color)
|
| 63 |
else:
|
| 64 |
color = (0,0,0)
|
| 65 |
name = "".join([vocab[_] for _ in model.name_color(torch.tensor([color])/255,color_pad_len,temp)]).strip()
|