DEMO_6_3 / app.py
bsmith3715's picture
Update app.py
350005c verified
import os
import json
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
import chainlit as cl
from typing import AsyncGenerator
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
# === Load and prepare data ===
with open("combined_data.json", "r") as f:
raw_data = json.load(f)
all_docs = [
Document(page_content=entry["content"], metadata=entry["metadata"])
for entry in raw_data
]
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100)
split_documents = text_splitter.split_documents(all_docs)
embeddings = HuggingFaceEmbeddings(model_name="bsmith3715/legal-ft-demo_final")
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
client = QdrantClient(":memory:")
client.create_collection(
collection_name="reformer_docs",
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
vector_store_ft = QdrantVectorStore(
client=client,
collection_name="reformer_docs",
embedding=embeddings,
)
# === Set up FAISS vector store ===
_ = vector_store_ft.add_documents(documents=split_documents)
retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5})
# === Load LLM ===
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, stream=True)
from langchain.prompts import ChatPromptTemplate
RAG_PROMPT = """\
You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.
### Question
{question}
### Context
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
finetune_rag_chain = (
{"context": itemgetter("question") | retriever_finetune, "question": itemgetter("question")}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | llm | StrOutputParser(), "context": itemgetter("context")}
)
# === Chainlit start event ===
@cl.on_chat_start
async def start():
await cl.Message(content =
"""πŸ‘‹ Welcome to your Reformer Pilates AI!
Here's what you can do:
β€’ Ask questions about Reformer Pilates
β€’ Get individualized workouts based on your level, goals, and equipment
β€’ Get instant exercise modifications based on injuries or limitations
Let's get started! πŸš€""").send()
cl.user_session.set("qa_chain", finetune_rag_chain)
@cl.on_message
async def main(message):
# Get retriever
chain = cl.user_session.get("qa_chain")
# Run the chain once to get context
inputs = {"question": message.content}
context_and_prompt = await chain.ainvoke(inputs)
# Send a blank message to stream into
msg = cl.Message(content="")
# Call LLM manually with streaming
llm = ChatOpenAI(model_name="gpt-4.1-mini", temperature=0, streaming=True)
full_prompt = rag_prompt.format(**inputs, context=context_and_prompt["context"])
async for chunk in llm.astream(full_prompt):
await msg.stream_token(chunk.content) # Only stream the text part
await msg.send()