File size: 1,996 Bytes
cd5aabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# clip_utils.py
import logging
import os
from typing import Union, List

import cn_clip.clip as clip
import torch
from PIL import Image
from cn_clip.clip import load_from_name

from config import MODELS_PATH

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 环境变量配置
MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16')

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 模型初始化
model = None
preprocess = None

def init_clip_model():
    """初始化CLIP模型"""
    global model, preprocess
    try:
        model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH)
        model.eval()
        logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}")
        return True
    except Exception as e:
        logger.error(f"CLIP model initialization failed: {e}")
        return False

def is_clip_available():
    """检查CLIP模型是否可用"""
    return model is not None and preprocess is not None

def encode_image(image_path: str) -> torch.Tensor:
    """编码图片为向量"""
    if not is_clip_available():
        raise RuntimeError("CLIP模型未初始化")

    image = Image.open(image_path).convert("RGB")
    image_tensor = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model.encode_image(image_tensor)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
    return features.cpu()

def encode_text(text: Union[str, List[str]]) -> torch.Tensor:
    """编码文本为向量"""
    if not is_clip_available():
        raise RuntimeError("CLIP模型未初始化")

    texts = [text] if isinstance(text, str) else text
    text_tokens = clip.tokenize(texts).to(device)
    with torch.no_grad():
        features = model.encode_text(text_tokens)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
    return features.cpu()