Shreekant Kalwar (Nokia) commited on
Commit
cd55ee8
Β·
1 Parent(s): 92db782

new server

Browse files
Files changed (10) hide show
  1. app.py +41 -19
  2. app2.py +3 -2
  3. app3.py +86 -0
  4. backup_gemini_llm.py +38 -0
  5. bot_instance.py +45 -0
  6. main.py +9 -0
  7. main2.py +28 -0
  8. requirements.txt +2 -0
  9. util.py +206 -0
  10. util2.py +185 -0
app.py CHANGED
@@ -1,38 +1,60 @@
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
- import google.generativeai as genai
5
- import os
6
- from dotenv import load_dotenv
7
 
8
- # Load variables from .env file
9
- load_dotenv()
10
- # βœ… Configure API Key (set GOOGLE_API_KEY in environment variables)
11
- genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
12
 
13
- app = FastAPI()
14
-
15
- # βœ… Allow all origins
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
 
 
 
 
 
24
  class ChatRequest(BaseModel):
25
  message: str
 
26
 
27
- # βœ… Load Gemini model (example: gemini-1.5-flash is lightweight & fast)
28
- model = genai.GenerativeModel("gemini-2.5-flash")
29
-
30
  @app.get("/")
31
  def root():
32
  return {"status": "ok"}
33
 
34
- @app.post("/chat")
35
- def chat(request: ChatRequest):
36
- """Chat endpoint using Gemini"""
37
- response = model.generate_content(request.message)
38
- return {"reply": response.text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from fastapi.middleware.cors import CORSMiddleware
5
+ from bot_instance import gemini_bot, llama_bot # singleton ErrorBot instance
6
+ from typing import List, Optional
 
7
 
8
+ app = FastAPI(title="ErrorBot API")
 
 
 
9
 
10
+ # βœ… Allow all origins (adjust in production)
 
 
11
  app.add_middleware(
12
  CORSMiddleware,
13
+ allow_origins=["*"],
14
  allow_credentials=True,
15
  allow_methods=["*"],
16
  allow_headers=["*"],
17
  )
18
 
19
+ # ---------------- Request Models ----------------
20
+ class MessageItem(BaseModel):
21
+ role: str # "user" or "bot"
22
+ content: str
23
+
24
  class ChatRequest(BaseModel):
25
  message: str
26
+ history: Optional[List[MessageItem]] = [] # optional conversation history
27
 
28
+ # ---------------- Endpoints ----------------
 
 
29
  @app.get("/")
30
  def root():
31
  return {"status": "ok"}
32
 
33
+ # @app.post("/chat")
34
+ # def chat(request: ChatRequest):
35
+ # """
36
+ # Main chat endpoint:
37
+ # - Accepts a message and optional conversation history
38
+ # - Uses ErrorBot with RAG + LLM
39
+ # """
40
+ # history_list = [
41
+ # {"role": msg.role, "content": msg.content} for msg in request.history
42
+ # ]
43
+
44
+ # # Ask bot with history
45
+ # answer = bot.ask(request.message, history=history_list)
46
+
47
+ # return {"reply": answer}
48
+
49
+
50
+ @app.post("/gemini/chat")
51
+ def gemini_chat(request: ChatRequest):
52
+ history_list = [{"role": msg.role, "content": msg.content} for msg in request.history]
53
+ answer = gemini_bot.ask(request.message, history=history_list)
54
+ return {"reply": answer}
55
+
56
+ @app.post("/llama/chat")
57
+ def llama_chat(request: ChatRequest):
58
+ history_list = [{"role": msg.role, "content": msg.content} for msg in request.history]
59
+ answer = llama_bot.ask(request.message, history=history_list)
60
+ return {"reply": answer}
app2.py CHANGED
@@ -37,7 +37,8 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_name,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
- device_map="auto"
 
41
  )
42
  print("Model loaded βœ…")
43
 
@@ -54,4 +55,4 @@ def chat(request: ChatRequest):
54
  reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
 
56
 
57
- return {"reply": reply}
 
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_name,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
+ device_map="auto",
41
+ offload_folder="offload"
42
  )
43
  print("Model loaded βœ…")
44
 
 
55
  reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
 
57
 
