| | |
| |
|
| | import os |
| | from typing import TypedDict, Annotated, List |
| | import operator |
| | from dotenv import load_dotenv |
| | from langchain_openai import ChatOpenAI |
| | from langchain_anthropic import ChatAnthropic |
| | from langchain.schema import BaseMessage, AIMessage |
| | from langgraph.graph import StateGraph, END |
| | from langgraph.checkpoint.memory import MemorySaver |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | class GraphState(TypedDict): |
| | """ |
| | Represents the state of our graph. |
| | |
| | Attributes: |
| | messages: The list of messages comprising the conversation. |
| | operator.add indicates messages should be appended. |
| | """ |
| | messages: Annotated[List[BaseMessage], operator.add] |
| |
|
| | |
| |
|
| | def initialize_llm(provider: str, model_name: str, temperature: float, api_key: str): |
| | """Initializes the appropriate LangChain Chat Model.""" |
| | if provider == "OpenAI": |
| | if not api_key: |
| | raise ValueError("OpenAI API key is missing. Please set OPENAI_API_KEY.") |
| | return ChatOpenAI(api_key=api_key, model_name=model_name, temperature=temperature) |
| | elif provider == "Anthropic": |
| | if not api_key: |
| | raise ValueError("Anthropic API key is missing. Please set ANTHROPIC_API_KEY.") |
| | return ChatAnthropic(api_key=api_key, model_name=model_name, temperature=temperature) |
| | else: |
| | raise ValueError(f"Unsupported LLM provider: {provider}") |
| |
|
| | |
| |
|
| | def create_chat_graph(llm): |
| | """ |
| | Builds and compiles the LangGraph conversational graph. |
| | |
| | Args: |
| | llm: An initialized LangChain Chat Model instance. |
| | |
| | Returns: |
| | A compiled LangGraph application. |
| | """ |
| |
|
| | |
| | def call_model(state: GraphState) -> dict: |
| | """Invokes the provided LLM with the current conversation state.""" |
| | messages = state['messages'] |
| | response = llm.invoke(messages) |
| | |
| | return {"messages": [response]} |
| |
|
| | |
| | workflow = StateGraph(GraphState) |
| |
|
| | |
| | workflow.add_node("llm_node", call_model) |
| |
|
| | |
| | workflow.set_entry_point("llm_node") |
| | workflow.add_edge("llm_node", END) |
| |
|
| | |
| | |
| | |
| | |
| | graph = workflow.compile() |
| |
|
| | return graph |
| |
|
| | |