Humphreykowl commited on
Commit
9cc48ff
·
verified ·
1 Parent(s): 422bb60

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +308 -1
models/model_manager.py CHANGED
@@ -331,4 +331,311 @@ class ModelManager:
331
  result = self.sd_pipeline(
332
  prompt=prompt,
333
  negative_prompt=negative_prompt,
334
- num_inference_steps=num_inference_steps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  result = self.sd_pipeline(
332
  prompt=prompt,
333
  negative_prompt=negative_prompt,
334
+ num_inference_steps=num_inference_steps,
335
+ guidance_scale=guidance_scale,
336
+ height=height,
337
+ width=width,
338
+ generator=torch.Generator(device=self.device).manual_seed(random.randint(0, 2**32-1))
339
+ )
340
+
341
+ # 清理显存
342
+ if torch.cuda.is_available():
343
+ torch.cuda.empty_cache()
344
+
345
+ return result.images[0]
346
+
347
+ except Exception as e:
348
+ logger.error(f"图像生成失败: {e}")
349
+ return self.create_placeholder_image(width, height)
350
+
351
+ @torch.no_grad()
352
+ def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, **kwargs):
353
+ """使用ControlNet生成3D试穿效果"""
354
+ if self.controlnet_pipeline is None:
355
+ self.load_controlnet_pipeline()
356
+ if self.controlnet_pipeline is None:
357
+ logger.error("无法生成3D试穿:ControlNet 模型未加载")
358
+ return self.create_placeholder_image(512, 768)
359
+
360
+ try:
361
+ # 预处理控制图像
362
+ if image.mode != 'RGB':
363
+ image = image.convert('RGB')
364
+
365
+ # 调整图像尺寸
366
+ control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
367
+
368
+ # 创建简单的姿态控制图(人体轮廓)
369
+ control_image = self.create_pose_control_image(control_image)
370
+
371
+ if negative_prompt is None:
372
+ negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
373
+
374
+ # 如果有参考设计,增强提示词
375
+ if reference_image is not None:
376
+ prompt = f"{prompt}, based on reference design"
377
+
378
+ # 生成3D试穿效果
379
+ result = self.controlnet_pipeline(
380
+ prompt=prompt,
381
+ image=control_image,
382
+ negative_prompt=negative_prompt,
383
+ num_inference_steps=num_inference_steps,
384
+ guidance_scale=guidance_scale,
385
+ controlnet_conditioning_scale=1.0,
386
+ generator=torch.Generator(device=self.device).manual_seed(random.randint(0, 2**32-1))
387
+ )
388
+
389
+ # 清理显存
390
+ if torch.cuda.is_available():
391
+ torch.cuda.empty_cache()
392
+
393
+ return result.images[0]
394
+
395
+ except Exception as e:
396
+ logger.error(f"ControlNet图像生成失败: {e}")
397
+ return self.create_placeholder_image(512, 768)
398
+
399
+ def create_pose_control_image(self, image):
400
+ """创建简单的姿态控制图"""
401
+ try:
402
+ # 转换为numpy数组
403
+ img_array = np.array(image)
404
+
405
+ # 创建简单的人体轮廓控制图
406
+ # 这里使用边缘检测作为简化的姿态控制
407
+ from scipy import ndimage
408
+ gray = np.mean(img_array, axis=2)
409
+ edges = ndimage.sobel(gray)
410
+
411
+ # 归一化到0-255范围
412
+ edges = ((edges - edges.min()) / (edges.max() - edges.min()) * 255).astype(np.uint8)
413
+
414
+ # 转换回PIL图像
415
+ control_image = Image.fromarray(edges, mode='L').convert('RGB')
416
+
417
+ return control_image
418
+
419
+ except Exception as e:
420
+ logger.warning(f"创建姿态控制图失败: {e}")
421
+ # 返回原图的边缘检测版本
422
+ return image.convert('L').convert('RGB')
423
+
424
+ def create_placeholder_image(self, width, height):
425
+ """创建占位图像"""
426
+ colors = [(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)]
427
+ color = random.choice(colors)
428
+ return Image.new('RGB', (width, height), color=color)
429
+
430
+ def cleanup(self):
431
+ """清理显存缓存,保持模型加载状态"""
432
+ logger.info("清理GPU显存缓存...")
433
+ try:
434
+ if torch.cuda.is_available():
435
+ # 强制垃圾回收
436
+ gc.collect()
437
+ # 清理CUDA缓存
438
+ torch.cuda.empty_cache()
439
+ torch.cuda.ipc_collect()
440
+
441
+ # 显示显存使用情况
442
+ allocated = torch.cuda.memory_allocated() / 1024**3
443
+ cached = torch.cuda.memory_reserved() / 1024**3
444
+ logger.info(f"显存使用: {allocated:.2f}GB (分配) / {cached:.2f}GB (缓存)")
445
+
446
+ logger.info("显存清理完成")
447
+
448
+ except Exception as e:
449
+ logger.error(f"显存��理失败: {e}")
450
+
451
+ def move_models_to_cpu(self):
452
+ """将模型移至CPU释放GPU显存"""
453
+ try:
454
+ logger.info("将所有模型移至CPU...")
455
+
456
+ models_to_move = [
457
+ ('caption_model', self.caption_model),
458
+ ('clip_model', self.clip_model),
459
+ ('sd_pipeline', self.sd_pipeline),
460
+ ('controlnet_pipeline', self.controlnet_pipeline),
461
+ ('controlnet', self.controlnet)
462
+ ]
463
+
464
+ for model_name, model in models_to_move:
465
+ if model is not None:
466
+ try:
467
+ if hasattr(model, 'to'):
468
+ model.to('cpu')
469
+ logger.info(f"{model_name} 已移至CPU")
470
+ except Exception as e:
471
+ logger.warning(f"移动 {model_name} 到CPU失败: {e}")
472
+
473
+ # 清理GPU缓存
474
+ if torch.cuda.is_available():
475
+ torch.cuda.empty_cache()
476
+ torch.cuda.ipc_collect()
477
+
478
+ allocated = torch.cuda.memory_allocated() / 1024**3
479
+ logger.info(f"移至CPU后GPU显存使用: {allocated:.2f}GB")
480
+
481
+ logger.info("所有模型已移至CPU")
482
+
483
+ except Exception as e:
484
+ logger.error(f"移动模型到CPU失败: {e}")
485
+
486
+ def move_models_to_gpu(self):
487
+ """将模型移回GPU"""
488
+ try:
489
+ logger.info("将所有模型移回GPU...")
490
+
491
+ models_to_move = [
492
+ ('caption_model', self.caption_model),
493
+ ('clip_model', self.clip_model),
494
+ ('sd_pipeline', self.sd_pipeline),
495
+ ('controlnet_pipeline', self.controlnet_pipeline),
496
+ ('controlnet', self.controlnet)
497
+ ]
498
+
499
+ for model_name, model in models_to_move:
500
+ if model is not None:
501
+ try:
502
+ if hasattr(model, 'to'):
503
+ model.to(self.device)
504
+ logger.info(f"{model_name} 已移回GPU")
505
+ except Exception as e:
506
+ logger.warning(f"移动 {model_name} 到GPU失败: {e}")
507
+
508
+ if torch.cuda.is_available():
509
+ allocated = torch.cuda.memory_allocated() / 1024**3
510
+ logger.info(f"移回GPU后显存使用: {allocated:.2f}GB")
511
+
512
+ logger.info("所有模型已移回GPU")
513
+
514
+ except Exception as e:
515
+ logger.error(f"移动模型到GPU失败: {e}")
516
+
517
+ def force_reload_all_models(self):
518
+ """强制重新加载所有模型"""
519
+ logger.info("开始强制重新加载所有模型...")
520
+ try:
521
+ # 释放现有模型
522
+ models_to_delete = [
523
+ 'caption_model', 'caption_processor',
524
+ 'clip_model', 'clip_processor',
525
+ 'sd_pipeline', 'controlnet', 'controlnet_pipeline'
526
+ ]
527
+
528
+ for model_name in models_to_delete:
529
+ if hasattr(self, model_name):
530
+ model = getattr(self, model_name)
531
+ if model is not None:
532
+ try:
533
+ del model
534
+ setattr(self, model_name, None)
535
+ logger.info(f"释放 {model_name}")
536
+ except Exception as e:
537
+ logger.warning(f"释放 {model_name} 失败: {e}")
538
+
539
+ # 强制垃圾回收
540
+ gc.collect()
541
+
542
+ # 清理GPU缓存
543
+ if torch.cuda.is_available():
544
+ torch.cuda.empty_cache()
545
+ torch.cuda.ipc_collect()
546
+
547
+ logger.info("开始重新加载模型...")
548
+
549
+ # 重新加载所有模型
550
+ self.load_all_models()
551
+
552
+ logger.info("所有模型重新加载完成")
553
+
554
+ except Exception as e:
555
+ logger.error(f"强制重新加载模型失败: {e}")
556
+ raise
557
+
558
+ def get_model_status(self):
559
+ """获取模型加载状态"""
560
+ status = {
561
+ "caption_model": self.caption_model is not None,
562
+ "clip_model": self.clip_model is not None,
563
+ "sd_pipeline": self.sd_pipeline is not None,
564
+ "controlnet_pipeline": self.controlnet_pipeline is not None,
565
+ "device": self.device
566
+ }
567
+
568
+ if torch.cuda.is_available():
569
+ status["gpu_memory"] = {
570
+ "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
571
+ "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB",
572
+ "max_allocated": f"{torch.cuda.max_memory_allocated() / 1024**3:.2f}GB"
573
+ }
574
+
575
+ return status
576
+
577
+ def optimize_for_inference(self):
578
+ """优化模型以提高推理速度"""
579
+ logger.info("优化模型推理性能...")
580
+
581
+ try:
582
+ # 编译模型(如果PyTorch版本支持)
583
+ if hasattr(torch, 'compile'):
584
+ models_to_compile = [
585
+ self.caption_model,
586
+ self.clip_model
587
+ ]
588
+
589
+ for model in models_to_compile:
590
+ if model is not None:
591
+ try:
592
+ model = torch.compile(model)
593
+ logger.info(f"模型编译成功")
594
+ except Exception as e:
595
+ logger.info(f"模型编译跳过: {e}")
596
+
597
+ # 设置模型为评估模式
598
+ models = [self.caption_model, self.clip_model]
599
+ for model in models:
600
+ if model is not None:
601
+ model.eval()
602
+
603
+ logger.info("模型优化完成")
604
+
605
+ except Exception as e:
606
+ logger.warning(f"模型优化失败: {e}")
607
+
608
+ def benchmark_models(self):
609
+ """基准测试模型性能"""
610
+ logger.info("开始模型性能基准测试...")
611
+
612
+ try:
613
+ # 创建测试图像
614
+ test_image = Image.new('RGB', (512, 512), color=(128, 128, 128))
615
+
616
+ results = {}
617
+
618
+ # 测试BLIP
619
+ if self.caption_model is not None:
620
+ start_time = time.time()
621
+ _ = self.generate_caption(test_image)
622
+ results['caption_time'] = time.time() - start_time
623
+
624
+ # 测试CLIP
625
+ if self.clip_model is not None:
626
+ start_time = time.time()
627
+ _ = self.analyze_style(test_image)
628
+ results['clip_time'] = time.time() - start_time
629
+
630
+ # 测试SD
631
+ if self.sd_pipeline is not None:
632
+ start_time = time.time()
633
+ _ = self.generate_image("test fashion design", num_inference_steps=5)
634
+ results['sd_time'] = time.time() - start_time
635
+
636
+ logger.info(f"基准测试结果: {results}")
637
+ return results
638
+
639
+ except Exception as e:
640
+ logger.error(f"基准测试失败: {e}")
641
+ return {}