File size: 4,972 Bytes
a52c5c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2100e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a52c5c7
 
d2100e7
a52c5c7
 
 
 
 
 
 
 
 
 
 
 
d2100e7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from langchain_core.embeddings import Embeddings
from typing import List
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

# huggingface-cli login ํ˜น์€ HF_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜ ํ•„์š”
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")

class OnnxGemmaWrapper(Embeddings):
    def __init__(self, model_id, token=None):
        print(f"Loading ONNX model: {model_id}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
        
        # ONNX ๋ชจ๋ธ ๋ฐ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
        model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx", token=token)
        try:
            hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data", token=token)
        except Exception:
            pass # model.onnx_data๊ฐ€ ์—†์„ ์ˆ˜๋„ ์žˆ์Œ (์ž‘์€ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ)
        
        # ์ถ”๋ก  ์„ธ์…˜ ์ƒ์„ฑ (GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ CUDAProvider ์‚ฌ์šฉ, ์—†์œผ๋ฉด CPU)
        available_providers = ort.get_available_providers()
        if 'CUDAExecutionProvider' in available_providers:
            print("CUDA detected. Using GPU.")
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        else:
            print("CUDA not detected. Using CPU.")
            providers = ['CPUExecutionProvider']

        self.session = ort.InferenceSession(model_path, providers=providers)
        
        # Prefix ์ •์˜
        self.prefixes = {
            "query": "task: search result | query: ",
            "document": "title: none | text: ",
        }
        print("ONNX Model loaded successfully.")

    def _run_inference(self, texts: List[str]):
        inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np")
        # ONNX Runtime ์‹คํ–‰ (output[0]: last_hidden_state, output[1]: pooler_output or sentence_embedding)
        # EmbeddingGemma ONNX ๋ชจ๋ธ์€ ๋ณดํ†ต ๋‘ ๋ฒˆ์งธ ๋ฆฌํ„ด๊ฐ’์ด sentence embedding์ž…๋‹ˆ๋‹ค.
        outputs = self.session.run(None, dict(inputs))
        # outputs[1]์ด (Batch, 768) ํ˜•ํƒœ์˜ ์ž„๋ฒ ๋”ฉ
        return outputs[1]

    def encode_document(self, documents: List[str]) -> np.ndarray:
        # ๋ฌธ์„œ์šฉ Prefix ์ถ”๊ฐ€
        prefixed_docs = [self.prefixes["document"] + doc for doc in documents]
        return self._run_inference(prefixed_docs)

    def encode_query(self, query: str) -> np.ndarray:
        # ์ฟผ๋ฆฌ์šฉ Prefix ์ถ”๊ฐ€ (๋‹จ์ผ ์ฟผ๋ฆฌ๋„ ๋ฆฌ์ŠคํŠธ๋กœ ์ฒ˜๋ฆฌ)
        prefixed_query = [self.prefixes["query"] + query]
        return self._run_inference(prefixed_query)[0]

    def similarity(self, query_emb: np.ndarray, doc_embs: np.ndarray) -> np.ndarray:
        if query_emb.ndim == 1:
            query_emb = query_emb.reshape(1, -1)
        scores = query_emb @ doc_embs.T
        return scores.flatten()

    # --- LangChain Compatibility Methods ---
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self.encode_document(texts).tolist()

    def embed_query(self, text: str) -> List[float]:
        return self.encode_query(text).tolist()

import torch
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from PIL import Image

# ... (existing OnnxGemmaWrapper and get_embedding_model)

class EfficientNetV2Embedding:
    def __init__(self):
        print("Loading EfficientNetV2-S model...")
        self.weights = EfficientNet_V2_S_Weights.DEFAULT
        self.model = efficientnet_v2_s(weights=self.weights)
        self.model.eval()
        
        # Remove the classification head to get embeddings
        self.model.classifier = torch.nn.Identity()
        
        self.preprocess = self.weights.transforms()
        print("EfficientNetV2-S model loaded successfully.")

    def embed_image(self, image: Image.Image) -> List[float]:
        # Preprocess image
        img_tensor = self.preprocess(image).unsqueeze(0)
        
        with torch.no_grad():
            embedding = self.model(img_tensor)
        
        return embedding.squeeze(0).tolist()

# ์ „์—ญ ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค ์ €์žฅ์†Œ
_embedding_model = None
_image_embedding_model = None

def get_embedding_model() -> OnnxGemmaWrapper:
    """
    ONNX ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ์ตœ์ดˆ 1ํšŒ ๋กœ๋“œํ•˜์—ฌ ์‹ฑ๊ธ€ํ†ค์œผ๋กœ ์žฌ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
    """
    global _embedding_model
    if _embedding_model is None:
        _embedding_model = OnnxGemmaWrapper(
            model_id="onnx-community/embeddinggemma-300m-ONNX",
            token=hf_token
        )
    return _embedding_model

def get_image_embedding_model() -> EfficientNetV2Embedding:
    """
    EfficientNetV2-S ๋ชจ๋ธ์„ ์ตœ์ดˆ 1ํšŒ ๋กœ๋“œํ•˜์—ฌ ์‹ฑ๊ธ€ํ†ค์œผ๋กœ ์žฌ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
    """
    global _image_embedding_model
    if _image_embedding_model is None:
        _image_embedding_model = EfficientNetV2Embedding()
    return _image_embedding_model