Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_openai import ChatOpenAI | |
| from langchain.chains import RetrievalQA | |
| import chainlit as cl | |
| from typing import AsyncGenerator | |
| from operator import itemgetter | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
| # === Load and prepare data === | |
| with open("combined_data.json", "r") as f: | |
| raw_data = json.load(f) | |
| all_docs = [ | |
| Document(page_content=entry["content"], metadata=entry["metadata"]) | |
| for entry in raw_data | |
| ] | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100) | |
| split_documents = text_splitter.split_documents(all_docs) | |
| embeddings = HuggingFaceEmbeddings(model_name="bsmith3715/legal-ft-demo_final") | |
| from langchain_qdrant import QdrantVectorStore | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| client = QdrantClient(":memory:") | |
| client.create_collection( | |
| collection_name="reformer_docs", | |
| vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
| ) | |
| vector_store_ft = QdrantVectorStore( | |
| client=client, | |
| collection_name="reformer_docs", | |
| embedding=embeddings, | |
| ) | |
| # === Set up FAISS vector store === | |
| _ = vector_store_ft.add_documents(documents=split_documents) | |
| retriever_finetune = vector_store_ft.as_retriever(search_kwargs={"k": 5}) | |
| # === Load LLM === | |
| llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, stream=True) | |
| from langchain.prompts import ChatPromptTemplate | |
| RAG_PROMPT = """\ | |
| You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge. | |
| ### Question | |
| {question} | |
| ### Context | |
| {context} | |
| """ | |
| rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
| finetune_rag_chain = ( | |
| {"context": itemgetter("question") | retriever_finetune, "question": itemgetter("question")} | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | llm | StrOutputParser(), "context": itemgetter("context")} | |
| ) | |
| # === Chainlit start event === | |
| async def start(): | |
| await cl.Message(content = | |
| """π Welcome to your Reformer Pilates AI! | |
| Here's what you can do: | |
| β’ Ask questions about Reformer Pilates | |
| β’ Get individualized workouts based on your level, goals, and equipment | |
| β’ Get instant exercise modifications based on injuries or limitations | |
| Let's get started! π""").send() | |
| cl.user_session.set("qa_chain", finetune_rag_chain) | |
| async def main(message): | |
| # Get retriever | |
| chain = cl.user_session.get("qa_chain") | |
| # Run the chain once to get context | |
| inputs = {"question": message.content} | |
| context_and_prompt = await chain.ainvoke(inputs) | |
| # Send a blank message to stream into | |
| msg = cl.Message(content="") | |
| # Call LLM manually with streaming | |
| llm = ChatOpenAI(model_name="gpt-4.1-mini", temperature=0, streaming=True) | |
| full_prompt = rag_prompt.format(**inputs, context=context_and_prompt["context"]) | |
| async for chunk in llm.astream(full_prompt): | |
| await msg.stream_token(chunk.content) # Only stream the text part | |
| await msg.send() | |