Hammad712 commited on
Commit
05c9258
·
verified ·
1 Parent(s): 8aec97c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -41
app.py CHANGED
@@ -1,18 +1,13 @@
1
- import os
2
- os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
3
-
4
  import streamlit as st
5
  import tensorflow as tf
6
  from tensorflow_addons.layers import InstanceNormalization
 
 
 
7
  from PIL import Image
8
  import requests
9
  from io import BytesIO
10
- import logging
11
- from huggingface_hub import HfApi, hf_hub_download
12
- import numpy as np
13
-
14
- # Setup logging
15
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
 
17
  # Custom CSS
18
  def set_css(style):
@@ -52,31 +47,28 @@ combined_css = """
52
  color: #feb47b;
53
  text-align: center;
54
  margin-top: -20px;
55
- margin-bottom: 20px.
 
56
  """
57
 
58
- # Streamlit application
59
  st.set_page_config(layout="wide")
60
 
61
  st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
62
 
63
- st.markdown('<div class="title"><span class="colorful-text">Photo</span> <span class="black-white-text">to Van Gogh</span></div>', unsafe_allow_html=True)
64
- st.markdown('<div class="custom-text">Convert photos to Van Gogh style using AI</div>', unsafe_allow_html=True)
65
 
66
- # Download and load the model
67
- MODEL_REPO = "Hammad712/CycleGAN-Model"
68
- MODEL_DIR = "/content/CycleGAN"
69
- if not os.path.exists(MODEL_DIR):
70
- os.makedirs(MODEL_DIR)
71
 
72
- # Load the model if not already loaded
73
- if 'model' not in st.session_state:
74
- api = HfApi()
75
- for file in api.list_repo_files(repo_id=MODEL_REPO, repo_type="model"):
76
- hf_hub_download(repo_id=MODEL_REPO, filename=file, local_dir=MODEL_DIR)
77
- st.session_state['model'] = tf.keras.models.load_model(MODEL_DIR, custom_objects={'InstanceNormalization': InstanceNormalization})
78
 
79
- model = st.session_state['model']
 
80
 
81
  def load_and_preprocess_image(image_path_or_url):
82
  if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')):
@@ -85,16 +77,24 @@ def load_and_preprocess_image(image_path_or_url):
85
  else:
86
  img = Image.open(image_path_or_url).convert("RGB")
87
  img = img.resize((256, 256))
88
- img = (np.array(img) - 127.5) / 127.5 # Normalize to [-1, 1]
89
- img = tf.expand_dims(img, axis=0) # Add batch dimension
90
  return img
91
 
92
- def infer_image(model, image_path_or_url):
93
- preprocessed_img = load_and_preprocess_image(image_path_or_url)
94
- generated_img = model(preprocessed_img, training=False)
95
- generated_img = tf.squeeze(generated_img, axis=0) # Remove batch dimension
96
- generated_img = (generated_img * 127.5 + 127.5).numpy().astype(np.uint8) # De-normalize to [0, 255]
97
- return Image.fromarray(generated_img)
 
 
 
 
 
 
 
 
98
 
99
  # Input for image URL or path
100
  with st.expander("Input Options", expanded=True):
@@ -109,19 +109,17 @@ if st.button("Convert"):
109
  if image_path_or_url:
110
  with st.spinner('Processing...'):
111
  try:
112
- generated_image = infer_image(model, image_path_or_url)
113
  original_image = load_and_preprocess_image(image_path_or_url)
114
- original_image = (tf.squeeze(original_image, axis=0) * 127.5 + 127.5).numpy().astype(np.uint8) # De-normalize to [0, 255]
115
- original_image = Image.fromarray(original_image)
116
 
117
  # Display original and generated images side by side
118
  st.markdown("### Result")
119
  col1, col2 = st.columns(2)
120
 
121
  with col1:
122
- st.image(original_image, caption='Original Image', use_column_width=True)
123
  with col2:
124
- st.image(generated_image, caption='Van Gogh Styled Image', use_column_width=True)
125
 
126
  # Provide a download button for the generated image
127
  img_byte_arr = BytesIO()
@@ -129,9 +127,9 @@ if st.button("Convert"):
129
  img_byte_arr = img_byte_arr.getvalue()
