aryrk commited on
Commit
a0c04c5
·
1 Parent(s): a86af59

[feat] Check for model entry

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import subprocess
5
  from PIL import Image
6
  from huggingface_hub import hf_hub_download
 
7
 
8
  UPLOAD_DIR = "./uploaded_images"
9
  RESULTS_DIR = "./results"
@@ -18,6 +19,10 @@ MODEL_FILE = "310_net_G.pth"
18
 
19
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
20
 
 
 
 
 
21
  def reflection_removal(input_images):
22
  for input_image in input_images:
23
  file_path = os.path.join(UPLOAD_DIR, input_image.name)
 
4
  import subprocess
5
  from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
+ from shutil import copyfile
8
 
9
  UPLOAD_DIR = "./uploaded_images"
10
  RESULTS_DIR = "./results"
 
19
 
20
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
21
 
22
+ expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
23
+ if not os.path.exists(expected_model_path):
24
+ copyfile(model_path, expected_model_path)
25
+
26
  def reflection_removal(input_images):
27
  for input_image in input_images:
28
  file_path = os.path.join(UPLOAD_DIR, input_image.name)