File size: 1,094 Bytes
22dcdfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from functools import cache
from typing import TypeAlias

from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings

from core.settings import settings
from schema.models import (
    AllEmbeddingModelEnum,
    GoogleEmbeddingModelName,
    OllamaEmbeddingModelName,
    OpenAIEmbeddingModelName,
)

EmbeddingT: TypeAlias = (
    OpenAIEmbeddings
    | GoogleGenerativeAIEmbeddings
    | OllamaEmbeddings
)


@cache
def get_embeddings(model_name: AllEmbeddingModelEnum, /) -> EmbeddingT:
    if model_name in OpenAIEmbeddingModelName:
        return OpenAIEmbeddings(model=model_name.value)
    
    if model_name in GoogleEmbeddingModelName:
        return GoogleGenerativeAIEmbeddings(model=model_name.value)
    
    if model_name in OllamaEmbeddingModelName:
        return OllamaEmbeddings(
            model=settings.OLLAMA_EMBEDDING_MODEL or model_name.value,
            base_url=settings.OLLAMA_BASE_URL,
        )
    
    raise ValueError(f"Unsupported embedding model: {model_name}")