Humphreykowl commited on
Commit
9888744
·
verified ·
1 Parent(s): 891265c

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +18 -215
models/model_manager.py CHANGED
@@ -1,195 +1,19 @@
1
- import torch
2
- from PIL import Image
3
- import numpy as np # 添加缺失的导入
4
- from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
5
- from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
6
- import os
7
- import logging
8
- import time
9
-
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- class ModelManager:
14
- def __init__(self):
15
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
- logger.info(f"使用设备: {self.device}")
17
-
18
- # 模型配置(更新了 SD 模型路径)
19
- self.model_config = {
20
- "caption_model": "Salesforce/blip-image-captioning-base",
21
- "clip_model": "openai/clip-vit-base-patch32",
22
- "sd_model": "runwayml/stable-diffusion-v1-5",
23
- "controlnet_model": "lllyasviel/sd-controlnet-openpose"
24
- }
25
-
26
- # 模型容器
27
- self.caption_processor = None
28
- self.caption_model = None
29
-
30
- self.clip_processor = None
31
- self.clip_model = None
32
-
33
- self.sd_pipeline = None
34
- self.controlnet = None
35
- self.controlnet_pipeline = None
36
-
37
- # 预加载所有模型
38
- self.load_all_models()
39
-
40
- def load_all_models(self):
41
- self.load_caption_model()
42
- self.load_clip_model()
43
- self.load_sd_pipeline()
44
  self.load_controlnet_pipeline()
45
-
46
- def load_caption_model(self):
47
- try:
48
- logger.info("加载 BLIP 图像描述模型...")
49
- self.caption_processor = BlipProcessor.from_pretrained(
50
- self.model_config["caption_model"],
51
- cache_dir="/tmp/models"
52
- )
53
- self.caption_model = BlipForConditionalGeneration.from_pretrained(
54
- self.model_config["caption_model"],
55
- cache_dir="/tmp/models",
56
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
57
- ).to(self.device)
58
- logger.info("BLIP 模型加载完成")
59
- except Exception as e:
60
- logger.error(f"BLIP 模型加载失败: {e}")
61
-
62
- def load_clip_model(self):
63
- try:
64
- logger.info("加载 CLIP 模型...")
65
- self.clip_processor = CLIPProcessor.from_pretrained(
66
- self.model_config["clip_model"],
67
- cache_dir="/tmp/models"
68
- )
69
- self.clip_model = CLIPModel.from_pretrained(
70
- self.model_config["clip_model"],
71
- cache_dir="/tmp/models",
72
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
73
- ).to(self.device)
74
- logger.info("CLIP 模型加载完成")
75
- except Exception as e:
76
- logger.error(f"CLIP 模型加载失败: {e}")
77
-
78
- def load_sd_pipeline(self):
79
- try:
80
- logger.info("加载 Stable Diffusion Pipeline...")
81
-
82
- # 尝试加载原始模型
83
- try:
84
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
85
- self.model_config["sd_model"],
86
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
87
- cache_dir="/tmp/models",
88
- safety_checker=None,
89
- use_safetensors=True
90
- )
91
- except Exception as e:
92
- logger.warning(f"原始模型加载失败: {e}")
93
- logger.info("尝试加载本地缓存的模型...")
94
-
95
- # 定义本地模型路径
96
- local_model_path = "./local_models/stable-diffusion-v1-5"
97
-
98
- # 检查本地模型是否存在
99
- if os.path.exists(local_model_path):
100
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
101
- local_model_path,
102
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
103
- safety_checker=None
104
- )
105
- logger.info("使用本地缓存的 Stable Diffusion 模型")
106
- else:
107
- logger.error("没有可用的本地模型")
108
- raise
109
-
110
- self.sd_pipeline = self.sd_pipeline.to(self.device)
111
- # 用更高效的调度器
112
- self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
113
- logger.info("Stable Diffusion Pipeline 加载完成")
114
- except Exception as e:
115
- logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
116
-
117
- def load_controlnet_pipeline(self):
118
- try:
119
- logger.info("加载 ControlNet 模型和 Pipeline...")
120
- self.controlnet = ControlNetModel.from_pretrained(
121
- self.model_config["controlnet_model"],
122
- cache_dir="/tmp/models",
123
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
124
- ).to(self.device)
125
-
126
- self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
127
- self.model_config["sd_model"],
128
- controlnet=self.controlnet,
129
- cache_dir="/tmp/models",
130
- torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
131
- safety_checker=None
132
- ).to(self.device)
133
-
134
- self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
135
- logger.info("ControlNet Pipeline 加载完成")
136
- except Exception as e:
137
- logger.error(f"ControlNet Pipeline 加载失败: {e}")
138
-
139
- # 下面是真正调用模型的接口
140
-
141
- def generate_caption(self, image):
142
- if self.caption_model is None or self.caption_processor is None:
143
- self.load_caption_model()
144
-
145
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
146
- with torch.no_grad():
147
- outputs = self.caption_model.generate(**inputs, max_length=50)
148
- caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
149
- return caption
150
-
151
- def analyze_style(self, image):
152
- if self.clip_model is None or self.clip_processor is None:
153
- self.load_clip_model()
154
 
