import streamlit as st import os import subprocess import cv2 import matplotlib.pyplot as plt import glob import psutil import time def modify_degradations_py(): file_path = '/usr/local/lib/python3.10/site-packages/basicsr/data/degradations.py' with open(file_path, 'r') as f: lines = f.readlines() # Find the line containing 'from torchvision.transforms.functional_tensor import rgb_to_grayscale' for i, line in enumerate(lines): if 'from torchvision.transforms.functional_tensor import rgb_to_grayscale' in line: # Replace it with 'from torchvision.transforms.functional import rgb_to_grayscale' lines[i] = 'from torchvision.transforms.functional import rgb_to_grayscale\n' break with open(file_path, 'w') as f: f.writelines(lines) # Call the function to modify the file modify_degradations_py() # Page configuration st.set_page_config( page_title="Image Enhancer", page_icon="🖼️", layout="wide", initial_sidebar_state="expanded" ) # Function to display images side by side def display(img1, img2): try: fig = plt.figure(figsize=(25, 10)) ax1 = fig.add_subplot(1, 2, 1) plt.title('Input image', fontsize=16) ax1.axis('off') ax2 = fig.add_subplot(1, 2, 2) plt.title('Enhanced output', fontsize=16) ax2.axis('off') ax1.imshow(img1) ax2.imshow(img2) st.pyplot(fig, use_container_width=True) plt.close(fig) except Exception as e: st.error(f"Error displaying images: {str(e)}") # Function to read an image def imread(img_path): try: if not os.path.exists(img_path): st.error(f"Image not found: {img_path}") return None img = cv2.imread(img_path) if img is None: st.error(f"Failed to load image: {img_path}") return None img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img except Exception as e: st.error(f"Error reading image: {str(e)}") return None # Function to clean up directories def cleanup_directories(): directories = ['inputs/upload', 'results'] for directory in directories: if os.path.exists(directory): try: for file in glob.glob(os.path.join(directory, '**/*'), recursive=True): if os.path.isfile(file): os.remove(file) except Exception as e: st.sidebar.warning(f"Cleanup warning: {str(e)}") # Function to run shell commands def run_shell_commands(): try: directories = [ "results/cropped_faces", "results/restored_faces", "results/restored_imgs", "results/cmp" ] for directory in directories: os.makedirs(directory, exist_ok=True) command = "python inference_gfpgan.py -i inputs/upload -o results -v 1.3 -s 2 --bg_upsampler realesrgan" process = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=300) if process.returncode != 0: st.error(f"Enhancement failed: {process.stderr}") return False return True except subprocess.TimeoutExpired: st.error("Process timed out after 5 minutes") return False except Exception as e: st.error(f"Process error: {str(e)}") return False # Memory monitoring def check_memory(): memory = psutil.Process().memory_info().rss / 1024 / 1024 st.sidebar.text(f"Memory usage: {memory:.2f} MB") # Main app def main(): st.title('Image Enhancer') st.write('Upload an image to enhance its quality') st.write('Please wait 30-40 seconds after uploading 🙂') # Sidebar information st.sidebar.title("App Info") st.sidebar.write("This app enhances image quality using AI") check_memory() # Clean up before starting cleanup_directories() # File uploader with progress bar uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: # Show processing status status = st.empty() progress_bar = st.progress(0) status.info("Starting process...") # Create input directory input_path = os.path.join('inputs', 'upload') os.makedirs(input_path, exist_ok=True) # Save uploaded file file_path = os.path.join(input_path, uploaded_file.name) with open(file_path, 'wb') as f: f.write(uploaded_file.getbuffer()) progress_bar.progress(25) status.info("File uploaded successfully. Processing image...") # Run enhancement if run_shell_commands(): progress_bar.progress(75) status.success("Processing complete!") # Display results input_folder = 'results/cropped_faces' result_folder = 'results/restored_faces' input_list = sorted(glob.glob(os.path.join(input_folder, '*'))) output_list = sorted(glob.glob(os.path.join(result_folder, '*'))) if not input_list or not output_list: st.warning("No faces detected in the image.") else: for input_path, output_path in zip(input_list, output_list): img_input = imread(input_path) img_output = imread(output_path) if img_input is not None and img_output is not None: display(img_input, img_output) progress_bar.progress(100) else: status.error("Failed to process image.") except Exception as e: st.error(f"Error: {str(e)}") finally: # Cleanup cleanup_directories() # Clear status and progress time.sleep(2) status.empty() progress_bar.empty() check_memory() if __name__ == "__main__": try: main() except Exception as e: st.error(f"Application error: {str(e)}")