wsntxxn commited on
Commit
9ea2c42
·
verified ·
1 Parent(s): d64def4

fix noise_scheduler error (#1)

Browse files

- change noise_scheduler to local (feed6af65c2c468f57b5e09cf793b69b77142319)

app.py CHANGED
@@ -1,3 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from pathlib import Path
3
 
@@ -56,12 +72,14 @@ ckpt_path = hf_hub_download(
56
  hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
57
  hubert_model = hubert_models[0].eval().to(device)
58
 
 
59
  scheduler = getattr(
60
  noise_schedulers,
61
  config["noise_scheduler"]["type"],
62
  ).from_pretrained(
63
- config["noise_scheduler"]["name"],
64
  subfolder="scheduler",
 
65
  )
66
 
67
  @torch.no_grad()
 
1
+ import subprocess
2
+ import sys
3
+
4
+ def force_fix_huggingface_hub():
5
+ subprocess.check_call([
6
+ sys.executable,
7
+ "-m",
8
+ "pip",
9
+ "install",
10
+ "--no-deps",
11
+ "--force-reinstall",
12
+ "huggingface-hub==0.30.2"
13
+ ])
14
+
15
+ force_fix_huggingface_hub()
16
+
17
  import gradio as gr
18
  from pathlib import Path
19
 
 
72
  hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
73
  hubert_model = hubert_models[0].eval().to(device)
74
 
75
+ scheduler_path = "./configs/models--stabilityai--stable-diffusion-2-1"
76
  scheduler = getattr(
77
  noise_schedulers,
78
  config["noise_scheduler"]["type"],
79
  ).from_pretrained(
80
+ scheduler_path,
81
  subfolder="scheduler",
82
+ local_files_only=True
83
  )
84
 
85
  @torch.no_grad()
configs/models--stabilityai--stable-diffusion-2-1/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.8.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "v_prediction",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }