Hammad712 commited on
Commit
da69390
·
verified ·
1 Parent(s): 07249fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -168
app.py CHANGED
@@ -1,180 +1,57 @@
 
1
  import streamlit as st
2
  import tensorflow as tf
3
- from tensorflow_addons.layers import InstanceNormalization
4
  import numpy as np
 
 
5
  from PIL import Image
6
- from huggingface_hub import hf_hub_download
7
- import requests
8
- from io import BytesIO
9
 
10
- # Custom CSS
11
- def set_css(style):
12
- st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
13
-
14
- # Combined dark mode styles
15
- combined_css = """
16
- .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
17
- .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
18
- .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
19
- .stSpinner { color: #4CAF50; }
20
- .title {
21
- font-size: 3rem;
22
- font-weight: bold;
23
- display: flex;
24
- align-items: center;
25
- justify-content: center;
26
- }
27
- .colorful-text {
28
- background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
29
- -webkit-background-clip: text;
30
- -webkit-text-fill-color: transparent;
31
- }
32
- .black-white-text {
33
- color: black;
34
- }
35
- .small-input .stTextInput>div>input {
36
- height: 2rem;
37
- font-size: 0.9rem;
38
- }
39
- .small-file-uploader .stFileUploader>div>div {
40
- height: 2rem;
41
- font-size: 0.9rem;
42
- }
43
- .custom-text {
44
- font-size: 1.2rem;
45
- color: #feb47b;
46
- text-align: center;
47
- margin-top: -20px;
48
- margin-bottom: 20px;
49
- }
50
- """
51
-
52
- st.set_page_config(layout="wide")
53
-
54
- st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
55
-
56
- st.markdown('<div class="title"><span class="colorful-text">Photo</span> <span class="black-white-text">to Art</span></div>', unsafe_allow_html=True)
57
- st.markdown('<div class="custom-text">Convert Photos to Art using CycleGAN</div>', unsafe_allow_html=True)
58
-
59
- # Define CycleGAN model architecture
60
- class CycleGAN(tf.keras.Model):
61
- def __init__(self):
62
- super(CycleGAN, self).__init__()
63
- self.generatorP = self.build_generator() # Generates Photo from Style
64
- self.generatorS = self.build_generator() # Generates Style from Photo
65
-
66
- def build_generator(self):
67
- def ConvBlock(filters, kernel, strides, x):
68
- x = tf.keras.layers.Conv2D(filters, kernel, strides, 'same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
69
- if filters == 3:
70
- return tf.keras.layers.Activation('tanh')(x)
71
- x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
72
- return tf.keras.layers.ReLU()(x)
73
-
74
- def ResBlock(filters, inputs):
75
- x = tf.keras.layers.Conv2D(filters, 3, 1, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(inputs)
76
- x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
77
- x = tf.keras.layers.ReLU()(x)
78
- x = tf.keras.layers.Conv2D(filters, 3, 1, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
79
- x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
80
- return tf.keras.layers.Add()([x, inputs])
81
-
82
- def TransBlock(filters, x):
83
- x = tf.keras.layers.Conv2DTranspose(filters, 3, 2, 'same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
84
- x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
85
- return tf.keras.layers.ReLU()(x)
86
-
87
- input = tf.keras.Input(shape=(256, 256, 3))
88
- x = ConvBlock(64, 7, 1, input)
89
- x = ConvBlock(128, 3, 2, x)
90
- x = ConvBlock(256, 3, 2, x)
91
- for _ in range(9):
92
- x = ResBlock(256, x)
93
- x = TransBlock(128, x)
94
- x = TransBlock(64, x)
95
- out = ConvBlock(3, 7, 1, x)
96
- return tf.keras.Model(inputs=input, outputs=out)
97
-
98
- # Load the model weights from Hugging Face
99
- try:
100
- repo_id = "Hammad712/CycleGAN-Model"
101
- generatorS_path = hf_hub_download(repo_id=repo_id, filename="generatorS.h5")
102
- generatorP_path = hf_hub_download(repo_id=repo_id, filename="generatorP.h5")
103
-
104
- model = CycleGAN()
105
- model.generatorS.load_weights(generatorS_path)
106
- model.generatorP.load_weights(generatorP_path)
107
- st.success("Model loaded successfully!")
108
- except Exception as e:
109
- st.error(f"An error occurred while loading the model: {e}")
110
- st.stop()
111
-
112
- def load_and_preprocess_image(image_path_or_url):
113
- if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')):
114
- response = requests.get(image_path_or_url)
115
- img = Image.open(BytesIO(response.content)).convert("RGB")
116
- else:
117
- img = Image.open(image_path_or_url).convert("RGB")
118
- img = img.resize((256, 256))
119
- img = np.array(img) / 127.5 - 1 # Normalize to [-1, 1]
120
  img = np.expand_dims(img, axis=0) # Add batch dimension
121
  return img
122
 
123
- def postprocess_and_display_image(img_tensor):
124
- img = tf.squeeze(img_tensor, axis=0) # Remove batch dimension
125
- img = (img * 127.5 + 127.5).numpy().astype(np.uint8) # De-normalize to [0, 255]
126
- return Image.fromarray(img)
127
-
128
- def perform_inference(model, image_path_or_url):
129
- # Load and preprocess the image
130
- input_img = load_and_preprocess_image(image_path_or_url)
131
-
132
- # Perform inference
133
- generated_img = model.generatorS(input_img, training=False)
134
 
135
- # Postprocess and return the generated image
136
- return postprocess_and_display_image(generated_img)
137
 
138
- # Input for image URL or path
139
- with st.expander("Input Options", expanded=True):
140
- image_path_or_url = st.text_input("Enter image URL", "", key="image_url", placeholder="Enter image URL", help="Enter the URL of the image to convert")
141
- uploaded_file = st.file_uploader("Or upload an image", type=["jpg", "jpeg", "png", "webp"], key="upload_file", help="Upload an image file to convert")
142
 
143
  if uploaded_file is not None:
144
- image_path_or_url = uploaded_file
145
-
146
- # Run inference button
147
- if st.button("Convert"):
148
- if image_path_or_url:
149
- with st.spinner('Processing...'):
150
- try:
151
- generated_image = perform_inference(model, image_path_or_url)
152
- original_image = load_and_preprocess_image(image_path_or_url)
153
-
154
- # Display original and generated images side by side
155
- st.markdown("### Result")
156
- col1, col2 = st.columns(2)
157
-
158
- with col1:
159
- st.image(np.array(original_image[0] * 127.5 + 127.5, dtype=np.uint8), caption='Original Image', use_column_width=True)
160
- with col2:
161
- st.image(generated_image, caption='Generated Art Image', use_column_width=True)
162
-
163
- # Provide a download button for the generated image
164
- img_byte_arr = BytesIO()
165
- generated_image.save(img_byte_arr, format='JPEG')
166
- img_byte_arr = img_byte_arr.getvalue()
167
-
168
- st.download_button(
169
- label="Download Art Image",
170
- data=img_byte_arr,
171
- file_name="art_image.jpg",
172
- mime="image/jpeg"
173
- )
174
-
175
- st.success("Image processed successfully!")
176
-
177
- except Exception as e:
178
- st.error(f"An error occurred: {e}")
179
- else:
180
- st.error("Please enter a valid image path or URL.")
 
1
+ import os
2
  import streamlit as st
3
  import tensorflow as tf
 
4
  import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from huggingface_hub import HfApi, hf_hub_download
7
  from PIL import Image
 
 
 
8
 
9
+ # Hugging Face credentials
10
+ api = HfApi()
11
+
12
+ # Set your Hugging Face username and model repository name
13
+ username = "Hammad712"
14
+ repo_name = "CycleGAN-Model"
15
+ repo_id = f"{username}/{repo_name}"
16
+
17
+ # Download model files from Hugging Face
18
+ local_dir = "CycleGAN" # Changed to a relative path
19
+ os.makedirs(local_dir, exist_ok=True)
20
+ for file in api.list_repo_files(repo_id=repo_id, repo_type="model"):
21
+ hf_hub_download(repo_id=repo_id, filename=file, local_dir=local_dir, token=token)
22
+
23
+ # Load the model
24
+ custom_objects = {'InstanceNormalization': tf.keras.layers.Layer} # Adjust custom objects as needed
25
+ loaded_model = tf.keras.models.load_model(local_dir, custom_objects=custom_objects)
26
+
27
+ # Helper functions
28
+ def load_and_preprocess_image(image):
29
+ img = image.resize((256, 256))
30
+ img = np.array(img)
31
+ img = (img - 127.5) / 127.5 # Normalize to [-1, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  img = np.expand_dims(img, axis=0) # Add batch dimension
33
  return img
34
 
35
+ def infer_image(model, image):
36
+ preprocessed_img = load_and_preprocess_image(image)
37
+ generated_img = model(preprocessed_img, training=False)
38
+ generated_img = np.squeeze(generated_img, axis=0) # Remove batch dimension
39
+ generated_img = (generated_img * 127.5 + 127.5).numpy().astype(np.uint8) # De-normalize to [0, 255]
40
+ return generated_img
 
 
 
 
 
41
 
42
+ # Streamlit UI
43
+ st.title("CycleGAN Inference App")
44
 
45
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
 
 
 
46
 
47
  if uploaded_file is not None:
48
+ # Load and display the uploaded image
49
+ image = Image.open(uploaded_file)
50
+ st.image(image, caption='Uploaded Image', use_column_width=True)
51
+
52
+ if st.button("Run Inference"):
53
+ # Perform inference
54
+ generated_image = infer_image(loaded_model, image)
55
+
56
+ # Display the result
57
+ st.image(generated_image, caption='Generated Image', use_column_width=True)