| |
| import os |
| import chainlit as cl |
| from typing import Annotated, List |
| from dotenv import load_dotenv |
| from typing_extensions import List, TypedDict |
|
|
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain.prompts import ChatPromptTemplate |
| from langchain_openai import ChatOpenAI |
| from langchain_core.documents import Document |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
| from langchain_cohere import CohereRerank |
| from langgraph.graph import START, StateGraph, END |
| from langchain_core.messages import HumanMessage |
| from langchain_core.tools import tool |
| from langchain_community.tools import TavilySearchResults |
| from langgraph.prebuilt.tool_node import ToolNode |
| from langgraph.graph.message import add_messages |
| from langchain_community.vectorstores import FAISS |
| from vectorstore import VectorStore |
|
|
| load_dotenv() |
|
|
| |
| """ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
| os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY |
| |
| COHERE_API_KEY = os.getenv("COHERE_API_KEY") |
| os.environ["COHERE_API_KEY"] = COHERE_API_KEY """ |
|
|
|
|
| |
| embed_model = HuggingFaceEmbeddings( |
| model_name="Snowflake/snowflake-arctic-embed-l", |
| model_kwargs={'device': 'cpu'}, |
| encode_kwargs={'normalize_embeddings': True} |
| ) |
|
|
| llm_sml = ChatOpenAI( |
| model="gpt-4o-mini", |
| temperature=0, |
| ) |
|
|
| |
| rag_prompt = ChatPromptTemplate.from_template("""\ |
| You are a helpful assistant who answers questions based on provided context. You must only use the provided context. Do NOT use your own knowledge. |
| if you don't know the answer, say so. |
| ### Question |
| {question} |
| ### Context |
| {context} |
| """) |
|
|
| |
| vectorstore = VectorStore( |
| collection_name="mg_alloy_collection_snowflake", |
| ) |
| documents = VectorStore.load_chunks_as_documents("data/contextual_chunks") |
| vectorstore.add_documents(documents) |
| retriever = vectorstore.as_retriever(k=5) |
|
|
| |
| class State(TypedDict): |
| question: str |
| context: List[Document] |
| response: str |
|
|
|
|
| |
|
|
| def generate(state): |
| docs_content = "\n\n".join(doc.page_content for doc in state["context"]) |
| messages = rag_prompt.format_messages(question=state["question"], context=docs_content) |
| response = llm_sml.invoke(messages) |
| return {"response" : response.content} |
|
|
|
|
| def retrieve_adjusted(state: State): |
| compressor = CohereRerank(model="rerank-v3.5") |
| compression_retriever = ContextualCompressionRetriever( |
| base_compressor=compressor, base_retriever=retriever, search_kwargs={"k": 5} |
| ) |
| retrieved_docs = compression_retriever.invoke(state["question"]) |
| return {"context" : retrieved_docs} |
|
|
|
|
| def should_continue(state): |
| last_message = state["messages"][-1] |
|
|
| if last_message.tool_calls: |
| return "action" |
|
|
| return END |
|
|
| |
|
|
| |
| graph_builder = StateGraph(State) |
| graph_builder.add_node("retrieve", retrieve_adjusted) |
| graph_builder.add_node("generate", generate) |
| graph_builder.add_edge(START, "retrieve") |
| graph_builder.add_edge("retrieve", "generate") |
| graph_builder.add_edge("generate", END) |
| graph = graph_builder.compile() |
|
|
|
|
| @tool |
| def ai_rag_tool(question: str) -> str: |
| """Useful for when you need to answer questions about magnesium alloys. Input should be a fully formed question.""" |
| response = graph.invoke({"question" : question}) |
| return { |
| "messages": [HumanMessage(content=response["response"])], |
| "context": response["context"] |
| } |
|
|
|
|
| |
| tool_belt = [ |
| ai_rag_tool |
| ] |
|
|
|
|
| class AgentState(TypedDict): |
| messages: Annotated[list, add_messages] |
| context: List[Document] |
|
|
| tool_node = ToolNode(tool_belt) |
|
|
| uncompiled_graph = StateGraph(AgentState) |
|
|
| def call_model(state): |
| messages = state["messages"] |
| response = llm_sml.invoke(messages) |
| return { |
| "messages": [response], |
| "context": state.get("context", []) |
| } |
|
|
| uncompiled_graph.add_node("agent", call_model) |
| uncompiled_graph.add_node("action", tool_node) |
| uncompiled_graph.set_entry_point("agent") |
|
|
| def should_continue(state): |
| last_message = state["messages"][-1] |
|
|
| if last_message.tool_calls: |
| return "action" |
|
|
| return END |
|
|
| uncompiled_graph.add_conditional_edges( |
| "agent", |
| should_continue |
| ) |
|
|
| uncompiled_graph.add_edge("action", "agent") |
|
|
| compiled_graph = uncompiled_graph.compile() |
|
|
|
|
| |
| @cl.on_chat_start |
| async def start(): |
| cl.user_session.set( |
| "graph", compiled_graph) |
|
|
| @cl.on_message |
| async def handle(message: cl.Message): |
| graph = cl.user_session.get("graph") |
| state = {"messages" : [HumanMessage(content=message.content)]} |
| response = await graph.ainvoke(state) |
| await cl.Message(content=response["messages"][-1].content).send() |