130
 
131
  st.download_button(
132
- label="Download Styled Image",
133
  data=img_byte_arr,
134
- file_name="vangogh_styled_image.jpg",
135
  mime="image/jpeg"
136
  )
137
 
@@ -139,6 +137,5 @@ if st.button("Convert"):
139
 
140
  except Exception as e:
141
  st.error(f"An error occurred: {e}")
142
- logging.error("Error during inference", exc_info=True)
143
  else:
144
  st.error("Please enter a valid image path or URL.")
 
 
 
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  from tensorflow_addons.layers import InstanceNormalization
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
9
  from io import BytesIO
10
+ import os
 
 
 
 
 
11
 
12
  # Custom CSS
13
  def set_css(style):
 
47
  color: #feb47b;
48
  text-align: center;
49
  margin-top: -20px;
50
+ margin-bottom: 20px;
51
+ }
52
  """
53
 
 
54
  st.set_page_config(layout="wide")
55
 
56
  st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
57
 
58
+ st.markdown('<div class="title"><span class="colorful-text">Photo</span> <span class="black-white-text">to Art</span></div>', unsafe_allow_html=True)
59
+ st.markdown('<div class="custom-text">Convert Photos to Art using CycleGAN</div>', unsafe_allow_html=True)
60
 
61
+ # Define your Hugging Face repository details
62
+ username = "Hammad712" # Replace with your Hugging Face username
63
+ repo_name = "CycleGAN-Model"
64
+ repo_id = f"{username}/{repo_name}"
65
+ model_filename = "CycleGAN.h5"
66
 
67
+ # Download the model file from the repository
68
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
 
 
 
 
69
 
70
+ # Load the model
71
+ model = tf.keras.models.load_model(model_path, custom_objects={"InstanceNormalization": InstanceNormalization})
72
 
73
  def load_and_preprocess_image(image_path_or_url):
74
  if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')):
 
77
  else:
78
  img = Image.open(image_path_or_url).convert("RGB")
79
  img = img.resize((256, 256))
80
+ img = np.array(img) / 127.5 - 1 # Normalize to [-1, 1]
81
+ img = np.expand_dims(img, axis=0) # Add batch dimension
82
  return img
83
 
84
+ def postprocess_and_display_image(img_tensor):
85
+ img = tf.squeeze(img_tensor, axis=0) # Remove batch dimension
86
+ img = (img * 127.5 + 127.5).numpy().astype(np.uint8) # De-normalize to [0, 255]
87
+ return Image.fromarray(img)
88
+
89
+ def perform_inference(model, image_path_or_url):
90
+ # Load and preprocess the image
91
+ input_img = load_and_preprocess_image(image_path_or_url)
92
+
93
+ # Perform inference
94
+ generated_img = model.generatorS(input_img, training=False)
95
+
96
+ # Postprocess and return the generated image
97
+ return postprocess_and_display_image(generated_img)
98
 
99
  # Input for image URL or path
100
  with st.expander("Input Options", expanded=True):
 
109
  if image_path_or_url:
110
  with st.spinner('Processing...'):
111
  try:
112
+ generated_image = perform_inference(model, image_path_or_url)
113
  original_image = load_and_preprocess_image(image_path_or_url)
 
 
114
 
115
  # Display original and generated images side by side
116
  st.markdown("### Result")
117
  col1, col2 = st.columns(2)
118
 
119
  with col1:
120
+ st.image(np.array(original_image[0] * 127.5 + 127.5, dtype=np.uint8), caption='Original Image', use_column_width=True)
121
  with col2:
122
+ st.image(generated_image, caption='Generated Art Image', use_column_width=True)
123
 
124
  # Provide a download button for the generated image
125
  img_byte_arr = BytesIO()
 
127
  img_byte_arr = img_byte_arr.getvalue()
128
 
129
  st.download_button(
130
+ label="Download Art Image",
131
  data=img_byte_arr,
132
+ file_name="art_image.jpg",
133
  mime="image/jpeg"
134
  )
135
 
 
137
 
138
  except Exception as e:
139
  st.error(f"An error occurred: {e}")
 
140
  else:
141
  st.error("Please enter a valid image path or URL.")