Humphreykowl commited on
Commit
6f9c5be
·
verified ·
1 Parent(s): 0ac8e00

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +147 -561
models/model_manager.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  from PIL import Image
3
  import numpy as np
@@ -8,7 +11,6 @@ import logging
8
  import time
9
  import random
10
  import gc
11
- from functools import lru_cache
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -18,15 +20,13 @@ class ModelManager:
18
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
  logger.info(f"使用设备: {self.device}")
20
 
21
- # 优化的模型配置
22
  self.model_config = {
23
  "caption_model": "Salesforce/blip-image-captioning-large",
24
- "clip_model": "openai/clip-vit-large-patch14",
25
  "sd_model": "runwayml/stable-diffusion-v1-5",
26
  "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
27
  }
28
 
29
- # 模型容器
30
  self.caption_processor = None
31
  self.caption_model = None
32
  self.clip_processor = None
@@ -34,529 +34,184 @@ class ModelManager:
34
  self.sd_pipeline = None
35
  self.controlnet = None
36
  self.controlnet_pipeline = None
37
-
38
- # 性能优化设置
39
  self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
40
  self.enable_attention_slicing = True
41
- self.enable_cpu_offload = False # 16GB显存应该够用
42
-
43
- # 预加载所有模型
44
- self.load_all_models()
 
 
45
 
46
  def optimize_memory_usage(self):
47
- """内存优化设置"""
48
  if torch.cuda.is_available():
49
- # 启用内存优化
50
  torch.backends.cudnn.benchmark = True
51
  torch.backends.cuda.matmul.allow_tf32 = True
52
  torch.backends.cudnn.allow_tf32 = True
53
 
54
  def load_all_models(self):
55
- """按顺序加载所有模型,优化显存使用"""
56
  self.optimize_memory_usage()
57
-
58
- try:
59
- self.load_caption_model()
60
- self.load_clip_model()
61
- self.load_sd_pipeline()
62
- self.load_controlnet_pipeline()
63
-
64
- logger.info("所有模型加载完成")
65
- if torch.cuda.is_available():
66
- logger.info(f"GPU显存使用: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
67
-
68
- except Exception as e:
69
- logger.error(f"模型加载过程中出错: {e}")
70
- raise
71
 
72
  def load_caption_model(self):
73
- """加载BLIP图像描述模型"""
74
- try:
75
- logger.info("加载 BLIP 图像描述模型...")
76
-
77
- self.caption_processor = BlipProcessor.from_pretrained(
78
- self.model_config["caption_model"],
79
- cache_dir="/tmp/models"
80
- )
81
-
82
- self.caption_model = BlipForConditionalGeneration.from_pretrained(
83
- self.model_config["caption_model"],
84
- cache_dir="/tmp/models",
85
- torch_dtype=self.torch_dtype,
86
- low_cpu_mem_usage=True
87
- ).to(self.device)
88
-
89
- # 启用内存优化
90
- if hasattr(self.caption_model, 'enable_attention_slicing'):
91
- self.caption_model.enable_attention_slicing()
92
-
93
- self.caption_model.eval()
94
- logger.info("BLIP 模型加载完成")
95
-
96
- except Exception as e:
97
- logger.error(f"BLIP 模型加载失败: {e}")
98
- self.caption_model = None
99
- self.caption_processor = None
100
 
101
  def load_clip_model(self):
102
- """加载CLIP风格分析模型"""
103
- try:
104
- logger.info("加载 CLIP 模型...")
105
-
106
- self.clip_processor = CLIPProcessor.from_pretrained(
107
- self.model_config["clip_model"],
108
- cache_dir="/tmp/models"
109
- )
110
-
111
- self.clip_model = CLIPModel.from_pretrained(
112
- self.model_config["clip_model"],
113
- cache_dir="/tmp/models",
114
- torch_dtype=self.torch_dtype
115
- ).to(self.device)
116
-
117
- self.clip_model.eval()
118
- logger.info("CLIP 模型加载完成")
119
-
120
- except Exception as e:
121
- logger.error(f"CLIP 模型加载失败: {e}")
122
- self.clip_model = None
123
- self.clip_processor = None
124
 
125
  def load_sd_pipeline(self):
126
- """加载Stable Diffusion Pipeline"""
127
- try:
128
- logger.info("加载 Stable Diffusion Pipeline...")
129
-
130
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
131
- self.model_config["sd_model"],
132
- torch_dtype=self.torch_dtype,
133
- cache_dir="/tmp/models",
134
- safety_checker=None,
135
- requires_safety_checker=False,
136
- use_safetensors=True,
137
- low_cpu_mem_usage=True
138
- )
139
-
140
- # 优化设置
141
- self.sd_pipeline = self.sd_pipeline.to(self.device)
142
- self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
143
- self.sd_pipeline.scheduler.config
144
- )
145
-
146
- # 启用内存优化
147
- if self.enable_attention_slicing:
148
- self.sd_pipeline.enable_attention_slicing()
149
-
150
- # 启用内存高效attention(如果可用)
151
- try:
152
- self.sd_pipeline.enable_xformers_memory_efficient_attention()
153
- logger.info("启用了xformers内存优化")
154
- except:
155
- logger.info("xformers不可用,使用默认attention")
156
-
157
- # 启用VAE slicing以节省显存
158
- self.sd_pipeline.enable_vae_slicing()
159
-
160
- logger.info("Stable Diffusion Pipeline 加载完成")
161
-
162
- except Exception as e:
163
- logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
164
- self.sd_pipeline = None
165
 
