File size: 6,188 Bytes
acbc1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c77b7b
acbc1d8
 
fb171e7
acbc1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb171e7
acbc1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb171e7
 
 
 
 
 
 
acbc1d8
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
from typing import TypedDict, Annotated, Union, List, Tuple
from typing_extensions import List, TypedDict
from dotenv import load_dotenv
import chainlit as cl
import nest_asyncio
import getpass
from uuid import uuid4
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from langchain_core.runnables import Runnable
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
import operator
from langchain_core.messages import BaseMessage
from langchain_core.documents import Document
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage
from ragas.embeddings import LangchainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from ragas.testset import TestsetGenerator
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from ragas import EvaluationDataset
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, ResponseRelevancy, ContextEntityRecall, NoiseSensitivity
from ragas import evaluate, RunConfig
from langchain_community.document_loaders import BSHTMLLoader
import tqdm
import asyncio
import json
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sentence_transformers import InputExample
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from huggingface_hub import notebook_login
import pandas as pd

load_dotenv()


with open("./data_new/combined_data.json", "r") as f:
    raw_data = json.load(f)

all_docs = [
    Document(page_content=entry["content"], metadata=entry["metadata"])
    for entry in raw_data]


tavily_tool = TavilySearchResults(max_results=5)


text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100, length_function = len)
split_documents = text_splitter.split_documents(all_docs)


embeddings = HuggingFaceEmbeddings(model_name = "bsmith3715/legal-ft-demo_final")


client = QdrantClient(":memory:")

client.create_collection(
    collection_name="reformer",
    vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)

vector_store_ft = QdrantVectorStore(
    client=client,
    collection_name="reformer",
    embedding=embeddings,
)

_ = vector_store_ft.add_documents(documents=split_documents)

retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5})

def retrieve_adjusted(state):
  compressor = CohereRerank(model="rerank-v3.5", top_n=10)
  compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever_finetune, search_kwargs={"k": 5}
  )
  retrieved_docs = compression_retriever.invoke(state["question"])
  return {"context" : retrieved_docs}

RAG_PROMPT = """\
You are a helpful pilates expert who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.

### Question
{question}

### Context
{context}
"""

rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

ft_llm = ChatOpenAI(model="gpt-4o-mini")

def generate(state):
  docs_content = "\n\n".join(doc.page_content for doc in state["context"])
  messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
  response = ft_llm.invoke(messages)
  return {"response" : response.content}

class State(TypedDict):
  question: str
  context: List[Document]
  response: str

graph_builder = StateGraph(State).add_sequence([retrieve_adjusted, generate])
graph_builder.add_edge(START, "retrieve_adjusted")
graph = graph_builder.compile()

cert_model = ChatOpenAI(
    model="gpt-4o",
    temperature=0
)

tool_belt = [
    tavily_tool,
]

cert_model = cert_model.bind_tools(tool_belt)

class AgentState(TypedDict):
    messages: Annotated[list, add_messages]
    context: List[Document]

def call_model(state):
  messages = state["messages"]
  response = cert_model.invoke(messages)
  return {"messages" : [response],
          "context" : state.get("context", [])
    }

tool_node = ToolNode(tool_belt)

uncompiled_graph = StateGraph(AgentState)

uncompiled_graph.add_node("agent", call_model)
uncompiled_graph.add_node("action", tool_node)

uncompiled_graph.set_entry_point("agent")

def should_continue(state):
  last_message = state["messages"][-1]

  if last_message.tool_calls:
    return "action"

  return END

uncompiled_graph.add_conditional_edges(
    "agent",
    should_continue
)

uncompiled_graph.add_edge("action", "agent")

compiled_graph_w_fine_tune = uncompiled_graph.compile()

@cl.on_chat_start
async def on_chat_start():
    await cl.Message(content=
"""👋 Welcome to your Reformer Pilates AI!
Here's what you can do:
• Ask questions about Reformer Pilates
• Get individualized workouts based on your level, goals, and equipment
• Get instant exercise modifications based on injuries or limitations
Let's get started! 🚀""").send()
    cl.user_session.set("graph", compiled_graph_w_fine_tune)
    
@cl.on_message
async def handle(message: cl.Message):
    graph = cl.user_session.get("graph")       
    state = {"messages" : [HumanMessage(message.content)]}
    response = await graph.ainvoke(state)
    await cl.Message(content=response["messages"][-1].content, stream = True).send()