|
|
import os |
|
|
import torch |
|
|
from qdrant_client import QdrantClient, models |
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
from pymongo import MongoClient |
|
|
from bson import ObjectId |
|
|
from typing import List, Dict |
|
|
import google.generativeai as genai |
|
|
from groq import Groq |
|
|
|
|
|
def build_content(doc: dict, entity_type: str) -> str: |
|
|
"""Convert MongoDB document into natural text for embeddings.""" |
|
|
parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"] |
|
|
for k, v in doc.items(): |
|
|
if k in ["_id"]: |
|
|
continue |
|
|
if isinstance(v, list): |
|
|
parts.append(f"{k}: {', '.join(map(str, v))}") |
|
|
elif isinstance(v, dict): |
|
|
nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv]) |
|
|
parts.append(f"{k}: {nested}") |
|
|
else: |
|
|
if v: |
|
|
parts.append(f"{k}: {v}") |
|
|
return "\n".join(parts) |
|
|
|
|
|
|
|
|
class ErrorBot: |
|
|
"""Chatbot using RAG (Qdrant + Gemini API).""" |
|
|
|
|
|
def __init__(self, embedding_model_name: str, llm_model_name: str, google_api_key: str = None, groq_api_key: str = None, llm_provider: str = "gemini"): |
|
|
print("🚀 Initializing ErrorBot...") |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
self.embedding_model = SentenceTransformer(embedding_model_name, device=self.device) |
|
|
self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension() |
|
|
|
|
|
|
|
|
print("Connecting to Qdrant...") |
|
|
self.qdrant = QdrantClient( |
|
|
url=os.getenv("QDRANT_URL"), |
|
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
|
) |
|
|
self.collection_name = "technical_errors" |
|
|
self._setup_collection() |
|
|
|
|
|
|
|
|
self.llm_provider = llm_provider.lower() |
|
|
self.llm_model_name = llm_model_name |
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
genai.configure(api_key=google_api_key) |
|
|
self.llm = genai.GenerativeModel(llm_model_name) |
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
self.llm = Groq(api_key=groq_api_key) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported LLM provider: {self.llm_provider}") |
|
|
|
|
|
|
|
|
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
print(f"✅ ErrorBot ready with {self.llm_provider.upper()}") |
|
|
|
|
|
def _setup_collection(self): |
|
|
if not self.qdrant.collection_exists(self.collection_name): |
|
|
self.qdrant.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=models.VectorParams( |
|
|
size=self.embedding_dim, |
|
|
distance=models.Distance.COSINE, |
|
|
), |
|
|
) |
|
|
|
|
|
def ingest_from_mongodb(self, mongo_uri: str, db_name: str, batch_size: int = 32): |
|
|
client = MongoClient(mongo_uri) |
|
|
db = client[db_name] |
|
|
|
|
|
collections = { |
|
|
"ProblemReport": db["problemReports"], |
|
|
"FaultAnalysis": db["faultanalysis"], |
|
|
"Correction": db["corrections"], |
|
|
} |
|
|
|
|
|
docs = [] |
|
|
for entity_type, coll in collections.items(): |
|
|
for doc in coll.find(): |
|
|
if "_id" in doc and isinstance(doc["_id"], ObjectId): |
|
|
doc["_id"] = str(doc["_id"]) |
|
|
docs.append({"entity_type": entity_type, "data": doc}) |
|
|
|
|
|
contents = [build_content(d["data"], d["entity_type"]) for d in docs] |
|
|
|
|
|
all_embeddings = [] |
|
|
for i in range(0, len(contents), batch_size): |
|
|
batch_contents = contents[i:i + batch_size] |
|
|
embeddings = self.embedding_model.encode(batch_contents, show_progress_bar=True).tolist() |
|
|
all_embeddings.extend(embeddings) |
|
|
|
|
|
self.qdrant.upsert( |
|
|
collection_name=self.collection_name, |
|
|
points=[ |
|
|
models.PointStruct( |
|
|
id=i, |
|
|
vector=emb, |
|
|
payload={ |
|
|
"id": d["data"].get("id", str(d["data"].get("_id", i))), |
|
|
"entity_type": d["entity_type"], |
|
|
"raw": d["data"], |
|
|
"content": c, |
|
|
}, |
|
|
) |
|
|
for i, (d, emb, c) in enumerate(zip(docs, all_embeddings, contents)) |
|
|
], |
|
|
wait=True, |
|
|
) |
|
|
print(f"✅ Ingested {len(docs)} documents into '{self.collection_name}'") |
|
|
|
|
|
def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.3, rerank: bool = True): |
|
|
query_embedding = self.embedding_model.encode(query).tolist() |
|
|
hits = self.qdrant.query_points( |
|
|
collection_name=self.collection_name, |
|
|
query=query_embedding, |
|
|
limit=top_k * 3 if rerank else top_k, |
|
|
with_payload=True, |
|
|
score_threshold=score_threshold, |
|
|
).points |
|
|
|
|
|
candidates = [ |
|
|
{ |
|
|
"id": hit.payload.get("id"), |
|
|
"entity_type": hit.payload.get("entity_type", ""), |
|
|
"content": hit.payload.get("content", ""), |
|
|
"score": hit.score, |
|
|
} |
|
|
for hit in hits |
|
|
] |
|
|
|
|
|
if rerank and candidates: |
|
|
pairs = [(query, c["content"]) for c in candidates] |
|
|
scores = self.reranker.predict(pairs) |
|
|
for i, score in enumerate(scores): |
|
|
candidates[i]["rerank_score"] = float(score) |
|
|
candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) |
|
|
|
|
|
return candidates[:top_k] |
|
|
|
|
|
def generate_answer(self, query: str, context: List[Dict], history: list = None): |
|
|
context_str = "\n---\n".join( |
|
|
[f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context] |
|
|
) |
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR). |
|
|
Use the provided context and conversation history to answer the question clearly and concisely. |
|
|
If context is not relevant, say you do not have enough information. |
|
|
|
|
|
### Context |
|
|
{context_str} |
|
|
""" |
|
|
|
|
|
|
|
|
convo = [] |
|
|
if history: |
|
|
for msg in history: |
|
|
convo.append({ |
|
|
"role": "user" if msg["role"] == "user" else "assistant", |
|
|
"content": msg["content"], |
|
|
}) |
|
|
|
|
|
convo.append({"role": "user", "content": query}) |
|
|
|
|
|
|
|
|
if self.llm_provider == "gemini": |
|
|
convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo]) |
|
|
prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:" |
|
|
response = self.llm.generate_content(prompt) |
|
|
return response.text.strip() |
|
|
|
|
|
|
|
|
elif self.llm_provider == "groq": |
|
|
completion = self.llm.chat.completions.create( |
|
|
model=self.llm_model_name, |
|
|
messages=[{"role": "system", "content": system_prompt}] + convo |
|
|
) |
|
|
return completion.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
def ask(self, query: str, history: list = None): |
|
|
print(f"\n❓ Query: {query}") |
|
|
retrieved_context = self.retrieve(query) |
|
|
|
|
|
if not retrieved_context: |
|
|
print("💬 No relevant context found.") |
|
|
return "I could not find any relevant information." |
|
|
|
|
|
print(f"✅ Retrieved {len(retrieved_context)} documents.") |
|
|
for i, doc in enumerate(retrieved_context): |
|
|
print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})") |
|
|
|
|
|
answer = self.generate_answer(query, retrieved_context, history) |
|
|
print(f"\n🤖 Answer: {answer}") |
|
|
return answer |
|
|
|