muddasser's picture
Update app.py
f929333 verified
import streamlit as st
import os
import re
import logging
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from playwright.sync_api import sync_playwright
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
logging.basicConfig(
filename='/app/cache/app.log',
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
MODEL_NAME = "google/flan-t5-large"
MAX_INPUT_LEN = 512 # FLAN-T5-large context window
st.set_page_config(
page_title="RAG Β· FLAN-T5",
page_icon="πŸ•ΈοΈ",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Instrument+Serif:ital@0;1&family=JetBrains+Mono:wght@300;400;500&display=swap');
:root {
--bg: #f5f0e8;
--surface: #ede8df;
--border: #d4cec4;
--text: #1a1814;
--muted: #7a756c;
--accent: #c13a1e;
--mono: 'JetBrains Mono', monospace;
--serif: 'Instrument Serif', serif;
}
html, body, [class*="css"] {
font-family: var(--mono);
background: var(--bg);
color: var(--text);
}
.stApp { background: var(--bg); }
#MainMenu, footer, header { visibility: hidden; }
[data-testid="stDecoration"] { display: none; }
[data-testid="stSidebar"] {
background: var(--surface);
border-right: 1px solid var(--border);
}
.stTextInput > div > div > input,
.stTextArea textarea {
background: #fff !important;
border: 1px solid var(--border) !important;
border-radius: 3px !important;
color: var(--text) !important;
font-family: var(--mono) !important;
font-size: 0.82rem !important;
}
.stTextInput > div > div > input:focus,
.stTextArea textarea:focus {
border-color: var(--accent) !important;
box-shadow: 0 0 0 2px rgba(193,58,30,0.12) !important;
}
.stButton > button {
background: var(--accent) !important;
color: #fff !important;
border: none !important;
border-radius: 3px !important;
font-family: var(--mono) !important;
font-size: 0.78rem !important;
font-weight: 500 !important;
letter-spacing: 0.06em !important;
text-transform: uppercase !important;
padding: 0.45rem 1.2rem !important;
transition: all 0.15s !important;
}
.stButton > button:hover {
background: #a83018 !important;
transform: translateY(-1px);
box-shadow: 0 3px 12px rgba(193,58,30,0.25) !important;
}
[data-testid="stChatMessage"] {
background: #fff !important;
border: 1px solid var(--border) !important;
border-radius: 4px !important;
margin-bottom: 0.4rem !important;
}
[data-testid="stChatInput"] textarea {
background: #fff !important;
font-family: var(--mono) !important;
font-size: 0.82rem !important;
}
hr { border-color: var(--border) !important; }
.content-box {
background: #fff;
border: 1px solid var(--border);
border-radius: 4px;
padding: 1.2rem 1.4rem;
font-family: var(--mono);
font-size: 0.78rem;
line-height: 1.7;
color: var(--text);
max-height: 340px;
overflow-y: auto;
white-space: pre-wrap;
word-break: break-word;
}
.content-box::-webkit-scrollbar { width: 6px; }
.content-box::-webkit-scrollbar-track { background: var(--surface); }
.content-box::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
.meta-pill {
display: inline-flex;
align-items: center;
gap: 6px;
background: var(--surface);
border: 1px solid var(--border);
border-radius: 20px;
padding: 3px 10px;
font-size: 0.72rem;
color: var(--muted);
margin-bottom: 0.6rem;
}
.meta-dot { width:6px; height:6px; border-radius:50%; background:#4caf50; }
.section-label {
font-size: 0.68rem;
letter-spacing: 0.12em;
text-transform: uppercase;
color: var(--muted);
margin-bottom: 0.5rem;
display: flex;
align-items: center;
gap: 8px;
}
.section-label::after {
content: '';
flex: 1;
height: 1px;
background: var(--border);
}
.qa-banner {
display: flex;
align-items: center;
gap: 12px;
margin: 1.8rem 0 1rem 0;
}
.qa-banner-line { flex:1; height:1px; background:var(--border); }
.qa-banner-label {
font-family: var(--serif);
font-style: italic;
font-size: 1.05rem;
color: var(--accent);
white-space: nowrap;
}
.model-badge {
display: inline-flex;
align-items: center;
gap: 5px;
font-size: 0.7rem;
color: var(--muted);
padding: 2px 8px;
border: 1px solid var(--border);
border-radius: 3px;
}
.model-dot { width:6px; height:6px; border-radius:50%; }
.page-header {
padding: 1.5rem 0 1rem 0;
border-bottom: 2px solid var(--text);
margin-bottom: 1.5rem;
display: flex;
align-items: baseline;
justify-content: space-between;
flex-wrap: wrap;
gap: 0.5rem;
}
.page-title {
font-family: var(--serif);
font-size: 2rem;
color: var(--text);
margin: 0;
line-height: 1;
}
.page-sub {
font-size: 0.72rem;
color: var(--muted);
letter-spacing: 0.08em;
text-transform: uppercase;
}
[data-testid="stAlert"] {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: 4px !important;
font-family: var(--mono) !important;
font-size: 0.82rem !important;
}
</style>
""", unsafe_allow_html=True)
# ── Session state ──────────────────────────────────────────────────────────────
for key, default in [
('scraped_content', ''),
('vector_store', None),
('chat_history', []),
('scraped_title', None),
('scraped_url', None),
('char_count', 0),
]:
if key not in st.session_state:
st.session_state[key] = default
# ── Utilities ──────────────────────────────────────────────────────────────────
def clean_text(text):
text = re.sub(r'[ \t]+', ' ', text)
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
def is_valid_url(url):
return bool(re.match(r'^https?://[\w\-\.]+(?::\d+)?(?:/[\w\-\./]*)*$', url))
# ── Model ──────────────────────────────────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_model():
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
)
model = model.to("cpu")
model.eval()
logging.info(f"Loaded {MODEL_NAME}")
return tokenizer, model
except Exception as e:
logging.error(f"Model load error: {e}")
return None, None
# ── Scraper ────────────────────────────────────────────────────────────────────
def scrape_website(url):
with sync_playwright() as p:
browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
page = browser.new_page()
try:
# domcontentloaded avoids timeout on ad-heavy sites
try:
page.goto(url, wait_until="domcontentloaded", timeout=30000)
except Exception:
pass # content may already be loaded even on timeout
page.wait_for_timeout(3000) # allow JS 3s to render
title = page.title()
# Strategy 1: <li> items β€” great for price/listing pages
lines = []
for li in page.query_selector_all("li"):
try:
text = li.inner_text().strip()
if text and 3 < len(text) < 300:
lines.append(text)
except:
continue
# Strategy 2: headings, paragraphs, table cells
for tag in ["h1", "h2", "h3", "h4", "p", "td"]:
for e in page.query_selector_all(tag):
try:
text = e.inner_text().strip()
if text and 3 < len(text) < 500:
lines.append(text)
except:
continue
# Deduplicate preserving order
seen, unique_lines = set(), []
for line in lines:
n = re.sub(r'\s+', ' ', line).strip()
if n not in seen:
seen.add(n)
unique_lines.append(n)
content = "\n".join(unique_lines)
# Fallback to body if nothing found
if len(content) < 200:
body = page.query_selector("body")
content = clean_text(body.inner_text()) if body else content
logging.info(f"Scraped {len(content)} chars from {url}")
return {"title": title, "content": content, "url": url}
except Exception as e:
logging.error(f"Scrape error: {e}")
st.error(f"Scraping failed: {e}")
return None
finally:
browser.close()
# ── Vector store ───────────────────────────────────────────────────────────────
@st.cache_resource
def create_vector_store(text):
try:
# Small chunks so the single best one fits cleanly in 512 tokens
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=30)
docs = [Document(page_content=c) for c in splitter.split_text(text)]
emb = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
return FAISS.from_documents(docs, emb)
except Exception as e:
st.error(f"Indexing failed: {e}")
return None
# ── Answer ─────────────────────────────────────────────────────────────────────
def answer_question(question):
if not st.session_state.vector_store:
return "No content indexed yet."
tokenizer, model = load_model()
if tokenizer is None:
return "Model failed to load. Check logs."
try:
# k=1 β€” single most relevant chunk keeps prompt tight within 512 tokens
docs = st.session_state.vector_store.similarity_search(question, k=1)
context = docs[0].page_content
prompt = (
"Answer the question using only the context provided. "
"If the answer is not in the context, say \"I don't know\".\n\n"
f"Context: {context}\n\n"
f"Question: {question}\n\n"
"Answer:"
)
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LEN,
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
except Exception as e:
logging.error(f"Inference error: {e}")
return f"Error generating answer: {e}"
# ── Preload model ──────────────────────────────────────────────────────────────
with st.spinner(f"Loading {MODEL_NAME}…"):
_tok, _mod = load_model()
model_ok = _tok is not None
# ── Sidebar ────────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("**Model**")
st.markdown(f"`{MODEL_NAME}`")
st.markdown("**Context window**")
st.markdown("`512 tokens`")
st.markdown("**Architecture**")
st.markdown("`Encoder-Decoder`")
st.markdown("**Status**")
if model_ok:
st.success("Model loaded βœ“")
else:
st.error("Model failed to load")
# ── Page header ────────────────────────────────────────────────────────────────
dot_color = "#4caf50" if model_ok else "#e53935"
dot_label = "Model ready" if model_ok else "Model error"
st.markdown(f"""
<div class="page-header">
<div>
<p class="page-title">Web RAG</p>
<span class="page-sub">Scrape β†’ Index β†’ Ask</span>
</div>
<div class="model-badge">
<div class="model-dot" style="background:{dot_color};"></div>
{dot_label} &nbsp;Β·&nbsp; FLAN-T5-large
</div>
</div>
""", unsafe_allow_html=True)
# ── URL bar ────────────────────────────────────────────────────────────────────
col_url, col_btn = st.columns([5, 1])
with col_url:
url_input = st.text_input(
"url", label_visibility="collapsed",
placeholder="https://en.wikipedia.org/wiki/Retrieval-augmented_generation"
)
with col_btn:
scrape_clicked = st.button("Scrape", use_container_width=True)
if scrape_clicked:
if not url_input or not is_valid_url(url_input):
st.warning("Enter a valid URL starting with https://")
else:
with st.spinner("Scraping…"):
result = scrape_website(url_input)
if result:
st.session_state.scraped_content = result['content']
st.session_state.scraped_title = result['title']
st.session_state.scraped_url = result['url']
st.session_state.char_count = len(result['content'])
st.session_state.chat_history = []
with st.spinner("Building FAISS index…"):
st.session_state.vector_store = create_vector_store(result['content'])
st.rerun()
# ── Main content area ──────────────────────────────────────────────────────────
if st.session_state.scraped_content:
title_display = st.session_state.scraped_title or ""
url_display = st.session_state.scraped_url or ""
st.markdown(f"""
<div class="meta-pill">
<div class="meta-dot"></div>
<span>{title_display}</span>
&nbsp;Β·&nbsp;
<span>{st.session_state.char_count:,} chars</span>
&nbsp;Β·&nbsp;
<span style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;">{url_display}</span>
</div>
""", unsafe_allow_html=True)
st.markdown('<div class="section-label">Scraped content</div>', unsafe_allow_html=True)
preview = st.session_state.scraped_content[:4000]
if len(st.session_state.scraped_content) > 4000:
preview += "\n\n… (truncated for display)"
st.markdown(f'<div class="content-box">{preview}</div>', unsafe_allow_html=True)
st.markdown("""
<div class="qa-banner">
<div class="qa-banner-line"></div>
<div class="qa-banner-label">Ask a question</div>
<div class="qa-banner-line"></div>
</div>
""", unsafe_allow_html=True)
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input("Ask anything about the content above…"):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("FLAN-T5 is thinking…"):
answer = answer_question(prompt)
st.markdown(answer)
st.session_state.chat_history.append({"role": "assistant", "content": answer})
if st.session_state.chat_history:
if st.button("Clear chat"):
st.session_state.chat_history = []
st.rerun()
else:
st.markdown("""
<div style="
text-align:center;
padding: 4rem 2rem;
color: #7a756c;
font-size: 0.82rem;
border: 1px dashed #d4cec4;
border-radius: 4px;
margin-top: 1rem;
">
<div style="font-family:'Instrument Serif',serif; font-style:italic;
font-size:1.4rem; margin-bottom:0.5rem; color:#1a1814;">
Nothing scraped yet
</div>
Enter a URL above and hit <strong>Scrape</strong> to get started.
</div>
""", unsafe_allow_html=True)