File size: 11,996 Bytes
244baf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
from contextlib import asynccontextmanager
import numpy as np
from PIL import Image
import io
import uuid
from typing import List, Union

import axengine
import torch

from transformers import CLIPTokenizer, PreTrainedTokenizer
import time
import argparse

import os
import traceback
from diffusers import DPMSolverMultistepScheduler
# 配置日志格式
DEBUG_MODE = True
LOG_TIMESTAMP = True

def debug_log(msg):
    if DEBUG_MODE:
        timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else ""
        print(f"{timestamp}[DEBUG] {msg}")
        
# 服务配置
MODEL_PATHS = {
    "tokenizer": "./models/tokenizer",
    "text_encoder": "./models/text_encoder/sd15_text_encoder_sim.axmodel",
    "unet": "./models/unet.axmodel",
    "vae": "./models/vae_decoder.axmodel",
    "time_embeddings": "./models/time_input_dpmpp_20steps.npy"  # 仍使用20步数据,但只取其中10步
}

class DiffusionModels:
    def __init__(self):
        self.models_loaded = False
        self.tokenizer = None
        self.text_encoder = None
        self.unet = None
        self.vae = None
        self.time_embeddings = None

    def load_models(self):
        """预加载所有模型到内存"""
        try:
            # 初始化tokenizer和模型
            self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS["tokenizer"])
            self.text_encoder = axengine.InferenceSession(MODEL_PATHS["text_encoder"])
            self.unet = axengine.InferenceSession(MODEL_PATHS["unet"])
            self.vae = axengine.InferenceSession(MODEL_PATHS["vae"])
            
            # 加载时间嵌入并间隔采样为10步
            full_time_embeddings = np.load(MODEL_PATHS["time_embeddings"])
            # 从20步中间隔取10步 (取索引 0, 2, 4, 6, 8, 10, 12, 14, 16, 18)
            self.time_embeddings = full_time_embeddings[::2]  # 间隔取值
            debug_log(f"时间嵌入已从20步采样为10步,形状: {self.time_embeddings.shape}")
            
            self.models_loaded = True
            print("所有模型已成功加载到内存")
        except Exception as e:
            print(f"模型加载失败: {str(e)}")
            raise

diffusion_models = DiffusionModels()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 服务启动时加载模型
    diffusion_models.load_models()
    yield
    # 服务关闭时清理资源
    # (根据axengine的要求添加必要的清理逻辑)

app = FastAPI(lifespan=lifespan)

class GenerationRequest(BaseModel):
    positive_prompt: str
    negative_prompt: str = ""
    # 移除这些参数,因为已经固定
    # num_inference_steps: int = 10  # 固定为10步
    # guidance_scale: float = 5.4    # 固定为5.4
    seed: int = None

@app.post("/generate")
async def generate_image(request: GenerationRequest):
    try:
        # 输入验证
        if len(request.positive_prompt) > 1000:
            raise ValueError("提示词过长")
            
        # 执行推理流程 - 固定参数
        image = generate_diffusion_image(
            positive_prompt=request.positive_prompt,
            negative_prompt=request.negative_prompt,
            num_steps=10,        # 固定10步
            guidance_scale=5.4,  # 固定CFG=5.4
            seed=request.seed
        )
        
        # 转换图像为字节流
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        
        return Response(content=img_byte_arr.getvalue(), media_type="image/png")
        
    except Exception as e:
        error_id = str(uuid.uuid4())
        print(f"Error [{error_id}]: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"生成失败,错误ID:{error_id}"
        )
        
        
        
def get_embeds(prompt, negative_prompt):
    """获取正负提示词的嵌入(带形状验证)"""
    try:
        debug_log(f"开始处理提示词: {prompt}")
        start_time = time.time()
        
        
        def process_prompt(prompt_text):
            inputs = diffusion_models.tokenizer(
                prompt_text,
                padding="max_length",
                max_length=77,
                truncation=True,
                return_tensors="pt"
            )
            debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}")
            
            outputs = diffusion_models.text_encoder.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0]
            debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}")
            return outputs
        
        neg_start = time.time()
        neg_embeds = process_prompt(negative_prompt)
        pos_embeds = process_prompt(prompt)
        debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s")
        
        # 验证形状
        if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768):
            raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}")
            
        return neg_embeds, pos_embeds
    except Exception as e:
        print(f"获取嵌入失败: {str(e)}")
        traceback.print_exc()
        exit(1)


