Spaces:
Sleeping
Sleeping
File size: 4,824 Bytes
da82a70 5ea8ce3 da82a70 31f0801 da82a70 be864aa 31f0801 be864aa da82a70 4972280 be864aa ce04382 da82a70 31f0801 da82a70 31f0801 da82a70 31f0801 da82a70 31f0801 da82a70 31f0801 da82a70 31f0801 da82a70 31f0801 da82a70 31f0801 da82a70 be864aa dc601f5 be864aa dc601f5 be864aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import json
from typing import AsyncGenerator
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import chainlit as cl
#from rag_processor import RAGProcessor
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI()
# Initialize RAG processor
#rag_processor = RAGProcessor()
# === Load and prepare data ===
with open("combined_data.json", "r") as f:
raw_data = json.load(f)
all_docs = [
Document(page_content=entry["content"], metadata=entry["metadata"])
for entry in raw_data]
# === Split documents into chunks ===
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=50)
chunked_docs = splitter.split_documents(all_docs)
# === Use your fine-tuned Hugging Face embeddings ===
embedding_model = HuggingFaceEmbeddings(
model_name="bsmith3715/legal-ft-demo_final"
)
# === Set up FAISS vector store ===
vectorstore = FAISS.from_documents(chunked_docs, embedding_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
# === Define prompt templates ===
RAG_PROMPT_TEMPLATE = """You are a helpful AI assistant specializing in reformer pilates. Use the following context to answer the user's question, provide a workout with the level of difficulty, length and focus provided, or a step by step description of the exercise provided. If you don't know the answer, just say that you don't know.
Context: {context}
Question: {question}
Answer:"""
IMAGE_PROMPT_TEMPLATE = """Create a detailed and professional image that represents the following reformer pilates exercise: {query}
The image should be:
- Professional and appropriate for a reformer pilates context
- Clear and easy to understand
- Visually appealing
- Suitable for use in professional settings and or presentations
- Provide a seperate visual for each step in the exercise with numbering of steps"""
# === Create prompt templates ===
rag_prompt = PromptTemplate(
template=RAG_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
image_prompt = PromptTemplate(
template=IMAGE_PROMPT_TEMPLATE,
input_variables=["query"]
)
# === Load LLM ===
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, stream = True)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type_kwargs={"prompt": rag_prompt}
)
# === Chainlit start event ===
@cl.on_chat_start
async def start():
await cl.Message(content =
"""👋 Welcome to your Reformer Pilates AI!
Here's what you can do:
• Ask questions about Reformer Pilates
• Get individualized workouts based on your level, goals, and equipment
• Get instant exercise modifications based on injuries or limitations
Let's get started! 🚀""").send()
cl.user_session.set("qa_chain", qa_chain)
# === Chainlit message handler ===
@cl.on_message
async def handle_message(message: cl.Message):
# Check if the message is requesting image generation
if message.content.lower().startswith("create an image"):
# Send loading message
msg = cl.Message(content="🎨 Creating your legal visualization...")
await msg.send()
try:
# Format the image prompt
formatted_prompt = image_prompt.format(query=message.content)
# Generate image using DALL-E
response = client.images.generate(
model="dall-e-3",
prompt=formatted_prompt,
size="1024x1024",
quality="standard",
n=1,
)
# Get the image URL
image_url = response.data[0].url
# Create and send the image message
await cl.Message(
content="Here's your generated image:",
elements=[cl.Image(url=image_url, name="generated_image")]
).send()
except Exception as e:
await cl.Message(content=f"⚠️ Error generating image: {str(e)}").send()
if message.content:
try:
# Create a message placeholder
msg = cl.Message(content="")
await msg.send()
qa_chain = cl.user_session.get("qa_chain")
# Stream the response
async for chunk in qa_chain.astream(message.content):
await msg.stream_token(chunk)
except Exception as e:
await cl.Message(content=f"Error processing your message: {e}").send()
return
await cl.Message(content="Please send a message.").send() |