johnnydang88 commited on
Commit
d87893b
·
verified ·
1 Parent(s): b5830e8

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +13 -0
  3. app.py +907 -0
  4. laborcode.pdf +3 -0
  5. requirements.txt +17 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ laborcode.pdf filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Philippine Labor Code RAG Assistant
3
+ emoji: "👷"
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: "1.40.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Philippine Labor Code & Employee Rights Assistant
13
+ Multi-Model RAG Evaluation: Qwen2.5-7B, LLaMA-3.1-8B, Gemma-2-9B
app.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Philippine Labor Code & Employee Rights Assistant
3
+ Multi-Model RAG Evaluation: Qwen2.5-7B | LLaMA-3.1-8B | Gemma-2-9B
4
+
5
+ Streamlit deployment converts the Colab notebook pipeline into a
6
+ permanent web application. All three models share a single retrieval
7
+ pipeline (BGE-M3 dense + BM25 sparse + RRF fusion + cross-encoder
8
+ reranking + MMR diversity selection).
9
+ """
10
+
11
+ import os
12
+ import re
13
+ import gc
14
+ import time
15
+ import json
16
+ import nltk
17
+ import torch
18
+ import faiss
19
+ import numpy as np
20
+ import streamlit as st
21
+ import matplotlib
22
+ matplotlib.use("Agg")
23
+ import matplotlib.pyplot as plt
24
+
25
+ from collections import Counter
26
+ from pypdf import PdfReader
27
+ from rank_bm25 import BM25Okapi
28
+ from sentence_transformers import SentenceTransformer, CrossEncoder
29
+ from sklearn.metrics.pairwise import cosine_similarity as cos_sim
30
+ from rouge_score import rouge_scorer
31
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Download NLTK data at startup
35
+ # ---------------------------------------------------------------------------
36
+ nltk.download("punkt", quiet=True)
37
+ nltk.download("punkt_tab", quiet=True)
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Page configuration
41
+ # ---------------------------------------------------------------------------
42
+ st.set_page_config(
43
+ page_title="Philippine Labor Code RAG Assistant",
44
+ page_icon=None,
45
+ layout="wide",
46
+ initial_sidebar_state="expanded",
47
+ )
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Constants
51
+ # ---------------------------------------------------------------------------
52
+ PDF_PATH = "laborcode.pdf"
53
+ EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
54
+ RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
55
+
56
+ MODEL_CONFIGS = {
57
+ "Qwen2.5-7B-Instruct": {
58
+ "hf_id": "Qwen/Qwen2.5-7B-Instruct",
59
+ "supports_system": True,
60
+ },
61
+ "LLaMA-3.1-8B-Instruct": {
62
+ "hf_id": "meta-llama/Meta-Llama-3.1-8B-Instruct",
63
+ "supports_system": True,
64
+ },
65
+ "Gemma-2-9B-IT": {
66
+ "hf_id": "google/gemma-2-9b-it",
67
+ "supports_system": False,
68
+ },
69
+ }
70
+
71
+ SYSTEM_PROMPT = (
72
+ "You are Lex, a helpful and professional Philippine Labor Law assistant. "
73
+ "Answer the user's question accurately using ONLY the provided legal context "
74
+ "from the Philippine Labor Code (Presidential Decree No. 442, as amended). "
75
+ "Always cite the specific Article number(s) when applicable. "
76
+ "If the context does not contain enough information to answer, say so honestly. "
77
+ "Do not fabricate legal provisions."
78
+ )
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Greeting detection (from notebook Cell 7)
82
+ # ---------------------------------------------------------------------------
83
+ GREETING_PATTERNS = [
84
+ r"^(hi|hello|hey|good morning|good afternoon|good evening|kamusta|kumusta|musta|helo|oi|uy)[\s!?.]*$",
85
+ r"^(what can you do|what are you|who are you|what is lex|are you a bot|are you ai)[\s?]*$",
86
+ r"^(thanks?|thank you|salamat|maraming salamat|ty)[\s!.]*$",
87
+ r"^(bye|goodbye|see you|paalam|ok|okay|sure|alright|got it|noted)[\s!.]*$",
88
+ r"^(help|tulong|tulungan mo ako)[\s!?]*$",
89
+ ]
90
+
91
+ GREETING_RESPONSE = (
92
+ "Hello! I am Lex, your Philippine Labor Law assistant. "
93
+ "Feel free to ask me any questions about labor rights, employment policies, "
94
+ "wages, working hours, leaves, termination, or any workplace concerns under "
95
+ "the Philippine Labor Code (PD 442). How can I help you today?"
96
+ )
97
+
98
+
99
+ def is_greeting(text: str) -> bool:
100
+ t = text.strip().lower()
101
+ for pat in GREETING_PATTERNS:
102
+ if re.match(pat, t, re.IGNORECASE):
103
+ return True
104
+ legal_keywords = [
105
+ "article", "labor", "wage", "leave", "work", "employ",
106
+ "salary", "pay", "overtime", "holiday", "terminate",
107
+ "strike", "union", "dole", "law", "code", "right",
108
+ "benefit", "retire", "resign", "dismiss",
109
+ ]
110
+ tokens = t.split()
111
+ if len(tokens) <= 3 and not any(kw in t for kw in legal_keywords):
112
+ return True
113
+ return False
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # PDF processing and chunking (from notebook Cells 4-5)
118
+ # ---------------------------------------------------------------------------
119
+ def extract_text_from_pdf(pdf_path: str) -> str:
120
+ reader = PdfReader(pdf_path)
121
+ text = ""
122
+ for page in reader.pages:
123
+ page_text = page.extract_text()
124
+ if page_text:
125
+ text += page_text + "\n"
126
+ return text
127
+
128
+
129
+ def clean_text(text: str) -> str:
130
+ text = re.sub(r"(ART\.)\s*(\d+)\s+(\d+)\.", r"\1 \2\3.", text)
131
+ text = re.sub(r"(Article\s+)(\d+)\s+(\d+)", r"\1\2\3", text)
132
+ text = re.sub(r"---\s*Page\s*\d+\s*---", "", text)
133
+ text = re.sub(
134
+ r"\n\s*\d{1,3}\s+(?:See|As amended|R\.A\.|P\.D\.|E\.O\.|The |This |Pursuant|Section|Sec\.).*",
135
+ "",
136
+ text,
137
+ flags=re.IGNORECASE,
138
+ )
139
+ text = re.sub(r"\[Footnote\].*?\n", "\n", text, flags=re.DOTALL)
140
+ text = re.sub(r"\n\s*\d{1,4}\s*\n", "\n", text)
141
+ text = re.sub(r"[ \t]{3,}", " ", text)
142
+ text = re.sub(r"\n{3,}", "\n\n", text)
143
+ return text.strip()
144
+
145
+
146
+ def is_substantive_chunk(chunk: str, max_footnote_ratio: float = 0.08) -> bool:
147
+ footnote_markers = [
148
+ "[Footnote]", "See DOLE", "As amended by", "superseded by",
149
+ "cross-reference", "R.A. No.", "P.D. No.", "E.O. No.",
150
+ "pursuant to", "inserted in", "renumbered as",
151
+ ]
152
+ words = chunk.split()
153
+ if len(words) == 0:
154
+ return False
155
+ footnote_hits = sum(chunk.lower().count(m.lower()) for m in footnote_markers)
156
+ ratio = footnote_hits / len(words)
157
+ return ratio < max_footnote_ratio
158
+
159
+
160
+ def fix_broken_article_header(chunk: str) -> str:
161
+ return re.sub(
162
+ r"(ART\.?\s*)(\d)\s+(\d+\.)",
163
+ lambda m: m.group(1) + m.group(2) + m.group(3),
164
+ chunk,
165
+ flags=re.IGNORECASE,
166
+ )
167
+
168
+
169
+ def chunk_text_by_article(
170
+ text: str, max_len: int = 1200, overlap: int = 200, min_len: int = 100
171
+ ) -> list[str]:
172
+ article_pattern = re.compile(
173
+ r"(?=(?:ART\.|Art\.|ARTICLE)\s+\d+[\.\ ])", re.IGNORECASE
174
+ )
175
+ raw_splits = article_pattern.split(text)
176
+ chunks = []
177
+ for block in raw_splits:
178
+ block = block.strip()
179
+ if not block:
180
+ continue
181
+ if len(block) <= max_len:
182
+ if len(block) >= min_len:
183
+ chunks.append(block)
184
+ else:
185
+ header_match = re.match(
186
+ r"((?:ART\.|Art\.|ARTICLE)\s+\d+[^.]*\.)", block
187
+ )
188
+ header = header_match.group(1).strip() if header_match else ""
189
+ sentences = re.split(r"(?<=[.!?;])\s+", block)
190
+ current = ""
191
+ chunk_num = 0
192
+ for sent in sentences:
193
+ if len(current) + len(sent) > max_len:
194
+ if current:
195
+ chunks.append(current.strip())
196
+ chunk_num += 1
197
+ tail = current[-overlap:] if len(current) > overlap else ""
198
+ current = (
199
+ (header + " [cont] " + tail)
200
+ if header and chunk_num > 0
201
+ else tail
202
+ )
203
+ current += " " + sent
204
+ if current.strip() and len(current.strip()) >= min_len:
205
+ chunks.append(current.strip())
206
+
207
+ boilerplate_patterns = [
208
+ "NOT FOR SALE", "Copyright", "SILVESTRE H. BELLO",
209
+ "Table of Contents", "FOREWORD", "www.dole.gov.ph",
210
+ "Repealing Clause", "cross-references all superseded",
211
+ "Name of Decree",
212
+ ]
213
+ chunks = [
214
+ c
215
+ for c in chunks
216
+ if not any(bp.lower() in c.lower() for bp in boilerplate_patterns)
217
+ and len(c.strip()) > min_len
218
+ and is_substantive_chunk(c)
219
+ ]
220
+ chunks = [fix_broken_article_header(c) for c in chunks]
221
+ return chunks
222
+
223
+
224
+ # ---------------------------------------------------------------------------
225
+ # Retrieval functions (from notebook Cells 6-7)
226
+ # ---------------------------------------------------------------------------
227
+ def mmr_select(
228
+ candidates: list[str],
229
+ scores: list[float],
230
+ embeddings: np.ndarray,
231
+ k: int = 5,
232
+ lam: float = 0.6,
233
+ ):
234
+ if len(candidates) <= k:
235
+ return candidates, scores
236
+ embs = np.array(embeddings)
237
+ selected_idx = []
238
+ remaining = list(range(len(candidates)))
239
+ while len(selected_idx) < k and remaining:
240
+ if not selected_idx:
241
+ best = max(remaining, key=lambda i: scores[i])
242
+ else:
243
+ sel_embs = embs[selected_idx]
244
+
245
+ def mmr_score(i, _sel_embs=sel_embs):
246
+ rel = scores[i]
247
+ sim = float(np.max(_sel_embs @ embs[i]))
248
+ return lam * rel - (1.0 - lam) * sim
249
+
250
+ best = max(remaining, key=mmr_score)
251
+ selected_idx.append(best)
252
+ remaining.remove(best)
253
+ return [candidates[i] for i in selected_idx], [scores[i] for i in selected_idx]
254
+
255
+
256
+ def deduplicate_by_article(
257
+ ranked_pairs: list, max_per_article: int = 2, final_k: int = 5
258
+ ) -> list:
259
+ seen_art = {}
260
+ final = []
261
+ for chunk, score in ranked_pairs:
262
+ match = re.match(r"(ART\.?\s*\d+)", chunk, re.IGNORECASE)
263
+ key = match.group(1).upper().replace(" ", "") if match else "UNK"
264
+ count = seen_art.get(key, 0)
265
+ if count < max_per_article:
266
+ final.append((chunk, score))
267
+ seen_art[key] = count + 1
268
+ if len(final) == final_k:
269
+ break
270
+ return final
271
+
272
+
273
+ def hybrid_retrieve_and_rerank(
274
+ question: str,
275
+ embedder: SentenceTransformer,
276
+ index: faiss.IndexFlatIP,
277
+ bm25: BM25Okapi,
278
+ reranker: CrossEncoder,
279
+ chunks: list[str],
280
+ initial_k: int = 20,
281
+ rerank_k: int = 8,
282
+ final_k: int = 5,
283
+ ):
284
+ # Dense retrieval
285
+ query_prefixed = f"query: {question}"
286
+ query_emb = embedder.encode(
287
+ [query_prefixed], convert_to_numpy=True, normalize_embeddings=True
288
+ )
289
+ dense_scores, dense_indices = index.search(query_emb, initial_k)
290
+ dense_ranking = list(dense_indices[0])
291
+
292
+ # BM25
293
+ bm25_scores = bm25.get_scores(question.lower().split())
294
+ bm25_ranking = list(np.argsort(bm25_scores)[::-1][:initial_k])
295
+
296
+ # RRF
297
+ rrf_k = 60
298
+ rrf_scores = {}
299
+ for rank, idx in enumerate(dense_ranking):
300
+ rrf_scores[idx] = rrf_scores.get(idx, 0.0) + 1.0 / (rank + rrf_k)
301
+ for rank, idx in enumerate(bm25_ranking):
302
+ rrf_scores[idx] = rrf_scores.get(idx, 0.0) + 1.0 / (rank + rrf_k)
303
+
304
+ fused_indices = sorted(rrf_scores, key=rrf_scores.get, reverse=True)[:initial_k]
305
+ candidate_chunks = [chunks[i] for i in fused_indices]
306
+
307
+ # Cross-encoder reranking
308
+ pairs = [[question, chunk] for chunk in candidate_chunks]
309
+ rerank_scores_arr = reranker.predict(pairs)
310
+ ranked_all = sorted(
311
+ zip(candidate_chunks, rerank_scores_arr.tolist()),
312
+ key=lambda x: x[1],
313
+ reverse=True,
314
+ )[:rerank_k]
315
+
316
+ # Per-article deduplication
317
+ deduped = deduplicate_by_article(ranked_all, max_per_article=2, final_k=rerank_k)
318
+ dedup_chunks = [x[0] for x in deduped]
319
+ dedup_scores = [x[1] for x in deduped]
320
+
321
+ # MMR selection
322
+ cand_embs = embedder.encode(
323
+ [f"passage: {c}" for c in dedup_chunks],
324
+ convert_to_numpy=True,
325
+ normalize_embeddings=True,
326
+ )
327
+ top_chunks, top_scores = mmr_select(
328
+ dedup_chunks, dedup_scores, cand_embs, k=final_k, lam=0.6
329
+ )
330
+ return top_chunks, top_scores
331
+
332
+
333
+ # ---------------------------------------------------------------------------
334
+ # Model loading and generation
335
+ # ---------------------------------------------------------------------------
336
+ def load_model_and_tokenizer(model_name: str):
337
+ """Load a model with 4-bit quantization. Returns (model, tokenizer)."""
338
+ config = MODEL_CONFIGS[model_name]
339
+ hf_id = config["hf_id"]
340
+
341
+ bnb_config = BitsAndBytesConfig(
342
+ load_in_4bit=True,
343
+ bnb_4bit_quant_type="nf4",
344
+ bnb_4bit_compute_dtype=torch.float16,
345
+ bnb_4bit_use_double_quant=True,
346
+ )
347
+
348
+ tokenizer = AutoTokenizer.from_pretrained(hf_id, trust_remote_code=True)
349
+ if tokenizer.pad_token is None:
350
+ tokenizer.pad_token = tokenizer.eos_token
351
+
352
+ model = AutoModelForCausalLM.from_pretrained(
353
+ hf_id,
354
+ quantization_config=bnb_config,
355
+ device_map="auto",
356
+ trust_remote_code=True,
357
+ torch_dtype=torch.float16,
358
+ )
359
+ model.eval()
360
+ return model, tokenizer
361
+
362
+
363
+ def unload_model(model, tokenizer):
364
+ """Free GPU memory."""
365
+ del model
366
+ del tokenizer
367
+ gc.collect()
368
+ if torch.cuda.is_available():
369
+ torch.cuda.empty_cache()
370
+
371
+
372
+ def build_prompt(
373
+ model_name: str,
374
+ question: str,
375
+ context_chunks: list[str],
376
+ tokenizer,
377
+ ) -> str:
378
+ """Build a chat-formatted prompt appropriate for each model."""
379
+ context_block = "\n\n---\n\n".join(context_chunks)
380
+ user_content = (
381
+ f"CONTEXT (from the Philippine Labor Code):\n{context_block}\n\n"
382
+ f"QUESTION: {question}\n\n"
383
+ f"Provide a clear, accurate answer citing specific Article numbers."
384
+ )
385
+
386
+ config = MODEL_CONFIGS[model_name]
387
+
388
+ if config["supports_system"]:
389
+ messages = [
390
+ {"role": "system", "content": SYSTEM_PROMPT},
391
+ {"role": "user", "content": user_content},
392
+ ]
393
+ else:
394
+ # Gemma-2 does not support system role; inject into first user turn
395
+ combined = f"{SYSTEM_PROMPT}\n\n{user_content}"
396
+ messages = [{"role": "user", "content": combined}]
397
+
398
+ prompt = tokenizer.apply_chat_template(
399
+ messages, tokenize=False, add_generation_prompt=True
400
+ )
401
+ return prompt
402
+
403
+
404
+ def generate_answer(
405
+ model,
406
+ tokenizer,
407
+ prompt: str,
408
+ max_new_tokens: int = 512,
409
+ ) -> str:
410
+ """Generate an answer from the model."""
411
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
412
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
413
+
414
+ with torch.no_grad():
415
+ outputs = model.generate(
416
+ **inputs,
417
+ max_new_tokens=max_new_tokens,
418
+ temperature=0.3,
419
+ top_p=0.9,
420
+ do_sample=True,
421
+ repetition_penalty=1.15,
422
+ pad_token_id=tokenizer.pad_token_id,
423
+ )
424
+
425
+ # Decode only the newly generated tokens
426
+ input_len = inputs["input_ids"].shape[1]
427
+ answer = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
428
+ return answer
429
+
430
+
431
+ # ---------------------------------------------------------------------------
432
+ # Evaluation metrics (from the notebook)
433
+ # ---------------------------------------------------------------------------
434
+ def compute_faithfulness(answer: str, context_chunks: list[str], embedder) -> float:
435
+ """Semantic similarity between the answer and the retrieved context."""
436
+ if not answer or not context_chunks:
437
+ return 0.0
438
+ context_combined = " ".join(context_chunks)
439
+ embs = embedder.encode(
440
+ [answer, context_combined], convert_to_numpy=True, normalize_embeddings=True
441
+ )
442
+ return float(cos_sim([embs[0]], [embs[1]])[0][0])
443
+
444
+
445
+ def compute_semantic_similarity(
446
+ answer: str, ground_truth: str, embedder
447
+ ) -> float:
448
+ """Cosine similarity between answer and ground truth embeddings."""
449
+ if not answer or not ground_truth:
450
+ return 0.0
451
+ embs = embedder.encode(
452
+ [answer, ground_truth], convert_to_numpy=True, normalize_embeddings=True
453
+ )
454
+ return float(cos_sim([embs[0]], [embs[1]])[0][0])
455
+
456
+
457
+ def compute_answer_relevancy(answer: str, question: str, embedder) -> float:
458
+ """Cosine similarity between answer and question embeddings."""
459
+ if not answer or not question:
460
+ return 0.0
461
+ embs = embedder.encode(
462
+ [answer, question], convert_to_numpy=True, normalize_embeddings=True
463
+ )
464
+ return float(cos_sim([embs[0]], [embs[1]])[0][0])
465
+
466
+
467
+ def compute_citation_accuracy(
468
+ answer: str, expected_articles: list[str]
469
+ ) -> float:
470
+ """Fraction of expected article numbers that appear in the answer."""
471
+ if not expected_articles:
472
+ return 1.0
473
+ found = 0
474
+ for art in expected_articles:
475
+ patterns = [
476
+ rf"Article\s*{art}\b",
477
+ rf"Art\.?\s*{art}\b",
478
+ rf"ART\.?\s*{art}\b",
479
+ ]
480
+ if any(re.search(p, answer, re.IGNORECASE) for p in patterns):
481
+ found += 1
482
+ return found / len(expected_articles)
483
+
484
+
485
+ def compute_rouge_l(answer: str, ground_truth: str) -> float:
486
+ """ROUGE-L F1 score."""
487
+ if not answer or not ground_truth:
488
+ return 0.0
489
+ scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
490
+ scores = scorer.score(ground_truth, answer)
491
+ return scores["rougeL"].fmeasure
492
+
493
+
494
+ def compute_retrieval_recall(
495
+ chunks: list[str], expected_articles: list[str]
496
+ ) -> float:
497
+ """Fraction of expected articles found in retrieved chunks."""
498
+ if not expected_articles:
499
+ return 1.0
500
+ found = 0
501
+ combined = " ".join(chunks)
502
+ for art in expected_articles:
503
+ patterns = [
504
+ rf"Article\s*{art}\b",
505
+ rf"Art\.?\s*{art}\b",
506
+ rf"ART\.?\s*{art}\b",
507
+ ]
508
+ if any(re.search(p, combined, re.IGNORECASE) for p in patterns):
509
+ found += 1
510
+ return found / len(expected_articles)
511
+
512
+
513
+ def compute_retrieval_precision(
514
+ chunks: list[str], expected_articles: list[str]
515
+ ) -> float:
516
+ """Fraction of retrieved chunks that contain at least one expected article."""
517
+ if not chunks:
518
+ return 0.0
519
+ relevant_count = 0
520
+ for chunk in chunks:
521
+ for art in expected_articles:
522
+ patterns = [
523
+ rf"Article\s*{art}\b",
524
+ rf"Art\.?\s*{art}\b",
525
+ rf"ART\.?\s*{art}\b",
526
+ ]
527
+ if any(re.search(p, chunk, re.IGNORECASE) for p in patterns):
528
+ relevant_count += 1
529
+ break
530
+ return relevant_count / len(chunks)
531
+
532
+
533
+ def evaluate_single_response(
534
+ question: str,
535
+ answer: str,
536
+ context_chunks: list[str],
537
+ ground_truth: str,
538
+ expected_articles: list[str],
539
+ embedder,
540
+ ) -> dict:
541
+ """Compute all evaluation metrics for a single response."""
542
+ return {
543
+ "Faithfulness": round(
544
+ compute_faithfulness(answer, context_chunks, embedder), 4
545
+ ),
546
+ "Semantic Similarity": round(
547
+ compute_semantic_similarity(answer, ground_truth, embedder), 4
548
+ ),
549
+ "Answer Relevancy": round(
550
+ compute_answer_relevancy(answer, question, embedder), 4
551
+ ),
552
+ "Citation Accuracy": round(
553
+ compute_citation_accuracy(answer, expected_articles), 4
554
+ ),
555
+ "ROUGE-L": round(compute_rouge_l(answer, ground_truth), 4),
556
+ "Recall@5": round(
557
+ compute_retrieval_recall(context_chunks, expected_articles), 4
558
+ ),
559
+ "Precision@5": round(
560
+ compute_retrieval_precision(context_chunks, expected_articles), 4
561
+ ),
562
+ }
563
+
564
+
565
+ # ---------------------------------------------------------------------------
566
+ # Visualization
567
+ # ---------------------------------------------------------------------------
568
+ def render_comparison_chart(all_metrics: dict) -> plt.Figure:
569
+ """
570
+ Create a grouped bar chart comparing metrics across models.
571
+ all_metrics: { "ModelName": { "MetricName": value, ... }, ... }
572
+ """
573
+ metric_names = [
574
+ "Faithfulness",
575
+ "Semantic Similarity",
576
+ "Answer Relevancy",
577
+ "Citation Accuracy",
578
+ "ROUGE-L",
579
+ "Recall@5",
580
+ "Precision@5",
581
+ ]
582
+ model_names = list(all_metrics.keys())
583
+ n_metrics = len(metric_names)
584
+ n_models = len(model_names)
585
+
586
+ x = np.arange(n_metrics)
587
+ width = 0.8 / max(n_models, 1)
588
+ colors = ["#2563eb", "#dc2626", "#16a34a"]
589
+
590
+ fig, ax = plt.subplots(figsize=(14, 6))
591
+ for i, model in enumerate(model_names):
592
+ values = [all_metrics[model].get(m, 0.0) for m in metric_names]
593
+ offset = (i - n_models / 2 + 0.5) * width
594
+ bars = ax.bar(x + offset, values, width, label=model, color=colors[i % 3])
595
+ for bar, val in zip(bars, values):
596
+ ax.text(
597
+ bar.get_x() + bar.get_width() / 2,
598
+ bar.get_height() + 0.01,
599
+ f"{val:.2f}",
600
+ ha="center",
601
+ va="bottom",
602
+ fontsize=7,
603
+ )
604
+
605
+ ax.set_ylabel("Score")
606
+ ax.set_title("Multi-Model RAG Evaluation Comparison")
607
+ ax.set_xticks(x)
608
+ ax.set_xticklabels(metric_names, rotation=30, ha="right")
609
+ ax.set_ylim(0, 1.15)
610
+ ax.legend(loc="upper right")
611
+ ax.grid(axis="y", alpha=0.3)
612
+ fig.tight_layout()
613
+ return fig
614
+
615
+
616
+ # ---------------------------------------------------------------------------
617
+ # Cached resource loaders
618
+ # ---------------------------------------------------------------------------
619
+ @st.cache_resource(show_spinner="Loading PDF and building document chunks...")
620
+ def load_chunks():
621
+ if not os.path.exists(PDF_PATH):
622
+ st.error(
623
+ f"PDF file not found at '{PDF_PATH}'. "
624
+ "Please place 'laborcode.pdf' in the application directory."
625
+ )
626
+ st.stop()
627
+ raw_text = extract_text_from_pdf(PDF_PATH)
628
+ cleaned = clean_text(raw_text)
629
+ ch = chunk_text_by_article(cleaned)
630
+ return ch
631
+
632
+
633
+ @st.cache_resource(show_spinner="Loading embedding model and building indices...")
634
+ def load_retrieval_infra(_chunks: list[str]):
635
+ embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
636
+ prefixed = [f"passage: {c}" for c in _chunks]
637
+ chunk_embeddings = embedder.encode(
638
+ prefixed,
639
+ batch_size=16,
640
+ show_progress_bar=False,
641
+ convert_to_numpy=True,
642
+ normalize_embeddings=True,
643
+ )
644
+ dimension = chunk_embeddings.shape[1]
645
+ idx = faiss.IndexFlatIP(dimension)
646
+ idx.add(chunk_embeddings)
647
+ tokenized = [c.lower().split() for c in _chunks]
648
+ bm25 = BM25Okapi(tokenized)
649
+ reranker = CrossEncoder(RERANKER_MODEL_NAME)
650
+ return embedder, idx, bm25, reranker
651
+
652
+
653
+ # ---------------------------------------------------------------------------
654
+ # Main application
655
+ # ---------------------------------------------------------------------------
656
+ def main():
657
+ # --- Sidebar ---
658
+ st.sidebar.title("Configuration")
659
+ st.sidebar.markdown("---")
660
+ st.sidebar.subheader("About")
661
+ st.sidebar.markdown(
662
+ "**Philippine Labor Code RAG Assistant**\n\n"
663
+ "This application uses Retrieval-Augmented Generation to answer "
664
+ "questions about the Philippine Labor Code (PD 442). Three models "
665
+ "are evaluated side-by-side:\n\n"
666
+ "- Qwen2.5-7B-Instruct\n"
667
+ "- LLaMA-3.1-8B-Instruct\n"
668
+ "- Gemma-2-9B-IT\n\n"
669
+ "All models use 4-bit quantization and share the same hybrid "
670
+ "retrieval pipeline (BGE-M3 + BM25 + RRF + cross-encoder reranking)."
671
+ )
672
+
673
+ st.sidebar.markdown("---")
674
+ st.sidebar.subheader("Pipeline Parameters")
675
+ top_k = st.sidebar.slider("Final retrieved chunks (top-k)", 3, 10, 5)
676
+ max_tokens = st.sidebar.slider("Max generation tokens", 128, 1024, 512, step=64)
677
+
678
+ st.sidebar.markdown("---")
679
+ st.sidebar.subheader("Ground Truth (optional)")
680
+ ground_truth = st.sidebar.text_area(
681
+ "Expected answer for evaluation",
682
+ placeholder="Paste ground truth here to compute evaluation metrics...",
683
+ height=120,
684
+ )
685
+ expected_articles_raw = st.sidebar.text_input(
686
+ "Expected article numbers (comma-separated)",
687
+ placeholder="e.g. 83, 86, 94",
688
+ )
689
+ expected_articles = [
690
+ a.strip() for a in expected_articles_raw.split(",") if a.strip()
691
+ ]
692
+
693
+ # --- Main area ---
694
+ st.title("Philippine Labor Code & Employee Rights Assistant")
695
+ st.markdown(
696
+ "Multi-Model RAG Evaluation: "
697
+ "**Qwen2.5-7B** | **LLaMA-3.1-8B** | **Gemma-2-9B**"
698
+ )
699
+ st.markdown("---")
700
+
701
+ # Load retrieval infrastructure
702
+ chunks = load_chunks()
703
+ embedder, faiss_index, bm25, reranker = load_retrieval_infra(chunks)
704
+
705
+ st.success(
706
+ f"Retrieval pipeline ready. {len(chunks)} document chunks indexed."
707
+ )
708
+
709
+ # --- Query input ---
710
+ st.subheader("Ask a Question")
711
+ question = st.text_input(
712
+ "Enter your question about the Philippine Labor Code:",
713
+ placeholder="e.g. What are the just causes for termination by employer?",
714
+ )
715
+
716
+ if not question:
717
+ st.info(
718
+ "Type a question above and press Enter to query all three models."
719
+ )
720
+ return
721
+
722
+ # --- Greeting check ---
723
+ if is_greeting(question):
724
+ st.markdown("### Response")
725
+ st.info(GREETING_RESPONSE)
726
+ return
727
+
728
+ # --- Retrieval ---
729
+ st.markdown("---")
730
+ with st.spinner("Retrieving relevant context from the Labor Code..."):
731
+ retrieval_start = time.time()
732
+ top_chunks, top_scores = hybrid_retrieve_and_rerank(
733
+ question=question,
734
+ embedder=embedder,
735
+ index=faiss_index,
736
+ bm25=bm25,
737
+ reranker=reranker,
738
+ chunks=chunks,
739
+ initial_k=20,
740
+ rerank_k=8,
741
+ final_k=top_k,
742
+ )
743
+ retrieval_time = time.time() - retrieval_start
744
+
745
+ # --- Display retrieved chunks ---
746
+ st.subheader("Retrieved Context Chunks")
747
+ st.caption(f"Retrieval completed in {retrieval_time:.2f}s")
748
+ for i, (chunk, score) in enumerate(zip(top_chunks, top_scores)):
749
+ with st.expander(
750
+ f"Chunk {i + 1} | Reranker score: {score:.4f}",
751
+ expanded=(i == 0),
752
+ ):
753
+ st.text(chunk)
754
+
755
+ # --- Generation across all three models ---
756
+ st.markdown("---")
757
+ st.subheader("Model Responses")
758
+
759
+ all_answers = {}
760
+ all_metrics = {}
761
+ all_latencies = {}
762
+
763
+ model_names = list(MODEL_CONFIGS.keys())
764
+ cols = st.columns(len(model_names))
765
+
766
+ for col, model_name in zip(cols, model_names):
767
+ with col:
768
+ st.markdown(f"**{model_name}**")
769
+ status_placeholder = st.empty()
770
+ answer_placeholder = st.empty()
771
+ latency_placeholder = st.empty()
772
+
773
+ status_placeholder.warning("Loading model...")
774
+
775
+ try:
776
+ gen_start = time.time()
777
+ model, tokenizer = load_model_and_tokenizer(model_name)
778
+
779
+ status_placeholder.warning("Generating response...")
780
+ prompt = build_prompt(model_name, question, top_chunks, tokenizer)
781
+ answer = generate_answer(
782
+ model, tokenizer, prompt, max_new_tokens=max_tokens
783
+ )
784
+ gen_time = time.time() - gen_start
785
+
786
+ # Unload to free GPU memory for the next model
787
+ unload_model(model, tokenizer)
788
+
789
+ all_answers[model_name] = answer
790
+ all_latencies[model_name] = gen_time
791
+
792
+ status_placeholder.empty()
793
+ answer_placeholder.markdown(answer)
794
+ latency_placeholder.caption(
795
+ f"Generation time: {gen_time:.1f}s"
796
+ )
797
+
798
+ except Exception as e:
799
+ status_placeholder.empty()
800
+ answer_placeholder.error(
801
+ f"Failed to load or run {model_name}: {str(e)}"
802
+ )
803
+ all_answers[model_name] = ""
804
+ all_latencies[model_name] = 0.0
805
+
806
+ # --- Evaluation metrics ---
807
+ st.markdown("---")
808
+ st.subheader("Evaluation Metrics")
809
+
810
+ if not ground_truth and not expected_articles:
811
+ st.info(
812
+ "To view full evaluation metrics, provide a ground truth answer "
813
+ "and/or expected article numbers in the sidebar."
814
+ )
815
+ # Still compute the metrics that do not require ground truth
816
+ for model_name, answer in all_answers.items():
817
+ if answer:
818
+ metrics = {
819
+ "Faithfulness": round(
820
+ compute_faithfulness(answer, top_chunks, embedder), 4
821
+ ),
822
+ "Answer Relevancy": round(
823
+ compute_answer_relevancy(answer, question, embedder), 4
824
+ ),
825
+ }
826
+ all_metrics[model_name] = metrics
827
+ else:
828
+ gt = ground_truth if ground_truth else ""
829
+ for model_name, answer in all_answers.items():
830
+ if answer:
831
+ metrics = evaluate_single_response(
832
+ question=question,
833
+ answer=answer,
834
+ context_chunks=top_chunks,
835
+ ground_truth=gt,
836
+ expected_articles=expected_articles,
837
+ embedder=embedder,
838
+ )
839
+ metrics["Latency (s)"] = round(all_latencies.get(model_name, 0), 2)
840
+ all_metrics[model_name] = metrics
841
+
842
+ if all_metrics:
843
+ # Display as a comparison table
844
+ st.markdown("#### Metric Comparison Table")
845
+
846
+ # Build table data
847
+ all_metric_keys = []
848
+ for m in all_metrics.values():
849
+ for k in m:
850
+ if k not in all_metric_keys:
851
+ all_metric_keys.append(k)
852
+
853
+ table_header = "| Metric | " + " | ".join(all_metrics.keys()) + " |"
854
+ table_sep = "|---|" + "|".join(["---"] * len(all_metrics)) + "|"
855
+ table_rows = []
856
+ for metric_key in all_metric_keys:
857
+ row = f"| {metric_key} |"
858
+ for model_name in all_metrics:
859
+ val = all_metrics[model_name].get(metric_key, "N/A")
860
+ if isinstance(val, float):
861
+ row += f" {val:.4f} |"
862
+ else:
863
+ row += f" {val} |"
864
+ table_rows.append(row)
865
+
866
+ st.markdown("\n".join([table_header, table_sep] + table_rows))
867
+
868
+ # Render comparison chart (only for 0-1 range metrics)
869
+ chart_metrics = {}
870
+ displayable = [
871
+ "Faithfulness", "Semantic Similarity", "Answer Relevancy",
872
+ "Citation Accuracy", "ROUGE-L", "Recall@5", "Precision@5",
873
+ ]
874
+ for model_name, metrics in all_metrics.items():
875
+ chart_metrics[model_name] = {
876
+ k: v for k, v in metrics.items() if k in displayable
877
+ }
878
+
879
+ if any(chart_metrics.values()):
880
+ st.markdown("#### Visual Comparison")
881
+ fig = render_comparison_chart(chart_metrics)
882
+ st.pyplot(fig)
883
+ plt.close(fig)
884
+
885
+ # --- Summary ---
886
+ st.markdown("---")
887
+ st.subheader("Summary")
888
+ summary_cols = st.columns(len(model_names))
889
+ for col, model_name in zip(summary_cols, model_names):
890
+ with col:
891
+ st.markdown(f"**{model_name}**")
892
+ latency = all_latencies.get(model_name, 0)
893
+ answer = all_answers.get(model_name, "")
894
+ word_count = len(answer.split()) if answer else 0
895
+ st.markdown(f"- Response length: {word_count} words")
896
+ st.markdown(f"- Total latency: {latency:.1f}s")
897
+ if model_name in all_metrics:
898
+ faith = all_metrics[model_name].get("Faithfulness", "N/A")
899
+ if isinstance(faith, float):
900
+ st.markdown(f"- Faithfulness: {faith:.4f}")
901
+ rel = all_metrics[model_name].get("Answer Relevancy", "N/A")
902
+ if isinstance(rel, float):
903
+ st.markdown(f"- Answer Relevancy: {rel:.4f}")
904
+
905
+
906
+ if __name__ == "__main__":
907
+ main()
laborcode.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9abd6f8eb3fb329456e926aa46ce05fbfe75f562f8aa6a60e7f421de97295c6
3
+ size 1532508
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.44.0
3
+ bitsandbytes
4
+ accelerate
5
+ sentencepiece
6
+ protobuf
7
+ sentence-transformers
8
+ faiss-cpu
9
+ rank-bm25
10
+ pypdf
11
+ rouge-score
12
+ nltk
13
+ scikit-learn
14
+ numpy
15
+ matplotlib
16
+ streamlit>=1.30.0
17
+ langdetect