aryrk commited on
Commit
d3aec7f
·
1 Parent(s): 2d422fa

[feat] make session folder to seperate per user work

Browse files
Files changed (1) hide show
  1. app.py +51 -51
app.py CHANGED
@@ -2,16 +2,15 @@ import gradio as gr
2
  import shutil
3
  import os
4
  import subprocess
 
5
  from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
  from shutil import copyfile
8
 
9
- UPLOAD_DIR = "./uploaded_images"
10
  RESULTS_DIR = "./results"
11
  CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
12
  SAMPLE_DIR = "./sample_images"
13
 
14
- os.makedirs(UPLOAD_DIR, exist_ok=True)
15
  os.makedirs(RESULTS_DIR, exist_ok=True)
16
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
17
  os.makedirs(SAMPLE_DIR, exist_ok=True)
@@ -24,63 +23,57 @@ expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
24
  if not os.path.exists(expected_model_path):
25
  copyfile(model_path, expected_model_path)
26
 
27
- def count_files(directory):
28
- return sum([len(files) for _, _, files in os.walk(directory)])
 
29
 
30
- def clear_uploaded_images():
31
- for filename in os.listdir(UPLOAD_DIR):
32
- file_path = os.path.join(UPLOAD_DIR, filename)
33
- if os.path.isfile(file_path):
34
- os.remove(file_path)
35
 
36
- def clear_results():
37
- for root, _, files in os.walk(RESULTS_DIR):
38
- for file in files:
39
- file_path = os.path.join(root, file)
40
- os.remove(file_path)
41
-
42
- def reflection_removal(input_image, preprocess_type="resize_and_crop"):
43
- if not input_image.lower().endswith((".jpg", ".jpeg", ".png")):
44
- return "File is not supported (only .jpg, .jpeg, .png)."
45
-
46
- file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image))
47
  shutil.copy(input_image, file_path)
48
-
49
  input_filename = os.path.splitext(os.path.basename(file_path))[0]
50
-
51
  cmd = [
52
  "python", "test.py",
53
- "--dataroot", UPLOAD_DIR,
54
  "--name", "SingleImageReflectionRemoval",
55
  "--model", "test", "--netG", "unet_256",
56
  "--direction", "AtoB", "--dataset_mode", "single",
57
  "--norm", "batch", "--epoch", "310",
58
- "--num_test", str(count_files(UPLOAD_DIR)),
59
  "--gpu_ids", "-1",
60
  "--preprocess", preprocess_type
61
  ]
62
  subprocess.run(cmd, check=True)
63
-
64
  for root, _, files in os.walk(RESULTS_DIR):
65
  for file in files:
66
  if file.startswith(input_filename) and file.endswith("_fake.png"):
67
  result_path = os.path.join(root, file)
68
  output_image = Image.open(result_path)
69
-
70
- clear_uploaded_images()
71
- clear_results()
72
-
73
  return output_image
74
 
75
- clear_uploaded_images()
76
- clear_results()
77
  return "No results found."
78
 
79
- def use_sample_image(sample_image_name):
80
- sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
81
- if not os.path.exists(sample_image_path):
82
- return "Sample image not found."
83
- return sample_image_path
 
 
 
 
 
 
84
 
