NHZ commited on
Commit
636755b
·
verified ·
1 Parent(s): b891ccb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -118
app.py CHANGED
@@ -1,122 +1,106 @@
1
  import os
2
- import requests
3
- import torch
4
- from transformers import AutoTokenizer, AutoModel
5
- from PyPDF2 import PdfReader
6
- from langchain.vectorstores import FAISS
7
- from langchain.chains import RetrievalQA
8
- from langchain.prompts import PromptTemplate
9
- from langchain.llms.base import LLM
10
- from pydantic import Field
11
- from typing import Optional, List
12
  import streamlit as st
13
-
14
- # Custom wrapper for Groq API
15
- class GroqLLM(LLM):
16
- api_key: str = Field(..., description="API key for Groq")
17
- model: str = "llama-3.3-70b-versatile"
18
-
19
- @property
20
- def _llm_type(self) -> str:
21
- return "groq"
22
-
23
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
24
- headers = {"Authorization": f"Bearer {self.api_key}"}
25
- json_data = {
26
- "model": self.model,
27
- "messages": [{"role": "user", "content": prompt}],
28
- }
29
-
30
- response = requests.post(
31
- "https://api.groq.com/v1/chat/completions", headers=headers, json=json_data
32
- )
33
-
34
- if response.status_code != 200:
35
- raise ValueError(f"Groq API call failed: {response.status_code}, {response.text}")
36
-
37
- data = response.json()
38
- return data["choices"][0]["message"]["content"]
39
-
40
- # Initialize Groq API LLM
41
- llm = GroqLLM(api_key="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")
42
-
43
- # Function to extract content from a public Google Drive PDF link
44
- def extract_pdf_content(drive_url):
45
- file_id = drive_url.split("/d/")[1].split("/view")[0]
46
- download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
47
- response = requests.get(download_url)
48
- if response.status_code != 200:
49
- return None
50
-
51
- with open("document.pdf", "wb") as f:
52
- f.write(response.content)
53
-
54
- reader = PdfReader("document.pdf")
55
- text = ""
56
- for page in reader.pages:
57
- text += page.extract_text()
58
  return text
59
 
