Neil Saunders commited on
Commit
0b8d9f0
·
1 Parent(s): 144333e

Add contour model

Browse files
model.pth → anime.pth RENAMED
File without changes
app.py CHANGED
@@ -83,13 +83,17 @@ class Generator(nn.Module):
83
 
84
  return out
85
 
86
- model1 = Generator(3, 1, 3)
87
- model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
88
- model1.eval()
89
 
90
- model2 = Generator(3, 1, 3)
91
- model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
92
- model2.eval()
 
 
 
 
93
 
94
  def predict(input_img, ver):
95
  input_img = Image.open(input_img)
@@ -100,10 +104,12 @@ def predict(input_img, ver):
100
 
101
  drawing = 0
102
  with torch.no_grad():
103
- if ver == 'style 2':
104
- drawing = model2(input_img)[0].detach()
 
 
105
  else:
106
- drawing = model1(input_img)[0].detach()
107
 
108
  drawing = transforms.ToPILImage()(drawing)
109
  return drawing
@@ -111,11 +117,11 @@ def predict(input_img, ver):
111
  title="informative-drawings"
112
  description="Gradio Demo for line drawing generation. "
113
  # article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
114
- examples=[['cat.png', 'style 1'], ['bridge.png', 'style 1'], ['lizard.png', 'style 2'],]
115
 
116
 
117
  iface = gr.Interface(predict, [gr.Image(type='filepath'),
118
- gr.Radio(['style 1','style 2'], type="value", value='style 1', label='version')],
119
  gr.Image(type="pil"), title=title,description=description,examples=examples)
120
 
121
  iface.launch()
 
83
 
84
  return out
85
 
86
+ anime = Generator(3, 1, 3)
87
+ anime.load_state_dict(torch.load('anime.pth', map_location=torch.device('cpu')))
88
+ anime.eval()
89
 
90
+ contour = Generator(3, 1, 3)
91
+ contour.load_state_dict(torch.load('contour.pth', map_location=torch.device('cpu')))
92
+ contour.eval()
93
+
94
+ opensketch = Generator(3, 1, 3)
95
+ opensketch.load_state_dict(torch.load('opensketch.pth', map_location=torch.device('cpu')))
96
+ opensketch.eval()
97
 
98
  def predict(input_img, ver):
99
  input_img = Image.open(input_img)
 
104
 
105
  drawing = 0
106
  with torch.no_grad():
107
+ if ver == 'anime':
108
+ drawing = anime(input_img)[0].detach()
109
+ elif ver == 'contour':
110
+ drawing = contour(input_img)[0].detach()
111
  else:
112
+ drawing = opensketch(input_img)[0].detach()
113
 
114
  drawing = transforms.ToPILImage()(drawing)
115
  return drawing
 
117
  title="informative-drawings"
118
  description="Gradio Demo for line drawing generation. "
119
  # article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
120
+ examples=[['cat.png', 'anime'], ['bridge.png', 'contour'], ['lizard.png', 'opensketch'],]
121
 
122
 
123
  iface = gr.Interface(predict, [gr.Image(type='filepath'),
124
+ gr.Radio(['anime','opensketch','contour'], type="value", value='contour', label='version')],
125
  gr.Image(type="pil"), title=title,description=description,examples=examples)
126
 
127
  iface.launch()
contour.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8b6ec6db973dce9e9a455115514be72f569081651c561a1a8d96d78bb3ece5f
3
+ size 17173511
model2.pth → opensketch.pth RENAMED
File without changes