Humphreykowl commited on
Commit
16b9cb1
·
verified ·
1 Parent(s): 2efd7f8

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +56 -17
models/model_manager.py CHANGED
@@ -1,6 +1,6 @@
1
- # models/model_manager.py
2
  import torch
3
  from PIL import Image
 
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
5
  from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
6
  import os
@@ -19,7 +19,7 @@ class ModelManager:
19
  self.model_config = {
20
  "caption_model": "Salesforce/blip-image-captioning-base",
21
  "clip_model": "openai/clip-vit-base-patch32",
22
- "sd_model": "runwayml/stable-diffusion-v1-5", # 这里用原版,可替换为镜像
23
  "controlnet_model": "lllyasviel/sd-controlnet-openpose"
24
  }
25
 
@@ -78,13 +78,35 @@ class ModelManager:
78
  def load_sd_pipeline(self):
79
  try:
80
  logger.info("加载 Stable Diffusion Pipeline...")
81
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
82
- self.model_config["sd_model"],
83
- revision="fp16" if self.device=="cuda" else None,
84
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
85
- cache_dir="/tmp/models",
86
- safety_checker=None # 可按需配置安全检查器
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  self.sd_pipeline = self.sd_pipeline.to(self.device)
89
  # 用更高效的调度器
90
  self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
@@ -142,6 +164,11 @@ class ModelManager:
142
  def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
143
  if self.sd_pipeline is None:
144
  self.load_sd_pipeline()
 
 
 
 
 
145
 
146
  # Stable Diffusion 生成图像
147
  result = self.sd_pipeline(
@@ -157,6 +184,10 @@ class ModelManager:
157
  def generate_controlnet_image(self, image, prompt, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
158
  if self.controlnet_pipeline is None:
159
  self.load_controlnet_pipeline()
 
 
 
 
160
 
161
  # 输入的 image 应该是 PIL Image 格式的控制图(比如人体姿态图)
162
  result = self.controlnet_pipeline(
@@ -171,16 +202,24 @@ class ModelManager:
171
  def cleanup(self):
172
  logger.info("释放模型占用显存和缓存...")
173
  try:
174
- del self.caption_model
175
- del self.caption_processor
176
- del self.clip_model
177
- del self.clip_processor
178
- del self.sd_pipeline
179
- del self.controlnet
180
- del self.controlnet_pipeline
 
 
 
 
 
 
 
 
181
  torch.cuda.empty_cache()
182
  import gc
183
  gc.collect()
184
  logger.info("显存清理完成")
185
  except Exception as e:
186
- logger.error(f"清理显存失败: {e}")
 
 
1
  import torch
2
  from PIL import Image
3
+ import numpy as np # 添加缺失的导入
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
5
  from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
6
  import os
 
19
  self.model_config = {
20
  "caption_model": "Salesforce/blip-image-captioning-base",
21
  "clip_model": "openai/clip-vit-base-patch32",
22
+ "sd_model": "runwayml/stable-diffusion-v1-5",
23
  "controlnet_model": "lllyasviel/sd-controlnet-openpose"
24
  }
25
 
 
78
  def load_sd_pipeline(self):
79
  try:
80
  logger.info("加载 Stable Diffusion Pipeline...")
81
+
82
+ # 尝试加载原始模型
83
+ try:
84
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
85
+ self.model_config["sd_model"],
86
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
87
+ cache_dir="/tmp/models",
88
+ safety_checker=None,
89
+ use_safetensors=True
90
+ )
91
+ except Exception as e:
92
+ logger.warning(f"原始模型加载失败: {e}")
93
+ logger.info("尝试加载本地缓存的模型...")
94
+
95
+ # 定义本地模型路径
96
+ local_model_path = "./local_models/stable-diffusion-v1-5"
97
+
98
+ # 检查本地模型是否存在
99
+ if os.path.exists(local_model_path):
100
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
101
+ local_model_path,
102
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
103
+ safety_checker=None
104
+ )
105
+ logger.info("使用本地缓存的 Stable Diffusion 模型")
106
+ else:
107
+ logger.error("没有可用的本地模型")
108
+ raise
109
+
110
  self.sd_pipeline = self.sd_pipeline.to(self.device)
111
  # 用更高效的调度器
112
  self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
 
164
  def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
165
  if self.sd_pipeline is None:
166
  self.load_sd_pipeline()
167
+ if self.sd_pipeline is None:
168
+ logger.error("无法生成图像:Stable Diffusion 模型未加载")
169
+ # 创建占位图像
170
+ color = (180, 180, 180)
171
+ return Image.new('RGB', (width, height), color=color)
172
 
173
  # Stable Diffusion 生成图像
174
  result = self.sd_pipeline(
 
184
  def generate_controlnet_image(self, image, prompt, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
185
  if self.controlnet_pipeline is None:
186
  self.load_controlnet_pipeline()
187
+ if self.controlnet_pipeline is None:
188
+ logger.error("无法生成图像:ControlNet 模型未加载")
189
+ # 创建占位图像
190
+ return Image.new('RGB', (512, 768), color=(180, 180, 180))
191
 
192
  # 输入的 image 应该是 PIL Image 格式的控制图(比如人体姿态图)
193
  result = self.controlnet_pipeline(
 
202
  def cleanup(self):
203
  logger.info("释放模型占用显存和缓存...")
204
  try:
205
+ if hasattr(self, 'caption_model') and self.caption_model is not None:
206
+ del self.caption_model
207
+ if hasattr(self, 'caption_processor') and self.caption_processor is not None:
208
+ del self.caption_processor
209
+ if hasattr(self, 'clip_model') and self.clip_model is not None:
210
+ del self.clip_model
211
+ if hasattr(self, 'clip_processor') and self.clip_processor is not None:
212
+ del self.clip_processor
213
+ if hasattr(self, 'sd_pipeline') and self.sd_pipeline is not None:
214
+ del self.sd_pipeline
215
+ if hasattr(self, 'controlnet') and self.controlnet is not None:
216
+ del self.controlnet
217
+ if hasattr(self, 'controlnet_pipeline') and self.controlnet_pipeline is not None:
218
+ del self.controlnet_pipeline
219
+
220
  torch.cuda.empty_cache()
221
  import gc
222
  gc.collect()
223
  logger.info("显存清理完成")
224
  except Exception as e:
225
+ logger.error(f"清理显存失败: {e}")