|
|
import os |
|
|
import torch |
|
|
from qdrant_client import QdrantClient, models |
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
from pymongo import MongoClient |
|
|
|
|
|
from typing import List, Dict |
|
|
import google.generativeai as genai |
|
|
from groq import Groq |
|
|
|
|
|
from embedding_model_instance import embedding_model_m3, embedding_dim_m3, embedding_model_large, embedding_dim_large, reranker |
|
|
from qdrant_instance import qdrant_m3, qdrant_large |
|
|
from llm import gemini, groq |
|
|
from mongo_instance import db |
|
|
import json |
|
|
from bson import ObjectId |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_content(doc: dict, entity_type: str) -> str: |
|
|
"""Convert MongoDB document into natural text for embeddings.""" |
|
|
parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"] |
|
|
for k, v in doc.items(): |
|
|
if k in ["_id"]: |
|
|
continue |
|
|
if isinstance(v, list): |
|
|
parts.append(f"{k}: {', '.join(map(str, v))}") |
|
|
elif isinstance(v, dict): |
|
|
nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv]) |
|
|
parts.append(f"{k}: {nested}") |
|
|
else: |
|
|
if v: |
|
|
parts.append(f"{k}: {v}") |
|
|
return "\n".join(parts) |
|
|
|
|
|
|
|
|
class ErrorBot: |
|
|
"""Chatbot using RAG (Qdrant + Gemini API).""" |
|
|
|
|
|
def __init__(self, llm_model_name: str, llm_provider: str = "gemini", last_context: list = None): |
|
|
print("🚀 Initializing ErrorBot...") |
|
|
self.last_context = last_context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embedding_model_m3 = embedding_model_m3 |
|
|
self.embedding_dim_m3 = embedding_dim_m3 |
|
|
|
|
|
self.embedding_model_large = embedding_model_large |
|
|
self.embedding_dim_large = embedding_dim_large |
|
|
|
|
|
|
|
|
|
|
|
self.db = db |
|
|
|
|
|
|
|
|
self.qdrant_m3 = qdrant_m3 |
|
|
self.qdrant_large = qdrant_large |
|
|
self.collection_name = "technical_errors" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.llm_provider = llm_provider.lower() |
|
|
self.llm_model_name = llm_model_name |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
|
|
|
self.llm = gemini |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
|
|
|
self.llm = groq |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported LLM provider: {self.llm_provider}") |
|
|
|
|
|
|
|
|
|
|
|
self.reranker = reranker |
|
|
print(f"✅ ErrorBot ready with {self.llm_provider.upper()}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.5, rerank: bool = True): |
|
|
"""Retrieve documents using ensemble of BGE-M3 and BGE-Large models.""" |
|
|
print(f"\n🔍 Retrieving context using ensemble (M3 + BGE-Large) for query: {query}") |
|
|
|
|
|
|
|
|
emb_m3 = self.embedding_model_m3.encode(query).tolist() |
|
|
emb_large = self.embedding_model_large.encode(query).tolist() |
|
|
|
|
|
|
|
|
hits_m3 = self.qdrant_m3.query_points( |
|
|
collection_name=self.collection_name, |
|
|
query=emb_m3, |
|
|
limit=top_k * 3, |
|
|
with_payload=True, |
|
|
score_threshold=score_threshold, |
|
|
).points |
|
|
|
|
|
hits_large = self.qdrant_large.query_points( |
|
|
collection_name=self.collection_name, |
|
|
query=emb_large, |
|
|
limit=top_k * 3, |
|
|
with_payload=True, |
|
|
score_threshold=score_threshold, |
|
|
).points |
|
|
|
|
|
|
|
|
all_hits = [] |
|
|
for hit in hits_m3 + hits_large: |
|
|
payload = hit.payload |
|
|
score = hit.score |
|
|
all_hits.append({ |
|
|
"id": payload.get("id"), |
|
|
"entity_type": payload.get("entity_type", ""), |
|
|
"content": payload.get("content", ""), |
|
|
"score": score, |
|
|
"source": "M3" if hit in hits_m3 else "LARGE" |
|
|
}) |
|
|
|
|
|
if not all_hits: |
|
|
print("⚠️ No hits from either model.") |
|
|
return [] |
|
|
|
|
|
|
|
|
scores = [h["score"] for h in all_hits] |
|
|
min_s, max_s = min(scores), max(scores) |
|
|
for h in all_hits: |
|
|
h["score_norm"] = (h["score"] - min_s) / (max_s - min_s + 1e-6) |
|
|
|
|
|
|
|
|
merged = {} |
|
|
for h in all_hits: |
|
|
_id = h["id"] |
|
|
if _id not in merged: |
|
|
merged[_id] = h |
|
|
else: |
|
|
merged[_id]["score_norm"] = (merged[_id]["score_norm"] + h["score_norm"]) / 2 |
|
|
|
|
|
combined_hits = list(merged.values()) |
|
|
combined_hits = sorted(combined_hits, key=lambda x: x["score_norm"], reverse=True)[:top_k * 2] |
|
|
|
|
|
|
|
|
if rerank and combined_hits: |
|
|
pairs = [(query, h["content"]) for h in combined_hits] |
|
|
scores = self.reranker.predict(pairs) |
|
|
for i, s in enumerate(scores): |
|
|
combined_hits[i]["rerank_score"] = float(s) |
|
|
combined_hits = sorted(combined_hits, key=lambda x: x["rerank_score"], reverse=True) |
|
|
|
|
|
print(f"✅ Ensemble retrieved {len(combined_hits)} candidates.") |
|
|
return combined_hits[:top_k] |
|
|
|
|
|
def generate_answer(self, query: str, context: List[Dict], history: list = None, is_followup: bool = False ): |
|
|
""" |
|
|
Generates an answer using the LLM, guiding it to identify which context is useful. |
|
|
""" |
|
|
context_str="" |
|
|
|
|
|
if(is_followup): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline = [ |
|
|
|
|
|
{ |
|
|
"$match": {"_id": {"$in": self.last_context}} |
|
|
}, |
|
|
{ |
|
|
"$addFields": {"entity_type": "ProblemReport"} |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"$unionWith": { |
|
|
"coll": "faultanalysis", |
|
|
"pipeline": [ |
|
|
{"$match": {"id": {"$in": self.last_context}}}, |
|
|
{"$addFields": {"entity_type": "FaultAnalysis"}} |
|
|
] |
|
|
} |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"$unionWith": { |
|
|
"coll": "corrections", |
|
|
"pipeline": [ |
|
|
{"$match": {"id": {"$in": self.last_context}}}, |
|
|
{"$addFields": {"entity_type": "Correction"}} |
|
|
] |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
context_docs = list(db.problemReports.aggregate(pipeline)) |
|
|
|
|
|
|
|
|
context_str = "\n---\n".join( |
|
|
[f"{c['entity_type']} (ID: {c['_id']}):\n{json.dumps(c, default=str)}" |
|
|
for c in context_docs] |
|
|
) |
|
|
print("Context String in Follow Up:") |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
context_str = "\n---\n".join( |
|
|
[f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are a versatile assistant. A user may ask questions about: |
|
|
- Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). |
|
|
- Programming, algorithms, and code examples. |
|
|
- Non-technical or general everyday topics. |
|
|
|
|
|
Your tasks are: |
|
|
1. If the question is about PR, FA, or CR → Identify which information is relevant and explain clearly in simple, actionable language (summarize, don’t just repeat). |
|
|
2. If the question is about programming or algorithms → Provide a correct, clear, and well-structured code example in the requested language, with explanation. |
|
|
3. If the question is non-technical/general → Respond politely, clearly, and helpfully in a conversational style. |
|
|
4. Always keep answers and easy to understand and detailed. |
|
|
|
|
|
### User Question: |
|
|
|
|
|
|
|
|
### Context: |
|
|
{context_str} |
|
|
|
|
|
Provide a concise, step-by-step explanation if applicable. |
|
|
""" |
|
|
|
|
|
|
|
|
convo = [] |
|
|
if history: |
|
|
for msg in history: |
|
|
convo.append({ |
|
|
"role": "user" if msg["role"] == "user" else "assistant", |
|
|
"content": msg["content"], |
|
|
}) |
|
|
|
|
|
convo.append({"role": "user", "content": query}) |
|
|
|
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:" |
|
|
response = self.llm.generate_content(prompt) |
|
|
return response.text.strip() |
|
|
|
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": system_prompt}] + convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
def fetch_problem_report_with_links(self, pr_id: str): |
|
|
|
|
|
|
|
|
pr_doc = db["problemReports"].find_one({"id": pr_id}) |
|
|
|
|
|
|
|
|
if not pr_doc: |
|
|
return None, [], [], [], [] |
|
|
|
|
|
if "_id" in pr_doc and isinstance(pr_doc["_id"], ObjectId): |
|
|
pr_doc["_id"] = str(pr_doc["_id"]) |
|
|
|
|
|
|
|
|
cr_ids = pr_doc.get("correctionIds", []) |
|
|
fa_ids = pr_doc.get("faultAnalysisId", []) |
|
|
|
|
|
|
|
|
if isinstance(cr_ids, str): |
|
|
cr_ids = [cr_ids] |
|
|
elif cr_ids is None: |
|
|
cr_ids = [] |
|
|
|
|
|
if isinstance(fa_ids, str): |
|
|
fa_ids = [fa_ids] |
|
|
elif fa_ids is None: |
|
|
fa_ids = [] |
|
|
|
|
|
|
|
|
cr_docs = list(db["corrections"].find({"id": {"$in": cr_ids}})) if cr_ids else [] |
|
|
for doc in cr_docs: |
|
|
if "_id" in doc and isinstance(doc["_id"], ObjectId): |
|
|
doc["_id"] = str(doc["_id"]) |
|
|
|
|
|
|
|
|
fa_docs = list(db["faultanalysis"].find({"id": {"$in": fa_ids}})) if fa_ids else [] |
|
|
for doc in fa_docs: |
|
|
if "_id" in doc and isinstance(doc["_id"], ObjectId): |
|
|
doc["_id"] = str(doc["_id"]) |
|
|
|
|
|
print(pr_doc) |
|
|
|
|
|
return pr_doc, cr_ids, fa_ids, cr_docs, fa_docs |
|
|
|
|
|
|
|
|
def is_technical_query(self, query: str) -> bool: |
|
|
""" |
|
|
Classify query as TECHNICAL or NON-TECHNICAL. |
|
|
""" |
|
|
classification_prompt = f""" |
|
|
You are a classifier. Determine if the following query is TECHNICAL |
|
|
(related to software, debugging, errors, troubleshooting, fault analysis, |
|
|
corrections, technical problem reports) or NON-TECHNICAL |
|
|
(general questions, greetings, chit-chat, unrelated topics). |
|
|
|
|
|
Query: "{query}" |
|
|
|
|
|
Respond with exactly one word: "TECHNICAL" or "NON-TECHNICAL". |
|
|
""" |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
response = self.llm.generate_content(classification_prompt) |
|
|
result = response.text.strip().upper() |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": classification_prompt}] |
|
|
) |
|
|
result = completion.choices[0].message.content.strip().upper() |
|
|
|
|
|
return result == "TECHNICAL" |
|
|
|
|
|
def is_followup_query(self, query: str, history: list = None) -> bool: |
|
|
""" |
|
|
Detect if query is a follow-up based on conversation history. |
|
|
""" |
|
|
if not history: |
|
|
return False |
|
|
|
|
|
classification_prompt = f""" |
|
|
You are a classifier. Determine if the following user query |
|
|
is a FOLLOW-UP (depends on the previous conversation) |
|
|
or a NEW QUERY (can be answered independently). |
|
|
|
|
|
Previous conversation: |
|
|
{ [msg['content'] for msg in history][-3:] } |
|
|
|
|
|
Current query: "{query}" |
|
|
|
|
|
Respond with exactly one word: "FOLLOW-UP" or "NEW". |
|
|
""" |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
response = self.llm.generate_content(classification_prompt) |
|
|
result = response.text.strip().upper() |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": classification_prompt}] |
|
|
) |
|
|
result = completion.choices[0].message.content.strip().upper() |
|
|
print("Follow up: ", result) |
|
|
return result == "FOLLOW-UP" |
|
|
|
|
|
def ask(self, query: str, history: list = None): |
|
|
print(f"\n❓ Query: {query}") |
|
|
|
|
|
|
|
|
is_technical = self.is_technical_query(query) |
|
|
is_followup = self.is_followup_query(query, history) |
|
|
|
|
|
|
|
|
|
|
|
print("is_followup", is_followup) |
|
|
|
|
|
print("is_technical", is_technical) |
|
|
|
|
|
|
|
|
if not is_technical and not is_followup: |
|
|
print("⚠️ Non-technical standalone query → skipping Qdrant.") |
|
|
system_prompt = "You are a helpful assistant. Answer clearly and concisely." |
|
|
convo = [{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": query}] |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
response = self.llm.generate_content(convo_str) |
|
|
return response.text.strip(), [] |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip(), [] |
|
|
|
|
|
elif is_followup and self.last_context: |
|
|
if not is_technical: |
|
|
print("⚠️ Non-technical followup → skipping Qdrant.") |
|
|
system_prompt = "You are a helpful assistant. Answer clearly and concisely." |
|
|
convo = [{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": query}] |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
response = self.llm.generate_content(convo_str) |
|
|
return response.text.strip(), [] |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip(), [] |
|
|
else: |
|
|
print("🔄 Follow-up query → reusing previous context.") |
|
|
retrieved_context = self.last_context |
|
|
context_docs = retrieved_context |
|
|
|
|
|
elif is_followup and not self.last_context: |
|
|
|
|
|
if not is_technical: |
|
|
print("⚠️ Non-technical followup → skipping Qdrant.") |
|
|
system_prompt = "You are a helpful assistant. Answer clearly and concisely." |
|
|
convo = [{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": query}] |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
response = self.llm.generate_content(convo_str) |
|
|
return response.text.strip(), [] |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip(), [] |
|
|
else: |
|
|
print("🔄 Follow-up query → without previous context.") |
|
|
|
|
|
context_docs = [] |
|
|
|
|
|
else: |
|
|
print("📥 New technical query → retrieving from Qdrant.") |
|
|
retrieved_context = self.retrieve(query) |
|
|
last_context = [] |
|
|
for i, doc in enumerate(retrieved_context): |
|
|
last_context.append(doc['id']) |
|
|
print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})") |
|
|
|
|
|
first_doc = retrieved_context[0] |
|
|
context_docs = [] |
|
|
|
|
|
|
|
|
pr_docs_to_use = [] |
|
|
|
|
|
if first_doc["entity_type"] == "ProblemReport": |
|
|
pr_id = first_doc["id"] |
|
|
print(f"📌 Using PR from context1: {pr_id}") |
|
|
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) |
|
|
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) |
|
|
|
|
|
elif first_doc["entity_type"] == "Correction": |
|
|
cr_id = first_doc["id"] |
|
|
print(f"📌 Using CR from context1: {cr_id}") |
|
|
cr_doc = self.db["corrections"].find_one({"id": cr_id}) |
|
|
pr_ids = cr_doc.get("problemReportIds", []) if cr_doc else [] |
|
|
|
|
|
if isinstance(pr_ids, str): |
|
|
pr_ids = [pr_ids] |
|
|
for pr_id in pr_ids: |
|
|
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) |
|
|
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) |
|
|
|
|
|
elif first_doc["entity_type"] == "FaultAnalysis": |
|
|
fa_id = first_doc["id"] |
|
|
print(f"📌 Using FA from context1: {fa_id}") |
|
|
fa_doc = self.db["faultanalysis"].find_one({"id": fa_id}) |
|
|
pr_ids = fa_doc.get("problemReportIds", []) if fa_doc else [] |
|
|
|
|
|
if isinstance(pr_ids, str): |
|
|
pr_ids = [pr_ids] |
|
|
for pr_id in pr_ids: |
|
|
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) |
|
|
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) |
|
|
|
|
|
|
|
|
for pr_doc, cr_docs, fa_docs in pr_docs_to_use: |
|
|
|
|
|
for fa in fa_docs: |
|
|
context_docs.append({ |
|
|
"entity_type": "FaultAnalysis", |
|
|
"content": build_content(fa, "FaultAnalysis"), |
|
|
"score": 1.0 |
|
|
}) |
|
|
|
|
|
for cr in cr_docs: |
|
|
context_docs.append({ |
|
|
"entity_type": "Correction", |
|
|
"content": build_content(cr, "Correction"), |
|
|
"score": 1.0 |
|
|
}) |
|
|
|
|
|
if pr_doc: |
|
|
context_docs.append({ |
|
|
"entity_type": "ProblemReport", |
|
|
"content": build_content(pr_doc, "ProblemReport"), |
|
|
"score": 0.9 |
|
|
}) |
|
|
|
|
|
print(f"✅ Total documents for LLM context: {len(context_docs)}") |
|
|
|
|
|
if(len(last_context)>0): |
|
|
self.last_context = context_docs |
|
|
if not retrieved_context: |
|
|
print("💬 No relevant context found.") |
|
|
return "I could not find any relevant information.", [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
answer = self.generate_answer(query, context_docs, history, is_followup) |
|
|
last_context = self.last_context |
|
|
|
|
|
return (answer, last_context) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|