Humphreykowl commited on
Commit
37a8e98
·
verified ·
1 Parent(s): d5dc109

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +561 -147
models/model_manager.py CHANGED
@@ -1,6 +1,3 @@
1
- # 完整的 modal_manager.py (即之前的 model_manager.py 完整实现,路径改为 model/modal_manager.py 可直接替换)
2
- # 包含三视图打板一致性、手稿风格生成、多角度 3D 试穿支持、显存优化等全部功能
3
-
4
  import torch
5
  from PIL import Image
6
  import numpy as np
@@ -11,6 +8,7 @@ import logging
11
  import time
12
  import random
13
  import gc
 
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
@@ -20,13 +18,15 @@ class ModelManager:
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  logger.info(f"使用设备: {self.device}")
22
 
 
23
  self.model_config = {
24
  "caption_model": "Salesforce/blip-image-captioning-large",
25
- "clip_model": "openai/clip-vit-large-patch14",
26
  "sd_model": "runwayml/stable-diffusion-v1-5",
27
  "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
28
  }
29
 
 
30
  self.caption_processor = None
31
  self.caption_model = None
32
  self.clip_processor = None
@@ -34,184 +34,529 @@ class ModelManager:
34
  self.sd_pipeline = None
35
  self.controlnet = None
36
  self.controlnet_pipeline = None
37
-
 
38
  self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
39
  self.enable_attention_slicing = True
40
- self.enable_cpu_offload = False
41
-
42
- try:
43
- self.load_all_models()
44
- except Exception as e:
45
- logger.warning(f"加载模型时出错: {e}")
46
 
47
  def optimize_memory_usage(self):
 
48
  if torch.cuda.is_available():
 
49
  torch.backends.cudnn.benchmark = True
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
 
53
  def load_all_models(self):
 
54
  self.optimize_memory_usage()
55
- self.load_caption_model()
56
- self.load_clip_model()
57
- self.load_sd_pipeline()
58
- self.load_controlnet_pipeline()
59
- logger.info("所有模型加载完成")
 
 
 
 
 
 
 
 
 
60
 
61
  def load_caption_model(self):
62
- self.caption_processor = BlipProcessor.from_pretrained(self.model_config["caption_model"], cache_dir="/tmp/models")
63
- self.caption_model = BlipForConditionalGeneration.from_pretrained(
64
- self.model_config["caption_model"],
65
- cache_dir="/tmp/models",
66
- torch_dtype=self.torch_dtype,
67
- low_cpu_mem_usage=True
68
- ).to(self.device)
69
- self.caption_model.enable_attention_slicing()
70
- self.caption_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def load_clip_model(self):
73
- self.clip_processor = CLIPProcessor.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models")
74
- self.clip_model = CLIPModel.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype).to(self.device)
75
- self.clip_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def load_sd_pipeline(self):
78
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
79
- self.model_config["sd_model"],
80
- torch_dtype=self.torch_dtype,
81
- cache_dir="/tmp/models",
82
- safety_checker=None,
83
- requires_safety_checker=False,
84
- use_safetensors=True,
85
- low_cpu_mem_usage=True
86
- ).to(self.device)
87
- self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
88
- if self.enable_attention_slicing:
89
- self.sd_pipeline.enable_attention_slicing()
90
- try:
91
- self.sd_pipeline.enable_xformers_memory_efficient_attention()
92
- except Exception:
93
- pass
94
- self.sd_pipeline.enable_vae_slicing()
95
- self.sd_pipeline.safety_checker = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def load_controlnet_pipeline(self):
98
- self.controlnet = ControlNetModel.from_pretrained(
99
- self.model_config["controlnet_model"],
100
- cache_dir="/tmp/models",
101
- torch_dtype=self.torch_dtype,
102
- low_cpu_mem_usage=True
103
- ).to(self.device)
104
- self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
105
- self.model_config["sd_model"],
106
- controlnet=self.controlnet,
107
- cache_dir="/tmp/models",
108
- torch_dtype=self.torch_dtype,
109
- safety_checker=None,
110
- requires_safety_checker=False,
111
- low_cpu_mem_usage=True
112
- ).to(self.device)
113
- self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
114
- if self.enable_attention_slicing:
115
- self.controlnet_pipeline.enable_attention_slicing()
116
- try:
117
- self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
118
- except Exception:
119
- pass
120
- self.controlnet_pipeline.enable_vae_slicing()
121
 
