Humphreykowl commited on
Commit
422bb60
·
verified ·
1 Parent(s): 2c2d2b6

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +213 -154
models/model_manager.py CHANGED
@@ -6,7 +6,9 @@ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionC
6
  import os
7
  import logging
8
  import time
9
- import random # 补充导入
 
 
10
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
@@ -16,10 +18,10 @@ class ModelManager:
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
  logger.info(f"使用设备: {self.device}")
18
 
19
- # 模型配置 - 使用更精细的3D模型
20
  self.model_config = {
21
  "caption_model": "Salesforce/blip-image-captioning-large",
22
- "clip_model": "openai/clip-vit-large-patch14",
23
  "sd_model": "runwayml/stable-diffusion-v1-5",
24
  "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
25
  }
@@ -32,244 +34,301 @@ class ModelManager:
32
  self.sd_pipeline = None
33
  self.controlnet = None
34
  self.controlnet_pipeline = None
35
-
 
 
 
 
 
36
  # 预加载所有模型
37
  self.load_all_models()
38
 
 
 
 
 
 
 
 
 
39
  def load_all_models(self):
40
- self.load_caption_model()
41
- self.load_clip_model()
42
- self.load_sd_pipeline()
43
- self.load_controlnet_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def load_caption_model(self):
 
46
  try:
47
  logger.info("加载 BLIP 图像描述模型...")
 
48
  self.caption_processor = BlipProcessor.from_pretrained(
49
  self.model_config["caption_model"],
50
  cache_dir="/tmp/models"
51
  )
 
52
  self.caption_model = BlipForConditionalGeneration.from_pretrained(
53
  self.model_config["caption_model"],
54
  cache_dir="/tmp/models",
55
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
 
56
  ).to(self.device)
 
 
 
 
 
 
57
  logger.info("BLIP 模型加载完成")
 
58
  except Exception as e:
59
  logger.error(f"BLIP 模型加载失败: {e}")
 
 
60
 
61
  def load_clip_model(self):
 
62
  try:
63
  logger.info("加载 CLIP 模型...")
 
64
  self.clip_processor = CLIPProcessor.from_pretrained(
65
  self.model_config["clip_model"],
66
  cache_dir="/tmp/models"
67
  )
 
68
  self.clip_model = CLIPModel.from_pretrained(
69
  self.model_config["clip_model"],
70
  cache_dir="/tmp/models",
71
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
72
  ).to(self.device)
 
 
73
  logger.info("CLIP 模型加载完成")
 
74
  except Exception as e:
75
  logger.error(f"CLIP 模型加载失败: {e}")
 
 
76
 
77
  def load_sd_pipeline(self):
 
78
  try:
79
  logger.info("加载 Stable Diffusion Pipeline...")
 
80
  self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
81
  self.model_config["sd_model"],
82
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
83
  cache_dir="/tmp/models",
84
  safety_checker=None,
85
- use_safetensors=True
 
 
86
  )
 
 
87
  self.sd_pipeline = self.sd_pipeline.to(self.device)
88
- self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  logger.info("Stable Diffusion Pipeline 加载完成")
 
90
  except Exception as e:
91
  logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
 
92
 
93
  def load_controlnet_pipeline(self):
 
94
  try:
95
  logger.info("加载 ControlNet 模型和 Pipeline...")
 
96
  self.controlnet = ControlNetModel.from_pretrained(
97
  self.model_config["controlnet_model"],
98
  cache_dir="/tmp/models",
99
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
 
100
  ).to(self.device)
101
 
102
  self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
103
  self.model_config["sd_model"],
104
  controlnet=self.controlnet,
105
  cache_dir="/tmp/models",
106
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
107
- safety_checker=None
 
 
108
  ).to(self.device)
109
 
110
- self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  logger.info("ControlNet Pipeline 加载完成")
 
112
  except Exception as e:
113
  logger.error(f"ControlNet Pipeline 加载失败: {e}")
 
 
114
 
115
- # 下面是真正调用模型的接口
116
-
117
  def generate_caption(self, image):
118
  """使用BLIP模型生成图像描述"""
119
  if self.caption_model is None or self.caption_processor is None:
120
  self.load_caption_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
