pradeep4321 commited on
Commit
2f54431
Β·
verified Β·
1 Parent(s): ec67f77

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +75 -54
src/streamlit_app.py CHANGED
@@ -1,5 +1,5 @@
1
  # =========================================================
2
- # 🌐 WEBSITE RAG + IMAGE UNDERSTANDING (HF SPACES)
3
  # =========================================================
4
 
5
  import streamlit as st
@@ -10,6 +10,7 @@ import faiss
10
  import torch
11
  from PIL import Image
12
  from io import BytesIO
 
13
 
14
  from sentence_transformers import SentenceTransformer
15
  from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
@@ -20,16 +21,18 @@ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
20
  st.set_page_config(page_title="🌐 Website QA System", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODELS
24
  # ==============================
25
  @st.cache_resource
26
  def load_models():
27
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
28
 
 
29
  qa_pipeline = pipeline(
30
- "text2text-generation",
31
  model="google/flan-t5-base",
32
- max_length=256
 
33
  )
34
 
35
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
@@ -50,47 +53,50 @@ if "documents" not in st.session_state:
50
  if "index" not in st.session_state:
51
  st.session_state.index = None
52
 
 
 
 
53
  # ==============================
54
  # CRAWL WEBSITE
55
  # ==============================
56
  def crawl_website(url):
57
  try:
58
- res = requests.get(url)
59
  soup = BeautifulSoup(res.text, "html.parser")
60
 
61
- links = []
 
62
  for a in soup.find_all("a", href=True):
63
- link = a["href"]
64
  if link.startswith("http"):
65
- links.append(link)
66
 
67
- return list(set(links))[:20] # limit
68
- except:
 
69
  return []
70
 
71
  # ==============================
72
- # EXTRACT CONTENT
73
  # ==============================
74
  def extract_content(url):
75
  try:
76
- res = requests.get(url)
77
  soup = BeautifulSoup(res.text, "html.parser")
78
 
79
  # TEXT
80
- paragraphs = [p.get_text() for p in soup.find_all("p")]
81
  text = " ".join(paragraphs)
82
 
83
  # IMAGES β†’ CAPTION
84
  image_texts = []
85
  images = soup.find_all("img")
86
 
87
- for img in images[:5]: # limit images
88
  try:
89
- img_url = img.get("src")
90
- if not img_url.startswith("http"):
91
- continue
92
 
93
- img_res = requests.get(img_url)
94
  image = Image.open(BytesIO(img_res.content)).convert("RGB")
95
 
96
  inputs = processor(image, return_tensors="pt")
@@ -102,8 +108,7 @@ def extract_content(url):
102
  except:
103
  continue
104
 
105
- full_text = text + " " + " ".join(image_texts)
106
- return full_text
107
 
108
  except:
109
  return ""
@@ -113,13 +118,10 @@ def extract_content(url):
113
  # ==============================
114
  def chunk_text(text, size=300):
115
  words = text.split()
116
- chunks = []
117
- for i in range(0, len(words), size):
118
- chunks.append(" ".join(words[i:i+size]))
119
- return chunks
120
 
121
  # ==============================
122
- # BUILD INDEX
123
  # ==============================
124
  def build_index(texts):
125
  embeddings = embed_model.encode(texts)
@@ -128,13 +130,24 @@ def build_index(texts):
128
  index = faiss.IndexFlatL2(dim)
129
  index.add(np.array(embeddings))
130
 
131
- return index, embeddings
 
 
 
 
 
 
 
 
132
 
133
  # ==============================
134
  # UI
135
  # ==============================
136
- st.title("🌐 Website QA with Images")
137
 
 
 
 
138
  url = st.text_input("πŸ”— Enter Website URL")
139
 
140
  if st.button("Crawl Website"):
@@ -144,15 +157,16 @@ if st.button("Crawl Website"):
144
  st.session_state.links = links
145
  st.success(f"Found {len(links)} pages")
146
  else:
147
- st.error("No links found")
148
 
149
  # ==============================
150
- # PAGE SELECTION
151
  # ==============================
152
- if "links" in st.session_state:
153
- st.subheader("Select Pages to Train")
 
 
154
 
155
- selected_links = []
156
  for link in st.session_state.links:
157
  if st.checkbox(link):
158
  selected_links.append(link)
@@ -167,35 +181,37 @@ if "links" in st.session_state:
167
  all_chunks.extend(chunks)
168
 
169
  if all_chunks:
170
- index, embeddings = build_index(all_chunks)
171
-
172
  st.session_state.documents = all_chunks
173
- st.session_state.index = index
174
 
175
- st.success("Training completed!")
 
 
176
 
177
  # ==============================
178
- # ADD MORE PAGES
179
  # ==============================
180
- if "links" in st.session_state:
181
- st.subheader("βž• Add More Pages")
182
-
183
- new_url = st.text_input("Add another URL")
184
 
185
- if st.button("Add & Train"):
186
- content = extract_content(new_url)
187
- chunks = chunk_text(content)
188
 
189
- if chunks:
190
- new_embeddings = embed_model.encode(chunks)
 
191
 
192
- st.session_state.index.add(np.array(new_embeddings))
193
- st.session_state.documents.extend(chunks)
 
 
 
 
194
 
195
- st.success("Added new page!")
 
 
196
 
197
  # ==============================
198
- # ASK QUESTIONS
199
  # ==============================
200
  st.subheader("πŸ’¬ Ask Questions")
201
 
@@ -203,7 +219,7 @@ query = st.text_input("Ask something from the website")
203
 
204
  if st.button("Get Answer"):
205
  if st.session_state.index is None:
206
- st.warning("Please train pages first")
207
  else:
208
  q_embed = embed_model.encode([query])
209
 
@@ -212,16 +228,21 @@ if st.button("Get Answer"):
212
  context = " ".join([st.session_state.documents[i] for i in I[0]])
213
 
214
  prompt = f"""
215
- Answer the question based on the context.
216
 
217
  Context:
218
  {context}
219
 
220
  Question:
221
  {query}
 
 
222
  """
223
 
224
- answer = qa_pipeline(prompt)[0]["generated_text"]
 
 
 
225
 
226
  st.write("### βœ… Answer")
227
- st.write(answer)
 
1
  # =========================================================
2
+ # 🌐 WEBSITE RAG + IMAGE QA (HF SPACES FIXED VERSION)
3
  # =========================================================
4
 
5
  import streamlit as st
 
10
  import torch
11
  from PIL import Image
12
  from io import BytesIO
13
+ from urllib.parse import urljoin
14
 
15
  from sentence_transformers import SentenceTransformer
16
  from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
 
21
  st.set_page_config(page_title="🌐 Website QA System", layout="wide")
22
 
23
  # ==============================
24
+ # LOAD MODELS (FIXED)
25
  # ==============================
26
  @st.cache_resource
27
  def load_models():
28
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
29
 
30
+ # βœ… FIX: use text-generation instead of text2text-generation
31
  qa_pipeline = pipeline(
32
+ "text-generation",
33
  model="google/flan-t5-base",
34
+ max_length=256,
35
+ do_sample=False
36
  )
37
 
38
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
53
  if "index" not in st.session_state:
54
  st.session_state.index = None
55
 
56
+ if "links" not in st.session_state:
57
+ st.session_state.links = []
58
+
59
  # ==============================
60
  # CRAWL WEBSITE
61
  # ==============================
62
  def crawl_website(url):
63
  try:
64
+ res = requests.get(url, timeout=10)
65
  soup = BeautifulSoup(res.text, "html.parser")
66
 
67
+ links = set()
68
+
69
  for a in soup.find_all("a", href=True):
70
+ link = urljoin(url, a["href"]) # βœ… FIX relative links
71
  if link.startswith("http"):
72
+ links.add(link)
73
 
74
+ return list(links)[:20]
75
+
76
+ except Exception as e:
77
  return []
78
 
79
  # ==============================
80
+ # EXTRACT CONTENT (TEXT + IMAGES)
81
  # ==============================
82
  def extract_content(url):
83
  try:
84
+ res = requests.get(url, timeout=10)
85
  soup = BeautifulSoup(res.text, "html.parser")
86
 
87
  # TEXT
88
+ paragraphs = [p.get_text().strip() for p in soup.find_all("p")]
89
  text = " ".join(paragraphs)
90
 
91
  # IMAGES β†’ CAPTION
92
  image_texts = []
93
  images = soup.find_all("img")
94
 
95
+ for img in images[:5]: # limit
96
  try:
97
+ img_url = urljoin(url, img.get("src"))
 
 
98
 
99
+ img_res = requests.get(img_url, timeout=5)
100
  image = Image.open(BytesIO(img_res.content)).convert("RGB")
101
 
102
  inputs = processor(image, return_tensors="pt")
 
108
  except:
109
  continue
110
 
111
+ return text + " " + " ".join(image_texts)
 
112
 
113
  except:
114
  return ""
 
118
  # ==============================
119
  def chunk_text(text, size=300):
120
  words = text.split()
121
+ return [" ".join(words[i:i+size]) for i in range(0, len(words), size)]
 
 
 
122
 
123
  # ==============================
124
+ # BUILD FAISS INDEX
125
  # ==============================
126
  def build_index(texts):
127
  embeddings = embed_model.encode(texts)
 
130
  index = faiss.IndexFlatL2(dim)
131
  index.add(np.array(embeddings))
132
 
133
+ return index
134
+
135
+ # ==============================
136
+ # ADD TO EXISTING INDEX
137
+ # ==============================
138
+ def add_to_index(new_chunks):
139
+ new_embeddings = embed_model.encode(new_chunks)
140
+ st.session_state.index.add(np.array(new_embeddings))
141
+ st.session_state.documents.extend(new_chunks)
142
 
143
  # ==============================
144
  # UI
145
  # ==============================
146
+ st.title("🌐 Website QA with Images (Fixed)")
147
 
148
+ # ==============================
149
+ # STEP 1: URL INPUT
150
+ # ==============================
151
  url = st.text_input("πŸ”— Enter Website URL")
152
 
153
  if st.button("Crawl Website"):
 
157
  st.session_state.links = links
158
  st.success(f"Found {len(links)} pages")
159
  else:
160
+ st.error("No links found or invalid URL")
161
 
162
  # ==============================
163
+ # STEP 2: PAGE SELECTION
164
  # ==============================
165
+ selected_links = []
166
+
167
+ if st.session_state.links:
168
+ st.subheader("πŸ“„ Select Pages to Train")
169
 
 
170
  for link in st.session_state.links:
171
  if st.checkbox(link):
172
  selected_links.append(link)
 
181
  all_chunks.extend(chunks)
182
 
183
  if all_chunks:
184
+ st.session_state.index = build_index(all_chunks)
 
185
  st.session_state.documents = all_chunks
 
186
 
187
+ st.success("βœ… Training completed!")
188
+ else:
189
+ st.warning("No content extracted")
190
 
191
  # ==============================
192
+ # STEP 3: ADD MORE PAGES
193
  # ==============================
194
+ st.subheader("βž• Add More Pages")
 
 
 
195
 
196
+ new_url = st.text_input("Enter another page URL")
 
 
197
 
198
+ if st.button("Add & Train"):
199
+ content = extract_content(new_url)
200
+ chunks = chunk_text(content)
201
 
202
+ if chunks:
203
+ if st.session_state.index is None:
204
+ st.session_state.index = build_index(chunks)
205
+ st.session_state.documents = chunks
206
+ else:
207
+ add_to_index(chunks)
208
 
209
+ st.success("βœ… Page added successfully!")
210
+ else:
211
+ st.error("Failed to extract content")
212
 
213
  # ==============================
214
+ # STEP 4: ASK QUESTIONS
215
  # ==============================
216
  st.subheader("πŸ’¬ Ask Questions")
217
 
 
219
 
220
  if st.button("Get Answer"):
221
  if st.session_state.index is None:
222
+ st.warning("⚠️ Please train pages first")
223
  else:
224
  q_embed = embed_model.encode([query])
225
 
 
228
  context = " ".join([st.session_state.documents[i] for i in I[0]])
229
 
230
  prompt = f"""
231
+ Answer based only on the context.
232
 
233
  Context:
234
  {context}
235
 
236
  Question:
237
  {query}
238
+
239
+ Answer:
240
  """
241
 
242
+ response = qa_pipeline(prompt)[0]["generated_text"]
243
+
244
+ # βœ… CLEAN OUTPUT
245
+ answer = response.replace(prompt, "").strip()
246
 
247
  st.write("### βœ… Answer")
248
+ st.write(answer if answer else "No relevant answer found")