import streamlit as st import os import re import logging import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from playwright.sync_api import sync_playwright from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.schema import Document logging.basicConfig( filename='/app/cache/app.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s' ) MODEL_NAME = "google/flan-t5-large" MAX_INPUT_LEN = 512 # FLAN-T5-large context window st.set_page_config( page_title="RAG Β· FLAN-T5", page_icon="πŸ•ΈοΈ", layout="wide", initial_sidebar_state="collapsed" ) st.markdown(""" """, unsafe_allow_html=True) # ── Session state ────────────────────────────────────────────────────────────── for key, default in [ ('scraped_content', ''), ('vector_store', None), ('chat_history', []), ('scraped_title', None), ('scraped_url', None), ('char_count', 0), ]: if key not in st.session_state: st.session_state[key] = default # ── Utilities ────────────────────────────────────────────────────────────────── def clean_text(text): text = re.sub(r'[ \t]+', ' ', text) text = re.sub(r'\n{3,}', '\n\n', text) return text.strip() def is_valid_url(url): return bool(re.match(r'^https?://[\w\-\.]+(?::\d+)?(?:/[\w\-\./]*)*$', url)) # ── Model ────────────────────────────────────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_model(): try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, ) model = model.to("cpu") model.eval() logging.info(f"Loaded {MODEL_NAME}") return tokenizer, model except Exception as e: logging.error(f"Model load error: {e}") return None, None # ── Scraper ──────────────────────────────────────────────────────────────────── def scrape_website(url): with sync_playwright() as p: browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage']) page = browser.new_page() try: # domcontentloaded avoids timeout on ad-heavy sites try: page.goto(url, wait_until="domcontentloaded", timeout=30000) except Exception: pass # content may already be loaded even on timeout page.wait_for_timeout(3000) # allow JS 3s to render title = page.title() # Strategy 1:
  • items β€” great for price/listing pages lines = [] for li in page.query_selector_all("li"): try: text = li.inner_text().strip() if text and 3 < len(text) < 300: lines.append(text) except: continue # Strategy 2: headings, paragraphs, table cells for tag in ["h1", "h2", "h3", "h4", "p", "td"]: for e in page.query_selector_all(tag): try: text = e.inner_text().strip() if text and 3 < len(text) < 500: lines.append(text) except: continue # Deduplicate preserving order seen, unique_lines = set(), [] for line in lines: n = re.sub(r'\s+', ' ', line).strip() if n not in seen: seen.add(n) unique_lines.append(n) content = "\n".join(unique_lines) # Fallback to body if nothing found if len(content) < 200: body = page.query_selector("body") content = clean_text(body.inner_text()) if body else content logging.info(f"Scraped {len(content)} chars from {url}") return {"title": title, "content": content, "url": url} except Exception as e: logging.error(f"Scrape error: {e}") st.error(f"Scraping failed: {e}") return None finally: browser.close() # ── Vector store ─────────────────────────────────────────────────────────────── @st.cache_resource def create_vector_store(text): try: # Small chunks so the single best one fits cleanly in 512 tokens splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=30) docs = [Document(page_content=c) for c in splitter.split_text(text)] emb = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'} ) return FAISS.from_documents(docs, emb) except Exception as e: st.error(f"Indexing failed: {e}") return None # ── Answer ───────────────────────────────────────────────────────────────────── def answer_question(question): if not st.session_state.vector_store: return "No content indexed yet." tokenizer, model = load_model() if tokenizer is None: return "Model failed to load. Check logs." try: # k=1 β€” single most relevant chunk keeps prompt tight within 512 tokens docs = st.session_state.vector_store.similarity_search(question, k=1) context = docs[0].page_content prompt = ( "Answer the question using only the context provided. " "If the answer is not in the context, say \"I don't know\".\n\n" f"Context: {context}\n\n" f"Question: {question}\n\n" "Answer:" ) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LEN, ) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=200, num_beams=4, early_stopping=True, no_repeat_ngram_size=3, ) return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() except Exception as e: logging.error(f"Inference error: {e}") return f"Error generating answer: {e}" # ── Preload model ────────────────────────────────────────────────────────────── with st.spinner(f"Loading {MODEL_NAME}…"): _tok, _mod = load_model() model_ok = _tok is not None # ── Sidebar ──────────────────────────────────────────────────────────────────── with st.sidebar: st.markdown("**Model**") st.markdown(f"`{MODEL_NAME}`") st.markdown("**Context window**") st.markdown("`512 tokens`") st.markdown("**Architecture**") st.markdown("`Encoder-Decoder`") st.markdown("**Status**") if model_ok: st.success("Model loaded βœ“") else: st.error("Model failed to load") # ── Page header ──────────────────────────────────────────────────────────────── dot_color = "#4caf50" if model_ok else "#e53935" dot_label = "Model ready" if model_ok else "Model error" st.markdown(f""" """, unsafe_allow_html=True) # ── URL bar ──────────────────────────────────────────────────────────────────── col_url, col_btn = st.columns([5, 1]) with col_url: url_input = st.text_input( "url", label_visibility="collapsed", placeholder="https://en.wikipedia.org/wiki/Retrieval-augmented_generation" ) with col_btn: scrape_clicked = st.button("Scrape", use_container_width=True) if scrape_clicked: if not url_input or not is_valid_url(url_input): st.warning("Enter a valid URL starting with https://") else: with st.spinner("Scraping…"): result = scrape_website(url_input) if result: st.session_state.scraped_content = result['content'] st.session_state.scraped_title = result['title'] st.session_state.scraped_url = result['url'] st.session_state.char_count = len(result['content']) st.session_state.chat_history = [] with st.spinner("Building FAISS index…"): st.session_state.vector_store = create_vector_store(result['content']) st.rerun() # ── Main content area ────────────────────────────────────────────────────────── if st.session_state.scraped_content: title_display = st.session_state.scraped_title or "" url_display = st.session_state.scraped_url or "" st.markdown(f"""
    {title_display}  Β·  {st.session_state.char_count:,} chars  Β·  {url_display}
    """, unsafe_allow_html=True) st.markdown('
    Scraped content
    ', unsafe_allow_html=True) preview = st.session_state.scraped_content[:4000] if len(st.session_state.scraped_content) > 4000: preview += "\n\n… (truncated for display)" st.markdown(f'
    {preview}
    ', unsafe_allow_html=True) st.markdown("""
    Ask a question
    """, unsafe_allow_html=True) for msg in st.session_state.chat_history: with st.chat_message(msg["role"]): st.markdown(msg["content"]) if prompt := st.chat_input("Ask anything about the content above…"): st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("FLAN-T5 is thinking…"): answer = answer_question(prompt) st.markdown(answer) st.session_state.chat_history.append({"role": "assistant", "content": answer}) if st.session_state.chat_history: if st.button("Clear chat"): st.session_state.chat_history = [] st.rerun() else: st.markdown("""
    Nothing scraped yet
    Enter a URL above and hit Scrape to get started.
    """, unsafe_allow_html=True)