123
- with torch.no_grad():
124
- outputs = self.caption_model.generate(**inputs, max_length=50)
125
- caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
126
- return caption
127
-
128
  def analyze_style(self, image):
129
  """使用CLIP模型分析服装风格"""
130
  if self.clip_model is None or self.clip_processor is None:
131
  self.load_clip_model()
 
 
132
 
133
- styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
134
-
135
- inputs = self.clip_processor(
136
- text=styles,
137
- images=image,
138
- return_tensors="pt",
139
- padding=True
140
- ).to(self.device)
141
-
142
- with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  outputs = self.clip_model(**inputs)
144
  logits_per_image = outputs.logits_per_image
145
  probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
146
-
147
- style_scores = {style: float(prob) for style, prob in zip(styles, probs)}
148
- return style_scores
149
-
150
- def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
151
- """使用Stable Diffusion生成设计图像"""
152
- if self.sd_pipeline is None:
153
- self.load_sd_pipeline()
154
- if self.sd_pipeline is None:
155
- logger.error("无法生成图像:Stable Diffusion 模型未加载")
156
- return self.create_placeholder_image(width, height)
157
-
158
- result = self.sd_pipeline(
159
- prompt=prompt,
160
- negative_prompt=negative_prompt,
161
- num_inference_steps=num_inference_steps,
162
- guidance_scale=guidance_scale,
163
- height=height,
164
- width=width
165
- )
166
- return result.images[0]
167
-
168
- def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
169
- """使用ControlNet生成3D试穿效果 - 更精细的模型"""
170
- if self.controlnet_pipeline is None:
171
- self.load_controlnet_pipeline()
172
- if self.controlnet_pipeline is None:
173
- logger.error("无法生成3D试穿:ControlNet 模型未加载")
174
- return self.create_placeholder_image(512, 768)
175
-
176
- if reference_image is not None:
177
- prompt = f"{prompt}, based on reference design"
178
-
179
- result = self.controlnet_pipeline(
180
- prompt=prompt,
181
- image=image,
182
- negative_prompt=negative_prompt,
183
- num_inference_steps=num_inference_steps,
184
- guidance_scale=guidance_scale,
185
- )
186
- return result.images[0]
187
-
188
- def create_placeholder_image(self, width, height):
189
- """创建占位图像"""
190
- color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200))
191
- return Image.new('RGB', (width, height), color=color)
192
-
193
- def cleanup(self):
194
- """仅清理显存缓存,保留模型以避免重新加载"""
195
- logger.info("清理显存缓存...")
196
- try:
197
- # 只清理缓存,不删除模型
198
- if torch.cuda.is_available():
199
- torch.cuda.empty_cache()
200
 
201
- logger.info("显存缓存清理完成")
202
-
203
- except Exception as e:
204
- logger.error(f"清理显存失败: {e}")
205
-
206
- def move_models_to_cpu(self):
207
- """将模型移动到CPU以释放显存"""
208
- try:
209
- logger.info("将模型移动到CPU...")
210
- if self.caption_model is not None:
211
- self.caption_model = self.caption_model.to('cpu')
212
- if self.clip_model is not None:
213
- self.clip_model = self.clip_model.to('cpu')
214
- if self.sd_pipeline is not None:
215
- self.sd_pipeline = self.sd_pipeline.to('cpu')
216
- if self.controlnet_pipeline is not None:
217
- self.controlnet_pipeline = self.controlnet_pipeline.to('cpu')
218
 
 
 
219
  if torch.cuda.is_available():
220
  torch.cuda.empty_cache()
221
 
222
- logger.info("模型已移动到CPU")
 
223
  except Exception as e:
224
- logger.error(f"移动模型到CPU失败: {e}")
 
225
 
226
- def move_models_to_gpu(self):
227
- """将模型移回GPU"""
228
- try:
229
- logger.info("将模型移动到GPU...")
230
- if self.caption_model is not None:
231
- self.caption_model = self.caption_model.to(self.device)
232
- if self.clip_model is not None:
233
- self.clip_model = self.clip_model.to(self.device)
234
- if self.sd_pipeline is not None:
235
- self.sd_pipeline = self.sd_pipeline.to(self.device)
236
- if self.controlnet_pipeline is not None:
237
- self.controlnet_pipeline = self.controlnet_pipeline.to(self.device)
238
-
239
- logger.info("模型已移回GPU")
240
- except Exception as e:
241
- logger.error(f"移动模型到GPU失败: {e}")
242
 
