Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,180 +1,57 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import tensorflow as tf
|
| 3 |
-
from tensorflow_addons.layers import InstanceNormalization
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
-
from huggingface_hub import hf_hub_download
|
| 7 |
-
import requests
|
| 8 |
-
from io import BytesIO
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
.
|
| 33 |
-
color: black;
|
| 34 |
-
}
|
| 35 |
-
.small-input .stTextInput>div>input {
|
| 36 |
-
height: 2rem;
|
| 37 |
-
font-size: 0.9rem;
|
| 38 |
-
}
|
| 39 |
-
.small-file-uploader .stFileUploader>div>div {
|
| 40 |
-
height: 2rem;
|
| 41 |
-
font-size: 0.9rem;
|
| 42 |
-
}
|
| 43 |
-
.custom-text {
|
| 44 |
-
font-size: 1.2rem;
|
| 45 |
-
color: #feb47b;
|
| 46 |
-
text-align: center;
|
| 47 |
-
margin-top: -20px;
|
| 48 |
-
margin-bottom: 20px;
|
| 49 |
-
}
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
st.set_page_config(layout="wide")
|
| 53 |
-
|
| 54 |
-
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
|
| 55 |
-
|
| 56 |
-
st.markdown('<div class="title"><span class="colorful-text">Photo</span> <span class="black-white-text">to Art</span></div>', unsafe_allow_html=True)
|
| 57 |
-
st.markdown('<div class="custom-text">Convert Photos to Art using CycleGAN</div>', unsafe_allow_html=True)
|
| 58 |
-
|
| 59 |
-
# Define CycleGAN model architecture
|
| 60 |
-
class CycleGAN(tf.keras.Model):
|
| 61 |
-
def __init__(self):
|
| 62 |
-
super(CycleGAN, self).__init__()
|
| 63 |
-
self.generatorP = self.build_generator() # Generates Photo from Style
|
| 64 |
-
self.generatorS = self.build_generator() # Generates Style from Photo
|
| 65 |
-
|
| 66 |
-
def build_generator(self):
|
| 67 |
-
def ConvBlock(filters, kernel, strides, x):
|
| 68 |
-
x = tf.keras.layers.Conv2D(filters, kernel, strides, 'same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
|
| 69 |
-
if filters == 3:
|
| 70 |
-
return tf.keras.layers.Activation('tanh')(x)
|
| 71 |
-
x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
|
| 72 |
-
return tf.keras.layers.ReLU()(x)
|
| 73 |
-
|
| 74 |
-
def ResBlock(filters, inputs):
|
| 75 |
-
x = tf.keras.layers.Conv2D(filters, 3, 1, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(inputs)
|
| 76 |
-
x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
|
| 77 |
-
x = tf.keras.layers.ReLU()(x)
|
| 78 |
-
x = tf.keras.layers.Conv2D(filters, 3, 1, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
|
| 79 |
-
x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
|
| 80 |
-
return tf.keras.layers.Add()([x, inputs])
|
| 81 |
-
|
| 82 |
-
def TransBlock(filters, x):
|
| 83 |
-
x = tf.keras.layers.Conv2DTranspose(filters, 3, 2, 'same', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(x)
|
| 84 |
-
x = InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
|
| 85 |
-
return tf.keras.layers.ReLU()(x)
|
| 86 |
-
|
| 87 |
-
input = tf.keras.Input(shape=(256, 256, 3))
|
| 88 |
-
x = ConvBlock(64, 7, 1, input)
|
| 89 |
-
x = ConvBlock(128, 3, 2, x)
|
| 90 |
-
x = ConvBlock(256, 3, 2, x)
|
| 91 |
-
for _ in range(9):
|
| 92 |
-
x = ResBlock(256, x)
|
| 93 |
-
x = TransBlock(128, x)
|
| 94 |
-
x = TransBlock(64, x)
|
| 95 |
-
out = ConvBlock(3, 7, 1, x)
|
| 96 |
-
return tf.keras.Model(inputs=input, outputs=out)
|
| 97 |
-
|
| 98 |
-
# Load the model weights from Hugging Face
|
| 99 |
-
try:
|
| 100 |
-
repo_id = "Hammad712/CycleGAN-Model"
|
| 101 |
-
generatorS_path = hf_hub_download(repo_id=repo_id, filename="generatorS.h5")
|
| 102 |
-
generatorP_path = hf_hub_download(repo_id=repo_id, filename="generatorP.h5")
|
| 103 |
-
|
| 104 |
-
model = CycleGAN()
|
| 105 |
-
model.generatorS.load_weights(generatorS_path)
|
| 106 |
-
model.generatorP.load_weights(generatorP_path)
|
| 107 |
-
st.success("Model loaded successfully!")
|
| 108 |
-
except Exception as e:
|
| 109 |
-
st.error(f"An error occurred while loading the model: {e}")
|
| 110 |
-
st.stop()
|
| 111 |
-
|
| 112 |
-
def load_and_preprocess_image(image_path_or_url):
|
| 113 |
-
if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')):
|
| 114 |
-
response = requests.get(image_path_or_url)
|
| 115 |
-
img = Image.open(BytesIO(response.content)).convert("RGB")
|
| 116 |
-
else:
|
| 117 |
-
img = Image.open(image_path_or_url).convert("RGB")
|
| 118 |
-
img = img.resize((256, 256))
|
| 119 |
-
img = np.array(img) / 127.5 - 1 # Normalize to [-1, 1]
|
| 120 |
img = np.expand_dims(img, axis=0) # Add batch dimension
|
| 121 |
return img
|
| 122 |
|
| 123 |
-
def
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# Load and preprocess the image
|
| 130 |
-
input_img = load_and_preprocess_image(image_path_or_url)
|
| 131 |
-
|
| 132 |
-
# Perform inference
|
| 133 |
-
generated_img = model.generatorS(input_img, training=False)
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
-
|
| 139 |
-
with st.expander("Input Options", expanded=True):
|
| 140 |
-
image_path_or_url = st.text_input("Enter image URL", "", key="image_url", placeholder="Enter image URL", help="Enter the URL of the image to convert")
|
| 141 |
-
uploaded_file = st.file_uploader("Or upload an image", type=["jpg", "jpeg", "png", "webp"], key="upload_file", help="Upload an image file to convert")
|
| 142 |
|
| 143 |
if uploaded_file is not None:
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
if
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
# Display original and generated images side by side
|
| 155 |
-
st.markdown("### Result")
|
| 156 |
-
col1, col2 = st.columns(2)
|
| 157 |
-
|
| 158 |
-
with col1:
|
| 159 |
-
st.image(np.array(original_image[0] * 127.5 + 127.5, dtype=np.uint8), caption='Original Image', use_column_width=True)
|
| 160 |
-
with col2:
|
| 161 |
-
st.image(generated_image, caption='Generated Art Image', use_column_width=True)
|
| 162 |
-
|
| 163 |
-
# Provide a download button for the generated image
|
| 164 |
-
img_byte_arr = BytesIO()
|
| 165 |
-
generated_image.save(img_byte_arr, format='JPEG')
|
| 166 |
-
img_byte_arr = img_byte_arr.getvalue()
|
| 167 |
-
|
| 168 |
-
st.download_button(
|
| 169 |
-
label="Download Art Image",
|
| 170 |
-
data=img_byte_arr,
|
| 171 |
-
file_name="art_image.jpg",
|
| 172 |
-
mime="image/jpeg"
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
st.success("Image processed successfully!")
|
| 176 |
-
|
| 177 |
-
except Exception as e:
|
| 178 |
-
st.error(f"An error occurred: {e}")
|
| 179 |
-
else:
|
| 180 |
-
st.error("Please enter a valid image path or URL.")
|
|
|
|
| 1 |
+
import os
|
| 2 |
import streamlit as st
|
| 3 |
import tensorflow as tf
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 7 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
# Hugging Face credentials
|
| 10 |
+
api = HfApi()
|
| 11 |
+
|
| 12 |
+
# Set your Hugging Face username and model repository name
|
| 13 |
+
username = "Hammad712"
|
| 14 |
+
repo_name = "CycleGAN-Model"
|
| 15 |
+
repo_id = f"{username}/{repo_name}"
|
| 16 |
+
|
| 17 |
+
# Download model files from Hugging Face
|
| 18 |
+
local_dir = "CycleGAN" # Changed to a relative path
|
| 19 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 20 |
+
for file in api.list_repo_files(repo_id=repo_id, repo_type="model"):
|
| 21 |
+
hf_hub_download(repo_id=repo_id, filename=file, local_dir=local_dir, token=token)
|
| 22 |
+
|
| 23 |
+
# Load the model
|
| 24 |
+
custom_objects = {'InstanceNormalization': tf.keras.layers.Layer} # Adjust custom objects as needed
|
| 25 |
+
loaded_model = tf.keras.models.load_model(local_dir, custom_objects=custom_objects)
|
| 26 |
+
|
| 27 |
+
# Helper functions
|
| 28 |
+
def load_and_preprocess_image(image):
|
| 29 |
+
img = image.resize((256, 256))
|
| 30 |
+
img = np.array(img)
|
| 31 |
+
img = (img - 127.5) / 127.5 # Normalize to [-1, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
img = np.expand_dims(img, axis=0) # Add batch dimension
|
| 33 |
return img
|
| 34 |
|
| 35 |
+
def infer_image(model, image):
|
| 36 |
+
preprocessed_img = load_and_preprocess_image(image)
|
| 37 |
+
generated_img = model(preprocessed_img, training=False)
|
| 38 |
+
generated_img = np.squeeze(generated_img, axis=0) # Remove batch dimension
|
| 39 |
+
generated_img = (generated_img * 127.5 + 127.5).numpy().astype(np.uint8) # De-normalize to [0, 255]
|
| 40 |
+
return generated_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# Streamlit UI
|
| 43 |
+
st.title("CycleGAN Inference App")
|
| 44 |
|
| 45 |
+
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if uploaded_file is not None:
|
| 48 |
+
# Load and display the uploaded image
|
| 49 |
+
image = Image.open(uploaded_file)
|
| 50 |
+
st.image(image, caption='Uploaded Image', use_column_width=True)
|
| 51 |
+
|
| 52 |
+
if st.button("Run Inference"):
|
| 53 |
+
# Perform inference
|
| 54 |
+
generated_image = infer_image(loaded_model, image)
|
| 55 |
+
|
| 56 |
+
# Display the result
|
| 57 |
+
st.image(generated_image, caption='Generated Image', use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|