LLM_Model / util.py
shreekantkalwar's picture
ensemble
b3e9a96
raw
history blame
23.6 kB
import os
import torch
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer, CrossEncoder
from pymongo import MongoClient
from typing import List, Dict
import google.generativeai as genai
from groq import Groq
from embedding_model_instance import embedding_model_m3, embedding_dim_m3, embedding_model_large, embedding_dim_large, reranker
from qdrant_instance import qdrant_m3, qdrant_large
from llm import gemini, groq
from mongo_instance import db
import json
from bson import ObjectId
def build_content(doc: dict, entity_type: str) -> str:
"""Convert MongoDB document into natural text for embeddings."""
parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"]
for k, v in doc.items():
if k in ["_id"]: # skip ObjectId
continue
if isinstance(v, list):
parts.append(f"{k}: {', '.join(map(str, v))}")
elif isinstance(v, dict):
nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv])
parts.append(f"{k}: {nested}")
else:
if v:
parts.append(f"{k}: {v}")
return "\n".join(parts)
class ErrorBot:
"""Chatbot using RAG (Qdrant + Gemini API)."""
def __init__(self, llm_model_name: str, llm_provider: str = "gemini", last_context: list = None):
print("🚀 Initializing ErrorBot...")
self.last_context = last_context
#print("last_context", last_context)
# --- Embedding model
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embedding_model_m3 = embedding_model_m3
self.embedding_dim_m3 = embedding_dim_m3
self.embedding_model_large = embedding_model_large
self.embedding_dim_large = embedding_dim_large
self.db = db
# --- Qdrant client
self.qdrant_m3 = qdrant_m3
self.qdrant_large = qdrant_large
self.collection_name = "technical_errors"
#self.collection_name = "json_ingestion"
#self._setup_collection()
# --- LLM setup
self.llm_provider = llm_provider.lower()
self.llm_model_name = llm_model_name
if self.llm_provider == "gemini":
self.llm = gemini
elif self.llm_provider == "groq":
self.llm = groq
else:
raise ValueError(f"Unsupported LLM provider: {self.llm_provider}")
# --- Cross encoder reranker
self.reranker = reranker
print(f"✅ ErrorBot ready with {self.llm_provider.upper()}")
# def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.5, rerank: bool = True):
# query_embedding = self.embedding_model.encode(query).tolist()
# hits = self.qdrant.query_points(
# collection_name=self.collection_name,
# query=query_embedding,
# #limit=top_k * 3 if rerank else top_k,
# limit = 100,
# with_payload=True,
# score_threshold=score_threshold,
# search_params=models.SearchParams(hnsw_ef=256),
# ).points
# candidates = [
# {
# "id": hit.payload.get("id"),
# # "id": hit.payload.get("raw", {}).get("id"),
# "entity_type": hit.payload.get("entity_type", ""),
# "content": hit.payload.get("content", ""),
# "score": hit.score,
# }
# for hit in hits
# ]
# if rerank and candidates:
# pairs = [(query, c["content"]) for c in candidates]
# scores = self.reranker.predict(pairs)
# for i, score in enumerate(scores):
# candidates[i]["rerank_score"] = float(score)
# candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
# return candidates[:5]
# ==================================================
# 🧮 Dual Qdrant Ensemble Retrieval
# ==================================================
def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.5, rerank: bool = True):
"""Retrieve documents using ensemble of BGE-M3 and BGE-Large models."""
print(f"\n🔍 Retrieving context using ensemble (M3 + BGE-Large) for query: {query}")
# 1️⃣ Encode using both models
emb_m3 = self.embedding_model_m3.encode(query).tolist()
emb_large = self.embedding_model_large.encode(query).tolist()
# 2️⃣ Query both Qdrant clusters
hits_m3 = self.qdrant_m3.query_points(
collection_name=self.collection_name,
query=emb_m3,
limit=top_k * 3,
with_payload=True,
score_threshold=score_threshold,
).points
hits_large = self.qdrant_large.query_points(
collection_name=self.collection_name,
query=emb_large,
limit=top_k * 3,
with_payload=True,
score_threshold=score_threshold,
).points
# 3️⃣ Combine results — average normalized scores
all_hits = []
for hit in hits_m3 + hits_large:
payload = hit.payload
score = hit.score
all_hits.append({
"id": payload.get("id"),
"entity_type": payload.get("entity_type", ""),
"content": payload.get("content", ""),
"score": score,
"source": "M3" if hit in hits_m3 else "LARGE"
})
if not all_hits:
print("⚠️ No hits from either model.")
return []
# Normalize scores between 0-1 (optional)
scores = [h["score"] for h in all_hits]
min_s, max_s = min(scores), max(scores)
for h in all_hits:
h["score_norm"] = (h["score"] - min_s) / (max_s - min_s + 1e-6)
# Group by ID and average scores if duplicates exist
merged = {}
for h in all_hits:
_id = h["id"]
if _id not in merged:
merged[_id] = h
else:
merged[_id]["score_norm"] = (merged[_id]["score_norm"] + h["score_norm"]) / 2
combined_hits = list(merged.values())
combined_hits = sorted(combined_hits, key=lambda x: x["score_norm"], reverse=True)[:top_k * 2]
# 4️⃣ (Optional) Rerank using cross encoder
if rerank and combined_hits:
pairs = [(query, h["content"]) for h in combined_hits]
scores = self.reranker.predict(pairs)
for i, s in enumerate(scores):
combined_hits[i]["rerank_score"] = float(s)
combined_hits = sorted(combined_hits, key=lambda x: x["rerank_score"], reverse=True)
print(f"✅ Ensemble retrieved {len(combined_hits)} candidates.")
return combined_hits[:top_k]
def generate_answer(self, query: str, context: List[Dict], history: list = None, is_followup: bool = False ):
"""
Generates an answer using the LLM, guiding it to identify which context is useful.
"""
context_str=""
if(is_followup):
pass
# Aggregation pipeline
# pipeline = [
# # Start with problemReports
# {"$match": {"_id": {"$in": self.last_context}}},
# # Add faultAnalysis
# {"$unionWith": {
# "coll": "faultanalysis",
# "pipeline": [{"$match": {"id": {"$in": self.last_context}}}]
# }},
# # Add corrections
# {"$unionWith": {
# "coll": "corrections",
# "pipeline": [{"$match": {"id": {"$in": self.last_context}}}]
# }}
# ]
pipeline = [
# Start with problemReports
{
"$match": {"_id": {"$in": self.last_context}}
},
{
"$addFields": {"entity_type": "ProblemReport"}
},
# Add faultAnalysis
{
"$unionWith": {
"coll": "faultanalysis",
"pipeline": [
{"$match": {"id": {"$in": self.last_context}}},
{"$addFields": {"entity_type": "FaultAnalysis"}}
]
}
},
# Add corrections
{
"$unionWith": {
"coll": "corrections",
"pipeline": [
{"$match": {"id": {"$in": self.last_context}}},
{"$addFields": {"entity_type": "Correction"}}
]
}
}
]
# Run aggregation on problemReports
context_docs = list(db.problemReports.aggregate(pipeline))
# Serialize full documents as text for LLM
#print(context_docs)
context_str = "\n---\n".join(
[f"{c['entity_type']} (ID: {c['_id']}):\n{json.dumps(c, default=str)}"
for c in context_docs]
)
print("Context String in Follow Up:")
#print(context_str)
else:
context_str = "\n---\n".join(
[f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context]
)
# --- System prompt
# system_prompt = f"""
# You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR).
# Use the provided context and conversation history to answer the question clearly and concisely.
# If context is not relevant, say you do not have enough information.
# ### Context
# {context_str}
# """
system_prompt = f"""
You are a versatile assistant. A user may ask questions about:
- Problem Reports (PR), Fault Analyses (FA), and Corrections (CR).
- Programming, algorithms, and code examples.
- Non-technical or general everyday topics.
Your tasks are:
1. If the question is about PR, FA, or CR → Identify which information is relevant and explain clearly in simple, actionable language (summarize, don’t just repeat).
2. If the question is about programming or algorithms → Provide a correct, clear, and well-structured code example in the requested language, with explanation.
3. If the question is non-technical/general → Respond politely, clearly, and helpfully in a conversational style.
4. Always keep answers and easy to understand and detailed.
### User Question:
### Context:
{context_str}
Provide a concise, step-by-step explanation if applicable.
"""
# --- Conversation history in list-of-dicts format
convo = []
if history:
for msg in history:
convo.append({
"role": "user" if msg["role"] == "user" else "assistant",
"content": msg["content"],
})
convo.append({"role": "user", "content": query})
# --- Gemini flow
if self.llm_provider == "gemini":
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo])
prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:"
response = self.llm.generate_content(prompt)
return response.text.strip()
# --- Groq flow
elif self.llm_provider == "groq":
completion = self.llm.chat.completions.create(
model=self.llm_model_name,
messages=[{"role": "system", "content": system_prompt}] + convo
)
return completion.choices[0].message.content.strip()
def fetch_problem_report_with_links(self, pr_id: str):
# --- Fetch Problem Report
pr_doc = db["problemReports"].find_one({"id": pr_id})
#print("pr_id:", pr_id)
#print("pr_doc:", pr_doc)
if not pr_doc:
return None, [], [], [], []
if "_id" in pr_doc and isinstance(pr_doc["_id"], ObjectId):
pr_doc["_id"] = str(pr_doc["_id"])
# --- Extract linked IDs
cr_ids = pr_doc.get("correctionIds", [])
fa_ids = pr_doc.get("faultAnalysisId", [])
# ensure both are lists
if isinstance(cr_ids, str):
cr_ids = [cr_ids]
elif cr_ids is None:
cr_ids = []
if isinstance(fa_ids, str):
fa_ids = [fa_ids]
elif fa_ids is None:
fa_ids = []
# --- Fetch Correction Reports
cr_docs = list(db["corrections"].find({"id": {"$in": cr_ids}})) if cr_ids else []
for doc in cr_docs:
if "_id" in doc and isinstance(doc["_id"], ObjectId):
doc["_id"] = str(doc["_id"])
# --- Fetch Fault Analysis Reports
fa_docs = list(db["faultanalysis"].find({"id": {"$in": fa_ids}})) if fa_ids else []
for doc in fa_docs:
if "_id" in doc and isinstance(doc["_id"], ObjectId):
doc["_id"] = str(doc["_id"])
print(pr_doc)
return pr_doc, cr_ids, fa_ids, cr_docs, fa_docs
def is_technical_query(self, query: str) -> bool:
"""
Classify query as TECHNICAL or NON-TECHNICAL.
"""
classification_prompt = f"""
You are a classifier. Determine if the following query is TECHNICAL
(related to software, debugging, errors, troubleshooting, fault analysis,
corrections, technical problem reports) or NON-TECHNICAL
(general questions, greetings, chit-chat, unrelated topics).
Query: "{query}"
Respond with exactly one word: "TECHNICAL" or "NON-TECHNICAL".
"""
if self.llm_provider == "gemini":
response = self.llm.generate_content(classification_prompt)
result = response.text.strip().upper()
elif self.llm_provider == "groq":
completion = self.llm.chat.completions.create(
model=self.llm_model_name,
messages=[{"role": "system", "content": classification_prompt}]
)
result = completion.choices[0].message.content.strip().upper()
return result == "TECHNICAL"
def is_followup_query(self, query: str, history: list = None) -> bool:
"""
Detect if query is a follow-up based on conversation history.
"""
if not history:
return False
classification_prompt = f"""
You are a classifier. Determine if the following user query
is a FOLLOW-UP (depends on the previous conversation)
or a NEW QUERY (can be answered independently).
Previous conversation:
{ [msg['content'] for msg in history][-3:] }
Current query: "{query}"
Respond with exactly one word: "FOLLOW-UP" or "NEW".
"""
if self.llm_provider == "gemini":
response = self.llm.generate_content(classification_prompt)
result = response.text.strip().upper()
elif self.llm_provider == "groq":
completion = self.llm.chat.completions.create(
model=self.llm_model_name,
messages=[{"role": "system", "content": classification_prompt}]
)
result = completion.choices[0].message.content.strip().upper()
print("Follow up: ", result)
return result == "FOLLOW-UP"
def ask(self, query: str, history: list = None):
print(f"\n❓ Query: {query}")
# Step 1: Classify
is_technical = self.is_technical_query(query)
is_followup = self.is_followup_query(query, history)
# Step 2: Non-technical standalone
# Step 3: Technical or follow-up
print("is_followup", is_followup)
#print("last_context", self.last_context)
print("is_technical", is_technical)
#if not is_technical:
if not is_technical and not is_followup:
print("⚠️ Non-technical standalone query → skipping Qdrant.")
system_prompt = "You are a helpful assistant. Answer clearly and concisely."
convo = [{"role": "system", "content": system_prompt},
{"role": "user", "content": query}]
if self.llm_provider == "gemini":
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo])
response = self.llm.generate_content(convo_str)
return response.text.strip(), []
elif self.llm_provider == "groq":
completion = self.llm.chat.completions.create(
model=self.llm_model_name,
messages=convo
)
return completion.choices[0].message.content.strip(), []
elif is_followup and self.last_context:
if not is_technical:
print("⚠️ Non-technical followup → skipping Qdrant.")
system_prompt = "You are a helpful assistant. Answer clearly and concisely."
convo = [{"role": "system", "content": system_prompt},
{"role": "user", "content": query}]
if self.llm_provider == "gemini":
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo])
response = self.llm.generate_content(convo_str)
return response.text.strip(), []
elif self.llm_provider == "groq":
completion = self.llm.chat.completions.create(
model=self.llm_model_name,
messages=convo
)
return completion.choices[0].message.content.strip(), []
else:
print("🔄 Follow-up query → reusing previous context.")
retrieved_context = self.last_context
context_docs = retrieved_context
elif is_followup and not self.last_context:
if not is_technical:
print("⚠️ Non-technical followup → skipping Qdrant.")
system_prompt = "You are a helpful assistant. Answer clearly and concisely."
convo = [{"role": "system", "content": system_prompt},
{"role": "user", "content": query}]
if self.llm_provider == "gemini":
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo])
response = self.llm.generate_content(convo_str)
return response.text.strip(), []
elif self.llm_provider == "groq":
completion = self.llm.chat.completions.create(
model=self.llm_model_name,
messages=convo
)
return completion.choices[0].message.content.strip(), []
else:
print("🔄 Follow-up query → without previous context.")
#retrieved_context = self.last_context
context_docs = []
else:
print("📥 New technical query → retrieving from Qdrant.")
retrieved_context = self.retrieve(query)
last_context = []
for i, doc in enumerate(retrieved_context):
last_context.append(doc['id'])
print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})")
first_doc = retrieved_context[0]
context_docs = []
# Step 2: Determine starting point based on entity type
pr_docs_to_use = []
if first_doc["entity_type"] == "ProblemReport":
pr_id = first_doc["id"]
print(f"📌 Using PR from context1: {pr_id}")
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id)
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs))
elif first_doc["entity_type"] == "Correction":
cr_id = first_doc["id"]
print(f"📌 Using CR from context1: {cr_id}")
cr_doc = self.db["corrections"].find_one({"id": cr_id})
pr_ids = cr_doc.get("problemReportIds", []) if cr_doc else []
if isinstance(pr_ids, str):
pr_ids = [pr_ids]
for pr_id in pr_ids:
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id)
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs))
elif first_doc["entity_type"] == "FaultAnalysis":
fa_id = first_doc["id"]
print(f"📌 Using FA from context1: {fa_id}")
fa_doc = self.db["faultanalysis"].find_one({"id": fa_id})
pr_ids = fa_doc.get("problemReportIds", []) if fa_doc else []
if isinstance(pr_ids, str):
pr_ids = [pr_ids]
for pr_id in pr_ids:
pr_doc, cr_ids, fa_ids, cr_docs, fa_docs = self.fetch_problem_report_with_links(pr_id)
pr_docs_to_use.append((pr_doc, cr_docs, fa_docs))
# Step 3: Build context documents for LLM, prioritize CR and FA
for pr_doc, cr_docs, fa_docs in pr_docs_to_use:
# Include FA first (analysis of problem)
for fa in fa_docs:
context_docs.append({
"entity_type": "FaultAnalysis",
"content": build_content(fa, "FaultAnalysis"),
"score": 1.0
})
# Include CR next (solutions/corrections)
for cr in cr_docs:
context_docs.append({
"entity_type": "Correction",
"content": build_content(cr, "Correction"),
"score": 1.0
})
# PR last (problem description)
if pr_doc:
context_docs.append({
"entity_type": "ProblemReport",
"content": build_content(pr_doc, "ProblemReport"),
"score": 0.9
})
print(f"✅ Total documents for LLM context: {len(context_docs)}")
if(len(last_context)>0):
self.last_context = context_docs # save for future follow-ups
if not retrieved_context:
print("💬 No relevant context found.")
return "I could not find any relevant information.", []
#print(f"✅ Using {len(retrieved_context)} documents as context.")
#answer = self.generate_answer(query, retrieved_context, history, is_followup)
answer = self.generate_answer(query, context_docs, history, is_followup)
last_context = self.last_context
#print(f"\n🤖 Answer: {answer}")
return (answer, last_context)