Spaces:
Sleeping
Sleeping
File size: 6,188 Bytes
acbc1d8 3c77b7b acbc1d8 fb171e7 acbc1d8 fb171e7 acbc1d8 fb171e7 acbc1d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import os
from typing import TypedDict, Annotated, Union, List, Tuple
from typing_extensions import List, TypedDict
from dotenv import load_dotenv
import chainlit as cl
import nest_asyncio
import getpass
from uuid import uuid4
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from langchain_core.runnables import Runnable
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
import operator
from langchain_core.messages import BaseMessage
from langchain_core.documents import Document
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage
from ragas.embeddings import LangchainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from ragas.testset import TestsetGenerator
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from ragas import EvaluationDataset
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, ResponseRelevancy, ContextEntityRecall, NoiseSensitivity
from ragas import evaluate, RunConfig
from langchain_community.document_loaders import BSHTMLLoader
import tqdm
import asyncio
import json
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sentence_transformers import InputExample
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from huggingface_hub import notebook_login
import pandas as pd
load_dotenv()
with open("./data_new/combined_data.json", "r") as f:
raw_data = json.load(f)
all_docs = [
Document(page_content=entry["content"], metadata=entry["metadata"])
for entry in raw_data]
tavily_tool = TavilySearchResults(max_results=5)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100, length_function = len)
split_documents = text_splitter.split_documents(all_docs)
embeddings = HuggingFaceEmbeddings(model_name = "bsmith3715/legal-ft-demo_final")
client = QdrantClient(":memory:")
client.create_collection(
collection_name="reformer",
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
vector_store_ft = QdrantVectorStore(
client=client,
collection_name="reformer",
embedding=embeddings,
)
_ = vector_store_ft.add_documents(documents=split_documents)
retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5})
def retrieve_adjusted(state):
compressor = CohereRerank(model="rerank-v3.5", top_n=10)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever_finetune, search_kwargs={"k": 5}
)
retrieved_docs = compression_retriever.invoke(state["question"])
return {"context" : retrieved_docs}
RAG_PROMPT = """\
You are a helpful pilates expert who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.
### Question
{question}
### Context
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
ft_llm = ChatOpenAI(model="gpt-4o-mini")
def generate(state):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
response = ft_llm.invoke(messages)
return {"response" : response.content}
class State(TypedDict):
question: str
context: List[Document]
response: str
graph_builder = StateGraph(State).add_sequence([retrieve_adjusted, generate])
graph_builder.add_edge(START, "retrieve_adjusted")
graph = graph_builder.compile()
cert_model = ChatOpenAI(
model="gpt-4o",
temperature=0
)
tool_belt = [
tavily_tool,
]
cert_model = cert_model.bind_tools(tool_belt)
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
context: List[Document]
def call_model(state):
messages = state["messages"]
response = cert_model.invoke(messages)
return {"messages" : [response],
"context" : state.get("context", [])
}
tool_node = ToolNode(tool_belt)
uncompiled_graph = StateGraph(AgentState)
uncompiled_graph.add_node("agent", call_model)
uncompiled_graph.add_node("action", tool_node)
uncompiled_graph.set_entry_point("agent")
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "action"
return END
uncompiled_graph.add_conditional_edges(
"agent",
should_continue
)
uncompiled_graph.add_edge("action", "agent")
compiled_graph_w_fine_tune = uncompiled_graph.compile()
@cl.on_chat_start
async def on_chat_start():
await cl.Message(content=
"""👋 Welcome to your Reformer Pilates AI!
Here's what you can do:
• Ask questions about Reformer Pilates
• Get individualized workouts based on your level, goals, and equipment
• Get instant exercise modifications based on injuries or limitations
Let's get started! 🚀""").send()
cl.user_session.set("graph", compiled_graph_w_fine_tune)
@cl.on_message
async def handle(message: cl.Message):
graph = cl.user_session.get("graph")
state = {"messages" : [HumanMessage(message.content)]}
response = await graph.ainvoke(state)
await cl.Message(content=response["messages"][-1].content, stream = True).send()
|