122
- @torch.no_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def generate_caption(self, image):
124
- if image.mode != 'RGB':
125
- image = image.convert('RGB')
126
- if image.width > 512 or image.height > 512:
127
- image.thumbnail((512, 512), Image.Resampling.LANCZOS)
128
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
129
- outputs = self.caption_model.generate(**inputs, max_length=50, num_beams=4, temperature=0.7, do_sample=True)
130
- caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
131
- del inputs, outputs
132
- if torch.cuda.is_available():
133
- torch.cuda.empty_cache()
134
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  @torch.no_grad()
137
  def analyze_style(self, image):
138
- style_labels = [
139
- "business formal suit professional attire",
140
- "casual comfortable everyday wear",
141
- "athletic sportswear activewear",
142
- "fashion trendy modern stylish",
143
- "vintage retro classic style",
144
- "streetwear urban contemporary",
145
- "elegant sophisticated refined"
146
- ]
147
- style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
148
- if image.mode != 'RGB':
149
- image = image.convert('RGB')
150
- if image.width > 224 or image.height > 224:
151
- image.thumbnail((224, 224), Image.Resampling.LANCZOS)
152
- inputs = self.clip_processor(text=style_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
153
- outputs = self.clip_model(**inputs)
154
- probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]
155
- return {name: float(prob) for name, prob in zip(style_names, probs)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  @torch.no_grad()
158
- def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, seed=None):
159
- if negative_prompt is None:
160
- negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
161
- width = (width // 8) * 8
162
- height = (height // 8) * 8
163
- gen = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None
164
- result = self.sd_pipeline(
165
- prompt=prompt,
166
- negative_prompt=negative_prompt,
167
- num_inference_steps=num_inference_steps,
168
- guidance_scale=guidance_scale,
169
- height=height,
170
- width=width,
171
- generator=gen
172
- )
173
- if torch.cuda.is_available():
174
- torch.cuda.empty_cache()
175
- return result.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  @torch.no_grad()
178
- def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, angle=0, width=512, height=768):
179
- if image.mode != 'RGB':
180
- image = image.convert('RGB')
181
- control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
182
- if negative_prompt is None:
183
- negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
184
- prompt_with_angle = f"{prompt}, view from {angle} degrees"
185
- if reference_image is not None:
186
- prompt_with_angle = f"{prompt_with_angle}, based on provided reference design"
187
- gen = torch.Generator(device=self.device).manual_seed(int(time.time()) + int(angle))
188
- result = self.controlnet_pipeline(
189
- prompt=prompt_with_angle,
190
- image=control_image,
191
- negative_prompt=negative_prompt,
192
- num_inference_steps=num_inference_steps,
193
- guidance_scale=guidance_scale,
194
- controlnet_conditioning_scale=1.0,
195
- generator=gen
196
- )
197
- if torch.cuda.is_available():
198
- torch.cuda.empty_cache()
199
- return result.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def create_placeholder_image(self, width, height):
202
- color = random.choice([(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)])
 
 
203
  return Image.new('RGB', (width, height), color=color)
204
 
205
  def cleanup(self):
206
- gc.collect()
207
- if torch.cuda.is_available():
208
- torch.cuda.empty_cache()
209
- try:
 
 
 
 
210
  torch.cuda.ipc_collect()
211
- except Exception:
212
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  def get_model_status(self):
 
215
  status = {
216
  "caption_model": self.caption_model is not None,
217
  "clip_model": self.clip_model is not None,
@@ -219,9 +564,78 @@ class ModelManager:
219
  "controlnet_pipeline": self.controlnet_pipeline is not None,
220
  "device": self.device
221
  }
 
222
  if torch.cuda.is_available():
223
  status["gpu_memory"] = {
224
  "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
225
- "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB"
 
226
  }
 
227
  return status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from PIL import Image
3
  import numpy as np
 
8
  import time
9
  import random
10
  import gc
11
+ from functools import lru_cache
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
18
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
  logger.info(f"使用设备: {self.device}")
20
 
21
+ # 优化的模型配置
22
  self.model_config = {
23
  "caption_model": "Salesforce/blip-image-captioning-large",
24
+ "clip_model": "openai/clip-vit-large-patch14",
25
  "sd_model": "runwayml/stable-diffusion-v1-5",
26
  "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
27
  }
28
 
29
+ # 模型容器
30
  self.caption_processor = None
31
  self.caption_model = None
32
  self.clip_processor = None
 
34
  self.sd_pipeline = None
35
  self.controlnet = None
36
  self.controlnet_pipeline = None
37
+
38
+ # 性能优化设置
39
  self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
40
  self.enable_attention_slicing = True
41
+ self.enable_cpu_offload = False # 16GB显存应该够用
42
+
43
+ # 预加载所有模型
44
+ self.load_all_models()
 
 
45
 
46
  def optimize_memory_usage(self):
47
+ """内存优化设置"""
48
  if torch.cuda.is_available():
49
+ # 启用内存优化
50
  torch.backends.cudnn.benchmark = True
51
  torch.backends.cuda.matmul.allow_tf32 = True
52
  torch.backends.cudnn.allow_tf32 = True
53
 
54
  def load_all_models(self):
55
+ """按顺序加载所有模型,优化显存使用"""
56
  self.optimize_memory_usage()
57
+
58
+ try:
59
+ self.load_caption_model()
60
+ self.load_clip_model()
61
+ self.load_sd_pipeline()
62
+ self.load_controlnet_pipeline()
63
+
64
+ logger.info("所有模型加载完成")
65
+ if torch.cuda.is_available():
66
+ logger.info(f"GPU显存使用: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
67
+
68
+ except Exception as e:
69
+ logger.error(f"模型加载过程中出错: {e}")
70
+ raise
71
 
72
  def load_caption_model(self):
73
+ """加载BLIP图像描述模型"""
74
+ try:
75
+ logger.info("加载 BLIP 图像描述模型...")
76
+
77
+ self.caption_processor = BlipProcessor.from_pretrained(
78
+ self.model_config["caption_model"],
79
+ cache_dir="/tmp/models"
80
+ )
81
+
82
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
83
+ self.model_config["caption_model"],
84
+ cache_dir="/tmp/models",
85
+ torch_dtype=self.torch_dtype,
86
+ low_cpu_mem_usage=True
87
+ ).to(self.device)
88
+
89
+ # 启用内存优化
90
+ if hasattr(self.caption_model, 'enable_attention_slicing'):
91
+ # self.caption_model.enable_attention_slicing() # Removed: BLIP does not support attention slicing
92
+
93
+ self.caption_model.eval()
94
+ logger.info("BLIP 模型加载完成")
95
+
96
+ except Exception as e:
97
+ logger.error(f"BLIP 模型加载失败: {e}")
98
+ self.caption_model = None
99
+ self.caption_processor = None
100
 
101
  def load_clip_model(self):
102
+ """加载CLIP风格分析模型"""
103
+ try:
104
+ logger.info("加载 CLIP 模型...")
105
+
106
+ self.clip_processor = CLIPProcessor.from_pretrained(
107
+ self.model_config["clip_model"],
108
+ cache_dir="/tmp/models"
109
+ )
110
+
111
+ self.clip_model = CLIPModel.from_pretrained(
112
+ self.model_config["clip_model"],
113
+ cache_dir="/tmp/models",
114
+ torch_dtype=self.torch_dtype
115
+ ).to(self.device)
116
+
117
+ self.clip_model.eval()
118
+ logger.info("CLIP 模型加载完成")
119
+
120
+ except Exception as e:
121
+ logger.error(f"CLIP 模型加载失败: {e}")
122
+ self.clip_model = None
123
+ self.clip_processor = None
124
 
125
  def load_sd_pipeline(self):
126
+ """加载Stable Diffusion Pipeline"""
127
+ try:
128
+ logger.info("加载 Stable Diffusion Pipeline...")
129
+
130
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
131
+ self.model_config["sd_model"],
132
+ torch_dtype=self.torch_dtype,
133
+ cache_dir="/tmp/models",
134
+ safety_checker=None,
135
+ requires_safety_checker=False,
136
+ use_safetensors=True,
137
+ low_cpu_mem_usage=True
138
+ )
139
+
140
+ # 优化设置
141
+ self.sd_pipeline = self.sd_pipeline.to(self.device)
142
+ self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
143
+ self.sd_pipeline.scheduler.config
144
+ )
145
+
146
+ # 启用内存优化
147
+ if self.enable_attention_slicing:
148
+ self.sd_pipeline.enable_attention_slicing()
149
+
150
+ # 启用内存高效attention(如果可用)
151
+ try:
152
+ self.sd_pipeline.enable_xformers_memory_efficient_attention()
153
+ logger.info("启用了xformers内存优化")
154
+ except:
155
+ logger.info("xformers不可用,使用默认attention")
156
+
157
+ # 启用VAE slicing以节省显存
158
+ self.sd_pipeline.enable_vae_slicing()
159
+
160
+ logger.info("Stable Diffusion Pipeline 加载完成")
161
+
162
+ except Exception as e:
163
+ logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
164
+ self.sd_pipeline = None
165
 
