LogicGoInfotechSpaces commited on
Commit
f71e08c
·
1 Parent(s): 95b1715

Auto-download weights from LogicGoInfotechSpaces/Smile_Changer_pre_model if missing

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
  from typing import Tuple, Dict
5
 
6
  import gradio as gr
 
7
  from PIL import Image
8
 
9
  from runners.simple_runner import SimpleRunner
@@ -22,9 +23,29 @@ logger = logging.getLogger("sfe-app")
22
  RUNNER: SimpleRunner | None = None
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def get_runner() -> SimpleRunner:
26
  global RUNNER
27
  if RUNNER is None:
 
28
  logger.info("Initializing SimpleRunner with %s", "pretrained_models/sfe_editor_light.pt")
29
  RUNNER = SimpleRunner(
30
  editor_ckpt_pth="pretrained_models/sfe_editor_light.pt",
 
4
  from typing import Tuple, Dict
5
 
6
  import gradio as gr
7
+ from huggingface_hub import snapshot_download
8
  from PIL import Image
9
 
10
  from runners.simple_runner import SimpleRunner
 
23
  RUNNER: SimpleRunner | None = None
24
 
25
 
26
+ def ensure_weights():
27
+ """Make sure pretrained weights exist locally; otherwise fetch from your HF model repo."""
28
+ need = [
29
+ "pretrained_models/sfe_editor_light.pt",
30
+ "pretrained_models/stylegan2-ffhq-config-f.pt",
31
+ ]
32
+ if all(os.path.exists(p) for p in need):
33
+ return
34
+
35
+ repo_id = "LogicGoInfotechSpaces/Smile_Changer_pre_model"
36
+ logger.info("Missing weights; downloading snapshot from %s", repo_id)
37
+ snapshot_download(
38
+ repo_id=repo_id,
39
+ local_dir="pretrained_models",
40
+ local_dir_use_symlinks=False,
41
+ allow_patterns=["**/*"],
42
+ )
43
+
44
+
45
  def get_runner() -> SimpleRunner:
46
  global RUNNER
47
  if RUNNER is None:
48
+ ensure_weights()
49
  logger.info("Initializing SimpleRunner with %s", "pretrained_models/sfe_editor_light.pt")
50
  RUNNER = SimpleRunner(
51
  editor_ckpt_pth="pretrained_models/sfe_editor_light.pt",