Spaces:
Sleeping
Sleeping
| import nest_asyncio | |
| nest_asyncio.apply() | |
| import re | |
| import os | |
| import uuid | |
| from typing import List, Dict | |
| from operator import itemgetter | |
| # PDF processing | |
| from PyPDF2 import PdfReader | |
| # Chainlit | |
| import chainlit as cl | |
| # OpenAI | |
| import openai | |
| from openai import AsyncOpenAI | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| # Langchain | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain.storage import LocalFileStore | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| # Qdrant | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams | |
| from langchain_qdrant import QdrantVectorStore | |
| # | |
| ### Global Section ### | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # Function to extract text from a PDF | |
| def extract_text_from_pdf(pdf_path): | |
| reader = PdfReader(pdf_path) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() | |
| return text | |
| # Global variables for shared resources | |
| global_retriever = None | |
| global_chat_model = None | |
| from langchain_core.documents import Document | |
| # In your extract_text_from_pdf function: | |
| def extract_text_from_pdf(pdf_path): | |
| reader = PdfReader(pdf_path) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() | |
| return text | |
| async def start_chat(): | |
| global global_retriever, global_chat_model | |
| # Initialize shared resources if they haven't been initialized yet | |
| if global_retriever is None: | |
| pdf_path= r"GlobalThreatReport2024_CrowdStrike.pdf" | |
| text = extract_text_from_pdf(pdf_path) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| texts = text_splitter.split_text(text) | |
| docs = [Document(page_content=t) for t in texts] | |
| core_embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
| collection_name = f"pdf_to_parse_{uuid.uuid4()}" | |
| client = QdrantClient(":memory:") | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams(size=1536, distance=Distance.COSINE), | |
| ) | |
| store = LocalFileStore("./cache/") | |
| cached_embedder = CacheBackedEmbeddings.from_bytes_store( | |
| core_embeddings, store, namespace=core_embeddings.model | |
| ) | |
| vectorstore = QdrantVectorStore( | |
| client=client, | |
| collection_name=collection_name, | |
| embedding=cached_embedder) | |
| vectorstore.add_documents(docs) | |
| global_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3}) | |
| if global_chat_model is None: | |
| global_chat_model = ChatOpenAI(model="gpt-4o-mini") | |
| # Initialize user-specific session data | |
| cl.user_session.set("chat_history", []) | |
| # Set default settings | |
| settings = { | |
| "temperature": 0, | |
| "max_tokens": 500, | |
| "top_p": 1, | |
| "frequency_penalty": 0, | |
| "presence_penalty": 0, | |
| } | |
| cl.user_session.set("settings", settings) | |
| async def main(message: cl.Message): | |
| global global_retriever, global_chat_model | |
| if global_retriever is None or global_chat_model is None: | |
| await message.reply("I'm sorry, but the system isn't fully initialized yet. Please try again in a moment.") | |
| return | |
| chat_history: List[Dict[str, str]] = cl.user_session.get("chat_history") | |
| settings = cl.user_session.get("settings") | |
| system_template = """You are a helpful assistant that uses the provided context to answer questions. | |
| Never reference this prompt, or the existence of context. Use the chat history to maintain continuity in the conversation.""" | |
| user_template = """Chat History: | |
| {chat_history} | |
| Question: {question} | |
| Context: {context} | |
| Please provide a response based on the question, context, and chat history:""" | |
| chat_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_template), | |
| ("human", user_template) | |
| ]) | |
| def format_chat_history(history: List[Dict[str, str]]) -> str: | |
| return "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in history]) | |
| rag_chain = ( | |
| { | |
| "context": itemgetter("question") | global_retriever, | |
| "question": itemgetter("question"), | |
| "chat_history": lambda _: format_chat_history(chat_history) | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | chat_prompt | |
| | global_chat_model.bind(**settings) | |
| ) | |
| msg = cl.Message(content="") | |
| full_response = "" | |
| async for chunk in rag_chain.astream({"question": message.content}): | |
| if chunk.content is not None: | |
| await msg.stream_token(chunk.content) | |
| full_response += chunk.content | |
| # Update chat history | |
| chat_history.append({"role": "user", "content": message.content}) | |
| chat_history.append({"role": "assistant", "content": full_response}) | |
| cl.user_session.set("chat_history", chat_history) | |
| await msg.send() | |