Finalv3.5 / app.py
zm-f21's picture
Update app.py
8f996bb verified
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import zipfile, os, re, torch
# -----------------------------
# Load Mistral (FP16, GPU if available)
# -----------------------------
llm = pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
torch_dtype=torch.float16,
device_map="auto"
)
# -----------------------------
# Load embedding model
# -----------------------------
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
# -----------------------------
# Extract ZIP with provincial legal texts
# -----------------------------
zip_path = "/app/provinces.zip"
extract_folder = "/app/provinces_texts"
if os.path.exists(extract_folder):
import shutil
shutil.rmtree(extract_folder)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(extract_folder)
date_pattern = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})")
# -----------------------------
# Parse documents
# -----------------------------
def parse_metadata_and_content(raw):
if "CONTENT:" not in raw:
raise ValueError("Missing CONTENT block")
header, content = raw.split("CONTENT:", 1)
metadata = {}
pdfs = []
for line in header.split("\n"):
if ":" in line and not line.startswith("-"):
k, v = line.split(":", 1)
metadata[k.strip().upper()] = v.strip()
elif line.strip().startswith("-"):
pdfs.append(line.strip())
if pdfs:
metadata["PDF_LINKS"] = "\n".join(pdfs)
return metadata, content.strip()
documents = []
for root, dirs, files in os.walk(extract_folder):
for filename in files:
if not filename.endswith(".txt") or filename.startswith("._"):
continue
path = os.path.join(root, filename)
try:
raw = open(path, "r", encoding="latin-1").read()
metadata, content = parse_metadata_and_content(raw)
for p in [x.strip() for x in content.split("\n\n") if x.strip()]:
documents.append({
"source_title": metadata.get("SOURCE_TITLE", "Unknown"),
"province": metadata.get("PROVINCE", "Unknown"),
"last_updated": metadata.get("LAST_UPDATED", "Unknown"),
"url": metadata.get("URL", "N/A"),
"pdf_links": metadata.get("PDF_LINKS", ""),
"text": p
})
except Exception as e:
print("Skipped:", path, e)
print("Loaded paragraphs:", len(documents))
# -----------------------------
# Build embeddings dataframe
# -----------------------------
df = pd.DataFrame(documents)
texts = df["text"].tolist()
embeddings = embedding_model.encode(texts).astype("float16")
df["Embedding"] = list(embeddings)
print("Embedding index ready:", len(df))
# -----------------------------
# Retrieval
# -----------------------------
def retrieve_with_pandas(query, province=None, top_k=2):
query_emb = embedding_model.encode([query])[0]
subset = df if province is None else df[df["province"] == province]
subset = subset.copy()
subset["Similarity"] = subset["Embedding"].apply(
lambda x: np.dot(query_emb, x) /
(np.linalg.norm(query_emb) * np.linalg.norm(x))
)
return subset.sort_values("Similarity", ascending=False).head(top_k)
# -----------------------------
# Province detection
# -----------------------------
def detect_province(q):
provinces = {
"yukon": "Yukon", "alberta": "Alberta", "bc": "British Columbia",
"british columbia": "British Columbia", "manitoba": "Manitoba",
"newfoundland": "Newfoundland and Labrador",
"saskatchewan": "Saskatchewan", "sask": "Saskatchewan",
"ontario": "Ontario", "pei": "Prince Edward Island",
"quebec": "Quebec", "new brunswick": "New Brunswick",
"nova scotia": "Nova Scotia", "nunavut": "Nunavut",
"northwest territories": "Northwest Territories"
}
q = q.lower()
for key, prov in provinces.items():
if key in q:
return prov
return None
# -----------------------------
# Filters
# -----------------------------
def is_disallowed(q):
banned = ["kill", "suicide", "bomb", "weapon", "harm yourself"]
return any(b in q.lower() for b in banned)
def is_off_topic(q):
keys = ["tenant","landlord","rent","evict","lease","repair","notice","unit"]
return not any(k in q.lower() for k in keys)
# -----------------------------
# Intro (sent once)
# -----------------------------
INTRO = (
"Hi! I'm a Canadian rental housing assistant. I help summarize and explain "
"information from Residential Tenancies Acts across Canada.\n\n"
"**Note:** I'm not a lawyer — this is not legal advice.\n\n"
)
# -----------------------------
# RAG Generation
# -----------------------------
def generate_with_rag(query):
if is_disallowed(query):
return "Sorry — I can’t help with harmful topics."
if is_off_topic(query):
return "Sorry — I only answer questions about Canadian tenancy law."
prov = detect_province(query)
docs = retrieve_with_pandas(query, province=prov, top_k=2)
if len(docs) == 0:
return "I couldn’t find anything relevant in the tenancy database."
context = " ".join(docs["text"].tolist())
prompt = f"""
Use only the context below. Do NOT invent laws.
Context:
{context}
Question:
{query}
Answer conversationally:
"""
out = llm(prompt, max_new_tokens=150)[0]["generated_text"]
answer = out.split("Answer conversationally:", 1)[-1].strip()
return answer
# -----------------------------
# Gradio Chat (Intro only once)
# -----------------------------
def start_chat():
return [(None, INTRO)]
def respond(msg, history):
answer = generate_with_rag(msg)
history.append((msg, answer))
return history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(value=start_chat())
inp = gr.Textbox(label="Ask a question:")
inp.submit(respond, [inp, chatbot], chatbot)
demo.launch(share=True)