backend / src /core /embeddings.py
anujjoshi3105's picture
initial
22dcdfd
from functools import cache
from typing import TypeAlias
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from core.settings import settings
from schema.models import (
AllEmbeddingModelEnum,
GoogleEmbeddingModelName,
OllamaEmbeddingModelName,
OpenAIEmbeddingModelName,
)
EmbeddingT: TypeAlias = (
OpenAIEmbeddings
| GoogleGenerativeAIEmbeddings
| OllamaEmbeddings
)
@cache
def get_embeddings(model_name: AllEmbeddingModelEnum, /) -> EmbeddingT:
if model_name in OpenAIEmbeddingModelName:
return OpenAIEmbeddings(model=model_name.value)
if model_name in GoogleEmbeddingModelName:
return GoogleGenerativeAIEmbeddings(model=model_name.value)
if model_name in OllamaEmbeddingModelName:
return OllamaEmbeddings(
model=settings.OLLAMA_EMBEDDING_MODEL or model_name.value,
base_url=settings.OLLAMA_BASE_URL,
)
raise ValueError(f"Unsupported embedding model: {model_name}")