DrHouseFan-315 commited on
Commit
ee9ebd6
·
verified ·
1 Parent(s): 95c19a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -2,6 +2,19 @@ import torch # the library
2
  import torch.nn as nn # so we can just write nn.X insted of torch.nn.X
3
  import gradio as gr
4
  from PIL import ImageColor
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class ColorGen(nn.Module):
6
  def __init__(self,dim,vocab_size):
7
  super().__init__()
@@ -46,7 +59,7 @@ model.load_state_dict(torch.load("color_gen.pt",map_location=torch.device("cpu")
46
 
47
  def get_name(color,temp:float):
48
  if color is not None:
49
- color = ImageColor.getcolor(color,"RGB")
50
  else:
51
  color = (0,0,0)
52
  name = "".join([vocab[_] for _ in model.name_color(torch.tensor([color])/255,color_pad_len,temp)]).strip()
 
2
  import torch.nn as nn # so we can just write nn.X insted of torch.nn.X
3
  import gradio as gr
4
  from PIL import ImageColor
5
+ def parse_color_str(clrstr):
6
+ if clrstr[0]=='#': return clrstr #looks good already
7
+ if not clrstr.startswith('rgb'): # neither rgb(r,g,b) nor rgba(r,g,b,a)
8
+ gr.Warning(f"Unsupported color format: {clrstr}")
9
+ return "#000000"
10
+
11
+ # extract substring inside ( ), then split by comma
12
+ values=clrstr.split('(')[1].split(')')[0].split(',')
13
+
14
+ r = min(max(round(float(values[0].strip())), 0), 255)
15
+ g = min(max(round(float(values[1].strip())), 0), 255)
16
+ b = min(max(round(float(values[2].strip())), 0), 255)
17
+ return (r,g,b)
18
  class ColorGen(nn.Module):
19
  def __init__(self,dim,vocab_size):
20
  super().__init__()
 
59
 
60
  def get_name(color,temp:float):
61
  if color is not None:
62
+ color = parse_color_str(color)
63
  else:
64
  color = (0,0,0)
65
  name = "".join([vocab[_] for _ in model.name_color(torch.tensor([color])/255,color_pad_len,temp)]).strip()