1plus1 commited on
Commit
148a566
·
verified ·
1 Parent(s): f51ce08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -23,17 +23,14 @@ num_content_layers = len(content_layers)
23
  num_style_layers = len(style_layers)
24
 
25
  # Model for extracting style feature
26
- style_models = [
27
- keras.Model(inputs=model.input, outputs=model.get_layer(layer).output) for layer in style_layers
28
- ]
29
 
30
  # Model for extracting content feature
31
  content_model = keras.Model(inputs=model.input, outputs=model.get_layer(content_layers).output, name='Content_model')
32
 
33
- # Style models summary
34
- for m in style_models:
35
- print(f'Style model: {m.name}')
36
- print(m.summary())
37
 
38
  # Content model summary
39
  print(f'Content model: {content_model.name}')
@@ -54,11 +51,13 @@ def gram_matrix(image):
54
 
55
  # Style loss
56
  def style_loss(gen, style_img):
57
- # Compute Gram matrix of generated image and style image
58
- gen_gram = gram_matrix(gen)
59
- style_gram = gram_matrix(style_img)
60
- # Compute style loss
61
- return tf.reduce_sum(tf.square(gen_gram-style_gram))/4
 
 
62
 
63
  # Training parameter
64
  lr = 10.
@@ -76,11 +75,9 @@ def train_step(image, process_content_image, process_style_image):
76
  content_feature = content_model(process_content_image)
77
  loss_content = content_loss(gen_content_feature, content_feature)
78
  # Calculate style loss
79
- loss_style = 0
80
- for style_model in style_models:
81
- gen_style_feature = style_model(image)
82
- style_feature = style_model(process_style_image)
83
- loss_style += style_loss(gen_style_feature, style_feature)
84
  # Calculate total loss
85
  total_loss = alpha*loss_content + beta*loss_style/style_weight
86
  # Calculate gradient
 
23
  num_style_layers = len(style_layers)
24
 
25
  # Model for extracting style feature
26
+ style_model = keras.Model(inputs=model.input, outputs=[model.get_layer(name).output for name in style_layers], name='Style_model')
 
 
27
 
28
  # Model for extracting content feature
29
  content_model = keras.Model(inputs=model.input, outputs=model.get_layer(content_layers).output, name='Content_model')
30
 
31
+ # Style model summary
32
+ print(f'Style model: {style_model.name}')
33
+ style_model.summary()
 
34
 
35
  # Content model summary
36
  print(f'Content model: {content_model.name}')
 
51
 
52
  # Style loss
53
  def style_loss(gen, style_img):
54
+ total_loss = 0.
55
+ for i in range(len(gen)):
56
+ gen_gram = gram_matrix(gen[i])
57
+ style_gram = gram_matrix(style_img[i])
58
+ # Compute style loss
59
+ total_loss += tf.reduce_sum(tf.square(gen_gram-style_gram))/4
60
+ return total_loss
61
 
62
  # Training parameter
63
  lr = 10.
 
75
  content_feature = content_model(process_content_image)
76
  loss_content = content_loss(gen_content_feature, content_feature)
77
  # Calculate style loss
78
+ gen_style_feature = style_model(image)
79
+ style_feature = style_model(process_style_image)
80
+ loss_style = style_loss(gen_style_feature, style_feature)
 
 
81
  # Calculate total loss
82
  total_loss = alpha*loss_content + beta*loss_style/style_weight
83
  # Calculate gradient