import os import torch from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer, CrossEncoder from pymongo import MongoClient from bson import ObjectId from typing import List, Dict import google.generativeai as genai from groq import Groq from embedding_model_instance import embedding_model, embedding_dim, reranker from qdrant_instance import qdrant from llm import gemini, groq from mongo_instance import db import json from bson import ObjectId def build_content(doc: dict, entity_type: str) -> str: """Convert MongoDB document into natural text for embeddings.""" parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"] for k, v in doc.items(): if k in ["_id"]: # skip ObjectId continue if isinstance(v, list): parts.append(f"{k}: {', '.join(map(str, v))}") elif isinstance(v, dict): nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv]) parts.append(f"{k}: {nested}") else: if v: parts.append(f"{k}: {v}") return "\n".join(parts) class ErrorBot: """Chatbot using RAG (Qdrant + Gemini API).""" def __init__(self, embedding_model_name: str, llm_model_name: str, google_api_key: str = None, groq_api_key: str = None, llm_provider: str = "gemini", last_context: list = None): print("šŸš€ Initializing ErrorBot...") self.last_context = last_context print("last_context", last_context) # --- Embedding model # self.device = "cuda" if torch.cuda.is_available() else "cpu" self.embedding_model = embedding_model self.embedding_dim = embedding_dim self.db = db # --- Qdrant client self.qdrant = qdrant self.collection_name = "technical_errors" #self._setup_collection() # --- LLM setup self.llm_provider = llm_provider.lower() self.llm_model_name = llm_model_name if self.llm_provider == "gemini": self.llm = gemini elif self.llm_provider == "groq": self.llm = groq else: raise ValueError(f"Unsupported LLM provider: {self.llm_provider}") # --- Cross encoder reranker self.reranker = reranker print(f"āœ… ErrorBot ready with {self.llm_provider.upper()}") def _setup_collection(self): if not self.qdrant.collection_exists(self.collection_name): self.qdrant.create_collection( collection_name=self.collection_name, vectors_config=models.VectorParams( size=self.embedding_dim, distance=models.Distance.COSINE, ), ) def ingest_from_mongodb(self, mongo_uri: str, db_name: str, batch_size: int = 32): client = MongoClient(mongo_uri) db = client[db_name] collections = { "ProblemReport": db["problemReports"], "FaultAnalysis": db["faultanalysis"], "Correction": db["corrections"], } docs = [] for entity_type, coll in collections.items(): for doc in coll.find(): if "_id" in doc and isinstance(doc["_id"], ObjectId): doc["_id"] = str(doc["_id"]) docs.append({"entity_type": entity_type, "data": doc}) contents = [build_content(d["data"], d["entity_type"]) for d in docs] all_embeddings = [] for i in range(0, len(contents), batch_size): batch_contents = contents[i:i + batch_size] embeddings = self.embedding_model.encode(batch_contents, show_progress_bar=True).tolist() all_embeddings.extend(embeddings) self.qdrant.upsert( collection_name=self.collection_name, points=[ models.PointStruct( id=i, vector=emb, payload={ "id": d["data"].get("id", str(d["data"].get("_id", i))), "entity_type": d["entity_type"], "raw": d["data"], "content": c, }, ) for i, (d, emb, c) in enumerate(zip(docs, all_embeddings, contents)) ], wait=True, ) print(f"āœ… Ingested {len(docs)} documents into '{self.collection_name}'") def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.3, rerank: bool = True): query_embedding = self.embedding_model.encode(query).tolist() hits = self.qdrant.query_points( collection_name=self.collection_name, query=query_embedding, limit=top_k * 3 if rerank else top_k, with_payload=True, score_threshold=score_threshold, ).points candidates = [ { "id": hit.payload.get("id"), "entity_type": hit.payload.get("entity_type", ""), "content": hit.payload.get("content", ""), "score": hit.score, } for hit in hits ] if rerank and candidates: pairs = [(query, c["content"]) for c in candidates] scores = self.reranker.predict(pairs) for i, score in enumerate(scores): candidates[i]["rerank_score"] = float(score) candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) return candidates[:top_k] def generate_answer(self, query: str, context: List[Dict], history: list = None, is_followup: bool = False ): """ Generates an answer using the LLM, guiding it to identify which context is useful. """ context_str="" if(is_followup): pass # Aggregation pipeline pipeline = [ # Start with problemReports {"$match": {"_id": {"$in": self.last_context}}}, # Add faultAnalysis {"$unionWith": { "coll": "faultanalysis", "pipeline": [{"$match": {"id": {"$in": self.last_context}}}] }}, # Add corrections {"$unionWith": { "coll": "corrections", "pipeline": [{"$match": {"id": {"$in": self.last_context}}}] }} ] # Run aggregation on problemReports context_docs = list(db.problemReports.aggregate(pipeline)) # Serialize full documents as text for LLM #print(context_docs) context_str = "\n---\n".join( [f"{c.get('entity_type', 'Unknown')} (ID: {c['_id']}):\n{json.dumps(c, default=str)}" for c in context_docs] ) print("Context String in Follow Up:") #print(context_str) else: context_str = "\n---\n".join( [f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context] ) # --- System prompt # system_prompt = f""" # You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). # Use the provided context and conversation history to answer the question clearly and concisely. # If context is not relevant, say you do not have enough information. # ### Context # {context_str} # """ system_prompt = f""" You are a technical assistant. A user may ask questions about Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). Your task is to: 1. Identify which information (PR, FA, CR) is relevant to answering the user's question. 2. Explain the solution in simple, clear, actionable language. 3. Do not just repeat the content; summarize and explain. ### User Question: ### Context: {context_str} Provide a concise, step-by-step explanation if applicable. """ # --- Conversation history in list-of-dicts format convo = [] if history: for msg in history: convo.append({ "role": "user" if msg["role"] == "user" else "assistant", "content": msg["content"], }) convo.append({"role": "user", "content": query}) # --- Gemini flow if self.llm_provider == "gemini": convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:" response = self.llm.generate_content(prompt) return response.text.strip() # --- Groq flow elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=[{"role": "system", "content": system_prompt}] + convo ) return completion.choices[0].message.content.strip() def fetch_problem_report_with_links(self, pr_id: str): # --- Fetch Problem Report pr_doc = db["problemReports"].find_one({"id": pr_id}) if not pr_doc: return None, [], [], [], [] if "_id" in pr_doc and isinstance(pr_doc["_id"], ObjectId): pr_doc["_id"] = str(pr_doc["_id"]) # --- Extract linked IDs cr_ids = pr_doc.get("correctionIds", []) fa_ids = pr_doc.get("faultAnalysisId", []) # ensure both are lists if isinstance(cr_ids, str): cr_ids = [cr_ids] elif cr_ids is None: cr_ids = [] if isinstance(fa_ids, str): fa_ids = [fa_ids] elif fa_ids is None: fa_ids = [] # --- Fetch Correction Reports cr_docs = list(db["corrections"].find({"id": {"$in": cr_ids}})) if cr_ids else [] for doc in cr_docs: if "_id" in doc and isinstance(doc["_id"], ObjectId): doc["_id"] = str(doc["_id"]) # --- Fetch Fault Analysis Reports fa_docs = list(db["faultanalysis"].find({"id": {"$in": fa_ids}})) if fa_ids else [] for doc in fa_docs: if "_id" in doc and isinstance(doc["_id"], ObjectId): doc["_id"] = str(doc["_id"]) return pr_doc, cr_ids, fa_ids, cr_docs, fa_docs def is_technical_query(self, query: str) -> bool: """ Classify query as TECHNICAL or NON-TECHNICAL. """ classification_prompt = f""" You are a classifier. Determine if the following query is TECHNICAL (related to software, debugging, errors, troubleshooting, fault analysis, corrections, technical problem reports) or NON-TECHNICAL (general questions, greetings, chit-chat, unrelated topics). Query: "{query}" Respond with exactly one word: "TECHNICAL" or "NON-TECHNICAL". """ if self.llm_provider == "gemini": response = self.llm.generate_content(classification_prompt) result = response.text.strip().upper() elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=[{"role": "system", "content": classification_prompt}] ) result = completion.choices[0].message.content.strip().upper() return result == "TECHNICAL" def is_followup_query(self, query: str, history: list = None) -> bool: """ Detect if query is a follow-up based on conversation history. """ if not history: return False classification_prompt = f""" You are a classifier. Determine if the following user query is a FOLLOW-UP (depends on the previous conversation) or a NEW QUERY (can be answered independently). Previous conversation: { [msg['content'] for msg in history][-3:] } Current query: "{query}" Respond with exactly one word: "FOLLOW-UP" or "NEW". """ if self.llm_provider == "gemini": response = self.llm.generate_content(classification_prompt) result = response.text.strip().upper() elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=[{"role": "system", "content": classification_prompt}] ) result = completion.choices[0].message.content.strip().upper() print("Follow up: ", result) return result == "FOLLOW-UP" def ask(self, query: str, history: list = None): print(f"\nā“ Query: {query}") # Step 1: Classify is_technical = self.is_technical_query(query) is_followup = self.is_followup_query(query, history) # Step 2: Non-technical standalone if not is_technical and not is_followup: print("āš ļø Non-technical standalone query → skipping Qdrant.") system_prompt = "You are a helpful assistant. Answer clearly and concisely." convo = [{"role": "system", "content": system_prompt}, {"role": "user", "content": query}] if self.llm_provider == "gemini": convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) response = self.llm.generate_content(convo_str) return response.text.strip(), [] elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=convo ) return completion.choices[0].message.content.strip(), [] # Step 3: Technical or follow-up print("is_followup", is_followup) print("last_context", self.last_context) if is_followup and self.last_context: print("šŸ”„ Follow-up query → reusing previous context.") retrieved_context = self.last_context else: print("šŸ“„ New technical query → retrieving from Qdrant.") retrieved_context = self.retrieve(query) last_context = [] for i, doc in enumerate(retrieved_context): last_context.append(doc['id']) print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})") if(len(last_context)>0): self.last_context = last_context # save for future follow-ups if not retrieved_context: print("šŸ’¬ No relevant context found.") return "I could not find any relevant information.", [] print(f"āœ… Using {len(retrieved_context)} documents as context.") answer = self.generate_answer(query, retrieved_context, history, is_followup) last_context = self.last_context print(f"\nšŸ¤– Answer: {answer}") return (answer, last_context)