| import os |
| import torch |
| from qdrant_client import QdrantClient, models |
| from sentence_transformers import SentenceTransformer, CrossEncoder |
| from pymongo import MongoClient |
| from bson import ObjectId |
| from typing import List, Dict |
| import google.generativeai as genai |
| from groq import Groq |
|
|
| def build_content(doc: dict, entity_type: str) -> str: |
| """Convert MongoDB document into natural text for embeddings.""" |
| parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"] |
| for k, v in doc.items(): |
| if k in ["_id"]: |
| continue |
| if isinstance(v, list): |
| parts.append(f"{k}: {', '.join(map(str, v))}") |
| elif isinstance(v, dict): |
| nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv]) |
| parts.append(f"{k}: {nested}") |
| else: |
| if v: |
| parts.append(f"{k}: {v}") |
| return "\n".join(parts) |
|
|
|
|
| class ErrorBot: |
| """Chatbot using RAG (Qdrant + Gemini API).""" |
|
|
| def __init__(self, embedding_model_name: str, llm_model_name: str, google_api_key: str): |
| print("🚀 Initializing ErrorBot...") |
|
|
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {self.device}") |
| self.embedding_model = SentenceTransformer(embedding_model_name, device=self.device) |
| self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension() |
|
|
| |
| print("Connecting to Qdrant...") |
| self.qdrant = QdrantClient( |
| url=os.getenv("QDRANT_URL"), |
| api_key=os.getenv("QDRANT_API_KEY"), |
| ) |
| self.collection_name = "technical_errors" |
| self._setup_collection() |
|
|
| |
| genai.configure(api_key=google_api_key) |
| self.llm_model_name = llm_model_name |
| self.llm = genai.GenerativeModel(llm_model_name) |
|
|
| |
| print("Loading cross-encoder reranker...") |
| self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
| print("✅ ErrorBot ready.") |
|
|
| def _setup_collection(self): |
| if not self.qdrant.collection_exists(self.collection_name): |
| self.qdrant.create_collection( |
| collection_name=self.collection_name, |
| vectors_config=models.VectorParams( |
| size=self.embedding_dim, |
| distance=models.Distance.COSINE, |
| ), |
| ) |
|
|
| def ingest_from_mongodb(self, mongo_uri: str, db_name: str, batch_size: int = 32): |
| client = MongoClient(mongo_uri) |
| db = client[db_name] |
|
|
| collections = { |
| "ProblemReport": db["problemReports"], |
| "FaultAnalysis": db["faultanalysis"], |
| "Correction": db["corrections"], |
| } |
|
|
| docs = [] |
| for entity_type, coll in collections.items(): |
| for doc in coll.find(): |
| if "_id" in doc and isinstance(doc["_id"], ObjectId): |
| doc["_id"] = str(doc["_id"]) |
| docs.append({"entity_type": entity_type, "data": doc}) |
|
|
| contents = [build_content(d["data"], d["entity_type"]) for d in docs] |
|
|
| all_embeddings = [] |
| for i in range(0, len(contents), batch_size): |
| batch_contents = contents[i:i + batch_size] |
| embeddings = self.embedding_model.encode(batch_contents, show_progress_bar=True).tolist() |
| all_embeddings.extend(embeddings) |
|
|
| self.qdrant.upsert( |
| collection_name=self.collection_name, |
| points=[ |
| models.PointStruct( |
| id=i, |
| vector=emb, |
| payload={ |
| "id": d["data"].get("id", str(d["data"].get("_id", i))), |
| "entity_type": d["entity_type"], |
| "raw": d["data"], |
| "content": c, |
| }, |
| ) |
| for i, (d, emb, c) in enumerate(zip(docs, all_embeddings, contents)) |
| ], |
| wait=True, |
| ) |
| print(f"✅ Ingested {len(docs)} documents into '{self.collection_name}'") |
|
|
| def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.3, rerank: bool = True): |
| query_embedding = self.embedding_model.encode(query).tolist() |
| hits = self.qdrant.query_points( |
| collection_name=self.collection_name, |
| query=query_embedding, |
| limit=top_k * 3 if rerank else top_k, |
| with_payload=True, |
| score_threshold=score_threshold, |
| ).points |
|
|
| candidates = [ |
| { |
| "id": hit.payload.get("id"), |
| "entity_type": hit.payload.get("entity_type", ""), |
| "content": hit.payload.get("content", ""), |
| "score": hit.score, |
| } |
| for hit in hits |
| ] |
|
|
| if rerank and candidates: |
| pairs = [(query, c["content"]) for c in candidates] |
| scores = self.reranker.predict(pairs) |
| for i, score in enumerate(scores): |
| candidates[i]["rerank_score"] = float(score) |
| candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) |
|
|
| return candidates[:top_k] |
|
|
| def generate_answer(self, query: str, context: List[Dict], history: list = None): |
| context_str = "\n---\n".join( |
| [f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context] |
| ) |
|
|
| convo_str = "" |
| if history: |
| for msg in history: |
| role = "User" if msg["role"] == "user" else "Assistant" |
| convo_str += f"{role}: {msg['content']}\n" |
|
|
| convo_str += f"User: {query}\nAssistant:" |
|
|
| prompt = f""" |
| You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). |
| Use the provided context and conversation history to answer the question clearly and concisely. |
| If context is not relevant, say you do not have enough information. |
| |
| ### Context |
| {context_str} |
| |
| ### Conversation |
| {convo_str} |
| """ |
|
|
| response = self.llm.generate_content(prompt) |
| return response.text.strip() |
|
|
| def ask(self, query: str, history: list = None): |
| print(f"\n❓ Query: {query}") |
| retrieved_context = self.retrieve(query) |
|
|
| if not retrieved_context: |
| print("💬 No relevant context found.") |
| return "I could not find any relevant information." |
|
|
| print(f"✅ Retrieved {len(retrieved_context)} documents.") |
| for i, doc in enumerate(retrieved_context): |
| print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})") |
|
|
| answer = self.generate_answer(query, retrieved_context, history) |
| print(f"\n🤖 Answer: {answer}") |
| return answer |
|
|