58
+ return {"reply": reply}
app3.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import torch
6
+ import os
7
+
8
+ # Ensure Hugging Face cache uses a writable path
9
+ os.environ["TRANSFORMERS_CACHE"] = "./.cache"
10
+ os.environ["HF_HOME"] = "./.cache"
11
+
12
+ app = FastAPI()
13
+
14
+ # βœ… Allow all origins
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+
24
+ class ChatRequest(BaseModel):
25
+ message: str
26
+ max_tokens: int = 200 # default shorter responses for speed
27
+
28
+
29
+ # πŸ”Ή Choose a model (smaller = faster on CPU)
30
+ #model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
31
+ #model_name = "Qwen/Qwen2.5-1.5B-Instruct"
32
+ model_name = "deepseek-ai/deepseek-coder-1.3b-base"
33
+
34
+ print("πŸš€ Loading model... this may take a minute ⏳")
35
+
36
+ try:
37
+ if torch.cuda.is_available():
38
+ # βœ… GPU with quantization
39
+ from transformers import BitsAndBytesConfig
40
+ quant_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_compute_dtype=torch.float16,
43
+ )
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_name,
46
+ device_map="auto",
47
+ quantization_config=quant_config,
48
+ )
49
+ else:
50
+ # βœ… CPU fallback (no quantization)
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_name,
53
+ torch_dtype=torch.float32,
54
+ device_map="auto",
55
+ )
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+ print("βœ… Model loaded successfully!")
58
+
59
+ except Exception as e:
60
+ print("❌ Model loading failed:", str(e))
61
+ raise
62
+
63
+
64
+ @app.get("/")
65
+ def root():
66
+ return {"status": "ok"}
67
+
68
+
69
+ @app.post("/chat")
70
+ def chat(request: ChatRequest):
71
+ """Chat endpoint"""
72
+ inputs = tokenizer(request.message, return_tensors="pt").to(model.device)
73
+
74
+ outputs = model.generate(
75
+ **inputs,
76
+ max_new_tokens=request.max_tokens,
77
+ do_sample=True,
78
+ top_p=0.9,
79
+ temperature=0.7
80
+ )
81
+
82
+ # πŸ”Ή Only decode new tokens
83
+ reply_tokens = outputs[0][inputs["input_ids"].shape[1]:]
84
+ reply = tokenizer.decode(reply_tokens, skip_special_tokens=True)
85
+
86
+ return {"reply": reply}
backup_gemini_llm.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import google.generativeai as genai
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ # Load variables from .env file
9
+ load_dotenv()
10
+ # βœ… Configure API Key (set GOOGLE_API_KEY in environment variables)
11
+ genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
12
+
13
+ app = FastAPI()
14
+
15
+ # βœ… Allow all origins
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ class ChatRequest(BaseModel):
25
+ message: str
26
+
27
+ # βœ… Load Gemini model (example: gemini-1.5-flash is lightweight & fast)
28
+ model = genai.GenerativeModel("gemini-2.5-flash")
29
+
30
+ @app.get("/")
31
+ def root():
32
+ return {"status": "ok"}
33
+
34
+ @app.post("/chat")
35
+ def chat(request: ChatRequest):
36
+ """Chat endpoint using Gemini"""
37
+ response = model.generate_content(request.message)
38
+ return {"reply": response.text}
bot_instance.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from util import ErrorBot
4
+
5
+ # Load environment variables
6
+ load_dotenv()
7
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
8
+ # if not GOOGLE_API_KEY:
9
+ # raise ValueError("Set GOOGLE_API_KEY in your environment variables")
10
+
11
+ # EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5"
12
+ # LLM_MODEL = "gemini-2.5-flash" # Gemini model
13
+
14
+ # # Initialize singleton bot
15
+ # bot = ErrorBot(
16
+ # embedding_model_name=EMBEDDING_MODEL,
17
+ # llm_model_name=LLM_MODEL,
18
+ # google_api_key=GOOGLE_API_KEY,
19
+ # )
20
+
21
+ # Ingest MongoDB
22
+ # bot.ingest_from_mongodb(
23
+ # mongo_uri="mongodb+srv://dhaval:Dhaval15@cluster0.rwu1ze6.mongodb.net/prontoDB?retryWrites=true&w=majority&appName=Cluster0",
24
+ # db_name="prontoDB",
25
+ # )
26
+
27
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
28
+
29
+ EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5"
30
+
31
+ # --- Gemini Bot ---
32
+ gemini_bot = ErrorBot(
33
+ embedding_model_name=EMBEDDING_MODEL,
34
+ llm_model_name="gemini-2.5-flash",
35
+ google_api_key=GOOGLE_API_KEY,
36
+ llm_provider="gemini",
37
+ )
38
+
39
+ # --- Groq Bot (LLaMA) ---
40
+ llama_bot = ErrorBot(
41
+ embedding_model_name=EMBEDDING_MODEL,
42
+ llm_model_name="llama-3.3-70b-versatile",
43
+ groq_api_key=GROQ_API_KEY,
44
+ llm_provider="groq",
45
+ )
main.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from bot_instance import bot
2
+
3
+ history = [
4
+ {"role": "user", "content": "My name is Shreekant"},
5
+ {"role": "bot", "content": "Ok"}
6
+ ]
7
+
8
+ answer = bot.ask("What is my name?", history=history)
9
+ print(answer)
main2.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from util2 import ErrorBot
2
+
3
+ print("hello")
4
+
5
+ if __name__ == "__main__":
6
+ EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5"
7
+ LLM_MODEL = "deepseek-ai/deepseek-coder-1.3b-instruct"
8
+
9
+ bot = ErrorBot(embedding_model_name=EMBEDDING_MODEL, llm_model_name=LLM_MODEL)
10
+
11
+ # Ingest MongoDB
12
+ bot.ingest_from_mongodb(
13
+ mongo_uri="mongodb+srv://dhaval:Dhaval15@cluster0.rwu1ze6.mongodb.net/prontoDB?retryWrites=true&w=majority&appName=Cluster0",
14
+ db_name="prontoDB",
15
+ )
16
+
17
+ # Example queries
18
+ #bot.ask("who is author of problem Id: PR787807")
19
+ #bot.ask("Who is the responsiblePerson for correction CR1554963?")
20
+ bot.ask("What is the solution for this Installation failed In DCA State with NIV services in Stopped State || SprintLab837")
21
+
22
+ history = [
23
+ {"role": "user", "content": "My name is Shreekant"},
24
+ {"role": "bot", "content": "Ok"}
25
+ ]
26
+
27
+ answer = bot.ask("What is my name?", history=history)
28
+ print(answer)
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
  accelerate==1.10.1
