| import os |
| from langchain_core.embeddings import Embeddings |
| from typing import List |
| import numpy as np |
| import onnxruntime as ort |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoTokenizer |
|
|
| |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
|
| class OnnxGemmaWrapper(Embeddings): |
| def __init__(self, model_id, token=None): |
| print(f"Loading ONNX model: {model_id}...") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
| |
| |
| model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx", token=token) |
| try: |
| hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data", token=token) |
| except Exception: |
| pass |
| |
| |
| available_providers = ort.get_available_providers() |
| if 'CUDAExecutionProvider' in available_providers: |
| print("CUDA detected. Using GPU.") |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| else: |
| print("CUDA not detected. Using CPU.") |
| providers = ['CPUExecutionProvider'] |
|
|
| self.session = ort.InferenceSession(model_path, providers=providers) |
| |
| |
| self.prefixes = { |
| "query": "task: search result | query: ", |
| "document": "title: none | text: ", |
| } |
| print("ONNX Model loaded successfully.") |
|
|
| def _run_inference(self, texts: List[str]): |
| inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np") |
| |
| |
| outputs = self.session.run(None, dict(inputs)) |
| |
| return outputs[1] |
|
|
| def encode_document(self, documents: List[str]) -> np.ndarray: |
| |
| prefixed_docs = [self.prefixes["document"] + doc for doc in documents] |
| return self._run_inference(prefixed_docs) |
|
|
| def encode_query(self, query: str) -> np.ndarray: |
| |
| prefixed_query = [self.prefixes["query"] + query] |
| return self._run_inference(prefixed_query)[0] |
|
|
| def similarity(self, query_emb: np.ndarray, doc_embs: np.ndarray) -> np.ndarray: |
| if query_emb.ndim == 1: |
| query_emb = query_emb.reshape(1, -1) |
| scores = query_emb @ doc_embs.T |
| return scores.flatten() |
|
|
| |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: |
| return self.encode_document(texts).tolist() |
|
|
| def embed_query(self, text: str) -> List[float]: |
| return self.encode_query(text).tolist() |
|
|
| import torch |
| import torchvision.transforms as transforms |
| from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights |
| from PIL import Image |
|
|
| |
|
|
| class EfficientNetV2Embedding: |
| def __init__(self): |
| print("Loading EfficientNetV2-S model...") |
| self.weights = EfficientNet_V2_S_Weights.DEFAULT |
| self.model = efficientnet_v2_s(weights=self.weights) |
| self.model.eval() |
| |
| |
| self.model.classifier = torch.nn.Identity() |
| |
| self.preprocess = self.weights.transforms() |
| print("EfficientNetV2-S model loaded successfully.") |
|
|
| def embed_image(self, image: Image.Image) -> List[float]: |
| |
| img_tensor = self.preprocess(image).unsqueeze(0) |
| |
| with torch.no_grad(): |
| embedding = self.model(img_tensor) |
| |
| return embedding.squeeze(0).tolist() |
|
|
| |
| _embedding_model = None |
| _image_embedding_model = None |
|
|
| def get_embedding_model() -> OnnxGemmaWrapper: |
| """ |
| ONNX ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ต์ด 1ํ ๋ก๋ํ์ฌ ์ฑ๊ธํค์ผ๋ก ์ฌ์ฌ์ฉํฉ๋๋ค. |
| """ |
| global _embedding_model |
| if _embedding_model is None: |
| _embedding_model = OnnxGemmaWrapper( |
| model_id="onnx-community/embeddinggemma-300m-ONNX", |
| token=hf_token |
| ) |
| return _embedding_model |
|
|
| def get_image_embedding_model() -> EfficientNetV2Embedding: |
| """ |
| EfficientNetV2-S ๋ชจ๋ธ์ ์ต์ด 1ํ ๋ก๋ํ์ฌ ์ฑ๊ธํค์ผ๋ก ์ฌ์ฌ์ฉํฉ๋๋ค. |
| """ |
| global _image_embedding_model |
| if _image_embedding_model is None: |
| _image_embedding_model = EfficientNetV2Embedding() |
| return _image_embedding_model |
|
|