NHZ commited on
Commit
3895252
·
verified ·
1 Parent(s): f26d06b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from PyPDF2 import PdfReader
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ from groq import Groq
8
+
9
+ # Initialize Groq Client
10
+ client = Groq(api_key=os.getenv("groq_Api_key"))
11
+
12
+ # Load embedding model
13
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
14
+
15
+ # Initialize FAISS vector store
16
+ dimension = 384 # Embedding dimension of the model
17
+ index = faiss.IndexFlatL2(dimension)
18
+
19
+ # Function to extract text from PDF
20
+ def extract_text_from_pdf(pdf_file):
21
+ reader = PdfReader(pdf_file)
22
+ text = ""
23
+ for page in reader.pages:
24
+ text += page.extract_text()
25
+ return text
26
+
27
+ # Function to split text into chunks
28
+ def chunk_text(text, chunk_size=500):
29
+ words = text.split()
30
+ return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
31
+
32
+ # Function to add embeddings to vector database
33
+ def add_to_vector_db(chunks):
34
+ embeddings = embedding_model.encode(chunks)
35
+ index.add(np.array(embeddings, dtype="float32"))
36
+ return embeddings
37
+
38
+ # Streamlit frontend
39
+ st.title("RAG-based PDF Query Application")
40
+
41
+ # PDF upload
42
+ uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
43
+ if uploaded_file:
44
+ st.write("Processing your PDF...")
45
+ text = extract_text_from_pdf(uploaded_file)
46
+ chunks = chunk_text(text)
47
+ add_to_vector_db(chunks)
48
+ st.success("PDF processed and embeddings stored in the vector database!")
49
+
50
+ # Query input
51
+ query = st.text_input("Enter your query:")
52
+ if query:
53
+ # Generate embedding for query
54
+ query_embedding = embedding_model.encode([query])
55
+
56
+ # Retrieve relevant chunks from FAISS
57
+ distances, indices = index.search(np.array(query_embedding, dtype="float32"), k=5)
58
+ context = "\n".join([chunks[i] for i in indices[0]])
59
+
60
+ # Interact with Groq API
61
+ chat_completion = client.chat.completions.create(
62
+ messages=[
63
+ {
64
+ "role": "user",
65
+ "content": f"Context: {context}\n\nQuery: {query}"
66
+ }
67
+ ],
68
+ model="llama3-8b-8192",
69
+ stream=False,
70
+ )
71
+ response = chat_completion.choices[0].message.content
72
+
73
+ # Display response
74
+ st.write("Response:")
75
+ st.write(response)