backend / src /core /llm.py
anujjoshi3105's picture
initial
22dcdfd
from functools import cache
from typing import TypeAlias
from langchain_anthropic import ChatAnthropic
from langchain_aws import ChatBedrock
from langchain_community.chat_models import FakeListChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from core.settings import settings
from schema.models import (
AllModelEnum,
AnthropicModelName,
AWSModelName,
AzureOpenAIModelName,
DeepseekModelName,
FakeModelName,
GoogleModelName,
GroqModelName,
OllamaModelName,
OpenAICompatibleName,
OpenAIModelName,
OpenRouterModelName,
VertexAIModelName,
)
_MODEL_TABLE = (
{m: m.value for m in OpenAIModelName}
| {m: m.value for m in OpenAICompatibleName}
| {m: m.value for m in AzureOpenAIModelName}
| {m: m.value for m in DeepseekModelName}
| {m: m.value for m in AnthropicModelName}
| {m: m.value for m in GoogleModelName}
| {m: m.value for m in VertexAIModelName}
| {m: m.value for m in GroqModelName}
| {m: m.value for m in AWSModelName}
| {m: m.value for m in OllamaModelName}
| {m: m.value for m in OpenRouterModelName}
| {m: m.value for m in FakeModelName}
)
class FakeToolModel(FakeListChatModel):
def __init__(self, responses: list[str]):
super().__init__(responses=responses)
def bind_tools(self, tools):
return self
ModelT: TypeAlias = (
AzureChatOpenAI
| ChatOpenAI
| ChatAnthropic
| ChatGoogleGenerativeAI
| ChatVertexAI
| ChatGroq
| ChatBedrock
| ChatOllama
| FakeToolModel
)
@cache
def get_model(model_name: AllModelEnum, /) -> ModelT:
# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
api_model_name = _MODEL_TABLE.get(model_name)
if not api_model_name:
raise ValueError(f"Unsupported model: {model_name}")
if model_name in OpenAIModelName:
return ChatOpenAI(model=api_model_name, streaming=True)
if model_name in OpenAICompatibleName:
if not settings.COMPATIBLE_BASE_URL or not settings.COMPATIBLE_MODEL:
raise ValueError("OpenAICompatible base url and endpoint must be configured")
return ChatOpenAI(
model=settings.COMPATIBLE_MODEL,
temperature=0.5,
streaming=True,
openai_api_base=settings.COMPATIBLE_BASE_URL,
openai_api_key=settings.COMPATIBLE_API_KEY,
)
if model_name in AzureOpenAIModelName:
if not settings.AZURE_OPENAI_API_KEY or not settings.AZURE_OPENAI_ENDPOINT:
raise ValueError("Azure OpenAI API key and endpoint must be configured")
return AzureChatOpenAI(
azure_endpoint=settings.AZURE_OPENAI_ENDPOINT,
deployment_name=api_model_name,
api_version=settings.AZURE_OPENAI_API_VERSION,
temperature=0.5,
streaming=True,
timeout=60,
max_retries=3,
)
if model_name in DeepseekModelName:
return ChatOpenAI(
model=api_model_name,
temperature=0.5,
streaming=True,
openai_api_base="https://api.deepseek.com",
openai_api_key=settings.DEEPSEEK_API_KEY,
)
if model_name in AnthropicModelName:
return ChatAnthropic(model=api_model_name, temperature=0.5, streaming=True)
if model_name in GoogleModelName:
return ChatGoogleGenerativeAI(model=api_model_name, temperature=0.5, streaming=True)
if model_name in VertexAIModelName:
return ChatVertexAI(model=api_model_name, temperature=0.5, streaming=True)
if model_name in GroqModelName:
if model_name == GroqModelName.LLAMA_GUARD_4_12B:
return ChatGroq(model=api_model_name, temperature=0.0) # type: ignore[call-arg]
return ChatGroq(model=api_model_name, temperature=0.5) # type: ignore[call-arg]
if model_name in AWSModelName:
return ChatBedrock(model_id=api_model_name, temperature=0.5)
if model_name in OllamaModelName:
if settings.OLLAMA_BASE_URL:
chat_ollama = ChatOllama(
model=settings.OLLAMA_MODEL, temperature=0.5, base_url=settings.OLLAMA_BASE_URL
)
else:
chat_ollama = ChatOllama(model=settings.OLLAMA_MODEL, temperature=0.5)
return chat_ollama
if model_name in OpenRouterModelName:
return ChatOpenAI(
model=api_model_name,
temperature=0.5,
streaming=True,
base_url="https://openrouter.ai/api/v1/",
api_key=settings.OPENROUTER_API_KEY,
)
if model_name in FakeModelName:
return FakeToolModel(responses=["This is a test response from the fake model."])
raise ValueError(f"Unsupported model: {model_name}")