Hammad712 commited on
Commit
a412d39
·
verified ·
1 Parent(s): 032818c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -8
app.py CHANGED
@@ -1,9 +1,10 @@
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  from tensorflow_addons.layers import InstanceNormalization
4
  import numpy as np
5
  from PIL import Image
6
- from huggingface_hub import from_pretrained_keras
7
  import requests
8
  from io import BytesIO
9
 
@@ -56,14 +57,54 @@ st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
56
  st.markdown('<div class="title"><span class="colorful-text">Photo</span> <span class="black-white-text">to Art</span></div>', unsafe_allow_html=True)
57
  st.markdown('<div class="custom-text">Convert Photos to Art using CycleGAN</div>', unsafe_allow_html=True)
58
 
59
- # Define your Hugging Face repository details
60
- username = "Hammad712" # Replace with your Hugging Face username
61
- repo_name = "CycleGAN-Model"
62
- model_id = f"{username}/{repo_name}"
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
- # Load the model
66
- model = from_pretrained_keras(model_id)
 
 
 
 
 
67
  st.success("Model loaded successfully!")
68
  except Exception as e:
69
  st.error(f"An error occurred while loading the model: {e}")
 
1
+ %%writefile st.py
2
  import streamlit as st
3
  import tensorflow as tf
4
  from tensorflow_addons.layers import InstanceNormalization
5
  import numpy as np
6
  from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
  import requests
9
  from io import BytesIO
10
 
 
57
  st.markdown('<div class="title"><span class="colorful-text">Photo</span> <span class="black-white-text">to Art</span></div>', unsafe_allow_html=True)
58
  st.markdown('<div class="custom-text">Convert Photos to Art using CycleGAN</div>', unsafe_allow_html=True)
59
 
60
+ # Define CycleGAN model architecture
61
+ class CycleGAN(tf.keras.Model):
62
+ def __init__(self):
63
+ super(CycleGAN, self).__init__()
64
+ self.generatorP = self.build_generator() # Generates Photo from Style
65
+ self.generatorS = self.build_generator() # Generates Style from Photo
66
+
67
+ def build_generator(self):
68
+ def ConvBlock(filters, kernel, strides, x):
69
+ x = tf.keras.layers.Conv2D(filters, kernel, strides, 'same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
70
+ if filters == 3:
71
+ return tf.keras.layers.Activation('tanh')(x)
72
+ x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
73
+ return tf.keras.layers.ReLU()(x)
74
+
75
+ def ResBlock(filters, inputs):
76
+ x = tf.keras.layers.Conv2D(filters, 3, 1, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(inputs)
77
+ x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
78
+ x = tf.keras.layers.ReLU()(x)
79
+ x = tf.keras.layers.Conv2D(filters, 3, 1, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
80
+ x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
81
+ return tf.keras.layers.Add()([x, inputs])
82
+
83
+ def TransBlock(filters, x):
84
+ x = tf.keras.layers.Conv2DTranspose(filters, 3, 2, 'same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
85
+ x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
86
+ return tf.keras.layers.ReLU()(x)
87
+
88
+ input = tf.keras.Input(shape=(256, 256, 3))
89
+ x = ConvBlock(64, 7, 1, input)
90
+ x = ConvBlock(128, 3, 2, x)
91
+ x = ConvBlock(256, 3, 2, x)
92
+ for _ in range(9):
93
+ x = ResBlock(256, x)
94
+ x = TransBlock(128, x)
95
+ x = TransBlock(64, x)
96
+ out = ConvBlock(3, 7, 1, x)
97
+ return tf.keras.Model(inputs=input, outputs=out)
98
+
99
+ # Load the model weights from Hugging Face
100
  try:
101
+ repo_id = "Hammad712/CycleGAN-Model"
102
+ generatorS_path = hf_hub_download(repo_id=repo_id, filename="generatorS.h5")
103
+ generatorP_path = hf_hub_download(repo_id=repo_id, filename="generatorP.h5")
104
+
105
+ model = CycleGAN()
106
+ model.generatorS.load_weights(generatorS_path)
107
+ model.generatorP.load_weights(generatorP_path)
108
  st.success("Model loaded successfully!")
109
  except Exception as e:
110
  st.error(f"An error occurred while loading the model: {e}")