import os from typing import TypedDict, Annotated, Union, List, Tuple from typing_extensions import List, TypedDict from dotenv import load_dotenv import chainlit as cl import nest_asyncio import getpass from uuid import uuid4 from langchain_community.document_loaders import DirectoryLoader from langchain_community.tools.tavily_search import TavilySearchResults from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_openai import OpenAIEmbeddings from langchain_huggingface import HuggingFaceEmbeddings from langchain_qdrant import QdrantVectorStore from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_cohere import CohereRerank from langchain.prompts import ChatPromptTemplate from langgraph.graph import START, StateGraph from typing_extensions import List, TypedDict from langchain_core.runnables import Runnable from langchain_core.tools import tool from langgraph.graph.message import add_messages import operator from langchain_core.messages import BaseMessage from langchain_core.documents import Document from langgraph.prebuilt import ToolNode from langgraph.graph import StateGraph, END from langchain_core.messages import HumanMessage from ragas.embeddings import LangchainEmbeddingsWrapper from langchain_openai import ChatOpenAI from ragas.testset import TestsetGenerator from operator import itemgetter from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough, RunnableParallel from ragas import EvaluationDataset from ragas.llms import LangchainLLMWrapper from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, ResponseRelevancy, ContextEntityRecall, NoiseSensitivity from ragas import evaluate, RunConfig from langchain_community.document_loaders import BSHTMLLoader import tqdm import asyncio import json from sentence_transformers import SentenceTransformer from torch.utils.data import DataLoader from torch.utils.data import Dataset from sentence_transformers import InputExample from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss from sentence_transformers.evaluation import InformationRetrievalEvaluator from huggingface_hub import notebook_login import pandas as pd load_dotenv() with open("./data_new/combined_data.json", "r") as f: raw_data = json.load(f) all_docs = [ Document(page_content=entry["content"], metadata=entry["metadata"]) for entry in raw_data] tavily_tool = TavilySearchResults(max_results=5) text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100, length_function = len) split_documents = text_splitter.split_documents(all_docs) embeddings = HuggingFaceEmbeddings(model_name = "bsmith3715/legal-ft-demo_final") client = QdrantClient(":memory:") client.create_collection( collection_name="reformer", vectors_config=VectorParams(size=768, distance=Distance.COSINE), ) vector_store_ft = QdrantVectorStore( client=client, collection_name="reformer", embedding=embeddings, ) _ = vector_store_ft.add_documents(documents=split_documents) retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5}) def retrieve_adjusted(state): compressor = CohereRerank(model="rerank-v3.5", top_n=10) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever_finetune, search_kwargs={"k": 5} ) retrieved_docs = compression_retriever.invoke(state["question"]) return {"context" : retrieved_docs} RAG_PROMPT = """\ You are a helpful pilates expert who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge. ### Question {question} ### Context {context} """ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) ft_llm = ChatOpenAI(model="gpt-4o-mini") def generate(state): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = rag_prompt.format_messages(question=state["question"], context=docs_content) response = ft_llm.invoke(messages) return {"response" : response.content} class State(TypedDict): question: str context: List[Document] response: str graph_builder = StateGraph(State).add_sequence([retrieve_adjusted, generate]) graph_builder.add_edge(START, "retrieve_adjusted") graph = graph_builder.compile() cert_model = ChatOpenAI( model="gpt-4o", temperature=0 ) tool_belt = [ tavily_tool, ] cert_model = cert_model.bind_tools(tool_belt) class AgentState(TypedDict): messages: Annotated[list, add_messages] context: List[Document] def call_model(state): messages = state["messages"] response = cert_model.invoke(messages) return {"messages" : [response], "context" : state.get("context", []) } tool_node = ToolNode(tool_belt) uncompiled_graph = StateGraph(AgentState) uncompiled_graph.add_node("agent", call_model) uncompiled_graph.add_node("action", tool_node) uncompiled_graph.set_entry_point("agent") def should_continue(state): last_message = state["messages"][-1] if last_message.tool_calls: return "action" return END uncompiled_graph.add_conditional_edges( "agent", should_continue ) uncompiled_graph.add_edge("action", "agent") compiled_graph_w_fine_tune = uncompiled_graph.compile() @cl.on_chat_start async def on_chat_start(): await cl.Message(content= """👋 Welcome to your Reformer Pilates AI! Here's what you can do: • Ask questions about Reformer Pilates • Get individualized workouts based on your level, goals, and equipment • Get instant exercise modifications based on injuries or limitations Let's get started! 🚀""").send() cl.user_session.set("graph", compiled_graph_w_fine_tune) @cl.on_message async def handle(message: cl.Message): graph = cl.user_session.get("graph") state = {"messages" : [HumanMessage(message.content)]} response = await graph.ainvoke(state) await cl.Message(content=response["messages"][-1].content, stream = True).send()