60
- # Function to create a FAISS vector store
61
- def create_vector_store(text):
62
- # Split the text into sentences and clean it
63
- sentences = [sentence.strip() for sentence in text.split(". ") if sentence.strip()]
64
-
65
- # Load the model and tokenizer from Hugging Face
66
- model_name = "sentence-transformers/all-MiniLM-L6-v2"
67
- tokenizer = AutoTokenizer.from_pretrained(model_name)
68
- model = AutoModel.from_pretrained(model_name)
69
-
70
- def embed(sentence):
71
- tokens = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
72
- with torch.no_grad():
73
- embeddings = model(**tokens).last_hidden_state.mean(dim=1).squeeze().numpy()
74
- return embeddings
75
-
76
- # Create a FAISS vector store
77
- vector_store = FAISS.from_texts(
78
- texts=sentences, embedding=lambda x: embed(x)
79
- )
80
-
81
- return vector_store, sentences
82
-
83
- # Streamlit app
84
- st.title("RAG-based Application with Focused Context")
85
-
86
- drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
87
- text = extract_pdf_content(drive_url)
88
-
89
- if text:
90
- st.write("Document extracted successfully!")
91
- vector_store, sentences = create_vector_store(text)
92
- st.write("Vector store created!")
93
-
94
- query = st.text_input("Enter your query:")
95
- if query:
96
- retriever = vector_store.as_retriever()
97
- retriever.search_kwargs["k"] = 3
98
-
99
- prompt_template = PromptTemplate(
100
- template="""
101
- Use the following context to answer the question:
102
-
103
- {context}
104
-
105
- Question: {question}
106
- Answer:""",
107
- input_variables=["context", "question"]
108
- )
109
-
110
- qa_chain = RetrievalQA.from_chain_type(
111
- retriever=retriever,
112
- llm=llm,
113
- chain_type="stuff",
114
- return_source_documents=True
115
- )
116
-
117
- response = qa_chain({"query": query})
118
- answer = response["result"]
119
-
120
- st.write("Answer:", answer)
121
- else:
122
- st.error("Failed to extract content from the document.")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
+ import requests
4
+ import PyPDF2
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ import nltk
8
+ from groq import Groq
9
+
10
+ nltk.download('punkt')
11
+
12
+ # Initialize Groq client
13
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
14
+
15
+ # Function to extract text from a PDF
16
+ def extract_text_from_pdf(pdf_url):
17
+ # Convert Google Drive shareable link to direct download link
18
+ direct_url = pdf_url.replace("/view?usp=sharing", "").replace("file/d/", "uc?id=")
19
+ response = requests.get(direct_url)
20
+ pdf_content = response.content
21
+ with open("temp.pdf", "wb") as f:
22
+ f.write(pdf_content)
23
+
24
+ # Read the PDF content
25
+ with open("temp.pdf", "rb") as f:
26
+ reader = PyPDF2.PdfReader(f)
27
+ text = ""
28
+ for page in reader.pages:
29
+ text += page.extract_text()
30
+ os.remove("temp.pdf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return text
32
 
33
+ # Function to chunk text
34
+ def chunk_text(text, chunk_size=300):
35
+ sentences = nltk.sent_tokenize(text)
36
+ chunks = []
37
+ current_chunk = []
38
+ current_length = 0
39
+
40
+ for sentence in sentences:
41
+ current_length += len(sentence.split())
42
+ if current_length <= chunk_size:
43
+ current_chunk.append(sentence)
44
+ else:
45
+ chunks.append(" ".join(current_chunk))
46
+ current_chunk = [sentence]
47
+ current_length = len(sentence.split())
48
+
49
+ if current_chunk:
50
+ chunks.append(" ".join(current_chunk))
51
+ return chunks
52
+
53
+ # Function to create embeddings and store them in FAISS
54
+ def create_faiss_index(chunks):
55
+ model = SentenceTransformer("all-MiniLM-L6-v2")
56
+ embeddings = model.encode(chunks)
57
+ dimension = embeddings.shape[1]
58
+ index = faiss.IndexFlatL2(dimension)
59
+ index.add(embeddings)
60
+ return index, embeddings
61
+
62
+ # Function to query FAISS
63
+ def query_faiss(index, query, chunks, model):
64
+ query_vector = model.encode([query])
65
+ distances, indices = index.search(query_vector, k=3)
66
+ results = [chunks[i] for i in indices[0]]
67
+ return results
68
+
69
+ # Main Streamlit App
70
+ def main():
71
+ st.title("RAG-based Application")
72
+ st.write("Interact with your document using Groq-powered model.")
73
+
74
+ # Pre-defined document link
75
+ doc_link = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
76
+
77
+ # Extract Document Content
78
+ if "document_text" not in st.session_state:
79
+ st.write("Extracting document content...")
80
+ text = extract_text_from_pdf(doc_link)
81
+ st.session_state['document_text'] = text
82
+ st.success("Document content extracted!")
83
+
84
+ # Process Document and Create FAISS Index
85
+ if 'document_text' in st.session_state and "faiss_index" not in st.session_state:
86
+ st.write("Processing document...")
87
+ chunks = chunk_text(st.session_state['document_text'])
88
+ index, embeddings = create_faiss_index(chunks)
89
+ st.session_state['faiss_index'] = index
90
+ st.session_state['chunks'] = chunks
91
+ st.session_state['model'] = SentenceTransformer("all-MiniLM-L6-v2")
92
+ st.success(f"Document processed into {len(chunks)} chunks!")
93
+
94
+ # Query the Document
95
+ if 'faiss_index' in st.session_state:
96
+ st.header("Ask Questions")
97
+ query = st.text_input("Enter your question here")
98
+ if st.button("Query Document"):
99
+ results = query_faiss(st.session_state['faiss_index'], query, st.session_state['chunks'], st.session_state['model'])
100
+ st.write("### Results from Document:")
101
+ for i, result in enumerate(results):
102
+ st.write(f"**Result {i+1}:** {result}")
103
+
104
+ # Use Groq API for additional insights
105
+ chat_completion = client.chat.completions.create(
106
+ messa