AlbeRota commited on
Commit
af9d144
·
1 Parent(s): a741286
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -23,11 +23,11 @@ def _ensure_weights():
23
  """Download weights to cache if not present."""
24
  weights_path = hf_hub_download(
25
  repo_id="AlbeRota/UnReflectAnything",
26
- filename="weights/full_model_weights.pt"
27
  )
28
  config_path = hf_hub_download(
29
  repo_id="AlbeRota/UnReflectAnything",
30
- filename="configs/pretrained_config.yaml"
31
  )
32
  return weights_path, config_path
33
 
@@ -80,8 +80,8 @@ def _get_model(device: str):
80
  pretrained=True,
81
  # weights_path=os.path.join(os.path.dirname(__file__), ".cache", "weights", "full_model_weights.pt"),
82
  # config_path=os.path.join(os.path.dirname(__file__), ".cache", "configs", "pretrained_config.yaml"),
83
- weights_path=weights_path,
84
- config_path=config_path,
85
  device=device,
86
  verbose=False,
87
  )
@@ -133,7 +133,8 @@ def build_ui():
133
  with gr.Column(scale=0, min_width=100):
134
  LOGO_PATH = hf_hub_download(
135
  repo_id="AlbeRota/UnReflectAnything",
136
- filename="assets/logo.png"
 
137
  )
138
  if Path(LOGO_PATH).is_file():
139
  gr.Image(
 
23
  """Download weights to cache if not present."""
24
  weights_path = hf_hub_download(
25
  repo_id="AlbeRota/UnReflectAnything",
26
+ filename="weights/full_model_weights.pt",
27
  )
28
  config_path = hf_hub_download(
29
  repo_id="AlbeRota/UnReflectAnything",
30
+ filename="configs/pretrained_config.yaml",
31
  )
32
  return weights_path, config_path
33
 
 
80
  pretrained=True,
81
  # weights_path=os.path.join(os.path.dirname(__file__), ".cache", "weights", "full_model_weights.pt"),
82
  # config_path=os.path.join(os.path.dirname(__file__), ".cache", "configs", "pretrained_config.yaml"),
83
+ weights_path="weights/full_model_weights.pt",
84
+ config_path="configs/pretrained_config.yaml",
85
  device=device,
86
  verbose=False,
87
  )
 
133
  with gr.Column(scale=0, min_width=100):
134
  LOGO_PATH = hf_hub_download(
135
  repo_id="AlbeRota/UnReflectAnything",
136
+ filename="assets/logo.png",
137
+ cache_dir="assets"
138
  )
139
  if Path(LOGO_PATH).is_file():
140
  gr.Image(