BillyAggarwal commited on
Commit
dd0d62a
·
verified ·
1 Parent(s): 51e1870

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -12,11 +12,13 @@ CONTENT_LAYER = [...] # e.g. [('block5_conv4', 1)]
12
 
13
 
14
  def gram_matrix(A):
15
- """Compute Gram matrix for style representation."""
16
- A = tf.transpose(A, (0, 3, 1, 2)) # (batch, channels, height, width)
17
- features = tf.reshape(A, (A.shape[0], A.shape[1], -1))
18
- gram = tf.matmul(features, features, transpose_b=True)
19
- return gram / tf.cast(tf.shape(features)[-1], tf.float32)
 
 
20
 
21
  def compute_content_cost(a_C, a_G):
22
  """Content cost between content and generated image features."""
 
12
 
13
 
14
  def gram_matrix(A):
15
+ """Compute Gram matrix for style representation.
16
+ A has shape (batch, height, width, channels).
17
+ """
18
+ # Flatten height & width into one dimension
19
+ A = tf.reshape(A, (A.shape[0], -1, A.shape[3])) # (batch, H*W, C)
20
+ gram = tf.matmul(A, A, transpose_a=True) # (batch, C, C)
21
+ return gram / tf.cast(tf.shape(A)[1], tf.float32)
22
 
23
  def compute_content_cost(a_C, a_G):
24
  """Content cost between content and generated image features."""