File size: 11,151 Bytes
0274afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""
DreamRenderer实现模块
"""

import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from PIL import Image, ImageDraw
import numpy as np
from typing import List, Dict, Optional, Tuple
import spaces

class DreamRendererPipeline:
    """
    DreamRenderer管道实现
    """
    
    def __init__(self, model_id: str = "black-forest-labs/FLUX.1-dev"):
        """
        初始化DreamRenderer管道
        
        Args:
            model_id: 使用的模型ID
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = model_id
        self.pipe = None
        self.loaded = False
        
    def load_model(self):
        """加载FLUX模型"""
        try:
            print(f"正在加载模型: {self.model_id}")
            self.pipe = FluxPipeline.from_pretrained(
                self.model_id,
                torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
                use_safetensors=True
            )
            self.pipe = self.pipe.to(self.device)
            
            # 启用内存高效的注意力机制
            if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'):
                self.pipe.enable_xformers_memory_efficient_attention()
            
            self.loaded = True
            print("模型加载完成!")
            return True
            
        except Exception as e:
            print(f"模型加载失败: {str(e)}")
            self.loaded = False
            return False
    
    def create_layout_mask(self, bbox_data: List[Dict], width: int, height: int) -> torch.Tensor:
        """
        根据边界框数据创建布局掩码
        
        Args:
            bbox_data: 边界框数据列表
            width: 图像宽度
            height: 图像高度
            
        Returns:
            布局掩码张量
        """
        mask = torch.zeros((height, width), dtype=torch.float32)
        
        for i, bbox in enumerate(bbox_data):
            x = int(bbox['x'] * width)
            y = int(bbox['y'] * height)
            w = int(bbox['width'] * width)
            h = int(bbox['height'] * height)
            
            # 在掩码中标记区域
            mask[y:y+h, x:x+w] = i + 1
            
        return mask
    
    def create_attention_mask(self, bbox_data: List[Dict], width: int, height: int) -> List[torch.Tensor]:
        """
        为每个实例创建注意力掩码
        
        Args:
            bbox_data: 边界框数据列表
            width: 图像宽度
            height: 图像高度
            
        Returns:
            注意力掩码列表
        """
        masks = []
        
        for bbox in bbox_data:
            mask = torch.zeros((height, width), dtype=torch.float32)
            
            x = int(bbox['x'] * width)
            y = int(bbox['y'] * height)
            w = int(bbox['width'] * width)
            h = int(bbox['height'] * height)
            
            # 创建软边界的掩码
            mask[y:y+h, x:x+w] = 1.0
            
            # 应用高斯模糊以创建软边界
            if torch.cuda.is_available():
                mask = mask.unsqueeze(0).unsqueeze(0).cuda()
                mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1)
                mask = mask.squeeze().cpu()
            
            masks.append(mask)
            
        return masks
    
    def modify_attention_weights(self, attention_weights: torch.Tensor, 
                               attention_masks: List[torch.Tensor],
                               current_token_idx: int) -> torch.Tensor:
        """
        修改注意力权重以实现区域控制
        
        Args:
            attention_weights: 原始注意力权重
            attention_masks: 注意力掩码列表
            current_token_idx: 当前token索引
            
        Returns:
            修改后的注意力权重
        """
        # 这里实现DreamRenderer的核心注意力修改逻辑
        # 根据当前token和对应的区域掩码调整注意力权重
        
        if current_token_idx < len(attention_masks):
            mask = attention_masks[current_token_idx]
            
            # 将掩码应用到注意力权重
            if mask.device != attention_weights.device:
                mask = mask.to(attention_weights.device)
            
            # 增强对应区域的注意力
            attention_weights = attention_weights * (1 + mask * 0.5)
            
        return attention_weights
    
    @spaces.GPU
    def generate_image(self, 
                      prompt: str, 
                      bbox_data: List[Dict],
                      negative_prompt: str = "",
                      num_inference_steps: int = 20,
                      guidance_scale: float = 7.5,
                      width: int = 512,
                      height: int = 512,
                      seed: Optional[int] = None) -> Image.Image:
        """
        生成图像的主要函数
        
        Args:
            prompt: 主提示词
            bbox_data: 边界框数据
            negative_prompt: 负向提示词
            num_inference_steps: 推理步数
            guidance_scale: 引导强度
            width: 图像宽度
            height: 图像高度
            seed: 随机种子
            
        Returns:
            生成的图像
        """
        if not self.loaded:
            if not self.load_model():
                # 如果模型加载失败,返回一个演示图像
                return self._create_demo_image(prompt, bbox_data, width, height)
        
        # 设置随机种子
        if seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(seed)
        else:
            generator = None
        
        try:
            # 构建完整的提示词
            full_prompt = self._build_full_prompt(prompt, bbox_data)
            
            # 如果没有边界框数据,直接使用标准生成
            if not bbox_data:
                image = self.pipe(
                    prompt=full_prompt,
                    negative_prompt=negative_prompt,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    width=width,
                    height=height,
                    generator=generator
                ).images[0]
            else:
                # 使用DreamRenderer的区域控制逻辑
                image = self._generate_with_bbox_control(
                    full_prompt, bbox_data, negative_prompt,
                    num_inference_steps, guidance_scale,
                    width, height, generator
                )
            
            return image
            
        except Exception as e:
            print(f"生成图像时出错: {str(e)}")
            # 返回演示图像
            return self._create_demo_image(prompt, bbox_data, width, height)
    
    def _build_full_prompt(self, main_prompt: str, bbox_data: List[Dict]) -> str:
        """构建包含区域描述的完整提示词"""
        full_prompt = main_prompt
        
        if bbox_data:
            region_descriptions = []
            for i, bbox in enumerate(bbox_data):
                if bbox['label']:
                    region_descriptions.append(f"{bbox['label']}")
            
            if region_descriptions:
                full_prompt += ", " + ", ".join(region_descriptions)
        
        return full_prompt
    
    def _generate_with_bbox_control(self, prompt: str, bbox_data: List[Dict],
                                   negative_prompt: str, num_inference_steps: int,
                                   guidance_scale: float, width: int, height: int,
                                   generator: Optional[torch.Generator]) -> Image.Image:
        """使用边界框控制生成图像"""
        
        # 创建注意力掩码
        attention_masks = self.create_attention_mask(bbox_data, width, height)
        
        # 这里应该实现DreamRenderer的核心算法
        # 包括注意力修改、交叉注意力控制等
        
        # 现在先用标准方法生成,后续可以替换为实际的DreamRenderer实现
        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            width=width,
            height=height,
            generator=generator
        ).images[0]
        
        # 在生成的图像上绘制边界框作为演示
        image = self._add_bbox_overlay(image, bbox_data)
        
        return image
    
    def _add_bbox_overlay(self, image: Image.Image, bbox_data: List[Dict]) -> Image.Image:
        """在图像上添加边界框覆盖层(用于演示)"""
        if not bbox_data:
            return image
        
        draw = ImageDraw.Draw(image)
        colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
        
        for i, bbox in enumerate(bbox_data):
            color = colors[i % len(colors)]
            
            x = int(bbox['x'] * image.width)
            y = int(bbox['y'] * image.height)
            w = int(bbox['width'] * image.width)
            h = int(bbox['height'] * image.height)
            
            # 绘制边界框
            draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
            
            # 绘制标签
            if bbox['label']:
                draw.text((x, y-15), bbox['label'], fill=color)
        
        return image
    
    def _create_demo_image(self, prompt: str, bbox_data: List[Dict], 
                          width: int, height: int) -> Image.Image:
        """创建演示图像(当模型加载失败时使用)"""
        # 创建一个渐变背景
        image = Image.new('RGB', (width, height))
        draw = ImageDraw.Draw(image)
        
        # 绘制渐变背景
        for y in range(height):
            color_value = int(255 * (y / height))
            color = (100 + color_value//3, 150 + color_value//4, 200 + color_value//5)
            draw.line([(0, y), (width, y)], fill=color)
        
        # 添加提示词文本
        draw.text((10, 10), f"Prompt: {prompt}", fill='white')
        draw.text((10, 30), "DreamRenderer Demo", fill='white')
        
        # 绘制边界框
        colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
        for i, bbox in enumerate(bbox_data):
            color = colors[i % len(colors)]
            
            x = int(bbox['x'] * width)
            y = int(bbox['y'] * height)
            w = int(bbox['width'] * width)
            h = int(bbox['height'] * height)
            
            # 绘制边界框
            draw.rectangle([x, y, x+w, y+h], outline=color, width=3)
            
            # 绘制标签
            if bbox['label']:
                draw.text((x, y-20), bbox['label'], fill=color)
        
        return image