zhoudewei.666 commited on
Commit
f291567
·
1 Parent(s): 12c29f6

fix: adapt for ZeroGPU - move model loading to module level, add missing deps

Browse files

- Move pipeline init to module level so ZeroGPU handles CPU/GPU transfer
- Add torchvision, sentencepiece, protobuf to requirements
- Fix ImageSlider import from gradio_imageslider
- Remove model_dir/device/output_path UI controls (not applicable on Spaces)
- Set hardware: zero-gpu in README

Made-with: Cursor

Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +53 -119
  3. requirements.txt +4 -1
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ hardware: zero-gpu
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -3,21 +3,16 @@ import os
3
  import threading
4
  import time
5
 
 
 
 
 
6
  try:
7
  import spaces
8
  _HAS_SPACES = True
9
  except ImportError:
10
  _HAS_SPACES = False
11
 
12
-
13
- def setup_debug():
14
- import debugpy
15
- rank = int(os.environ.get("RANK", 0))
16
- if rank == 0:
17
- debugpy.listen(5679)
18
- print("wait for debug")
19
- debugpy.wait_for_client()
20
-
21
  def calculate_dimensions(target_area: int, ratio: float):
22
  width = math.sqrt(target_area * ratio)
23
  height = width / ratio
@@ -84,69 +79,60 @@ _HF_LORA_REPO = "limuloo1999/RefineAnything"
84
  _HF_LORA_FILENAME = "Qwen-Image-Edit-2511-RefineAny.safetensors"
85
  _HF_LORA_ADAPTER = "refine_anything"
86
 
87
- _PIPELINE = None
88
- _PIPELINE_KEY = None
89
- _LORA_LOADED = False
90
  _LIGHTNING_LOADED = False
91
  _PIPELINE_LOCK = threading.Lock()
92
 
93
 
94
- def _ensure_hf_lora() -> str:
95
- """Download the LoRA weights from HuggingFace Hub and return the local path."""
96
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- return hf_hub_download(repo_id=_HF_LORA_REPO, filename=_HF_LORA_FILENAME)
 
 
 
 
 
 
99
 
 
 
100
 
101
- def _get_pipeline(model_dir: str, device: str, load_lightning_lora: bool):
102
- global _PIPELINE, _PIPELINE_KEY, _LORA_LOADED, _LIGHTNING_LOADED
103
- base_key = (model_dir, device)
104
 
105
- with _PIPELINE_LOCK:
106
- if _PIPELINE is None or _PIPELINE_KEY != base_key:
107
- import torch
108
- from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
109
-
110
- scheduler_config = {
111
- "base_image_seq_len": 256,
112
- "base_shift": math.log(3),
113
- "invert_sigmas": False,
114
- "max_image_seq_len": 8192,
115
- "max_shift": math.log(3),
116
- "num_train_timesteps": 1000,
117
- "shift": 1.0,
118
- "shift_terminal": None,
119
- "stochastic_sampling": False,
120
- "time_shift_type": "exponential",
121
- "use_beta_sigmas": False,
122
- "use_dynamic_shifting": True,
123
- "use_exponential_sigmas": False,
124
- "use_karras_sigmas": False,
125
- }
126
- scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
127
- pipe = QwenImageEditPlusPipeline.from_pretrained(
128
- model_dir,
129
- torch_dtype=torch.bfloat16,
130
- scheduler=scheduler,
131
- )
132
- pipe.to(device)
133
- pipe.set_progress_bar_config(disable=None)
134
 
135
- _PIPELINE = pipe
136
- _PIPELINE_KEY = base_key
137
- _LORA_LOADED = False
138
- _LIGHTNING_LOADED = False
139
 
140
- if not _LORA_LOADED:
141
- local_path = _ensure_hf_lora()
142
- lora_dir = os.path.dirname(local_path)
143
- weight_name = os.path.basename(local_path)
144
- _PIPELINE.load_lora_weights(lora_dir, weight_name=weight_name, adapter_name=_HF_LORA_ADAPTER)
145
- _LORA_LOADED = True
146
 
 
147
  if load_lightning_lora and not _LIGHTNING_LOADED:
148
- from huggingface_hub import hf_hub_download
149
-
150
  lightning_path = hf_hub_download(
151
  repo_id="lightx2v/Qwen-Image-Edit-2511-Lightning",
152
  filename="Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors",
@@ -171,11 +157,7 @@ def _get_pipeline(model_dir: str, device: str, load_lightning_lora: bool):
171
  return _PIPELINE
172
 
173
 
174
- def build_app(
175
- *,
176
- default_model_dir: str,
177
- default_device: str,
178
- ):
179
  import base64
180
  import gradio as gr
181
  import inspect
@@ -636,14 +618,11 @@ def build_app(
636
  mode,
637
  spatial_source,
638
  spatial_bbox_margin,
639
- model_dir,
640
- device,
641
  seed,
642
  steps,
643
  true_cfg_scale,
644
  guidance_scale,
645
  negative_prompt,
646
- output_path,
647
  load_lightning_lora,
648
  paste_back_bbox,
649
  paste_back_mode,
@@ -653,8 +632,6 @@ def build_app(
653
  paste_blend_kernel,
654
  not_use_spatial_vae,
655
  ):
656
- import torch
657
-
658
  prompt = (prompt or "").strip()
659
  if not prompt:
660
  raise gr.Error("prompt 为空")
@@ -721,25 +698,13 @@ def build_app(
721
  if mode == "仅生成prompt":
722
  return (img_pil, img_pil), prompt_for_model, info, vis, "完成"
723
 
724
- model_dir = (model_dir or "").strip()
725
- if not model_dir:
726
- raise gr.Error("model_dir 不能为空")
727
- if os.path.exists(model_dir) and not os.path.isdir(model_dir):
728
- raise gr.Error(f"model_dir 不是目录: {model_dir}")
729
-
730
- device = (device or "").strip() or "cuda"
731
-
732
  seed = int(seed) if seed is not None and str(seed).strip() else 0
733
  steps = int(steps) if steps is not None and str(steps).strip() else 8
734
  true_cfg_scale = float(true_cfg_scale) if true_cfg_scale is not None and str(true_cfg_scale).strip() else 4.0
735
  guidance_scale = float(guidance_scale) if guidance_scale is not None and str(guidance_scale).strip() else 1.0
736
  negative_prompt = negative_prompt if negative_prompt is not None else " "
737
 
738
- pipe = _get_pipeline(
739
- model_dir=model_dir,
740
- device=device,
741
- load_lightning_lora=bool(load_lightning_lora),
742
- )
743
 
744
  img = img_for_model if image2_for_model is None else [img_for_model, image2_for_model]
745
  if spatial_mask_l is not None:
@@ -748,7 +713,7 @@ def build_app(
748
  img = img + [spatial_rgb]
749
  else:
750
  img = [img, spatial_rgb]
751
- gen = torch.Generator(device=device)
752
  gen.manual_seed(seed)
753
 
754
  t0 = time.time()
@@ -765,10 +730,6 @@ def build_app(
765
  num_images_per_prompt=1,
766
  not_use_spatial_vae=bool(not_use_spatial_vae),
767
  )
768
- # img[0].save('input0.png')
769
- # img[1].save('input1.png')
770
- # print(img[0].size, img[1].size, out.images[0].size)
771
- # out.images[0].save('./zdw_debug.png')
772
  except Exception as e:
773
  raise gr.Error(f"推理失败: {type(e).__name__}: {e}")
774
  dt = time.time() - t0
@@ -794,26 +755,7 @@ def build_app(
794
  else:
795
  out_img = out_img_crop
796
 
797
- output_path = (output_path or "").strip()
798
- saved = ""
799
- if output_path:
800
- if os.path.isdir(output_path):
801
- raise gr.Error(f"output_path 不能是目录: {output_path}")
802
- parent = os.path.dirname(os.path.abspath(output_path)) or "."
803
- if not os.path.isdir(parent):
804
- raise gr.Error(f"output_path 的父目录不存在: {parent}")
805
- out_img.save(output_path)
806
- saved = os.path.abspath(output_path)
807
- base, ext = os.path.splitext(saved)
808
- img_pil.save(base + "_input" + ext)
809
- if image2 is not None:
810
- image2.save(base + "_ref" + ext)
811
- if mask_pil_l is not None:
812
- mask_pil_l.save(base + "_mask.png")
813
-
814
  status = f"完成 ({dt:.2f}s)"
815
- if saved:
816
- status += f" 已保存: {saved}"
817
  return (img_pil, out_img), prompt_for_model, info, vis, status
818
 
819
  if _HAS_SPACES:
@@ -844,14 +786,11 @@ def build_app(
844
  spatial_source = gr.Radio(["mask", "bbox"], value="mask", label="空间提示来源(作为 mask 输入模型)")
845
  spatial_bbox_margin = gr.Number(label="spatial_bbox_margin", value=0, precision=0)
846
 
847
- model_dir = gr.Textbox(label="model_dir", value=default_model_dir)
848
- device = gr.Textbox(label="device", value=default_device)
849
  seed = gr.Number(label="seed", value=0, precision=0)
850
  steps = gr.Number(label="num_inference_steps", value=8, precision=0)
851
  true_cfg_scale = gr.Number(label="true_cfg_scale", value=4.0)
852
  guidance_scale = gr.Number(label="guidance_scale", value=1.0)
853
  negative_prompt = gr.Textbox(label="negative_prompt", value=" ")
854
- output_path = gr.Textbox(label="output_path(可空)", value="")
855
 
856
  load_lightning_lora = gr.Checkbox(label="加载加速 LoRA(Lightning)", value=False)
857
 
@@ -864,7 +803,8 @@ def build_app(
864
 
865
  not_use_spatial_vae = gr.Checkbox(label="不使用 spatial VAE(not_use_spatial_vae)", value=False)
866
 
867
- out_image = gr.ImageSlider(label="对比:原图 vs 输出", show_label=True)
 
868
  replaced_prompt = gr.Textbox(label="实际使用的 prompt", lines=4)
869
  bbox_info = gr.Textbox(label="区域信息", lines=2)
870
  image1_vis = gr.Image(label="model_input(vit384) + 区域可视化", type="pil")
@@ -879,14 +819,11 @@ def build_app(
879
  mode,
880
  spatial_source,
881
  spatial_bbox_margin,
882
- model_dir,
883
- device,
884
  seed,
885
  steps,
886
  true_cfg_scale,
887
  guidance_scale,
888
  negative_prompt,
889
- output_path,
890
  load_lightning_lora,
891
  paste_back_bbox,
892
  paste_back_mode,
@@ -897,15 +834,12 @@ def build_app(
897
  not_use_spatial_vae,
898
  ],
899
  outputs=[out_image, replaced_prompt, bbox_info, image1_vis, status],
900
- title="Qwen-Image-Edit GUI Tester",
901
  )
902
  return demo
903
 
904
 
905
- demo = build_app(
906
- default_model_dir=os.environ.get("MODEL_DIR", "Qwen/Qwen-Image-Edit-2511"),
907
- default_device="cuda",
908
- )
909
 
910
  if __name__ == "__main__":
911
  demo.launch(show_error=True)
 
3
  import threading
4
  import time
5
 
6
+ import torch
7
+ from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
8
+ from huggingface_hub import hf_hub_download
9
+
10
  try:
11
  import spaces
12
  _HAS_SPACES = True
13
  except ImportError:
14
  _HAS_SPACES = False
15
 
 
 
 
 
 
 
 
 
 
16
  def calculate_dimensions(target_area: int, ratio: float):
17
  width = math.sqrt(target_area * ratio)
18
  height = width / ratio
 
79
  _HF_LORA_FILENAME = "Qwen-Image-Edit-2511-RefineAny.safetensors"
80
  _HF_LORA_ADAPTER = "refine_anything"
81
 
 
 
 
82
  _LIGHTNING_LOADED = False
83
  _PIPELINE_LOCK = threading.Lock()
84
 
85
 
86
+ def _build_pipeline(model_dir: str):
87
+ """Build the pipeline at module level. ZeroGPU intercepts .to('cuda')
88
+ and keeps the model on CPU until a @spaces.GPU function runs."""
89
+ scheduler_config = {
90
+ "base_image_seq_len": 256,
91
+ "base_shift": math.log(3),
92
+ "invert_sigmas": False,
93
+ "max_image_seq_len": 8192,
94
+ "max_shift": math.log(3),
95
+ "num_train_timesteps": 1000,
96
+ "shift": 1.0,
97
+ "shift_terminal": None,
98
+ "stochastic_sampling": False,
99
+ "time_shift_type": "exponential",
100
+ "use_beta_sigmas": False,
101
+ "use_dynamic_shifting": True,
102
+ "use_exponential_sigmas": False,
103
+ "use_karras_sigmas": False,
104
+ }
105
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
106
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
107
+ model_dir,
108
+ torch_dtype=torch.bfloat16,
109
+ scheduler=scheduler,
110
+ )
111
+ pipe.set_progress_bar_config(disable=None)
112
 
113
+ local_path = hf_hub_download(
114
+ repo_id=_HF_LORA_REPO,
115
+ filename=_HF_LORA_FILENAME,
116
+ )
117
+ lora_dir = os.path.dirname(local_path)
118
+ weight_name = os.path.basename(local_path)
119
+ pipe.load_lora_weights(lora_dir, weight_name=weight_name, adapter_name=_HF_LORA_ADAPTER)
120
 
121
+ pipe.to("cuda")
122
+ return pipe
123
 
 
 
 
124
 
125
+ _DEFAULT_MODEL_DIR = os.environ.get("MODEL_DIR", "Qwen/Qwen-Image-Edit-2511")
126
+ print(f"[startup] Loading pipeline from {_DEFAULT_MODEL_DIR} ...")
127
+ _PIPELINE = _build_pipeline(_DEFAULT_MODEL_DIR)
128
+ print("[startup] Pipeline ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
 
 
130
 
131
+ def _get_pipeline(load_lightning_lora: bool):
132
+ global _LIGHTNING_LOADED
 
 
 
 
133
 
134
+ with _PIPELINE_LOCK:
135
  if load_lightning_lora and not _LIGHTNING_LOADED:
 
 
136
  lightning_path = hf_hub_download(
137
  repo_id="lightx2v/Qwen-Image-Edit-2511-Lightning",
138
  filename="Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors",
 
157
  return _PIPELINE
158
 
159
 
160
+ def build_app():
 
 
 
 
161
  import base64
162
  import gradio as gr
163
  import inspect
 
618
  mode,
619
  spatial_source,
620
  spatial_bbox_margin,
 
 
621
  seed,
622
  steps,
623
  true_cfg_scale,
624
  guidance_scale,
625
  negative_prompt,
 
626
  load_lightning_lora,
627
  paste_back_bbox,
628
  paste_back_mode,
 
632
  paste_blend_kernel,
633
  not_use_spatial_vae,
634
  ):
 
 
635
  prompt = (prompt or "").strip()
636
  if not prompt:
637
  raise gr.Error("prompt 为空")
 
698
  if mode == "仅生成prompt":
699
  return (img_pil, img_pil), prompt_for_model, info, vis, "完成"
700
 
 
 
 
 
 
 
 
 
701
  seed = int(seed) if seed is not None and str(seed).strip() else 0
702
  steps = int(steps) if steps is not None and str(steps).strip() else 8
703
  true_cfg_scale = float(true_cfg_scale) if true_cfg_scale is not None and str(true_cfg_scale).strip() else 4.0
704
  guidance_scale = float(guidance_scale) if guidance_scale is not None and str(guidance_scale).strip() else 1.0
705
  negative_prompt = negative_prompt if negative_prompt is not None else " "
706
 
707
+ pipe = _get_pipeline(load_lightning_lora=bool(load_lightning_lora))
 
 
 
 
708
 
709
  img = img_for_model if image2_for_model is None else [img_for_model, image2_for_model]
710
  if spatial_mask_l is not None:
 
713
  img = img + [spatial_rgb]
714
  else:
715
  img = [img, spatial_rgb]
716
+ gen = torch.Generator(device="cuda")
717
  gen.manual_seed(seed)
718
 
719
  t0 = time.time()
 
730
  num_images_per_prompt=1,
731
  not_use_spatial_vae=bool(not_use_spatial_vae),
732
  )
 
 
 
 
733
  except Exception as e:
734
  raise gr.Error(f"推理失败: {type(e).__name__}: {e}")
735
  dt = time.time() - t0
 
755
  else:
756
  out_img = out_img_crop
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  status = f"完成 ({dt:.2f}s)"
 
 
759
  return (img_pil, out_img), prompt_for_model, info, vis, status
760
 
761
  if _HAS_SPACES:
 
786
  spatial_source = gr.Radio(["mask", "bbox"], value="mask", label="空间提示来源(作为 mask 输入模型)")
787
  spatial_bbox_margin = gr.Number(label="spatial_bbox_margin", value=0, precision=0)
788
 
 
 
789
  seed = gr.Number(label="seed", value=0, precision=0)
790
  steps = gr.Number(label="num_inference_steps", value=8, precision=0)
791
  true_cfg_scale = gr.Number(label="true_cfg_scale", value=4.0)
792
  guidance_scale = gr.Number(label="guidance_scale", value=1.0)
793
  negative_prompt = gr.Textbox(label="negative_prompt", value=" ")
 
794
 
795
  load_lightning_lora = gr.Checkbox(label="加载加速 LoRA(Lightning)", value=False)
796
 
 
803
 
804
  not_use_spatial_vae = gr.Checkbox(label="不使用 spatial VAE(not_use_spatial_vae)", value=False)
805
 
806
+ from gradio_imageslider import ImageSlider
807
+ out_image = ImageSlider(label="对比:原图 vs 输出", show_label=True)
808
  replaced_prompt = gr.Textbox(label="实际使用的 prompt", lines=4)
809
  bbox_info = gr.Textbox(label="区域信息", lines=2)
810
  image1_vis = gr.Image(label="model_input(vit384) + 区域可视化", type="pil")
 
819
  mode,
820
  spatial_source,
821
  spatial_bbox_margin,
 
 
822
  seed,
823
  steps,
824
  true_cfg_scale,
825
  guidance_scale,
826
  negative_prompt,
 
827
  load_lightning_lora,
828
  paste_back_bbox,
829
  paste_back_mode,
 
834
  not_use_spatial_vae,
835
  ],
836
  outputs=[out_image, replaced_prompt, bbox_info, image1_vis, status],
837
+ title="RefineAnything - Qwen Image Edit",
838
  )
839
  return demo
840
 
841
 
842
+ demo = build_app()
 
 
 
843
 
844
  if __name__ == "__main__":
845
  demo.launch(show_error=True)
requirements.txt CHANGED
@@ -5,5 +5,8 @@ attrs
5
  gradio_imageslider
6
  git+https://github.com/huggingface/diffusers
7
  torch
 
8
  accelerate
9
- safetensors
 
 
 
5
  gradio_imageslider
6
  git+https://github.com/huggingface/diffusers
7
  torch
8
+ torchvision
9
  accelerate
10
+ safetensors
11
+ sentencepiece
12
+ protobuf