2
  annotated-types==0.7.0
3
  anyio==4.10.0
 
 
4
  cachetools==5.5.2
5
  certifi==2025.8.3
6
  charset-normalizer==3.4.3
 
1
  accelerate==1.10.1
2
  annotated-types==0.7.0
3
  anyio==4.10.0
4
+ bitsandbytes==0.47.0
5
+ bitsandbytes-windows==0.37.5
6
  cachetools==5.5.2
7
  certifi==2025.8.3
8
  charset-normalizer==3.4.3
util.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from qdrant_client import QdrantClient, models
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
+ from pymongo import MongoClient
6
+ from bson import ObjectId
7
+ from typing import List, Dict
8
+ import google.generativeai as genai
9
+ from groq import Groq
10
+
11
+ def build_content(doc: dict, entity_type: str) -> str:
12
+ """Convert MongoDB document into natural text for embeddings."""
13
+ parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"]
14
+ for k, v in doc.items():
15
+ if k in ["_id"]: # skip ObjectId
16
+ continue
17
+ if isinstance(v, list):
18
+ parts.append(f"{k}: {', '.join(map(str, v))}")
19
+ elif isinstance(v, dict):
20
+ nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv])
21
+ parts.append(f"{k}: {nested}")
22
+ else:
23
+ if v:
24
+ parts.append(f"{k}: {v}")
25
+ return "\n".join(parts)
26
+
27
+
28
+ class ErrorBot:
29
+ """Chatbot using RAG (Qdrant + Gemini API)."""
30
+
31
+ def __init__(self, embedding_model_name: str, llm_model_name: str, google_api_key: str = None, groq_api_key: str = None, llm_provider: str = "gemini"):
32
+ print("πŸš€ Initializing ErrorBot...")
33
+
34
+ # --- Embedding model
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ print(f"Using device: {self.device}")
37
+ self.embedding_model = SentenceTransformer(embedding_model_name, device=self.device)
38
+ self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
39
+
40
+ # --- Qdrant client
41
+ print("Connecting to Qdrant...")
42
+ self.qdrant = QdrantClient(
43
+ url=os.getenv("QDRANT_URL"),
44
+ api_key=os.getenv("QDRANT_API_KEY"),
45
+ )
46
+ self.collection_name = "technical_errors"
47
+ self._setup_collection()
48
+
49
+ # --- LLM setup
50
+ self.llm_provider = llm_provider.lower()
51
+ self.llm_model_name = llm_model_name
52
+
53
+ if self.llm_provider == "gemini":
54
+ genai.configure(api_key=google_api_key)
55
+ self.llm = genai.GenerativeModel(llm_model_name)
56
+
57
+ elif self.llm_provider == "groq":
58
+ self.llm = Groq(api_key=groq_api_key)
59
+
60
+ else:
61
+ raise ValueError(f"Unsupported LLM provider: {self.llm_provider}")
62
+
63
+ # --- Cross encoder reranker
64
+ self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
65
+ print(f"βœ… ErrorBot ready with {self.llm_provider.upper()}")
66
+
67
+ def _setup_collection(self):
68
+ if not self.qdrant.collection_exists(self.collection_name):
69
+ self.qdrant.create_collection(
70
+ collection_name=self.collection_name,
71
+ vectors_config=models.VectorParams(
72
+ size=self.embedding_dim,
73
+ distance=models.Distance.COSINE,
74
+ ),
75
+ )
76
+
77
+ def ingest_from_mongodb(self, mongo_uri: str, db_name: str, batch_size: int = 32):
78
+ client = MongoClient(mongo_uri)
79
+ db = client[db_name]
80
+
81
+ collections = {
82
+ "ProblemReport": db["problemReports"],
83
+ "FaultAnalysis": db["faultanalysis"],
84
+ "Correction": db["corrections"],
85
+ }
86
+
87
+ docs = []
88
+ for entity_type, coll in collections.items():
89
+ for doc in coll.find():
90
+ if "_id" in doc and isinstance(doc["_id"], ObjectId):
91
+ doc["_id"] = str(doc["_id"])
92
+ docs.append({"entity_type": entity_type, "data": doc})
93
+
94
+ contents = [build_content(d["data"], d["entity_type"]) for d in docs]
95
+
96
+ all_embeddings = []
97
+ for i in range(0, len(contents), batch_size):
98
+ batch_contents = contents[i:i + batch_size]
99
+ embeddings = self.embedding_model.encode(batch_contents, show_progress_bar=True).tolist()
100
+ all_embeddings.extend(embeddings)
101
+
102
+ self.qdrant.upsert(
103
+ collection_name=self.collection_name,
104
+ points=[
105
+ models.PointStruct(
106
+ id=i,
107
+ vector=emb,
108
+ payload={
109
+ "id": d["data"].get("id", str(d["data"].get("_id", i))),
110
+ "entity_type": d["entity_type"],
111
+ "raw": d["data"],
112
+ "content": c,
113
+ },
114
+ )
115
+ for i, (d, emb, c) in enumerate(zip(docs, all_embeddings, contents))
116
+ ],
117
+ wait=True,
118
+ )
119
+ print(f"βœ… Ingested {len(docs)} documents into '{self.collection_name}'")
120
+
121
+ def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.3, rerank: bool = True):
122
+ query_embedding = self.embedding_model.encode(query).tolist()
123
+ hits = self.qdrant.query_points(
124
+ collection_name=self.collection_name,
125
+ query=query_embedding,
126
+ limit=top_k * 3 if rerank else top_k,
127
+ with_payload=True,
128
+ score_threshold=score_threshold,
129
+ ).points
130
+
131
+ candidates = [
132
+ {
133
+ "id": hit.payload.get("id"),
134
+ "entity_type": hit.payload.get("entity_type", ""),
135
+ "content": hit.payload.get("content", ""),
136
+ "score": hit.score,
137
+ }
138
+ for hit in hits
139
+ ]
140
+
141
+ if rerank and candidates:
142
+ pairs = [(query, c["content"]) for c in candidates]
143
+ scores = self.reranker.predict(pairs)
144
+ for i, score in enumerate(scores):
145
+ candidates[i]["rerank_score"] = float(score)
146
+ candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
147
+
148
+ return candidates[:top_k]
149
+
150
+ def generate_answer(self, query: str, context: List[Dict], history: list = None):
151
+ context_str = "\n---\n".join(
152
+ [f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context]
153
+ )
154
+
155
+ # --- System prompt
156
+ system_prompt = f"""
157
+ You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR).
158
+ Use the provided context and conversation history to answer the question clearly and concisely.
159
+ If context is not relevant, say you do not have enough information.
160
+
161
+ ### Context
162
+ {context_str}
163
+ """
164
+
165
+ # --- Conversation history in list-of-dicts format
166
+ convo = []
167
+ if history:
168
+ for msg in history:
169
+ convo.append({
170
+ "role": "user" if msg["role"] == "user" else "assistant",
171
+ "content": msg["content"],
172
+ })
173
+
174
+ convo.append({"role": "user", "content": query})
175
+
176
+ # --- Gemini flow
177
+ if self.llm_provider == "gemini":
178
+ convo_str = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in convo])
179
+ prompt = system_prompt + "\n\n" + convo_str + "\nAssistant:"
180
+ response = self.llm.generate_content(prompt)
181
+ return response.text.strip()
182
+
183
+ # --- Groq flow
184
+ elif self.llm_provider == "groq":
185
+ completion = self.llm.chat.completions.create(
186
+ model=self.llm_model_name,
187
+ messages=[{"role": "system", "content": system_prompt}] + convo
188
+ )
189
+ return completion.choices[0].message.content.strip()
190
+
191
+
192
+ def ask(self, query: str, history: list = None):
193
+ print(f"\n❓ Query: {query}")
194
+ retrieved_context = self.retrieve(query)
195
+
196
+ if not retrieved_context:
197
+ print("πŸ’¬ No relevant context found.")
198
+ return "I could not find any relevant information."
199
+
200
+ print(f"βœ… Retrieved {len(retrieved_context)} documents.")
201
+ for i, doc in enumerate(retrieved_context):
202
+ print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})")
203
+
204
+ answer = self.generate_answer(query, retrieved_context, history)
205
+ print(f"\nπŸ€– Answer: {answer}")
206
+ return answer
util2.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from qdrant_client import QdrantClient, models
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
+ from pymongo import MongoClient
6
+ from bson import ObjectId
7
+ from typing import List, Dict
8
+ import google.generativeai as genai
9
+ from groq import Groq
10
+
11
+ def build_content(doc: dict, entity_type: str) -> str:
12
+ """Convert MongoDB document into natural text for embeddings."""
13
+ parts = [f"{entity_type} ID: {doc.get('id', str(doc.get('_id', '')))}"]
14
+ for k, v in doc.items():
15
+ if k in ["_id"]: # skip ObjectId
16
+ continue
17
+ if isinstance(v, list):
18
+ parts.append(f"{k}: {', '.join(map(str, v))}")
19
+ elif isinstance(v, dict):
20
+ nested = "; ".join([f"{nk}: {nv}" for nk, nv in v.items() if nv])
21
+ parts.append(f"{k}: {nested}")
22
+ else:
23
+ if v:
24
+ parts.append(f"{k}: {v}")
25
+ return "\n".join(parts)
26
+
27
+
28
+ class ErrorBot:
29
+ """Chatbot using RAG (Qdrant + Gemini API)."""
30
+
31
+ def __init__(self, embedding_model_name: str, llm_model_name: str, google_api_key: str):
32
+ print("πŸš€ Initializing ErrorBot...")
33
+
34
+ # --- Embedding model
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ print(f"Using device: {self.device}")
37
+ self.embedding_model = SentenceTransformer(embedding_model_name, device=self.device)
38
+ self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
39
+
40
+ # --- Qdrant client
41
+ print("Connecting to Qdrant...")
42
+ self.qdrant = QdrantClient(
43
+ url=os.getenv("QDRANT_URL"),
44
+ api_key=os.getenv("QDRANT_API_KEY"),
45
+ )
46
+ self.collection_name = "technical_errors"
47
+ self._setup_collection()
48
+
49
+ # --- Gemini LLM
50
+ genai.configure(api_key=google_api_key)
51
+ self.llm_model_name = llm_model_name
52
+ self.llm = genai.GenerativeModel(llm_model_name)
53
+
54
+ # --- Cross encoder reranker
55
+ print("Loading cross-encoder reranker...")
56
+ self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
57
+
58
+ print("βœ… ErrorBot ready.")
59
+
60
+ def _setup_collection(self):
61
+ if not self.qdrant.collection_exists(self.collection_name):
62
+ self.qdrant.create_collection(
63
+ collection_name=self.collection_name,
64
+ vectors_config=models.VectorParams(
65
+ size=self.embedding_dim,
66
+ distance=models.Distance.COSINE,
67
+ ),
68
+ )
69
+
70
+ def ingest_from_mongodb(self, mongo_uri: str, db_name: str, batch_size: int = 32):
71
+ client = MongoClient(mongo_uri)
72
+ db = client[db_name]
73
+
74
+ collections = {
75
+ "ProblemReport": db["problemReports"],
76
+ "FaultAnalysis": db["faultanalysis"],
77
+ "Correction": db["corrections"],
78
+ }
79
+
80
+ docs = []
81
+ for entity_type, coll in collections.items():
82
+ for doc in coll.find():
83
+ if "_id" in doc and isinstance(doc["_id"], ObjectId):
84
+ doc["_id"] = str(doc["_id"])
85
+ docs.append({"entity_type": entity_type, "data": doc})
86
+
87
+ contents = [build_content(d["data"], d["entity_type"]) for d in docs]
88
+
89
+ all_embeddings = []
90
+ for i in range(0, len(contents), batch_size):
91
+ batch_contents = contents[i:i + batch_size]
92
+ embeddings = self.embedding_model.encode(batch_contents, show_progress_bar=True).tolist()
93
+ all_embeddings.extend(embeddings)
94
+
95
+ self.qdrant.upsert(
96
+ collection_name=self.collection_name,
97
+ points=[
98
+ models.PointStruct(
99
+ id=i,
100
+ vector=emb,
101
+ payload={
102
+ "id": d["data"].get("id", str(d["data"].get("_id", i))),
103
+ "entity_type": d["entity_type"],
104
+ "raw": d["data"],
105
+ "content": c,
106
+ },
107
+ )
108
+ for i, (d, emb, c) in enumerate(zip(docs, all_embeddings, contents))
109
+ ],
110
+ wait=True,
111
+ )
112
+ print(f"βœ… Ingested {len(docs)} documents into '{self.collection_name}'")
113
+
114
+ def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.3, rerank: bool = True):
115
+ query_embedding = self.embedding_model.encode(query).tolist()
116
+ hits = self.qdrant.query_points(
117
+ collection_name=self.collection_name,
118
+ query=query_embedding,
119
+ limit=top_k * 3 if rerank else top_k,
120
+ with_payload=True,
121
+ score_threshold=score_threshold,
122
+ ).points
123
+
124
+ candidates = [
125
+ {
126
+ "id": hit.payload.get("id"),
127
+ "entity_type": hit.payload.get("entity_type", ""),
128
+ "content": hit.payload.get("content", ""),
129
+ "score": hit.score,
130
+ }
131
+ for hit in hits
132
+ ]
133
+
134
+ if rerank and candidates:
135
+ pairs = [(query, c["content"]) for c in candidates]
136
+ scores = self.reranker.predict(pairs)
137
+ for i, score in enumerate(scores):
138
+ candidates[i]["rerank_score"] = float(score)
139
+ candidates = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
140
+
141
+ return candidates[:top_k]
142
+
143
+ def generate_answer(self, query: str, context: List[Dict], history: list = None):
144
+ context_str = "\n---\n".join(
145
+ [f"{c['entity_type']} (Score: {c['score']:.2f}):\n{c['content']}" for c in context]
146
+ )
147
+
148
+ convo_str = ""
149
+ if history:
150
+ for msg in history:
151
+ role = "User" if msg["role"] == "user" else "Assistant"
152
+ convo_str += f"{role}: {msg['content']}\n"
153
+
154
+ convo_str += f"User: {query}\nAssistant:"
155
+
156
+ prompt = f"""
157
+ You are a technical assistant. You have access to Problem Reports (PR), Fault Analyses (FA), and Corrections (CR).
158
+ Use the provided context and conversation history to answer the question clearly and concisely.
159
+ If context is not relevant, say you do not have enough information.
160
+
161
+ ### Context
162
+ {context_str}
163
+
164
+ ### Conversation
165
+ {convo_str}
166
+ """
167
+
168
+ response = self.llm.generate_content(prompt)
169
+ return response.text.strip()
170
+
171
+ def ask(self, query: str, history: list = None):
172
+ print(f"\n❓ Query: {query}")
173
+ retrieved_context = self.retrieve(query)
174
+
175
+ if not retrieved_context:
176
+ print("πŸ’¬ No relevant context found.")
177
+ return "I could not find any relevant information."
178
+
179
+ print(f"βœ… Retrieved {len(retrieved_context)} documents.")
180
+ for i, doc in enumerate(retrieved_context):
181
+ print(f" - Context {i+1} ({doc['entity_type']}, ID: {doc['id']}, Score: {doc['score']:.2f})")
182
+
183
+ answer = self.generate_answer(query, retrieved_context, history)
184
+ print(f"\nπŸ€– Answer: {answer}")
185
+ return answer