NHZ commited on
Commit
0ac9077
·
verified ·
1 Parent(s): 0004542

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import PyPDF2
4
+ import faiss
5
+ import numpy as np
6
+ import streamlit as st
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from groq import Groq
9
+
10
+ # Download file from Google Drive link
11
+ def download_file_from_drive(url):
12
+ file_id = url.split("/d/")[1].split("/")[0]
13
+ download_url = f"https://drive.google.com/uc?id={file_id}&export=download"
14
+ response = requests.get(download_url)
15
+ pdf_path = "document.pdf"
16
+ with open(pdf_path, "wb") as f:
17
+ f.write(response.content)
18
+ return pdf_path
19
+
20
+ # Extract text from PDF
21
+ def extract_text_from_pdf(pdf_path):
22
+ with open(pdf_path, "rb") as f:
23
+ reader = PyPDF2.PdfReader(f)
24
+ text = " ".join(page.extract_text() for page in reader.pages)
25
+ return text
26
+
27
+ # Chunk text
28
+ def chunk_text(text, chunk_size=500):
29
+ words = text.split()
30
+ chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
31
+ return chunks
32
+
33
+ # Generate embeddings
34
+ def generate_embeddings(chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModel.from_pretrained(model_name)
37
+ embeddings = []
38
+ for chunk in chunks:
39
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
40
+ outputs = model(**inputs)
41
+ embeddings.append(outputs.last_hidden_state.mean(dim=1).detach().numpy())
42
+ return np.vstack(embeddings)
43
+
44
+ # Store embeddings in FAISS
45
+ def create_faiss_index(embeddings):
46
+ dimension = embeddings.shape[1]
47
+ index = faiss.IndexFlatL2(dimension)
48
+ index.add(embeddings)
49
+ return index
50
+
51
+ # Groq API Integration
52
+ def query_groq_api(query, api_key):
53
+ client = Groq(api_key=api_key)
54
+ chat_completion = client.chat.completions.create(
55
+ messages=[
56
+ {
57
+ "role": "user",
58
+ "content": query,
59
+ }
60
+ ],
61
+ model="llama-3.3-70b-versatile",
62
+ )
63
+ return chat_completion.choices[0].message.content
64
+
65
+ # Streamlit App
66
+ def main():
67
+ st.title("RAG-based Application")
68
+ st.sidebar.title("Settings")
69
+
70
+ groq_api_key = st.sidebar.text_input("Enter your Groq API Key", type="password")
71
+ google_drive_url = st.sidebar.text_input("Enter Google Drive File Link")
72
+
73
+ if st.sidebar.button("Process Document"):
74
+ st.info("Downloading document...")
75
+ pdf_path = download_file_from_drive(google_drive_url)
76
+ st.success("Document downloaded successfully!")
77
+
78
+ st.info("Extracting text...")
79
+ text = extract_text_from_pdf(pdf_path)
80
+ st.success("Text extracted successfully!")
81
+
82
+ st.info("Chunking text...")
83
+ chunks = chunk_text(text)
84
+ st.success(f"Document chunked into {len(chunks)} chunks.")
85
+
86
+ st.info("Generating embeddings...")
87
+ embeddings = generate_embeddings(chunks)
88
+ st.success("Embeddings generated successfully!")
89
+
90
+ st.info("Creating FAISS index...")
91
+ index = create_faiss_index(embeddings)
92
+ st.success("FAISS index created successfully!")
93
+
94
+ st.session_state.index = index
95
+ st.session_state.chunks = chunks
96
+
97
+ if "index" in st.session_state:
98
+ query = st.text_input("Ask a question:")
99
+ if st.button("Search"):
100
+ st.info("Querying FAISS index...")
101
+ query_embeddings = generate_embeddings([query])
102
+ distances, indices = st.session_state.index.search(query_embeddings, k=5)
103
+ relevant_chunks = [st.session_state.chunks[i] for i in indices[0]]
104
+ st.success("Relevant chunks retrieved!")
105
+
106
+ st.info("Generating answer via Groq API...")
107
+ context = " ".join(relevant_chunks)
108
+ answer = query_groq_api(context + "\n" + query, api_key=groq_api_key)
109
+ st.success("Answer generated!")
110
+ st.write(answer)
111
+
112
+ if __name__ == "__main__":
113
+ main()