Spaces:
Running
Running
| import os | |
| from langchain_core.embeddings import Embeddings | |
| from typing import List | |
| import numpy as np | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer | |
| # huggingface-cli login ํน์ HF_TOKEN ํ๊ฒฝ๋ณ์ ํ์ | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| class OnnxGemmaWrapper(Embeddings): | |
| def __init__(self, model_id, token=None): | |
| print(f"Loading ONNX model: {model_id}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
| # ONNX ๋ชจ๋ธ ๋ฐ ๊ฐ์ค์น ๋ค์ด๋ก๋ | |
| model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx", token=token) | |
| try: | |
| hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data", token=token) | |
| except Exception: | |
| pass # model.onnx_data๊ฐ ์์ ์๋ ์์ (์์ ๋ชจ๋ธ์ ๊ฒฝ์ฐ) | |
| # ์ถ๋ก ์ธ์ ์์ฑ (GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ CUDAProvider ์ฌ์ฉ, ์์ผ๋ฉด CPU) | |
| available_providers = ort.get_available_providers() | |
| if 'CUDAExecutionProvider' in available_providers: | |
| print("CUDA detected. Using GPU.") | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| else: | |
| print("CUDA not detected. Using CPU.") | |
| providers = ['CPUExecutionProvider'] | |
| self.session = ort.InferenceSession(model_path, providers=providers) | |
| # Prefix ์ ์ | |
| self.prefixes = { | |
| "query": "task: search result | query: ", | |
| "document": "title: none | text: ", | |
| } | |
| print("ONNX Model loaded successfully.") | |
| def _run_inference(self, texts: List[str]): | |
| inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np") | |
| # ONNX Runtime ์คํ (output[0]: last_hidden_state, output[1]: pooler_output or sentence_embedding) | |
| # EmbeddingGemma ONNX ๋ชจ๋ธ์ ๋ณดํต ๋ ๋ฒ์งธ ๋ฆฌํด๊ฐ์ด sentence embedding์ ๋๋ค. | |
| outputs = self.session.run(None, dict(inputs)) | |
| # outputs[1]์ด (Batch, 768) ํํ์ ์๋ฒ ๋ฉ | |
| return outputs[1] | |
| def encode_document(self, documents: List[str]) -> np.ndarray: | |
| # ๋ฌธ์์ฉ Prefix ์ถ๊ฐ | |
| prefixed_docs = [self.prefixes["document"] + doc for doc in documents] | |
| return self._run_inference(prefixed_docs) | |
| def encode_query(self, query: str) -> np.ndarray: | |
| # ์ฟผ๋ฆฌ์ฉ Prefix ์ถ๊ฐ (๋จ์ผ ์ฟผ๋ฆฌ๋ ๋ฆฌ์คํธ๋ก ์ฒ๋ฆฌ) | |
| prefixed_query = [self.prefixes["query"] + query] | |
| return self._run_inference(prefixed_query)[0] | |
| def similarity(self, query_emb: np.ndarray, doc_embs: np.ndarray) -> np.ndarray: | |
| if query_emb.ndim == 1: | |
| query_emb = query_emb.reshape(1, -1) | |
| scores = query_emb @ doc_embs.T | |
| return scores.flatten() | |
| # --- LangChain Compatibility Methods --- | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| return self.encode_document(texts).tolist() | |
| def embed_query(self, text: str) -> List[float]: | |
| return self.encode_query(text).tolist() | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights | |
| from PIL import Image | |
| # ... (existing OnnxGemmaWrapper and get_embedding_model) | |
| class EfficientNetV2Embedding: | |
| def __init__(self): | |
| print("Loading EfficientNetV2-S model...") | |
| self.weights = EfficientNet_V2_S_Weights.DEFAULT | |
| self.model = efficientnet_v2_s(weights=self.weights) | |
| self.model.eval() | |
| # Remove the classification head to get embeddings | |
| self.model.classifier = torch.nn.Identity() | |
| self.preprocess = self.weights.transforms() | |
| print("EfficientNetV2-S model loaded successfully.") | |
| def embed_image(self, image: Image.Image) -> List[float]: | |
| # Preprocess image | |
| img_tensor = self.preprocess(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| embedding = self.model(img_tensor) | |
| return embedding.squeeze(0).tolist() | |
| # ์ ์ญ ์ฑ๊ธํค ์ธ์คํด์ค ์ ์ฅ์ | |
| _embedding_model = None | |
| _image_embedding_model = None | |
| def get_embedding_model() -> OnnxGemmaWrapper: | |
| """ | |
| ONNX ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ต์ด 1ํ ๋ก๋ํ์ฌ ์ฑ๊ธํค์ผ๋ก ์ฌ์ฌ์ฉํฉ๋๋ค. | |
| """ | |
| global _embedding_model | |
| if _embedding_model is None: | |
| _embedding_model = OnnxGemmaWrapper( | |
| model_id="onnx-community/embeddinggemma-300m-ONNX", | |
| token=hf_token | |
| ) | |
| return _embedding_model | |
| def get_image_embedding_model() -> EfficientNetV2Embedding: | |
| """ | |
| EfficientNetV2-S ๋ชจ๋ธ์ ์ต์ด 1ํ ๋ก๋ํ์ฌ ์ฑ๊ธํค์ผ๋ก ์ฌ์ฌ์ฉํฉ๋๋ค. | |
| """ | |
| global _image_embedding_model | |
| if _image_embedding_model is None: | |
| _image_embedding_model = EfficientNetV2Embedding() | |
| return _image_embedding_model | |