|
|
import os |
|
|
import torch |
|
|
from qdrant_client import QdrantClient, models |
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
from pymongo import MongoClient |
|
|
from bson import ObjectId |
|
|
from typing import List, Dict |
|
|
import google.generativeai as genai |
|
|
from groq import Groq |
|
|
|
|
|
from embedding_model_instance import embedding_model, embedding_dim, reranker |
|
|
from qdrant_instance import qdrant |
|
|
from llm import gemini, groq |
|
|
from mongo_instance import db |
|
|
import json |
|
|
from bson import ObjectId |
|
|
|
|
|
def build_content(doc: dict, entity_type: str) -> str: |
|
|
"""Convert MongoDB document into natural text for embeddings.""" |
|
|
parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"] |
|
|
for k, v in doc.items(): |
|
|
if k in ["_id"]: |
|
|
continue |
|
|
if isinstance(v, list): |
|
|
parts.append(f"{k}: {', '.join(map(str, v))}") |
|
|
elif isinstance(v, dict): |
|
|
nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv]) |
|
|
parts.append(f"{k}: {nested}") |
|
|
else: |
|
|
if v: |
|
|
parts.append(f"{k}: {v}") |
|
|
return "\n".join(parts) |
|
|
|
|
|
|
|
|
class ErrorBot: |
|
|
"""Chatbot using RAG (Qdrant + Gemini API).""" |
|
|
|
|
|
def __init__(self, embedding_model_name: str, llm_model_name: str, google_api_key: str = None, groq_api_key: str = None, llm_provider: str = "gemini", last_context: list = None): |
|
|
print("π Initializing ErrorBot...") |
|
|
self.last_context = last_context |
|
|
|
|
|
print("last_context", last_context) |
|
|
|
|
|
|
|
|
|
|
|
self.embedding_model = embedding_model |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
self.db = db |
|
|
|
|
|
|
|
|
self.qdrant = qdrant |
|
|
self.collection_name = "technical_errors" |
|
|
|
|
|
|
|
|
|
|
|
self.llm_provider = llm_provider.lower() |
|
|
self.llm_model_name = llm_model_name |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
|
|
|
self.llm = gemini |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
|
|
|
self.llm = groq |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported LLM provider: {self.llm_provider}") |
|
|
|
|
|
|
|
|
|
|
|
self.reranker = reranker |
|
|
print(f"β
ErrorBot ready with {self.llm_provider.upper()}") |
|
|
|
|
|
def _setup_collection(self): |
|
|
if not self.qdrant.collection_exists(self.collection_name): |
|
|
self.qdrant.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=models.VectorParams( |
|
|
size=self.embedding_dim, |
|
|
distance=models.Distance.COSINE, |
|
|
), |
|
|
) |
|
|
|
|
|
def ingest_from_mongodb(self, mongo_uri: str, db_name: str, batch_size: int = 32): |
|
|
client = MongoClient(mongo_uri) |
|
|
db = client[db_name] |
|
|
|
|
|
collections = { |
|
|
"ProblemReport": db["problemReports"], |
|
|
"FaultAnalysis": db["faultanalysis"], |
|
|
"Correction": db["corrections"], |
|
|
} |
|
|
|
|
|
docs = [] |
|
|
for entity_type, coll in collections.items(): |
|
|
for doc in coll.find(): |
|
|
if "_id" in doc and isinstance(doc["_id"], ObjectId): |
|
|
doc["_id"] = str(doc["_id"]) |
|
|
docs.append({"entity_type": entity_type, "data": doc}) |
|
|
|
|
|
contents = [build_content(d["data"], d["entity_type"]) for d in docs] |
|
|
|
|
|
all_embeddings = [] |
|
|
for i in range(0, len(contents), batch_size): |
|
|
batch_contents = contents[i:i + batch_size] |
|
|
embeddings = self.embedding_model.encode(batch_contents, show_progress_bar=True).tolist() |
|
|
all_embeddings.extend(embeddings) |
|
|
|
|
|
self.qdrant.upsert( |
|
|
collection_name=self.collection_name, |
|
|
points=[ |
|
|
models.PointStruct( |
|
|
id=i, |
|
|
vector=emb, |
|
|
payload={ |
|
|
"id": d["data"].get("id", str(d["data"].get("_id", i))), |
|
|
"entity_type": d["entity_type"], |
|
|
"raw": d["data"], |
|
|
"content": c, |
|
|
}, |
|
|
) |
|
|
for i, (d, emb, c) in enumerate(zip(docs, all_embeddings, contents)) |
|
|
], |
|
|
wait=True, |
|
|
) |
|
|
print(f"β
Ingested {len(docs)} documents into '{self.collection_name}'") |
|
|
|
|
|
def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.3, rerank: bool = True): |
|
|
query_embedding = self.embedding_model.encode(query).tolist() |
|
|
hits = self.qdrant.query_points( |
|
|
collection_name=self.collection_name, |
|
|
query=query_embedding, |
|
|
limit=top_k * 3 if rerank else top_k, |
|
|
with_payload=True, |
|
|
score_threshold=score_threshold, |
|
|
).points |
|
|
|
|
|
candidates = [ |
|
|
{ |
|
|
"id": hit.payload.get("id"), |
|
|
"entity_type": hit.payload.get("entity_type", ""), |
|
|
"content": hit.payload.get("content", ""), |
|
|
"score": hit.score, |
|
|
} |
|
|
for hit in hits |
|
|
] |
|
|
|
|
|
if rerank and candidates: |
|
|
pairs = [(query, c["content"]) for c in candidates] |
|
|
scores = self.reranker.predict(pairs) |
|
|
for i, score in enumerate(scores): |
|
|
candidates[i]["rerank_score"] = float(score) |
|
|
candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) |
|
|
|
|
|
return candidates[:top_k] |
|
|
|
|
|
def generate_answer(self, query: str, context: List[Dict], history: list = None, is_followup: bool = False ): |
|
|
""" |
|
|
Generates an answer using the LLM, guiding it to identify which context is useful. |
|
|
""" |
|
|
context_str="" |
|
|
|
|
|
if(is_followup): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline = [ |
|
|
|
|
|
{ |
|
|
"$match": {"_id": {"$in": self.last_context}} |
|
|
}, |
|
|
{ |
|
|
"$addFields": {"entity_type": "ProblemReport"} |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"$unionWith": { |
|
|
"coll": "faultanalysis", |
|
|
"pipeline": [ |
|
|
{"$match": {"id": {"$in": self.last_context}}}, |
|
|
{"$addFields": {"entity_type": "FaultAnalysis"}} |
|
|
] |
|
|
} |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"$unionWith": { |
|
|
"coll": "corrections", |
|
|
"pipeline": [ |
|
|
{"$match": {"id": {"$in": self.last_context}}}, |
|
|
{"$addFields": {"entity_type": "Correction"}} |
|
|
] |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
context_docs = list(db.problemReports.aggregate(pipeline)) |
|
|
|
|
|
|
|
|
context_str = "\n---\n".join( |
|
|
[f"{c['entity_type']} (ID: {c['_id']}):\n{json.dumps(c, default=str)}" |
|
|
for c in context_docs] |
|
|
) |
|
|
print("Context String in Follow Up:") |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
context_str = "\n---\n".join( |
|
|
[f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are a technical assistant. A user may ask questions about Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). |
|
|
Your task is to: |
|
|
1. Identify which information (PR, FA, CR) is relevant to answering the user's question. |
|
|
2. Explain the solution in simple, clear, actionable language. |
|
|
3. Do not just repeat the content; summarize and explain. |
|
|
|
|
|
### User Question: |
|
|
|
|
|
|
|
|
### Context: |
|
|
{context_str} |
|
|
|
|
|
Provide a concise, step-by-step explanation if applicable. |
|
|
""" |
|
|
|
|
|
|
|
|
convo = [] |
|
|
if history: |
|
|
for msg in history: |
|
|
convo.append({ |
|
|
"role": "user" if msg["role"] == "user" else "assistant", |
|
|
"content": msg["content"], |
|
|
}) |
|
|
|
|
|
convo.append({"role": "user", "content": query}) |
|
|
|
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:" |
|
|
response = self.llm.generate_content(prompt) |
|
|
return response.text.strip() |
|
|
|
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": system_prompt}] + convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip() |
|
|
|
|
|
def fetch_problem_report_with_links(self, pr_id: str): |
|
|
|
|
|
|
|
|
pr_doc = db["problemReports"].find_one({"id": pr_id}) |
|
|
if not pr_doc: |
|
|
return None, [], [], [], [] |
|
|
|
|
|
if "_id" in pr_doc and isinstance(pr_doc["_id"], ObjectId): |
|
|
pr_doc["_id"] = str(pr_doc["_id"]) |
|
|
|
|
|
|
|
|
cr_ids = pr_doc.get("correctionIds", []) |
|
|
fa_ids = pr_doc.get("faultAnalysisId", []) |
|
|
|
|
|
|
|
|
if isinstance(cr_ids, str): |
|
|
cr_ids = [cr_ids] |
|
|
elif cr_ids is None: |
|
|
cr_ids = [] |
|
|
|
|
|
if isinstance(fa_ids, str): |
|
|
fa_ids = [fa_ids] |
|
|
elif fa_ids is None: |
|
|
fa_ids = [] |
|
|
|
|
|
|
|
|
cr_docs = list(db["corrections"].find({"id": {"$in": cr_ids}})) if cr_ids else [] |
|
|
for doc in cr_docs: |
|
|
if "_id" in doc and isinstance(doc["_id"], ObjectId): |
|
|
doc["_id"] = str(doc["_id"]) |
|
|
|
|
|
|
|
|
fa_docs = list(db["faultanalysis"].find({"id": {"$in": fa_ids}})) if fa_ids else [] |
|
|
for doc in fa_docs: |
|
|
if "_id" in doc and isinstance(doc["_id"], ObjectId): |
|
|
doc["_id"] = str(doc["_id"]) |
|
|
|
|
|
return pr_doc, cr_ids, fa_ids, cr_docs, fa_docs |
|
|
|
|
|
|
|
|
def is_technical_query(self, query: str) -> bool: |
|
|
""" |
|
|
Classify query as TECHNICAL or NON-TECHNICAL. |
|
|
""" |
|
|
classification_prompt = f""" |
|
|
You are a classifier. Determine if the following query is TECHNICAL |
|
|
(related to software, debugging, errors, troubleshooting, fault analysis, |
|
|
corrections, technical problem reports) or NON-TECHNICAL |
|
|
(general questions, greetings, chit-chat, unrelated topics). |
|
|
|
|
|
Query: "{query}" |
|
|
|
|
|
Respond with exactly one word: "TECHNICAL" or "NON-TECHNICAL". |
|
|
""" |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
response = self.llm.generate_content(classification_prompt) |
|
|
result = response.text.strip().upper() |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": classification_prompt}] |
|
|
) |
|
|
result = completion.choices[0].message.content.strip().upper() |
|
|
|
|
|
return result == "TECHNICAL" |
|
|
|
|
|
def is_followup_query(self, query: str, history: list = None) -> bool: |
|
|
""" |
|
|
Detect if query is a follow-up based on conversation history. |
|
|
""" |
|
|
if not history: |
|
|
return False |
|
|
|
|
|
classification_prompt = f""" |
|
|
You are a classifier. Determine if the following user query |
|
|
is a FOLLOW-UP (depends on the previous conversation) |
|
|
or a NEW QUERY (can be answered independently). |
|
|
|
|
|
Previous conversation: |
|
|
{ [msg['content'] for msg in history][-3:] } |
|
|
|
|
|
Current query: "{query}" |
|
|
|
|
|
Respond with exactly one word: "FOLLOW-UP" or "NEW". |
|
|
""" |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
response = self.llm.generate_content(classification_prompt) |
|
|
result = response.text.strip().upper() |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": classification_prompt}] |
|
|
) |
|
|
result = completion.choices[0].message.content.strip().upper() |
|
|
print("Follow up: ", result) |
|
|
return result == "FOLLOW-UP" |
|
|
|
|
|
def ask(self, query: str, history: list = None): |
|
|
print(f"\nβ Query: {query}") |
|
|
|
|
|
|
|
|
is_technical = self.is_technical_query(query) |
|
|
is_followup = self.is_followup_query(query, history) |
|
|
|
|
|
|
|
|
|
|
|
if not is_technical and not is_followup: |
|
|
print("β οΈ Non-technical standalone query β skipping Qdrant.") |
|
|
system_prompt = "You are a helpful assistant. Answer clearly and concisely." |
|
|
convo = [{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": query}] |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
response = self.llm.generate_content(convo_str) |
|
|
return response.text.strip(), [] |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip(), [] |
|
|
|
|
|
|
|
|
print("is_followup", is_followup) |
|
|
print("last_context", self.last_context) |
|
|
print("is_technical", is_technical) |
|
|
|
|
|
if is_followup and self.last_context: |
|
|
if not is_technical: |
|
|
print("β οΈ Non-technical followup β skipping Qdrant.") |
|
|
system_prompt = "You are a helpful assistant. Answer clearly and concisely." |
|
|
convo = [{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": query}] |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
response = self.llm.generate_content(convo_str) |
|
|
return response.text.strip(), [] |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip(), [] |
|
|
else: |
|
|
print("π Follow-up query β reusing previous context.") |
|
|
retrieved_context = self.last_context |
|
|
context_docs = retrieved_context |
|
|
|
|
|
else: |
|
|
print("π₯ New technical query β retrieving from Qdrant.") |
|
|
retrieved_context = self.retrieve(query) |
|
|
last_context = [] |
|
|
for i, doc in enumerate(retrieved_context): |
|
|
last_context.append(doc['id']) |
|
|
print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})") |
|
|
|
|
|
first_doc = retrieved_context[0] |
|
|
context_docs = [] |
|
|
|
|
|
|
|
|
pr_docs_to_use = [] |
|
|
|
|
|
if first_doc["entity_type"] == "ProblemReport": |
|
|
pr_id = first_doc["id"] |
|
|
print(f"π Using PR from context1: {pr_id}") |
|
|
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) |
|
|
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) |
|
|
|
|
|
elif first_doc["entity_type"] == "Correction": |
|
|
cr_id = first_doc["id"] |
|
|
print(f"π Using CR from context1: {cr_id}") |
|
|
cr_doc = self.db["corrections"].find_one({"id": cr_id}) |
|
|
pr_ids = cr_doc.get("problemReportIds", []) if cr_doc else [] |
|
|
|
|
|
if isinstance(pr_ids, str): |
|
|
pr_ids = [pr_ids] |
|
|
for pr_id in pr_ids: |
|
|
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) |
|
|
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) |
|
|
|
|
|
elif first_doc["entity_type"] == "FaultAnalysis": |
|
|
fa_id = first_doc["id"] |
|
|
print(f"π Using FA from context1: {fa_id}") |
|
|
fa_doc = self.db["faultanalysis"].find_one({"id": fa_id}) |
|
|
pr_ids = fa_doc.get("problemReportIds", []) if fa_doc else [] |
|
|
|
|
|
if isinstance(pr_ids, str): |
|
|
pr_ids = [pr_ids] |
|
|
for pr_id in pr_ids: |
|
|
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id) |
|
|
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs)) |
|
|
|
|
|
|
|
|
for pr_doc, cr_docs, fa_docs in pr_docs_to_use: |
|
|
|
|
|
for fa in fa_docs: |
|
|
context_docs.append({ |
|
|
"entity_type": "FaultAnalysis", |
|
|
"content": build_content(fa, "FaultAnalysis"), |
|
|
"score": 1.0 |
|
|
}) |
|
|
|
|
|
for cr in cr_docs: |
|
|
context_docs.append({ |
|
|
"entity_type": "Correction", |
|
|
"content": build_content(cr, "Correction"), |
|
|
"score": 1.0 |
|
|
}) |
|
|
|
|
|
if pr_doc: |
|
|
context_docs.append({ |
|
|
"entity_type": "ProblemReport", |
|
|
"content": build_content(pr_doc, "ProblemReport"), |
|
|
"score": 0.9 |
|
|
}) |
|
|
|
|
|
print(f"β
Total documents for LLM context: {len(context_docs)}") |
|
|
|
|
|
if(len(last_context)>0): |
|
|
self.last_context = context_docs |
|
|
if not retrieved_context: |
|
|
print("π¬ No relevant context found.") |
|
|
return "I could not find any relevant information.", [] |
|
|
|
|
|
print(f"β
Using {len(retrieved_context)} documents as context.") |
|
|
|
|
|
|
|
|
answer = self.generate_answer(query, context_docs, history, is_followup) |
|
|
last_context = self.last_context |
|
|
print(f"\nπ€ Answer: {answer}") |
|
|
return (answer, last_context) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|