import streamlit as st
import os
import re
import logging
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from playwright.sync_api import sync_playwright
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
logging.basicConfig(
filename='/app/cache/app.log',
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
MODEL_NAME = "google/flan-t5-large"
MAX_INPUT_LEN = 512 # FLAN-T5-large context window
st.set_page_config(
page_title="RAG Β· FLAN-T5",
page_icon="πΈοΈ",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
""", unsafe_allow_html=True)
# ββ Session state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
for key, default in [
('scraped_content', ''),
('vector_store', None),
('chat_history', []),
('scraped_title', None),
('scraped_url', None),
('char_count', 0),
]:
if key not in st.session_state:
st.session_state[key] = default
# ββ Utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def clean_text(text):
text = re.sub(r'[ \t]+', ' ', text)
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
def is_valid_url(url):
return bool(re.match(r'^https?://[\w\-\.]+(?::\d+)?(?:/[\w\-\./]*)*$', url))
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource(show_spinner=False)
def load_model():
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
)
model = model.to("cpu")
model.eval()
logging.info(f"Loaded {MODEL_NAME}")
return tokenizer, model
except Exception as e:
logging.error(f"Model load error: {e}")
return None, None
# ββ Scraper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def scrape_website(url):
with sync_playwright() as p:
browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
page = browser.new_page()
try:
# domcontentloaded avoids timeout on ad-heavy sites
try:
page.goto(url, wait_until="domcontentloaded", timeout=30000)
except Exception:
pass # content may already be loaded even on timeout
page.wait_for_timeout(3000) # allow JS 3s to render
title = page.title()
# Strategy 1:
items β great for price/listing pages
lines = []
for li in page.query_selector_all("li"):
try:
text = li.inner_text().strip()
if text and 3 < len(text) < 300:
lines.append(text)
except:
continue
# Strategy 2: headings, paragraphs, table cells
for tag in ["h1", "h2", "h3", "h4", "p", "td"]:
for e in page.query_selector_all(tag):
try:
text = e.inner_text().strip()
if text and 3 < len(text) < 500:
lines.append(text)
except:
continue
# Deduplicate preserving order
seen, unique_lines = set(), []
for line in lines:
n = re.sub(r'\s+', ' ', line).strip()
if n not in seen:
seen.add(n)
unique_lines.append(n)
content = "\n".join(unique_lines)
# Fallback to body if nothing found
if len(content) < 200:
body = page.query_selector("body")
content = clean_text(body.inner_text()) if body else content
logging.info(f"Scraped {len(content)} chars from {url}")
return {"title": title, "content": content, "url": url}
except Exception as e:
logging.error(f"Scrape error: {e}")
st.error(f"Scraping failed: {e}")
return None
finally:
browser.close()
# ββ Vector store βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource
def create_vector_store(text):
try:
# Small chunks so the single best one fits cleanly in 512 tokens
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=30)
docs = [Document(page_content=c) for c in splitter.split_text(text)]
emb = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
return FAISS.from_documents(docs, emb)
except Exception as e:
st.error(f"Indexing failed: {e}")
return None
# ββ Answer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def answer_question(question):
if not st.session_state.vector_store:
return "No content indexed yet."
tokenizer, model = load_model()
if tokenizer is None:
return "Model failed to load. Check logs."
try:
# k=1 β single most relevant chunk keeps prompt tight within 512 tokens
docs = st.session_state.vector_store.similarity_search(question, k=1)
context = docs[0].page_content
prompt = (
"Answer the question using only the context provided. "
"If the answer is not in the context, say \"I don't know\".\n\n"
f"Context: {context}\n\n"
f"Question: {question}\n\n"
"Answer:"
)
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LEN,
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
except Exception as e:
logging.error(f"Inference error: {e}")
return f"Error generating answer: {e}"
# ββ Preload model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with st.spinner(f"Loading {MODEL_NAME}β¦"):
_tok, _mod = load_model()
model_ok = _tok is not None
# ββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with st.sidebar:
st.markdown("**Model**")
st.markdown(f"`{MODEL_NAME}`")
st.markdown("**Context window**")
st.markdown("`512 tokens`")
st.markdown("**Architecture**")
st.markdown("`Encoder-Decoder`")
st.markdown("**Status**")
if model_ok:
st.success("Model loaded β")
else:
st.error("Model failed to load")
# ββ Page header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
dot_color = "#4caf50" if model_ok else "#e53935"
dot_label = "Model ready" if model_ok else "Model error"
st.markdown(f"""
""", unsafe_allow_html=True)
# ββ URL bar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
col_url, col_btn = st.columns([5, 1])
with col_url:
url_input = st.text_input(
"url", label_visibility="collapsed",
placeholder="https://en.wikipedia.org/wiki/Retrieval-augmented_generation"
)
with col_btn:
scrape_clicked = st.button("Scrape", use_container_width=True)
if scrape_clicked:
if not url_input or not is_valid_url(url_input):
st.warning("Enter a valid URL starting with https://")
else:
with st.spinner("Scrapingβ¦"):
result = scrape_website(url_input)
if result:
st.session_state.scraped_content = result['content']
st.session_state.scraped_title = result['title']
st.session_state.scraped_url = result['url']
st.session_state.char_count = len(result['content'])
st.session_state.chat_history = []
with st.spinner("Building FAISS indexβ¦"):
st.session_state.vector_store = create_vector_store(result['content'])
st.rerun()
# ββ Main content area ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if st.session_state.scraped_content:
title_display = st.session_state.scraped_title or ""
url_display = st.session_state.scraped_url or ""
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown('Scraped content
', unsafe_allow_html=True)
preview = st.session_state.scraped_content[:4000]
if len(st.session_state.scraped_content) > 4000:
preview += "\n\n⦠(truncated for display)"
st.markdown(f'{preview}
', unsafe_allow_html=True)
st.markdown("""
""", unsafe_allow_html=True)
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if prompt := st.chat_input("Ask anything about the content aboveβ¦"):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("FLAN-T5 is thinkingβ¦"):
answer = answer_question(prompt)
st.markdown(answer)
st.session_state.chat_history.append({"role": "assistant", "content": answer})
if st.session_state.chat_history:
if st.button("Clear chat"):
st.session_state.chat_history = []
st.rerun()
else:
st.markdown("""
Nothing scraped yet
Enter a URL above and hit
Scrape to get started.
""", unsafe_allow_html=True)