LejobuildYT commited on
Commit
63b49f9
·
verified ·
1 Parent(s): 158b5f4

Update hy3dgen/text2image.py

Browse files
Files changed (1) hide show
  1. hy3dgen/text2image.py +55 -17
hy3dgen/text2image.py CHANGED
@@ -14,12 +14,22 @@
14
 
15
  import os
16
  import random
17
-
18
  import numpy as np
19
  import torch
20
  from diffusers import AutoPipelineForText2Image
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  def seed_everything(seed):
24
  random.seed(seed)
25
  np.random.seed(seed)
@@ -31,43 +41,70 @@ class HunyuanDiTPipeline:
31
  def __init__(
32
  self,
33
  model_path="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
34
- device='cuda'
35
  ):
 
 
 
 
36
  self.device = device
 
 
 
 
37
  self.pipe = AutoPipelineForText2Image.from_pretrained(
38
  model_path,
39
- torch_dtype=torch.float16,
40
  enable_pag=True,
41
  pag_applied_layers=["blocks.(16|17|18|19)"]
42
  ).to(device)
 
43
  self.pos_txt = ",白色背景,3D风格,最佳质量"
44
- self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态," \
45
- "残缺,多余的手指,变异的手,画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学," \
46
- "糟糕的比例,多余的肢体,克隆的脸,毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿," \
47
- "额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
 
 
48
 
 
49
  def compile(self):
50
- # accelarate hunyuan-dit transformer,first inference will cost long time
51
- torch.set_float32_matmul_precision('high')
52
- self.pipe.transformer = torch.compile(self.pipe.transformer, fullgraph=True)
53
- # self.pipe.vae.decode = torch.compile(self.pipe.vae.decode, fullgraph=True)
54
- generator = torch.Generator(device=self.pipe.device) # infer once for hot-start
55
- out_img = self.pipe(
56
- prompt='美少女战士',
57
- negative_prompt='模糊',
 
 
 
 
 
 
 
58
  num_inference_steps=25,
59
  pag_scale=1.3,
60
  width=1024,
61
  height=1024,
62
- generator=generator,
63
  return_dict=False
64
  )[0][0]
65
 
 
66
  @torch.no_grad()
67
  def __call__(self, prompt, seed=0):
68
  seed_everything(seed)
69
- generator = torch.Generator(device=self.pipe.device)
 
 
 
 
 
 
70
  generator = generator.manual_seed(int(seed))
 
71
  out_img = self.pipe(
72
  prompt=prompt[:60] + self.pos_txt,
73
  negative_prompt=self.neg_txt,
@@ -78,4 +115,5 @@ class HunyuanDiTPipeline:
78
  generator=generator,
79
  return_dict=False
80
  )[0][0]
 
81
  return out_img
 
14
 
15
  import os
16
  import random
 
17
  import numpy as np
18
  import torch
19
  from diffusers import AutoPipelineForText2Image
20
 
21
 
22
+ # -------------------- Device Auto-Selection --------------------
23
+ def get_auto_device():
24
+ if torch.cuda.is_available():
25
+ return "cuda"
26
+ elif torch.backends.mps.is_available(): # macOS GPU
27
+ return "mps"
28
+ else:
29
+ return "cpu"
30
+
31
+
32
+ # -------------------- Seed Helper --------------------
33
  def seed_everything(seed):
34
  random.seed(seed)
35
  np.random.seed(seed)
 
41
  def __init__(
42
  self,
43
  model_path="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
44
+ device=None
45
  ):
46
+ # Device auto-detect
47
+ if device is None:
48
+ device = get_auto_device()
49
+
50
  self.device = device
51
+
52
+ # float16 only works on CUDA / sometimes MPS, but not CPU
53
+ dtype = torch.float16 if device in ["cuda", "mps"] else torch.float32
54
+
55
  self.pipe = AutoPipelineForText2Image.from_pretrained(
56
  model_path,
57
+ torch_dtype=dtype,
58
  enable_pag=True,
59
  pag_applied_layers=["blocks.(16|17|18|19)"]
60
  ).to(device)
61
+
62
  self.pos_txt = ",白色背景,3D风格,最佳质量"
63
+ self.neg_txt = (
64
+ "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,"
65
+ "残缺,多余的手指,变异的手,画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,"
66
+ "糟糕的比例,多余的肢体,克隆的脸,毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,"
67
+ "额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
68
+ )
69
 
70
+ # -------------------- Compile (optional) --------------------
71
  def compile(self):
72
+ # accelerate transformer works only on CUDA; skip otherwise
73
+ if self.device == "cuda":
74
+ torch.set_float32_matmul_precision("high")
75
+ self.pipe.transformer = torch.compile(self.pipe.transformer, fullgraph=True)
76
+
77
+ # Safe generator creation (mps can't use device=...)
78
+ try:
79
+ generator = torch.Generator(device=self.pipe.device)
80
+ except:
81
+ generator = torch.Generator()
82
+
83
+ # warmup inference
84
+ _ = self.pipe(
85
+ prompt="美少女战士",
86
+ negative_prompt="模糊",
87
  num_inference_steps=25,
88
  pag_scale=1.3,
89
  width=1024,
90
  height=1024,
91
+ generator=generator.manual_seed(42),
92
  return_dict=False
93
  )[0][0]
94
 
95
+ # -------------------- Generate Image --------------------
96
  @torch.no_grad()
97
  def __call__(self, prompt, seed=0):
98
  seed_everything(seed)
99
+
100
+ # Generator fix (no device for mps)
101
+ try:
102
+ generator = torch.Generator(device=self.pipe.device)
103
+ except:
104
+ generator = torch.Generator()
105
+
106
  generator = generator.manual_seed(int(seed))
107
+
108
  out_img = self.pipe(
109
  prompt=prompt[:60] + self.pos_txt,
110
  negative_prompt=self.neg_txt,
 
115
  generator=generator,
116
  return_dict=False
117
  )[0][0]
118
+
119
  return out_img