NHZ commited on
Commit
daa4a6a
·
verified ·
1 Parent(s): 4617265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -54
app.py CHANGED
@@ -1,69 +1,106 @@
1
- import os
2
- import re
3
- import torch
4
  import numpy as np
5
- from langchain.vectorstores import FAISS
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.document_loaders import PyPDFLoader
8
- from langchain.text_splitter import CharacterTextSplitter
9
- from langchain.chains.question_answering import load_qa_chain
10
- from langchain.prompts import PromptTemplate
11
- from langchain.llms import HuggingFaceHub
12
  import streamlit as st
 
 
13
 
14
- # Environment setup
15
- HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
16
- if not HUGGINGFACEHUB_API_TOKEN:
17
- raise ValueError("HuggingFace API Token is missing.")
18
 
19
- # Initialize HuggingFace embeddings model
20
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
21
 
22
- # Load PDF document from Google Drive
23
- pdf_url = "https://drive.google.com/uc?id=1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0"
24
- loader = PyPDFLoader(pdf_url)
25
- documents = loader.load()
26
 
27
- # Split text into chunks
28
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
29
- texts = text_splitter.split_documents(documents)
 
 
 
 
 
30
 
31
- # Create FAISS vector database
32
- db = FAISS.from_documents(texts, embeddings)
 
 
 
33
 
34
- # Initialize HuggingFace LLM (example model, replace as needed)
35
- llm = HuggingFaceHub(repo_id="bigscience/bloom", model_kwargs={"temperature": 0, "max_length": 512})
 
 
 
 
 
 
 
36
 
37
- # Define custom prompt
38
- prompt_template = """
39
- Use the following pieces of context to answer the question at the end.
40
- If the question cannot be answered based on the context, say "I don't know."
41
 
42
- Context:
43
- {context}
44
 
45
- Question:
46
- {question}
47
 
48
- Answer:
49
- """
50
- prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
 
 
 
 
 
 
51
 
52
- # Load QA chain
53
- qa_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt)
54
 
55
- # Streamlit frontend
56
- st.title("RAG-based Document Q&A")
57
- st.write("Upload a document and ask questions about it.")
58
 
59
- query = st.text_input("Enter your question:")
60
- if query:
61
- # Search vector database
62
- docs = db.similarity_search(query, k=4)
63
-
64
- # Get relevant context
65
- context = "\n\n".join([doc.page_content for doc in docs])
66
-
67
- # Generate answer using LLM
68
- answer = qa_chain.run({"context": context, "question": query})
69
- st.write("**Answer:**", answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
 
 
2
  import numpy as np
3
+ import faiss
4
+ from PyPDF2 import PdfReader
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from groq import Groq
 
 
 
7
  import streamlit as st
8
+ import torch
9
+ import os
10
 
11
+ # Initialize Groq client using secret API key
12
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
13
 
14
+ # Function to download and extract content from a public Google Drive PDF link
15
+ def extract_pdf_content(drive_url):
16
+ # Extract file ID from the Google Drive URL
17
+ file_id = drive_url.split("/d/")[1].split("/view")[0]
18
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
19
 
20
+ # Download the PDF content
21
+ response = requests.get(download_url)
22
+ if response.status_code != 200:
23
+ return None
24
 
25
+ # Save and extract text from the PDF
26
+ with open("document.pdf", "wb") as f:
27
+ f.write(response.content)
28
+ reader = PdfReader("document.pdf")
29
+ text = ""
30
+ for page in reader.pages:
31
+ text += page.extract_text()
32
+ return text
33
 
34
+ # Function to chunk and tokenize text
35
+ def chunk_and_tokenize(text, tokenizer, chunk_size=512):
36
+ tokens = tokenizer.encode(text, add_special_tokens=False)
37
+ chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]
38
+ return chunks
39
 
40
+ # Function to compute embeddings and build FAISS index
41
+ def build_faiss_index(chunks, model):
42
+ embeddings = []
43
+ for chunk in chunks:
44
+ input_ids = torch.tensor([chunk])
45
+ with torch.no_grad():
46
+ embedding = model(input_ids).last_hidden_state.mean(dim=1).detach().numpy()
47
+ embeddings.append(embedding)
48
+ embeddings = np.vstack(embeddings)
49
 
50
+ index = faiss.IndexFlatL2(embeddings.shape[1])
51
+ index.add(embeddings)
52
+ return index
 
53
 
54
+ # Streamlit app
55
+ st.title("RAG-based Application with Groq API")
56
 
57
+ # Predefined Google Drive link
58
+ drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
59
 
60
+ # Extract document content
61
+ st.write("Extracting content from the document...")
62
+ text = extract_pdf_content(drive_url)
63
+ if text:
64
+ st.write("Document extracted successfully!")
65
+
66
+ # Initialize tokenizer and model
67
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
68
+ model = AutoModel.from_pretrained("bert-base-uncased")
69
 
70
+ st.write("Chunking and tokenizing content...")
71
+ chunks = chunk_and_tokenize(text, tokenizer)
72
 
73
+ st.write("Building FAISS index...")
74
+ index = build_faiss_index(chunks, model)
 
75
 
76
+ # Query input
77
+ query = st.text_input("Enter your query:")
78
+ if query:
79
+ st.write("Searching for the most relevant chunk...")
80
+ query_tokens = tokenizer.encode(query, add_special_tokens=False)
81
+ query_embedding = (
82
+ model(torch.tensor([query_tokens]))
83
+ .last_hidden_state.mean(dim=1)
84
+ .detach().numpy()
85
+ )
86
+ _, indices = index.search(query_embedding, k=1)
87
+
88
+ # Retrieve the most relevant chunk
89
+ relevant_chunk = chunks[indices[0][0]]
90
+ relevant_text = tokenizer.decode(relevant_chunk)
91
+ st.write("Relevant chunk found:", relevant_text)
92
+
93
+ # Interact with Groq API
94
+ st.write("Querying the Groq API...")
95
+ chat_completion = client.chat.completions.create(
96
+ messages=[
97
+ {
98
+ "role": "user",
99
+ "content": relevant_text,
100
+ }
101
+ ],
102
+ model="llama-3.3-70b-versatile",
103
+ )
104
+ st.write("Model Response:", chat_completion.choices[0].message.content)
105
+ else:
106
+ st.error("Failed to extract content from the document.")