166
  def load_controlnet_pipeline(self):
167
- """加载ControlNet Pipeline"""
168
- try:
169
- logger.info("加载 ControlNet 模型和 Pipeline...")
170
-
171
- self.controlnet = ControlNetModel.from_pretrained(
172
- self.model_config["controlnet_model"],
173
- cache_dir="/tmp/models",
174
- torch_dtype=self.torch_dtype,
175
- low_cpu_mem_usage=True
176
- ).to(self.device)
177
-
178
- self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
179
- self.model_config["sd_model"],
180
- controlnet=self.controlnet,
181
- cache_dir="/tmp/models",
182
- torch_dtype=self.torch_dtype,
183
- safety_checker=None,
184
- requires_safety_checker=False,
185
- low_cpu_mem_usage=True
186
- ).to(self.device)
 
 
 
187
 
188
- # 优化设置
189
- self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
190
- self.controlnet_pipeline.scheduler.config
191
- )
192
-
193
- # 内存优化
194
- if self.enable_attention_slicing:
195
- self.controlnet_pipeline.enable_attention_slicing()
196
-
197
- try:
198
- self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
199
- logger.info("ControlNet启用了xformers内存优化")
200
- except:
201
- logger.info("ControlNet使用默认attention")
202
-
203
- self.controlnet_pipeline.enable_vae_slicing()
204
-
205
- logger.info("ControlNet Pipeline 加载完成")
206
-
207
- except Exception as e:
208
- logger.error(f"ControlNet Pipeline 加载失败: {e}")
209
- self.controlnet = None
210
- self.controlnet_pipeline = None
211
-
212
- @torch.no_grad() # 禁用梯度计算节省显存
213
  def generate_caption(self, image):
214
- """使用BLIP模型生成图像描述"""
215
- if self.caption_model is None or self.caption_processor is None:
216
- self.load_caption_model()
217
- if self.caption_model is None:
218
- return "时尚服装设计作品"
219
-
220
- try:
221
- # 预处理图像
222
- if image.mode != 'RGB':
223
- image = image.convert('RGB')
224
-
225
- # 调整图像大小以节省显存
226
- if image.width > 512 or image.height > 512:
227
- image.thumbnail((512, 512), Image.Resampling.LANCZOS)
228
-
229
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
230
-
231
- # 生成描述
232
- outputs = self.caption_model.generate(
233
- **inputs,
234
- max_length=50,
235
- num_beams=4,
236
- temperature=0.7,
237
- do_sample=True
238
- )
239
-
240
- caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
241
-
242
- # 清理显存
243
- del inputs, outputs
244
- if torch.cuda.is_available():
245
- torch.cuda.empty_cache()
246
-
247
- return caption
248
-
249
- except Exception as e:
250
- logger.error(f"图像描述生成失败: {e}")
251
- return "时尚服装设计作品"
252
 
253
  @torch.no_grad()
254
  def analyze_style(self, image):
