Spaces:
Paused
Paused
using openai
Browse files- app.py +50 -625
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,546 +1,62 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
| 3 |
-
import numpy as np
|
| 4 |
-
import time
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from utils.retrieve_n_rerank import retrieve_and_rerank
|
| 12 |
-
from utils.sentiment_analysis import get_sentiment
|
| 13 |
-
from utils.coherence_bbscore import coherence_report
|
| 14 |
-
from utils.loading_embeddings import get_vectorstore
|
| 15 |
-
from utils.model_generation import build_messages
|
| 16 |
-
from utils.query_constraints import parse_query_constraints, page_matches, doc_matches
|
| 17 |
-
from utils.conversation_logging import load_history, log_exchange
|
| 18 |
-
from langchain.schema import Document
|
| 19 |
-
from utils.hybrid_retrieval import HybridRetriever, consolidate_page
|
| 20 |
-
except ImportError as e:
|
| 21 |
-
print(f"Import error: {e}")
|
| 22 |
-
print("Make sure you're running from the correct directory and all dependencies are installed.")
|
| 23 |
|
| 24 |
-
|
| 25 |
-
MODEL = "llama3.3-70b-instruct"
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
ENABLE_COHERENCE = True
|
| 30 |
|
| 31 |
-
|
| 32 |
-
PERSISTED_HISTORY = load_history()
|
| 33 |
-
|
| 34 |
-
# Default verbatim mode flag (quotes only, no generative summarization)
|
| 35 |
-
VERBATIM_MODE_DEFAULT = True
|
| 36 |
-
|
| 37 |
-
def _citation(meta, alias_map=None):
|
| 38 |
-
"""Return a concise citation token for a metadata dict.
|
| 39 |
-
Example: [S1 p.17] where S1 alias maps to full doc name in Sources section.
|
| 40 |
-
Falls back to base filename when no alias_map provided.
|
| 41 |
-
"""
|
| 42 |
-
src_raw = os.path.basename(meta.get('source', 'Unknown'))
|
| 43 |
-
base = os.path.splitext(src_raw)[0]
|
| 44 |
-
label = alias_map.get(base, base) if alias_map else base
|
| 45 |
-
page = meta.get('page_label') or meta.get('page') or 'unknown'
|
| 46 |
-
return f"[{label} p.{page}]"
|
| 47 |
-
|
| 48 |
-
def _extract_quotes(query: str, docs, max_quotes: int = 12):
|
| 49 |
-
import re, math
|
| 50 |
-
terms = [t.lower() for t in re.findall(r"[A-Za-z0-9]+", query) if len(t) > 2]
|
| 51 |
-
term_set = set(terms)
|
| 52 |
-
scored = []
|
| 53 |
-
for d in docs:
|
| 54 |
-
meta = getattr(d,'metadata',{})
|
| 55 |
-
# split on sentence end punctuation heuristically
|
| 56 |
-
sentences = re.split(r"(?<=[\.!?])\s+", d.page_content)
|
| 57 |
-
for sent in sentences:
|
| 58 |
-
s = sent.strip()
|
| 59 |
-
if not s:
|
| 60 |
-
continue
|
| 61 |
-
toks = [w.lower() for w in re.findall(r"[A-Za-z0-9]+", s)]
|
| 62 |
-
if not toks:
|
| 63 |
-
continue
|
| 64 |
-
overlap = len(term_set.intersection(toks))
|
| 65 |
-
if overlap == 0:
|
| 66 |
-
continue
|
| 67 |
-
score = overlap / math.log(len(toks)+1, 2)
|
| 68 |
-
scored.append((score, s, meta))
|
| 69 |
-
scored.sort(key=lambda x: x[0], reverse=True)
|
| 70 |
-
out = []
|
| 71 |
-
seen = set()
|
| 72 |
-
for score, s, meta in scored:
|
| 73 |
-
key = (s, meta.get('source'), meta.get('page_label'))
|
| 74 |
-
if key in seen:
|
| 75 |
-
continue
|
| 76 |
-
seen.add(key)
|
| 77 |
-
out.append(f"- \"{s}\" {_citation(meta)}")
|
| 78 |
-
if len(out) >= max_quotes:
|
| 79 |
-
break
|
| 80 |
-
return out
|
| 81 |
-
|
| 82 |
-
def _extract_quote_records(query: str, docs, max_quotes: int = 14):
|
| 83 |
-
"""Structured variant of _extract_quotes returning list of dict rows."""
|
| 84 |
-
import re, math
|
| 85 |
-
terms = [t.lower() for t in re.findall(r"[A-Za-z0-9]+", query) if len(t) > 2]
|
| 86 |
-
term_set = set(terms)
|
| 87 |
-
scored = []
|
| 88 |
-
for d in docs:
|
| 89 |
-
meta = getattr(d,'metadata',{})
|
| 90 |
-
sentences = re.split(r"(?<=[\.!?])\s+", d.page_content)
|
| 91 |
-
for sent in sentences:
|
| 92 |
-
s = sent.strip()
|
| 93 |
-
if not s:
|
| 94 |
-
continue
|
| 95 |
-
toks = [w.lower() for w in re.findall(r"[A-Za-z0-9]+", s)]
|
| 96 |
-
if not toks:
|
| 97 |
-
continue
|
| 98 |
-
overlap = len(term_set.intersection(toks))
|
| 99 |
-
if overlap == 0:
|
| 100 |
-
continue
|
| 101 |
-
score = overlap / math.log(len(toks)+1, 2)
|
| 102 |
-
scored.append((score, s, meta))
|
| 103 |
-
scored.sort(key=lambda x: x[0], reverse=True)
|
| 104 |
-
rows = []
|
| 105 |
-
seen = set()
|
| 106 |
-
for score, s, meta in scored:
|
| 107 |
-
key = (s, meta.get('source'), meta.get('page_label'))
|
| 108 |
-
if key in seen:
|
| 109 |
-
continue
|
| 110 |
-
seen.add(key)
|
| 111 |
-
src_raw = os.path.basename(meta.get('source','Unknown'))
|
| 112 |
-
src = os.path.splitext(src_raw)[0]
|
| 113 |
-
page = meta.get('page_label') or meta.get('page') or '—'
|
| 114 |
-
rows.append({"Document": src, "Page": page, "Excerpt": s, "Citation": _citation(meta)})
|
| 115 |
-
if len(rows) >= max_quotes:
|
| 116 |
-
break
|
| 117 |
-
return rows
|
| 118 |
-
|
| 119 |
-
def _detect_comparison_intent(message: str) -> bool:
|
| 120 |
-
m = message.lower()
|
| 121 |
-
keywords = [" compare ", " comparison", " vs ", " versus ", "difference between", "differences between", "contrast "]
|
| 122 |
-
if any(k in m for k in keywords):
|
| 123 |
-
return True
|
| 124 |
-
# pattern like "between X and Y"
|
| 125 |
-
return ("between" in m and " and " in m)
|
| 126 |
-
|
| 127 |
-
def _quote_records_to_table(rows):
|
| 128 |
-
if not rows:
|
| 129 |
-
return "Not found in sources."
|
| 130 |
-
# Limit excerpt length for readability
|
| 131 |
-
def shorten(txt, limit=260):
|
| 132 |
-
return txt if len(txt) <= limit else txt[:limit].rstrip() + "…"
|
| 133 |
-
headers = ["Document", "Page", "Excerpt", "Citation"]
|
| 134 |
-
table_lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"]*len(headers)) + " |"]
|
| 135 |
-
for r in rows:
|
| 136 |
-
table_lines.append("| " + " | ".join([
|
| 137 |
-
str(r['Document']),
|
| 138 |
-
str(r['Page']),
|
| 139 |
-
shorten(r['Excerpt']).replace("|","\\|"),
|
| 140 |
-
r['Citation']
|
| 141 |
-
]) + " |")
|
| 142 |
-
return "\n".join(table_lines)
|
| 143 |
-
|
| 144 |
-
def _extract_enumerated_objectives(text: str):
|
| 145 |
-
"""Extract enumerated policy objectives (a) .. (q) from a block of text.
|
| 146 |
-
Returns list of tuples (label, objective_text)."""
|
| 147 |
-
import re
|
| 148 |
-
lowered = text.lower()
|
| 149 |
-
if "objective" not in lowered:
|
| 150 |
-
return []
|
| 151 |
-
# Heuristic anchor phrases.
|
| 152 |
-
anchor_idx = None
|
| 153 |
-
for anchor in ["specifically these objectives are", "specifically, these objectives are", "1.2 energy policy objectives", "energy policy objectives", "policy objectives are:"]:
|
| 154 |
-
idx = lowered.find(anchor)
|
| 155 |
-
if idx != -1:
|
| 156 |
-
anchor_idx = idx
|
| 157 |
-
break
|
| 158 |
-
if anchor_idx is not None:
|
| 159 |
-
segment = text[anchor_idx: anchor_idx + 6000] # limit to reasonable window
|
| 160 |
-
else:
|
| 161 |
-
segment = text
|
| 162 |
-
# Normalize line breaks inside enumerations: join lines that don't start a new bullet.
|
| 163 |
-
lines = [l.strip() for l in segment.splitlines() if l.strip()]
|
| 164 |
-
rebuilt = []
|
| 165 |
-
current = ""
|
| 166 |
-
bullet_pat = re.compile(r"^\(?([a-q])\)\s+", re.IGNORECASE)
|
| 167 |
-
for ln in lines:
|
| 168 |
-
if bullet_pat.match(ln):
|
| 169 |
-
if current:
|
| 170 |
-
rebuilt.append(current.strip())
|
| 171 |
-
current = ln
|
| 172 |
-
else:
|
| 173 |
-
if current:
|
| 174 |
-
# continuation of current bullet
|
| 175 |
-
if not ln.endswith(('.', ';', ':')) and not current.endswith((';', ':')):
|
| 176 |
-
current += ' ' + ln
|
| 177 |
-
else:
|
| 178 |
-
current += ' ' + ln
|
| 179 |
-
# else ignore preamble lines
|
| 180 |
-
if current:
|
| 181 |
-
rebuilt.append(current.strip())
|
| 182 |
-
objectives = []
|
| 183 |
-
for item in rebuilt:
|
| 184 |
-
m = bullet_pat.match(item)
|
| 185 |
-
if not m:
|
| 186 |
-
continue
|
| 187 |
-
label = m.group(1).lower()
|
| 188 |
-
body = bullet_pat.sub('', item).strip()
|
| 189 |
-
# Clean stray spaces/hyphen splits
|
| 190 |
-
body = re.sub(r"\s+", " ", body)
|
| 191 |
-
body = body.replace(" - ", " – ")
|
| 192 |
-
objectives.append((label, body))
|
| 193 |
-
# Deduplicate by label keeping first occurrence
|
| 194 |
-
seen = set()
|
| 195 |
-
ordered = []
|
| 196 |
-
for lab, body in objectives:
|
| 197 |
-
if lab in seen:
|
| 198 |
-
continue
|
| 199 |
-
seen.add(lab)
|
| 200 |
-
ordered.append((lab, body))
|
| 201 |
-
return ordered
|
| 202 |
-
|
| 203 |
-
def _format_objectives_markdown(objs, meta_docs, alias_map=None):
|
| 204 |
-
if not objs:
|
| 205 |
-
return None
|
| 206 |
-
hdr = f"**Policy Objectives** ({len(objs)})\n"
|
| 207 |
-
bullets = [f"{i+1}. ({lab}) {txt}" for i,(lab, txt) in enumerate(objs)]
|
| 208 |
-
src_note = "\n\nSources:\n" + "\n".join([f"- {alias_map.get(d,d)}" for d in sorted(meta_docs)]) if meta_docs else ""
|
| 209 |
-
return hdr + "\n".join(bullets) + src_note
|
| 210 |
-
|
| 211 |
-
def _build_alias_map(docs):
|
| 212 |
-
bases = []
|
| 213 |
-
for d in docs:
|
| 214 |
-
meta = getattr(d,'metadata',{})
|
| 215 |
-
base = os.path.splitext(os.path.basename(meta.get('source','Unknown')))[0]
|
| 216 |
-
if base not in bases:
|
| 217 |
-
bases.append(base)
|
| 218 |
-
alias_map = {}
|
| 219 |
-
for idx, b in enumerate(bases, start=1):
|
| 220 |
-
alias_map[b] = f"S{idx}"
|
| 221 |
-
return alias_map
|
| 222 |
-
|
| 223 |
-
def chat_response(message, history, verbatim_mode=True):
|
| 224 |
"""
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
message: Current user message
|
| 229 |
-
history: List of [user_message, bot_response] pairs
|
| 230 |
"""
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
vectorstore = get_vectorstore()
|
| 235 |
|
| 236 |
-
|
| 237 |
-
want_page = constraints.get("page")
|
| 238 |
-
doc_tokens = constraints.get("doc_tokens", [])
|
| 239 |
|
| 240 |
-
|
| 241 |
-
reranked_results = []
|
| 242 |
try:
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
reranked_results = retrieve_and_rerank(
|
| 254 |
-
query_text=message,
|
| 255 |
-
vectorstore=vectorstore,
|
| 256 |
-
k=base_k,
|
| 257 |
-
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 258 |
-
top_m=40 if want_page is not None else 20,
|
| 259 |
-
min_score=0.4 if want_page is not None else 0.5,
|
| 260 |
-
only_docs=False
|
| 261 |
)
|
| 262 |
-
|
| 263 |
-
if not reranked_results:
|
| 264 |
-
return "I'm sorry, I couldn't find any relevant information in the policy documents to answer your question. Could you try rephrasing your question or asking about a different topic?"
|
| 265 |
-
|
| 266 |
-
# Enforce page constraint if present
|
| 267 |
-
# Document filtering (title tokens)
|
| 268 |
-
if doc_tokens:
|
| 269 |
-
reranked_results = [(d,s) for d,s in reranked_results if doc_matches(getattr(d,'metadata',{}), doc_tokens)]
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
doc = vectorstore.docstore.search(vectorstore.index_to_docstore_id[i])
|
| 280 |
-
meta = getattr(doc,'metadata',{})
|
| 281 |
-
if doc_tokens and not doc_matches(meta, doc_tokens):
|
| 282 |
-
continue
|
| 283 |
-
if page_matches(meta, want_page):
|
| 284 |
-
all_docs.append(doc)
|
| 285 |
-
except Exception:
|
| 286 |
-
pass
|
| 287 |
-
if all_docs:
|
| 288 |
-
# treat as retrieved with neutral score
|
| 289 |
-
reranked_results = [(d, 0.0) for d in all_docs]
|
| 290 |
-
page_filtered = reranked_results
|
| 291 |
-
else:
|
| 292 |
-
reranked_results = page_filtered
|
| 293 |
|
| 294 |
-
|
| 295 |
-
if want_page is not None and (not reranked_results or (doc_tokens and not any(page_matches(getattr(d,'metadata',{}), want_page) for d,_ in reranked_results))):
|
| 296 |
-
return "Not found in sources."
|
| 297 |
|
| 298 |
-
|
|
|
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
top_docs = consolidated
|
| 305 |
|
| 306 |
-
if verbatim_mode:
|
| 307 |
-
alias_map = _build_alias_map(top_docs)
|
| 308 |
-
# Specialized extraction for enumerated objectives if user asks for objectives
|
| 309 |
-
wants_objectives = 'objective' in message.lower()
|
| 310 |
-
objectives_output = None
|
| 311 |
-
if wants_objectives:
|
| 312 |
-
collected = []
|
| 313 |
-
doc_names = set()
|
| 314 |
-
for d in top_docs:
|
| 315 |
-
meta = getattr(d, 'metadata', {})
|
| 316 |
-
raw = d.page_content
|
| 317 |
-
objs = _extract_enumerated_objectives(raw)
|
| 318 |
-
if objs:
|
| 319 |
-
src = os.path.splitext(os.path.basename(meta.get('source','Unknown')))[0]
|
| 320 |
-
doc_names.add(src)
|
| 321 |
-
for o in objs:
|
| 322 |
-
collected.append(o)
|
| 323 |
-
# Collapse duplicates by label keeping first
|
| 324 |
-
dedup = []
|
| 325 |
-
seen_lab = set()
|
| 326 |
-
for lab, body in collected:
|
| 327 |
-
if lab in seen_lab:
|
| 328 |
-
continue
|
| 329 |
-
seen_lab.add(lab)
|
| 330 |
-
dedup.append((lab, body))
|
| 331 |
-
if len(dedup) >= 3: # threshold to treat as valid objective list
|
| 332 |
-
md = _format_objectives_markdown(dedup, doc_names, alias_map=alias_map)
|
| 333 |
-
if md:
|
| 334 |
-
objectives_output = md
|
| 335 |
-
yield md
|
| 336 |
-
try:
|
| 337 |
-
log_exchange(message, md, meta={"mode":"verbatim_objectives","count":len(dedup)})
|
| 338 |
-
except Exception:
|
| 339 |
-
pass
|
| 340 |
-
return
|
| 341 |
-
is_comparison = _detect_comparison_intent(message)
|
| 342 |
-
if is_comparison:
|
| 343 |
-
rows = _extract_quote_records(message, top_docs)
|
| 344 |
-
if not rows:
|
| 345 |
-
return "Not found in sources."
|
| 346 |
-
# Replace Document column with alias
|
| 347 |
-
alias_map_rows = _build_alias_map(top_docs)
|
| 348 |
-
for r in rows:
|
| 349 |
-
r['Document'] = alias_map_rows.get(r['Document'], r['Document'])
|
| 350 |
-
# Rebuild citation using alias
|
| 351 |
-
# Extract meta again not stored; citation already present keep as is for now
|
| 352 |
-
table_md = _quote_records_to_table(rows)
|
| 353 |
-
doc_set = sorted({r['Document'] for r in rows})
|
| 354 |
-
header = f"**Comparative Excerpts** ({len(rows)} sentences)\n"
|
| 355 |
-
guidance = "Columns: Document alias, Page, Excerpt, Citation."
|
| 356 |
-
sources_section = "\n\nSources:\n" + "\n".join([f"- {alias_map_rows.get(k,k)}: {k}" for k in sorted(alias_map_rows)])
|
| 357 |
-
answer = header + guidance + "\n\n" + table_md + sources_section
|
| 358 |
-
yield answer
|
| 359 |
-
try:
|
| 360 |
-
log_exchange(message, answer, meta={"mode": "verbatim_compare", "docs": doc_set})
|
| 361 |
-
except Exception:
|
| 362 |
-
pass
|
| 363 |
-
return
|
| 364 |
-
else:
|
| 365 |
-
# Rebuild quotes with alias citation for cleanliness
|
| 366 |
-
quotes_raw = []
|
| 367 |
-
import re, math
|
| 368 |
-
terms = [t.lower() for t in re.findall(r"[A-Za-z0-9]+", message) if len(t)>2]
|
| 369 |
-
term_set = set(terms)
|
| 370 |
-
scored=[]
|
| 371 |
-
for d in top_docs:
|
| 372 |
-
meta = getattr(d,'metadata',{})
|
| 373 |
-
sentences = re.split(r"(?<=[\.!?])\s+", d.page_content)
|
| 374 |
-
for sent in sentences:
|
| 375 |
-
s = sent.strip()
|
| 376 |
-
if not s:
|
| 377 |
-
continue
|
| 378 |
-
toks=[w.lower() for w in re.findall(r"[A-Za-z0-9]+", s)]
|
| 379 |
-
if not toks:
|
| 380 |
-
continue
|
| 381 |
-
overlap = len(term_set.intersection(toks))
|
| 382 |
-
if overlap==0:
|
| 383 |
-
continue
|
| 384 |
-
score = overlap / math.log(len(toks)+1,2)
|
| 385 |
-
scored.append((score,s,meta))
|
| 386 |
-
scored.sort(key=lambda x:x[0], reverse=True)
|
| 387 |
-
seen=set(); quotes=[]
|
| 388 |
-
for score, s, meta in scored:
|
| 389 |
-
key=(s, meta.get('source'), meta.get('page_label'))
|
| 390 |
-
if key in seen: continue
|
| 391 |
-
seen.add(key)
|
| 392 |
-
quotes.append(f"- \"{s}\" {_citation(meta, alias_map)}")
|
| 393 |
-
if len(quotes)>=12: break
|
| 394 |
-
quotes = quotes
|
| 395 |
-
if not quotes:
|
| 396 |
-
return "Not found in sources."
|
| 397 |
-
# Summarize document + page coverage
|
| 398 |
-
doc_ids = []
|
| 399 |
-
pages = set()
|
| 400 |
-
for d in top_docs:
|
| 401 |
-
m = getattr(d, 'metadata', {})
|
| 402 |
-
sid = os.path.splitext(os.path.basename(m.get('source', 'Unknown')))[0]
|
| 403 |
-
if sid not in doc_ids:
|
| 404 |
-
doc_ids.append(sid)
|
| 405 |
-
pg = m.get('page_label') or m.get('page')
|
| 406 |
-
if pg is not None:
|
| 407 |
-
pages.add(str(pg))
|
| 408 |
-
coverage = f"{len(quotes)} excerpt(s) from {len(doc_ids)} document(s)"
|
| 409 |
-
if want_page is not None:
|
| 410 |
-
coverage += f" (page {want_page})"
|
| 411 |
-
elif pages:
|
| 412 |
-
coverage += f" across pages {', '.join(sorted(pages))}"
|
| 413 |
-
header = f"**Verbatim Excerpts** ({coverage})\n"
|
| 414 |
-
sources_section = "\n\nSources:\n" + "\n".join([f"- {_citation({'source': sid}, alias_map).split()[0][1:-1]}: {sid}" for sid in doc_ids])
|
| 415 |
-
answer = header + "\n".join(quotes) + sources_section
|
| 416 |
-
yield answer
|
| 417 |
-
try:
|
| 418 |
-
log_exchange(message, answer, meta={"mode": "verbatim", "page": want_page, "docs": doc_ids})
|
| 419 |
-
except Exception:
|
| 420 |
-
pass
|
| 421 |
-
return
|
| 422 |
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
coherence_report_ = coherence_report(reranked_results=top_docs, input_text=message) if ENABLE_COHERENCE else ""
|
| 426 |
|
| 427 |
-
# Build base messages from strict template
|
| 428 |
-
allow_meta = None
|
| 429 |
-
if want_page is not None and doc_tokens:
|
| 430 |
-
# simple doc_id alias from tokens joined
|
| 431 |
-
allow_meta = {"doc_id": "_".join(doc_tokens), "pages": [want_page]}
|
| 432 |
-
base_messages = build_messages(
|
| 433 |
-
query=message,
|
| 434 |
-
top_docs=top_docs,
|
| 435 |
-
task_mode="verbatim_sentiment",
|
| 436 |
-
sentiment_rollup=sentiment_rollup if ENABLE_SENTIMENT else {},
|
| 437 |
-
coherence_report=coherence_report_ if ENABLE_COHERENCE else "",
|
| 438 |
-
allowlist_meta=allow_meta
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
# Insert recent history (excluding system + final user already in base) after system message
|
| 442 |
-
messages = [base_messages[0]] # system
|
| 443 |
-
# Combine persisted history (only at first call when provided history empty)
|
| 444 |
-
if not history and PERSISTED_HISTORY:
|
| 445 |
-
history.extend(PERSISTED_HISTORY[-6:]) # seed last 6 past exchanges
|
| 446 |
-
recent_history = history[-6:] if len(history) > 6 else history
|
| 447 |
-
for u, a in recent_history:
|
| 448 |
-
messages.append({"role": "user", "content": u})
|
| 449 |
-
messages.append({"role": "assistant", "content": a})
|
| 450 |
-
messages.append(base_messages[1]) # current user prompt (template)
|
| 451 |
-
|
| 452 |
-
# Stream response from the API
|
| 453 |
-
response = ""
|
| 454 |
-
heading_added = False
|
| 455 |
-
for chunk in stream_llm_response(messages):
|
| 456 |
-
if not heading_added:
|
| 457 |
-
chunk = "**Answer**\n" + chunk.lstrip()
|
| 458 |
-
heading_added = True
|
| 459 |
-
response += chunk
|
| 460 |
-
yield response
|
| 461 |
-
# Append sources block (non-streamed) for clarity
|
| 462 |
-
alias_map_final = _build_alias_map(top_docs)
|
| 463 |
-
if alias_map_final:
|
| 464 |
-
sources_block = "\n\nSources:\n" + "\n".join([f"- {a}: {doc}" for doc,a in {v:k for k,v in alias_map_final.items()}.items()])
|
| 465 |
-
response += sources_block
|
| 466 |
-
yield response
|
| 467 |
-
# After final response, log exchange persistently
|
| 468 |
-
try:
|
| 469 |
-
log_exchange(message, response, meta={"pages": [getattr(d.metadata,'page_label', None) if hasattr(d,'metadata') else None for d in top_docs]})
|
| 470 |
-
except Exception as log_err:
|
| 471 |
-
print(f"Logging error: {log_err}")
|
| 472 |
-
|
| 473 |
-
except Exception as e:
|
| 474 |
-
error_msg = f"I encountered an error while processing your request: {str(e)}"
|
| 475 |
-
yield error_msg
|
| 476 |
-
|
| 477 |
-
## Removed custom prompt builder in favor of strict template usage
|
| 478 |
-
|
| 479 |
-
def stream_llm_response(messages):
|
| 480 |
-
"""Stream response from the LLM API."""
|
| 481 |
-
headers = {
|
| 482 |
-
"Authorization": f"Bearer {API_KEY}",
|
| 483 |
-
"Content-Type": "application/json"
|
| 484 |
-
}
|
| 485 |
-
|
| 486 |
-
data = {
|
| 487 |
-
"model": MODEL,
|
| 488 |
-
"messages": messages,
|
| 489 |
-
"temperature": 0.2,
|
| 490 |
-
"stream": True,
|
| 491 |
-
"max_tokens": 2000
|
| 492 |
-
}
|
| 493 |
-
|
| 494 |
-
try:
|
| 495 |
-
with requests.post("https://inference.do-ai.run/v1/chat/completions",
|
| 496 |
-
headers=headers, json=data, stream=True, timeout=30) as r:
|
| 497 |
-
if r.status_code != 200:
|
| 498 |
-
yield f"[ERROR] API returned status {r.status_code}: {r.text}"
|
| 499 |
-
return
|
| 500 |
-
|
| 501 |
-
for line in r.iter_lines(decode_unicode=True):
|
| 502 |
-
if not line or line.strip() == "data: [DONE]":
|
| 503 |
-
continue
|
| 504 |
-
if line.startswith("data: "):
|
| 505 |
-
line = line[len("data: "):]
|
| 506 |
-
|
| 507 |
-
try:
|
| 508 |
-
chunk = json.loads(line)
|
| 509 |
-
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| 510 |
-
if delta:
|
| 511 |
-
yield delta
|
| 512 |
-
time.sleep(0.01) # Small delay for smooth streaming
|
| 513 |
-
except json.JSONDecodeError:
|
| 514 |
-
continue
|
| 515 |
-
except Exception as e:
|
| 516 |
-
print(f"Streaming error: {e}")
|
| 517 |
-
continue
|
| 518 |
-
|
| 519 |
-
except requests.exceptions.RequestException as e:
|
| 520 |
-
yield f"[ERROR] Network error: {str(e)}"
|
| 521 |
-
except Exception as e:
|
| 522 |
-
yield f"[ERROR] Unexpected error: {str(e)}"
|
| 523 |
-
|
| 524 |
-
def update_sentiment_setting(enable):
|
| 525 |
-
"""Update global sentiment analysis setting."""
|
| 526 |
-
global ENABLE_SENTIMENT
|
| 527 |
-
ENABLE_SENTIMENT = enable
|
| 528 |
-
return f"✅ Sentiment analysis {'enabled' if enable else 'disabled'}"
|
| 529 |
-
|
| 530 |
-
def update_coherence_setting(enable):
|
| 531 |
-
"""Update global coherence analysis setting."""
|
| 532 |
-
global ENABLE_COHERENCE
|
| 533 |
-
ENABLE_COHERENCE = enable
|
| 534 |
-
return f"✅ Coherence analysis {'enabled' if enable else 'disabled'}"
|
| 535 |
-
|
| 536 |
-
# Create the chat interface
|
| 537 |
-
with gr.Blocks(title="Kenya Policy Assistant - Chat", theme=gr.themes.Soft()) as demo:
|
| 538 |
-
gr.Markdown("""
|
| 539 |
-
# 🏛️ Kenya Policy Assistant - Interactive Chat
|
| 540 |
-
Ask questions about Kenya's policies and have a conversation! I can help you understand policy documents with sentiment and coherence analysis.
|
| 541 |
-
""")
|
| 542 |
-
# Floating popup embed for external policy-agent widget (bottom-right)
|
| 543 |
-
# Assumption: vendor script auto-renders a launcher; we just position container.
|
| 544 |
popup_widget_html = '''<style>
|
| 545 |
.policy-agent-popup-container { position:fixed; bottom:16px; right:16px; z-index:9999; }
|
| 546 |
</style>
|
|
@@ -558,107 +74,16 @@ with gr.Blocks(title="Kenya Policy Assistant - Chat", theme=gr.themes.Soft()) as
|
|
| 558 |
</script>
|
| 559 |
</div>'''
|
| 560 |
gr.HTML(popup_widget_html)
|
| 561 |
-
|
| 562 |
-
with gr.Row():
|
| 563 |
-
with gr.Column(scale=3):
|
| 564 |
-
# Settings row at the top
|
| 565 |
-
with gr.Row():
|
| 566 |
-
sentiment_toggle = gr.Checkbox(
|
| 567 |
-
label="📊 Sentiment Analysis",
|
| 568 |
-
value=True,
|
| 569 |
-
info="Analyze tone and sentiment of policy documents"
|
| 570 |
-
)
|
| 571 |
-
coherence_toggle = gr.Checkbox(
|
| 572 |
-
label="🔍 Coherence Analysis",
|
| 573 |
-
value=True,
|
| 574 |
-
info="Check coherence and consistency of retrieved documents"
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
# Main chat interface
|
| 578 |
-
chatbot = gr.Chatbot(
|
| 579 |
-
height=500,
|
| 580 |
-
bubble_full_width=False,
|
| 581 |
-
show_copy_button=True,
|
| 582 |
-
show_share_button=True,
|
| 583 |
-
avatar_images=("👤", "🤖"),
|
| 584 |
-
value=PERSISTED_HISTORY # seed prior memory
|
| 585 |
-
)
|
| 586 |
-
|
| 587 |
-
msg = gr.Textbox(
|
| 588 |
-
placeholder="Ask me about Kenya's policies... (e.g., 'What are the renewable energy regulations?')",
|
| 589 |
-
label="Your Question",
|
| 590 |
-
lines=2
|
| 591 |
-
)
|
| 592 |
-
|
| 593 |
-
with gr.Row():
|
| 594 |
-
submit_btn = gr.Button("📤 Send", variant="primary")
|
| 595 |
-
|
| 596 |
-
with gr.Column(scale=1):
|
| 597 |
-
gr.Markdown("""
|
| 598 |
-
### 💡 Chat Tips
|
| 599 |
-
- Ask specific questions about Kenya policies
|
| 600 |
-
- Ask follow-up questions based on responses
|
| 601 |
-
- Reference previous answers: *"What does this mean?"*
|
| 602 |
-
- Request elaboration: *"Can you explain more?"*
|
| 603 |
-
|
| 604 |
-
### 📝 Example Questions
|
| 605 |
-
- *"What are Kenya's renewable energy policies?"*
|
| 606 |
-
- *"Tell me about water management regulations"*
|
| 607 |
-
- *"What penalties exist for environmental violations?"*
|
| 608 |
-
- *"How does this relate to what you mentioned earlier?"*
|
| 609 |
-
|
| 610 |
-
### ⚙️ Analysis Features
|
| 611 |
-
**Sentiment Analysis**: Understands the tone and intent of policy text
|
| 612 |
-
|
| 613 |
-
**Coherence Analysis**: Checks if retrieved documents are relevant and consistent
|
| 614 |
-
""")
|
| 615 |
-
|
| 616 |
-
with gr.Accordion("📊 Analysis Status", open=False):
|
| 617 |
-
sentiment_status = gr.Textbox(
|
| 618 |
-
value="✅ Sentiment analysis enabled",
|
| 619 |
-
label="Sentiment Status",
|
| 620 |
-
interactive=False
|
| 621 |
-
)
|
| 622 |
-
coherence_status = gr.Textbox(
|
| 623 |
-
value="✅ Coherence analysis enabled",
|
| 624 |
-
label="Coherence Status",
|
| 625 |
-
interactive=False
|
| 626 |
-
)
|
| 627 |
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
history.append([message, ""])
|
| 633 |
-
|
| 634 |
-
for partial_response in bot_message:
|
| 635 |
-
history[-1][1] = partial_response
|
| 636 |
-
yield history, ""
|
| 637 |
-
else:
|
| 638 |
-
yield history, ""
|
| 639 |
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
# Update settings when toggles change
|
| 645 |
-
sentiment_toggle.change(
|
| 646 |
-
fn=update_sentiment_setting,
|
| 647 |
-
inputs=[sentiment_toggle],
|
| 648 |
-
outputs=[sentiment_status]
|
| 649 |
-
)
|
| 650 |
-
|
| 651 |
-
coherence_toggle.change(
|
| 652 |
-
fn=update_coherence_setting,
|
| 653 |
-
inputs=[coherence_toggle],
|
| 654 |
-
outputs=[coherence_status]
|
| 655 |
-
)
|
| 656 |
|
| 657 |
if __name__ == "__main__":
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
share=True,
|
| 661 |
-
debug=True,
|
| 662 |
-
server_name="0.0.0.0",
|
| 663 |
-
server_port=7860
|
| 664 |
-
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from openai import OpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# 🔹 Configure your agent
|
| 5 |
+
agent_endpoint = "https://q77iuwf7ncfemoonbzon2iyd.agents.do-ai.run/api/v1/"
|
| 6 |
+
agent_access_key = "CzIwmTIDFNWRRIHvxVNzKWztq8rn5S5w"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
client = OpenAI(base_url=agent_endpoint, api_key=agent_access_key)
|
|
|
|
| 9 |
|
| 10 |
+
# Parameters
|
| 11 |
+
DEFAULT_RETRIEVAL_RUNS = 3 # adjustable in UI
|
|
|
|
| 12 |
|
| 13 |
+
def policy_chat(message, history, retrieval_runs=DEFAULT_RETRIEVAL_RUNS):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
+
Chatbot with streaming + multiple retrieval runs.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
+
# Show "processing" placeholder
|
| 18 |
+
history = history + [[message, "Processing..."]]
|
| 19 |
+
yield history, history
|
|
|
|
| 20 |
|
| 21 |
+
aggregated_responses = []
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
for run in range(retrieval_runs):
|
|
|
|
| 24 |
try:
|
| 25 |
+
stream = client.chat.completions.create(
|
| 26 |
+
model="n/a", # agent handles routing
|
| 27 |
+
messages=[
|
| 28 |
+
{"role": "system", "content": "The data must be returned verbatim. Please be quite detailed and include all information. You are new to the analysis of policy documents, hence you need to be objective in retrieving information, and it is not expected that you will analyse and interpret the information."},
|
| 29 |
+
*[{"role": "user", "content": u} if i % 2 == 0 else {"role": "assistant", "content": b}
|
| 30 |
+
for i, (u, b) in enumerate(history[:-1])], # exclude placeholder
|
| 31 |
+
{"role": "user", "content": message},
|
| 32 |
+
],
|
| 33 |
+
extra_body={"include_retrieval_info": True},
|
| 34 |
+
stream=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
response_text = ""
|
| 38 |
+
for chunk in stream:
|
| 39 |
+
delta = chunk.choices[0].delta
|
| 40 |
+
if delta and delta.content: # delta.content is a string or None
|
| 41 |
+
response_text += delta.content
|
| 42 |
+
# Stream update to Gradio UI
|
| 43 |
+
history[-1][1] = response_text
|
| 44 |
+
yield history, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
aggregated_responses.append(response_text or "⚠️ Empty response")
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
except Exception as e:
|
| 49 |
+
aggregated_responses.append(f"⚠️ Error during run {run+1}: {str(e)}")
|
| 50 |
|
| 51 |
+
# 🔹 Choose the “best” response (longest for now)
|
| 52 |
+
best_response = max(aggregated_responses, key=len, default="⚠️ No response")
|
| 53 |
+
history[-1][1] = best_response
|
| 54 |
+
yield history, history
|
|
|
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
|
| 58 |
+
gr.Markdown("# 🤖 Policy-Agent Chatbot\nAsk me about policies. I’ll query the knowledge base multiple times to retrieve the best answer for you!")
|
|
|
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
popup_widget_html = '''<style>
|
| 61 |
.policy-agent-popup-container { position:fixed; bottom:16px; right:16px; z-index:9999; }
|
| 62 |
</style>
|
|
|
|
| 74 |
</script>
|
| 75 |
</div>'''
|
| 76 |
gr.HTML(popup_widget_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
chatbot = gr.Chatbot(height=500)
|
| 79 |
+
msg = gr.Textbox(placeholder="Type your question...")
|
| 80 |
+
retrieval_slider = gr.Slider(1, 10, value=DEFAULT_RETRIEVAL_RUNS, step=1, label="Number of retrieval runs")
|
| 81 |
+
clear = gr.Button("Clear Chat")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
msg.submit(policy_chat, [msg, chatbot, retrieval_slider], [chatbot, chatbot])
|
| 84 |
+
msg.submit(lambda: "", None, msg) # clear textbox
|
| 85 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
if __name__ == "__main__":
|
| 88 |
+
demo.launch(debug=True)
|
| 89 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -15,4 +15,5 @@ rank-bm25
|
|
| 15 |
pypdf
|
| 16 |
Pillow
|
| 17 |
pytesseract
|
|
|
|
| 18 |
|
|
|
|
| 15 |
pypdf
|
| 16 |
Pillow
|
| 17 |
pytesseract
|
| 18 |
+
openai
|
| 19 |
|