bsmith3715's picture
Update app.py
3c77b7b verified
import os
from typing import TypedDict, Annotated, Union, List, Tuple
from typing_extensions import List, TypedDict
from dotenv import load_dotenv
import chainlit as cl
import nest_asyncio
import getpass
from uuid import uuid4
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from langchain_core.runnables import Runnable
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
import operator
from langchain_core.messages import BaseMessage
from langchain_core.documents import Document
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage
from ragas.embeddings import LangchainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from ragas.testset import TestsetGenerator
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from ragas import EvaluationDataset
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, ResponseRelevancy, ContextEntityRecall, NoiseSensitivity
from ragas import evaluate, RunConfig
from langchain_community.document_loaders import BSHTMLLoader
import tqdm
import asyncio
import json
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sentence_transformers import InputExample
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from huggingface_hub import notebook_login
import pandas as pd
load_dotenv()
with open("./data_new/combined_data.json", "r") as f:
raw_data = json.load(f)
all_docs = [
Document(page_content=entry["content"], metadata=entry["metadata"])
for entry in raw_data]
tavily_tool = TavilySearchResults(max_results=5)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100, length_function = len)
split_documents = text_splitter.split_documents(all_docs)
embeddings = HuggingFaceEmbeddings(model_name = "bsmith3715/legal-ft-demo_final")
client = QdrantClient(":memory:")
client.create_collection(
collection_name="reformer",
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
vector_store_ft = QdrantVectorStore(
client=client,
collection_name="reformer",
embedding=embeddings,
)
_ = vector_store_ft.add_documents(documents=split_documents)
retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5})
def retrieve_adjusted(state):
compressor = CohereRerank(model="rerank-v3.5", top_n=10)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever_finetune, search_kwargs={"k": 5}
)
retrieved_docs = compression_retriever.invoke(state["question"])
return {"context" : retrieved_docs}
RAG_PROMPT = """\
You are a helpful pilates expert who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.
### Question
{question}
### Context
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
ft_llm = ChatOpenAI(model="gpt-4o-mini")
def generate(state):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
response = ft_llm.invoke(messages)
return {"response" : response.content}
class State(TypedDict):
question: str
context: List[Document]
response: str
graph_builder = StateGraph(State).add_sequence([retrieve_adjusted, generate])
graph_builder.add_edge(START, "retrieve_adjusted")
graph = graph_builder.compile()
cert_model = ChatOpenAI(
model="gpt-4o",
temperature=0
)
tool_belt = [
tavily_tool,
]
cert_model = cert_model.bind_tools(tool_belt)
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
context: List[Document]
def call_model(state):
messages = state["messages"]
response = cert_model.invoke(messages)
return {"messages" : [response],
"context" : state.get("context", [])
}
tool_node = ToolNode(tool_belt)
uncompiled_graph = StateGraph(AgentState)
uncompiled_graph.add_node("agent", call_model)
uncompiled_graph.add_node("action", tool_node)
uncompiled_graph.set_entry_point("agent")
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "action"
return END
uncompiled_graph.add_conditional_edges(
"agent",
should_continue
)
uncompiled_graph.add_edge("action", "agent")
compiled_graph_w_fine_tune = uncompiled_graph.compile()
@cl.on_chat_start
async def on_chat_start():
await cl.Message(content=
"""👋 Welcome to your Reformer Pilates AI!
Here's what you can do:
• Ask questions about Reformer Pilates
• Get individualized workouts based on your level, goals, and equipment
• Get instant exercise modifications based on injuries or limitations
Let's get started! 🚀""").send()
cl.user_session.set("graph", compiled_graph_w_fine_tune)
@cl.on_message
async def handle(message: cl.Message):
graph = cl.user_session.get("graph")
state = {"messages" : [HumanMessage(message.content)]}
response = await graph.ainvoke(state)
await cl.Message(content=response["messages"][-1].content, stream = True).send()