Humphreykowl commited on
Commit
57e54f6
·
verified ·
1 Parent(s): bba66a9

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +193 -16
models/model_manager.py CHANGED
@@ -1,28 +1,205 @@
1
- # 在ModelManager类中添加新方法
2
- def generate_controlnet_image(self, image, prompt, reference_image=None,
3
- negative_prompt=None, num_inference_steps=30,
4
- guidance_scale=8.0):
5
- """使用ControlNet生成图像,支持参考图像"""
6
- if self.controlnet_pipeline is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  self.load_controlnet_pipeline()
8
-
9
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # 如果有参考图像,将其融入提示词
11
  if reference_image is not None:
12
- # 简化的参考图像描述(实际应用中可用CLIP生成详细描述)
13
- ref_desc = "参考设计风格"
14
- prompt = f"{prompt}, {ref_desc}"
15
 
16
  # 生成图像
17
  result = self.controlnet_pipeline(
18
  prompt=prompt,
19
- image=image,
20
  negative_prompt=negative_prompt,
21
  num_inference_steps=num_inference_steps,
22
  guidance_scale=guidance_scale,
23
  )
24
  return result.images[0]
25
- except Exception as e:
26
- logger.error(f"ControlNet生成失败: {e}")
27
- # 创建占位图像
28
- return Image.new('RGB', (512, 768), color=(180, 180, 180))
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
5
+ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
6
+ import os
7
+ import logging
8
+ import time
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class ModelManager:
14
+ def __init__(self):
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ logger.info(f"使用设备: {self.device}")
17
+
18
+ # 模型配置 - 使用更精细的3D模型
19
+ self.model_config = {
20
+ "caption_model": "Salesforce/blip-image-captioning-large",
21
+ "clip_model": "openai/clip-vit-large-patch14",
22
+ "sd_model": "runwayml/stable-diffusion-v1-5",
23
+ "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
24
+ }
25
+
26
+ # 模型容器
27
+ self.caption_processor = None
28
+ self.caption_model = None
29
+ self.clip_processor = None
30
+ self.clip_model = None
31
+ self.sd_pipeline = None
32
+ self.controlnet = None
33
+ self.controlnet_pipeline = None
34
+
35
+ # 预加载所有模型
36
+ self.load_all_models()
37
+
38
+ def load_all_models(self):
39
+ self.load_caption_model()
40
+ self.load_clip_model()
41
+ self.load_sd_pipeline()
42
  self.load_controlnet_pipeline()
