MohammadYaseen commited on
Commit
e6fb287
·
verified ·
1 Parent(s): 03b3412

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import PyPDF2
4
+ import docx
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ from groq import Groq
8
+ import streamlit as st
9
+
10
+ # Initialize Groq API Client
11
+ client = Groq(api_key="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")
12
+
13
+ # SentenceTransformer model for embeddings
14
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
15
+
16
+ # FAISS index for retrieval
17
+ dimension = 384 # Dimension of 'all-MiniLM-L6-v2' embeddings
18
+ index = faiss.IndexFlatL2(dimension)
19
+ document_texts = [] # Store text corresponding to embeddings
20
+
21
+ # Helper function: Extract text from different file types
22
+ def extract_text_from_file(file):
23
+ text = ""
24
+ if file.name.endswith(".pdf"):
25
+ pdf_reader = PyPDF2.PdfReader(file)
26
+ for page in pdf_reader.pages:
27
+ text += page.extract_text()
28
+ elif file.name.endswith(".csv"):
29
+ df = pd.read_csv(file)
30
+ text = "\n".join([" ".join(map(str, row)) for row in df.values])
31
+ elif file.name.endswith(".xlsx") or file.name.endswith(".xls"):
32
+ df = pd.read_excel(file)
33
+ text = "\n".join([" ".join(map(str, row)) for row in df.values])
34
+ elif file.name.endswith(".txt"):
35
+ text = file.read().decode("utf-8")
36
+ elif file.name.endswith(".docx"):
37
+ doc = docx.Document(file)
38
+ text = "\n".join([p.text for p in doc.paragraphs])
39
+ else:
40
+ text = None
41
+ return text
42
+
43
+ # Add document embeddings to FAISS
44
+ def add_to_index(text, index, document_texts):
45
+ sentences = text.split("\n")
46
+ embeddings = embedding_model.encode(sentences, convert_to_numpy=True)
47
+ index.add(embeddings)
48
+ document_texts.extend(sentences)
49
+
50
+ # Perform RAG Query
51
+ def rag_query(query, index, document_texts, top_k=3):
52
+ """
53
+ Perform a RAG query: Retrieve relevant documents and generate a response.
54
+ """
55
+ # Generate query embedding and retrieve closest matches
56
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True)
57
+ distances, indices = index.search(query_embedding, top_k)
58
+
59
+ # Build the context from retrieved documents
60
+ retrieved_context = " ".join([document_texts[idx] for idx in indices[0]])
61
+
62
+ # Construct the prompt for the Groq model
63
+ prompt = f"Context: {retrieved_context}\n\nQuestion: {query}"
64
+
65
+ # Generate a response using Groq API
66
+ chat_completion = client.chat.completions.create(
67
+ messages=[
68
+ {"role": "user", "content": prompt}
69
+ ],
70
+ model="gemma2-9b-it",
71
+ )
72
+ return chat_completion.choices[0].message.content
73
+
74
+ # Streamlit UI
75
+ st.title("RAG-Based Document Q&A")
76
+ st.write("Upload your documents and ask questions based on the content.")
77
+
78
+ uploaded_files = st.file_uploader(
79
+ "Upload PDFs, CSVs, Excel, or Text files",
80
+ type=["pdf", "csv", "xlsx", "xls", "txt", "docx"],
81
+ accept_multiple_files=True,
82
+ )
83
+
84
+ if uploaded_files:
85
+ for file in uploaded_files:
86
+ with st.spinner(f"Processing {file.name}..."):
87
+ text = extract_text_from_file(file)
88
+ if text:
89
+ add_to_index(text, index, document_texts)
90
+ st.success(f"Processed {file.name}")
91
+ else:
92
+ st.error(f"Could not process {file.name}. Unsupported file format.")
93
+
94
+ query = st.text_input("Enter your question:")
95
+ if query:
96
+ with st.spinner("Generating response..."):
97
+ response = rag_query(query, index, document_texts)
98
+ st.write("### Answer:")
99
+ st.write(response)