rishabh5752's picture
Update app.py
88335c8 verified
# app.py – PolicyGPT 🇮🇳 (error-free)
import pathlib, tempfile, textwrap, traceback, requests
from functools import lru_cache
import gradio as gr
from langchain_community.embeddings import HuggingFaceEmbeddings # new import
from langchain_community.vectorstores import FAISS # new import
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from transformers import pipeline
import pypdf
# ---------- 1. Policy corpus ----------
POLICY_URLS = {
"DPDP Act 2023":
"https://www.meity.gov.in/static/uploads/2024/06/2bf1f0e9f04e6fb4f8fef35e82c42aa5.pdf",
"Responsible AI (NITI Aayog)":
"https://www.niti.gov.in/sites/default/files/2021-08/Part2-Responsible-AI-12082021.pdf",
# … keep the rest …
}
INDUSTRY_MAP = {
"Health Care": ["DPDP Act 2023", "Responsible AI (NITI Aayog)"],
"All": list(POLICY_URLS.keys()),
}
# ---------- 2. Helpers ----------
def download(url: str, path: pathlib.Path):
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
r = requests.get(url, timeout=120)
r.raise_for_status()
path.write_bytes(r.content)
return path
def pdf_text(path: pathlib.Path) -> str:
out = []
with path.open("rb") as f:
for p in pypdf.PdfReader(f).pages:
out.append(p.extract_text() or "")
return "\n".join(out)
@lru_cache(maxsize=1)
def store(srcs=tuple(POLICY_URLS.keys())):
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
docs = []
for name in srcs:
path = pathlib.Path(tempfile.gettempdir()) / "policygpt" / f"{name}.pdf"
try:
for chunk in splitter.split_text(pdf_text(download(POLICY_URLS[name], path))):
docs.append(Document(page_content=chunk, metadata={"src": name}))
except Exception as e:
print("❌", name, e)
embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return FAISS.from_documents(docs, embed)
GEN = pipeline( # ✅ use text2text-generation
"text2text-generation",
model="google/flan-t5-small",
max_new_tokens=200,
do_sample=False,
)
def rag(q: str, industry: str):
db = store(tuple(POLICY_URLS.keys()) if industry == "All" else tuple(INDUSTRY_MAP[industry]))
ctx = "\n\n".join(d.page_content for d in db.similarity_search(q, k=4))[:3500]
prompt = textwrap.dedent(f"""
You are PolicyGPT. Using CONTEXT, answer QUESTION (≤150 words)
and cite source names in brackets. If unsure, say I don’t know.
CONTEXT:
{ctx}
QUESTION: {q}
ANSWER:
""")
try:
return GEN(prompt)[0]["generated_text"].strip() or "I don’t know."
except Exception as e:
return f"⚠️ Generation error: {e}"
def risk(text: str):
low = text.lower()
if any(k in low for k in ("violation", "prohibited", "penalty")):
return "High"
if any(k in low for k in ("must", "should", "shall")):
return "Medium"
return "Low"
# ---------- 3. Gradio UI ----------
def answer_fn(q, ind):
a = rag(q, ind)
return a, f"**Estimated compliance risk:** {risk(a)}", gr.update(interactive=True)
with gr.Blocks(title="PolicyGPT 🇮🇳") as demo:
gr.Markdown("# PolicyGPT 🇮🇳 — ask about AI & Data-governance laws")
ind = gr.Dropdown(list(INDUSTRY_MAP.keys()), label="Select industry", value="All")
qbox = gr.Textbox(lines=2, label="Your question",
placeholder="e.g. What PII rules apply to hospitals?")
ask = gr.Button("Ask")
ans = gr.Markdown(); rsk = gr.Markdown()
# Disable button while processing
ask.click(lambda: gr.update(interactive=False), None, ask, queue=False)
ask.click(answer_fn, [qbox, ind], [ans, rsk, ask])
qbox.submit(lambda: gr.update(interactive=False), None, ask, queue=False)
qbox.submit(answer_fn, [qbox, ind], [ans, rsk, ask])
# Gradio 4+: no concurrency_count param
if __name__ == "__main__":
demo.queue().launch()