ChBysk's picture
Update app.py
78bfdbd verified
import os
import torch
import numpy as np
import zipfile
import gradio as gr
from PIL import Image
# Load the model directly from the source repository using torch.hub
# This bypasses the need for local .py files
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.hub.load('SkyTNT/anime-segmentation', 'isnetis', pretrained=True, trust_repo=True)
model.to(device)
model.eval()
def process_images(file_paths):
if not file_paths:
return None, None
res_dir, mask_dir = "results_out", "masks_out"
os.makedirs(res_dir, exist_ok=True)
os.makedirs(mask_dir, exist_ok=True)
res_list, mask_list = [], []
for path in file_paths:
img = Image.open(path).convert("RGB")
original_size = img.size # Store original dimensions
# Prepare image for model (the model expects 1024x1024 internally)
input_img = img.resize((1024, 1024))
img_np = np.array(input_img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
# Model returns a list of tensors; the first one is the main mask
outputs = model(img_tensor)
mask = outputs[0][0][0].cpu().numpy()
# Resize mask back to original resolution to avoid quality loss
mask_img = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, resample=Image.BILINEAR)
# Create Transparent Result
result_img = img.copy()
result_img.putalpha(mask_img)
# Save files for zipping
base_name = os.path.splitext(os.path.basename(path))[0]
res_path = os.path.join(res_dir, f"{base_name}_no_bg.png")
mask_path = os.path.join(mask_dir, f"{base_name}_mask.png")
result_img.save(res_path)
mask_img.save(mask_path)
res_list.append(res_path)
mask_list.append(mask_path)
# Create ZIP files
res_zip, mask_zip = "transparent_results.zip", "grayscale_masks.zip"
with zipfile.ZipFile(res_zip, 'w') as z:
for f in res_list: z.write(f, os.path.basename(f))
with zipfile.ZipFile(mask_zip, 'w') as z:
for f in mask_list: z.write(f, os.path.basename(f))
return res_zip, mask_zip
# Build the Interface
with gr.Blocks(title="Anime Background Remover (Bulk)") as demo:
gr.Markdown("## 🏮 Bulk Anime Background Remover")
gr.Markdown("Upload multiple images. This tool processes them at their **original resolution** and provides separate ZIP downloads.")
input_files = gr.File(label="Select Images", file_count="multiple", file_types=["image"])
btn = gr.Button("🚀 Start Processing", variant="primary")
with gr.Row():
out_res = gr.File(label="1. Download Transparent Images (ZIP)")
out_mask = gr.File(label="2. Download Grayscale Masks (ZIP)")
btn.click(process_images, inputs=input_files, outputs=[out_res, out_mask])
demo.launch()