Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from transformers import AutoModelForImageSegmentation | |
| import torch | |
| from torchvision import transforms | |
| import spaces # Import ZeroGPU support | |
| # Detect if CUDA is available; otherwise, fallback to CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load BiRefNet model | |
| torch.set_float32_matmul_precision(["high", "highest"][0]) | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ) | |
| birefnet.to(device) | |
| # Image transformation pipeline | |
| transform_image = transforms.Compose( | |
| [ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # Decorate to ensure GPU is allocated only during model loading | |
| # Function to extract the subject using BiRefNet and create a mask | |
| def create_mask(image): | |
| image_size = image.size | |
| input_images = transform_image(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| preds = birefnet(input_images)[-1].sigmoid().cpu() # Always move results to CPU for processing | |
| pred = preds[0].squeeze() | |
| mask_pil = transforms.ToPILImage()(pred) | |
| mask = mask_pil.resize(image_size) | |
| return mask | |
| # Function to apply the pink filter-like color change | |
| def apply_filter(image, mask=None, apply_to_subject=True): | |
| # Convert image to numpy array | |
| image_np = np.array(image.convert("RGBA")) | |
| # Define the pink color in RGBA | |
| pink_color = np.array([255, 0, 255, 128]) # Pink color with transparency (alpha = 128) | |
| if apply_to_subject and mask is not None: | |
| # Convert mask to numpy array | |
| mask_np = np.array(mask) | |
| # Blend the original image with the pink color where the mask is applied | |
| for i in range(image_np.shape[0]): | |
| for j in range(image_np.shape[1]): | |
| if mask_np[i, j] > 128: # Check if the mask value indicates subject presence | |
| image_np[i, j] = (image_np[i, j] * 0.5 + pink_color * 0.5).astype(np.uint8) | |
| else: | |
| # Apply the pink filter to the whole image if no subject is detected or if chosen by user | |
| image_np = (image_np * 0.5 + pink_color * 0.5).astype(np.uint8) | |
| # Convert back to PIL image | |
| result_image = Image.fromarray(image_np) | |
| return result_image | |
| # Main processing function for Gradio | |
| def process(input_image, subject_choice): | |
| if input_image is None: | |
| raise gr.Error('Please upload an input image') | |
| # Convert input image to PIL image | |
| original_image = Image.fromarray(input_image) | |
| # Default mask is None | |
| mask = None | |
| # Generate mask using BiRefNet if the user selected "Subject Only" | |
| if subject_choice == "Subject Only": | |
| mask = create_mask(original_image) | |
| # Apply pink filter based on user choice | |
| apply_to_subject = (subject_choice == "Subject Only" and mask is not None) | |
| result_image = apply_filter(original_image, mask, apply_to_subject) | |
| return result_image | |
| # Define Gradio Interface | |
| block = gr.Blocks() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("Apply Pink Filter Effect to Subject or Full Image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="numpy", label="Input Image", height=640) | |
| subject_choice = gr.Radio( | |
| choices=["Subject Only", "Full Image"], | |
| value="Subject Only", | |
| label="Apply Pink Filter to:" | |
| ) | |
| run_button = gr.Button("Run") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| # Set the processing function | |
| run_button.click( | |
| fn=process, | |
| inputs=[input_image, subject_choice], | |
| outputs=output_image | |
| ) | |
| block.launch() |