243
- def force_reload_all_models(self):
244
- """强制重新加载所有模型"""
245
- logger.info("强制重新加载所有模型...")
246
  try:
247
- # 先清理
248
- if hasattr(self, 'caption_model') and self.caption_model is not None:
249
- del self.caption_model
250
- del self.caption_processor
251
- self.caption_model = None
252
- self.caption_processor = None
253
- if hasattr(self, 'clip_model') and self.clip_model is not None:
254
- del self.clip_model
255
- del self.clip_processor
256
- self.clip_model = None
257
- self.clip_processor = None
258
- if hasattr(self, 'sd_pipeline') and self.sd_pipeline is not None:
259
- del self.sd_pipeline
260
- self.sd_pipeline = None
261
- if hasattr(self, 'controlnet_pipeline') and self.controlnet_pipeline is not None:
262
- del self.controlnet
263
- del self.controlnet_pipeline
264
- self.controlnet = None
265
- self.controlnet_pipeline = None
266
 
267
- if torch.cuda.is_available():
268
- torch.cuda.empty_cache()
 
269
 
270
- # 重新加载
271
- self.load_all_models()
272
- logger.info("所有模型重新加载完成")
273
-
274
- except Exception as e:
275
- logger.error(f"强制重新加载模型失败: {e}")
 
6
  import os
7
  import logging
8
  import time
9
+ import random
10
+ import gc
11
+ from functools import lru_cache
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
18
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
  logger.info(f"使用设备: {self.device}")
20
 
21
+ # 优化的模型配置
22
  self.model_config = {
23
  "caption_model": "Salesforce/blip-image-captioning-large",
24
+ "clip_model": "openai/clip-vit-large-patch14",
25
  "sd_model": "runwayml/stable-diffusion-v1-5",
26
  "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
27
  }
 
34
  self.sd_pipeline = None
35
  self.controlnet = None
36
  self.controlnet_pipeline = None
37
+
38
+ # 性能优化设置
39
+ self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
40
+ self.enable_attention_slicing = True
41
+ self.enable_cpu_offload = False # 16GB显存应该够用
42
+
43
  # 预加载所有模型
44
  self.load_all_models()
45
 
46
+ def optimize_memory_usage(self):
47
+ """内存优化设置"""
48
+ if torch.cuda.is_available():
49
+ # 启用内存优化
50
+ torch.backends.cudnn.benchmark = True
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+
54
  def load_all_models(self):
