File size: 11,418 Bytes
8a183f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import os, io, glob, time, hashlib
import pandas as pd
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass
from dotenv import load_dotenv

# Telegram
from telegram import Update
from telegram.constants import ParseMode
from telegram.ext import Application, CommandHandler, MessageHandler, ContextTypes, filters

# OpenAI
from openai import OpenAI

# Files / parsing
from docx import Document as DocxDocument
from pypdf import PdfReader

# Vector store
import faiss

load_dotenv()
import httpx
if not hasattr(httpx, "proxies"):
    httpx.proxies = None
    
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TELEGRAM_BOT_TOKEN = os.getenv("TELEGRAM_BOT_TOKEN")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
STRICT_DOC_MODE = (os.getenv("STRICT_DOC_MODE", "true").lower() == "true")

DOCS_DIR = os.getenv("DOCS_DIR", "wedding_docs")
INDEX_PATH = os.getenv("INDEX_PATH", "wedding.index")
META_CSV = os.getenv("META_CSV", "wedding_chunks.csv")

client = OpenAI(api_key=OPENAI_API_KEY)

# -----------------------------
# Utilities to read documents
# -----------------------------

def read_txt_md(path: str) -> str:
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        return f.read()

def read_docx(path: str) -> str:
    doc = DocxDocument(path)
    return "\n".join(p.text for p in doc.paragraphs)

def read_pdf(path: str) -> str:
    reader = PdfReader(path)
    texts = []
    for page in reader.pages:
        try:
            texts.append(page.extract_text() or "")
        except Exception:
            pass
    return "\n".join(texts)

def read_csv_file(path: str, max_rows: int = 20000, max_chars: int = 400_000) -> str:
    """
    Read a CSV as text that’s friendly for RAG.
    - Coerces everything to string
    - Skips bad lines
    - Truncates very large files by rows and total chars (to avoid blowing up memory)
    """
    try:
        df = pd.read_csv(
            path,
            dtype=str,
            on_bad_lines="skip",      # pandas>=1.4
            nrows=max_rows,           # safeguard for very large CSVs
            encoding="utf-8"
        )
    except UnicodeDecodeError:
        # Fallback for odd encodings
        df = pd.read_csv(
            path,
            dtype=str,
            on_bad_lines="skip",
            nrows=max_rows,
            encoding="latin-1"
        )

    df = df.fillna("")

    # Build a compact, readable text block:
    # 1) schema
    cols = list(df.columns)
    schema_block = "COLUMNS: " + ", ".join(cols)

    # 2) first N rows as pipe-separated lines (more RAG-friendly than raw CSV)
    #    Example line: "col1=..., col2=..., col3=..."
    lines = []
    for _, row in df.iterrows():
        kv = [f"{c}={row[c]}" for c in cols]
        lines.append(" | ".join(kv))
        # Hard cap to avoid giant strings
        if sum(len(x) for x in lines) > max_chars:
            lines.append("…(truncated)")
            break

    data_block = "\n".join(lines)
    return f"{schema_block}\n{data_block}"

def load_all_docs(folder: str) -> List[Tuple[str, str]]:
    paths = []
    # ADD "*.csv" here
    for ext in ("*.md", "*.txt", "*.docx", "*.pdf", "*.csv"):
        paths.extend(glob.glob(os.path.join(folder, ext)))

    docs = []
    for p in paths:
        if p.endswith((".md", ".txt")):
            text = read_txt_md(p)
        elif p.endswith(".docx"):
            text = read_docx(p)
        elif p.endswith(".pdf"):
            text = read_pdf(p)
        elif p.endswith(".csv"):  # <--- NEW
            text = read_csv_file(p)
        else:
            continue
        docs.append((p, text))
    return docs


# -----------------------------
# Chunk + Embed + Index
# -----------------------------

def chunk_text(text: str, source: str, chunk_size: int = 300, overlap: int = 80) -> List[dict]:
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        chunk_words = words[i:i+chunk_size]
        chunk = " ".join(chunk_words)
        chunks.append({
            "source": source,
            "chunk": chunk,
            "hash": hashlib.md5((source + str(i) + chunk).encode("utf-8")).hexdigest()
        })
        i += (chunk_size - overlap)
    return chunks

def embed_texts(texts: List[str]) -> np.ndarray:
    # Returns an array of shape (n, d)
    # Uses OpenAI embeddings
    resp = client.embeddings.create(model=EMBEDDING_MODEL, input=texts)
    vecs = [item.embedding for item in resp.data]
    return np.array(vecs).astype("float32")

@dataclass
class RAGIndex:
    index: faiss.IndexFlatIP
    df: pd.DataFrame  # columns: [source, chunk, hash, vector]
    dim: int

