Humphreykowl commited on
Commit
97c17b5
·
verified ·
1 Parent(s): 880ef76

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +12 -374
models/model_manager.py CHANGED
@@ -1,4 +1,4 @@
1
- # models/model_manager.py
2
  import torch
3
  from PIL import Image
4
  from transformers import (
@@ -39,15 +39,15 @@ class ModelManager:
39
  self.controlnet_pipeline = None
40
  self.controlnet = None
41
 
42
- # 模型配置 - 使用较小的模型变体以适应 Space 环境
43
  self.model_config = {
44
- "caption_model": "Salesforce/blip-image-captioning-base", # 基础版节省内存
45
- "clip_model": "openai/clip-vit-base-patch32", # 基础版CLIP
46
- "sd_model": "stabilityai/stable-diffusion-2-1-base", # SD 2.1基础版
47
- "controlnet_model": "lllyasviel/sd-controlnet-openpose" # 姿势控制模型
48
  }
49
 
50
- # 创建缓存目录 - 使用Space的临时目录
51
  self.cache_dir = "/tmp/models"
52
  os.makedirs(self.cache_dir, exist_ok=True)
53
  logger.info(f"模型缓存目录: {self.cache_dir}")
@@ -55,12 +55,15 @@ class ModelManager:
55
  # 加载统计
56
  self.load_times = {}
57
  self.last_used = {}
 
 
 
58
 
59
  def load_caption_model(self):
60
  """加载图像描述模型"""
61
  if self.caption_model is None:
62
  start_time = time.time()
63
- logger.info("正在加载图像描述模型...")
64
 
65
  try:
66
  self.caption_processor = BlipProcessor.from_pretrained(
@@ -71,369 +74,4 @@ class ModelManager:
71
  self.caption_model = BlipForConditionalGeneration.from_pretrained(
72
  self.model_config["caption_model"],
73
  cache_dir=self.cache_dir,
74
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
75
- ).to(self.device)
76
-
77
- # 模型优化
78
- if self.device == "cuda":
79
- self.caption_model = self.caption_model.half()
80
-
81
- logger.info("图像描述模型加载完成")
82
- self.load_times["caption"] = time.time() - start_time
83
- self.last_used["caption"] = time.time()
84
- except Exception as e:
85
- logger.error(f"加载描述模型失败: {str(e)}")
86
- # 尝试回退到更小的模型
87
- self.model_config["caption_model"] = "Salesforce/blip-image-captioning-base"
88
- self.load_caption_model()
89
-
90
- def load_clip_model(self):
91
- """加载CLIP模型用于风格分析"""
92
- if self.clip_model is None:
93
- start_time = time.time()
94
- logger.info("正在加载CLIP模型...")
95
-
96
- try:
97
- self.clip_processor = CLIPProcessor.from_pretrained(
98
- self.model_config["clip_model"],
99
- cache_dir=self.cache_dir
100
- )
101
-
102
- self.clip_model = CLIPModel.from_pretrained(
103
- self.model_config["clip_model"],
104
- cache_dir=self.cache_dir,
105
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
106
- ).to(self.device)
107
-
108
- # 模型优化
109
- if self.device == "cuda":
110
- self.clip_model = self.clip_model.half()
111
-
112
- logger.info("CLIP模型加载完成")
113
- self.load_times["clip"] = time.time() - start_time
114
- self.last_used["clip"] = time.time()
115
- except Exception as e:
116
- logger.error(f"加载CLIP模型失败: {str(e)}")
117
-
118
- def load_sd_pipeline(self):
119
- """加载Stable Diffusion生成管道"""
120
- if self.sd_pipeline is None:
121
- start_time = time.time()
122
- logger.info("正在加载Stable Diffusion模型...")
123
-
124
- # 根据可用内存选择模型变体
125
- if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 10 * 1024**3:
126
- logger.info("检测到有限GPU内存,使用更小的SD模型")
127
- self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5"
128
-
129
- try:
130
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
131
- self.model_config["sd_model"],
132
- cache_dir=self.cache_dir,
133
- safety_checker=None, # 禁用安全检查以节省内存
134
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
135
- ).to(self.device)
136
-
137
- # 优化性能
138
- if self.device == "cuda":
139
- try:
140
- # 启用内存高效注意力
141
- self.sd_pipeline.enable_xformers_memory_efficient_attention()
142
- except:
143
- logger.warning("无法启用xformers,使用回退方案")
144
-
145
- # 启用注意力切片
146
- self.sd_pipeline.enable_attention_slicing()
147
-
148
- logger.info("Stable Diffusion模型加载完成")
149
- self.load_times["sd"] = time.time() - start_time
150
- self.last_used["sd"] = time.time()
151
- except Exception as e:
152
- logger.error(f"加载SD模型失败: {str(e)}")
153
- # 尝试回退到更小的模型
154
- self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5"
155
- self.load_sd_pipeline()
156
-
157
- def load_controlnet_pipeline(self):
158
- """加载ControlNet管道用于3D试穿"""
159
- if self.controlnet_pipeline is None:
160
- start_time = time.time()
161
- logger.info("正在加载ControlNet模型...")
162
-
163
- try:
164
- # 先加载ControlNet模型
165
- self.controlnet = ControlNetModel.from_pretrained(
166
- self.model_config["controlnet_model"],
167
- cache_dir=self.cache_dir,
168
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
169
- )
170
-
171
- # 然后创建ControlNet管道
172
- self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
173
- self.model_config["sd_model"],
174
- controlnet=self.controlnet,
175
- cache_dir=self.cache_dir,
176
- safety_checker=None,
177
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
178
- ).to(self.device)
179
-
180
- # 设置调度器
181
- self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
182
- self.controlnet_pipeline.scheduler.config
183
- )
184
-
185
- # 优化性能
186
- if self.device == "cuda":
187
- try:
188
- self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
189
- except:
190
- logger.warning("无法为ControlNet启用xformers")
191
-
192
- self.controlnet_pipeline.enable_attention_slicing()
193
-
194
- logger.info("ControlNet模型加载完成")
195
- self.load_times["controlnet"] = time.time() - start_time
196
- self.last_used["controlnet"] = time.time()
197
- except Exception as e:
198
- logger.error(f"加载ControlNet模型失败: {str(e)}")
199
-
200
- def generate_caption(self, image: Image.Image) -> str:
201
- """为图像生成描述性标题"""
202
- try:
203
- self.load_caption_model()
204
- self.last_used["caption"] = time.time()
205
-
206
- # 准备输入
207
- inputs = self.caption_processor(
208
- images=image,
209
- return_tensors="pt"
210
- ).to(self.device, torch.float16 if self.device == "cuda" else torch.float32)
211
-
212
- # 生成标题
213
- output = self.caption_model.generate(**inputs, max_length=50)
214
- caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
215
-
216
- logger.info(f"生成的标题: {caption}")
217
- return caption
218
-
219
- except Exception as e:
220
- logger.error(f"生成标题失败: {str(e)}")
221
- # 返回默认标题
222
- return "时尚服装设计"
223
-
224
- def analyze_style(self, image: Image.Image) -> Dict[str, float]:
225
- """使用CLIP分析图像风格"""
226
- try:
227
- self.load_clip_model()
228
- self.last_used["clip"] = time.time()
229
-
230
- # 定义风格类别
231
- style_labels = [
232
- "商务正装", "休闲风", "运动风", "时尚潮流",
233
- "复古风", "街头风", "优雅风", "民族风"
234
- ]
235
-
236
- # 准备输入
237
- inputs = self.clip_processor(
238
- text=style_labels,
239
- images=image,
240
- return_tensors="pt",
241
- padding=True
242
- ).to(self.device)
243
-
244
- # 获取预测
245
- outputs = self.clip_model(**inputs)
246
- logits_per_image = outputs.logits_per_image
247
- probs = logits_per_image.softmax(dim=1).detach().cpu().numpy()[0]
248
-
249
- # 获取前3个风格
250
- top3_idx = np.argsort(probs)[-3:][::-1]
251
- top_styles = {
252
- style_labels[i]: float(probs[i]) for i in top3_idx
253
- }
254
-
255
- logger.info(f"风格分析结果: {top_styles}")
256
- return top_styles
257
-
258
- except Exception as e:
259
- logger.error(f"风格分析失败: {str(e)}")
260
- # 返回默认风格
261
- return {"休闲风": 0.8, "时尚潮流": 0.7}
262
-
263
- def generate_image(
264
- self,
265
- prompt: str,
266
- negative_prompt: str = "",
267
- num_inference_steps: int = 30,
268
- guidance_scale: float = 7.5,
269
- height: int = 512,
270
- width: int = 512
271
- ) -> Image.Image:
272
- """根据提示生成设计图像"""
273
- try:
274
- self.load_sd_pipeline()
275
- self.last_used["sd"] = time.time()
276
-
277
- # 生成图像
278
- with torch.autocast("cuda" if self.device == "cuda" else "cpu"):
279
- image = self.sd_pipeline(
280
- prompt=prompt,
281
- negative_prompt=negative_prompt,
282
- num_inference_steps=num_inference_steps,
283
- guidance_scale=guidance_scale,
284
- height=height,
285
- width=width
286
- ).images[0]
287
-
288
- logger.info(f"成功生成设计图像: {prompt[:50]}...")
289
- return image
290
-
291
- except Exception as e:
292
- logger.error(f"生成设计图像失败: {str(e)}")
293
- # 创建占位图像
294
- return Image.new('RGB', (512, 512), color=(220, 220, 220))
295
-
296
- def generate_controlnet_image(
297
- self,
298
- image: Image.Image,
299
- prompt: str,
300
- negative_prompt: str = "",
301
- num_inference_steps: int = 35,
302
- guidance_scale: float = 8.0
303
- ) -> Image.Image:
304
- """使用ControlNet生成3D试穿图像"""
305
- try:
306
- self.load_controlnet_pipeline()
307
- self.last_used["controlnet"] = time.time()
308
-
309
- # 生成图像
310
- with torch.autocast("cuda" if self.device == "cuda" else "cpu"):
311
- image = self.controlnet_pipeline(
312
- prompt=prompt,
313
- image=image,
314
- negative_prompt=negative_prompt,
315
- num_inference_steps=num_inference_steps,
316
- guidance_scale=guidance_scale,
317
- controlnet_conditioning_scale=0.8
318
- ).images[0]
319
-
320
- logger.info(f"成功生成3D试穿图像")
321
- return image
322
-
323
- except Exception as e:
324
- logger.error(f"生成3D试穿图像失败: {str(e)}")
325
- # 回退到普通SD模型
326
- return self.generate_image(
327
- prompt,
328
- negative_prompt,
329
- num_inference_steps
330
- )
331
-
332
- def unload_model(self, model_type: str):
333
- """卸载指定类型的模型以释放内存"""
334
- logger.info(f"卸载模型: {model_type}")
335
-
336
- if model_type == "caption" and self.caption_model is not None:
337
- del self.caption_model
338
- del self.caption_processor
339
- self.caption_model = None
340
- self.caption_processor = None
341
- logger.info("卸载图像描述模型")
342
-
343
- elif model_type == "clip" and self.clip_model is not None:
344
- del self.clip_model
345
- del self.clip_processor
346
- self.clip_model = None
347
- self.clip_processor = None
348
- logger.info("卸载CLIP模型")
349
-
350
- elif model_type == "sd" and self.sd_pipeline is not None:
351
- del self.sd_pipeline
352
- self.sd_pipeline = None
353
- logger.info("卸载Stable Diffusion模型")
354
-
355
- elif model_type == "controlnet" and self.controlnet_pipeline is not None:
356
- del self.controlnet_pipeline
357
- del self.controlnet
358
- self.controlnet_pipeline = None
359
- self.controlnet = None
360
- logger.info("卸载ControlNet模型")
361
-
362
- # 清理内存
363
- self.cleanup_memory()
364
-
365
- def cleanup(self):
366
- """清理所有模型释放内存"""
367
- logger.info("清理所有模型释放内存...")
368
-
369
- # 释放所有模型
370
- if self.caption_model is not None:
371
- del self.caption_model
372
- if self.caption_processor is not None:
373
- del self.caption_processor
374
- if self.clip_model is not None:
375
- del self.clip_model
376
- if self.clip_processor is not None:
377
- del self.clip_processor
378
- if self.sd_pipeline is not None:
379
- del self.sd_pipeline
380
- if self.controlnet_pipeline is not None:
381
- del self.controlnet_pipeline
382
- if self.controlnet is not None:
383
- del self.controlnet
384
-
385
- # 重置引用
386
- self.caption_model = None
387
- self.caption_processor = None
388
- self.clip_model = None
389
- self.clip_processor = None
390
- self.sd_pipeline = None
391
- self.controlnet_pipeline = None
392
- self.controlnet = None
393
-
394
- # 清理内存
395
- self.cleanup_memory()
396
- logger.info("内存清理完成")
397
-
398
- def cleanup_memory(self):
399
- """执行内存清理操作"""
400
- # 清理CUDA缓存
401
- if torch.cuda.is_available():
402
- torch.cuda.empty_cache()
403
-
404
- # 执行垃圾回收
405
- gc.collect()
406
-
407
- def get_memory_usage(self) -> Dict[str, float]:
408
- """获取当前内存使用情况"""
409
- mem_info = {}
410
-
411
- if torch.cuda.is_available():
412
- mem_info["gpu_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
413
- mem_info["gpu_used"] = torch.cuda.memory_allocated() / (1024**3)
414
- mem_info["gpu_free"] = mem_info["gpu_total"] - mem_info["gpu_used"]
415
-
416
- return mem_info
417
-
418
- def get_model_status(self) -> Dict[str, str]:
419
- """获取模型加载状态"""
420
- status = {
421
- "caption_model": "已加载" if self.caption_model else "未加载",
422
- "clip_model": "已加载" if self.clip_model else "未加载",
423
- "sd_model": "已加载" if self.sd_pipeline else "未加载",
424
- "controlnet_model": "已加载" if self.controlnet_pipeline else "未加载"
425
- }
426
-
427
- # 添加加载时间信息
428
- for model in ["caption", "clip", "sd", "controlnet"]:
429
- if model in self.load_times:
430
- status[f"{model}_load_time"] = f"{self.load_times[model]:.2f}秒"
431
- if model in self.last_used:
432
- mins_ago = (time.time() - self.last_used[model]) / 60
433
- status[f"{model}_last_used"] = f"{mins_ago:.1f}分钟前"
434
-
435
- return status
436
-
437
- def __del__(self):
438
- """析构函数确保资源释放"""
439
- self.cleanup()
 
1
+ # models/model_manager.py - 增强版
2
  import torch
3
  from PIL import Image
4
  from transformers import (
 
39
  self.controlnet_pipeline = None
40
  self.controlnet = None
41
 
42
+ # 模型配置 - 针对Spaces环境优化
43
  self.model_config = {
44
+ "caption_model": "Salesforce/blip-image-captioning-base",
45
+ "clip_model": "openai/clip-vit-base-patch32",
46
+ "sd_model": "runwayml/stable-diffusion-v1-5", # 使用更稳定的v1.5
47
+ "controlnet_model": "lllyasviel/sd-controlnet-openpose"
48
  }
49
 
50
+ # 创建缓存目录
51
  self.cache_dir = "/tmp/models"
52
  os.makedirs(self.cache_dir, exist_ok=True)
53
  logger.info(f"模型缓存目录: {self.cache_dir}")
 
55
  # 加载统计
56
  self.load_times = {}
57
  self.last_used = {}
58
+
59
+ # 预热标志
60
+ self.models_warmed = False
61
 
62
  def load_caption_model(self):
63
  """加载图像描述模型"""
64
  if self.caption_model is None:
65
  start_time = time.time()
66
+ logger.info("正在加载BLIP图像描述模型...")
67
 
68
  try:
69
  self.caption_processor = BlipProcessor.from_pretrained(
 
74
  self.caption_model = BlipForConditionalGeneration.from_pretrained(
75
  self.model_config["caption_model"],
76
  cache_dir=self.cache_dir,
77
+ torch_dtype=torch.float16 if self.