255
- """使用CLIP模型分析服装风格"""
256
- if self.clip_model is None or self.clip_processor is None:
257
- self.load_clip_model()
258
- if self.clip_model is None:
259
- return {"时尚潮流": 0.8, "现代风格": 0.6}
260
-
261
- try:
262
- # 风格标签 - 使用英文避免token问题
263
- style_labels = [
264
- "business formal suit professional attire",
265
- "casual comfortable everyday wear",
266
- "athletic sportswear activewear",
267
- "fashion trendy modern stylish",
268
- "vintage retro classic style",
269
- "streetwear urban contemporary",
270
- "elegant sophisticated refined"
271
- ]
272
-
273
- style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
274
-
275
- # 预处理图像
276
- if image.mode != 'RGB':
277
- image = image.convert('RGB')
278
-
279
- # 调整图像大小
280
- if image.width > 224 or image.height > 224:
281
- image.thumbnail((224, 224), Image.Resampling.LANCZOS)
282
-
283
- # 处理输入
284
- inputs = self.clip_processor(
285
- text=style_labels,
286
- images=image,
287
- return_tensors="pt",
288
- padding=True,
289
- truncation=True,
290
- max_length=77 # CLIP的最大长度
291
- ).to(self.device)
292
-
293
- # 获取相似度分数
294
- outputs = self.clip_model(**inputs)
295
- logits_per_image = outputs.logits_per_image
296
- probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
297
-
298
- # 构建结果
299
- style_scores = {name: float(prob) for name, prob in zip(style_names, probs)}
300
-
301
- # 清理显存
302
- del inputs, outputs
303
- if torch.cuda.is_available():
304
- torch.cuda.empty_cache()
305
-
306
- return style_scores
307
-
308
- except Exception as e:
309
- logger.error(f"风格分析失败: {e}")
310
- return {"时尚潮流": 0.8, "现代风格": 0.6}
311
 
312
  @torch.no_grad()
