Spaces:
Sleeping
Sleeping
| import os | |
| import hashlib | |
| import time | |
| from datetime import datetime | |
| from typing import List, Dict | |
| import numpy as np | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from pymongo import MongoClient, UpdateOne | |
| from sentence_transformers import SentenceTransformer | |
| load_dotenv() | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| DB_NAME = os.getenv("MONGO_DB", "legal_chatbot_db") | |
| COLLECTION_NAME = os.getenv("MONGO_COLLECTION", "datasets") | |
| EMBED_MODEL_NAME = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| CSV_MAIN = os.getenv("CSV_MAIN", "qa_dataset_prepared_for_training.csv") | |
| CSV_PDF = os.getenv("CSV_PDF", "pdf_chunks.csv") | |
| DO_UPSERT = os.getenv("UPSERT", "true").lower() == "true" | |
| # ---------------- Helpers ---------------- # | |
| def normalize_columns(df: pd.DataFrame) -> pd.DataFrame: | |
| # make flexible with various capitalizations | |
| cols = {c.lower().strip(): c for c in df.columns} | |
| # try map common names | |
| mapping = {} | |
| for c in df.columns: | |
| cl = c.lower().strip() | |
| if cl in ("question", "queries", "q"): | |
| mapping[c] = "question" | |
| elif cl in ("answer", "answers", "a", "solution", "text"): | |
| mapping[c] = "answer" | |
| elif cl in ("intent", "label", "category"): | |
| mapping[c] = "intent" | |
| elif cl in ("source",): | |
| mapping[c] = "source" | |
| else: | |
| # keep original if not recognized | |
| mapping[c] = c | |
| out = df.rename(columns=mapping) | |
| # ensure columns exist | |
| if "question" not in out.columns: | |
| raise ValueError("CSV is missing a 'Question' column (or alias).") | |
| if "answer" not in out.columns: | |
| out["answer"] = "" | |
| if "intent" not in out.columns: | |
| out["intent"] = "" | |
| if "source" not in out.columns: | |
| out["source"] = "" | |
| return out | |
| def row_key(question: str, answer: str) -> str: | |
| # stable id for upsert/dedupe | |
| m = hashlib.md5() | |
| m.update((question.strip() + "\n" + answer.strip()).encode("utf-8")) | |
| return m.hexdigest() | |
| def load_frames() -> pd.DataFrame: | |
| frames = [] | |
| if os.path.exists(CSV_MAIN): | |
| frames.append(pd.read_csv(CSV_MAIN)) | |
| if os.path.exists(CSV_PDF): | |
| frames.append(pd.read_csv(CSV_PDF)) | |
| if not frames: | |
| raise FileNotFoundError( | |
| f"No CSVs found. Looked for: {CSV_MAIN} and {CSV_PDF}" | |
| ) | |
| # normalize and concat | |
| norm = [normalize_columns(f) for f in frames] | |
| df = pd.concat(norm, ignore_index=True) | |
| # clean | |
| df["question"] = df["question"].astype(str).str.strip() | |
| df["answer"] = df["answer"].astype(str).fillna("").str.strip() | |
| df["intent"] = df["intent"].astype(str).fillna("").str.strip() | |
| df["source"] = df["source"].astype(str).fillna("").str.strip() | |
| df = df.dropna(subset=["question"]) | |
| df = df[df["question"].str.len() > 0] | |
| # dedupe exact (question+answer) | |
| df["__key"] = df.apply(lambda r: row_key(r["question"], r["answer"]), axis=1) | |
| df = df.drop_duplicates(subset="__key") | |
| return df | |
| def ensure_vector_index(client: MongoClient, db_name: str, col_name: str, dim: int): | |
| """ | |
| NOTE: This uses the $vectorSearch Atlas index name `kb_vector_index`. | |
| You still need to create it in Atlas UI (preferred) OR run an admin command. | |
| Here we print the JSON you should paste if not present. | |
| """ | |
| print("\n⚠️ Make sure you have a Vector Search index in Atlas:") | |
| print(f""" | |
| Database: {db_name} | |
| Collection: {col_name} | |
| Index Name: kb_vector_index | |
| JSON: | |
| {{ | |
| "fields": [ | |
| {{"type": "vector", "path": "embedding", "numDimensions": {dim}, "similarity": "cosine"}} | |
| ] | |
| }} | |
| """.strip()) | |
| def chunked(iterable, size): | |
| for i in range(0, len(iterable), size): | |
| yield iterable[i:i+size] | |
| # ---------------- Main ingest ---------------- # | |
| def main(): | |
| print("🔄 Loading SentenceTransformer:", EMBED_MODEL_NAME) | |
| embedder = SentenceTransformer(EMBED_MODEL_NAME) | |
| dim = embedder.get_sentence_embedding_dimension() | |
| print(f"✅ Embedding dimension: {dim}") | |
| client = MongoClient(MONGO_URI, tls=True, tlsAllowInvalidCertificates=True) | |
| db = client[DB_NAME] | |
| col = db[COLLECTION_NAME] | |
| ensure_vector_index(client, DB_NAME, COLLECTION_NAME, dim) | |
| df = load_frames() | |
| print(f"🧾 Loaded rows: {len(df)} (unique by question+answer)") | |
| # text to embed: use question + [SEP] + answer to give richer context | |
| texts = (df["question"] + " [SEP] " + df["answer"]).tolist() | |
| # encode in chunks to save memory | |
| BATCH = 512 if dim <= 512 else 128 | |
| embeddings: List[np.ndarray] = [] | |
| print("🧠 Computing embeddings...") | |
| for idxs in chunked(list(range(len(texts))), BATCH): | |
| batch_texts = [texts[i] for i in idxs] | |
| vecs = embedder.encode(batch_texts, normalize_embeddings=True) | |
| embeddings.append(vecs) | |
| embeddings = np.vstack(embeddings) | |
| print("✅ Embeddings ready.") | |
| now = datetime.utcnow() | |
| ops: List[UpdateOne] = [] | |
| df = df.rename(columns={"__key": "doc_key"}).reset_index(drop=True) # ensure clean 0..N index | |
| for i, row in enumerate(df.itertuples(index=False)): | |
| key = row.doc_key | |
| doc = { | |
| "_id": key, | |
| "question": row.question, | |
| "answer": row.answer, | |
| "intent": row.intent, | |
| "source": row.source or ("pdf" if "section" in row.question.lower() else "csv"), | |
| "embedding": embeddings[i].tolist(), | |
| "updated_at": now, | |
| } | |
| if DO_UPSERT: | |
| ops.append( | |
| UpdateOne( | |
| {"_id": key}, | |
| {"$setOnInsert": {"created_at": now}, "$set": doc}, | |
| upsert=True, | |
| ) | |
| ) | |
| else: | |
| ops.append(UpdateOne({"_id": key}, {"$set": doc}, upsert=False)) | |
| if DO_UPSERT: | |
| ops.append( | |
| UpdateOne({"_id": key}, | |
| {"$setOnInsert": {"created_at": now}, | |
| "$set": doc}, | |
| upsert=True) | |
| ) | |
| else: | |
| ops.append(UpdateOne({"_id": key}, {"$set": doc}, upsert=False)) | |
| print(f"⬆️ Writing to MongoDB: {len(ops)} docs ...") | |
| result = col.bulk_write(ops, ordered=False) | |
| print("✅ Done.") | |
| print(f"Upserts: {getattr(result, 'upserted_count', 0)}, Modified: {result.modified_count}") | |
| # Quick sample lookup | |
| sample_q = "What are my rights in case of workplace harassment?" | |
| q_vec = embedder.encode([sample_q], normalize_embeddings=True)[0].tolist() | |
| print("🔎 Sample search:", sample_q) | |
| pipe = [ | |
| { | |
| "$vectorSearch": { | |
| "index": "kb_vector_index", | |
| "path": "embedding", | |
| "queryVector": q_vec, | |
| "numCandidates": 100, | |
| "limit": 3 | |
| } | |
| }, | |
| {"$project": {"_id": 0, "intent": 1, "question": 1, "answer": 1, "score": {"$meta": "vectorSearchScore"}}} | |
| ] | |
| try: | |
| hits = list(col.aggregate(pipe)) | |
| for h in hits: | |
| print(f"• [{h.get('intent','')}] score={round(h['score'],3)} Q: {h['question']}\n A: {h['answer'][:140]}...\n") | |
| if not hits: | |
| print("⚠️ No hits returned. Check your Vector Search index status in Atlas.") | |
| except Exception as e: | |
| print("⚠️ Vector search aggregation failed:", e) | |
| print(" - Is your kb_vector_index active?") | |
| print(" - Does your collection contain 'embedding' arrays with correct dimension?") | |
| if __name__ == "__main__": | |
| main() | |