Sai809701
Added model ,dataset and dockerfile
17205ab
import os
import hashlib
import time
from datetime import datetime
from typing import List, Dict
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from pymongo import MongoClient, UpdateOne
from sentence_transformers import SentenceTransformer
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI")
DB_NAME = os.getenv("MONGO_DB", "legal_chatbot_db")
COLLECTION_NAME = os.getenv("MONGO_COLLECTION", "datasets")
EMBED_MODEL_NAME = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
CSV_MAIN = os.getenv("CSV_MAIN", "qa_dataset_prepared_for_training.csv")
CSV_PDF = os.getenv("CSV_PDF", "pdf_chunks.csv")
DO_UPSERT = os.getenv("UPSERT", "true").lower() == "true"
# ---------------- Helpers ---------------- #
def normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
# make flexible with various capitalizations
cols = {c.lower().strip(): c for c in df.columns}
# try map common names
mapping = {}
for c in df.columns:
cl = c.lower().strip()
if cl in ("question", "queries", "q"):
mapping[c] = "question"
elif cl in ("answer", "answers", "a", "solution", "text"):
mapping[c] = "answer"
elif cl in ("intent", "label", "category"):
mapping[c] = "intent"
elif cl in ("source",):
mapping[c] = "source"
else:
# keep original if not recognized
mapping[c] = c
out = df.rename(columns=mapping)
# ensure columns exist
if "question" not in out.columns:
raise ValueError("CSV is missing a 'Question' column (or alias).")
if "answer" not in out.columns:
out["answer"] = ""
if "intent" not in out.columns:
out["intent"] = ""
if "source" not in out.columns:
out["source"] = ""
return out
def row_key(question: str, answer: str) -> str:
# stable id for upsert/dedupe
m = hashlib.md5()
m.update((question.strip() + "\n" + answer.strip()).encode("utf-8"))
return m.hexdigest()
def load_frames() -> pd.DataFrame:
frames = []
if os.path.exists(CSV_MAIN):
frames.append(pd.read_csv(CSV_MAIN))
if os.path.exists(CSV_PDF):
frames.append(pd.read_csv(CSV_PDF))
if not frames:
raise FileNotFoundError(
f"No CSVs found. Looked for: {CSV_MAIN} and {CSV_PDF}"
)
# normalize and concat
norm = [normalize_columns(f) for f in frames]
df = pd.concat(norm, ignore_index=True)
# clean
df["question"] = df["question"].astype(str).str.strip()
df["answer"] = df["answer"].astype(str).fillna("").str.strip()
df["intent"] = df["intent"].astype(str).fillna("").str.strip()
df["source"] = df["source"].astype(str).fillna("").str.strip()
df = df.dropna(subset=["question"])
df = df[df["question"].str.len() > 0]
# dedupe exact (question+answer)
df["__key"] = df.apply(lambda r: row_key(r["question"], r["answer"]), axis=1)
df = df.drop_duplicates(subset="__key")
return df
def ensure_vector_index(client: MongoClient, db_name: str, col_name: str, dim: int):
"""
NOTE: This uses the $vectorSearch Atlas index name `kb_vector_index`.
You still need to create it in Atlas UI (preferred) OR run an admin command.
Here we print the JSON you should paste if not present.
"""
print("\n⚠️ Make sure you have a Vector Search index in Atlas:")
print(f"""
Database: {db_name}
Collection: {col_name}
Index Name: kb_vector_index
JSON:
{{
"fields": [
{{"type": "vector", "path": "embedding", "numDimensions": {dim}, "similarity": "cosine"}}
]
}}
""".strip())
def chunked(iterable, size):
for i in range(0, len(iterable), size):
yield iterable[i:i+size]
# ---------------- Main ingest ---------------- #
def main():
print("🔄 Loading SentenceTransformer:", EMBED_MODEL_NAME)
embedder = SentenceTransformer(EMBED_MODEL_NAME)
dim = embedder.get_sentence_embedding_dimension()
print(f"✅ Embedding dimension: {dim}")
client = MongoClient(MONGO_URI, tls=True, tlsAllowInvalidCertificates=True)
db = client[DB_NAME]
col = db[COLLECTION_NAME]
ensure_vector_index(client, DB_NAME, COLLECTION_NAME, dim)
df = load_frames()
print(f"🧾 Loaded rows: {len(df)} (unique by question+answer)")
# text to embed: use question + [SEP] + answer to give richer context
texts = (df["question"] + " [SEP] " + df["answer"]).tolist()
# encode in chunks to save memory
BATCH = 512 if dim <= 512 else 128
embeddings: List[np.ndarray] = []
print("🧠 Computing embeddings...")
for idxs in chunked(list(range(len(texts))), BATCH):
batch_texts = [texts[i] for i in idxs]
vecs = embedder.encode(batch_texts, normalize_embeddings=True)
embeddings.append(vecs)
embeddings = np.vstack(embeddings)
print("✅ Embeddings ready.")
now = datetime.utcnow()
ops: List[UpdateOne] = []
df = df.rename(columns={"__key": "doc_key"}).reset_index(drop=True) # ensure clean 0..N index
for i, row in enumerate(df.itertuples(index=False)):
key = row.doc_key
doc = {
"_id": key,
"question": row.question,
"answer": row.answer,
"intent": row.intent,
"source": row.source or ("pdf" if "section" in row.question.lower() else "csv"),
"embedding": embeddings[i].tolist(),
"updated_at": now,
}
if DO_UPSERT:
ops.append(
UpdateOne(
{"_id": key},
{"$setOnInsert": {"created_at": now}, "$set": doc},
upsert=True,
)
)
else:
ops.append(UpdateOne({"_id": key}, {"$set": doc}, upsert=False))
if DO_UPSERT:
ops.append(
UpdateOne({"_id": key},
{"$setOnInsert": {"created_at": now},
"$set": doc},
upsert=True)
)
else:
ops.append(UpdateOne({"_id": key}, {"$set": doc}, upsert=False))
print(f"⬆️ Writing to MongoDB: {len(ops)} docs ...")
result = col.bulk_write(ops, ordered=False)
print("✅ Done.")
print(f"Upserts: {getattr(result, 'upserted_count', 0)}, Modified: {result.modified_count}")
# Quick sample lookup
sample_q = "What are my rights in case of workplace harassment?"
q_vec = embedder.encode([sample_q], normalize_embeddings=True)[0].tolist()
print("🔎 Sample search:", sample_q)
pipe = [
{
"$vectorSearch": {
"index": "kb_vector_index",
"path": "embedding",
"queryVector": q_vec,
"numCandidates": 100,
"limit": 3
}
},
{"$project": {"_id": 0, "intent": 1, "question": 1, "answer": 1, "score": {"$meta": "vectorSearchScore"}}}
]
try:
hits = list(col.aggregate(pipe))
for h in hits:
print(f"• [{h.get('intent','')}] score={round(h['score'],3)} Q: {h['question']}\n A: {h['answer'][:140]}...\n")
if not hits:
print("⚠️ No hits returned. Check your Vector Search index status in Atlas.")
except Exception as e:
print("⚠️ Vector search aggregation failed:", e)
print(" - Is your kb_vector_index active?")
print(" - Does your collection contain 'embedding' arrays with correct dimension?")
if __name__ == "__main__":
main()