BiliSakura commited on
Commit
418ab4a
·
verified ·
1 Parent(s): 4c58b48

Add files using upload-large-folder tool

Browse files
__pycache__/run_jit_diffusers_inference.cpython-312.pyc ADDED
Binary file (7.05 kB). View file
 
demo.png CHANGED

Git LFS Details

  • SHA256: f5fdbd0300f895de7642229d1294aff74facd75c0bb4c4a01efa8c75b14b6fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB

Git LFS Details

  • SHA256: d595ae2a4d665119949ee1c3930fd7a24befd51d4d4b1932a1a4c7e9e180f899
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB
demo_images/jit_h32_test_inference.png CHANGED

Git LFS Details

  • SHA256: f5fdbd0300f895de7642229d1294aff74facd75c0bb4c4a01efa8c75b14b6fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB

Git LFS Details

  • SHA256: d595ae2a4d665119949ee1c3930fd7a24befd51d4d4b1932a1a4c7e9e180f899
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB
jit_diffusers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (448 Bytes). View file
 
jit_diffusers/__pycache__/modeling_jit_backbone.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
jit_diffusers/__pycache__/modeling_jit_transformer_2d.cpython-312.pyc ADDED
Binary file (9.74 kB). View file
 
jit_diffusers/__pycache__/modeling_jit_utils.cpython-312.pyc ADDED
Binary file (10 kB). View file
 
jit_diffusers/__pycache__/pipeline_jit.cpython-312.pyc ADDED
Binary file (9.6 kB). View file
 
jit_diffusers/__pycache__/scheduling_jit.cpython-312.pyc ADDED
Binary file (3.33 kB). View file
 
jit_diffusers/pipeline_jit.py CHANGED
@@ -13,6 +13,21 @@ from .modeling_jit_transformer_2d import JiTTransformer2DModel
13
  from .scheduling_jit import JiTScheduler
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @dataclass
17
  class JiTPipelineOutput(BaseOutput):
18
  images: List["PIL.Image.Image"] | np.ndarray | torch.Tensor
@@ -51,10 +66,10 @@ class JiTPipeline(DiffusionPipeline):
51
  self,
52
  class_labels: int | List[int] | torch.Tensor,
53
  num_inference_steps: int = 50,
54
- guidance_scale: float = 2.9,
55
  guidance_interval_min: float = 0.1,
56
  guidance_interval_max: float = 1.0,
57
- noise_scale: float = 2.0,
58
  t_eps: float = 5e-2,
59
  sampling_method: str | None = None,
60
  generator: torch.Generator | List[torch.Generator] | None = None,
@@ -81,6 +96,12 @@ class JiTPipeline(DiffusionPipeline):
81
  latent_size = int(self.transformer.config.sample_size)
82
  latent_channels = int(getattr(self.transformer.config, "in_channels", 3))
83
  num_classes = int(self.transformer.config.num_class_embeds)
 
 
 
 
 
 
84
 
85
  class_labels = class_labels.clamp(0, num_classes - 1)
86
  class_null = torch.full_like(class_labels, num_classes)
@@ -102,7 +123,9 @@ class JiTPipeline(DiffusionPipeline):
102
  x_uncond = self.transformer(sample=z_value, timestep=t.flatten(), class_labels=class_null).sample
103
  v_uncond = (x_uncond - z_value) / (1.0 - t).clamp_min(t_eps)
104
 
