File size: 1,984 Bytes
5f45c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feab02f
5f45c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import json
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
import chainlit as cl

# === Load and prepare data ===
with open("combined_data.json", "r") as f:
    raw_data = json.load(f)

all_docs = [
    Document(page_content=entry["content"], metadata=entry["metadata"])
    for entry in raw_data
]

# === Split documents into chunks ===
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=50)
chunked_docs = splitter.split_documents(all_docs)

# === Use your fine-tuned Hugging Face embeddings ===
embedding_model = HuggingFaceEmbeddings(
    model_name="bsmith3715/legal-ft-demo_final"
)

# === Set up FAISS vector store ===
vectorstore = FAISS.from_documents(chunked_docs, embedding_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

# === Load LLM ===
llm = ChatOpenAI(model_name="gpt-4o", temperature=0, streaming = True)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)

# === Chainlit start event ===
@cl.on_chat_start
async def start():
    await cl.Message(content = 
"""👋 Welcome to your Reformer Pilates AI!

Here’s what you can do:
• Ask questions about Reformer Pilates
• Get individualized workouts based on your level, goals, and equipment
• Get instant exercise modifications based on injuries or limitations

Let’s get started! 🚀""").send()
    cl.user_session.set("qa_chain", qa_chain)

# === Chainlit message handler ===
@cl.on_message
async def handle_message(message: cl.Message):
    chain = cl.user_session.get("qa_chain")
    if chain:
        try:
            response = chain.run(message.content)
        except Exception as e:
            response = f"⚠️ Error: {str(e)}"
        await cl.Message(response).send()