Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,151 Bytes
0274afd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
"""
DreamRenderer实现模块
"""
import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from PIL import Image, ImageDraw
import numpy as np
from typing import List, Dict, Optional, Tuple
import spaces
class DreamRendererPipeline:
"""
DreamRenderer管道实现
"""
def __init__(self, model_id: str = "black-forest-labs/FLUX.1-dev"):
"""
初始化DreamRenderer管道
Args:
model_id: 使用的模型ID
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = model_id
self.pipe = None
self.loaded = False
def load_model(self):
"""加载FLUX模型"""
try:
print(f"正在加载模型: {self.model_id}")
self.pipe = FluxPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
use_safetensors=True
)
self.pipe = self.pipe.to(self.device)
# 启用内存高效的注意力机制
if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'):
self.pipe.enable_xformers_memory_efficient_attention()
self.loaded = True
print("模型加载完成!")
return True
except Exception as e:
print(f"模型加载失败: {str(e)}")
self.loaded = False
return False
def create_layout_mask(self, bbox_data: List[Dict], width: int, height: int) -> torch.Tensor:
"""
根据边界框数据创建布局掩码
Args:
bbox_data: 边界框数据列表
width: 图像宽度
height: 图像高度
Returns:
布局掩码张量
"""
mask = torch.zeros((height, width), dtype=torch.float32)
for i, bbox in enumerate(bbox_data):
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 在掩码中标记区域
mask[y:y+h, x:x+w] = i + 1
return mask
def create_attention_mask(self, bbox_data: List[Dict], width: int, height: int) -> List[torch.Tensor]:
"""
为每个实例创建注意力掩码
Args:
bbox_data: 边界框数据列表
width: 图像宽度
height: 图像高度
Returns:
注意力掩码列表
"""
masks = []
for bbox in bbox_data:
mask = torch.zeros((height, width), dtype=torch.float32)
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 创建软边界的掩码
mask[y:y+h, x:x+w] = 1.0
# 应用高斯模糊以创建软边界
if torch.cuda.is_available():
mask = mask.unsqueeze(0).unsqueeze(0).cuda()
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1)
mask = mask.squeeze().cpu()
masks.append(mask)
return masks
def modify_attention_weights(self, attention_weights: torch.Tensor,
attention_masks: List[torch.Tensor],
current_token_idx: int) -> torch.Tensor:
"""
修改注意力权重以实现区域控制
Args:
attention_weights: 原始注意力权重
attention_masks: 注意力掩码列表
current_token_idx: 当前token索引
Returns:
修改后的注意力权重
"""
# 这里实现DreamRenderer的核心注意力修改逻辑
# 根据当前token和对应的区域掩码调整注意力权重
if current_token_idx < len(attention_masks):
mask = attention_masks[current_token_idx]
# 将掩码应用到注意力权重
if mask.device != attention_weights.device:
mask = mask.to(attention_weights.device)
# 增强对应区域的注意力
attention_weights = attention_weights * (1 + mask * 0.5)
return attention_weights
@spaces.GPU
def generate_image(self,
prompt: str,
bbox_data: List[Dict],
negative_prompt: str = "",
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
width: int = 512,
height: int = 512,
seed: Optional[int] = None) -> Image.Image:
"""
生成图像的主要函数
Args:
prompt: 主提示词
bbox_data: 边界框数据
negative_prompt: 负向提示词
num_inference_steps: 推理步数
guidance_scale: 引导强度
width: 图像宽度
height: 图像高度
seed: 随机种子
Returns:
生成的图像
"""
if not self.loaded:
if not self.load_model():
# 如果模型加载失败,返回一个演示图像
return self._create_demo_image(prompt, bbox_data, width, height)
# 设置随机种子
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None
try:
# 构建完整的提示词
full_prompt = self._build_full_prompt(prompt, bbox_data)
# 如果没有边界框数据,直接使用标准生成
if not bbox_data:
image = self.pipe(
prompt=full_prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
).images[0]
else:
# 使用DreamRenderer的区域控制逻辑
image = self._generate_with_bbox_control(
full_prompt, bbox_data, negative_prompt,
num_inference_steps, guidance_scale,
width, height, generator
)
return image
except Exception as e:
print(f"生成图像时出错: {str(e)}")
# 返回演示图像
return self._create_demo_image(prompt, bbox_data, width, height)
def _build_full_prompt(self, main_prompt: str, bbox_data: List[Dict]) -> str:
"""构建包含区域描述的完整提示词"""
full_prompt = main_prompt
if bbox_data:
region_descriptions = []
for i, bbox in enumerate(bbox_data):
if bbox['label']:
region_descriptions.append(f"{bbox['label']}")
if region_descriptions:
full_prompt += ", " + ", ".join(region_descriptions)
return full_prompt
def _generate_with_bbox_control(self, prompt: str, bbox_data: List[Dict],
negative_prompt: str, num_inference_steps: int,
guidance_scale: float, width: int, height: int,
generator: Optional[torch.Generator]) -> Image.Image:
"""使用边界框控制生成图像"""
# 创建注意力掩码
attention_masks = self.create_attention_mask(bbox_data, width, height)
# 这里应该实现DreamRenderer的核心算法
# 包括注意力修改、交叉注意力控制等
# 现在先用标准方法生成,后续可以替换为实际的DreamRenderer实现
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
).images[0]
# 在生成的图像上绘制边界框作为演示
image = self._add_bbox_overlay(image, bbox_data)
return image
def _add_bbox_overlay(self, image: Image.Image, bbox_data: List[Dict]) -> Image.Image:
"""在图像上添加边界框覆盖层(用于演示)"""
if not bbox_data:
return image
draw = ImageDraw.Draw(image)
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
for i, bbox in enumerate(bbox_data):
color = colors[i % len(colors)]
x = int(bbox['x'] * image.width)
y = int(bbox['y'] * image.height)
w = int(bbox['width'] * image.width)
h = int(bbox['height'] * image.height)
# 绘制边界框
draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
# 绘制标签
if bbox['label']:
draw.text((x, y-15), bbox['label'], fill=color)
return image
def _create_demo_image(self, prompt: str, bbox_data: List[Dict],
width: int, height: int) -> Image.Image:
"""创建演示图像(当模型加载失败时使用)"""
# 创建一个渐变背景
image = Image.new('RGB', (width, height))
draw = ImageDraw.Draw(image)
# 绘制渐变背景
for y in range(height):
color_value = int(255 * (y / height))
color = (100 + color_value//3, 150 + color_value//4, 200 + color_value//5)
draw.line([(0, y), (width, y)], fill=color)
# 添加提示词文本
draw.text((10, 10), f"Prompt: {prompt}", fill='white')
draw.text((10, 30), "DreamRenderer Demo", fill='white')
# 绘制边界框
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
for i, bbox in enumerate(bbox_data):
color = colors[i % len(colors)]
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 绘制边界框
draw.rectangle([x, y, x+w, y+h], outline=color, width=3)
# 绘制标签
if bbox['label']:
draw.text((x, y-20), bbox['label'], fill=color)
return image |