|
|
from functools import cache |
|
|
from typing import TypeAlias |
|
|
|
|
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings |
|
|
from langchain_ollama import OllamaEmbeddings |
|
|
from langchain_openai import OpenAIEmbeddings |
|
|
|
|
|
from core.settings import settings |
|
|
from schema.models import ( |
|
|
AllEmbeddingModelEnum, |
|
|
GoogleEmbeddingModelName, |
|
|
OllamaEmbeddingModelName, |
|
|
OpenAIEmbeddingModelName, |
|
|
) |
|
|
|
|
|
EmbeddingT: TypeAlias = ( |
|
|
OpenAIEmbeddings |
|
|
| GoogleGenerativeAIEmbeddings |
|
|
| OllamaEmbeddings |
|
|
) |
|
|
|
|
|
|
|
|
@cache |
|
|
def get_embeddings(model_name: AllEmbeddingModelEnum, /) -> EmbeddingT: |
|
|
if model_name in OpenAIEmbeddingModelName: |
|
|
return OpenAIEmbeddings(model=model_name.value) |
|
|
|
|
|
if model_name in GoogleEmbeddingModelName: |
|
|
return GoogleGenerativeAIEmbeddings(model=model_name.value) |
|
|
|
|
|
if model_name in OllamaEmbeddingModelName: |
|
|
return OllamaEmbeddings( |
|
|
model=settings.OLLAMA_EMBEDDING_MODEL or model_name.value, |
|
|
base_url=settings.OLLAMA_BASE_URL, |
|
|
) |
|
|
|
|
|
raise ValueError(f"Unsupported embedding model: {model_name}") |
|
|
|