lukeafullard commited on
Commit
e6c83b5
·
verified ·
1 Parent(s): 046baa7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +149 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,151 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image, ImageEnhance
3
+ from rembg import remove
4
+ import io
5
+ import torch
6
+ import numpy as np
7
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, pipeline
8
+
9
+ # Page Configuration
10
+ st.set_page_config(layout="wide", page_title="AI Image Lab")
11
+
12
+ # --- Caching AI Models ---
13
+ # We use separate functions for 2x and 4x to avoid loading both into memory if not needed.
14
+
15
+ @st.cache_resource
16
+ def load_upscaler_x2():
17
+ """Loads the Swin2SR model for 2x upscale."""
18
+ model_id = "caidas/swin2SR-classical-sr-x2-64"
19
+ processor = AutoImageProcessor.from_pretrained(model_id)
20
+ model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
21
+ return processor, model
22
+
23
+ @st.cache_resource
24
+ def load_upscaler_x4():
25
+ """Loads the Swin2SR model for 4x upscale."""
26
+ # This model is heavier and takes longer to run
27
+ model_id = "caidas/swin2SR-classical-sr-x4-63"
28
+ processor = AutoImageProcessor.from_pretrained(model_id)
29
+ model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
30
+ return processor, model
31
+
32
+ @st.cache_resource
33
+ def load_depth_pipeline():
34
+ """Loads a lightweight Depth Estimation pipeline."""
35
+ pipe = pipeline(task="depth-estimation", model="vinvino02/glpn-nyu")
36
+ return pipe
37
+
38
+ def ai_upscale(image, processor, model):
39
+ """Runs the super-resolution model."""
40
+ inputs = processor(image, return_tensors="pt")
41
+ with torch.no_grad():
42
+ outputs = model(**inputs)
43
+
44
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
45
+ output = np.moveaxis(output, 0, -1)
46
+ output = (output * 255.0).round().astype(np.uint8)
47
+ return Image.fromarray(output)
48
+
49
+ def convert_image_to_bytes(img):
50
+ buf = io.BytesIO()
51
+ img.save(buf, format="PNG")
52
+ return buf.getvalue()
53
+
54
+ def main():
55
+ st.title("✨ AI Image Lab: Transformers Edition")
56
+ st.markdown("Processing pipeline: **Background Removal** → **AI Modifiers** → **Geometry**")
57
+
58
+ # --- Sidebar Controls ---
59
+ st.sidebar.header("Processing Pipeline")
60
+
61
+ # 1. Background
62
+ st.sidebar.subheader("1. Cleanup")
63
+ remove_bg = st.sidebar.checkbox("Remove Background (rembg)", value=False)
64
+
65
+ # 2. AI Enhancements
66
+ st.sidebar.subheader("2. AI Enhancements")
67
+ ai_mode = st.sidebar.radio(
68
+ "Choose AI Modification:",
69
+ ["None", "AI Super-Resolution (2x)", "AI Super-Resolution (4x)", "Depth Estimation"]
70
+ )
71
+
72
+ # 3. Geometry & Color
73
+ st.sidebar.subheader("3. Final Adjustments")
74
+ rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
75
+ contrast_val = st.sidebar.slider("Contrast", 0.5, 1.5, 1.0, 0.1)
76
+
77
+ # --- Main Content ---
78
+ uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png", "webp"])
79
+
80
+ if uploaded_file is not None:
81
+ image = Image.open(uploaded_file).convert("RGB")
82
+ processed_image = image.copy()
83
+
84
+ # --- STEP 1: Background Removal ---
85
+ if remove_bg:
86
+ with st.spinner("Removing background..."):
87
+ processed_image = remove(processed_image)
88
+
89
+ # --- STEP 2: AI Enhancements ---
90
+ if ai_mode == "AI Super-Resolution (2x)":
91
+ st.info("Loading Swin2SR (2x) model... (Fast)")
92
+ try:
93
+ processor, model = load_upscaler_x2()
94
+ with st.spinner("Upscaling (2x)..."):
95
+ processed_image = ai_upscale(processed_image, processor, model)
96
+ except Exception as e:
97
+ st.error(f"Error loading Upscaler: {e}")
98
+
99
+ elif ai_mode == "AI Super-Resolution (4x)":
100
+ st.warning("Loading Swin2SR (4x) model... (This is computationally heavy!)")
101
+ # Added a warning because 4x on CPU can be quite slow for large images
102
+ try:
103
+ processor, model = load_upscaler_x4()
104
+ with st.spinner("Upscaling (4x)... please wait"):
105
+ processed_image = ai_upscale(processed_image, processor, model)
106
+ except Exception as e:
107
+ st.error(f"Error loading Upscaler: {e}")
108
+
109
+ elif ai_mode == "Depth Estimation":
110
+ st.info("Generating Depth Map...")
111
+ try:
112
+ depth_pipe = load_depth_pipeline()
113
+ with st.spinner("Estimating depth..."):
114
+ result = depth_pipe(processed_image)
115
+ processed_image = result["depth"]
116
+ except Exception as e:
117
+ st.error(f"Error loading Depth Model: {e}")
118
+
119
+ # --- STEP 3: Geometry/Color ---
120
+ # Rotation
121
+ if rotate_angle != 0:
122
+ processed_image = processed_image.rotate(rotate_angle, expand=True)
123
+
124
+ # Contrast
125
+ if contrast_val != 1.0:
126
+ enhancer = ImageEnhance.Contrast(processed_image)
127
+ processed_image = enhancer.enhance(contrast_val)
128
+
129
+ # --- Display ---
130
+ col1, col2 = st.columns(2)
131
+ with col1:
132
+ st.subheader("Original")
133
+ st.image(image, use_container_width=True)
134
+ st.caption(f"Size: {image.size}")
135
+
136
+ with col2:
137
+ st.subheader("Result")
138
+ st.image(processed_image, use_container_width=True)
139
+ st.caption(f"Size: {processed_image.size}")
140
+
141
+ # --- Download ---
142
+ st.markdown("---")
143
+ btn = st.download_button(
144
+ label="💾 Download Result",
145
+ data=convert_image_to_bytes(processed_image),
146
+ file_name="ai_enhanced_image.png",
147
+ mime="image/png",
148
+ )
149
 
150
+ if __name__ == "__main__":
151
+ main()