Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
| 3 |
import re
|
| 4 |
import logging
|
| 5 |
import torch
|
| 6 |
-
from transformers import AutoTokenizer,
|
| 7 |
from playwright.sync_api import sync_playwright
|
| 8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 9 |
from langchain_community.vectorstores import FAISS
|
|
@@ -16,10 +16,11 @@ logging.basicConfig(
|
|
| 16 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 17 |
)
|
| 18 |
|
| 19 |
-
MODEL_NAME
|
|
|
|
| 20 |
|
| 21 |
st.set_page_config(
|
| 22 |
-
page_title="RAG Β·
|
| 23 |
page_icon="πΈοΈ",
|
| 24 |
layout="wide",
|
| 25 |
initial_sidebar_state="collapsed"
|
|
@@ -213,7 +214,6 @@ for key, default in [
|
|
| 213 |
# ββ Utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
|
| 215 |
def clean_text(text):
|
| 216 |
-
# Only collapse whitespace β preserve prices, commas, symbols
|
| 217 |
text = re.sub(r'[ \t]+', ' ', text)
|
| 218 |
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 219 |
return text.strip()
|
|
@@ -227,7 +227,7 @@ def is_valid_url(url):
|
|
| 227 |
def load_model():
|
| 228 |
try:
|
| 229 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 230 |
-
model =
|
| 231 |
MODEL_NAME,
|
| 232 |
torch_dtype=torch.float32,
|
| 233 |
low_cpu_mem_usage=True,
|
|
@@ -246,17 +246,15 @@ def scrape_website(url):
|
|
| 246 |
browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
|
| 247 |
page = browser.new_page()
|
| 248 |
try:
|
| 249 |
-
#
|
| 250 |
-
# domcontentloaded fires as soon as HTML is parsed, then we wait
|
| 251 |
-
# a few seconds for JS-rendered content to appear
|
| 252 |
try:
|
| 253 |
page.goto(url, wait_until="domcontentloaded", timeout=30000)
|
| 254 |
except Exception:
|
| 255 |
-
pass #
|
| 256 |
-
page.wait_for_timeout(3000) #
|
| 257 |
title = page.title()
|
| 258 |
|
| 259 |
-
# Strategy 1:
|
| 260 |
lines = []
|
| 261 |
for li in page.query_selector_all("li"):
|
| 262 |
try:
|
|
@@ -305,7 +303,8 @@ def scrape_website(url):
|
|
| 305 |
@st.cache_resource
|
| 306 |
def create_vector_store(text):
|
| 307 |
try:
|
| 308 |
-
|
|
|
|
| 309 |
docs = [Document(page_content=c) for c in splitter.split_text(text)]
|
| 310 |
emb = HuggingFaceEmbeddings(
|
| 311 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
|
@@ -325,54 +324,36 @@ def answer_question(question):
|
|
| 325 |
if tokenizer is None:
|
| 326 |
return "Model failed to load. Check logs."
|
| 327 |
try:
|
| 328 |
-
# Retrieve
|
| 329 |
-
docs = st.session_state.vector_store.similarity_search(question, k=
|
| 330 |
context = " ".join(d.page_content for d in docs)
|
| 331 |
|
| 332 |
-
#
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
"say \"I don't know\"."
|
| 340 |
-
),
|
| 341 |
-
},
|
| 342 |
-
{
|
| 343 |
-
"role": "user",
|
| 344 |
-
"content": f"Context:\n{context}\n\nQuestion: {question}",
|
| 345 |
-
},
|
| 346 |
-
]
|
| 347 |
-
|
| 348 |
-
# Apply chat template β produces <|system|>...<|user|>...<|assistant|>
|
| 349 |
-
prompt = tokenizer.apply_chat_template(
|
| 350 |
-
messages,
|
| 351 |
-
tokenize=False,
|
| 352 |
-
add_generation_prompt=True, # appends <|assistant|> so model starts answering
|
| 353 |
)
|
| 354 |
|
| 355 |
inputs = tokenizer(
|
| 356 |
prompt,
|
| 357 |
return_tensors="pt",
|
| 358 |
truncation=True,
|
| 359 |
-
max_length=
|
| 360 |
)
|
| 361 |
|
| 362 |
with torch.no_grad():
|
| 363 |
outputs = model.generate(
|
| 364 |
**inputs,
|
| 365 |
max_new_tokens=300,
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
repetition_penalty=1.1,
|
| 370 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 371 |
)
|
| 372 |
|
| 373 |
-
|
| 374 |
-
generated = outputs[0][inputs["input_ids"].shape[1]:]
|
| 375 |
-
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 376 |
|
| 377 |
except Exception as e:
|
| 378 |
logging.error(f"Inference error: {e}")
|
|
@@ -388,7 +369,9 @@ with st.sidebar:
|
|
| 388 |
st.markdown("**Model**")
|
| 389 |
st.markdown(f"`{MODEL_NAME}`")
|
| 390 |
st.markdown("**Context window**")
|
| 391 |
-
st.markdown("`
|
|
|
|
|
|
|
| 392 |
st.markdown("**Status**")
|
| 393 |
if model_ok:
|
| 394 |
st.success("Model loaded β")
|
|
@@ -407,7 +390,7 @@ st.markdown(f"""
|
|
| 407 |
</div>
|
| 408 |
<div class="model-badge">
|
| 409 |
<div class="model-dot" style="background:{dot_color};"></div>
|
| 410 |
-
{dot_label} Β·
|
| 411 |
</div>
|
| 412 |
</div>
|
| 413 |
""", unsafe_allow_html=True)
|
|
@@ -477,7 +460,7 @@ if st.session_state.scraped_content:
|
|
| 477 |
with st.chat_message("user"):
|
| 478 |
st.markdown(prompt)
|
| 479 |
with st.chat_message("assistant"):
|
| 480 |
-
with st.spinner("
|
| 481 |
answer = answer_question(prompt)
|
| 482 |
st.markdown(answer)
|
| 483 |
st.session_state.chat_history.append({"role": "assistant", "content": answer})
|
|
|
|
| 3 |
import re
|
| 4 |
import logging
|
| 5 |
import torch
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
from playwright.sync_api import sync_playwright
|
| 8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 9 |
from langchain_community.vectorstores import FAISS
|
|
|
|
| 16 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 17 |
)
|
| 18 |
|
| 19 |
+
MODEL_NAME = "google/long-t5-tglobal-large"
|
| 20 |
+
MAX_INPUT_LEN = 16384 # LongT5's full context window
|
| 21 |
|
| 22 |
st.set_page_config(
|
| 23 |
+
page_title="RAG Β· LongT5",
|
| 24 |
page_icon="πΈοΈ",
|
| 25 |
layout="wide",
|
| 26 |
initial_sidebar_state="collapsed"
|
|
|
|
| 214 |
# ββ Utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
|
| 216 |
def clean_text(text):
|
|
|
|
| 217 |
text = re.sub(r'[ \t]+', ' ', text)
|
| 218 |
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 219 |
return text.strip()
|
|
|
|
| 227 |
def load_model():
|
| 228 |
try:
|
| 229 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 230 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 231 |
MODEL_NAME,
|
| 232 |
torch_dtype=torch.float32,
|
| 233 |
low_cpu_mem_usage=True,
|
|
|
|
| 246 |
browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
|
| 247 |
page = browser.new_page()
|
| 248 |
try:
|
| 249 |
+
# domcontentloaded avoids timeout on ad-heavy sites
|
|
|
|
|
|
|
| 250 |
try:
|
| 251 |
page.goto(url, wait_until="domcontentloaded", timeout=30000)
|
| 252 |
except Exception:
|
| 253 |
+
pass # content may already be loaded even on timeout
|
| 254 |
+
page.wait_for_timeout(3000) # allow JS 3s to render
|
| 255 |
title = page.title()
|
| 256 |
|
| 257 |
+
# Strategy 1: <li> items β great for price/listing pages
|
| 258 |
lines = []
|
| 259 |
for li in page.query_selector_all("li"):
|
| 260 |
try:
|
|
|
|
| 303 |
@st.cache_resource
|
| 304 |
def create_vector_store(text):
|
| 305 |
try:
|
| 306 |
+
# Larger chunks since LongT5 can handle much more context
|
| 307 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
| 308 |
docs = [Document(page_content=c) for c in splitter.split_text(text)]
|
| 309 |
emb = HuggingFaceEmbeddings(
|
| 310 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
|
|
|
| 324 |
if tokenizer is None:
|
| 325 |
return "Model failed to load. Check logs."
|
| 326 |
try:
|
| 327 |
+
# Retrieve more chunks β LongT5 can handle it
|
| 328 |
+
docs = st.session_state.vector_store.similarity_search(question, k=6)
|
| 329 |
context = " ".join(d.page_content for d in docs)
|
| 330 |
|
| 331 |
+
# LongT5 uses plain text prompt like T5 β no chat template needed
|
| 332 |
+
prompt = (
|
| 333 |
+
"Answer the question using only the context provided. "
|
| 334 |
+
"If the answer is not in the context, say \"I don't know\".\n\n"
|
| 335 |
+
f"Context: {context}\n\n"
|
| 336 |
+
f"Question: {question}\n\n"
|
| 337 |
+
"Answer:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
)
|
| 339 |
|
| 340 |
inputs = tokenizer(
|
| 341 |
prompt,
|
| 342 |
return_tensors="pt",
|
| 343 |
truncation=True,
|
| 344 |
+
max_length=MAX_INPUT_LEN, # full 16,384 token window
|
| 345 |
)
|
| 346 |
|
| 347 |
with torch.no_grad():
|
| 348 |
outputs = model.generate(
|
| 349 |
**inputs,
|
| 350 |
max_new_tokens=300,
|
| 351 |
+
num_beams=4,
|
| 352 |
+
early_stopping=True,
|
| 353 |
+
no_repeat_ngram_size=3,
|
|
|
|
|
|
|
| 354 |
)
|
| 355 |
|
| 356 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
|
|
|
|
|
|
| 357 |
|
| 358 |
except Exception as e:
|
| 359 |
logging.error(f"Inference error: {e}")
|
|
|
|
| 369 |
st.markdown("**Model**")
|
| 370 |
st.markdown(f"`{MODEL_NAME}`")
|
| 371 |
st.markdown("**Context window**")
|
| 372 |
+
st.markdown("`16,384 tokens`")
|
| 373 |
+
st.markdown("**Architecture**")
|
| 374 |
+
st.markdown("`Encoder-Decoder`")
|
| 375 |
st.markdown("**Status**")
|
| 376 |
if model_ok:
|
| 377 |
st.success("Model loaded β")
|
|
|
|
| 390 |
</div>
|
| 391 |
<div class="model-badge">
|
| 392 |
<div class="model-dot" style="background:{dot_color};"></div>
|
| 393 |
+
{dot_label} Β· LongT5-16k
|
| 394 |
</div>
|
| 395 |
</div>
|
| 396 |
""", unsafe_allow_html=True)
|
|
|
|
| 460 |
with st.chat_message("user"):
|
| 461 |
st.markdown(prompt)
|
| 462 |
with st.chat_message("assistant"):
|
| 463 |
+
with st.spinner("LongT5 is thinkingβ¦"):
|
| 464 |
answer = answer_question(prompt)
|
| 465 |
st.markdown(answer)
|
| 466 |
st.session_state.chat_history.append({"role": "assistant", "content": answer})
|