155
- inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
156
- with torch.no_grad():
157
- outputs = self.clip_model.get_image_features(**inputs)
158
- features = outputs.cpu().numpy()[0]
159
- # 简单归一化(范例)
160
- norm = features / (np.linalg.norm(features) + 1e-10)
161
- style_score = { "clip_feature_vector": norm }
162
- return style_score
163
-
164
- def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
165
- if self.sd_pipeline is None:
166
- self.load_sd_pipeline()
167
- if self.sd_pipeline is None:
168
- logger.error("无法生成图像:Stable Diffusion 模型未加载")
169
- # 创建占位图像
170
- color = (180, 180, 180)
171
- return Image.new('RGB', (width, height), color=color)
172
-
173
- # Stable Diffusion 生成图像
174
- result = self.sd_pipeline(
175
- prompt=prompt,
176
- negative_prompt=negative_prompt,
177
- num_inference_steps=num_inference_steps,
178
- guidance_scale=guidance_scale,
179
- height=height,
180
- width=width
181
- )
182
- return result.images[0]
183
-
184
- def generate_controlnet_image(self, image, prompt, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
185
- if self.controlnet_pipeline is None:
186
- self.load_controlnet_pipeline()
187
- if self.controlnet_pipeline is None:
188
- logger.error("无法生成图像:ControlNet 模型未加载")
189
- # 创建占位图像
190
- return Image.new('RGB', (512, 768), color=(180, 180, 180))
191
-
192
- # 输入的 image 应该是 PIL Image 格式的控制图(比如人体姿态图)
193
  result = self.controlnet_pipeline(
194
  prompt=prompt,
195
  image=image,
@@ -198,28 +22,7 @@ class ModelManager:
198
  guidance_scale=guidance_scale,
199
  )
200
  return result.images[0]
201
-
202
- def cleanup(self):
203
- logger.info("释放模型用显存和缓存...")
204
- try:
205
- if hasattr(self, 'caption_model') and self.caption_model is not None:
206
- del self.caption_model
207
- if hasattr(self, 'caption_processor') and self.caption_processor is not None:
208
- del self.caption_processor
209
- if hasattr(self, 'clip_model') and self.clip_model is not None:
210
- del self.clip_model
211
- if hasattr(self, 'clip_processor') and self.clip_processor is not None:
212
- del self.clip_processor
213
- if hasattr(self, 'sd_pipeline') and self.sd_pipeline is not None:
214
- del self.sd_pipeline
215
- if hasattr(self, 'controlnet') and self.controlnet is not None:
216
- del self.controlnet
217
- if hasattr(self, 'controlnet_pipeline') and self.controlnet_pipeline is not None:
218
- del self.controlnet_pipeline
219
-
220
- torch.cuda.empty_cache()
221
- import gc
222
- gc.collect()
223
- logger.info("显存清理完成")
224
- except Exception as e:
225
- logger.error(f"清理显存失败: {e}")
 
1
+ # 在ModelManager类中添加新方法
2
+ def generate_controlnet_image(self, image, prompt, reference_image=None,
3
+ negative_prompt=None, num_inference_steps=30,
4
+ guidance_scale=8.0):
5
+ """使用ControlNet生成图像,支持参考图像"""
6
+ if self.controlnet_pipeline is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  self.load_controlnet_pipeline()
8
+
9
+ try:
10
+ # 如果有参考图像,将其融入提示词
11
+ if reference_image is not None:
12
+ # 简化的参考图像描述(实际应用中可用CLIP生成详细描述)
13
+ ref_desc = "参考设计风格"
14
+ prompt = f"{prompt}, {ref_desc}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # 生成图像
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  result = self.controlnet_pipeline(
18
  prompt=prompt,
19
  image=image,
 
22
  guidance_scale=guidance_scale,
23
  )
24
  return result.images[0]
25
+ except Exception as e:
26
+ logger.error(f"ControlNet生成失败: {e}")
27
+ # 创建位图像
28
+ return Image.new('RGB', (512, 768), color=(180, 180, 180))