|
|
import os |
|
|
from enum import Enum |
|
|
from typing import Dict, List |
|
|
|
|
|
from llama_index.core.embeddings import BaseEmbedding |
|
|
|
|
|
|
|
|
|
|
|
SUPPORTED_MODELS: Dict[str, List[str]] = { |
|
|
"openai": [ |
|
|
"text-embedding-ada-002", |
|
|
"text-embedding-3-small", |
|
|
"text-embedding-3-large" |
|
|
], |
|
|
"azure_openai": [ |
|
|
"text-embedding-3-small", |
|
|
"text-embedding-3-large" |
|
|
], |
|
|
"huggingface": [ |
|
|
"sentence-transformers/all-MiniLM-L6-v2", |
|
|
"sentence-transformers/all-mpnet-base-v2", |
|
|
"sentence-transformers/multi-qa-mpnet-base-dot-v1", |
|
|
"BAAI/bge-small-en-v1.5", |
|
|
"BAAI/bge-large-en-v1.5", |
|
|
], |
|
|
"ollama": [ |
|
|
"nomic-embed-text", |
|
|
"mxbai-embed-large", |
|
|
"bge-m3", |
|
|
"all-minilm" |
|
|
"snowflake-arctic-embed" |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
class EmbeddingProvider(str, Enum): |
|
|
OPENAI = "openai" |
|
|
AZURE_OPENAI = "azure_openai" |
|
|
HUGGINGFACE = "huggingface" |
|
|
OLLAMA = "ollama" |
|
|
VOYAGE = "voyage" |
|
|
|
|
|
@classmethod |
|
|
def validate_model(cls, provider: str, model_name: str) -> bool: |
|
|
"""Validate if the model is supported for the given provider. |
|
|
|
|
|
Args: |
|
|
provider (str): The embedding provider (e.g., 'openai', 'huggingface', 'ollama'). |
|
|
model_name (str): The name of the embedding model to validate. |
|
|
|
|
|
Returns: |
|
|
bool: True if the model is supported or provider is 'custom', False otherwise. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the provider is invalid. |
|
|
""" |
|
|
if provider not in SUPPORTED_MODELS: |
|
|
raise ValueError(f"Unsupported provider: {provider}") |
|
|
|
|
|
|
|
|
if provider == "huggingface": |
|
|
if os.path.exists(model_name): |
|
|
return True |
|
|
return model_name in SUPPORTED_MODELS.get(provider, []) |
|
|
|
|
|
return model_name in SUPPORTED_MODELS.get(provider, []) |
|
|
|
|
|
|
|
|
class BaseEmbeddingWrapper: |
|
|
"""Base interface for embedding wrappers.""" |
|
|
|
|
|
def get_embedding_model(self) -> BaseEmbedding: |
|
|
"""Return the LlamaIndex-compatible embedding model.""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
def validate_model(self, provider: EmbeddingProvider, model_name: str) -> bool: |
|
|
"""Validate if the model is supported for the given provider. |
|
|
|
|
|
Args: |
|
|
provider (EmbeddingProvider): The embedding provider. |
|
|
model_name (str): The name of the embedding model to validate. |
|
|
|
|
|
Returns: |
|
|
bool: True if the model is supported, False otherwise. |
|
|
""" |
|
|
return EmbeddingProvider.validate_model(provider, model_name) |
|
|
|
|
|
@property |
|
|
def dimensions(self) -> int: |
|
|
raise NotImplementedError() |