166
  def load_controlnet_pipeline(self):
167
+ """加载ControlNet Pipeline"""
168
+ try:
169
+ logger.info("加载 ControlNet 模型和 Pipeline...")
170
+
171
+ self.controlnet = ControlNetModel.from_pretrained(
172
+ self.model_config["controlnet_model"],
173
+ cache_dir="/tmp/models",
174
+ torch_dtype=self.torch_dtype,
175
+ low_cpu_mem_usage=True
176
+ ).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
179
+ self.model_config["sd_model"],
180
+ controlnet=self.controlnet,
181
+ cache_dir="/tmp/models",
182
+ torch_dtype=self.torch_dtype,
183
+ safety_checker=None,
184
+ requires_safety_checker=False,
185
+ low_cpu_mem_usage=True
186
+ ).to(self.device)
187
+
188
+ # 优化设置
189
+ self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
190
+ self.controlnet_pipeline.scheduler.config
191
+ )
192
+
193
+ # 内存优化
194
+ if self.enable_attention_slicing:
195
+ self.controlnet_pipeline.enable_attention_slicing()
196
+
197
+ try:
198
+ self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
199
+ logger.info("ControlNet启用了xformers内存优化")
200
+ except:
201
+ logger.info("ControlNet使用默认attention")
202
+
203
+ self.controlnet_pipeline.enable_vae_slicing()
204
+
205
+ logger.info("ControlNet Pipeline 加载完成")
206
+
207
+ except Exception as e:
208
+ logger.error(f"ControlNet Pipeline 加载失败: {e}")
209
+ self.controlnet = None
210
+ self.controlnet_pipeline = None
211
+
212
+ @torch.no_grad() # 禁用梯度计算节省显存
213
  def generate_caption(self, image):
