File size: 1,996 Bytes
cd5aabe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# clip_utils.py
import logging
import os
from typing import Union, List
import cn_clip.clip as clip
import torch
from PIL import Image
from cn_clip.clip import load_from_name
from config import MODELS_PATH
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 环境变量配置
MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16')
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 模型初始化
model = None
preprocess = None
def init_clip_model():
"""初始化CLIP模型"""
global model, preprocess
try:
model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH)
model.eval()
logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}")
return True
except Exception as e:
logger.error(f"CLIP model initialization failed: {e}")
return False
def is_clip_available():
"""检查CLIP模型是否可用"""
return model is not None and preprocess is not None
def encode_image(image_path: str) -> torch.Tensor:
"""编码图片为向量"""
if not is_clip_available():
raise RuntimeError("CLIP模型未初始化")
image = Image.open(image_path).convert("RGB")
image_tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
features = model.encode_image(image_tensor)
features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.cpu()
def encode_text(text: Union[str, List[str]]) -> torch.Tensor:
"""编码文本为向量"""
if not is_clip_available():
raise RuntimeError("CLIP模型未初始化")
texts = [text] if isinstance(text, str) else text
text_tokens = clip.tokenize(texts).to(device)
with torch.no_grad():
features = model.encode_text(text_tokens)
features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.cpu()
|