import streamlit as st from PIL import Image, ImageEnhance from rembg import remove import io import torch import numpy as np from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, pipeline # Page Configuration st.set_page_config(layout="wide", page_title="AI Image Lab") # --- Caching AI Models --- # We use separate functions for 2x and 4x to avoid loading both into memory if not needed. @st.cache_resource def load_upscaler_x2(): """Loads the Swin2SR model for 2x upscale.""" model_id = "caidas/swin2SR-classical-sr-x2-64" processor = AutoImageProcessor.from_pretrained(model_id) model = Swin2SRForImageSuperResolution.from_pretrained(model_id) return processor, model @st.cache_resource def load_upscaler_x4(): """Loads the Swin2SR model for 4x upscale.""" # This model is heavier and takes longer to run model_id = "caidas/swin2SR-classical-sr-x4-63" processor = AutoImageProcessor.from_pretrained(model_id) model = Swin2SRForImageSuperResolution.from_pretrained(model_id) return processor, model @st.cache_resource def load_depth_pipeline(): """Loads a lightweight Depth Estimation pipeline.""" pipe = pipeline(task="depth-estimation", model="vinvino02/glpn-nyu") return pipe def ai_upscale(image, processor, model): """Runs the super-resolution model.""" inputs = processor(image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, 0, -1) output = (output * 255.0).round().astype(np.uint8) return Image.fromarray(output) def convert_image_to_bytes(img): buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def main(): st.title("✨ AI Image Lab: Transformers Edition") st.markdown("Processing pipeline: **Background Removal** → **AI Modifiers** → **Geometry**") # --- Sidebar Controls --- st.sidebar.header("Processing Pipeline") # 1. Background st.sidebar.subheader("1. Cleanup") remove_bg = st.sidebar.checkbox("Remove Background (rembg)", value=False) # 2. AI Enhancements st.sidebar.subheader("2. AI Enhancements") ai_mode = st.sidebar.radio( "Choose AI Modification:", ["None", "AI Super-Resolution (2x)", "AI Super-Resolution (4x)", "Depth Estimation"] ) # 3. Geometry & Color st.sidebar.subheader("3. Final Adjustments") rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1) contrast_val = st.sidebar.slider("Contrast", 0.5, 1.5, 1.0, 0.1) # --- Main Content --- uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png", "webp"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") processed_image = image.copy() # --- STEP 1: Background Removal --- if remove_bg: with st.spinner("Removing background..."): processed_image = remove(processed_image) # --- STEP 2: AI Enhancements --- if ai_mode == "AI Super-Resolution (2x)": st.info("Loading Swin2SR (2x) model... (Fast)") try: processor, model = load_upscaler_x2() with st.spinner("Upscaling (2x)..."): processed_image = ai_upscale(processed_image, processor, model) except Exception as e: st.error(f"Error loading Upscaler: {e}") elif ai_mode == "AI Super-Resolution (4x)": st.warning("Loading Swin2SR (4x) model... (This is computationally heavy!)") # Added a warning because 4x on CPU can be quite slow for large images try: processor, model = load_upscaler_x4() with st.spinner("Upscaling (4x)... please wait"): processed_image = ai_upscale(processed_image, processor, model) except Exception as e: st.error(f"Error loading Upscaler: {e}") elif ai_mode == "Depth Estimation": st.info("Generating Depth Map...") try: depth_pipe = load_depth_pipeline() with st.spinner("Estimating depth..."): result = depth_pipe(processed_image) processed_image = result["depth"] except Exception as e: st.error(f"Error loading Depth Model: {e}") # --- STEP 3: Geometry/Color --- # Rotation if rotate_angle != 0: processed_image = processed_image.rotate(rotate_angle, expand=True) # Contrast if contrast_val != 1.0: enhancer = ImageEnhance.Contrast(processed_image) processed_image = enhancer.enhance(contrast_val) # --- Display --- col1, col2 = st.columns(2) with col1: st.subheader("Original") st.image(image, use_container_width=True) st.caption(f"Size: {image.size}") with col2: st.subheader("Result") st.image(processed_image, use_container_width=True) st.caption(f"Size: {processed_image.size}") # --- Download --- st.markdown("---") btn = st.download_button( label="💾 Download Result", data=convert_image_to_bytes(processed_image), file_name="ai_enhanced_image.png", mime="image/png", ) if __name__ == "__main__": main()