File size: 1,094 Bytes
22dcdfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from functools import cache
from typing import TypeAlias
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from core.settings import settings
from schema.models import (
AllEmbeddingModelEnum,
GoogleEmbeddingModelName,
OllamaEmbeddingModelName,
OpenAIEmbeddingModelName,
)
EmbeddingT: TypeAlias = (
OpenAIEmbeddings
| GoogleGenerativeAIEmbeddings
| OllamaEmbeddings
)
@cache
def get_embeddings(model_name: AllEmbeddingModelEnum, /) -> EmbeddingT:
if model_name in OpenAIEmbeddingModelName:
return OpenAIEmbeddings(model=model_name.value)
if model_name in GoogleEmbeddingModelName:
return GoogleGenerativeAIEmbeddings(model=model_name.value)
if model_name in OllamaEmbeddingModelName:
return OllamaEmbeddings(
model=settings.OLLAMA_EMBEDDING_MODEL or model_name.value,
base_url=settings.OLLAMA_BASE_URL,
)
raise ValueError(f"Unsupported embedding model: {model_name}")
|