muddasser commited on
Commit
bb07c26
Β·
verified Β·
1 Parent(s): bf68fe9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -64
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import re
4
  import logging
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  from playwright.sync_api import sync_playwright
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import FAISS
@@ -14,13 +14,12 @@ logging.basicConfig(
14
  filename='/app/cache/app.log',
15
  level=logging.DEBUG,
16
  format='%(asctime)s - %(levelname)s - %(message)s'
17
-
18
  )
19
 
20
- MODEL_NAME = "google/flan-t5-large"
21
 
22
  st.set_page_config(
23
- page_title="RAG Β· Mistral",
24
  page_icon="πŸ•ΈοΈ",
25
  layout="wide",
26
  initial_sidebar_state="collapsed"
@@ -29,7 +28,6 @@ st.set_page_config(
29
  st.markdown("""
30
  <style>
31
  @import url('https://fonts.googleapis.com/css2?family=Instrument+Serif:ital@0;1&family=JetBrains+Mono:wght@300;400;500&display=swap');
32
-
33
  :root {
34
  --bg: #f5f0e8;
35
  --surface: #ede8df;
@@ -40,7 +38,6 @@ st.markdown("""
40
  --mono: 'JetBrains Mono', monospace;
41
  --serif: 'Instrument Serif', serif;
42
  }
43
-
44
  html, body, [class*="css"] {
45
  font-family: var(--mono);
46
  background: var(--bg);
@@ -49,12 +46,10 @@ html, body, [class*="css"] {
49
  .stApp { background: var(--bg); }
50
  #MainMenu, footer, header { visibility: hidden; }
51
  [data-testid="stDecoration"] { display: none; }
52
-
53
  [data-testid="stSidebar"] {
54
  background: var(--surface);
55
  border-right: 1px solid var(--border);
56
  }
57
-
58
  .stTextInput > div > div > input,
59
  .stTextArea textarea {
60
  background: #fff !important;
@@ -69,7 +64,6 @@ html, body, [class*="css"] {
69
  border-color: var(--accent) !important;
70
  box-shadow: 0 0 0 2px rgba(193,58,30,0.12) !important;
71
  }
72
-
73
  .stButton > button {
74
  background: var(--accent) !important;
75
  color: #fff !important;
@@ -88,7 +82,6 @@ html, body, [class*="css"] {
88
  transform: translateY(-1px);
89
  box-shadow: 0 3px 12px rgba(193,58,30,0.25) !important;
90
  }
91
-
92
  [data-testid="stChatMessage"] {
93
  background: #fff !important;
94
  border: 1px solid var(--border) !important;
@@ -100,9 +93,7 @@ html, body, [class*="css"] {
100
  font-family: var(--mono) !important;
101
  font-size: 0.82rem !important;
102
  }
103
-
104
  hr { border-color: var(--border) !important; }
105
-
106
  .content-box {
107
  background: #fff;
108
  border: 1px solid var(--border);
@@ -120,7 +111,6 @@ hr { border-color: var(--border) !important; }
120
  .content-box::-webkit-scrollbar { width: 6px; }
121
  .content-box::-webkit-scrollbar-track { background: var(--surface); }
122
  .content-box::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
123
-
124
  .meta-pill {
125
  display: inline-flex;
126
  align-items: center;
@@ -134,7 +124,6 @@ hr { border-color: var(--border) !important; }
134
  margin-bottom: 0.6rem;
135
  }
136
  .meta-dot { width:6px; height:6px; border-radius:50%; background:#4caf50; }
137
-
138
  .section-label {
139
  font-size: 0.68rem;
140
  letter-spacing: 0.12em;
@@ -151,7 +140,6 @@ hr { border-color: var(--border) !important; }
151
  height: 1px;
152
  background: var(--border);
153
  }
154
-
155
  .qa-banner {
156
  display: flex;
157
  align-items: center;
@@ -166,8 +154,7 @@ hr { border-color: var(--border) !important; }
166
  color: var(--accent);
167
  white-space: nowrap;
168
  }
169
-
170
- .ollama-badge {
171
  display: inline-flex;
172
  align-items: center;
173
  gap: 5px;
@@ -177,8 +164,7 @@ hr { border-color: var(--border) !important; }
177
  border: 1px solid var(--border);
178
  border-radius: 3px;
179
  }
180
- .ollama-dot { width:6px; height:6px; border-radius:50%; }
181
-
182
  .page-header {
183
  padding: 1.5rem 0 1rem 0;
184
  border-bottom: 2px solid var(--text);
@@ -202,7 +188,6 @@ hr { border-color: var(--border) !important; }
202
  letter-spacing: 0.08em;
203
  text-transform: uppercase;
204
  }
205
-
206
  [data-testid="stAlert"] {
207
  background: var(--surface) !important;
208
  border: 1px solid var(--border) !important;
@@ -228,7 +213,7 @@ for key, default in [
228
  # ── Utilities ──────────────────────────────────────────────────────────────────
229
 
230
  def clean_text(text):
231
- # Only collapse whitespace β€” preserve Rs. prices, commas, symbols
232
  text = re.sub(r'[ \t]+', ' ', text)
233
  text = re.sub(r'\n{3,}', '\n\n', text)
234
  return text.strip()
@@ -236,14 +221,17 @@ def clean_text(text):
236
  def is_valid_url(url):
237
  return bool(re.match(r'^https?://[\w\-\.]+(?::\d+)?(?:/[\w\-\./]*)*$', url))
238
 
239
- def check_model():
240
- return st.session_state.get('qa_model') is not None
241
 
242
  @st.cache_resource(show_spinner=False)
243
  def load_model():
244
  try:
245
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
246
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32)
 
 
 
 
247
  model.eval()
248
  logging.info(f"Loaded {MODEL_NAME}")
249
  return tokenizer, model
@@ -251,50 +239,47 @@ def load_model():
251
  logging.error(f"Model load error: {e}")
252
  return None, None
253
 
 
 
254
  def scrape_website(url):
255
  with sync_playwright() as p:
256
- browser = p.chromium.launch(headless=True, args=['--no-sandbox','--disable-dev-shm-usage'])
257
  page = browser.new_page()
258
  try:
259
  page.goto(url, wait_until="networkidle", timeout=45000)
260
  title = page.title()
261
 
262
- # Strategy 1: extract structured name+price pairs from <li> elements
263
- # Works well for listing/price pages like whatmobile.com.pk
264
  lines = []
265
- li_elements = page.query_selector_all("li")
266
- for li in li_elements:
267
  try:
268
  text = li.inner_text().strip()
269
- # Keep li items that contain a heading and a price-like pattern
270
- if text and len(text) > 3 and len(text) < 300:
271
  lines.append(text)
272
  except:
273
  continue
274
 
275
- # Strategy 2: grab all headings and paragraphs too
276
  for tag in ["h1", "h2", "h3", "h4", "p", "td"]:
277
- elements = page.query_selector_all(tag)
278
- for e in elements:
279
  try:
280
  text = e.inner_text().strip()
281
- if text and len(text) > 3 and len(text) < 500:
282
  lines.append(text)
283
  except:
284
  continue
285
 
286
- # Deduplicate while preserving order
287
- seen = set()
288
- unique_lines = []
289
  for line in lines:
290
- normalised = re.sub(r'\s+', ' ', line).strip()
291
- if normalised not in seen:
292
- seen.add(normalised)
293
- unique_lines.append(normalised)
294
 
295
  content = "\n".join(unique_lines)
296
 
297
- # Fallback to full body if we got almost nothing
298
  if len(content) < 200:
299
  body = page.query_selector("body")
300
  content = clean_text(body.inner_text()) if body else content
@@ -308,6 +293,8 @@ def scrape_website(url):
308
  finally:
309
  browser.close()
310
 
 
 
311
  @st.cache_resource
312
  def create_vector_store(text):
313
  try:
@@ -322,6 +309,8 @@ def create_vector_store(text):
322
  st.error(f"Indexing failed: {e}")
323
  return None
324
 
 
 
325
  def answer_question(question):
326
  if not st.session_state.vector_store:
327
  return "No content indexed yet."
@@ -329,23 +318,55 @@ def answer_question(question):
329
  if tokenizer is None:
330
  return "Model failed to load. Check logs."
331
  try:
 
332
  docs = st.session_state.vector_store.similarity_search(question, k=3)
333
  context = " ".join(d.page_content for d in docs)
334
- prompt = (
335
- "Answer the question using only the context provided. "
336
- "If the answer is not in the context, say \"I don't know\".\n\n"
337
- f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
340
  with torch.no_grad():
341
  outputs = model.generate(
342
  **inputs,
343
- max_new_tokens=200,
344
- num_beams=4,
345
- early_stopping=True,
346
- no_repeat_ngram_size=3,
 
 
347
  )
348
- return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
 
 
349
  except Exception as e:
350
  logging.error(f"Inference error: {e}")
351
  return f"Error generating answer: {e}"
@@ -359,6 +380,8 @@ model_ok = _tok is not None
359
  with st.sidebar:
360
  st.markdown("**Model**")
361
  st.markdown(f"`{MODEL_NAME}`")
 
 
362
  st.markdown("**Status**")
363
  if model_ok:
364
  st.success("Model loaded βœ“")
@@ -375,9 +398,9 @@ st.markdown(f"""
375
  <p class="page-title">Web RAG</p>
376
  <span class="page-sub">Scrape β†’ Index β†’ Ask</span>
377
  </div>
378
- <div class="ollama-badge">
379
- <div class="ollama-dot" style="background:{dot_color};"></div>
380
- {dot_label} &nbsp;Β·&nbsp; flan-t5-large
381
  </div>
382
  </div>
383
  """, unsafe_allow_html=True)
@@ -411,7 +434,6 @@ if scrape_clicked:
411
  # ── Main content area ──────────────────────────────────────────────────────────
412
  if st.session_state.scraped_content:
413
 
414
- # Meta pill
415
  title_display = st.session_state.scraped_title or ""
416
  url_display = st.session_state.scraped_url or ""
417
  st.markdown(f"""
@@ -421,18 +443,16 @@ if st.session_state.scraped_content:
421
  &nbsp;Β·&nbsp;
422
  <span>{st.session_state.char_count:,} chars</span>
423
  &nbsp;Β·&nbsp;
424
- <span style="max-width:300px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap;">{url_display}</span>
425
  </div>
426
  """, unsafe_allow_html=True)
427
 
428
- # Scraped content label + scrollable box
429
  st.markdown('<div class="section-label">Scraped content</div>', unsafe_allow_html=True)
430
  preview = st.session_state.scraped_content[:4000]
431
  if len(st.session_state.scraped_content) > 4000:
432
  preview += "\n\n… (truncated for display)"
433
  st.markdown(f'<div class="content-box">{preview}</div>', unsafe_allow_html=True)
434
 
435
- # ── Q&A section directly below ─────────────────────────────────────────────
436
  st.markdown("""
437
  <div class="qa-banner">
438
  <div class="qa-banner-line"></div>
@@ -441,18 +461,16 @@ if st.session_state.scraped_content:
441
  </div>
442
  """, unsafe_allow_html=True)
443
 
444
- # Render past exchanges
445
  for msg in st.session_state.chat_history:
446
  with st.chat_message(msg["role"]):
447
  st.markdown(msg["content"])
448
 
449
- # Chat input
450
  if prompt := st.chat_input("Ask anything about the content above…"):
451
  st.session_state.chat_history.append({"role": "user", "content": prompt})
452
  with st.chat_message("user"):
453
  st.markdown(prompt)
454
  with st.chat_message("assistant"):
455
- with st.spinner("FLAN-T5 is thinking…"):
456
  answer = answer_question(prompt)
457
  st.markdown(answer)
458
  st.session_state.chat_history.append({"role": "assistant", "content": answer})
@@ -463,7 +481,6 @@ if st.session_state.scraped_content:
463
  st.rerun()
464
 
465
  else:
466
- # Empty state
467
  st.markdown("""
468
  <div style="
469
  text-align:center;
 
3
  import re
4
  import logging
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from playwright.sync_api import sync_playwright
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import FAISS
 
14
  filename='/app/cache/app.log',
15
  level=logging.DEBUG,
16
  format='%(asctime)s - %(levelname)s - %(message)s'
 
17
  )
18
 
19
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
20
 
21
  st.set_page_config(
22
+ page_title="RAG Β· TinyLlama",
23
  page_icon="πŸ•ΈοΈ",
24
  layout="wide",
25
  initial_sidebar_state="collapsed"
 
28
  st.markdown("""
29
  <style>
30
  @import url('https://fonts.googleapis.com/css2?family=Instrument+Serif:ital@0;1&family=JetBrains+Mono:wght@300;400;500&display=swap');
 
31
  :root {
32
  --bg: #f5f0e8;
33
  --surface: #ede8df;
 
38
  --mono: 'JetBrains Mono', monospace;
39
  --serif: 'Instrument Serif', serif;
40
  }
 
41
  html, body, [class*="css"] {
42
  font-family: var(--mono);
43
  background: var(--bg);
 
46
  .stApp { background: var(--bg); }
47
  #MainMenu, footer, header { visibility: hidden; }
48
  [data-testid="stDecoration"] { display: none; }
 
49
  [data-testid="stSidebar"] {
50
  background: var(--surface);
51
  border-right: 1px solid var(--border);
52
  }
 
53
  .stTextInput > div > div > input,
54
  .stTextArea textarea {
55
  background: #fff !important;
 
64
  border-color: var(--accent) !important;
65
  box-shadow: 0 0 0 2px rgba(193,58,30,0.12) !important;
66
  }
 
67
  .stButton > button {
68
  background: var(--accent) !important;
69
  color: #fff !important;
 
82
  transform: translateY(-1px);
83
  box-shadow: 0 3px 12px rgba(193,58,30,0.25) !important;
84
  }
 
85
  [data-testid="stChatMessage"] {
86
  background: #fff !important;
87
  border: 1px solid var(--border) !important;
 
93
  font-family: var(--mono) !important;
94
  font-size: 0.82rem !important;
95
  }
 
96
  hr { border-color: var(--border) !important; }
 
97
  .content-box {
98
  background: #fff;
99
  border: 1px solid var(--border);
 
111
  .content-box::-webkit-scrollbar { width: 6px; }
112
  .content-box::-webkit-scrollbar-track { background: var(--surface); }
113
  .content-box::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
 
114
  .meta-pill {
115
  display: inline-flex;
116
  align-items: center;
 
124
  margin-bottom: 0.6rem;
125
  }
126
  .meta-dot { width:6px; height:6px; border-radius:50%; background:#4caf50; }
 
127
  .section-label {
128
  font-size: 0.68rem;
129
  letter-spacing: 0.12em;
 
140
  height: 1px;
141
  background: var(--border);
142
  }
 
143
  .qa-banner {
144
  display: flex;
145
  align-items: center;
 
154
  color: var(--accent);
155
  white-space: nowrap;
156
  }
157
+ .model-badge {
 
158
  display: inline-flex;
159
  align-items: center;
160
  gap: 5px;
 
164
  border: 1px solid var(--border);
165
  border-radius: 3px;
166
  }
167
+ .model-dot { width:6px; height:6px; border-radius:50%; }
 
168
  .page-header {
169
  padding: 1.5rem 0 1rem 0;
170
  border-bottom: 2px solid var(--text);
 
188
  letter-spacing: 0.08em;
189
  text-transform: uppercase;
190
  }
 
191
  [data-testid="stAlert"] {
192
  background: var(--surface) !important;
193
  border: 1px solid var(--border) !important;
 
213
  # ── Utilities ──────────────────────────────────────────────────────────────────
214
 
215
  def clean_text(text):
216
+ # Only collapse whitespace β€” preserve prices, commas, symbols
217
  text = re.sub(r'[ \t]+', ' ', text)
218
  text = re.sub(r'\n{3,}', '\n\n', text)
219
  return text.strip()
 
221
  def is_valid_url(url):
222
  return bool(re.match(r'^https?://[\w\-\.]+(?::\d+)?(?:/[\w\-\./]*)*$', url))
223
 
224
+ # ── Model ──────────────────────────────────────────────────────────────────────
 
225
 
226
  @st.cache_resource(show_spinner=False)
227
  def load_model():
228
  try:
229
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
230
+ model = AutoModelForCausalLM.from_pretrained(
231
+ MODEL_NAME,
232
+ torch_dtype=torch.float32,
233
+ low_cpu_mem_usage=True,
234
+ )
235
  model.eval()
236
  logging.info(f"Loaded {MODEL_NAME}")
237
  return tokenizer, model
 
239
  logging.error(f"Model load error: {e}")
240
  return None, None
241
 
242
+ # ── Scraper ────────────────────────────────────────────────────────────────────
243
+
244
  def scrape_website(url):
245
  with sync_playwright() as p:
246
+ browser = p.chromium.launch(headless=True, args=['--no-sandbox', '--disable-dev-shm-usage'])
247
  page = browser.new_page()
248
  try:
249
  page.goto(url, wait_until="networkidle", timeout=45000)
250
  title = page.title()
251
 
252
+ # Strategy 1: extract from <li> elements β€” good for listing/price pages
 
253
  lines = []
254
+ for li in page.query_selector_all("li"):
 
255
  try:
256
  text = li.inner_text().strip()
257
+ if text and 3 < len(text) < 300:
 
258
  lines.append(text)
259
  except:
260
  continue
261
 
262
+ # Strategy 2: headings, paragraphs, table cells
263
  for tag in ["h1", "h2", "h3", "h4", "p", "td"]:
264
+ for e in page.query_selector_all(tag):
 
265
  try:
266
  text = e.inner_text().strip()
267
+ if text and 3 < len(text) < 500:
268
  lines.append(text)
269
  except:
270
  continue
271
 
272
+ # Deduplicate preserving order
273
+ seen, unique_lines = set(), []
 
274
  for line in lines:
275
+ n = re.sub(r'\s+', ' ', line).strip()
276
+ if n not in seen:
277
+ seen.add(n)
278
+ unique_lines.append(n)
279
 
280
  content = "\n".join(unique_lines)
281
 
282
+ # Fallback to body if nothing found
283
  if len(content) < 200:
284
  body = page.query_selector("body")
285
  content = clean_text(body.inner_text()) if body else content
 
293
  finally:
294
  browser.close()
295
 
296
+ # ── Vector store ───────────────────────────────────────────────────────────────
297
+
298
  @st.cache_resource
299
  def create_vector_store(text):
300
  try:
 
309
  st.error(f"Indexing failed: {e}")
310
  return None
311
 
312
+ # ── Answer ─────────────────────────────────────────────────────────────────────
313
+
314
  def answer_question(question):
315
  if not st.session_state.vector_store:
316
  return "No content indexed yet."
 
318
  if tokenizer is None:
319
  return "Model failed to load. Check logs."
320
  try:
321
+ # Retrieve top 3 relevant chunks from FAISS
322
  docs = st.session_state.vector_store.similarity_search(question, k=3)
323
  context = " ".join(d.page_content for d in docs)
324
+
325
+ # TinyLlama expects the chat template format
326
+ messages = [
327
+ {
328
+ "role": "system",
329
+ "content": (
330
+ "You are a helpful assistant. Answer the user's question using "
331
+ "ONLY the context provided. If the answer is not in the context, "
332
+ "say \"I don't know\"."
333
+ ),
334
+ },
335
+ {
336
+ "role": "user",
337
+ "content": f"Context:\n{context}\n\nQuestion: {question}",
338
+ },
339
+ ]
340
+
341
+ # Apply chat template β†’ produces <|system|>...<|user|>...<|assistant|>
342
+ prompt = tokenizer.apply_chat_template(
343
+ messages,
344
+ tokenize=False,
345
+ add_generation_prompt=True, # appends <|assistant|> so model starts answering
346
+ )
347
+
348
+ inputs = tokenizer(
349
+ prompt,
350
+ return_tensors="pt",
351
+ truncation=True,
352
+ max_length=2048, # TinyLlama's full context window
353
  )
354
+
355
  with torch.no_grad():
356
  outputs = model.generate(
357
  **inputs,
358
+ max_new_tokens=300,
359
+ do_sample=True,
360
+ temperature=0.7,
361
+ top_p=0.95,
362
+ repetition_penalty=1.1,
363
+ pad_token_id=tokenizer.eos_token_id,
364
  )
365
+
366
+ # Slice off the prompt tokens β€” only decode what the model generated
367
+ generated = outputs[0][inputs["input_ids"].shape[1]:]
368
+ return tokenizer.decode(generated, skip_special_tokens=True).strip()
369
+
370
  except Exception as e:
371
  logging.error(f"Inference error: {e}")
372
  return f"Error generating answer: {e}"
 
380
  with st.sidebar:
381
  st.markdown("**Model**")
382
  st.markdown(f"`{MODEL_NAME}`")
383
+ st.markdown("**Context window**")
384
+ st.markdown("`2048 tokens`")
385
  st.markdown("**Status**")
386
  if model_ok:
387
  st.success("Model loaded βœ“")
 
398
  <p class="page-title">Web RAG</p>
399
  <span class="page-sub">Scrape β†’ Index β†’ Ask</span>
400
  </div>
401
+ <div class="model-badge">
402
+ <div class="model-dot" style="background:{dot_color};"></div>
403
+ {dot_label} &nbsp;Β·&nbsp; TinyLlama-1.1B-Chat
404
  </div>
405
  </div>
406
  """, unsafe_allow_html=True)
 
434
  # ── Main content area ──────────────────────────────────────────────────────────
435
  if st.session_state.scraped_content:
436
 
 
437
  title_display = st.session_state.scraped_title or ""
438
  url_display = st.session_state.scraped_url or ""
439
  st.markdown(f"""
 
443
  &nbsp;Β·&nbsp;
444
  <span>{st.session_state.char_count:,} chars</span>
445
  &nbsp;Β·&nbsp;
446
+ <span style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;">{url_display}</span>
447
  </div>
448
  """, unsafe_allow_html=True)
449
 
 
450
  st.markdown('<div class="section-label">Scraped content</div>', unsafe_allow_html=True)
451
  preview = st.session_state.scraped_content[:4000]
452
  if len(st.session_state.scraped_content) > 4000:
453
  preview += "\n\n… (truncated for display)"
454
  st.markdown(f'<div class="content-box">{preview}</div>', unsafe_allow_html=True)
455
 
 
456
  st.markdown("""
457
  <div class="qa-banner">
458
  <div class="qa-banner-line"></div>
 
461
  </div>
462
  """, unsafe_allow_html=True)
463
 
 
464
  for msg in st.session_state.chat_history:
465
  with st.chat_message(msg["role"]):
466
  st.markdown(msg["content"])
467
 
 
468
  if prompt := st.chat_input("Ask anything about the content above…"):
469
  st.session_state.chat_history.append({"role": "user", "content": prompt})
470
  with st.chat_message("user"):
471
  st.markdown(prompt)
472
  with st.chat_message("assistant"):
473
+ with st.spinner("TinyLlama is thinking…"):
474
  answer = answer_question(prompt)
475
  st.markdown(answer)
476
  st.session_state.chat_history.append({"role": "assistant", "content": answer})
 
481
  st.rerun()
482
 
483
  else:
 
484
  st.markdown("""
485
  <div style="
486
  text-align:center;