Humphreykowl commited on
Commit
2efd7f8
·
verified ·
1 Parent(s): 2bf52c2

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +185 -9
models/model_manager.py CHANGED
@@ -1,10 +1,186 @@
1
- def caption_image(self, image) -> str:
2
- """对图像生成描述"""
3
- if self.caption_model is None or self.caption_processor is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  self.load_caption_model()
5
-
6
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
7
- with torch.no_grad():
8
- output = self.caption_model.generate(**inputs, max_length=50)
9
-
10
- return self.caption_processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/model_manager.py
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
5
+ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
6
+ import os
7
+ import logging
8
+ import time
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class ModelManager:
14
+ def __init__(self):
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ logger.info(f"使用设备: {self.device}")
17
+
18
+ # 模型配置(更新了 SD 模型路径)
19
+ self.model_config = {
20
+ "caption_model": "Salesforce/blip-image-captioning-base",
21
+ "clip_model": "openai/clip-vit-base-patch32",
22
+ "sd_model": "runwayml/stable-diffusion-v1-5", # 这里用原版,可替换为镜像
23
+ "controlnet_model": "lllyasviel/sd-controlnet-openpose"
24
+ }
25
+
26
+ # 模型容器
27
+ self.caption_processor = None
28
+ self.caption_model = None
29
+
30
+ self.clip_processor = None
31
+ self.clip_model = None
32
+
33
+ self.sd_pipeline = None
34
+ self.controlnet = None
35
+ self.controlnet_pipeline = None
36
+
37
+ # 预加载所有模型
38
+ self.load_all_models()
39
+
40
+ def load_all_models(self):
41
  self.load_caption_model()
42
+ self.load_clip_model()
43
+ self.load_sd_pipeline()
44
+ self.load_controlnet_pipeline()
45
+
46
+ def load_caption_model(self):
47
+ try:
48
+ logger.info("加载 BLIP 图像描述模型...")
49
+ self.caption_processor = BlipProcessor.from_pretrained(
50
+ self.model_config["caption_model"],
51
+ cache_dir="/tmp/models"
52
+ )
53
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
54
+ self.model_config["caption_model"],
55
+ cache_dir="/tmp/models",
56
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
57
+ ).to(self.device)
58
+ logger.info("BLIP 模型加载完成")
59
+ except Exception as e:
60
+ logger.error(f"BLIP 模型加载失败: {e}")
61
+
62
+ def load_clip_model(self):
63
+ try:
64
+ logger.info("加载 CLIP 模型...")
65
+ self.clip_processor = CLIPProcessor.from_pretrained(
66
+ self.model_config["clip_model"],
67
+ cache_dir="/tmp/models"
68
+ )
69
+ self.clip_model = CLIPModel.from_pretrained(
70
+ self.model_config["clip_model"],
71
+ cache_dir="/tmp/models",
72
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
73
+ ).to(self.device)
74
+ logger.info("CLIP 模型加载完成")
75
+ except Exception as e:
76
+ logger.error(f"CLIP 模型加载失败: {e}")
77
+
78
+ def load_sd_pipeline(self):
79
+ try:
80
+ logger.info("加载 Stable Diffusion Pipeline...")
81
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
82
+ self.model_config["sd_model"],
83
+ revision="fp16" if self.device=="cuda" else None,
84
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
85
+ cache_dir="/tmp/models",
86
+ safety_checker=None # 可按需配置安全检查器
87
+ )
88
+ self.sd_pipeline = self.sd_pipeline.to(self.device)
89
+ # 用更高效的调度器
90
+ self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
91
+ logger.info("Stable Diffusion Pipeline 加载完成")
92
+ except Exception as e:
93
+ logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
94
+
95
+ def load_controlnet_pipeline(self):
96
+ try:
97
+ logger.info("加载 ControlNet 模型和 Pipeline...")
98
+ self.controlnet = ControlNetModel.from_pretrained(
99
+ self.model_config["controlnet_model"],
100
+ cache_dir="/tmp/models",
101
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
102
+ ).to(self.device)
103
+
104
+ self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
105
+ self.model_config["sd_model"],
106
+ controlnet=self.controlnet,
107
+ cache_dir="/tmp/models",
108
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
109
+ safety_checker=None
110
+ ).to(self.device)
111
+
112
+ self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
113
+ logger.info("ControlNet Pipeline 加载完成")
114
+ except Exception as e:
115
+ logger.error(f"ControlNet Pipeline 加载失败: {e}")
116
+
117
+ # 下面是真正调用模型的接口
118
+
119
+ def generate_caption(self, image):
120
+ if self.caption_model is None or self.caption_processor is None:
121
+ self.load_caption_model()
122
+
123
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
124
+ with torch.no_grad():
125
+ outputs = self.caption_model.generate(**inputs, max_length=50)
126
+ caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
127
+ return caption
128
+
129
+ def analyze_style(self, image):
130
+ if self.clip_model is None or self.clip_processor is None:
131
+ self.load_clip_model()
132
+
133
+ inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
134
+ with torch.no_grad():
135
+ outputs = self.clip_model.get_image_features(**inputs)
136
+ features = outputs.cpu().numpy()[0]
137
+ # 简单归一化(范例)
138
+ norm = features / (np.linalg.norm(features) + 1e-10)
139
+ style_score = { "clip_feature_vector": norm }
140
+ return style_score
141
+
142
+ def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
143
+ if self.sd_pipeline is None:
144
+ self.load_sd_pipeline()
145
+
146
+ # Stable Diffusion 生成图像
147
+ result = self.sd_pipeline(
148
+ prompt=prompt,
149
+ negative_prompt=negative_prompt,
150
+ num_inference_steps=num_inference_steps,
151
+ guidance_scale=guidance_scale,
152
+ height=height,
153
+ width=width
154
+ )
155
+ return result.images[0]
156
+
157
+ def generate_controlnet_image(self, image, prompt, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
158
+ if self.controlnet_pipeline is None:
159
+ self.load_controlnet_pipeline()
160
+
161
+ # 输入的 image 应该是 PIL Image 格式的控制图(比如人体姿态图)
162
+ result = self.controlnet_pipeline(
163
+ prompt=prompt,
164
+ image=image,
165
+ negative_prompt=negative_prompt,
166
+ num_inference_steps=num_inference_steps,
167
+ guidance_scale=guidance_scale,
168
+ )
169
+ return result.images[0]
170
+
171
+ def cleanup(self):
172
+ logger.info("释放模型占用显存和缓存...")
173
+ try:
174
+ del self.caption_model
175
+ del self.caption_processor
176
+ del self.clip_model
177
+ del self.clip_processor
178
+ del self.sd_pipeline
179
+ del self.controlnet
180
+ del self.controlnet_pipeline
181
+ torch.cuda.empty_cache()
182
+ import gc
183
+ gc.collect()
184
+ logger.info("显存清理完成")
185
+ except Exception as e:
186
+ logger.error(f"清理显存失败: {e}")