43
+
44
+ def load_caption_model(self):
45
+ try:
46
+ logger.info("加载 BLIP 图像描述模型...")
47
+ self.caption_processor = BlipProcessor.from_pretrained(
48
+ self.model_config["caption_model"],
49
+ cache_dir="/tmp/models"
50
+ )
51
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
52
+ self.model_config["caption_model"],
53
+ cache_dir="/tmp/models",
54
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
55
+ ).to(self.device)
56
+ logger.info("BLIP 模型加载完成")
57
+ except Exception as e:
58
+ logger.error(f"BLIP 模型加载失败: {e}")
59
+
60
+ def load_clip_model(self):
61
+ try:
62
+ logger.info("加载 CLIP 模型...")
63
+ self.clip_processor = CLIPProcessor.from_pretrained(
64
+ self.model_config["clip_model"],
65
+ cache_dir="/tmp/models"
66
+ )
67
+ self.clip_model = CLIPModel.from_pretrained(
68
+ self.model_config["clip_model"],
69
+ cache_dir="/tmp/models",
70
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
71
+ ).to(self.device)
72
+ logger.info("CLIP 模型加载完成")
73
+ except Exception as e:
74
+ logger.error(f"CLIP 模型加载失败: {e}")
75
+
76
+ def load_sd_pipeline(self):
77
+ try:
78
+ logger.info("加载 Stable Diffusion Pipeline...")
79
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
80
+ self.model_config["sd_model"],
81
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
82
+ cache_dir="/tmp/models",
83
+ safety_checker=None,
84
+ use_safetensors=True
85
+ )
86
+ self.sd_pipeline = self.sd_pipeline.to(self.device)
87
+ self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
88
+ logger.info("Stable Diffusion Pipeline 加载完成")
89
+ except Exception as e:
90
+ logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
91
+
92
+ def load_controlnet_pipeline(self):
93
+ try:
94
+ logger.info("加载 ControlNet 模型和 Pipeline...")
95
+ self.controlnet = ControlNetModel.from_pretrained(
96
+ self.model_config["controlnet_model"],
97
+ cache_dir="/tmp/models",
98
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
99
+ ).to(self.device)
100
+
101
+ self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
102
+ self.model_config["sd_model"],
103
+ controlnet=self.controlnet,
104
+ cache_dir="/tmp/models",
105
+ torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
106
+ safety_checker=None
107
+ ).to(self.device)
108
+
109
+ self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
110
+ logger.info("ControlNet Pipeline 加载完成")
111
+ except Exception as e:
112
+ logger.error(f"ControlNet Pipeline 加���失败: {e}")
113
+
114
+ # 下面是真正调用模型的接口
115
+
116
+ def generate_caption(self, image):
117
+ """使用BLIP模型生成图像描述"""
118
+ if self.caption_model is None or self.caption_processor is None:
119
+ self.load_caption_model()
120
+
121
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
122
+ with torch.no_grad():
123
+ outputs = self.caption_model.generate(**inputs, max_length=50)
124
+ caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
125
+ return caption
126
+
127
+ def analyze_style(self, image):
128
+ """使用CLIP模型分析服装风格"""
129
+ if self.clip_model is None or self.clip_processor is None:
130
+ self.load_clip_model()
131
+
132
+ # 定义服装风格类别
133
+ styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
134
+
135
+ # 准备输入
136
+ inputs = self.clip_processor(
137
+ text=styles,
138
+ images=image,
139
+ return_tensors="pt",
140
+ padding=True
141
+ ).to(self.device)
142
+
143
+ # 获取特征
144
+ with torch.no_grad():
145
+ outputs = self.clip_model(**inputs)
146
+ logits_per_image = outputs.logits_per_image
147
+ probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
148
+
149
+ # 转换为分数字典
150
+ style_scores = {style: float(prob) for style, prob in zip(styles, probs)}
151
+ return style_scores
152
+
153
+ def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
154
+ """使用Stable Diffusion生成设计图像"""
155
+ if self.sd_pipeline is None:
156
+ self.load_sd_pipeline()
157
+ if self.sd_pipeline is None:
158
+ logger.error("无法生成图像:Stable Diffusion 模型未加载")
159
+ return self.create_placeholder_image(width, height)
160
+
161
+ # 生成图像
162
+ result = self.sd_pipeline(
163
+ prompt=prompt,
164
+ negative_prompt=negative_prompt,
165
+ num_inference_steps=num_inference_steps,
166
+ guidance_scale=guidance_scale,
167
+ height=height,
168
+ width=width
169
+ )
170
+ return result.images[0]
171
+
172
+ def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
173
+ """使用ControlNet生成3D试穿效果 - 更精细的模型"""
174
+ if self.controlnet_pipeline is None:
175
+ self.load_controlnet_pipeline()
176
+ if self.controlnet_pipeline is None:
177
+ logger.error("无法生成3D试穿:ControlNet 模型未加载")
178
+ return self.create_placeholder_image(512, 768)
179
+
180
  # 如果有参考图像,将其融入提示词
181
  if reference_image is not None:
182
+ # 这里可以添加将参考图像融入提示词的逻辑
183
+ prompt = f"{prompt}, based on reference design"
 
184
 
185
  # 生成图像
186
  result = self.controlnet_pipeline(
187
  prompt=prompt,
188
+ image=image, # 控制图像(如人体姿态)
189
  negative_prompt=negative_prompt,
190
  num_inference_steps=num_inference_steps,
191
  guidance_scale=guidance_scale,
192
  )
193
  return result.images[0]
194
+
195
+ def create_placeholder_image(self, width, height):
196
+ """创建占位图像"""
197
+ color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200)
198
+ return Image.new('RGB', (width, height), color=color)
199
+
200
+ def cleanup(self):
201
+ logger.info("释放模型占用显存和缓存...")
202
+ try:
203
+ # ...清理代码不变...
204
+ except Exception as e:
205
+ logger.error(f"清理显存失败: {e}")