85
  sample_images = [
86
  file for file in os.listdir(SAMPLE_DIR)
@@ -91,20 +84,27 @@ preprocess_options = [
91
  "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"
92
  ]
93
 
94
- iface = gr.Interface(
95
- fn=reflection_removal,
96
- inputs=[
97
- gr.Image(type="filepath", label="Upload Image (JPG/PNG)"),
98
- gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop")
99
- ],
100
- outputs=gr.Image(label="Result after Reflection Removal"),
101
- examples=[
102
- [os.path.join("sample_images", img), "resize_and_crop"]
103
- for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png"))
104
- ],
105
- title="Reflection Remover with Pix2Pix",
106
- description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below."
107
- )
 
 
 
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
- iface.launch()
 
2
  import shutil
3
  import os
4
  import subprocess
5
+ import uuid
6
  from PIL import Image
7
  from huggingface_hub import hf_hub_download
8
  from shutil import copyfile
9
 
 
10
  RESULTS_DIR = "./results"
11
  CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
12
  SAMPLE_DIR = "./sample_images"
13
 
 
14
  os.makedirs(RESULTS_DIR, exist_ok=True)
15
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
16
  os.makedirs(SAMPLE_DIR, exist_ok=True)
 
23
  if not os.path.exists(expected_model_path):
24
  copyfile(model_path, expected_model_path)
25
 
26
+ def reflection_removal(input_image, preprocess_type="resize_and_crop", session_id=None):
27
+ if not session_id:
28
+ return "Session ID missing. Please try again."
29
 
30
+ upload_dir = os.path.join("./sessions", session_id, "uploads")
31
+ os.makedirs(upload_dir, exist_ok=True)
 
 
 
32
 
33
+ if not input_image or not os.path.exists(input_image):
34
+ return "No image was provided or file was cleared. Please upload a valid image."
35
+
36
+ file_path = os.path.join(upload_dir, os.path.basename(input_image))
 
 
 
 
 
 
 
37
  shutil.copy(input_image, file_path)
38
+
39
  input_filename = os.path.splitext(os.path.basename(file_path))[0]
40
+
41
  cmd = [
42
  "python", "test.py",
43
+ "--dataroot", upload_dir,
44
  "--name", "SingleImageReflectionRemoval",
45
  "--model", "test", "--netG", "unet_256",
46
  "--direction", "AtoB", "--dataset_mode", "single",
47
  "--norm", "batch", "--epoch", "310",
48
+ "--num_test", "1",
49
  "--gpu_ids", "-1",
50
  "--preprocess", preprocess_type
51
  ]
52
  subprocess.run(cmd, check=True)
53
+
54
  for root, _, files in os.walk(RESULTS_DIR):
55
  for file in files:
56
  if file.startswith(input_filename) and file.endswith("_fake.png"):
57
  result_path = os.path.join(root, file)
58
  output_image = Image.open(result_path)
59
+
60
+ clear_session_files(session_id)
 
 
61
  return output_image
62
 
63
+ clear_session_files(session_id)
 
64
  return "No results found."
65
 
66
+ def clear_session_files(session_id):
67
+ session_dir = os.path.join("./sessions", session_id)
68
+ if os.path.exists(session_dir):
69
+ shutil.rmtree(session_dir)
70
+ print(f"Session {session_id} files cleared.")
71
+
72
+ def clear_action(session_id=None):
73
+ if session_id:
74
+ clear_session_files(session_id)
75
+ return "Upload cleared!"
76
+ return "No session found to clear."
77
 
78
  sample_images = [
79
  file for file in os.listdir(SAMPLE_DIR)
 
84
  "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"
85
  ]
86
 
87
+ def session_interface():
88
+ session_id = str(uuid.uuid4())
89
+
90
+ return gr.Interface(
91
+ fn=lambda img, prep: reflection_removal(img, prep, session_id),
92
+ inputs=[
93
+ gr.Image(type="filepath", label="Upload Image (JPG/PNG)", interactive=True),
94
+ gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop")
95
+ ],
96
+ outputs=gr.Image(label="Results after Reflection Removal"),
97
+ examples=[
98
+ [os.path.join("sample_images", img), "resize_and_crop"]
99
+ for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png"))
100
+ ],
101
+ title="Reflection Remover with Pix2Pix",
102
+ description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.",
103
+ allow_flagging="never",
104
+ live=False
105
+ )
106
+
107
+ os.makedirs("./sessions", exist_ok=True)
108
 
109
  if __name__ == "__main__":
110
+ session_interface().launch()