RAG_HF / core /dependencies.py
tjrlgns09's picture
.
d2100e7
import os
from langchain_core.embeddings import Embeddings
from typing import List
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
# huggingface-cli login ํ˜น์€ HF_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜ ํ•„์š”
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
class OnnxGemmaWrapper(Embeddings):
def __init__(self, model_id, token=None):
print(f"Loading ONNX model: {model_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
# ONNX ๋ชจ๋ธ ๋ฐ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx", token=token)
try:
hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data", token=token)
except Exception:
pass # model.onnx_data๊ฐ€ ์—†์„ ์ˆ˜๋„ ์žˆ์Œ (์ž‘์€ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ)
# ์ถ”๋ก  ์„ธ์…˜ ์ƒ์„ฑ (GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ CUDAProvider ์‚ฌ์šฉ, ์—†์œผ๋ฉด CPU)
available_providers = ort.get_available_providers()
if 'CUDAExecutionProvider' in available_providers:
print("CUDA detected. Using GPU.")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
print("CUDA not detected. Using CPU.")
providers = ['CPUExecutionProvider']
self.session = ort.InferenceSession(model_path, providers=providers)
# Prefix ์ •์˜
self.prefixes = {
"query": "task: search result | query: ",
"document": "title: none | text: ",
}
print("ONNX Model loaded successfully.")
def _run_inference(self, texts: List[str]):
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np")
# ONNX Runtime ์‹คํ–‰ (output[0]: last_hidden_state, output[1]: pooler_output or sentence_embedding)
# EmbeddingGemma ONNX ๋ชจ๋ธ์€ ๋ณดํ†ต ๋‘ ๋ฒˆ์งธ ๋ฆฌํ„ด๊ฐ’์ด sentence embedding์ž…๋‹ˆ๋‹ค.
outputs = self.session.run(None, dict(inputs))
# outputs[1]์ด (Batch, 768) ํ˜•ํƒœ์˜ ์ž„๋ฒ ๋”ฉ
return outputs[1]
def encode_document(self, documents: List[str]) -> np.ndarray:
# ๋ฌธ์„œ์šฉ Prefix ์ถ”๊ฐ€
prefixed_docs = [self.prefixes["document"] + doc for doc in documents]
return self._run_inference(prefixed_docs)
def encode_query(self, query: str) -> np.ndarray:
# ์ฟผ๋ฆฌ์šฉ Prefix ์ถ”๊ฐ€ (๋‹จ์ผ ์ฟผ๋ฆฌ๋„ ๋ฆฌ์ŠคํŠธ๋กœ ์ฒ˜๋ฆฌ)
prefixed_query = [self.prefixes["query"] + query]
return self._run_inference(prefixed_query)[0]
def similarity(self, query_emb: np.ndarray, doc_embs: np.ndarray) -> np.ndarray:
if query_emb.ndim == 1:
query_emb = query_emb.reshape(1, -1)
scores = query_emb @ doc_embs.T
return scores.flatten()
# --- LangChain Compatibility Methods ---
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.encode_document(texts).tolist()
def embed_query(self, text: str) -> List[float]:
return self.encode_query(text).tolist()
import torch
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from PIL import Image
# ... (existing OnnxGemmaWrapper and get_embedding_model)
class EfficientNetV2Embedding:
def __init__(self):
print("Loading EfficientNetV2-S model...")
self.weights = EfficientNet_V2_S_Weights.DEFAULT
self.model = efficientnet_v2_s(weights=self.weights)
self.model.eval()
# Remove the classification head to get embeddings
self.model.classifier = torch.nn.Identity()
self.preprocess = self.weights.transforms()
print("EfficientNetV2-S model loaded successfully.")
def embed_image(self, image: Image.Image) -> List[float]:
# Preprocess image
img_tensor = self.preprocess(image).unsqueeze(0)
with torch.no_grad():
embedding = self.model(img_tensor)
return embedding.squeeze(0).tolist()
# ์ „์—ญ ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค ์ €์žฅ์†Œ
_embedding_model = None
_image_embedding_model = None
def get_embedding_model() -> OnnxGemmaWrapper:
"""
ONNX ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ์ตœ์ดˆ 1ํšŒ ๋กœ๋“œํ•˜์—ฌ ์‹ฑ๊ธ€ํ†ค์œผ๋กœ ์žฌ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
"""
global _embedding_model
if _embedding_model is None:
_embedding_model = OnnxGemmaWrapper(
model_id="onnx-community/embeddinggemma-300m-ONNX",
token=hf_token
)
return _embedding_model
def get_image_embedding_model() -> EfficientNetV2Embedding:
"""
EfficientNetV2-S ๋ชจ๋ธ์„ ์ตœ์ดˆ 1ํšŒ ๋กœ๋“œํ•˜์—ฌ ์‹ฑ๊ธ€ํ†ค์œผ๋กœ ์žฌ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
"""
global _image_embedding_model
if _image_embedding_model is None:
_image_embedding_model = EfficientNetV2Embedding()
return _image_embedding_model