def build_or_load_index(force_rebuild: bool = False) -> RAGIndex:
    docs = load_all_docs(DOCS_DIR)
    if not docs:
        raise RuntimeError(f"No docs found in {DOCS_DIR}/. Put your itinerary files there.")

    # Simple staleness check: if any file is newer than index, rebuild
    def newest_mtime():
        paths = []
        for ext in ("*.md", "*.txt", "*.docx", "*.pdf"):
            paths.extend(glob.glob(os.path.join(DOCS_DIR, ext)))
        return max(os.path.getmtime(p) for p in paths)

    index_exists = os.path.exists(INDEX_PATH) and os.path.exists(META_CSV)
    need_rebuild = force_rebuild
    if index_exists:
        idx_mtime = min(os.path.getmtime(INDEX_PATH), os.path.getmtime(META_CSV))
        need_rebuild = need_rebuild or (newest_mtime() > idx_mtime)

    if index_exists and not need_rebuild:
        df = pd.read_csv(META_CSV)
        vecs = np.load(INDEX_PATH)
        dim = vecs.shape[1]
        index = faiss.IndexFlatIP(dim)
        faiss.normalize_L2(vecs)
        index.add(vecs)
        return RAGIndex(index=index, df=df, dim=dim)

    # Rebuild
    all_chunks = []
    for path, text in docs:
        if not text.strip():
            continue
        all_chunks.extend(chunk_text(text, source=path))

    if not all_chunks:
        raise RuntimeError("Docs were read but produced no chunks. Check formats.")

    df = pd.DataFrame(all_chunks)
    vecs = embed_texts(df["chunk"].tolist())
    # Normalize for cosine similarity via inner product
    faiss.normalize_L2(vecs)
    np.save(INDEX_PATH, vecs)
    df.to_csv(META_CSV, index=False)

    dim = vecs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(vecs)
    return RAGIndex(index=index, df=df, dim=dim)

# -----------------------------
# Retrieval + Answering
# -----------------------------

def retrieve(query: str, rag: RAGIndex, k: int = 7) -> List[dict]:
    qvec = embed_texts([query])
    faiss.normalize_L2(qvec)
    D, I = rag.index.search(qvec, k)
    results = []
    for score, idx in zip(D[0], I[0]):
        if idx == -1:
            continue
        row = rag.df.iloc[int(idx)]
        results.append({
            "score": float(score),
            "source": row["source"],
            "chunk": row["chunk"]
        })
    return results

SYSTEM_PROMPT = (
    "You are a helpful, concise wedding assistant for Samuel's wedding. "
    "Answer ONLY using the provided context from the wedding documents. "
    "If the answer isn’t in the docs, say you don’t have that info and suggest who to contact (e.g., Overall IC). "
    "Keep answers under 6 bullets or 150 words when possible. Use SGT times."
)

async def answer_with_rag(question: str, rag: RAGIndex) -> str:
    ctx = retrieve(question, rag, k=6)
    context_blocks = []
    for r in ctx:
        # Keep brief context slices
        text = r["chunk"]
        if len(text) > 800:
            text = text[:800] + "…"
        context_blocks.append(f"[Source: {os.path.basename(r['source'])}]\n{text}")

    context_text = "\n\n".join(context_blocks)

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"Context from docs:\n\n{context_text}\n\nQuestion: {question}"}
    ]

    completion = client.chat.completions.create(
        model=OPENAI_MODEL,
        messages=messages,
        temperature=0.2,
    )
    answer = completion.choices[0].message.content.strip()

    if STRICT_DOC_MODE and (not context_blocks or "I don’t have that info" in answer):
        # If no context matched strongly, enforce honesty
        if len(context_blocks) == 0:
            return (
                "I couldn’t find this in the wedding docs. Please check the Family Playbook or ask the Overall IC. "
                "You can also /refresh to make sure I have the latest files."
            )
    return answer

# -----------------------------
# Telegram Handlers
# -----------------------------
RAG = None  # lazy loaded

async def ensure_rag(force: bool = False):
    global RAG
    if RAG is None or force:
        RAG = build_or_load_index(force_rebuild=force)

async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
    await ensure_rag(False)
    msg = (
        "👋 Hello! I’m the Wedding Q&A Bot. Ask me anything about roles, timings, addresses, and logistics.\n\n"
        "Examples:\n"
        "• What time is the solemnisation?\n"
        "• What’s Mum’s role during tea ceremony?\n"
        "• Where to park at the hotel?\n"
        "• Who holds the ang bao box?\n\n"
        "Admins can /refresh after updating the docs."
    )
    await update.message.reply_text(msg)

async def help_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE):
    await update.message.reply_text(
        "Send a normal question, or use:\n"
        "/role <name> — quick role lookup\n"
        "/refresh — rebuild knowledge from latest docs (admin only, but not enforced)"
    )

async def role_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE):
    await ensure_rag(False)
    name = " ".join(context.args).strip()
    if not name:
        await update.message.reply_text("Usage: /role <name>")
        return
    q = f"What is the role and responsibilities of {name}? Include timings and contact if available."
    ans = await answer_with_rag(q, RAG)
    await update.message.reply_text(ans, parse_mode=ParseMode.MARKDOWN)

async def refresh_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE):
    await ensure_rag(True)
    await update.message.reply_text("✅ Refreshed. I’m now using the latest documents in wedding_docs/.")

async def on_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
    await ensure_rag(False)
    text = (update.message.text or "").strip()
    if not text:
        return
    ans = await answer_with_rag(text, RAG)
    # Telegram has 4096 char limit per message; be safe
    if len(ans) > 3500:
        ans = ans[:3500] + "…"
    await update.message.reply_text(ans, parse_mode=ParseMode.MARKDOWN)

def main():
    if not TELEGRAM_BOT_TOKEN:
        raise RuntimeError("TELEGRAM_BOT_TOKEN missing")
    if not OPENAI_API_KEY:
        raise RuntimeError("OPENAI_API_KEY missing")

    app = Application.builder().token(TELEGRAM_BOT_TOKEN).build()

    app.add_handler(CommandHandler("start", start))
    app.add_handler(CommandHandler("help", help_cmd))
    app.add_handler(CommandHandler("role", role_cmd))
    app.add_handler(CommandHandler("refresh", refresh_cmd))
    app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, on_message))

    print("Bot running… Press Ctrl+C to stop.")
    app.run_polling(drop_pending_updates=True)

if __name__ == "__main__":
    main()