samuel_wedding / app.py
samsonleegh's picture
Upload 5 files
8a183f5 verified
import os, io, glob, time, hashlib
import pandas as pd
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass
from dotenv import load_dotenv
# Telegram
from telegram import Update
from telegram.constants import ParseMode
from telegram.ext import Application, CommandHandler, MessageHandler, ContextTypes, filters
# OpenAI
from openai import OpenAI
# Files / parsing
from docx import Document as DocxDocument
from pypdf import PdfReader
# Vector store
import faiss
load_dotenv()
import httpx
if not hasattr(httpx, "proxies"):
httpx.proxies = None
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TELEGRAM_BOT_TOKEN = os.getenv("TELEGRAM_BOT_TOKEN")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
STRICT_DOC_MODE = (os.getenv("STRICT_DOC_MODE", "true").lower() == "true")
DOCS_DIR = os.getenv("DOCS_DIR", "wedding_docs")
INDEX_PATH = os.getenv("INDEX_PATH", "wedding.index")
META_CSV = os.getenv("META_CSV", "wedding_chunks.csv")
client = OpenAI(api_key=OPENAI_API_KEY)
# -----------------------------
# Utilities to read documents
# -----------------------------
def read_txt_md(path: str) -> str:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
def read_docx(path: str) -> str:
doc = DocxDocument(path)
return "\n".join(p.text for p in doc.paragraphs)
def read_pdf(path: str) -> str:
reader = PdfReader(path)
texts = []
for page in reader.pages:
try:
texts.append(page.extract_text() or "")
except Exception:
pass
return "\n".join(texts)
def read_csv_file(path: str, max_rows: int = 20000, max_chars: int = 400_000) -> str:
"""
Read a CSV as text that’s friendly for RAG.
- Coerces everything to string
- Skips bad lines
- Truncates very large files by rows and total chars (to avoid blowing up memory)
"""
try:
df = pd.read_csv(
path,
dtype=str,
on_bad_lines="skip", # pandas>=1.4
nrows=max_rows, # safeguard for very large CSVs
encoding="utf-8"
)
except UnicodeDecodeError:
# Fallback for odd encodings
df = pd.read_csv(
path,
dtype=str,
on_bad_lines="skip",
nrows=max_rows,
encoding="latin-1"
)
df = df.fillna("")
# Build a compact, readable text block:
# 1) schema
cols = list(df.columns)
schema_block = "COLUMNS: " + ", ".join(cols)
# 2) first N rows as pipe-separated lines (more RAG-friendly than raw CSV)
# Example line: "col1=..., col2=..., col3=..."
lines = []
for _, row in df.iterrows():
kv = [f"{c}={row[c]}" for c in cols]
lines.append(" | ".join(kv))
# Hard cap to avoid giant strings
if sum(len(x) for x in lines) > max_chars:
lines.append("…(truncated)")
break
data_block = "\n".join(lines)
return f"{schema_block}\n{data_block}"
def load_all_docs(folder: str) -> List[Tuple[str, str]]:
paths = []
# ADD "*.csv" here
for ext in ("*.md", "*.txt", "*.docx", "*.pdf", "*.csv"):
paths.extend(glob.glob(os.path.join(folder, ext)))
docs = []
for p in paths:
if p.endswith((".md", ".txt")):
text = read_txt_md(p)
elif p.endswith(".docx"):
text = read_docx(p)
elif p.endswith(".pdf"):
text = read_pdf(p)
elif p.endswith(".csv"): # <--- NEW
text = read_csv_file(p)
else:
continue
docs.append((p, text))
return docs
# -----------------------------
# Chunk + Embed + Index
# -----------------------------
def chunk_text(text: str, source: str, chunk_size: int = 300, overlap: int = 80) -> List[dict]:
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk_words = words[i:i+chunk_size]
chunk = " ".join(chunk_words)
chunks.append({
"source": source,
"chunk": chunk,
"hash": hashlib.md5((source + str(i) + chunk).encode("utf-8")).hexdigest()
})
i += (chunk_size - overlap)
return chunks
def embed_texts(texts: List[str]) -> np.ndarray:
# Returns an array of shape (n, d)
# Uses OpenAI embeddings
resp = client.embeddings.create(model=EMBEDDING_MODEL, input=texts)
vecs = [item.embedding for item in resp.data]
return np.array(vecs).astype("float32")
@dataclass
class RAGIndex:
index: faiss.IndexFlatIP
df: pd.DataFrame # columns: [source, chunk, hash, vector]
dim: int
def build_or_load_index(force_rebuild: bool = False) -> RAGIndex:
docs = load_all_docs(DOCS_DIR)
if not docs:
raise RuntimeError(f"No docs found in {DOCS_DIR}/. Put your itinerary files there.")
# Simple staleness check: if any file is newer than index, rebuild
def newest_mtime():
paths = []
for ext in ("*.md", "*.txt", "*.docx", "*.pdf"):
paths.extend(glob.glob(os.path.join(DOCS_DIR, ext)))
return max(os.path.getmtime(p) for p in paths)
index_exists = os.path.exists(INDEX_PATH) and os.path.exists(META_CSV)
need_rebuild = force_rebuild
if index_exists:
idx_mtime = min(os.path.getmtime(INDEX_PATH), os.path.getmtime(META_CSV))
need_rebuild = need_rebuild or (newest_mtime() > idx_mtime)
if index_exists and not need_rebuild:
df = pd.read_csv(META_CSV)
vecs = np.load(INDEX_PATH)
dim = vecs.shape[1]
index = faiss.IndexFlatIP(dim)
faiss.normalize_L2(vecs)
index.add(vecs)
return RAGIndex(index=index, df=df, dim=dim)
# Rebuild
all_chunks = []
for path, text in docs:
if not text.strip():
continue
all_chunks.extend(chunk_text(text, source=path))
if not all_chunks:
raise RuntimeError("Docs were read but produced no chunks. Check formats.")
df = pd.DataFrame(all_chunks)
vecs = embed_texts(df["chunk"].tolist())
# Normalize for cosine similarity via inner product
faiss.normalize_L2(vecs)
np.save(INDEX_PATH, vecs)
df.to_csv(META_CSV, index=False)
dim = vecs.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(vecs)
return RAGIndex(index=index, df=df, dim=dim)
# -----------------------------
# Retrieval + Answering
# -----------------------------
def retrieve(query: str, rag: RAGIndex, k: int = 7) -> List[dict]:
qvec = embed_texts([query])
faiss.normalize_L2(qvec)
D, I = rag.index.search(qvec, k)
results = []
for score, idx in zip(D[0], I[0]):
if idx == -1:
continue
row = rag.df.iloc[int(idx)]
results.append({
"score": float(score),
"source": row["source"],
"chunk": row["chunk"]
})
return results
SYSTEM_PROMPT = (
"You are a helpful, concise wedding assistant for Samuel's wedding. "
"Answer ONLY using the provided context from the wedding documents. "
"If the answer isn’t in the docs, say you don’t have that info and suggest who to contact (e.g., Overall IC). "
"Keep answers under 6 bullets or 150 words when possible. Use SGT times."
)
async def answer_with_rag(question: str, rag: RAGIndex) -> str:
ctx = retrieve(question, rag, k=6)
context_blocks = []
for r in ctx:
# Keep brief context slices
text = r["chunk"]
if len(text) > 800:
text = text[:800] + "…"
context_blocks.append(f"[Source: {os.path.basename(r['source'])}]\n{text}")
context_text = "\n\n".join(context_blocks)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context from docs:\n\n{context_text}\n\nQuestion: {question}"}
]
completion = client.chat.completions.create(
model=OPENAI_MODEL,
messages=messages,
temperature=0.2,
)
answer = completion.choices[0].message.content.strip()
if STRICT_DOC_MODE and (not context_blocks or "I don’t have that info" in answer):
# If no context matched strongly, enforce honesty
if len(context_blocks) == 0:
return (
"I couldn’t find this in the wedding docs. Please check the Family Playbook or ask the Overall IC. "
"You can also /refresh to make sure I have the latest files."
)
return answer
# -----------------------------
# Telegram Handlers
# -----------------------------
RAG = None # lazy loaded
async def ensure_rag(force: bool = False):
global RAG
if RAG is None or force:
RAG = build_or_load_index(force_rebuild=force)
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
await ensure_rag(False)
msg = (
"👋 Hello! I’m the Wedding Q&A Bot. Ask me anything about roles, timings, addresses, and logistics.\n\n"
"Examples:\n"
"• What time is the solemnisation?\n"
"• What’s Mum’s role during tea ceremony?\n"
"• Where to park at the hotel?\n"
"• Who holds the ang bao box?\n\n"
"Admins can /refresh after updating the docs."
)
await update.message.reply_text(msg)
async def help_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE):
await update.message.reply_text(
"Send a normal question, or use:\n"
"/role <name> — quick role lookup\n"
"/refresh — rebuild knowledge from latest docs (admin only, but not enforced)"
)
async def role_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE):
await ensure_rag(False)
name = " ".join(context.args).strip()
if not name:
await update.message.reply_text("Usage: /role <name>")
return
q = f"What is the role and responsibilities of {name}? Include timings and contact if available."
ans = await answer_with_rag(q, RAG)
await update.message.reply_text(ans, parse_mode=ParseMode.MARKDOWN)
async def refresh_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE):
await ensure_rag(True)
await update.message.reply_text("✅ Refreshed. I’m now using the latest documents in wedding_docs/.")
async def on_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
await ensure_rag(False)
text = (update.message.text or "").strip()
if not text:
return
ans = await answer_with_rag(text, RAG)
# Telegram has 4096 char limit per message; be safe
if len(ans) > 3500:
ans = ans[:3500] + "…"
await update.message.reply_text(ans, parse_mode=ParseMode.MARKDOWN)
def main():
if not TELEGRAM_BOT_TOKEN:
raise RuntimeError("TELEGRAM_BOT_TOKEN missing")
if not OPENAI_API_KEY:
raise RuntimeError("OPENAI_API_KEY missing")
app = Application.builder().token(TELEGRAM_BOT_TOKEN).build()
app.add_handler(CommandHandler("start", start))
app.add_handler(CommandHandler("help", help_cmd))
app.add_handler(CommandHandler("role", role_cmd))
app.add_handler(CommandHandler("refresh", refresh_cmd))
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, on_message))
print("Bot running… Press Ctrl+C to stop.")
app.run_polling(drop_pending_updates=True)
if __name__ == "__main__":
main()