File size: 4,972 Bytes
a52c5c7 d2100e7 a52c5c7 d2100e7 a52c5c7 d2100e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import os
from langchain_core.embeddings import Embeddings
from typing import List
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
# huggingface-cli login ํน์ HF_TOKEN ํ๊ฒฝ๋ณ์ ํ์
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
class OnnxGemmaWrapper(Embeddings):
def __init__(self, model_id, token=None):
print(f"Loading ONNX model: {model_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
# ONNX ๋ชจ๋ธ ๋ฐ ๊ฐ์ค์น ๋ค์ด๋ก๋
model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx", token=token)
try:
hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data", token=token)
except Exception:
pass # model.onnx_data๊ฐ ์์ ์๋ ์์ (์์ ๋ชจ๋ธ์ ๊ฒฝ์ฐ)
# ์ถ๋ก ์ธ์
์์ฑ (GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ CUDAProvider ์ฌ์ฉ, ์์ผ๋ฉด CPU)
available_providers = ort.get_available_providers()
if 'CUDAExecutionProvider' in available_providers:
print("CUDA detected. Using GPU.")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
print("CUDA not detected. Using CPU.")
providers = ['CPUExecutionProvider']
self.session = ort.InferenceSession(model_path, providers=providers)
# Prefix ์ ์
self.prefixes = {
"query": "task: search result | query: ",
"document": "title: none | text: ",
}
print("ONNX Model loaded successfully.")
def _run_inference(self, texts: List[str]):
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np")
# ONNX Runtime ์คํ (output[0]: last_hidden_state, output[1]: pooler_output or sentence_embedding)
# EmbeddingGemma ONNX ๋ชจ๋ธ์ ๋ณดํต ๋ ๋ฒ์งธ ๋ฆฌํด๊ฐ์ด sentence embedding์
๋๋ค.
outputs = self.session.run(None, dict(inputs))
# outputs[1]์ด (Batch, 768) ํํ์ ์๋ฒ ๋ฉ
return outputs[1]
def encode_document(self, documents: List[str]) -> np.ndarray:
# ๋ฌธ์์ฉ Prefix ์ถ๊ฐ
prefixed_docs = [self.prefixes["document"] + doc for doc in documents]
return self._run_inference(prefixed_docs)
def encode_query(self, query: str) -> np.ndarray:
# ์ฟผ๋ฆฌ์ฉ Prefix ์ถ๊ฐ (๋จ์ผ ์ฟผ๋ฆฌ๋ ๋ฆฌ์คํธ๋ก ์ฒ๋ฆฌ)
prefixed_query = [self.prefixes["query"] + query]
return self._run_inference(prefixed_query)[0]
def similarity(self, query_emb: np.ndarray, doc_embs: np.ndarray) -> np.ndarray:
if query_emb.ndim == 1:
query_emb = query_emb.reshape(1, -1)
scores = query_emb @ doc_embs.T
return scores.flatten()
# --- LangChain Compatibility Methods ---
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.encode_document(texts).tolist()
def embed_query(self, text: str) -> List[float]:
return self.encode_query(text).tolist()
import torch
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from PIL import Image
# ... (existing OnnxGemmaWrapper and get_embedding_model)
class EfficientNetV2Embedding:
def __init__(self):
print("Loading EfficientNetV2-S model...")
self.weights = EfficientNet_V2_S_Weights.DEFAULT
self.model = efficientnet_v2_s(weights=self.weights)
self.model.eval()
# Remove the classification head to get embeddings
self.model.classifier = torch.nn.Identity()
self.preprocess = self.weights.transforms()
print("EfficientNetV2-S model loaded successfully.")
def embed_image(self, image: Image.Image) -> List[float]:
# Preprocess image
img_tensor = self.preprocess(image).unsqueeze(0)
with torch.no_grad():
embedding = self.model(img_tensor)
return embedding.squeeze(0).tolist()
# ์ ์ญ ์ฑ๊ธํค ์ธ์คํด์ค ์ ์ฅ์
_embedding_model = None
_image_embedding_model = None
def get_embedding_model() -> OnnxGemmaWrapper:
"""
ONNX ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ต์ด 1ํ ๋ก๋ํ์ฌ ์ฑ๊ธํค์ผ๋ก ์ฌ์ฌ์ฉํฉ๋๋ค.
"""
global _embedding_model
if _embedding_model is None:
_embedding_model = OnnxGemmaWrapper(
model_id="onnx-community/embeddinggemma-300m-ONNX",
token=hf_token
)
return _embedding_model
def get_image_embedding_model() -> EfficientNetV2Embedding:
"""
EfficientNetV2-S ๋ชจ๋ธ์ ์ต์ด 1ํ ๋ก๋ํ์ฌ ์ฑ๊ธํค์ผ๋ก ์ฌ์ฌ์ฉํฉ๋๋ค.
"""
global _image_embedding_model
if _image_embedding_model is None:
_image_embedding_model = EfficientNetV2Embedding()
return _image_embedding_model
|