55
+ """按顺序加载所有模型,优化显存使用"""
56
+ self.optimize_memory_usage()
57
+
58
+ try:
59
+ self.load_caption_model()
60
+ self.load_clip_model()
61
+ self.load_sd_pipeline()
62
+ self.load_controlnet_pipeline()
63
+
64
+ logger.info("所有模型加载完成")
65
+ if torch.cuda.is_available():
66
+ logger.info(f"GPU显存使用: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
67
+
68
+ except Exception as e:
69
+ logger.error(f"模型加载过程中出错: {e}")
70
+ raise
71
 
72
  def load_caption_model(self):
73
+ """加载BLIP图像描述模型"""
74
  try:
75
  logger.info("加载 BLIP 图像描述模型...")
76
+
77
  self.caption_processor = BlipProcessor.from_pretrained(
78
  self.model_config["caption_model"],
79
  cache_dir="/tmp/models"
80
  )
81
+
82
  self.caption_model = BlipForConditionalGeneration.from_pretrained(
83
  self.model_config["caption_model"],
84
  cache_dir="/tmp/models",
85
+ torch_dtype=self.torch_dtype,
86
+ low_cpu_mem_usage=True
87
  ).to(self.device)
88
+
89
+ # 启用内存优化
90
+ if hasattr(self.caption_model, 'enable_attention_slicing'):
91
+ self.caption_model.enable_attention_slicing()
92
+
93
+ self.caption_model.eval()
94
  logger.info("BLIP 模型加载完成")
95
+
96
  except Exception as e:
97
  logger.error(f"BLIP 模型加载失败: {e}")
98
+ self.caption_model = None
99
+ self.caption_processor = None
100
 
101
  def load_clip_model(self):
102
+ """加载CLIP风格分析模型"""
103
  try:
104
  logger.info("加载 CLIP 模型...")
105
+
106
  self.clip_processor = CLIPProcessor.from_pretrained(
107
  self.model_config["clip_model"],
108
  cache_dir="/tmp/models"
109
  )
110
+
111
  self.clip_model = CLIPModel.from_pretrained(
112
  self.model_config["clip_model"],
113
  cache_dir="/tmp/models",
114
+ torch_dtype=self.torch_dtype
115
  ).to(self.device)
116
+
117
+ self.clip_model.eval()
118
  logger.info("CLIP 模型加载完成")
119
+
120
  except Exception as e:
121
  logger.error(f"CLIP 模型加载失败: {e}")
122
+ self.clip_model = None
123
+ self.clip_processor = None
124
 
125
  def load_sd_pipeline(self):
126
+ """加载Stable Diffusion Pipeline"""
127
  try:
128
  logger.info("加载 Stable Diffusion Pipeline...")
129
+
130
  self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
131
  self.model_config["sd_model"],
132
+ torch_dtype=self.torch_dtype,
133
  cache_dir="/tmp/models",
134
  safety_checker=None,
135
+ requires_safety_checker=False,
136
+ use_safetensors=True,
137
+ low_cpu_mem_usage=True
138
  )
139
+
140
+ # 优化设置
141
  self.sd_pipeline = self.sd_pipeline.to(self.device)
142
+ self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
143
+ self.sd_pipeline.scheduler.config
144
+ )
145
+
146
+ # 启用内存优化
147
+ if self.enable_attention_slicing:
148
+ self.sd_pipeline.enable_attention_slicing()
149
+
150
+ # 启用内存高效attention(如果可用)
151
+ try:
152
+ self.sd_pipeline.enable_xformers_memory_efficient_attention()
153
+ logger.info("启用了xformers内存优化")
154
+ except:
155
+ logger.info("xformers不可用,使用默认attention")
156
+
157
+ # 启用VAE slicing以节省显存
158
+ self.sd_pipeline.enable_vae_slicing()
159
+
160
  logger.info("Stable Diffusion Pipeline 加载完成")
161
+
162
  except Exception as e:
163
  logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
164
+ self.sd_pipeline = None
165
 
166
  def load_controlnet_pipeline(self):
167
+ """加载ControlNet Pipeline"""
168
  try:
169
  logger.info("加载 ControlNet 模型和 Pipeline...")
170
+
171
  self.controlnet = ControlNetModel.from_pretrained(
172
  self.model_config["controlnet_model"],
173
  cache_dir="/tmp/models",
174
+ torch_dtype=self.torch_dtype,
175
+ low_cpu_mem_usage=True
176
  ).to(self.device)
177
 
178
  self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
179
  self.model_config["sd_model"],
180
  controlnet=self.controlnet,
181
  cache_dir="/tmp/models",
182
+ torch_dtype=self.torch_dtype,
183
+ safety_checker=None,
184
+ requires_safety_checker=False,
185
+ low_cpu_mem_usage=True
186
  ).to(self.device)
187
 
188
+ # 优化设置
189
+ self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
190
+ self.controlnet_pipeline.scheduler.config
191
+ )
192
+
193
+ # 内存优化
194
+ if self.enable_attention_slicing:
195
+ self.controlnet_pipeline.enable_attention_slicing()
196
+
197
+ try:
198
+ self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
199
+ logger.info("ControlNet启用了xformers内存优化")
200
+ except:
201
+ logger.info("ControlNet使用默认attention")
202
+
203
+ self.controlnet_pipeline.enable_vae_slicing()
204
+
205
  logger.info("ControlNet Pipeline 加载完成")
206
+
207
  except Exception as e:
208
  logger.error(f"ControlNet Pipeline 加载失败: {e}")
209
+ self.controlnet = None
210
+ self.controlnet_pipeline = None
211
 
212
+ @torch.no_grad() # 禁用梯度计算节省显存
 
