KuunVo commited on
Commit
ca337fc
·
1 Parent(s): c7c808e

Add enhanced feature

Browse files
Files changed (4) hide show
  1. app.py +2 -1
  2. ui/enhancer_ui.py +80 -1
  3. ui/upscaler_ui.py +0 -1
  4. utils.py +35 -0
app.py CHANGED
@@ -3,8 +3,9 @@ from ui import upscaler_ui, enhancer_ui
3
 
4
  st.set_page_config(layout="wide")
5
 
 
6
  # st.title("Image Upscaler and Enhancer")
7
- tab1, tab2 = st.tabs(["Upscaler", "Enhancer"])
8
 
9
  with tab1:
10
  upscaler_ui.ui()
 
3
 
4
  st.set_page_config(layout="wide")
5
 
6
+ st.header('SUPER RESOLUTION')
7
  # st.title("Image Upscaler and Enhancer")
8
+ tab1, tab2 = st.tabs([ "Upscaler", "Enhancer"]) #
9
 
10
  with tab1:
11
  upscaler_ui.ui()
ui/enhancer_ui.py CHANGED
@@ -1,2 +1,81 @@
 
 
 
 
 
 
 
 
1
  def ui():
2
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ from streamlit_image_comparison import image_comparison
6
+ from utils import upscale_image, enhanced_image
7
+
8
+
9
  def ui():
10
+ img_selected = None
11
+ input_text = None
12
+ uploaded_file = None
13
+
14
+ input_image_area = st.columns(2)
15
+
16
+ with input_image_area[0]:
17
+ option = st.selectbox(
18
+ "How do you want to provide the image?",
19
+ ("Fetch from URL", "Upload from local machine"),
20
+ key="option_enhanced"
21
+ )
22
+ if option == "Upload from local machine":
23
+ uploaded_file = st.file_uploader(
24
+ "Choose an image...", type=["jpg", "jpeg", "png"], key='file_enhanced')
25
+ elif option == "Fetch from URL":
26
+ input_text = st.text_input(
27
+ "Enter the image URL", key='input_enhanced')
28
+
29
+ if st.button("Submit", key="submit_enhanced"):
30
+ if option == "Upload from local machine" and uploaded_file is not None:
31
+ try:
32
+ img_selected = Image.open(uploaded_file)
33
+ # st.image(image, caption="Uploaded Image", use_column_width=True)
34
+ except Exception as e:
35
+ st.error(f"Error opening image: {e}")
36
+ elif option == "Fetch from URL" and input_text:
37
+ try:
38
+ response = requests.get(input_text)
39
+ response.raise_for_status()
40
+ img_selected = Image.open(BytesIO(response.content))
41
+ # st.image(image, caption="Image from URL", use_column_width=True)
42
+ except requests.exceptions.RequestException as e:
43
+ st.error(f"Error fetching image: {e}")
44
+
45
+ if img_selected:
46
+ width, height = img_selected.size
47
+ if width > 1000 or height > 1000:
48
+ st.error(
49
+ "Unable to upscale. The size of upscaled image should be less than 1000x1000")
50
+ img_selected = None
51
+
52
+
53
+ with input_image_area[1]:
54
+ option_model = st.selectbox(
55
+ "Which model do you want to use?",
56
+ ('SRUNET_x2', 'SRUNET_x3', 'SRUNET_x4', 'SRUNET_x234'),
57
+ key="option_model_enhanced"
58
+ )
59
+
60
+
61
+ if img_selected:
62
+ st.header('Results')
63
+ st.text(f'Model: {option_model}')
64
+
65
+ col1, col2 = st.columns(2)
66
+ with col1:
67
+ st.image(img_selected, caption="Original",
68
+ use_column_width=True)
69
+ with col2:
70
+ img_enhanced = enhanced_image(img_selected, option_model)
71
+ # img_enhanced = img_selected.resize((64, 64))
72
+ col2.image(img_enhanced, caption="Enhanced",
73
+ use_column_width=True)
74
+
75
+ image_comparison(
76
+ img1=img_selected,
77
+ img2=img_enhanced,
78
+ )
79
+
80
+
81
+
ui/upscaler_ui.py CHANGED
@@ -62,7 +62,6 @@ def ui():
62
  st.error(
63
  "Unable to upscale. The size of upscaled image should be less than 1000x1000")
64
  image = None
65
- # pass
66
 
67
  if image:
68
  st.header('Results')
 
62
  st.error(
63
  "Unable to upscale. The size of upscaled image should be less than 1000x1000")
64
  image = None
 
65
 
66
  if image:
67
  st.header('Results')
utils.py CHANGED
@@ -73,3 +73,38 @@ def upscale_image(img, model_name, scale_factor):
73
  img_scale_pred = img_scale_pred.squeeze(0)
74
  return transforms.ToPILImage()(img_scale_pred).convert(img_mode)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  img_scale_pred = img_scale_pred.squeeze(0)
74
  return transforms.ToPILImage()(img_scale_pred).convert(img_mode)
75
 
76
+ def enhanced_image(img, model_name):
77
+ img_mode = img.mode
78
+ if img.mode != "RGB":
79
+ img = img.convert("RGB")
80
+
81
+ transform = transforms.Compose([
82
+ transforms.ToImage(),
83
+ transforms.ToDtype(torch.float32, scale=True),
84
+ ])
85
+
86
+ #Load Model
87
+ checkpoint = torch.load(get_pretrained_path(
88
+ model_name), map_location=torch.device('cpu'))
89
+ model = UNET()
90
+ model.load_state_dict(checkpoint['best_model_state_dict'])
91
+ model.eval()
92
+
93
+ data = transform(img).clamp(0, 1).unsqueeze(0)
94
+ h_pad, w_pad = find_padding(data)
95
+ data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect')
96
+
97
+
98
+ with torch.no_grad():
99
+ img_scale_pred = model(data).clamp(0, 1)
100
+ if h_pad > 0 and w_pad > 0:
101
+ img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad]
102
+ elif h_pad > 0:
103
+ img_scale_pred = img_scale_pred[..., :-h_pad, :]
104
+ elif w_pad > 0:
105
+ img_scale_pred = img_scale_pred[..., :, :-w_pad]
106
+ else:
107
+ img_scale_pred = img_scale_pred
108
+
109
+ img_scale_pred = img_scale_pred.squeeze(0)
110
+ return transforms.ToPILImage()(img_scale_pred).convert(img_mode)