Humphreykowl commited on
Commit
5946c21
·
verified ·
1 Parent(s): 97c17b5

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +32 -74
models/model_manager.py CHANGED
@@ -1,77 +1,35 @@
1
- # models/model_manager.py - 增强版
2
- import torch
3
- from PIL import Image
4
- from transformers import (
5
- BlipProcessor,
6
- BlipForConditionalGeneration,
7
- CLIPProcessor,
8
- CLIPModel
9
- )
10
- from diffusers import (
11
- StableDiffusionPipeline,
12
- StableDiffusionControlNetPipeline,
13
- ControlNetModel,
14
- EulerAncestralDiscreteScheduler
15
- )
16
- import numpy as np
17
- import gc
18
- import os
19
- import logging
20
- import time
21
- from typing import Optional, Dict, List, Tuple
 
 
22
 
23
- # 设置日志
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
 
27
- class ModelManager:
28
- def __init__(self):
29
- # 自动检测设备
30
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
- logger.info(f"使用设备: {self.device}")
32
-
33
- # 初始化模型为空
34
- self.caption_model = None
35
- self.caption_processor = None
36
- self.clip_model = None
37
- self.clip_processor = None
38
- self.sd_pipeline = None
39
- self.controlnet_pipeline = None
40
- self.controlnet = None
41
-
42
- # 模型配置 - 针对Spaces环境优化
43
- self.model_config = {
44
- "caption_model": "Salesforce/blip-image-captioning-base",
45
- "clip_model": "openai/clip-vit-base-patch32",
46
- "sd_model": "runwayml/stable-diffusion-v1-5", # 使用更稳定的v1.5
47
- "controlnet_model": "lllyasviel/sd-controlnet-openpose"
48
- }
49
-
50
- # 创建缓存目录
51
- self.cache_dir = "/tmp/models"
52
- os.makedirs(self.cache_dir, exist_ok=True)
53
- logger.info(f"模型缓存目录: {self.cache_dir}")
54
-
55
- # 加载统计
56
- self.load_times = {}
57
- self.last_used = {}
58
-
59
- # 预热标志
60
- self.models_warmed = False
61
 
62
- def load_caption_model(self):
63
- """加载图像描述模型"""
64
- if self.caption_model is None:
65
- start_time = time.time()
66
- logger.info("正在加载BLIP图像描述模型...")
67
-
68
- try:
69
- self.caption_processor = BlipProcessor.from_pretrained(
70
- self.model_config["caption_model"],
71
- cache_dir=self.cache_dir
72
- )
73
-
74
- self.caption_model = BlipForConditionalGeneration.from_pretrained(
75
- self.model_config["caption_model"],
76
- cache_dir=self.cache_dir,
77
- torch_dtype=torch.float16 if self.
 
1
+ def load_caption_model(self):
2
+ """加载图像描述模型"""
3
+ if self.caption_model is None:
4
+ start_time = time.time()
5
+ logger.info("正在加载BLIP图像描述模型...")
6
+
7
+ try:
8
+ self.caption_processor = BlipProcessor.from_pretrained(
9
+ self.model_config["caption_model"],
10
+ cache_dir=self.cache_dir
11
+ )
12
+
13
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
14
+ self.model_config["caption_model"],
15
+ cache_dir=self.cache_dir,
16
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
17
+ ).to(self.device)
18
+
19
+ self.load_times["caption_model"] = time.time() - start_time
20
+ self.last_used["caption_model"] = time.time()
21
+ logger.info(f"BLIP图像描述模型加载完成,用时 {self.load_times['caption_model']:.2f} 秒")
22
+ except Exception as e:
23
+ logger.error(f"加载BLIP图像描述模型失败: {e}")
24
 
 
 
 
25
 
26
+ def caption_image(self, image: Image.Image) -> str:
27
+ """对图像生成描述"""
28
+ if self.caption_model is None or self.caption_processor is None:
29
+ self.load_caption_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
32
+ with torch.no_grad():
33
+ output = self.caption_model.generate(**inputs, max_length=50)
34
+
35
+ return self.caption_processor.decode(output[0], skip_special_tokens=True)