|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Union, List |
|
|
|
|
|
import cn_clip.clip as clip |
|
|
import torch |
|
|
from PIL import Image |
|
|
from cn_clip.clip import load_from_name |
|
|
|
|
|
from config import MODELS_PATH |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MODEL_NAME_CN = os.environ.get('MODEL_NAME_CN', 'ViT-B-16') |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model = None |
|
|
preprocess = None |
|
|
|
|
|
def init_clip_model(): |
|
|
"""初始化CLIP模型""" |
|
|
global model, preprocess |
|
|
try: |
|
|
model, preprocess = load_from_name(MODEL_NAME_CN, device=device, download_root=MODELS_PATH) |
|
|
model.eval() |
|
|
logger.info(f"CLIP model initialized successfully, dimension: {model.visual.output_dim}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"CLIP model initialization failed: {e}") |
|
|
return False |
|
|
|
|
|
def is_clip_available(): |
|
|
"""检查CLIP模型是否可用""" |
|
|
return model is not None and preprocess is not None |
|
|
|
|
|
def encode_image(image_path: str) -> torch.Tensor: |
|
|
"""编码图片为向量""" |
|
|
if not is_clip_available(): |
|
|
raise RuntimeError("CLIP模型未初始化") |
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
image_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
features = model.encode_image(image_tensor) |
|
|
features = features / features.norm(p=2, dim=-1, keepdim=True) |
|
|
return features.cpu() |
|
|
|
|
|
def encode_text(text: Union[str, List[str]]) -> torch.Tensor: |
|
|
"""编码文本为向量""" |
|
|
if not is_clip_available(): |
|
|
raise RuntimeError("CLIP模型未初始化") |
|
|
|
|
|
texts = [text] if isinstance(text, str) else text |
|
|
text_tokens = clip.tokenize(texts).to(device) |
|
|
with torch.no_grad(): |
|
|
features = model.encode_text(text_tokens) |
|
|
features = features / features.norm(p=2, dim=-1, keepdim=True) |
|
|
return features.cpu() |
|
|
|