import os import torch from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer, CrossEncoder from pymongo import MongoClient from typing import List, Dict import google.generativeai as genai from groq import Groq from embedding_model_instance import embedding_model_m3, embedding_dim_m3, embedding_model_large, embedding_dim_large, reranker from qdrant_instance import qdrant_m3, qdrant_large from llm import gemini, groq from mongo_instance import db import json from bson import ObjectId def build_content(doc: dict, entity_type: str) -> str: """Convert MongoDB document into natural text for embeddings.""" parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"] for k, v in doc.items(): if k in ["_id"]: # skip ObjectId continue if isinstance(v, list): parts.append(f"{k}: {', '.join(map(str, v))}") elif isinstance(v, dict): nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv]) parts.append(f"{k}: {nested}") else: if v: parts.append(f"{k}: {v}") return "\n".join(parts) class ErrorBot: """Chatbot using RAG (Qdrant + Gemini API).""" def __init__(self, llm_model_name: str, llm_provider: str = "gemini", last_context: list = None): print("๐Ÿš€ Initializing ErrorBot...") self.last_context = last_context #print("last_context", last_context) # --- Embedding model # self.device = "cuda" if torch.cuda.is_available() else "cpu" self.embedding_model_m3 = embedding_model_m3 self.embedding_dim_m3 = embedding_dim_m3 self.embedding_model_large = embedding_model_large self.embedding_dim_large = embedding_dim_large self.db = db # --- Qdrant client self.qdrant_m3 = qdrant_m3 self.qdrant_large = qdrant_large self.collection_name = "technical_errors" #self.collection_name = "json_ingestion" #self._setup_collection() # --- LLM setup self.llm_provider = llm_provider.lower() self.llm_model_name = llm_model_name if self.llm_provider == "gemini": self.llm = gemini elif self.llm_provider == "groq": self.llm = groq else: raise ValueError(f"Unsupported LLM provider: {self.llm_provider}") # --- Cross encoder reranker self.reranker = reranker print(f"โœ… ErrorBot ready with {self.llm_provider.upper()}") # def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.5, rerank: bool = True): # query_embedding = self.embedding_model.encode(query).tolist() # hits = self.qdrant.query_points( # collection_name=self.collection_name, # query=query_embedding, # #limit=top_k * 3 if rerank else top_k, # limit = 100, # with_payload=True, # score_threshold=score_threshold, # search_params=models.SearchParams(hnsw_ef=256), # ).points # candidates = [ # { # "id": hit.payload.get("id"), # # "id": hit.payload.get("raw", {}).get("id"), # "entity_type": hit.payload.get("entity_type", ""), # "content": hit.payload.get("content", ""), # "score": hit.score, # } # for hit in hits # ] # if rerank and candidates: # pairs = [(query, c["content"]) for c in candidates] # scores = self.reranker.predict(pairs) # for i, score in enumerate(scores): # candidates[i]["rerank_score"] = float(score) # candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) # return candidates[:5] # ================================================== # ๐Ÿงฎ Dual Qdrant Ensemble Retrieval # ================================================== def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.5, rerank: bool = True): """Retrieve documents using ensemble of BGE-M3 and BGE-Large models.""" print(f"\n๐Ÿ” Retrieving context using ensemble (M3 + BGE-Large) for query: {query}") # 1๏ธโƒฃ Encode using both models emb_m3 = self.embedding_model_m3.encode(query).tolist() emb_large = self.embedding_model_large.encode(query).tolist() # 2๏ธโƒฃ Query both Qdrant clusters hits_m3 = self.qdrant_m3.query_points( collection_name=self.collection_name, query=emb_m3, limit=top_k * 3, with_payload=True, score_threshold=score_threshold, ).points hits_large = self.qdrant_large.query_points( collection_name=self.collection_name, query=emb_large, limit=top_k * 3, with_payload=True, score_threshold=score_threshold, ).points # 3๏ธโƒฃ Combine results โ€” average normalized scores all_hits = [] for hit in hits_m3 + hits_large: payload = hit.payload score = hit.score all_hits.append({ "id": payload.get("id"), "entity_type": payload.get("entity_type", ""), "content": payload.get("content", ""), "score": score, "source": "M3" if hit in hits_m3 else "LARGE" }) if not all_hits: print("โš ๏ธ No hits from either model.") return [] # Normalize scores between 0-1 (optional) scores = [h["score"] for h in all_hits] min_s, max_s = min(scores), max(scores) for h in all_hits: h["score_norm"] = (h["score"] - min_s) / (max_s - min_s + 1e-6) # Group by ID and average scores if duplicates exist merged = {} for h in all_hits: _id = h["id"] if _id not in merged: merged[_id] = h else: merged[_id]["score_norm"] = (merged[_id]["score_norm"] + h["score_norm"]) / 2 combined_hits = list(merged.values()) combined_hits = sorted(combined_hits, key=lambda x: x["score_norm"], reverse=True)[:top_k * 2] # 4๏ธโƒฃ (Optional) Rerank using cross encoder if rerank and combined_hits: pairs = [(query, h["content"]) for h in combined_hits] scores = self.reranker.predict(pairs) for i, s in enumerate(scores): combined_hits[i]["rerank_score"] = float(s) combined_hits = sorted(combined_hits, key=lambda x: x["rerank_score"], reverse=True) print(f"โœ… Ensemble retrieved {len(combined_hits)} candidates.") return combined_hits[:top_k] def generate_answer(self, query: str, context: List[Dict], history: list = None, is_followup: bool = False ): """ Generates an answer using the LLM, guiding it to identify which context is useful. """ context_str="" if(is_followup): pass # Aggregation pipeline # pipeline = [ # # Start with problemReports # {"$match": {"_id": {"$in": self.last_context}}}, # # Add faultAnalysis # {"$unionWith": { # "coll": "faultanalysis", # "pipeline": [{"$match": {"id": {"$in": self.last_context}}}] # }}, # # Add corrections # {"$unionWith": { # "coll": "corrections", # "pipeline": [{"$match": {"id": {"$in": self.last_context}}}] # }} # ] pipeline = [ # Start with problemReports { "$match": {"_id": {"$in": self.last_context}} }, { "$addFields": {"entity_type": "ProblemReport"} }, # Add faultAnalysis { "$unionWith": { "coll": "faultanalysis", "pipeline": [ {"$match": {"id": {"$in": self.last_context}}}, {"$addFields": {"entity_type": "FaultAnalysis"}} ] } }, # Add corrections { "$unionWith": { "coll": "corrections", "pipeline": [ {"$match": {"id": {"$in": self.last_context}}}, {"$addFields": {"entity_type": "Correction"}} ] } } ] # Run aggregation on problemReports context_docs = list(db.problemReports.aggregate(pipeline)) # Serialize full documents as text for LLM #print(context_docs) context_str = "\n---\n".join( [f"{c['entity_type']} (ID: {c['_id']}):\n{json.dumps(c, default=str)}" for c in context_docs] ) print("Context String in Follow Up:") #print(context_str) else: context_str = "\n---\n".join( [f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context] ) # --- System prompt # system_prompt = f""" # You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). # Use the provided context and conversation history to answer the question clearly and concisely. # If context is not relevant, say you do not have enough information. # ### Context # {context_str} # """ system_prompt = f""" You are a versatile assistant. A user may ask questions about: - Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). - Programming, algorithms, and code examples. - Non-technical or general everyday topics. Your tasks are: 1. If the question is about PR, FA, or CR โ†’ Identify which information is relevant and explain clearly in simple, actionable language (summarize, donโ€™t just repeat). 2. If the question is about programming or algorithms โ†’ Provide a correct, clear, and well-structured code example in the requested language, with explanation. 3. If the question is non-technical/general โ†’ Respond politely, clearly, and helpfully in a conversational style. 4. Always keep answers and easy to understand and detailed. ### User Question: ### Context: {context_str} Provide a concise, step-by-step explanation if applicable. """ # --- Conversation history in list-of-dicts format convo = [] if history: for msg in history: convo.append({ "role": "user" if msg["role"] == "user" else "assistant", "content": msg["content"], }) convo.append({"role": "user", "content": query}) # --- Gemini flow if self.llm_provider == "gemini": convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:" response = self.llm.generate_content(prompt) return response.text.strip() # --- Groq flow elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=[{"role": "system", "content": system_prompt}] + convo ) return completion.choices[0].message.content.strip() def fetch_problem_report_with_links(self, pr_id: str): # --- Fetch Problem Report pr_doc = db["problemReports"].find_one({"id": pr_id}) #print("pr_id:", pr_id) #print("pr_doc:", pr_doc) if not pr_doc: return None, [], [], [], [] if "_id" in pr_doc and isinstance(pr_doc["_id"], ObjectId): pr_doc["_id"] = str(pr_doc["_id"]) # --- Extract linked IDs cr_ids = pr_doc.get("correctionIds", []) fa_ids = pr_doc.get("faultAnalysisId", []) # ensure both are lists if isinstance(cr_ids, str): cr_ids = [cr_ids] elif cr_ids is None: cr_ids = [] if isinstance(fa_ids, str): fa_ids = [fa_ids] elif fa_ids is None: fa_ids = [] # --- Fetch Correction Reports cr_docs = list(db["corrections"].find({"id": {"$in": cr_ids}})) if cr_ids else [] for doc in cr_docs: if "_id" in doc and isinstance(doc["_id"], ObjectId): doc["_id"] = str(doc["_id"]) # --- Fetch Fault Analysis Reports fa_docs = list(db["faultanalysis"].find({"id": {"$in": fa_ids}})) if fa_ids else [] for doc in fa_docs: if "_id" in doc and isinstance(doc["_id"], ObjectId): doc["_id"] = str(doc["_id"]) print(pr_doc) return pr_doc, cr_ids, fa_ids, cr_docs, fa_docs def is_technical_query(self, query: str) -> bool: """ Classify query as TECHNICAL or NON-TECHNICAL. """ classification_prompt = f""" You are a classifier. Determine if the following query is TECHNICAL (related to software, debugging, errors, troubleshooting, fault analysis, corrections, technical problem reports) or NON-TECHNICAL (general questions, greetings, chit-chat, unrelated topics). Query: "{query}" Respond with exactly one word: "TECHNICAL" or "NON-TECHNICAL". """ if self.llm_provider == "gemini": response = self.llm.generate_content(classification_prompt) result = response.text.strip().upper() elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=[{"role": "system", "content": classification_prompt}] ) result = completion.choices[0].message.content.strip().upper() return result == "TECHNICAL" def is_followup_query(self, query: str, history: list = None) -> bool: """ Detect if query is a follow-up based on conversation history. """ if not history: return False classification_prompt = f""" You are a classifier. Determine if the following user query is a FOLLOW-UP (depends on the previous conversation) or a NEW QUERY (can be answered independently). Previous conversation: { [msg['content'] for msg in history][-3:] } Current query: "{query}" Respond with exactly one word: "FOLLOW-UP" or "NEW". """ if self.llm_provider == "gemini": response = self.llm.generate_content(classification_prompt) result = response.text.strip().upper() elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=[{"role": "system", "content": classification_prompt}] ) result = completion.choices[0].message.content.strip().upper() print("Follow up: ", result) return result == "FOLLOW-UP" def ask(self, query: str, history: list = None): print(f"\nโ“ Query: {query}") # Step 1: Classify is_technical = self.is_technical_query(query) is_followup = self.is_followup_query(query, history) # Step 2: Non-technical standalone # Step 3: Technical or follow-up print("is_followup", is_followup) #print("last_context", self.last_context) print("is_technical", is_technical) #if not is_technical: if not is_technical and not is_followup: print("โš ๏ธ Non-technical standalone query โ†’ skipping Qdrant.") system_prompt = "You are a helpful assistant. Answer clearly and concisely." convo = [{"role": "system", "content": system_prompt}, {"role": "user", "content": query}] if self.llm_provider == "gemini": convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) response = self.llm.generate_content(convo_str) return response.text.strip(), [] elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=convo ) return completion.choices[0].message.content.strip(), [] elif is_followup and self.last_context: if not is_technical: print("โš ๏ธ Non-technical followup โ†’ skipping Qdrant.") system_prompt = "You are a helpful assistant. Answer clearly and concisely." convo = [{"role": "system", "content": system_prompt}, {"role": "user", "content": query}] if self.llm_provider == "gemini": convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) response = self.llm.generate_content(convo_str) return response.text.strip(), [] elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=convo ) return completion.choices[0].message.content.strip(), [] else: print("๐Ÿ”„ Follow-up query โ†’ reusing previous context.") retrieved_context = self.last_context context_docs = retrieved_context elif is_followup and not self.last_context: if not is_technical: print("โš ๏ธ Non-technical followup โ†’ skipping Qdrant.") system_prompt = "You are a helpful assistant. Answer clearly and concisely." convo = [{"role": "system", "content": system_prompt}, {"role": "user", "content": query}] if self.llm_provider == "gemini": convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) response = self.llm.generate_content(convo_str) return response.text.strip(), [] elif self.llm_provider == "groq": completion = self.llm.chat.completions.create( model=self.llm_model_name, messages=convo ) return completion.choices[0].message.content.strip(), [] else: print("๐Ÿ”„ Follow-up query โ†’ without previous context.") #retrieved_context = self.last_context context_docs = [] else: print("๐Ÿ“ฅ New technical query โ†’ retrieving from Qdrant.") retrieved_context = self.retrieve(query) last_context = [] for i, doc in enumerate(retrieved_context): last_context.append(doc['id']) print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})") first_doc = retrieved_context[0] context_docs = [] # Step 2: Determine starting point based on entity type pr_docs_to_use = [] if first_doc["entity_type"] == "ProblemReport": pr_id = first_doc["id"] print(f"๐Ÿ“Œ Using PR from context1: {pr_id}") pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) elif first_doc["entity_type"] == "Correction": cr_id = first_doc["id"] print(f"๐Ÿ“Œ Using CR from context1: {cr_id}") cr_doc = self.db["corrections"].find_one({"id": cr_id}) pr_ids = cr_doc.get("problemReportIds", []) if cr_doc else [] if isinstance(pr_ids, str): pr_ids = [pr_ids] for pr_id in pr_ids: pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) elif first_doc["entity_type"] == "FaultAnalysis": fa_id = first_doc["id"] print(f"๐Ÿ“Œ Using FA from context1: {fa_id}") fa_doc = self.db["faultanalysis"].find_one({"id": fa_id}) pr_ids = fa_doc.get("problemReportIds", []) if fa_doc else [] if isinstance(pr_ids, str): pr_ids = [pr_ids] for pr_id in pr_ids: pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) # Step 3: Build context documents for LLM, prioritize CR and FA for pr_doc, cr_docs, fa_docs in pr_docs_to_use: # Include FA first (analysis of problem) for fa in fa_docs: context_docs.append({ "entity_type": "FaultAnalysis", "content": build_content(fa, "FaultAnalysis"), "score": 1.0 }) # Include CR next (solutions/corrections) for cr in cr_docs: context_docs.append({ "entity_type": "Correction", "content": build_content(cr, "Correction"), "score": 1.0 }) # PR last (problem description) if pr_doc: context_docs.append({ "entity_type": "ProblemReport", "content": build_content(pr_doc, "ProblemReport"), "score": 0.9 }) print(f"โœ… Total documents for LLM context: {len(context_docs)}") if(len(last_context)>0): self.last_context = context_docs # save for future follow-ups if not retrieved_context: print("๐Ÿ’ฌ No relevant context found.") return "I could not find any relevant information.", [] #print(f"โœ… Using {len(retrieved_context)} documents as context.") #answer = self.generate_answer(query, retrieved_context, history, is_followup) answer = self.generate_answer(query, context_docs, history, is_followup) last_context = self.last_context #print(f"\n๐Ÿค– Answer: {answer}") return (answer, last_context)