Spaces:
Sleeping
Sleeping
| import os | |
| from typing import TypedDict, Annotated, Union, List, Tuple | |
| from typing_extensions import List, TypedDict | |
| from dotenv import load_dotenv | |
| import chainlit as cl | |
| import nest_asyncio | |
| import getpass | |
| from uuid import uuid4 | |
| from langchain_community.document_loaders import DirectoryLoader | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_qdrant import QdrantVectorStore | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| from langchain_cohere import CohereRerank | |
| from langchain.prompts import ChatPromptTemplate | |
| from langgraph.graph import START, StateGraph | |
| from typing_extensions import List, TypedDict | |
| from langchain_core.runnables import Runnable | |
| from langchain_core.tools import tool | |
| from langgraph.graph.message import add_messages | |
| import operator | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.documents import Document | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import HumanMessage | |
| from ragas.embeddings import LangchainEmbeddingsWrapper | |
| from langchain_openai import ChatOpenAI | |
| from ragas.testset import TestsetGenerator | |
| from operator import itemgetter | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
| from ragas import EvaluationDataset | |
| from ragas.llms import LangchainLLMWrapper | |
| from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, ResponseRelevancy, ContextEntityRecall, NoiseSensitivity | |
| from ragas import evaluate, RunConfig | |
| from langchain_community.document_loaders import BSHTMLLoader | |
| import tqdm | |
| import asyncio | |
| import json | |
| from sentence_transformers import SentenceTransformer | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import Dataset | |
| from sentence_transformers import InputExample | |
| from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss | |
| from sentence_transformers.evaluation import InformationRetrievalEvaluator | |
| from huggingface_hub import notebook_login | |
| import pandas as pd | |
| load_dotenv() | |
| with open("./data_new/combined_data.json", "r") as f: | |
| raw_data = json.load(f) | |
| all_docs = [ | |
| Document(page_content=entry["content"], metadata=entry["metadata"]) | |
| for entry in raw_data] | |
| tavily_tool = TavilySearchResults(max_results=5) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100, length_function = len) | |
| split_documents = text_splitter.split_documents(all_docs) | |
| embeddings = HuggingFaceEmbeddings(model_name = "bsmith3715/legal-ft-demo_final") | |
| client = QdrantClient(":memory:") | |
| client.create_collection( | |
| collection_name="reformer", | |
| vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
| ) | |
| vector_store_ft = QdrantVectorStore( | |
| client=client, | |
| collection_name="reformer", | |
| embedding=embeddings, | |
| ) | |
| _ = vector_store_ft.add_documents(documents=split_documents) | |
| retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5}) | |
| def retrieve_adjusted(state): | |
| compressor = CohereRerank(model="rerank-v3.5", top_n=10) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, base_retriever=retriever_finetune, search_kwargs={"k": 5} | |
| ) | |
| retrieved_docs = compression_retriever.invoke(state["question"]) | |
| return {"context" : retrieved_docs} | |
| RAG_PROMPT = """\ | |
| You are a helpful pilates expert who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge. | |
| ### Question | |
| {question} | |
| ### Context | |
| {context} | |
| """ | |
| rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
| ft_llm = ChatOpenAI(model="gpt-4o-mini") | |
| def generate(state): | |
| docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
| messages = rag_prompt.format_messages(question=state["question"], context=docs_content) | |
| response = ft_llm.invoke(messages) | |
| return {"response" : response.content} | |
| class State(TypedDict): | |
| question: str | |
| context: List[Document] | |
| response: str | |
| graph_builder = StateGraph(State).add_sequence([retrieve_adjusted, generate]) | |
| graph_builder.add_edge(START, "retrieve_adjusted") | |
| graph = graph_builder.compile() | |
| cert_model = ChatOpenAI( | |
| model="gpt-4o", | |
| temperature=0 | |
| ) | |
| tool_belt = [ | |
| tavily_tool, | |
| ] | |
| cert_model = cert_model.bind_tools(tool_belt) | |
| class AgentState(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| context: List[Document] | |
| def call_model(state): | |
| messages = state["messages"] | |
| response = cert_model.invoke(messages) | |
| return {"messages" : [response], | |
| "context" : state.get("context", []) | |
| } | |
| tool_node = ToolNode(tool_belt) | |
| uncompiled_graph = StateGraph(AgentState) | |
| uncompiled_graph.add_node("agent", call_model) | |
| uncompiled_graph.add_node("action", tool_node) | |
| uncompiled_graph.set_entry_point("agent") | |
| def should_continue(state): | |
| last_message = state["messages"][-1] | |
| if last_message.tool_calls: | |
| return "action" | |
| return END | |
| uncompiled_graph.add_conditional_edges( | |
| "agent", | |
| should_continue | |
| ) | |
| uncompiled_graph.add_edge("action", "agent") | |
| compiled_graph_w_fine_tune = uncompiled_graph.compile() | |
| async def on_chat_start(): | |
| await cl.Message(content= | |
| """👋 Welcome to your Reformer Pilates AI! | |
| Here's what you can do: | |
| • Ask questions about Reformer Pilates | |
| • Get individualized workouts based on your level, goals, and equipment | |
| • Get instant exercise modifications based on injuries or limitations | |
| Let's get started! 🚀""").send() | |
| cl.user_session.set("graph", compiled_graph_w_fine_tune) | |
| async def handle(message: cl.Message): | |
| graph = cl.user_session.get("graph") | |
| state = {"messages" : [HumanMessage(message.content)]} | |
| response = await graph.ainvoke(state) | |
| await cl.Message(content=response["messages"][-1].content, stream = True).send() | |