File size: 4,279 Bytes
94f7c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b49f9
 
 
 
 
 
 
 
 
 
 
94f7c5c
 
 
 
 
 
 
 
 
 
 
63b49f9
94f7c5c
63b49f9
 
 
 
94f7c5c
63b49f9
 
 
 
94f7c5c
 
63b49f9
94f7c5c
 
 
63b49f9
94f7c5c
63b49f9
 
 
 
 
 
94f7c5c
63b49f9
94f7c5c
63b49f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f7c5c
 
 
 
63b49f9
94f7c5c
 
 
63b49f9
94f7c5c
 
 
63b49f9
 
 
 
 
 
 
94f7c5c
63b49f9
94f7c5c
 
 
 
 
 
 
 
 
 
63b49f9
94f7c5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.

# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.

import os
import random
import numpy as np
import torch
from diffusers import AutoPipelineForText2Image


# -------------------- Device Auto-Selection --------------------
def get_auto_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():  # macOS GPU
        return "mps"
    else:
        return "cpu"


# -------------------- Seed Helper --------------------
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ["PL_GLOBAL_SEED"] = str(seed)


class HunyuanDiTPipeline:
    def __init__(
        self,
        model_path="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
        device=None
    ):
        # Device auto-detect
        if device is None:
            device = get_auto_device()

        self.device = device

        # float16 only works on CUDA / sometimes MPS, but not CPU
        dtype = torch.float16 if device in ["cuda", "mps"] else torch.float32

        self.pipe = AutoPipelineForText2Image.from_pretrained(
            model_path,
            torch_dtype=dtype,
            enable_pag=True,
            pag_applied_layers=["blocks.(16|17|18|19)"]
        ).to(device)

        self.pos_txt = ",白色背景,3D风格,最佳质量"
        self.neg_txt = (
            "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,"
            "残缺,多余的手指,变异的手,画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,"
            "糟糕的比例,多余的肢体,克隆的脸,毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,"
            "额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
        )

    # -------------------- Compile (optional) --------------------
    def compile(self):
        # accelerate transformer — works only on CUDA; skip otherwise
        if self.device == "cuda":
            torch.set_float32_matmul_precision("high")
            self.pipe.transformer = torch.compile(self.pipe.transformer, fullgraph=True)

        # Safe generator creation (mps can't use device=...)
        try:
            generator = torch.Generator(device=self.pipe.device)
        except:
            generator = torch.Generator()

        # warmup inference
        _ = self.pipe(
            prompt="美少女战士",
            negative_prompt="模糊",
            num_inference_steps=25,
            pag_scale=1.3,
            width=1024,
            height=1024,
            generator=generator.manual_seed(42),
            return_dict=False
        )[0][0]

    # -------------------- Generate Image --------------------
    @torch.no_grad()
    def __call__(self, prompt, seed=0):
        seed_everything(seed)

        # Generator fix (no device for mps)
        try:
            generator = torch.Generator(device=self.pipe.device)
        except:
            generator = torch.Generator()

        generator = generator.manual_seed(int(seed))

        out_img = self.pipe(
            prompt=prompt[:60] + self.pos_txt,
            negative_prompt=self.neg_txt,
            num_inference_steps=25,
            pag_scale=1.3,
            width=1024,
            height=1024,
            generator=generator,
            return_dict=False
        )[0][0]

        return out_img