muddasser commited on
Commit
b785aa3
·
verified ·
1 Parent(s): e5a6972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -84
app.py CHANGED
@@ -1,84 +1,99 @@
1
- import streamlit as st
2
- from selenium import webdriver
3
- from selenium.webdriver.chrome.service import Service
4
- from webdriver_manager.chrome import ChromeDriverManager
5
- from selenium.webdriver.chrome.options import Options
6
- import time
7
-
8
- from sentence_transformers import SentenceTransformer
9
- import faiss
10
- import numpy as np
11
- from transformers import pipeline
12
-
13
- # -------------------------------
14
- # 1. Setup Selenium (Headless Chrome for Hugging Face/Streamlit)
15
- # -------------------------------
16
- def init_driver():
17
- chrome_options = Options()
18
- chrome_options.add_argument("--headless")
19
- chrome_options.add_argument("--disable-gpu")
20
- chrome_options.add_argument("--no-sandbox")
21
- chrome_options.add_argument("--disable-dev-shm-usage")
22
-
23
- service = Service(ChromeDriverManager().install())
24
- driver = webdriver.Chrome(service=service, options=chrome_options)
25
- return driver
26
-
27
- # -------------------------------
28
- # 2. Scrape website text with Selenium
29
- # -------------------------------
30
- def scrape_website(url):
31
- driver = init_driver()
32
- driver.get(url)
33
- time.sleep(3) # wait for JS to load
34
- text = driver.page_source # raw HTML
35
- driver.quit()
36
- return text
37
-
38
- # -------------------------------
39
- # 3. Embed and store in FAISS
40
- # -------------------------------
41
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
42
- dimension = 384
43
- index = faiss.IndexFlatL2(dimension)
44
-
45
- documents = []
46
-
47
- def add_to_faiss(text):
48
- global documents
49
- embedding = embedder.encode([text])
50
- index.add(np.array(embedding, dtype="float32"))
51
- documents.append(text)
52
-
53
- def retrieve(query, k=1):
54
- q_emb = embedder.encode([query])
55
- D, I = index.search(np.array(q_emb, dtype="float32"), k)
56
- return [documents[i] for i in I[0]]
57
-
58
- # -------------------------------
59
- # 4. QA Model (FLAN-T5-small)
60
- # -------------------------------
61
- qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-small")
62
-
63
- def answer_query(query):
64
- context_docs = retrieve(query, k=1)
65
- context = " ".join(context_docs)
66
- prompt = f"Answer the question based on context:\nContext: {context}\nQuestion: {query}"
67
- result = qa_pipeline(prompt, max_length=256, do_sample=False)
68
- return result[0]['generated_text']
69
-
70
- # -------------------------------
71
- # 5. Streamlit App
72
- # -------------------------------
73
- st.title("🌐 Web Scraping + RAG (Selenium + FLAN-T5-small)")
74
-
75
- url = st.text_input("Enter website URL:")
76
- if url and st.button("Scrape & Index"):
77
- scraped_text = scrape_website(url)
78
- add_to_faiss(scraped_text)
79
- st.success("✅ Website scraped and indexed successfully!")
80
-
81
- query = st.text_input("Ask a question:")
82
- if query and st.button("Get Answer"):
83
- answer = answer_query(query)
84
- st.write("**Answer:**", answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from selenium import webdriver
3
+ from selenium.webdriver.chrome.options import Options
4
+ import time
5
+
6
+ # -------------------------
7
+ # FIX for huggingface_hub cached_download issue
8
+ # -------------------------
9
+ import huggingface_hub
10
+ if not hasattr(huggingface_hub, "cached_download"):
11
+ from huggingface_hub import hf_hub_download
12
+ huggingface_hub.cached_download = hf_hub_download
13
+
14
+ # -------------------------
15
+ # RAG + NLP libraries
16
+ # -------------------------
17
+ from sentence_transformers import SentenceTransformer
18
+ import faiss
19
+ import numpy as np
20
+ from transformers import pipeline
21
+
22
+
23
+ # -------------------------
24
+ # 1️⃣ Function: Scrape website using Selenium
25
+ # -------------------------
26
+ def scrape_with_selenium(url: str):
27
+ chrome_options = Options()
28
+ chrome_options.add_argument("--headless")
29
+ chrome_options.add_argument("--no-sandbox")
30
+ chrome_options.add_argument("--disable-dev-shm-usage")
31
+
32
+ driver = webdriver.Chrome(options=chrome_options)
33
+ driver.get(url)
34
+ time.sleep(2)
35
+
36
+ # Scrape all visible text
37
+ paragraphs = driver.find_elements("tag name", "p")
38
+ text_data = [p.text for p in paragraphs if p.text.strip()]
39
+ driver.quit()
40
+
41
+ return text_data
42
+
43
+
44
+ # -------------------------
45
+ # 2️⃣ Function: Build FAISS Index
46
+ # -------------------------
47
+ def build_faiss_index(text_data):
48
+ model = SentenceTransformer("all-MiniLM-L6-v2")
49
+ embeddings = model.encode(text_data, convert_to_numpy=True)
50
+
51
+ dim = embeddings.shape[1]
52
+ index = faiss.IndexFlatL2(dim)
53
+ index.add(embeddings)
54
+
55
+ return model, index, text_data
56
+
57
+
58
+ # -------------------------
59
+ # 3️⃣ Function: Query RAG
60
+ # -------------------------
61
+ def query_rag(question, model, index, text_data):
62
+ q_embedding = model.encode([question], convert_to_numpy=True)
63
+ D, I = index.search(q_embedding, k=3)
64
+ retrieved = [text_data[i] for i in I[0]]
65
+
66
+ # Generate answer using Flan-T5
67
+ generator = pipeline("text2text-generation", model="google/flan-t5-small")
68
+ context = " ".join(retrieved)
69
+ prompt = f"Answer the question using the context:\nContext: {context}\nQuestion: {question}"
70
+ answer = generator(prompt, max_length=150, do_sample=True)[0]["generated_text"]
71
+
72
+ return answer, retrieved
73
+
74
+
75
+ # -------------------------
76
+ # 4️⃣ Streamlit UI
77
+ # -------------------------
78
+ st.title("🚀 Web Scraping + RAG with Selenium")
79
+
80
+ url = st.text_input("Enter a website URL:", "https://quotes.toscrape.com/")
81
+ if st.button("Scrape Website"):
82
+ with st.spinner("Scraping website..."):
83
+ scraped_text = scrape_with_selenium(url)
84
+ st.success(f" Scraped {len(scraped_text)} paragraphs!")
85
+
86
+ st.session_state["scraped_text"] = scraped_text
87
+
88
+ if "scraped_text" in st.session_state:
89
+ question = st.text_input("Ask a question based on scraped data:")
90
+ if st.button("Get Answer"):
91
+ model, index, text_data = build_faiss_index(st.session_state["scraped_text"])
92
+ answer, retrieved = query_rag(question, model, index, text_data)
93
+
94
+ st.subheader("🔍 Retrieved Context")
95
+ for r in retrieved:
96
+ st.write("-", r)
97
+
98
+ st.subheader("💡 Answer")
99
+ st.write(answer)