pradeep4321 commited on
Commit
ec67f77
Β·
verified Β·
1 Parent(s): fd59730

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +225 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,227 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # =========================================================
2
+ # 🌐 WEBSITE RAG + IMAGE UNDERSTANDING (HF SPACES)
3
+ # =========================================================
4
+
5
  import streamlit as st
6
+ import requests
7
+ from bs4 import BeautifulSoup
8
+ import numpy as np
9
+ import faiss
10
+ import torch
11
+ from PIL import Image
12
+ from io import BytesIO
13
+
14
+ from sentence_transformers import SentenceTransformer
15
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
16
+
17
+ # ==============================
18
+ # PAGE CONFIG
19
+ # ==============================
20
+ st.set_page_config(page_title="🌐 Website QA System", layout="wide")
21
+
22
+ # ==============================
23
+ # LOAD MODELS
24
+ # ==============================
25
+ @st.cache_resource
26
+ def load_models():
27
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
28
+
29
+ qa_pipeline = pipeline(
30
+ "text2text-generation",
31
+ model="google/flan-t5-base",
32
+ max_length=256
33
+ )
34
+
35
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
36
+ image_model = BlipForConditionalGeneration.from_pretrained(
37
+ "Salesforce/blip-image-captioning-base"
38
+ )
39
+
40
+ return embed_model, qa_pipeline, processor, image_model
41
+
42
+ embed_model, qa_pipeline, processor, image_model = load_models()
43
+
44
+ # ==============================
45
+ # SESSION STATE
46
+ # ==============================
47
+ if "documents" not in st.session_state:
48
+ st.session_state.documents = []
49
+
50
+ if "index" not in st.session_state:
51
+ st.session_state.index = None
52
+
53
+ # ==============================
54
+ # CRAWL WEBSITE
55
+ # ==============================
56
+ def crawl_website(url):
57
+ try:
58
+ res = requests.get(url)
59
+ soup = BeautifulSoup(res.text, "html.parser")
60
+
61
+ links = []
62
+ for a in soup.find_all("a", href=True):
63
+ link = a["href"]
64
+ if link.startswith("http"):
65
+ links.append(link)
66
+
67
+ return list(set(links))[:20] # limit
68
+ except:
69
+ return []
70
+
71
+ # ==============================
72
+ # EXTRACT CONTENT
73
+ # ==============================
74
+ def extract_content(url):
75
+ try:
76
+ res = requests.get(url)
77
+ soup = BeautifulSoup(res.text, "html.parser")
78
+
79
+ # TEXT
80
+ paragraphs = [p.get_text() for p in soup.find_all("p")]
81
+ text = " ".join(paragraphs)
82
+
83
+ # IMAGES β†’ CAPTION
84
+ image_texts = []
85
+ images = soup.find_all("img")
86
+
87
+ for img in images[:5]: # limit images
88
+ try:
89
+ img_url = img.get("src")
90
+ if not img_url.startswith("http"):
91
+ continue
92
+
93
+ img_res = requests.get(img_url)
94
+ image = Image.open(BytesIO(img_res.content)).convert("RGB")
95
+
96
+ inputs = processor(image, return_tensors="pt")
97
+ out = image_model.generate(**inputs)
98
+ caption = processor.decode(out[0], skip_special_tokens=True)
99
+
100
+ image_texts.append(caption)
101
+
102
+ except:
103
+ continue
104
+
105
+ full_text = text + " " + " ".join(image_texts)
106
+ return full_text
107
+
108
+ except:
109
+ return ""
110
+
111
+ # ==============================
112
+ # CHUNKING
113
+ # ==============================
114
+ def chunk_text(text, size=300):
115
+ words = text.split()
116
+ chunks = []
117
+ for i in range(0, len(words), size):
118
+ chunks.append(" ".join(words[i:i+size]))
119
+ return chunks
120
+
121
+ # ==============================
122
+ # BUILD INDEX
123
+ # ==============================
124
+ def build_index(texts):
125
+ embeddings = embed_model.encode(texts)
126
+ dim = embeddings.shape[1]
127
+
128
+ index = faiss.IndexFlatL2(dim)
129
+ index.add(np.array(embeddings))
130
+
131
+ return index, embeddings
132
+
133
+ # ==============================
134
+ # UI
135
+ # ==============================
136
+ st.title("🌐 Website QA with Images")
137
+
138
+ url = st.text_input("πŸ”— Enter Website URL")
139
+
140
+ if st.button("Crawl Website"):
141
+ links = crawl_website(url)
142
+
143
+ if links:
144
+ st.session_state.links = links
145
+ st.success(f"Found {len(links)} pages")
146
+ else:
147
+ st.error("No links found")
148
+
149
+ # ==============================
150
+ # PAGE SELECTION
151
+ # ==============================
152
+ if "links" in st.session_state:
153
+ st.subheader("Select Pages to Train")
154
+
155
+ selected_links = []
156
+ for link in st.session_state.links:
157
+ if st.checkbox(link):
158
+ selected_links.append(link)
159
+
160
+ if st.button("Train Selected Pages"):
161
+ all_chunks = []
162
+
163
+ with st.spinner("Processing pages..."):
164
+ for link in selected_links:
165
+ content = extract_content(link)
166
+ chunks = chunk_text(content)
167
+ all_chunks.extend(chunks)
168
+
169
+ if all_chunks:
170
+ index, embeddings = build_index(all_chunks)
171
+
172
+ st.session_state.documents = all_chunks
173
+ st.session_state.index = index
174
+
175
+ st.success("Training completed!")
176
+
177
+ # ==============================
178
+ # ADD MORE PAGES
179
+ # ==============================
180
+ if "links" in st.session_state:
181
+ st.subheader("βž• Add More Pages")
182
+
183
+ new_url = st.text_input("Add another URL")
184
+
185
+ if st.button("Add & Train"):
186
+ content = extract_content(new_url)
187
+ chunks = chunk_text(content)
188
+
189
+ if chunks:
190
+ new_embeddings = embed_model.encode(chunks)
191
+
192
+ st.session_state.index.add(np.array(new_embeddings))
193
+ st.session_state.documents.extend(chunks)
194
+
195
+ st.success("Added new page!")
196
+
197
+ # ==============================
198
+ # ASK QUESTIONS
199
+ # ==============================
200
+ st.subheader("πŸ’¬ Ask Questions")
201
+
202
+ query = st.text_input("Ask something from the website")
203
+
204
+ if st.button("Get Answer"):
205
+ if st.session_state.index is None:
206
+ st.warning("Please train pages first")
207
+ else:
208
+ q_embed = embed_model.encode([query])
209
+
210
+ D, I = st.session_state.index.search(np.array(q_embed), k=5)
211
+
212
+ context = " ".join([st.session_state.documents[i] for i in I[0]])
213
+
214
+ prompt = f"""
215
+ Answer the question based on the context.
216
+
217
+ Context:
218
+ {context}
219
+
220
+ Question:
221
+ {query}
222
+ """
223
+
224
+ answer = qa_pipeline(prompt)[0]["generated_text"]
225
 
226
+ st.write("### βœ… Answer")
227
+ st.write(answer)