105
- interval_mask = (t < guidance_interval_max) & (t > guidance_interval_min)
 
 
106
  scale = torch.where(
107
  interval_mask,
108
  torch.tensor(guidance_scale, device=self._execution_device, dtype=latents.dtype),
 
13
  from .scheduling_jit import JiTScheduler
14
 
15
 
16
+ RECOMMENDED_CFG_BY_MODEL = {
17
+ "JiT-B/16": 3.0,
18
+ "JiT-L/16": 2.4,
19
+ "JiT-H/16": 2.2,
20
+ "JiT-B/32": 3.0,
21
+ "JiT-L/32": 2.5,
22
+ "JiT-H/32": 2.3,
23
+ }
24
+
25
+ RECOMMENDED_NOISE_BY_RESOLUTION = {
26
+ 256: 1.0,
27
+ 512: 2.0,
28
+ }
29
+
30
+
31
  @dataclass
32
  class JiTPipelineOutput(BaseOutput):
33
  images: List["PIL.Image.Image"] | np.ndarray | torch.Tensor
 
66
  self,
67
  class_labels: int | List[int] | torch.Tensor,
68
  num_inference_steps: int = 50,
69
+ guidance_scale: float | None = None,
70
  guidance_interval_min: float = 0.1,
71
  guidance_interval_max: float = 1.0,
72
+ noise_scale: float | None = None,
73
  t_eps: float = 5e-2,
74
  sampling_method: str | None = None,
75
  generator: torch.Generator | List[torch.Generator] | None = None,
 
96
  latent_size = int(self.transformer.config.sample_size)
97
  latent_channels = int(getattr(self.transformer.config, "in_channels", 3))
98
  num_classes = int(self.transformer.config.num_class_embeds)
99
+ model_type = str(getattr(self.transformer.config, "model_type", ""))
100
+
101
+ if guidance_scale is None:
102
+ guidance_scale = RECOMMENDED_CFG_BY_MODEL.get(model_type, 2.9)
103
+ if noise_scale is None:
104
+ noise_scale = RECOMMENDED_NOISE_BY_RESOLUTION.get(latent_size, 1.0)
105
 
106
  class_labels = class_labels.clamp(0, num_classes - 1)
107
  class_null = torch.full_like(class_labels, num_classes)
 
123
  x_uncond = self.transformer(sample=z_value, timestep=t.flatten(), class_labels=class_null).sample
124
  v_uncond = (x_uncond - z_value) / (1.0 - t).clamp_min(t_eps)
125
 
126
+ interval_mask = t < guidance_interval_max
127
+ if guidance_interval_min != 0.0:
128
+ interval_mask = interval_mask & (t > guidance_interval_min)
129
  scale = torch.where(
130
  interval_mask,
131
  torch.tensor(guidance_scale, device=self._execution_device, dtype=latents.dtype),
run_jit_diffusers_inference.py CHANGED
@@ -11,6 +11,21 @@ if str(SCRIPT_DIR) not in sys.path:
11
  from jit_diffusers import JiTPipeline
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def parse_args() -> argparse.Namespace:
15
  parser = argparse.ArgumentParser(description="Run single-image JiT diffusers inference.")
16
  parser.add_argument("--model_path", type=str, required=True, help="Path to converted diffusers model directory.")
@@ -18,10 +33,20 @@ def parse_args() -> argparse.Namespace:
18
  parser.add_argument("--class_label", type=int, default=207, help="ImageNet class id for conditional generation.")
19
  parser.add_argument("--seed", type=int, default=42, help="Random seed.")
20
  parser.add_argument("--steps", type=int, default=50, help="Number of ODE sampling steps.")
21
- parser.add_argument("--cfg", type=float, default=2.9, help="Classifier-free guidance scale.")
 
 
 
 
 
22
  parser.add_argument("--interval_min", type=float, default=0.1, help="CFG interval min.")
23
  parser.add_argument("--interval_max", type=float, default=1.0, help="CFG interval max.")
24
- parser.add_argument("--noise_scale", type=float, default=2.0, help="Initial Gaussian noise scale.")
 
 
 
 
 
25
  parser.add_argument("--t_eps", type=float, default=5e-2, help="Small epsilon for timestep denominator.")
26
  parser.add_argument(
27
  "--device",
@@ -59,6 +84,14 @@ def resolve_dtype(name: str, device: torch.device) -> torch.dtype:
59
  return torch.float32
60
 
61
 
 
 
 
 
 
 
 
 
62
  def main() -> None:
63
  args = parse_args()
64
  device = resolve_device(args.device)
@@ -70,15 +103,16 @@ def main() -> None:
70
  pipe.transformer = pipe.transformer.to(device=device, dtype=dtype)
71
  pipe.transformer.eval()
72
  sampling_method = None if args.solver == "scheduler" else args.solver
 
73
 
74
  generator = torch.Generator(device=device).manual_seed(args.seed)
75
  output = pipe(
76
  class_labels=[args.class_label],
77
  num_inference_steps=args.steps,
78
- guidance_scale=args.cfg,
79
  guidance_interval_min=args.interval_min,
80
  guidance_interval_max=args.interval_max,
81
- noise_scale=args.noise_scale,
82
  t_eps=args.t_eps,
83
  sampling_method=sampling_method,
84
  generator=generator,
@@ -89,6 +123,7 @@ def main() -> None:
89
  output_path = Path(args.output_path)
90
  output_path.parent.mkdir(parents=True, exist_ok=True)
91
  image.save(output_path)
 
92
  print(f"Saved image to: {output_path}")
93
 
94
 
 
11
  from jit_diffusers import JiTPipeline
12
 
13
 
14
+ RECOMMENDED_CFG_BY_MODEL = {
15
+ "JiT-B/16": 3.0,
16
+ "JiT-L/16": 2.4,
17
+ "JiT-H/16": 2.2,
18
+ "JiT-B/32": 3.0,
19
+ "JiT-L/32": 2.5,
20
+ "JiT-H/32": 2.3,
21
+ }
22
+
23
+ RECOMMENDED_NOISE_BY_RESOLUTION = {
24
+ 256: 1.0,
25
+ 512: 2.0,
26
+ }
27
+
28
+
29
  def parse_args() -> argparse.Namespace:
30
  parser = argparse.ArgumentParser(description="Run single-image JiT diffusers inference.")
31
  parser.add_argument("--model_path", type=str, required=True, help="Path to converted diffusers model directory.")
 
33
  parser.add_argument("--class_label", type=int, default=207, help="ImageNet class id for conditional generation.")
34
  parser.add_argument("--seed", type=int, default=42, help="Random seed.")
35
  parser.add_argument("--steps", type=int, default=50, help="Number of ODE sampling steps.")
36
+ parser.add_argument(
37
+ "--cfg",
38
+ type=float,
39
+ default=None,
40
+ help="Classifier-free guidance scale. Defaults to paper recommendation for the loaded model.",
41
+ )
42
  parser.add_argument("--interval_min", type=float, default=0.1, help="CFG interval min.")
43
  parser.add_argument("--interval_max", type=float, default=1.0, help="CFG interval max.")
44
+ parser.add_argument(
45
+ "--noise_scale",
46
+ type=float,
47
+ default=None,
48
+ help="Initial Gaussian noise scale. Defaults to paper recommendation for the loaded resolution.",
49
+ )
50
  parser.add_argument("--t_eps", type=float, default=5e-2, help="Small epsilon for timestep denominator.")
51
  parser.add_argument(
52
  "--device",
 
84
  return torch.float32
85
 
86
 
87
+ def resolve_generation_defaults(pipe: JiTPipeline, cfg: float | None, noise_scale: float | None) -> tuple[float, float]:
88
+ model_type = str(getattr(pipe.transformer.config, "model_type", ""))
89
+ sample_size = int(getattr(pipe.transformer.config, "sample_size", 256))
90
+ resolved_cfg = cfg if cfg is not None else RECOMMENDED_CFG_BY_MODEL.get(model_type, 2.9)
91
+ resolved_noise_scale = noise_scale if noise_scale is not None else RECOMMENDED_NOISE_BY_RESOLUTION.get(sample_size, 1.0)
92
+ return resolved_cfg, resolved_noise_scale
93
+
94
+
95
  def main() -> None:
96
  args = parse_args()
97
  device = resolve_device(args.device)
 
103
  pipe.transformer = pipe.transformer.to(device=device, dtype=dtype)
104
  pipe.transformer.eval()
105
  sampling_method = None if args.solver == "scheduler" else args.solver
106
+ cfg, noise_scale = resolve_generation_defaults(pipe, args.cfg, args.noise_scale)
107
 
108
  generator = torch.Generator(device=device).manual_seed(args.seed)
109
  output = pipe(
110
  class_labels=[args.class_label],
111
  num_inference_steps=args.steps,
112
+ guidance_scale=cfg,
113
  guidance_interval_min=args.interval_min,
114
  guidance_interval_max=args.interval_max,
115
+ noise_scale=noise_scale,
116
  t_eps=args.t_eps,
117
  sampling_method=sampling_method,
118
  generator=generator,
 
123
  output_path = Path(args.output_path)
124
  output_path.parent.mkdir(parents=True, exist_ok=True)
125
  image.save(output_path)
126
+ print(f"Used sampling hyperparameters: cfg={cfg}, noise_scale={noise_scale}")
127
  print(f"Saved image to: {output_path}")
128
 
129