aryrk commited on
Commit
9489bae
·
1 Parent(s): ff089f0

[update] only allowed 1 image at a time and add sample image

Browse files
app.py CHANGED
@@ -9,52 +9,68 @@ from shutil import copyfile
9
  UPLOAD_DIR = "./uploaded_images"
10
  RESULTS_DIR = "./results"
11
  CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
 
12
 
13
  os.makedirs(UPLOAD_DIR, exist_ok=True)
14
  os.makedirs(RESULTS_DIR, exist_ok=True)
15
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
 
16
 
17
  REPO_ID = "hasnafk/SingleImageReflectionRemoval"
18
  MODEL_FILE = "310_net_G.pth"
19
-
20
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
21
 
22
  expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
23
  if not os.path.exists(expected_model_path):
24
  copyfile(model_path, expected_model_path)
25
 
26
- def reflection_removal(input_images):
27
- for input_image in input_images:
28
- file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image.name))
29
- shutil.copy(input_image.name, file_path)
30
-
 
31
 
32
  cmd = [
33
- "python", "test.py",
34
  "--dataroot", UPLOAD_DIR,
35
  "--name", "SingleImageReflectionRemoval",
36
  "--model", "test", "--netG", "unet_256",
37
  "--direction", "AtoB", "--dataset_mode", "single",
38
  "--norm", "batch", "--epoch", "310",
39
- "--num_test", str(len(input_images)),
40
  "--gpu_ids", "-1"
41
  ]
42
  subprocess.run(cmd, check=True)
43
 
44
- output_images = []
45
  for root, _, files in os.walk(RESULTS_DIR):
46
  for file in files:
47
  if file.endswith("_fake.png"):
48
- output_images.append(Image.open(os.path.join(root, file)))
 
 
 
 
 
 
 
 
 
49
 
50
- return output_images
 
 
 
51
 
52
  iface = gr.Interface(
53
  fn=reflection_removal,
54
- inputs=gr.File(file_types=["image"], file_count="multiple"),
55
- outputs=gr.Gallery(label="Results after Reflection Removal"),
56
  title="Reflection Remover with Pix2Pix",
57
- description="Upload images to remove reflections using a Pix2Pix model."
 
 
 
58
  )
59
 
60
  if __name__ == "__main__":
 
9
  UPLOAD_DIR = "./uploaded_images"
10
  RESULTS_DIR = "./results"
11
  CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
12
+ SAMPLE_DIR = "./sample_images"
13
 
14
  os.makedirs(UPLOAD_DIR, exist_ok=True)
15
  os.makedirs(RESULTS_DIR, exist_ok=True)
16
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
17
+ os.makedirs(SAMPLE_DIR, exist_ok=True)
18
 
19
  REPO_ID = "hasnafk/SingleImageReflectionRemoval"
20
  MODEL_FILE = "310_net_G.pth"
 
21
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
22
 
23
  expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
24
  if not os.path.exists(expected_model_path):
25
  copyfile(model_path, expected_model_path)
26
 
27
+ def reflection_removal(input_image):
28
+ if not input_image.name.endswith((".jpg", ".jpeg", ".png")):
29
+ return "FIle is not supported (only .jpg, .jpeg, .png)."
30
+
31
+ file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image.name))
32
+ shutil.copy(input_image.name, file_path)
33
 
34
  cmd = [
35
+ "python", "test.py",
36
  "--dataroot", UPLOAD_DIR,
37
  "--name", "SingleImageReflectionRemoval",
38
  "--model", "test", "--netG", "unet_256",
39
  "--direction", "AtoB", "--dataset_mode", "single",
40
  "--norm", "batch", "--epoch", "310",
41
+ "--num_test", "1",
42
  "--gpu_ids", "-1"
43
  ]
44
  subprocess.run(cmd, check=True)
45
 
 
46
  for root, _, files in os.walk(RESULTS_DIR):
47
  for file in files:
48
  if file.endswith("_fake.png"):
49
+ result_path = os.path.join(root, file)
50
+ return Image.open(result_path)
51
+
52
+ return "No results found."
53
+
54
+ def use_sample_image(sample_image_name):
55
+ sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
56
+ if not os.path.exists(sample_image_path):
57
+ return "Sample image not found."
58
+ return sample_image_path
59
 
60
+ sample_images = [
61
+ file for file in os.listdir(SAMPLE_DIR)
62
+ if file.endswith((".jpg", ".jpeg", ".png"))
63
+ ]
64
 
65
  iface = gr.Interface(
66
  fn=reflection_removal,
67
+ inputs=gr.File(file_types=["image"], label="Upload Image (JPG/PNG)", live=True),
68
+ outputs=gr.Image(type="pil", label="Processed Image"),
69
  title="Reflection Remover with Pix2Pix",
70
+ description="Upload a single image to remove reflections. No GIFs allowed.",
71
+ examples=[
72
+ os.path.join(SAMPLE_DIR, img) for img in sample_images
73
+ ],
74
  )
75
 
76
  if __name__ == "__main__":
sample_images/70-16.jpeg ADDED
sample_images/85-4.jpeg ADDED
sample_images/89-1.jpeg ADDED