File size: 2,544 Bytes
2d79a6a
 
 
 
 
 
9998a26
2d79a6a
 
b46c6e1
2d79a6a
 
 
 
9998a26
 
2d79a6a
9998a26
2d79a6a
 
 
9998a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d79a6a
 
 
 
 
 
 
 
 
 
 
 
9998a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d79a6a
 
 
9998a26
2d79a6a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import pipeline
import os

# Page setup
st.title("Ai-Buddy Chatbot")

# Load and process PDF
@st.cache_resource
def initialize_system():
    # Set up persistent directory for Chroma
    persist_directory = "chroma_db"
    
    # Create embeddings
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )
    
    # Check if database already exists
    if not os.path.exists(persist_directory):
        # Load PDF
        data = PyPDFLoader("ai_buddy.pdf").load()
        
        # Split into chunks
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=750,
            chunk_overlap=150
        )
        splits = splitter.split_documents(data)
        
        # Create and persist vector store
        vector_db = Chroma.from_documents(
            documents=splits,
            embedding=embeddings,
            persist_directory=persist_directory
        )
        vector_db.persist()
    else:
        # Load existing database
        vector_db = Chroma(
            persist_directory=persist_directory,
            embedding_function=embeddings
        )
    
    # Setup QA pipeline
    qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
    
    return vector_db, qa_model

# Initialize the system
if 'vector_db' not in st.session_state:
    st.session_state.vector_db, st.session_state.qa_model = initialize_system()

# Function to answer questions
def get_answer(question):
    try:
        # Get relevant documents
        docs = st.session_state.vector_db.as_retriever().get_relevant_documents(question)
        
        if not docs:
            return "Sorry, I couldn't find any relevant information."
            
        # Combine document contents
        context = " ".join([doc.page_content for doc in docs])
        
        # Get answer
        response = st.session_state.qa_model(
            question=question,
            context=context
        )
        return response['answer']
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Simple input/output interface
question = st.text_input("Ask your question:")

if question:
    with st.spinner("Finding answer..."):
        answer = get_answer(question)
        st.write("Answer:", answer)