|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import zipfile, os, re, torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = pipeline( |
|
|
"text-generation", |
|
|
model="mistralai/Mistral-7B-Instruct-v0.2", |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
zip_path = "/app/provinces.zip" |
|
|
extract_folder = "/app/provinces_texts" |
|
|
|
|
|
if os.path.exists(extract_folder): |
|
|
import shutil |
|
|
shutil.rmtree(extract_folder) |
|
|
|
|
|
with zipfile.ZipFile(zip_path, "r") as z: |
|
|
z.extractall(extract_folder) |
|
|
|
|
|
date_pattern = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_metadata_and_content(raw): |
|
|
if "CONTENT:" not in raw: |
|
|
raise ValueError("Missing CONTENT block") |
|
|
|
|
|
header, content = raw.split("CONTENT:", 1) |
|
|
metadata = {} |
|
|
pdfs = [] |
|
|
|
|
|
for line in header.split("\n"): |
|
|
if ":" in line and not line.startswith("-"): |
|
|
k, v = line.split(":", 1) |
|
|
metadata[k.strip().upper()] = v.strip() |
|
|
elif line.strip().startswith("-"): |
|
|
pdfs.append(line.strip()) |
|
|
|
|
|
if pdfs: |
|
|
metadata["PDF_LINKS"] = "\n".join(pdfs) |
|
|
|
|
|
return metadata, content.strip() |
|
|
|
|
|
documents = [] |
|
|
for root, dirs, files in os.walk(extract_folder): |
|
|
for filename in files: |
|
|
if not filename.endswith(".txt") or filename.startswith("._"): |
|
|
continue |
|
|
|
|
|
path = os.path.join(root, filename) |
|
|
try: |
|
|
raw = open(path, "r", encoding="latin-1").read() |
|
|
metadata, content = parse_metadata_and_content(raw) |
|
|
|
|
|
for p in [x.strip() for x in content.split("\n\n") if x.strip()]: |
|
|
documents.append({ |
|
|
"source_title": metadata.get("SOURCE_TITLE", "Unknown"), |
|
|
"province": metadata.get("PROVINCE", "Unknown"), |
|
|
"last_updated": metadata.get("LAST_UPDATED", "Unknown"), |
|
|
"url": metadata.get("URL", "N/A"), |
|
|
"pdf_links": metadata.get("PDF_LINKS", ""), |
|
|
"text": p |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print("Skipped:", path, e) |
|
|
|
|
|
print("Loaded paragraphs:", len(documents)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = pd.DataFrame(documents) |
|
|
texts = df["text"].tolist() |
|
|
embeddings = embedding_model.encode(texts).astype("float16") |
|
|
|
|
|
df["Embedding"] = list(embeddings) |
|
|
print("Embedding index ready:", len(df)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_with_pandas(query, province=None, top_k=2): |
|
|
query_emb = embedding_model.encode([query])[0] |
|
|
|
|
|
subset = df if province is None else df[df["province"] == province] |
|
|
|
|
|
subset = subset.copy() |
|
|
subset["Similarity"] = subset["Embedding"].apply( |
|
|
lambda x: np.dot(query_emb, x) / |
|
|
(np.linalg.norm(query_emb) * np.linalg.norm(x)) |
|
|
) |
|
|
|
|
|
return subset.sort_values("Similarity", ascending=False).head(top_k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_province(q): |
|
|
provinces = { |
|
|
"yukon": "Yukon", "alberta": "Alberta", "bc": "British Columbia", |
|
|
"british columbia": "British Columbia", "manitoba": "Manitoba", |
|
|
"newfoundland": "Newfoundland and Labrador", |
|
|
"saskatchewan": "Saskatchewan", "sask": "Saskatchewan", |
|
|
"ontario": "Ontario", "pei": "Prince Edward Island", |
|
|
"quebec": "Quebec", "new brunswick": "New Brunswick", |
|
|
"nova scotia": "Nova Scotia", "nunavut": "Nunavut", |
|
|
"northwest territories": "Northwest Territories" |
|
|
} |
|
|
q = q.lower() |
|
|
for key, prov in provinces.items(): |
|
|
if key in q: |
|
|
return prov |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_disallowed(q): |
|
|
banned = ["kill", "suicide", "bomb", "weapon", "harm yourself"] |
|
|
return any(b in q.lower() for b in banned) |
|
|
|
|
|
def is_off_topic(q): |
|
|
keys = ["tenant","landlord","rent","evict","lease","repair","notice","unit"] |
|
|
return not any(k in q.lower() for k in keys) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INTRO = ( |
|
|
"Hi! I'm a Canadian rental housing assistant. I help summarize and explain " |
|
|
"information from Residential Tenancies Acts across Canada.\n\n" |
|
|
"**Note:** I'm not a lawyer — this is not legal advice.\n\n" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_with_rag(query): |
|
|
if is_disallowed(query): |
|
|
return "Sorry — I can’t help with harmful topics." |
|
|
|
|
|
if is_off_topic(query): |
|
|
return "Sorry — I only answer questions about Canadian tenancy law." |
|
|
|
|
|
prov = detect_province(query) |
|
|
docs = retrieve_with_pandas(query, province=prov, top_k=2) |
|
|
|
|
|
if len(docs) == 0: |
|
|
return "I couldn’t find anything relevant in the tenancy database." |
|
|
|
|
|
context = " ".join(docs["text"].tolist()) |
|
|
|
|
|
prompt = f""" |
|
|
Use only the context below. Do NOT invent laws. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{query} |
|
|
|
|
|
Answer conversationally: |
|
|
""" |
|
|
|
|
|
out = llm(prompt, max_new_tokens=150)[0]["generated_text"] |
|
|
answer = out.split("Answer conversationally:", 1)[-1].strip() |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_chat(): |
|
|
return [(None, INTRO)] |
|
|
|
|
|
def respond(msg, history): |
|
|
answer = generate_with_rag(msg) |
|
|
history.append((msg, answer)) |
|
|
return history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
chatbot = gr.Chatbot(value=start_chat()) |
|
|
inp = gr.Textbox(label="Ask a question:") |
|
|
|
|
|
inp.submit(respond, [inp, chatbot], chatbot) |
|
|
|
|
|
demo.launch(share=True) |