214
+ """使用BLIP模型生成图像描述"""
215
+ if self.caption_model is None or self.caption_processor is None:
216
+ self.load_caption_model()
217
+ if self.caption_model is None:
218
+ return "时尚服装设计作品"
219
+
220
+ try:
221
+ # 预处理图像
222
+ if image.mode != 'RGB':
223
+ image = image.convert('RGB')
224
+
225
+ # 调整图像大小以节省显存
226
+ if image.width > 512 or image.height > 512:
227
+ image.thumbnail((512, 512), Image.Resampling.LANCZOS)
228
+
229
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
230
+
231
+ # 生成描述
232
+ outputs = self.caption_model.generate(
233
+ **inputs,
234
+ max_length=50,
235
+ num_beams=4,
236
+ temperature=0.7,
237
+ do_sample=True
238
+ )
239
+
240
+ caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
241
+
242
+ # 清理显存
243
+ del inputs, outputs
244
+ if torch.cuda.is_available():
245
+ torch.cuda.empty_cache()
246
+
247
+ return caption
248
+
249
+ except Exception as e:
250
+ logger.error(f"图像描述生成失败: {e}")
251
+ return "时尚服装设计作品"
252
 
253
  @torch.no_grad()
254
  def analyze_style(self, image):
255
+ """使用CLIP模型分析服装风格"""
256
+ if self.clip_model is None or self.clip_processor is None:
257
+ self.load_clip_model()
258
+ if self.clip_model is None:
259
+ return {"时尚潮流": 0.8, "现代风格": 0.6}
260
+
261
+ try:
262
+ # 风格标签 - 使用英文避免token问题
263
+ style_labels = [
264
+ "business formal suit professional attire",
265
+ "casual comfortable everyday wear",
266
+ "athletic sportswear activewear",
267
+ "fashion trendy modern stylish",
268
+ "vintage retro classic style",
269
+ "streetwear urban contemporary",
270
+ "elegant sophisticated refined"
271
+ ]
272
+
273
+ style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
274
+
275
+ # 预处理图像
276
+ if image.mode != 'RGB':
277
+ image = image.convert('RGB')
278
+
279
+ # 调整图像大小
280
+ if image.width > 224 or image.height > 224:
281
+ image.thumbnail((224, 224), Image.Resampling.LANCZOS)
282
+
283
+ # 处理输入
284
+ inputs = self.clip_processor(
285
+ text=style_labels,
286
+ images=image,
287
+ return_tensors="pt",
288
+ padding=True,
289
+ truncation=True,
290
+ max_length=77 # CLIP的最大长度
291
+ ).to(self.device)
292
+
293
+ # 获取相似度分数
294
+ outputs = self.clip_model(**inputs)
295
+ logits_per_image = outputs.logits_per_image
296
+ probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
297
+
298
+ # 构建结果
299
+ style_scores = {name: float(prob) for name, prob in zip(style_names, probs)}
300
+
301
+ # 清理显存
302
+ del inputs, outputs
303
+ if torch.cuda.is_available():
304
+ torch.cuda.empty_cache()
305
+
306
+ return style_scores
307
+
308
+ except Exception as e:
309
+ logger.error(f"风格分析失败: {e}")
310
+ return {"时尚潮流": 0.8, "现代风格": 0.6}
311
 