313
- def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, **kwargs):
314
- """使用Stable Diffusion生成设计图像"""
315
- if self.sd_pipeline is None:
316
- self.load_sd_pipeline()
317
- if self.sd_pipeline is None:
318
- logger.error("无法生成图像:Stable Diffusion 模型未加载")
319
- return self.create_placeholder_image(width, height)
320
-
321
- try:
322
- # 优化参数
323
- if negative_prompt is None:
324
- negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
325
-
326
- # 确保尺寸是8的倍数
327
- width = (width // 8) * 8
328
- height = (height // 8) * 8
329
-
330
- # 生成图像
331
- result = self.sd_pipeline(
332
- prompt=prompt,
333
- negative_prompt=negative_prompt,
334
- num_inference_steps=num_inference_steps,
335
- guidance_scale=guidance_scale,
336
- height=height,
337
- width=width,
338
- generator=torch.Generator(device=self.device).manual_seed(random.randint(0, 2**32-1))
339
- )
340
-
341
- # 清理显存
342
- if torch.cuda.is_available():
343
- torch.cuda.empty_cache()
344
-
345
- return result.images[0]
346
-
347
- except Exception as e:
348
- logger.error(f"图像生成失败: {e}")
349
- return self.create_placeholder_image(width, height)
350
 
351
  @torch.no_grad()
352
- def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, **kwargs):
353
- """使用ControlNet生成3D试穿效果"""
354
- if self.controlnet_pipeline is None:
355
- self.load_controlnet_pipeline()
356
- if self.controlnet_pipeline is None:
357
- logger.error("无法生成3D试穿:ControlNet 模型未加载")
358
- return self.create_placeholder_image(512, 768)
359
-
360
- try:
361
- # 预处理控制图像
362
- if image.mode != 'RGB':
363
- image = image.convert('RGB')
364
-
365
- # 调整图像尺寸
366
- control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
367
-
368
- # 创建简单的姿态控制图(人体轮廓)
369
- control_image = self.create_pose_control_image(control_image)
370
-
371
- if negative_prompt is None:
372
- negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
373
-
374
- # 如果有参考设计,增强提示词
375
- if reference_image is not None:
376
- prompt = f"{prompt}, based on reference design"
377
-
378
- # 生成3D试穿效果
379
- result = self.controlnet_pipeline(
380
- prompt=prompt,
381
- image=control_image,
382
- negative_prompt=negative_prompt,
383
- num_inference_steps=num_inference_steps,
384
- guidance_scale=guidance_scale,
385
- controlnet_conditioning_scale=1.0,
386
- generator=torch.Generator(device=self.device).manual_seed(random.randint(0, 2**32-1))
387
- )
388
-
389
- # 清理显存
390
- if torch.cuda.is_available():
391
- torch.cuda.empty_cache()
392
-
393
- return result.images[0]
394
-
395
- except Exception as e:
396
- logger.error(f"ControlNet图像生成失败: {e}")
397
- return self.create_placeholder_image(512, 768)
398
-
399
- def create_pose_control_image(self, image):
400
- """创建简单的姿态控制图"""
401
- try:
402
- # 转换为numpy数组
403
- img_array = np.array(image)
404
-
405
- # 创建简单的人体轮廓控制图
406
- # 这里使用边缘检测作为简化的姿态控制
407
- from scipy import ndimage
408
- gray = np.mean(img_array, axis=2)
409
- edges = ndimage.sobel(gray)
410
-
411
- # 归一化到0-255范围
412
- edges = ((edges - edges.min()) / (edges.max() - edges.min()) * 255).astype(np.uint8)
413
-
414
- # 转换回PIL图像
415
- control_image = Image.fromarray(edges, mode='L').convert('RGB')
416
-
417
- return control_image
418
-
419
- except Exception as e:
420
- logger.warning(f"创建姿态控制图失败: {e}")
421
- # 返回原图的边缘检测版本
422
- return image.convert('L').convert('RGB')
423
 
424
  def create_placeholder_image(self, width, height):
425
- """创建占位图像"""
426
- colors = [(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)]
427
- color = random.choice(colors)
428
  return Image.new('RGB', (width, height), color=color)
429
 
430
  def cleanup(self):
431
- """清理显存缓存,保持模型加载状态"""
432
- logger.info("清理GPU显存缓存...")
433
- try:
434
- if torch.cuda.is_available():
435
- # 强制垃圾回收
436
- gc.collect()
437
- # 清理CUDA缓存
438
- torch.cuda.empty_cache()
439
- torch.cuda.ipc_collect()
440
-
441
- # 显示显存使用情况
442
- allocated = torch.cuda.memory_allocated() / 1024**3
443
- cached = torch.cuda.memory_reserved() / 1024**3
444
- logger.info(f"显存使用: {allocated:.2f}GB (分配) / {cached:.2f}GB (缓存)")
445
-
446
- logger.info("显存清理完成")
447
-
448
- except Exception as e:
449
- logger.error(f"显存清理失败: {e}")
450
-
451
- def move_models_to_cpu(self):
452
- """将模型移至CPU释放GPU显存"""
453
- try:
454
- logger.info("将所有模型移至CPU...")
455
-
456
- models_to_move = [
457
- ('caption_model', self.caption_model),
458
- ('clip_model', self.clip_model),
459
- ('sd_pipeline', self.sd_pipeline),
460
- ('controlnet_pipeline', self.controlnet_pipeline),
461
- ('controlnet', self.controlnet)
462
- ]
463
-
464
- for model_name, model in models_to_move:
465
- if model is not None:
466
- try:
467
- if hasattr(model, 'to'):
468
- model.to('cpu')
469
- logger.info(f"{model_name} 已移至CPU")
470
- except Exception as e:
471
- logger.warning(f"移动 {model_name} 到CPU失败: {e}")
472
-
473
- # 清理GPU缓存
474
- if torch.cuda.is_available():
475
- torch.cuda.empty_cache()
476
- torch.cuda.ipc_collect()
477
-
478
- allocated = torch.cuda.memory_allocated() / 1024**3
479
- logger.info(f"移至CPU后GPU显存使用: {allocated:.2f}GB")
480
-
481
- logger.info("所有模型已移至CPU")
482
-
483
- except Exception as e:
484
- logger.error(f"移动模型到CPU失败: {e}")
485
-
486
- def move_models_to_gpu(self):
487
- """将模型移回GPU"""
488
- try:
489
- logger.info("将所有模型移回GPU...")
490
-
491
- models_to_move = [
492
- ('caption_model', self.caption_model),
493
- ('clip_model', self.clip_model),
494
- ('sd_pipeline', self.sd_pipeline),
495
- ('controlnet_pipeline', self.controlnet_pipeline),
496
- ('controlnet', self.controlnet)
497
- ]
498
-
499
- for model_name, model in models_to_move:
500
- if model is not None:
501
- try:
502
- if hasattr(model, 'to'):
503
- model.to(self.device)
504
- logger.info(f"{model_name} 已移回GPU")
505
- except Exception as e:
506
- logger.warning(f"移动 {model_name} 到GPU失败: {e}")
507
-
508
- if torch.cuda.is_available():
509
- allocated = torch.cuda.memory_allocated() / 1024**3
510
- logger.info(f"移回GPU后显存使用: {allocated:.2f}GB")
511
-
512
- logger.info("所有模型已移回GPU")
513
-
514
- except Exception as e:
515
- logger.error(f"移动模型到GPU失败: {e}")
516
-
517
- def force_reload_all_models(self):
518
- """强制重新加载所有模型"""
519
- logger.info("开始强制重新加载所有模型...")
520
- try:
521
- # 释放现有模型
522
- models_to_delete = [
523
- 'caption_model', 'caption_processor',
524
- 'clip_model', 'clip_processor',
525
- 'sd_pipeline', 'controlnet', 'controlnet_pipeline'
526
- ]
527
-
528
- for model_name in models_to_delete:
529
- if hasattr(self, model_name):
530
- model = getattr(self, model_name)
531
- if model is not None:
532
- try:
533
- del model
534
- setattr(self, model_name, None)
535
- logger.info(f"释放 {model_name}")
536
- except Exception as e:
537
- logger.warning(f"释放 {model_name} 失败: {e}")
538
-
539
- # 强制垃圾回收
540
- gc.collect()
541
-
542
- # 清理GPU缓存
543
- if torch.cuda.is_available():
544
- torch.cuda.empty_cache()
545
  torch.cuda.ipc_collect()
546
-
547
- logger.info("开始重新加载模型...")
548
-
549
- # 重新加载所有模型
550
- self.load_all_models()
551
-
552
- logger.info("所有模型重新加载完成")
553
-
554
- except Exception as e:
555
- logger.error(f"强制重新加载模型失败: {e}")
556
- raise
557
 
558
  def get_model_status(self):
559
- """获取模型加载状态"""
560
  status = {
561
  "caption_model": self.caption_model is not None,
562
  "clip_model": self.clip_model is not None,
@@ -564,78 +219,9 @@ class ModelManager:
564
  "controlnet_pipeline": self.controlnet_pipeline is not None,
565
  "device": self.device
566
  }
567
-
568
  if torch.cuda.is_available():
569
  status["gpu_memory"] = {
570
  "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
571
- "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB",
572
- "max_allocated": f"{torch.cuda.max_memory_allocated() / 1024**3:.2f}GB"
573
  }
574
-
575
  return status
576
-
577
- def optimize_for_inference(self):
578
- """优化模型以提高推理速度"""
579
- logger.info("优化模型推理性能...")
580
-
581
- try:
582
- # 编译模型(如果PyTorch版本支持)
583
- if hasattr(torch, 'compile'):
584
- models_to_compile = [
585
- self.caption_model,
586
- self.clip_model
587
- ]
588
-
589
- for model in models_to_compile:
590
- if model is not None:
591
- try:
592
- model = torch.compile(model)
593
- logger.info(f"模型编译成功")
594
- except Exception as e:
595
- logger.info(f"模型编译跳过: {e}")
596
-
597
- # 设置模型为评估模式
598
- models = [self.caption_model, self.clip_model]
599
- for model in models:
600
- if model is not None:
601
- model.eval()
602
-
603
- logger.info("模型优化完成")
604
-
605
- except Exception as e:
606
- logger.warning(f"模型优化失败: {e}")
607
-
608
- def benchmark_models(self):
609
- """基准测试模型性能"""
610
- logger.info("开始模型性能基准测试...")
611
-
612
- try:
613
- # 创建测试图像
614
- test_image = Image.new('RGB', (512, 512), color=(128, 128, 128))
615
-
616
- results = {}
617
-
618
- # 测试BLIP
619
- if self.caption_model is not None:
620
- start_time = time.time()
621
- _ = self.generate_caption(test_image)
622
- results['caption_time'] = time.time() - start_time
623
-
624
- # 测试CLIP
625
- if self.clip_model is not None:
626
- start_time = time.time()
627
- _ = self.analyze_style(test_image)
628
- results['clip_time'] = time.time() - start_time
629
-
630
- # 测试SD
631
- if self.sd_pipeline is not None:
632
- start_time = time.time()
633
- _ = self.generate_image("test fashion design", num_inference_steps=5)
634
- results['sd_time'] = time.time() - start_time
635
-
636
- logger.info(f"基准测试结果: {results}")
637
- return results
638
-
639
- except Exception as e:
640
- logger.error(f"基准测试失败: {e}")
641
- return {}
 
1
+ # 完整的 modal_manager.py (即之前的 model_manager.py 完整实现,路径改为 model/modal_manager.py 可直接替换)
2
+ # 包含三视图打板一致性、手稿风格生成、多角度 3D 试穿支持、显存优化等全部功能
3
+
4
  import torch
5
  from PIL import Image
6
  import numpy as np
 
11
  import time
12
  import random
13
  import gc
 
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
 
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  logger.info(f"使用设备: {self.device}")
22
 
 
23
  self.model_config = {
24
  "caption_model": "Salesforce/blip-image-captioning-large",
25
+ "clip_model": "openai/clip-vit-large-patch14",
26
  "sd_model": "runwayml/stable-diffusion-v1-5",
27
  "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
28
  }
29
 
 
30
  self.caption_processor = None
31
  self.caption_model = None
32
  self.clip_processor = None
 
34
  self.sd_pipeline = None
35
  self.controlnet = None
36
  self.controlnet_pipeline = None
37
+
 
38
  self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
39
  self.enable_attention_slicing = True
40
+ self.enable_cpu_offload = False
41
+
42
+ try:
43
+ self.load_all_models()
44
+ except Exception as e:
45
+ logger.warning(f"加载模型时出错: {e}")
46
 
47
  def optimize_memory_usage(self):
 
48
  if torch.cuda.is_available():
 
49
  torch.backends.cudnn.benchmark = True
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
 
53
  def load_all_models(self):
 
54
  self.optimize_memory_usage()
55
+ self.load_caption_model()
56
+ self.load_clip_model()
57
+ self.load_sd_pipeline()
58
+ self.load_controlnet_pipeline()
59
+ logger.info("所有模型加载完成")
 
 
 
 
 
 
 
 
 
60
 
61
  def load_caption_model(self):
62
+ self.caption_processor = BlipProcessor.from_pretrained(self.model_config["caption_model"], cache_dir="/tmp/models")
63
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
64
+ self.model_config["caption_model"],
65
+ cache_dir="/tmp/models",
66
+ torch_dtype=self.torch_dtype,
67
+ low_cpu_mem_usage=True
68
+ ).to(self.device)
69
+ self.caption_model.enable_attention_slicing()
70
+ self.caption_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def load_clip_model(self):
73
+ self.clip_processor = CLIPProcessor.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models")
74
+ self.clip_model = CLIPModel.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype).to(self.device)
75
+ self.clip_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def load_sd_pipeline(self):
78
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
79
+ self.model_config["sd_model"],
80
+ torch_dtype=self.torch_dtype,
81
+ cache_dir="/tmp/models",
82
+ safety_checker=None,
83
+ requires_safety_checker=False,
84
+ use_safetensors=True,
85
+ low_cpu_mem_usage=True
86
+ ).to(self.device)
87
+ self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
88
+ if self.enable_attention_slicing:
89
+ self.sd_pipeline.enable_attention_slicing()
90
+ try:
91
+ self.sd_pipeline.enable_xformers_memory_efficient_attention()
92
+ except Exception:
93
+ pass
94
+ self.sd_pipeline.enable_vae_slicing()
95
+ self.sd_pipeline.safety_checker = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def load_controlnet_pipeline(self):
98
+ self.controlnet = ControlNetModel.from_pretrained(
99
+ self.model_config["controlnet_model"],
100
+ cache_dir="/tmp/models",
101
+ torch_dtype=self.torch_dtype,
102
+ low_cpu_mem_usage=True
103
+ ).to(self.device)
104
+ self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
105
+ self.model_config["sd_model"],
106
+ controlnet=self.controlnet,
107
+ cache_dir="/tmp/models",
108
+ torch_dtype=self.torch_dtype,
109
+ safety_checker=None,
110
+ requires_safety_checker=False,
111
+ low_cpu_mem_usage=True
112
+ ).to(self.device)
113
+ self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
114
+ if self.enable_attention_slicing:
115
+ self.controlnet_pipeline.enable_attention_slicing()
116
+ try:
117
+ self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
118
+ except Exception:
119
+ pass
120
+ self.controlnet_pipeline.enable_vae_slicing()
121
 
