File size: 4,824 Bytes
da82a70
 
5ea8ce3
da82a70
 
 
 
 
 
 
 
31f0801
da82a70
 
 
 
 
be864aa
31f0801
be864aa
da82a70
 
 
 
 
 
4972280
be864aa
ce04382
 
 
 
 
 
 
 
da82a70
31f0801
 
da82a70
 
31f0801
da82a70
31f0801
da82a70
31f0801
da82a70
31f0801
da82a70
 
 
 
 
 
 
 
 
 
 
31f0801
 
 
 
da82a70
 
 
 
 
 
 
31f0801
 
 
 
 
 
da82a70
 
 
 
 
 
 
 
 
 
 
 
 
31f0801
da82a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be864aa
 
 
 
 
 
dc601f5
be864aa
dc601f5
be864aa
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import json
from typing import AsyncGenerator
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import chainlit as cl
#from rag_processor import RAGProcessor
from openai import OpenAI

# Initialize OpenAI client
client = OpenAI()

# Initialize RAG processor
#rag_processor = RAGProcessor()

# === Load and prepare data ===
with open("combined_data.json", "r") as f:
    raw_data = json.load(f)

all_docs = [
    Document(page_content=entry["content"], metadata=entry["metadata"])
    for entry in raw_data]
    
# === Split documents into chunks ===
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=50)
chunked_docs = splitter.split_documents(all_docs)

# === Use your fine-tuned Hugging Face embeddings ===
embedding_model = HuggingFaceEmbeddings(
    model_name="bsmith3715/legal-ft-demo_final"
)
# === Set up FAISS vector store ===
vectorstore = FAISS.from_documents(chunked_docs, embedding_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

# === Define prompt templates ===
RAG_PROMPT_TEMPLATE = """You are a helpful AI assistant specializing in reformer pilates. Use the following context to answer the user's question, provide a workout with the level of difficulty, length and focus provided, or a step by step description of the exercise provided. If you don't know the answer, just say that you don't know.

Context: {context}

Question: {question}

Answer:"""

IMAGE_PROMPT_TEMPLATE = """Create a detailed and professional image that represents the following reformer pilates exercise: {query}

The image should be:
- Professional and appropriate for a reformer pilates context
- Clear and easy to understand
- Visually appealing
- Suitable for use in professional settings and or presentations
- Provide a seperate visual for each step in the exercise with numbering of steps"""

# === Create prompt templates ===
rag_prompt = PromptTemplate(
    template=RAG_PROMPT_TEMPLATE,
    input_variables=["context", "question"]
)

image_prompt = PromptTemplate(
    template=IMAGE_PROMPT_TEMPLATE,
    input_variables=["query"]
)

# === Load LLM ===
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, stream = True)
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type_kwargs={"prompt": rag_prompt}
)

# === Chainlit start event ===
@cl.on_chat_start
async def start():
    await cl.Message(content = 
"""👋 Welcome to your Reformer Pilates AI!

Here's what you can do:
• Ask questions about Reformer Pilates
• Get individualized workouts based on your level, goals, and equipment
• Get instant exercise modifications based on injuries or limitations

Let's get started! 🚀""").send()
    cl.user_session.set("qa_chain", qa_chain)

# === Chainlit message handler ===
@cl.on_message
async def handle_message(message: cl.Message):
    # Check if the message is requesting image generation
    if message.content.lower().startswith("create an image"):
        # Send loading message
        msg = cl.Message(content="🎨 Creating your legal visualization...")
        await msg.send()
        
        try:
            # Format the image prompt
            formatted_prompt = image_prompt.format(query=message.content)
            
            # Generate image using DALL-E
            response = client.images.generate(
                model="dall-e-3",
                prompt=formatted_prompt,
                size="1024x1024",
                quality="standard",
                n=1,
            )
            
            # Get the image URL
            image_url = response.data[0].url
            
            # Create and send the image message
            await cl.Message(
                content="Here's your generated image:",
                elements=[cl.Image(url=image_url, name="generated_image")]
            ).send()
            
        except Exception as e:
            await cl.Message(content=f"⚠️ Error generating image: {str(e)}").send()
    
    if message.content:
        try:
            # Create a message placeholder
            msg = cl.Message(content="")
            await msg.send()

            qa_chain = cl.user_session.get("qa_chain")
            # Stream the response
            async for chunk in qa_chain.astream(message.content):
                await msg.stream_token(chunk)
        except Exception as e:
            await cl.Message(content=f"Error processing your message: {e}").send()
        return

    await cl.Message(content="Please send a message.").send()