213
  def generate_caption(self, image):
214
  """使用BLIP模型生成图像描述"""
215
  if self.caption_model is None or self.caption_processor is None:
216
  self.load_caption_model()
217
+ if self.caption_model is None:
218
+ return "时尚服装设计作品"
219
+
220
+ try:
221
+ # 预处理图像
222
+ if image.mode != 'RGB':
223
+ image = image.convert('RGB')
224
+
225
+ # 调整图像大小以节省显存
226
+ if image.width > 512 or image.height > 512:
227
+ image.thumbnail((512, 512), Image.Resampling.LANCZOS)
228
+
229
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
230
+
231
+ # 生成描述
232
+ outputs = self.caption_model.generate(
233
+ **inputs,
234
+ max_length=50,
235
+ num_beams=4,
236
+ temperature=0.7,
237
+ do_sample=True
238
+ )
239
+
240
+ caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
241
+
242
+ # 清理显存
243
+ del inputs, outputs
244
+ if torch.cuda.is_available():
245
+ torch.cuda.empty_cache()
246
+
247
+ return caption
248
+
249
+ except Exception as e:
250
+ logger.error(f"图像描述生成失败: {e}")
251
+ return "时尚服装设计作品"
252
 
253
+ @torch.no_grad()
 
 
 
 
 
254
  def analyze_style(self, image):
255
  """使用CLIP模型分析服装风格"""
256
  if self.clip_model is None or self.clip_processor is None:
257
  self.load_clip_model()
258
+ if self.clip_model is None:
259
+ return {"时尚潮流": 0.8, "现代风格": 0.6}
260
 
261
+ try:
262
+ # 风格标签 - 使用英文避免token问题
263
+ style_labels = [
264
+ "business formal suit professional attire",
265
+ "casual comfortable everyday wear",
266
+ "athletic sportswear activewear",
267
+ "fashion trendy modern stylish",
268
+ "vintage retro classic style",
269
+ "streetwear urban contemporary",
270
+ "elegant sophisticated refined"
271
+ ]
272
+
273
+ style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
274
+
275
+ # 预处理图像
276
+ if image.mode != 'RGB':
277
+ image = image.convert('RGB')
278
+
279
+ # 调整图像大小
280
+ if image.width > 224 or image.height > 224:
281
+ image.thumbnail((224, 224), Image.Resampling.LANCZOS)
282
+
283
+ # 处理输入
284
+ inputs = self.clip_processor(
285
+ text=style_labels,
286
+ images=image,
287
+ return_tensors="pt",
288
+ padding=True,
289
+ truncation=True,
290
+ max_length=77 # CLIP的最大长度
291
+ ).to(self.device)
292
+
293
+ # 获取相似度分数
294
  outputs = self.clip_model(**inputs)
295
  logits_per_image = outputs.logits_per_image
296
  probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ # 构建结果
299
+ style_scores = {name: float(prob) for name, prob in zip(style_names, probs)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ # 清理显存
302
+ del inputs, outputs
303
  if torch.cuda.is_available():
304
  torch.cuda.empty_cache()
305
 
306
+ return style_scores
307
+
308
  except Exception as e:
309
+ logger.error(f"风格分析失败: {e}")
310
+ return {"时尚潮流": 0.8, "现代风格": 0.6}
311
 
312
+ @torch.no_grad()
313
+ def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, **kwargs):
314
+ """使用Stable Diffusion生成设计图像"""
315
+ if self.sd_pipeline is None:
316
+ self.load_sd_pipeline()
317
+ if self.sd_pipeline is None:
318
+ logger.error("无法生成图像:Stable Diffusion 模型未加载")
319
+ return self.create_placeholder_image(width, height)
 
 
 
 
 
 
 
 
320
 
 
 
 
321
  try:
322
+ # 优化参数
323
+ if negative_prompt is None:
324
+ negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ # 确保尺寸是8的倍数
327
+ width = (width // 8) * 8
328
+ height = (height // 8) * 8
329
 
330
+ # 生成图像
331
+ result = self.sd_pipeline(
332
+ prompt=prompt,
333
+ negative_prompt=negative_prompt,
334
+ num_inference_steps=num_inference_steps,