REFORMER_AI / app.py
bsmith3715's picture
Update app.py
dc601f5 verified
import os
import json
from typing import AsyncGenerator
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import chainlit as cl
#from rag_processor import RAGProcessor
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI()
# Initialize RAG processor
#rag_processor = RAGProcessor()
# === Load and prepare data ===
with open("combined_data.json", "r") as f:
raw_data = json.load(f)
all_docs = [
Document(page_content=entry["content"], metadata=entry["metadata"])
for entry in raw_data]
# === Split documents into chunks ===
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=50)
chunked_docs = splitter.split_documents(all_docs)
# === Use your fine-tuned Hugging Face embeddings ===
embedding_model = HuggingFaceEmbeddings(
model_name="bsmith3715/legal-ft-demo_final"
)
# === Set up FAISS vector store ===
vectorstore = FAISS.from_documents(chunked_docs, embedding_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
# === Define prompt templates ===
RAG_PROMPT_TEMPLATE = """You are a helpful AI assistant specializing in reformer pilates. Use the following context to answer the user's question, provide a workout with the level of difficulty, length and focus provided, or a step by step description of the exercise provided. If you don't know the answer, just say that you don't know.
Context: {context}
Question: {question}
Answer:"""
IMAGE_PROMPT_TEMPLATE = """Create a detailed and professional image that represents the following reformer pilates exercise: {query}
The image should be:
- Professional and appropriate for a reformer pilates context
- Clear and easy to understand
- Visually appealing
- Suitable for use in professional settings and or presentations
- Provide a seperate visual for each step in the exercise with numbering of steps"""
# === Create prompt templates ===
rag_prompt = PromptTemplate(
template=RAG_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
image_prompt = PromptTemplate(
template=IMAGE_PROMPT_TEMPLATE,
input_variables=["query"]
)
# === Load LLM ===
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, stream = True)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type_kwargs={"prompt": rag_prompt}
)
# === Chainlit start event ===
@cl.on_chat_start
async def start():
await cl.Message(content =
"""πŸ‘‹ Welcome to your Reformer Pilates AI!
Here's what you can do:
β€’ Ask questions about Reformer Pilates
β€’ Get individualized workouts based on your level, goals, and equipment
β€’ Get instant exercise modifications based on injuries or limitations
Let's get started! πŸš€""").send()
cl.user_session.set("qa_chain", qa_chain)
# === Chainlit message handler ===
@cl.on_message
async def handle_message(message: cl.Message):
# Check if the message is requesting image generation
if message.content.lower().startswith("create an image"):
# Send loading message
msg = cl.Message(content="🎨 Creating your legal visualization...")
await msg.send()
try:
# Format the image prompt
formatted_prompt = image_prompt.format(query=message.content)
# Generate image using DALL-E
response = client.images.generate(
model="dall-e-3",
prompt=formatted_prompt,
size="1024x1024",
quality="standard",
n=1,
)
# Get the image URL
image_url = response.data[0].url
# Create and send the image message
await cl.Message(
content="Here's your generated image:",
elements=[cl.Image(url=image_url, name="generated_image")]
).send()
except Exception as e:
await cl.Message(content=f"⚠️ Error generating image: {str(e)}").send()
if message.content:
try:
# Create a message placeholder
msg = cl.Message(content="")
await msg.send()
qa_chain = cl.user_session.get("qa_chain")
# Stream the response
async for chunk in qa_chain.astream(message.content):
await msg.stream_token(chunk)
except Exception as e:
await cl.Message(content=f"Error processing your message: {e}").send()
return
await cl.Message(content="Please send a message.").send()