sayshara commited on
Commit
769c814
·
1 Parent(s): 8a3a4b3

Updated wrapper to use huggingface models.

Browse files
Files changed (1) hide show
  1. diffqrcoder_wrapper.py +29 -25
diffqrcoder_wrapper.py CHANGED
@@ -3,29 +3,29 @@ import torch
3
  from diffusers import ControlNetModel, DDIMScheduler
4
  from PIL import Image
5
  import qrcode
 
6
 
7
  from diffqrcoder import DiffQRCoderPipeline
8
 
9
  # ---- Defaults taken from run_diffqrcoder.py ---- #
 
 
10
  CONTROLNET_CKPT = "monster-labs/control_v1p_sd15_qrcode_monster"
11
- # Original used a direct file URL; we can keep that:
12
- PIPE_CKPT = (
13
- "https://huggingface.co/fp16-guy/Cetus-Mix_Whalefall_fp16_cleaned/"
14
- "resolve/main/cetusMix_Whalefall2_fp16.safetensors"
15
- )
16
- # You can also upload that file to the Space and use a local path.
17
 
18
- DEVICE = "cuda" # ZeroGPU will give us a CUDA device during @spaces.GPU calls
 
 
 
 
19
 
20
- # Cache
21
  _controlnet = None
22
  _pipe = None
23
 
24
 
25
  def _make_qr_image(
26
  data: str,
27
- box_size: int = 20, # aligns with qrcode_module_size default
28
- border: int = 4, # typical QR quiet zone in modules
29
  ) -> Image.Image:
30
  qr = qrcode.QRCode(
31
  version=None,
@@ -42,32 +42,42 @@ def _make_qr_image(
42
  def load_pipeline():
43
  """
44
  Lazily load ControlNet + DiffQRCoderPipeline.
45
- Mirrors run_diffqrcoder.py, but only once.
 
 
 
46
  """
47
  global _controlnet, _pipe
48
 
49
  if _pipe is not None:
50
  return _pipe
51
 
52
- # 1. ControlNet
53
  if _controlnet is None:
54
  _controlnet = ControlNetModel.from_pretrained(
55
  CONTROLNET_CKPT,
56
  torch_dtype=torch.float16,
57
  )
58
 
59
- # 2. DiffQRCoderPipeline (from single safetensors file)
 
 
 
 
 
 
60
  pipe = DiffQRCoderPipeline.from_single_file(
61
- PIPE_CKPT,
62
  controlnet=_controlnet,
63
  torch_dtype=torch.float16,
64
- use_auth_token=True, # uses the Space's HF token
65
  )
66
 
67
- # 3. Use DDIM scheduler as in original script
68
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
69
 
70
- # Don't call .to("cuda") yet; do it inside the @spaces.GPU function
 
71
  _pipe = pipe
72
  return _pipe
73
 
@@ -86,20 +96,14 @@ def generate_qr_art(
86
  srmpgd_lr: float = 0.1,
87
  seed: int = 1,
88
  ) -> Image.Image:
89
- """
90
- Directly mirrors the call at the bottom of run_diffqrcoder.py,
91
- but takes the QR content + prompt as arguments and returns a PIL image.
92
- """
93
  pipe = load_pipeline()
94
 
95
- # ZeroGPU will ensure DEVICE exists as "cuda" when we call this
96
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
97
 
98
- # Create QR image in-memory instead of loading from disk
99
  qrcode_img = _make_qr_image(
100
  data=url_or_text,
101
- box_size=qrcode_module_size, # roughly aligned
102
- border=4, # module-based border; padding param handles extra
103
  )
104
 
105
  pipe = pipe.to(DEVICE)
 
3
  from diffusers import ControlNetModel, DDIMScheduler
4
  from PIL import Image
5
  import qrcode
6
+ from huggingface_hub import hf_hub_download
7
 
8
  from diffqrcoder import DiffQRCoderPipeline
9
 
10
  # ---- Defaults taken from run_diffqrcoder.py ---- #
11
+
12
+ # ControlNet is already a proper HF repo id:
13
  CONTROLNET_CKPT = "monster-labs/control_v1p_sd15_qrcode_monster"
 
 
 
 
 
 
14
 
15
+ # For the base SD model (Cetus-Mix), use repo + filename rather than raw URL
16
+ PIPE_REPO_ID = "fp16-guy/Cetus-Mix_Whalefall_fp16_cleaned"
17
+ PIPE_FILENAME = "cetusMix_Whalefall2_fp16.safetensors"
18
+
19
+ DEVICE = "cuda"
20
 
 
21
  _controlnet = None
22
  _pipe = None
23
 
24
 
25
  def _make_qr_image(
26
  data: str,
27
+ box_size: int = 20,
28
+ border: int = 4,
29
  ) -> Image.Image:
30
  qr = qrcode.QRCode(
31
  version=None,
 
42
  def load_pipeline():
43
  """
44
  Lazily load ControlNet + DiffQRCoderPipeline.
45
+
46
+ This now:
47
+ - pulls the ControlNet weights from HF by repo id
48
+ - downloads the Cetus-Mix safetensors file via hf_hub_download
49
  """
50
  global _controlnet, _pipe
51
 
52
  if _pipe is not None:
53
  return _pipe
54
 
55
+ # 1. Load ControlNet from its HF repo
56
  if _controlnet is None:
57
  _controlnet = ControlNetModel.from_pretrained(
58
  CONTROLNET_CKPT,
59
  torch_dtype=torch.float16,
60
  )
61
 
62
+ # 2. Download the base model safetensors from Hugging Face Hub
63
+ ckpt_path = hf_hub_download(
64
+ repo_id=PIPE_REPO_ID,
65
+ filename=PIPE_FILENAME,
66
+ )
67
+
68
+ # 3. Build DiffQRCoder pipeline from the local safetensors file
69
  pipe = DiffQRCoderPipeline.from_single_file(
70
+ ckpt_path,
71
  controlnet=_controlnet,
72
  torch_dtype=torch.float16,
73
+ use_auth_token=True, # uses the Space's HF token
74
  )
75
 
76
+ # 4. Same scheduler as in run_diffqrcoder.py
77
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
78
 
79
+ # NOTE: we call .to("cuda") inside the @spaces.GPU function so that
80
+ # it only happens when a GPU is actually attached.
81
  _pipe = pipe
82
  return _pipe
83
 
 
96
  srmpgd_lr: float = 0.1,
97
  seed: int = 1,
98
  ) -> Image.Image:
 
 
 
 
99
  pipe = load_pipeline()
100
 
 
101
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
102
 
 
103
  qrcode_img = _make_qr_image(
104
  data=url_or_text,
105
+ box_size=qrcode_module_size,
106
+ border=4,
107
  )
108
 
109
  pipe = pipe.to(DEVICE)