312
  @torch.no_grad()
313
+ def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, **kwargs):
314
+ """使用Stable Diffusion生成设计图像"""
315
+ if self.sd_pipeline is None:
316
+ self.load_sd_pipeline()
317
+ if self.sd_pipeline is None:
318
+ logger.error("无法生成图像:Stable Diffusion 模型未加载")
319
+ return self.create_placeholder_image(width, height)
320
+
321
+ try:
322
+ # 优化参数
323
+ if negative_prompt is None:
324
+ negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
325
+
326
+ # 确保尺寸是8的倍数
327
+ width = (width // 8) * 8
328
+ height = (height // 8) * 8
329
+
330
+ # 生成图像
331
+ result = self.sd_pipeline(
332
+ prompt=prompt,
333
+ negative_prompt=negative_prompt,
334
+ num_inference_steps=num_inference_steps,
335
+ guidance_scale=guidance_scale,
336
+ height=height,
337
+ width=width,
338
+ generator=torch.Generator(device=self.device).manual_seed(random.randint(0, 2**32-1))
339
+ )
340
+
341
+ # 清理显存
342
+ if torch.cuda.is_available():
343
+ torch.cuda.empty_cache()
344
+
345
+ return result.images[0]
346
+
347
+ except Exception as e:
348
+ logger.error(f"图像生成失败: {e}")
349
+ return self.create_placeholder_image(width, height)
350
 
351
  @torch.no_grad()
352
+ def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, **kwargs):
353
+ """使用ControlNet生成3D试穿效果"""
354
+ if self.controlnet_pipeline is None:
355
+ self.load_controlnet_pipeline()
356
+ if self.controlnet_pipeline is None:
357
+ logger.error("无法生成3D试穿:ControlNet 模型未加载")
358
+ return self.create_placeholder_image(512, 768)
359
+
360
+ try:
361
+ # 预处理控制图像
362
+ if image.mode != 'RGB':
363
+ image = image.convert('RGB')
364
+
365
+ # 调整图像尺寸
366
+ control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
367
+
368
+ # 创建简单的姿态控制图(人体轮廓)
369
+ control_image = self.create_pose_control_image(control_image)
370
+
371
+ if negative_prompt is None:
372
+ negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
373
+
374
+ # 如果有参考设计,增强提示词
375
+ if reference_image is not None:
376
+ prompt = f"{prompt}, based on reference design"
377
+
378
+ # 生成3D试穿效果
379
+ result = self.controlnet_pipeline(
380
+ prompt=prompt,
381
+ image=control_image,
382
+ negative_prompt=negative_prompt,
383
+ num_inference_steps=num_inference_steps,
384
+ guidance_scale=guidance_scale,
385
+ controlnet_conditioning_scale=1.0,
386
+ generator=torch.Generator(device=self.device).manual_seed(random.randint(0, 2**32-1))
387
+ )
388
+
389
+ # 清理显存
390
+ if torch.cuda.is_available():
391
+ torch.cuda.empty_cache()
392
+
393
+ return result.images[0]
394
+
395
+ except Exception as e:
396
+ logger.error(f"ControlNet图像生成失败: {e}")
397
+ return self.create_placeholder_image(512, 768)
398
+
399
+ def create_pose_control_image(self, image):
400
+ """创建简单的姿态控制图"""
401
+ try:
402
+ # 转换为numpy数组
403
+ img_array = np.array(image)
404
+
405
+ # 创建简单的人体轮廓控制图
406
+ # 这里使用边缘检测作为简化的姿态控制
407
+ from scipy import ndimage
408
+ gray = np.mean(img_array, axis=2)
409
+ edges = ndimage.sobel(gray)
410
+
411
+ # 归一化到0-255范围
412
+ edges = ((edges - edges.min()) / (edges.max() - edges.min()) * 255).astype(np.uint8)
413
+
414
+ # 转换回PIL图像
415
+ control_image = Image.fromarray(edges, mode='L').convert('RGB')
416
+
417
+ return control_image
418
+
419
+ except Exception as e:
420
+ logger.warning(f"创建姿态控制图失败: {e}")
421
+ # 返回原图的边缘检测版本
422
+ return image.convert('L').convert('RGB')
423
 
424
  def create_placeholder_image(self, width, height):
425
+ """创建占位图像"""
426
+ colors = [(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)]
427
+ color = random.choice(colors)
428
  return Image.new('RGB', (width, height), color=color)
429
 
430
  def cleanup(self):
431
+ """清理显存缓存,保持模型加载状态"""
432
+ logger.info("清理GPU显存缓存...")
433
+ try:
434
+ if torch.cuda.is_available():
435
+ # 强制垃圾回收
436
+ gc.collect()
437
+ # 清理CUDA缓存
438
+ torch.cuda.empty_cache()
439
  torch.cuda.ipc_collect()
440
+
441
+ # 显示显存使用情况
442
+ allocated = torch.cuda.memory_allocated() / 1024**3
443
+ cached = torch.cuda.memory_reserved() / 1024**3
444
+ logger.info(f"显存使用: {allocated:.2f}GB (分配) / {cached:.2f}GB (缓存)")
445
+
446
+ logger.info("显存清理完成")
447
+
448
+ except Exception as e:
449
+ logger.error(f"显存清理失败: {e}")
450
+
451
+ def move_models_to_cpu(self):
452
+ """将模型移至CPU释放GPU显存"""
453
+ try:
454
+ logger.info("将所有模型移至CPU...")
455
+
456
+ models_to_move = [
457
+ ('caption_model', self.caption_model),
458
+ ('clip_model', self.clip_model),
459
+ ('sd_pipeline', self.sd_pipeline),
460
+ ('controlnet_pipeline', self.controlnet_pipeline),
461
+ ('controlnet', self.controlnet)
462
+ ]
463
+
464
+ for model_name, model in models_to_move:
465
+ if model is not None:
466
+ try:
467
+ if hasattr(model, 'to'):
468
+ model.to('cpu')
469
+ logger.info(f"{model_name} 已移至CPU")
470
+ except Exception as e:
471
+ logger.warning(f"移动 {model_name} 到CPU失败: {e}")
472
+
473
+ # 清理GPU缓存
474
+ if torch.cuda.is_available():
475
+ torch.cuda.empty_cache()
476
+ torch.cuda.ipc_collect()
477
+
478
+ allocated = torch.cuda.memory_allocated() / 1024**3
479
+ logger.info(f"移至CPU后GPU显存使用: {allocated:.2f}GB")
480
+
481
+ logger.info("所有模型已移至CPU")
482
+
483
+ except Exception as e:
484
+ logger.error(f"移动模型到CPU失败: {e}")
485
+
486
+ def move_models_to_gpu(self):
487
+ """将模型移回GPU"""
488
+ try:
489
+ logger.info("将所有模型移回GPU...")
490
+
491
+ models_to_move = [
492
+ ('caption_model', self.caption_model),
493
+ ('clip_model', self.clip_model),
494
+ ('sd_pipeline', self.sd_pipeline),
495
+ ('controlnet_pipeline', self.controlnet_pipeline),
496
+ ('controlnet', self.controlnet)
497
+ ]
498
+
499
+ for model_name, model in models_to_move:
500
+ if model is not None:
501
+ try:
502
+ if hasattr(model, 'to'):
503
+ model.to(self.device)
504
+ logger.info(f"{model_name} 已移回GPU")
505
+ except Exception as e:
506
+ logger.warning(f"移动 {model_name} 到GPU失败: {e}")
507
+
508
+ if torch.cuda.is_available():
509
+ allocated = torch.cuda.memory_allocated() / 1024**3
510
+ logger.info(f"移回GPU后显存使用: {allocated:.2f}GB")
511
+
512
+ logger.info("所有模型已移回GPU")
513
+
514
+ except Exception as e:
515
+ logger.error(f"移动模型到GPU失败: {e}")
516
+
517
+ def force_reload_all_models(self):
518
+ """强制重新加载所有模型"""
519
+ logger.info("开始强制重新加载所有模型...")
520
+ try:
521
+ # 释放现有模型
522
+ models_to_delete = [
523
+ 'caption_model', 'caption_processor',
524
+ 'clip_model', 'clip_processor',
525
+ 'sd_pipeline', 'controlnet', 'controlnet_pipeline'
526
+ ]
527
+
528
+ for model_name in models_to_delete:
529
+ if hasattr(self, model_name):
530
+ model = getattr(self, model_name)
531
+ if model is not None:
532
+ try:
533
+ del model
534
+ setattr(self, model_name, None)
535
+ logger.info(f"释放 {model_name}")
536
+ except Exception as e:
537
+ logger.warning(f"释放 {model_name} 失败: {e}")
538
+
539
+ # 强制垃圾回收
540
+ gc.collect()
541
+
542
+ # 清理GPU缓存
543
+ if torch.cuda.is_available():
544
+ torch.cuda.empty_cache()
545
+ torch.cuda.ipc_collect()
546
+
547
+ logger.info("开始重新加载模型...")
548
+
549
+ # 重新加载所有模型
550
+ self.load_all_models()
551
+
552
+ logger.info("所有模型重新加载完成")
553
+
554
+ except Exception as e:
555
+ logger.error(f"强制重新加载模型失败: {e}")
556
+ raise
557
 
558
  def get_model_status(self):
559
+ """获取模型加载状态"""
560
  status = {
561
  "caption_model": self.caption_model is not None,
562
  "clip_model": self.clip_model is not None,
 
564
  "controlnet_pipeline": self.controlnet_pipeline is not None,
565
  "device": self.device
566
  }
567
+
568
  if torch.cuda.is_available():
569
  status["gpu_memory"] = {
570
  "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
571
+ "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB",
572
+ "max_allocated": f"{torch.cuda.max_memory_allocated() / 1024**3:.2f}GB"
573
  }
574
+
575
  return status
576
+
577
+ def optimize_for_inference(self):
578
+ """优化模型以提高推理速度"""
579
+ logger.info("优化模型推理性能...")
580
+
581
+ try:
582
+ # 编译模型(如果PyTorch版本支持)
583
+ if hasattr(torch, 'compile'):
584
+ models_to_compile = [
585
+ self.caption_model,
586
+ self.clip_model
587
+ ]
588
+
589
+ for model in models_to_compile:
590
+ if model is not None:
591
+ try:
592
+ model = torch.compile(model)
593
+ logger.info(f"模型编译成功")
594
+ except Exception as e:
595
+ logger.info(f"模型编译跳过: {e}")
596
+
597
+ # 设置模型为评估模式
598
+ models = [self.caption_model, self.clip_model]
599
+ for model in models:
600
+ if model is not None:
601
+ model.eval()
602
+
603
+ logger.info("模型优化完成")
604
+
605
+ except Exception as e:
606
+ logger.warning(f"模型优化失败: {e}")
607
+
608
+ def benchmark_models(self):
609
+ """基准测试模型性能"""
610
+ logger.info("开始模型性能基准测试...")
611
+
612
+ try:
613
+ # 创建测试图像
614
+ test_image = Image.new('RGB', (512, 512), color=(128, 128, 128))
615
+
616
+ results = {}
617
+
618
+ # 测试BLIP
619
+ if self.caption_model is not None:
620
+ start_time = time.time()
621
+ _ = self.generate_caption(test_image)
622
+ results['caption_time'] = time.time() - start_time
623
+
624
+ # 测试CLIP
625
+ if self.clip_model is not None:
626
+ start_time = time.time()
627
+ _ = self.analyze_style(test_image)
628
+ results['clip_time'] = time.time() - start_time
629
+
630
+ # 测试SD
631
+ if self.sd_pipeline is not None:
632
+ start_time = time.time()
633
+ _ = self.generate_image("test fashion design", num_inference_steps=5)
634
+ results['sd_time'] = time.time() - start_time
635
+
636
+ logger.info(f"基准测试结果: {results}")
637
+ return results
638
+
639
+ except Exception as e:
640
+ logger.error(f"基准测试失败: {e}")
641
+ return {}