Spaces:
No application file
No application file
| import os, io, glob, time, hashlib | |
| import pandas as pd | |
| import numpy as np | |
| from typing import List, Tuple | |
| from dataclasses import dataclass | |
| from dotenv import load_dotenv | |
| # Telegram | |
| from telegram import Update | |
| from telegram.constants import ParseMode | |
| from telegram.ext import Application, CommandHandler, MessageHandler, ContextTypes, filters | |
| # OpenAI | |
| from openai import OpenAI | |
| # Files / parsing | |
| from docx import Document as DocxDocument | |
| from pypdf import PdfReader | |
| # Vector store | |
| import faiss | |
| load_dotenv() | |
| import httpx | |
| if not hasattr(httpx, "proxies"): | |
| httpx.proxies = None | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| TELEGRAM_BOT_TOKEN = os.getenv("TELEGRAM_BOT_TOKEN") | |
| OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
| EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") | |
| STRICT_DOC_MODE = (os.getenv("STRICT_DOC_MODE", "true").lower() == "true") | |
| DOCS_DIR = os.getenv("DOCS_DIR", "wedding_docs") | |
| INDEX_PATH = os.getenv("INDEX_PATH", "wedding.index") | |
| META_CSV = os.getenv("META_CSV", "wedding_chunks.csv") | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| # ----------------------------- | |
| # Utilities to read documents | |
| # ----------------------------- | |
| def read_txt_md(path: str) -> str: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| def read_docx(path: str) -> str: | |
| doc = DocxDocument(path) | |
| return "\n".join(p.text for p in doc.paragraphs) | |
| def read_pdf(path: str) -> str: | |
| reader = PdfReader(path) | |
| texts = [] | |
| for page in reader.pages: | |
| try: | |
| texts.append(page.extract_text() or "") | |
| except Exception: | |
| pass | |
| return "\n".join(texts) | |
| def read_csv_file(path: str, max_rows: int = 20000, max_chars: int = 400_000) -> str: | |
| """ | |
| Read a CSV as text that’s friendly for RAG. | |
| - Coerces everything to string | |
| - Skips bad lines | |
| - Truncates very large files by rows and total chars (to avoid blowing up memory) | |
| """ | |
| try: | |
| df = pd.read_csv( | |
| path, | |
| dtype=str, | |
| on_bad_lines="skip", # pandas>=1.4 | |
| nrows=max_rows, # safeguard for very large CSVs | |
| encoding="utf-8" | |
| ) | |
| except UnicodeDecodeError: | |
| # Fallback for odd encodings | |
| df = pd.read_csv( | |
| path, | |
| dtype=str, | |
| on_bad_lines="skip", | |
| nrows=max_rows, | |
| encoding="latin-1" | |
| ) | |
| df = df.fillna("") | |
| # Build a compact, readable text block: | |
| # 1) schema | |
| cols = list(df.columns) | |
| schema_block = "COLUMNS: " + ", ".join(cols) | |
| # 2) first N rows as pipe-separated lines (more RAG-friendly than raw CSV) | |
| # Example line: "col1=..., col2=..., col3=..." | |
| lines = [] | |
| for _, row in df.iterrows(): | |
| kv = [f"{c}={row[c]}" for c in cols] | |
| lines.append(" | ".join(kv)) | |
| # Hard cap to avoid giant strings | |
| if sum(len(x) for x in lines) > max_chars: | |
| lines.append("…(truncated)") | |
| break | |
| data_block = "\n".join(lines) | |
| return f"{schema_block}\n{data_block}" | |
| def load_all_docs(folder: str) -> List[Tuple[str, str]]: | |
| paths = [] | |
| # ADD "*.csv" here | |
| for ext in ("*.md", "*.txt", "*.docx", "*.pdf", "*.csv"): | |
| paths.extend(glob.glob(os.path.join(folder, ext))) | |
| docs = [] | |
| for p in paths: | |
| if p.endswith((".md", ".txt")): | |
| text = read_txt_md(p) | |
| elif p.endswith(".docx"): | |
| text = read_docx(p) | |
| elif p.endswith(".pdf"): | |
| text = read_pdf(p) | |
| elif p.endswith(".csv"): # <--- NEW | |
| text = read_csv_file(p) | |
| else: | |
| continue | |
| docs.append((p, text)) | |
| return docs | |
| # ----------------------------- | |
| # Chunk + Embed + Index | |
| # ----------------------------- | |
| def chunk_text(text: str, source: str, chunk_size: int = 300, overlap: int = 80) -> List[dict]: | |
| words = text.split() | |
| chunks = [] | |
| i = 0 | |
| while i < len(words): | |
| chunk_words = words[i:i+chunk_size] | |
| chunk = " ".join(chunk_words) | |
| chunks.append({ | |
| "source": source, | |
| "chunk": chunk, | |
| "hash": hashlib.md5((source + str(i) + chunk).encode("utf-8")).hexdigest() | |
| }) | |
| i += (chunk_size - overlap) | |
| return chunks | |
| def embed_texts(texts: List[str]) -> np.ndarray: | |
| # Returns an array of shape (n, d) | |
| # Uses OpenAI embeddings | |
| resp = client.embeddings.create(model=EMBEDDING_MODEL, input=texts) | |
| vecs = [item.embedding for item in resp.data] | |
| return np.array(vecs).astype("float32") | |
| class RAGIndex: | |
| index: faiss.IndexFlatIP | |
| df: pd.DataFrame # columns: [source, chunk, hash, vector] | |
| dim: int | |
| def build_or_load_index(force_rebuild: bool = False) -> RAGIndex: | |
| docs = load_all_docs(DOCS_DIR) | |
| if not docs: | |
| raise RuntimeError(f"No docs found in {DOCS_DIR}/. Put your itinerary files there.") | |
| # Simple staleness check: if any file is newer than index, rebuild | |
| def newest_mtime(): | |
| paths = [] | |
| for ext in ("*.md", "*.txt", "*.docx", "*.pdf"): | |
| paths.extend(glob.glob(os.path.join(DOCS_DIR, ext))) | |
| return max(os.path.getmtime(p) for p in paths) | |
| index_exists = os.path.exists(INDEX_PATH) and os.path.exists(META_CSV) | |
| need_rebuild = force_rebuild | |
| if index_exists: | |
| idx_mtime = min(os.path.getmtime(INDEX_PATH), os.path.getmtime(META_CSV)) | |
| need_rebuild = need_rebuild or (newest_mtime() > idx_mtime) | |
| if index_exists and not need_rebuild: | |
| df = pd.read_csv(META_CSV) | |
| vecs = np.load(INDEX_PATH) | |
| dim = vecs.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| faiss.normalize_L2(vecs) | |
| index.add(vecs) | |
| return RAGIndex(index=index, df=df, dim=dim) | |
| # Rebuild | |
| all_chunks = [] | |
| for path, text in docs: | |
| if not text.strip(): | |
| continue | |
| all_chunks.extend(chunk_text(text, source=path)) | |
| if not all_chunks: | |
| raise RuntimeError("Docs were read but produced no chunks. Check formats.") | |
| df = pd.DataFrame(all_chunks) | |
| vecs = embed_texts(df["chunk"].tolist()) | |
| # Normalize for cosine similarity via inner product | |
| faiss.normalize_L2(vecs) | |
| np.save(INDEX_PATH, vecs) | |
| df.to_csv(META_CSV, index=False) | |
| dim = vecs.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(vecs) | |
| return RAGIndex(index=index, df=df, dim=dim) | |
| # ----------------------------- | |
| # Retrieval + Answering | |
| # ----------------------------- | |
| def retrieve(query: str, rag: RAGIndex, k: int = 7) -> List[dict]: | |
| qvec = embed_texts([query]) | |
| faiss.normalize_L2(qvec) | |
| D, I = rag.index.search(qvec, k) | |
| results = [] | |
| for score, idx in zip(D[0], I[0]): | |
| if idx == -1: | |
| continue | |
| row = rag.df.iloc[int(idx)] | |
| results.append({ | |
| "score": float(score), | |
| "source": row["source"], | |
| "chunk": row["chunk"] | |
| }) | |
| return results | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful, concise wedding assistant for Samuel's wedding. " | |
| "Answer ONLY using the provided context from the wedding documents. " | |
| "If the answer isn’t in the docs, say you don’t have that info and suggest who to contact (e.g., Overall IC). " | |
| "Keep answers under 6 bullets or 150 words when possible. Use SGT times." | |
| ) | |
| async def answer_with_rag(question: str, rag: RAGIndex) -> str: | |
| ctx = retrieve(question, rag, k=6) | |
| context_blocks = [] | |
| for r in ctx: | |
| # Keep brief context slices | |
| text = r["chunk"] | |
| if len(text) > 800: | |
| text = text[:800] + "…" | |
| context_blocks.append(f"[Source: {os.path.basename(r['source'])}]\n{text}") | |
| context_text = "\n\n".join(context_blocks) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Context from docs:\n\n{context_text}\n\nQuestion: {question}"} | |
| ] | |
| completion = client.chat.completions.create( | |
| model=OPENAI_MODEL, | |
| messages=messages, | |
| temperature=0.2, | |
| ) | |
| answer = completion.choices[0].message.content.strip() | |
| if STRICT_DOC_MODE and (not context_blocks or "I don’t have that info" in answer): | |
| # If no context matched strongly, enforce honesty | |
| if len(context_blocks) == 0: | |
| return ( | |
| "I couldn’t find this in the wedding docs. Please check the Family Playbook or ask the Overall IC. " | |
| "You can also /refresh to make sure I have the latest files." | |
| ) | |
| return answer | |
| # ----------------------------- | |
| # Telegram Handlers | |
| # ----------------------------- | |
| RAG = None # lazy loaded | |
| async def ensure_rag(force: bool = False): | |
| global RAG | |
| if RAG is None or force: | |
| RAG = build_or_load_index(force_rebuild=force) | |
| async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
| await ensure_rag(False) | |
| msg = ( | |
| "👋 Hello! I’m the Wedding Q&A Bot. Ask me anything about roles, timings, addresses, and logistics.\n\n" | |
| "Examples:\n" | |
| "• What time is the solemnisation?\n" | |
| "• What’s Mum’s role during tea ceremony?\n" | |
| "• Where to park at the hotel?\n" | |
| "• Who holds the ang bao box?\n\n" | |
| "Admins can /refresh after updating the docs." | |
| ) | |
| await update.message.reply_text(msg) | |
| async def help_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
| await update.message.reply_text( | |
| "Send a normal question, or use:\n" | |
| "/role <name> — quick role lookup\n" | |
| "/refresh — rebuild knowledge from latest docs (admin only, but not enforced)" | |
| ) | |
| async def role_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
| await ensure_rag(False) | |
| name = " ".join(context.args).strip() | |
| if not name: | |
| await update.message.reply_text("Usage: /role <name>") | |
| return | |
| q = f"What is the role and responsibilities of {name}? Include timings and contact if available." | |
| ans = await answer_with_rag(q, RAG) | |
| await update.message.reply_text(ans, parse_mode=ParseMode.MARKDOWN) | |
| async def refresh_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
| await ensure_rag(True) | |
| await update.message.reply_text("✅ Refreshed. I’m now using the latest documents in wedding_docs/.") | |
| async def on_message(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
| await ensure_rag(False) | |
| text = (update.message.text or "").strip() | |
| if not text: | |
| return | |
| ans = await answer_with_rag(text, RAG) | |
| # Telegram has 4096 char limit per message; be safe | |
| if len(ans) > 3500: | |
| ans = ans[:3500] + "…" | |
| await update.message.reply_text(ans, parse_mode=ParseMode.MARKDOWN) | |
| def main(): | |
| if not TELEGRAM_BOT_TOKEN: | |
| raise RuntimeError("TELEGRAM_BOT_TOKEN missing") | |
| if not OPENAI_API_KEY: | |
| raise RuntimeError("OPENAI_API_KEY missing") | |
| app = Application.builder().token(TELEGRAM_BOT_TOKEN).build() | |
| app.add_handler(CommandHandler("start", start)) | |
| app.add_handler(CommandHandler("help", help_cmd)) | |
| app.add_handler(CommandHandler("role", role_cmd)) | |
| app.add_handler(CommandHandler("refresh", refresh_cmd)) | |
| app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, on_message)) | |
| print("Bot running… Press Ctrl+C to stop.") | |
| app.run_polling(drop_pending_updates=True) | |
| if __name__ == "__main__": | |
| main() |