Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import AsyncGenerator | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_openai import ChatOpenAI | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| import chainlit as cl | |
| #from rag_processor import RAGProcessor | |
| from openai import OpenAI | |
| # Initialize OpenAI client | |
| client = OpenAI() | |
| # Initialize RAG processor | |
| #rag_processor = RAGProcessor() | |
| # === Load and prepare data === | |
| with open("combined_data.json", "r") as f: | |
| raw_data = json.load(f) | |
| all_docs = [ | |
| Document(page_content=entry["content"], metadata=entry["metadata"]) | |
| for entry in raw_data] | |
| # === Split documents into chunks === | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=50) | |
| chunked_docs = splitter.split_documents(all_docs) | |
| # === Use your fine-tuned Hugging Face embeddings === | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="bsmith3715/legal-ft-demo_final" | |
| ) | |
| # === Set up FAISS vector store === | |
| vectorstore = FAISS.from_documents(chunked_docs, embedding_model) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| # === Define prompt templates === | |
| RAG_PROMPT_TEMPLATE = """You are a helpful AI assistant specializing in reformer pilates. Use the following context to answer the user's question, provide a workout with the level of difficulty, length and focus provided, or a step by step description of the exercise provided. If you don't know the answer, just say that you don't know. | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| IMAGE_PROMPT_TEMPLATE = """Create a detailed and professional image that represents the following reformer pilates exercise: {query} | |
| The image should be: | |
| - Professional and appropriate for a reformer pilates context | |
| - Clear and easy to understand | |
| - Visually appealing | |
| - Suitable for use in professional settings and or presentations | |
| - Provide a seperate visual for each step in the exercise with numbering of steps""" | |
| # === Create prompt templates === | |
| rag_prompt = PromptTemplate( | |
| template=RAG_PROMPT_TEMPLATE, | |
| input_variables=["context", "question"] | |
| ) | |
| image_prompt = PromptTemplate( | |
| template=IMAGE_PROMPT_TEMPLATE, | |
| input_variables=["query"] | |
| ) | |
| # === Load LLM === | |
| llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, stream = True) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": rag_prompt} | |
| ) | |
| # === Chainlit start event === | |
| async def start(): | |
| await cl.Message(content = | |
| """π Welcome to your Reformer Pilates AI! | |
| Here's what you can do: | |
| β’ Ask questions about Reformer Pilates | |
| β’ Get individualized workouts based on your level, goals, and equipment | |
| β’ Get instant exercise modifications based on injuries or limitations | |
| Let's get started! π""").send() | |
| cl.user_session.set("qa_chain", qa_chain) | |
| # === Chainlit message handler === | |
| async def handle_message(message: cl.Message): | |
| # Check if the message is requesting image generation | |
| if message.content.lower().startswith("create an image"): | |
| # Send loading message | |
| msg = cl.Message(content="π¨ Creating your legal visualization...") | |
| await msg.send() | |
| try: | |
| # Format the image prompt | |
| formatted_prompt = image_prompt.format(query=message.content) | |
| # Generate image using DALL-E | |
| response = client.images.generate( | |
| model="dall-e-3", | |
| prompt=formatted_prompt, | |
| size="1024x1024", | |
| quality="standard", | |
| n=1, | |
| ) | |
| # Get the image URL | |
| image_url = response.data[0].url | |
| # Create and send the image message | |
| await cl.Message( | |
| content="Here's your generated image:", | |
| elements=[cl.Image(url=image_url, name="generated_image")] | |
| ).send() | |
| except Exception as e: | |
| await cl.Message(content=f"β οΈ Error generating image: {str(e)}").send() | |
| if message.content: | |
| try: | |
| # Create a message placeholder | |
| msg = cl.Message(content="") | |
| await msg.send() | |
| qa_chain = cl.user_session.get("qa_chain") | |
| # Stream the response | |
| async for chunk in qa_chain.astream(message.content): | |
| await msg.stream_token(chunk) | |
| except Exception as e: | |
| await cl.Message(content=f"Error processing your message: {e}").send() | |
| return | |
| await cl.Message(content="Please send a message.").send() |