def generate_diffusion_image(
    positive_prompt: str,
    negative_prompt: str,
    num_steps: int = 10,       # 固定默认值为10
    guidance_scale: float = 5.4, # 固定默认值为5.4
    seed: int = None
) -> Image.Image:
    """
    生成扩散图像的优化版本(固定10步推理,CFG=5.4)
    
    参数:
        positive_prompt (str): 正向提示词
        negative_prompt (str): 负向提示词
        num_steps (int): 推理步数 (固定为10)
        guidance_scale (float): 分类器自由引导系数 (固定为5.4)
        seed (int): 随机种子 (可选)
    
    返回:
        PIL.Image.Image: 生成的图像
    
    异常:
        ValueError: 输入参数无效时抛出
        RuntimeError: 推理过程中出现错误时抛出
    """
    try:
        # 参数验证和固定
        if not positive_prompt:
            raise ValueError("正向提示词不能为空")
        
        # 强制使用优化后的固定参数
        num_steps = 10
        guidance_scale = 5.4
        
        debug_log(f"开始生成流程 (固定参数: 10步, CFG=5.4)...")
        start_time = time.time()

        # =====================================================================
        # 1. 初始化配置
        # =====================================================================
        seed = seed if seed is not None else int(time.time() * 1000) % 0xFFFFFFFF
        torch.manual_seed(seed)
        np.random.seed(seed)
        debug_log(f"初始随机种子: {seed}")

        # =====================================================================
        # 2. 文本编码 (保持原有输入形状 [1, 77, 768])
        # =====================================================================
        embed_start = time.time()
        neg_emb, pos_emb = get_embeds(
            positive_prompt,
            negative_prompt,
        )
        debug_log(f"文本编码完成 | 耗时: {time.time()-embed_start:.2f}s")

        # =====================================================================
        # 3. 初始化潜在变量 (固定形状 [1, 4, 60, 40])
        # =====================================================================
        scheduler = DPMSolverMultistepScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            algorithm_type="dpmsolver++",
            use_karras_sigmas=True
        )
        scheduler.set_timesteps(num_steps)  # 设置为10步
        
        latents_shape = (1, 4, 60, 40)
        latent = torch.randn(latents_shape, generator=torch.Generator().manual_seed(seed))
        latent = latent * scheduler.init_noise_sigma
        latent = latent.numpy().astype(np.float32)
        debug_log(f"潜在变量初始化 | 形状: {latent.shape} sigma:{scheduler.init_noise_sigma:.3f}")

        # =====================================================================
        # 4. 准备时间嵌入 (使用预处理的10步数据)
        # =====================================================================
        if len(diffusion_models.time_embeddings) != num_steps:
            raise ValueError(f"时间嵌入步数不匹配: 需要{num_steps}步 当前{len(diffusion_models.time_embeddings)}步")
        time_steps = diffusion_models.time_embeddings
        debug_log(f"使用预处理的10步时间嵌入,形状: {time_steps.shape}")

        # =====================================================================
        # 5. 采样主循环 (10步优化版)
        # =====================================================================
        debug_log("开始10步采样循环...")
        for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)):
            step_start = time.time()
            
            # 准备时间嵌入 (形状 [1, 1])
            time_emb = np.expand_dims(time_steps[step_idx], axis=0)

            # -----------------------------------------
            # UNET双推理流程 (CFG=5.4优化)
            # -----------------------------------------
            # 负面提示推理
            noise_pred_neg = diffusion_models.unet.run(None, {
                "sample": latent,
                "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
                "encoder_hidden_states": neg_emb
            })[0]
            
            # 正面提示推理
            noise_pred_pos = diffusion_models.unet.run(None, {
                "sample": latent,
                "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
                "encoder_hidden_states": pos_emb
            })[0]

            # CFG融合 (固定使用5.4的引导强度)
            noise_pred = noise_pred_neg + 5.4 * (noise_pred_pos - noise_pred_neg)

            # 转换为Tensor
            latent_tensor = torch.from_numpy(latent)
            noise_pred_tensor = torch.from_numpy(noise_pred)
            
            # 调度器更新
            scheduler_start = time.time()
            latent_tensor = scheduler.step(
                model_output=noise_pred_tensor,
                timestep=timestep,
                sample=latent_tensor
            ).prev_sample
            debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
            
            # 转换回numpy
            latent = latent_tensor.numpy().astype(np.float32)
            debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]")

            debug_log(f"步骤 {step_idx+1}/{num_steps} | 耗时: {time.time()-step_start:.2f}s")

        # =====================================================================
        # 6. VAE解码 (强制输出形状为768x512)
        # =====================================================================
        debug_log("开始VAE解码...")
        vae_start = time.time()
        latent = latent / 0.18215
        image = diffusion_models.vae.run(None, {"latent": latent})[0]
        
        # 转换为PIL图像 (优化内存拷贝)
        image = np.transpose(image.squeeze(), (1, 2, 0))
        image = np.clip((image / 2 + 0.5) * 255, 0, 255).astype(np.uint8)
        pil_image = Image.fromarray(image[..., :3])  # 移除alpha通道
        pil_image.save("./api.png")
        debug_log(f"VAE解码完成 | 耗时: {time.time()-vae_start:.2f}s")
        debug_log(f"总耗时: {time.time()-start_time:.2f}s (10步优化版)")
        return pil_image

    except Exception as e:
        error_msg = f"生成失败: {str(e)}"
        debug_log(error_msg)
        traceback.print_exc()
        raise RuntimeError(error_msg)