muddasser commited on
Commit
02d4635
Β·
verified Β·
1 Parent(s): 0a25fe2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -47
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import re
4
  import logging
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from playwright.sync_api import sync_playwright
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import FAISS
@@ -16,10 +16,11 @@ logging.basicConfig(
16
  format='%(asctime)s - %(levelname)s - %(message)s'
17
  )
18
 
19
- MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
20
 
21
  st.set_page_config(
22
- page_title="RAG Β· TinyLlama",
23
  page_icon="πŸ•ΈοΈ",
24
  layout="wide",
25
  initial_sidebar_state="collapsed"
@@ -213,7 +214,6 @@ for key, default in [
213
  # ── Utilities ──────────────────────────────────────────────────────────────────
214
 
215
  def clean_text(text):
216
- # Only collapse whitespace β€” preserve prices, commas, symbols
217
  text = re.sub(r'[ \t]+', ' ', text)
218
  text = re.sub(r'\n{3,}', '\n\n', text)
219
  return text.strip()
@@ -227,7 +227,7 @@ def is_valid_url(url):
227
  def load_model():
228
  try:
229
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
230
- model = AutoModelForCausalLM.from_pretrained(
231
  MODEL_NAME,
232
  torch_dtype=torch.float32,
233
  low_cpu_mem_usage=True,
@@ -246,17 +246,15 @@ def scrape_website(url):
246
  browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
247
  page = browser.new_page()
248
  try:
249
- # networkidle times out on ad-heavy sites like whatmobile.com.pk
250
- # domcontentloaded fires as soon as HTML is parsed, then we wait
251
- # a few seconds for JS-rendered content to appear
252
  try:
253
  page.goto(url, wait_until="domcontentloaded", timeout=30000)
254
  except Exception:
255
- pass # even if it times out, content may already be there
256
- page.wait_for_timeout(3000) # give JS 3s to render
257
  title = page.title()
258
 
259
- # Strategy 1: extract from <li> elements β€” good for listing/price pages
260
  lines = []
261
  for li in page.query_selector_all("li"):
262
  try:
@@ -305,7 +303,8 @@ def scrape_website(url):
305
  @st.cache_resource
306
  def create_vector_store(text):
307
  try:
308
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
 
309
  docs = [Document(page_content=c) for c in splitter.split_text(text)]
310
  emb = HuggingFaceEmbeddings(
311
  model_name="sentence-transformers/all-MiniLM-L6-v2",
@@ -325,54 +324,36 @@ def answer_question(question):
325
  if tokenizer is None:
326
  return "Model failed to load. Check logs."
327
  try:
328
- # Retrieve top 3 relevant chunks from FAISS
329
- docs = st.session_state.vector_store.similarity_search(question, k=3)
330
  context = " ".join(d.page_content for d in docs)
331
 
332
- # TinyLlama expects the chat template format
333
- messages = [
334
- {
335
- "role": "system",
336
- "content": (
337
- "You are a helpful assistant. Answer the user's question using "
338
- "ONLY the context provided. If the answer is not in the context, "
339
- "say \"I don't know\"."
340
- ),
341
- },
342
- {
343
- "role": "user",
344
- "content": f"Context:\n{context}\n\nQuestion: {question}",
345
- },
346
- ]
347
-
348
- # Apply chat template β†’ produces <|system|>...<|user|>...<|assistant|>
349
- prompt = tokenizer.apply_chat_template(
350
- messages,
351
- tokenize=False,
352
- add_generation_prompt=True, # appends <|assistant|> so model starts answering
353
  )
354
 
355
  inputs = tokenizer(
356
  prompt,
357
  return_tensors="pt",
358
  truncation=True,
359
- max_length=2048, # TinyLlama's full context window
360
  )
361
 
362
  with torch.no_grad():
363
  outputs = model.generate(
364
  **inputs,
365
  max_new_tokens=300,
366
- do_sample=True,
367
- temperature=0.7,
368
- top_p=0.95,
369
- repetition_penalty=1.1,
370
- pad_token_id=tokenizer.eos_token_id,
371
  )
372
 
373
- # Slice off the prompt tokens β€” only decode what the model generated
374
- generated = outputs[0][inputs["input_ids"].shape[1]:]
375
- return tokenizer.decode(generated, skip_special_tokens=True).strip()
376
 
377
  except Exception as e:
378
  logging.error(f"Inference error: {e}")
@@ -388,7 +369,9 @@ with st.sidebar:
388
  st.markdown("**Model**")
389
  st.markdown(f"`{MODEL_NAME}`")
390
  st.markdown("**Context window**")
391
- st.markdown("`2048 tokens`")
 
 
392
  st.markdown("**Status**")
393
  if model_ok:
394
  st.success("Model loaded βœ“")
@@ -407,7 +390,7 @@ st.markdown(f"""
407
  </div>
408
  <div class="model-badge">
409
  <div class="model-dot" style="background:{dot_color};"></div>
410
- {dot_label} &nbsp;Β·&nbsp; TinyLlama-1.1B-Chat
411
  </div>
412
  </div>
413
  """, unsafe_allow_html=True)
@@ -477,7 +460,7 @@ if st.session_state.scraped_content:
477
  with st.chat_message("user"):
478
  st.markdown(prompt)
479
  with st.chat_message("assistant"):
480
- with st.spinner("TinyLlama is thinking…"):
481
  answer = answer_question(prompt)
482
  st.markdown(answer)
483
  st.session_state.chat_history.append({"role": "assistant", "content": answer})
 
3
  import re
4
  import logging
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  from playwright.sync_api import sync_playwright
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import FAISS
 
16
  format='%(asctime)s - %(levelname)s - %(message)s'
17
  )
18
 
19
+ MODEL_NAME = "google/long-t5-tglobal-large"
20
+ MAX_INPUT_LEN = 16384 # LongT5's full context window
21
 
22
  st.set_page_config(
23
+ page_title="RAG Β· LongT5",
24
  page_icon="πŸ•ΈοΈ",
25
  layout="wide",
26
  initial_sidebar_state="collapsed"
 
214
  # ── Utilities ──────────────────────────────────────────────────────────────────
215
 
216
  def clean_text(text):
 
217
  text = re.sub(r'[ \t]+', ' ', text)
218
  text = re.sub(r'\n{3,}', '\n\n', text)
219
  return text.strip()
 
227
  def load_model():
228
  try:
229
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
230
+ model = AutoModelForSeq2SeqLM.from_pretrained(
231
  MODEL_NAME,
232
  torch_dtype=torch.float32,
233
  low_cpu_mem_usage=True,
 
246
  browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
247
  page = browser.new_page()
248
  try:
249
+ # domcontentloaded avoids timeout on ad-heavy sites
 
 
250
  try:
251
  page.goto(url, wait_until="domcontentloaded", timeout=30000)
252
  except Exception:
253
+ pass # content may already be loaded even on timeout
254
+ page.wait_for_timeout(3000) # allow JS 3s to render
255
  title = page.title()
256
 
257
+ # Strategy 1: <li> items β€” great for price/listing pages
258
  lines = []
259
  for li in page.query_selector_all("li"):
260
  try:
 
303
  @st.cache_resource
304
  def create_vector_store(text):
305
  try:
306
+ # Larger chunks since LongT5 can handle much more context
307
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
308
  docs = [Document(page_content=c) for c in splitter.split_text(text)]
309
  emb = HuggingFaceEmbeddings(
310
  model_name="sentence-transformers/all-MiniLM-L6-v2",
 
324
  if tokenizer is None:
325
  return "Model failed to load. Check logs."
326
  try:
327
+ # Retrieve more chunks β€” LongT5 can handle it
328
+ docs = st.session_state.vector_store.similarity_search(question, k=6)
329
  context = " ".join(d.page_content for d in docs)
330
 
331
+ # LongT5 uses plain text prompt like T5 β€” no chat template needed
332
+ prompt = (
333
+ "Answer the question using only the context provided. "
334
+ "If the answer is not in the context, say \"I don't know\".\n\n"
335
+ f"Context: {context}\n\n"
336
+ f"Question: {question}\n\n"
337
+ "Answer:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
 
340
  inputs = tokenizer(
341
  prompt,
342
  return_tensors="pt",
343
  truncation=True,
344
+ max_length=MAX_INPUT_LEN, # full 16,384 token window
345
  )
346
 
347
  with torch.no_grad():
348
  outputs = model.generate(
349
  **inputs,
350
  max_new_tokens=300,
351
+ num_beams=4,
352
+ early_stopping=True,
353
+ no_repeat_ngram_size=3,
 
 
354
  )
355
 
356
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
357
 
358
  except Exception as e:
359
  logging.error(f"Inference error: {e}")
 
369
  st.markdown("**Model**")
370
  st.markdown(f"`{MODEL_NAME}`")
371
  st.markdown("**Context window**")
372
+ st.markdown("`16,384 tokens`")
373
+ st.markdown("**Architecture**")
374
+ st.markdown("`Encoder-Decoder`")
375
  st.markdown("**Status**")
376
  if model_ok:
377
  st.success("Model loaded βœ“")
 
390
  </div>
391
  <div class="model-badge">
392
  <div class="model-dot" style="background:{dot_color};"></div>
393
+ {dot_label} &nbsp;Β·&nbsp; LongT5-16k
394
  </div>
395
  </div>
396
  """, unsafe_allow_html=True)
 
460
  with st.chat_message("user"):
461
  st.markdown(prompt)
462
  with st.chat_message("assistant"):
463
+ with st.spinner("LongT5 is thinking…"):
464
  answer = answer_question(prompt)
465
  st.markdown(answer)
466
  st.session_state.chat_history.append({"role": "assistant", "content": answer})