aryrk
[feat] make session folder to seperate per user work
d3aec7f
raw
history blame
3.76 kB
import gradio as gr
import shutil
import os
import subprocess
import uuid
from PIL import Image
from huggingface_hub import hf_hub_download
from shutil import copyfile
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
copyfile(model_path, expected_model_path)
def reflection_removal(input_image, preprocess_type="resize_and_crop", session_id=None):
if not session_id:
return "Session ID missing. Please try again."
upload_dir = os.path.join("./sessions", session_id, "uploads")
os.makedirs(upload_dir, exist_ok=True)
if not input_image or not os.path.exists(input_image):
return "No image was provided or file was cleared. Please upload a valid image."
file_path = os.path.join(upload_dir, os.path.basename(input_image))
shutil.copy(input_image, file_path)
input_filename = os.path.splitext(os.path.basename(file_path))[0]
cmd = [
"python", "test.py",
"--dataroot", upload_dir,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1",
"--preprocess", preprocess_type
]
subprocess.run(cmd, check=True)
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
if file.startswith(input_filename) and file.endswith("_fake.png"):
result_path = os.path.join(root, file)
output_image = Image.open(result_path)
clear_session_files(session_id)
return output_image
clear_session_files(session_id)
return "No results found."
def clear_session_files(session_id):
session_dir = os.path.join("./sessions", session_id)
if os.path.exists(session_dir):
shutil.rmtree(session_dir)
print(f"Session {session_id} files cleared.")
def clear_action(session_id=None):
if session_id:
clear_session_files(session_id)
return "Upload cleared!"
return "No session found to clear."
sample_images = [
file for file in os.listdir(SAMPLE_DIR)
if file.endswith((".jpg", ".jpeg", ".png"))
]
preprocess_options = [
"resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"
]
def session_interface():
session_id = str(uuid.uuid4())
return gr.Interface(
fn=lambda img, prep: reflection_removal(img, prep, session_id),
inputs=[
gr.Image(type="filepath", label="Upload Image (JPG/PNG)", interactive=True),
gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop")
],
outputs=gr.Image(label="Results after Reflection Removal"),
examples=[
[os.path.join("sample_images", img), "resize_and_crop"]
for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png"))
],
title="Reflection Remover with Pix2Pix",
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.",
allow_flagging="never",
live=False
)
os.makedirs("./sessions", exist_ok=True)
if __name__ == "__main__":
session_interface().launch()