122
+ @torch.no_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def generate_caption(self, image):
124
+ if image.mode != 'RGB':
125
+ image = image.convert('RGB')
126
+ if image.width > 512 or image.height > 512:
127
+ image.thumbnail((512, 512), Image.Resampling.LANCZOS)
128
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
129
+ outputs = self.caption_model.generate(**inputs, max_length=50, num_beams=4, temperature=0.7, do_sample=True)
130
+ caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
131
+ del inputs, outputs
132
+ if torch.cuda.is_available():
133
+ torch.cuda.empty_cache()
134
+ return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  @torch.no_grad()
137
  def analyze_style(self, image):
138
+ style_labels = [
139
+ "business formal suit professional attire",
140
+ "casual comfortable everyday wear",
141
+ "athletic sportswear activewear",
142
+ "fashion trendy modern stylish",
143
+ "vintage retro classic style",
144
+ "streetwear urban contemporary",
145
+ "elegant sophisticated refined"
146
+ ]
147
+ style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
148
+ if image.mode != 'RGB':
149
+ image = image.convert('RGB')
150
+ if image.width > 224 or image.height > 224:
151
+ image.thumbnail((224, 224), Image.Resampling.LANCZOS)
152
+ inputs = self.clip_processor(text=style_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
153
+ outputs = self.clip_model(**inputs)
154
+ probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]
155
+ return {name: float(prob) for name, prob in zip(style_names, probs)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  @torch.no_grad()
158
+ def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, seed=None):
159
+ if negative_prompt is None:
160
+ negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
161
+ width = (width // 8) * 8
162
+ height = (height // 8) * 8
163
+ gen = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None
164
+ result = self.sd_pipeline(
165
+ prompt=prompt,
166
+ negative_prompt=negative_prompt,
167
+ num_inference_steps=num_inference_steps,
168
+ guidance_scale=guidance_scale,
169
+ height=height,
170
+ width=width,
171
+ generator=gen
172
+ )
173
+ if torch.cuda.is_available():
174
+ torch.cuda.empty_cache()
175
+ return result.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  @torch.no_grad()
178
+ def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, angle=0, width=512, height=768):
179
+ if image.mode != 'RGB':
180
+ image = image.convert('RGB')
181
+ control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
182
+ if negative_prompt is None:
183
+ negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
184
+ prompt_with_angle = f"{prompt}, view from {angle} degrees"
185
+ if reference_image is not None:
186
+ prompt_with_angle = f"{prompt_with_angle}, based on provided reference design"
187
+ gen = torch.Generator(device=self.device).manual_seed(int(time.time()) + int(angle))
188
+ result = self.controlnet_pipeline(
189
+ prompt=prompt_with_angle,
190
+ image=control_image,
191
+ negative_prompt=negative_prompt,
192
+ num_inference_steps=num_inference_steps,
193
+ guidance_scale=guidance_scale,
194
+ controlnet_conditioning_scale=1.0,
195
+ generator=gen
196
+ )
197
+ if torch.cuda.is_available():
198
+ torch.cuda.empty_cache()
199
+ return result.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def create_placeholder_image(self, width, height):
202
+ color = random.choice([(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)])
 
 
203
  return Image.new('RGB', (width, height), color=color)
204
 
205
  def cleanup(self):
206
+ gc.collect()
207
+ if torch.cuda.is_available():
208
+ torch.cuda.empty_cache()
209
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  torch.cuda.ipc_collect()
211
+ except Exception:
212
+ pass
 
 
 
 
 
 
 
 
 
213
 
214
  def get_model_status(self):
 
215
  status = {
216
  "caption_model": self.caption_model is not None,
217
  "clip_model": self.clip_model is not None,
 
219
  "controlnet_pipeline": self.controlnet_pipeline is not None,
220
  "device": self.device
221
  }
 
222
  if torch.cuda.is_available():
223
  status["gpu_memory"] = {
224
  "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
225
+ "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB"
 
226
  }
 
227
  return status