SRA25 commited on
Commit
a8e8980
·
verified ·
1 Parent(s): a48bc9f

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/config.py +38 -0
  2. src/database_telemetry.db +0 -0
  3. src/langgraph_init.py +613 -0
src/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ class Config:
4
+ # Security settings
5
+ RATE_LIMIT_REQUESTS = 100 # Max requests per window
6
+ RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
7
+
8
+ # Content moderation settings
9
+ BLACKLIST_WORDS = [
10
+ "password", "credit card", "ssn", "social security",
11
+ "exploit", "hack", "bypass", "ignore previous", "ignore above",
12
+ "suicide", "self-harm", "kill myself", "hurt myself",
13
+ "bomb", "terrorist", "attack", "shoot", "weapon"
14
+ ]
15
+
16
+ SUSPICIOUS_PATTERNS = [
17
+ r"(?i)(ignore|disregard).*(previous|above|instructions)",
18
+ r"(?i)(system|assistant).*(prompt|instructions)",
19
+ r"(?i)(as an? ai|you are an? ai)",
20
+ r"(?i)(human|user).*response",
21
+ r"(?i)(role play|pretend|act as)",
22
+ r"(?i)(hack|exploit|vulnerability|bypass)",
23
+ r"(?i)(password|credentials|login|admin)"
24
+ ]
25
+
26
+ # Allowed topics (optional allowlist approach)
27
+ ALLOWED_TOPICS = [
28
+ "general knowledge", "science", "technology", "history",
29
+ "culture", "education", "creative writing", "programming"
30
+ ]
31
+
32
+ # Response templates for restricted content
33
+ RESTRICTED_RESPONSES = {
34
+ "injection": "I cannot process this request as it appears to be attempting to manipulate the system.",
35
+ "harmful": "I cannot provide information that may be harmful or dangerous.",
36
+ "sensitive": "I cannot provide sensitive personal or security information.",
37
+ "general": "This request has been restricted due to content policy violations."
38
+ }
src/database_telemetry.db ADDED
Binary file (81.9 kB). View file
 
src/langgraph_init.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, validator
2
+ from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
3
+ import uuid
4
+ import io
5
+ import os
6
+ import PyPDF2
7
+ import re
8
+ import logging
9
+ import time
10
+ from docx import Document as dx
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ import tempfile
13
+ import faiss
14
+ from langchain_community.docstore.in_memory import InMemoryDocstore
15
+ from langchain_community.vectorstores import FAISS
16
+ from langchain_core.prompts import PromptTemplate
17
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
18
+ from langchain_core.documents import Document
19
+ from langchain_huggingface import HuggingFaceEmbeddings
20
+ from langgraph.checkpoint.memory import MemorySaver
21
+ from langgraph.graph import StateGraph, END
22
+ from sqlalchemy import create_engine, Column, String, Integer, DateTime, ForeignKey, Text
23
+ from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON
24
+ # from sqlalchemy.ext.declarative import declarative_base
25
+ from sqlalchemy.orm import sessionmaker, relationship
26
+ from sentence_transformers import SentenceTransformer
27
+ from huggingface_hub import login
28
+ from langchain_google_genai import ChatGoogleGenerativeAI
29
+ import datetime
30
+ from enum import Enum as PyEnum
31
+ from sqlalchemy.orm import DeclarativeBase
32
+ from config import Config
33
+ from functools import lru_cache
34
+ from dotenv import load_dotenv
35
+
36
+ load_dotenv()
37
+ # hf_token = os.environ.get("hf_user_token") or os.getenv("hf_user_token")
38
+ def login_hf():
39
+ hf_token = os.environ.get("hf_user_token") or os.getenv("hf_user_token")
40
+ if hf_token:
41
+ login(token=hf_token,add_to_git_credential=True)
42
+ else:
43
+ raise ValueError("HF_TOKEN environment variable is not set.")
44
+
45
+ T = TypeVar("T")
46
+ # --- 1. Database Setup ---
47
+ DATABASE_URL = "sqlite:///src/database_telemetry.db"
48
+ if os.path.exists(DATABASE_URL):
49
+ engine = create_engine(DATABASE_URL)
50
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
51
+ else:
52
+ DATABASE_URL = "sqlite:///database_telemetry.db"
53
+ engine = create_engine(DATABASE_URL)
54
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
55
+
56
+ class Base(DeclarativeBase):
57
+ pass
58
+
59
+ class FeedbackScore(PyEnum):
60
+ POSITIVE = 1
61
+ NEGATIVE = -1
62
+
63
+ class Telemetry(Base):
64
+ __tablename__ = "telemetry_table"
65
+ transaction_id = Column(String, primary_key=True)
66
+ session_id = Column(String)
67
+ user_question = Column(Text)
68
+ response = Column(Text)
69
+ context = Column(Text)
70
+ model_name = Column(String)
71
+ input_tokens = Column(Integer)
72
+ output_tokens = Column(Integer)
73
+ total_tokens = Column(Integer)
74
+ latency = Column(Integer)
75
+ dtcreatedon = Column(DateTime)
76
+
77
+ feedback = relationship("Feedback", back_populates="telemetry_entry", uselist=False)
78
+
79
+ class Feedback(Base):
80
+ __tablename__ = "feedback_table"
81
+ id = Column(Integer, primary_key=True, autoincrement=True)
82
+ telemetry_entry_id = Column(String, ForeignKey("telemetry_table.transaction_id"), nullable=False, unique=True)
83
+ feedback_score = Column(Integer, nullable=False)
84
+ feedback_text = Column(Text, nullable=True)
85
+ user_query = Column(Text, nullable=False)
86
+ llm_response = Column(Text, nullable=False)
87
+ timestamp = Column(DateTime, default=datetime.datetime.now)
88
+
89
+ telemetry_entry = relationship("Telemetry", back_populates="feedback")
90
+
91
+ class ConversationHistory(Base):
92
+ __tablename__ = "conversation_history"
93
+ session_id = Column(String, primary_key=True)
94
+ messages = Column(SQLiteJSON, nullable=False)
95
+ last_updated = Column(DateTime, default=datetime.datetime.now)
96
+
97
+ # --- 2. Initialize LLM and Embeddings ---
98
+
99
+ gak = os.environ.get("Gapi_key")
100
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite",google_api_key=gak)
101
+
102
+ def init_embed():
103
+ embedding_model = HuggingFaceEmbeddings(
104
+ model_name="ibm-granite/granite-embedding-english-r2",
105
+ model_kwargs={'device': 'cpu'},
106
+ encode_kwargs={'normalize_embeddings': False}
107
+ )
108
+ # embedding_model = SentenceTransformer("ibm-granite/granite-embedding-english-r2")
109
+ return embedding_model
110
+ # embedding_model = SentenceTransformer("ibm-granite/granite-embedding-english-r2")
111
+
112
+ # my_model_name = "gemma3:1b-it-qat"
113
+ # llm = ChatOllama(model=my_model_name)
114
+
115
+
116
+ # --- 3. LangGraph State and Workflow ---
117
+ class GraphState(TypedDict):
118
+ chat_history: List[Dict[str, Any]]
119
+ retrieved_documents: List[str]
120
+ user_question: str
121
+ session_id: str
122
+ telemetry_id: Optional[str] = None
123
+
124
+ vectorstore_retriever = None
125
+ compiled_app = None
126
+ memory = MemorySaver()
127
+
128
+ # --- 4. LangGraph Nodes ---
129
+ def retrieve_documents(state: GraphState):
130
+ global vectorstore_retriever
131
+ user_question = state["user_question"]
132
+ if vectorstore_retriever is None:
133
+ raise ValueError("Knowledge base not loaded. Please upload documents first.")
134
+ retrieved_docs = vectorstore_retriever.as_retriever(search_type="mmr", search_kwargs={"k": 3})
135
+ top_docs = retrieved_docs.invoke(user_question)
136
+ print("Top Docs: ", top_docs)
137
+ retrieved_docs_content = [doc.page_content if doc.page_content else doc for doc in top_docs]
138
+ print("retrieved_documents List: ", retrieved_docs_content)
139
+ return {"retrieved_documents": retrieved_docs_content}
140
+
141
+ def generate_response(state: GraphState):
142
+ global llm
143
+ user_question = state["user_question"]
144
+ retrieved_documents = state["retrieved_documents"]
145
+
146
+ formatted_chat_history = []
147
+ for msg in state["chat_history"]:
148
+ if msg['role'] == 'user':
149
+ formatted_chat_history.append(HumanMessage(content=msg['content']))
150
+ elif msg['role'] == 'assistant':
151
+ formatted_chat_history.append(AIMessage(content=msg['content']))
152
+
153
+ if not retrieved_documents:
154
+ response_content = "I couldn't find any relevant information in the uploaded documents for your question. Can you please rephrase or provide more context?"
155
+ response_obj = AIMessage(content=response_content)
156
+ else:
157
+ context = "\n\n".join(retrieved_documents)
158
+ template = """
159
+ You are a helpful AI assistant. Answer the user's question based on the provided context {context} and the conversation history {chat_history}.
160
+ If the answer is not in the context, state that you don't have enough information.
161
+ Do not make up answers. Only use the given context and chat_history.
162
+ Remove unwanted words like 'Response:' or 'Answer:' from answers.
163
+ \n\nHere is the Question:\n{user_question}
164
+ """
165
+ rag_prompt = PromptTemplate(
166
+ input_variables=["context", "chat_history", "user_question"],
167
+ template=template
168
+ )
169
+ rag_chain = rag_prompt | llm
170
+ time.sleep(3)
171
+ response_obj = rag_chain.invoke({
172
+ "context": [SystemMessage(content=context)],
173
+ "chat_history": formatted_chat_history,
174
+ "user_question": [HumanMessage(content=user_question)]
175
+ })
176
+
177
+ telemetry_data = response_obj.model_dump()
178
+ input_tokens = telemetry_data.get('usage_metadata', {}).get('input_tokens', 0)
179
+ output_tokens = telemetry_data.get('usage_metadata', {}).get('output_tokens', 0)
180
+ total_tokens = telemetry_data.get('usage_metadata', {}).get('total_tokens', 0)
181
+ model_name = telemetry_data.get('response_metadata', {}).get('model', 'unknown')
182
+ total_duration = telemetry_data.get('response_metadata', {}).get('total_duration', 0)
183
+
184
+ db = SessionLocal()
185
+ transaction_id = str(uuid.uuid4())
186
+ try:
187
+ telemetry_record = Telemetry(
188
+ transaction_id=transaction_id,
189
+ session_id=state.get("session_id"),
190
+ user_question=user_question,
191
+ response=response_obj.content,
192
+ context="\n\n".join(retrieved_documents) if retrieved_documents else "No documents retrieved",
193
+ model_name=model_name,
194
+ input_tokens=input_tokens,
195
+ output_tokens=output_tokens,
196
+ total_tokens=total_tokens,
197
+ latency=total_duration,
198
+ dtcreatedon=datetime.datetime.now()
199
+ )
200
+ db.add(telemetry_record)
201
+
202
+ new_messages = state["chat_history"] + [
203
+ {"role": "user", "content": user_question},
204
+ {"role": "assistant", "content": response_obj.content, "telemetry_id": transaction_id}
205
+ ]
206
+
207
+ # --- FIX: Refactored Database Save Logic ---
208
+ print(f"Saving conversation for session_id: {state.get('session_id')}")
209
+ conversation_entry = db.query(ConversationHistory).filter_by(session_id=state.get("session_id")).first()
210
+ if conversation_entry:
211
+ print(f"Updating existing conversation for session_id: {state.get('session_id')}")
212
+ conversation_entry.messages = new_messages
213
+ conversation_entry.last_updated = datetime.datetime.now()
214
+ else:
215
+ print(f"Creating new conversation for session_id: {state.get('session_id')}")
216
+ new_conversation_entry = ConversationHistory(
217
+ session_id=state.get("session_id"),
218
+ messages=new_messages,
219
+ last_updated=datetime.datetime.now()
220
+ )
221
+ db.add(new_conversation_entry)
222
+
223
+ db.commit()
224
+ print(f"Successfully saved conversation for session_id: {state.get('session_id')}")
225
+
226
+ except Exception as e:
227
+ db.rollback()
228
+ print(f"***CRITICAL ERROR***: Failed to save data to database. Error: {e}")
229
+ finally:
230
+ db.close()
231
+
232
+ return {
233
+ "chat_history": new_messages,
234
+ "telemetry_id": transaction_id
235
+ }
236
+
237
+
238
+ # Build and compile the workflow
239
+ workflow = StateGraph(GraphState)
240
+ workflow.add_node("retrieve", retrieve_documents)
241
+ workflow.add_node("generate", generate_response)
242
+ workflow.set_entry_point("retrieve")
243
+ workflow.add_edge("retrieve", "generate")
244
+ workflow.add_edge("generate", END)
245
+ compiled_app = workflow.compile(checkpointer=memory)
246
+
247
+
248
+ # --- 5. API Models ---
249
+ class ChatHistoryEntry(BaseModel):
250
+ role: str
251
+ content: str
252
+ telemetry_id: Optional[str] = None
253
+
254
+ class ChatRequest(BaseModel):
255
+ user_question: str
256
+ session_id: str
257
+ chat_history: Optional[List[ChatHistoryEntry]] = Field(default_factory=list)
258
+
259
+ @validator('user_question')
260
+ def validate_prompt(cls, v):
261
+ v = v.strip()
262
+ if not v:
263
+ raise ValueError('Question cannot be empty')
264
+ return v
265
+
266
+ class ChatResponse(BaseModel):
267
+ ai_response: str
268
+ updated_chat_history: List[ChatHistoryEntry]
269
+ telemetry_entry_id: str
270
+ is_restricted: bool = False
271
+ moderation_reason: Optional[str] = None
272
+
273
+ class FeedbackRequest(BaseModel):
274
+ session_id: str
275
+ telemetry_entry_id: str
276
+ feedback_score: int
277
+ feedback_text: Optional[str] = None
278
+
279
+ class ConversationSummary(BaseModel):
280
+ session_id: str
281
+ title: str
282
+
283
+ # Content Moderation Service
284
+ class ContentModerator:
285
+ def __init__(self):
286
+ self.blacklist_words = Config.BLACKLIST_WORDS
287
+ self.suspicious_patterns = [re.compile(pattern, re.IGNORECASE)
288
+ for pattern in Config.SUSPICIOUS_PATTERNS]
289
+ self.allowed_topics = Config.ALLOWED_TOPICS
290
+
291
+ def contains_blacklisted_words(self, text: str) -> bool:
292
+ text_lower = text.lower()
293
+ return any(word in text_lower for word in self.blacklist_words)
294
+
295
+ def contains_suspicious_patterns(self, text: str) -> bool:
296
+ return any(pattern.search(text) for pattern in self.suspicious_patterns)
297
+
298
+ def has_encoding_attempts(self, text: str) -> bool:
299
+ # Check for encoding/obfuscation attempts
300
+ encoding_patterns = [
301
+ r"%[0-9A-Fa-f]{2}", # URL encoding
302
+ r"\\x[0-9A-Fa-f]{2}", # Hex encoding
303
+ r"&#x?[0-9a-f]+;", # HTML entities
304
+ ]
305
+ return any(re.search(pattern, text) for pattern in encoding_patterns)
306
+
307
+ def has_excessive_special_chars(self, text: str) -> bool:
308
+ # Check for excessive special characters that might indicate obfuscation
309
+ special_chars = len(re.findall(r'[^\w\s]', text))
310
+ total_chars = len(text)
311
+ if total_chars == 0:
312
+ return False
313
+ return (special_chars / total_chars) > 0.3 # More than 30% special chars
314
+
315
+ def is_prompt_injection(self, text: str) -> bool:
316
+ # Check for common prompt injection techniques
317
+ injection_indicators = [
318
+ self.contains_suspicious_patterns(text),
319
+ self.contains_blacklisted_words(text),
320
+ self.has_encoding_attempts(text),
321
+ self.has_excessive_special_chars(text)
322
+ ]
323
+ return any(injection_indicators)
324
+
325
+ def moderate_content(self, text: str) -> Dict[str, Any]:
326
+ # Check for prompt injection first
327
+ if self.is_prompt_injection(text):
328
+ return {
329
+ "is_restricted": True,
330
+ "reason": "Potential prompt injection detected",
331
+ "response_type": "injection"
332
+ }
333
+
334
+ # Check for harmful content
335
+ if self.contains_blacklisted_words(text):
336
+ harmful_words = [word for word in self.blacklist_words if word in text.lower()]
337
+ return {
338
+ "is_restricted": True,
339
+ "reason": f"Contains restricted content: {', '.join(harmful_words[:3])}",
340
+ "response_type": "harmful"
341
+ }
342
+
343
+ return {"is_restricted": False, "reason": None, "response_type": None}
344
+
345
+ moderator = ContentModerator()
346
+
347
+ @lru_cache(maxsize=5)
348
+ def process_text(file):
349
+ string_data = (file.read()).decode("utf-8")
350
+ return string_data
351
+
352
+ @lru_cache(maxsize=5)
353
+ def process_pdf(file):
354
+ pdf_bytes = io.BytesIO(file.read())
355
+ reader = PyPDF2.PdfReader(pdf_bytes)
356
+ pdf_text = "".join([page.extract_text() + "\n" for page in reader.pages])
357
+ return pdf_text
358
+
359
+ @lru_cache(maxsize=5)
360
+ def process_docx(file):
361
+ docx_bytes = io.BytesIO(file.read())
362
+ docx_docs = dx(docx_bytes)
363
+ docx_content = "\n".join([para.text for para in docx_docs.paragraphs])
364
+ return docx_content
365
+
366
+
367
+ # @app.post("/upload-documents")
368
+ def upload_documents(files):
369
+ global vectorstore_retriever
370
+
371
+ embedding_model = init_embed()
372
+
373
+ all_documents = []
374
+ for uploaded_file in files:
375
+
376
+ if uploaded_file.type == "text/plain":
377
+ # string_data = ( uploaded_file.read()).decode("utf-8")
378
+ string_data = process_text(uploaded_file)
379
+ all_documents.append(Document(page_content=string_data, metadata={"source": uploaded_file.name}))
380
+ elif uploaded_file.type == "application/pdf":
381
+ pdf_text = process_pdf(uploaded_file)
382
+
383
+ # pdf_bytes = io.BytesIO( uploaded_file.read())
384
+ # reader = PyPDF2.PdfReader(pdf_bytes)
385
+ # pdf_text = "".join([page.extract_text() + "\n" for page in reader.pages])
386
+ all_documents.append(Document(page_content=pdf_text, metadata={"source": uploaded_file.name}))
387
+
388
+ elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
389
+ docx_content = process_docx(uploaded_file)
390
+
391
+ # docx_bytes = io.BytesIO( uploaded_file.read())
392
+ # docx_docs = dx(docx_bytes)
393
+ # docx_content = "\n".join([para.text for para in docx_docs.paragraphs])
394
+ all_documents.append(Document(page_content=docx_content, metadata={"source": uploaded_file.name}))
395
+ else:
396
+ raise Exception(status_code=400, detail=f"Unsupported file type: {uploaded_file.name} ({uploaded_file.type})")
397
+
398
+ if not all_documents:
399
+ raise Exception(status_code=400, detail="No supported documents uploaded.")
400
+
401
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
402
+ text_chunks = text_splitter.split_documents(all_documents)
403
+ print("text_chucks: ", text_chunks[:100])
404
+
405
+ processed_chunks_with_ids = []
406
+ for i, chunk in enumerate(text_chunks):
407
+ # Generate a unique ID for each chunk
408
+ # Option 1 (Recommended): Using UUID for global uniqueness
409
+ # chunk_id = str(uuid.uuid4())
410
+
411
+ # Option 2 (Alternative): Combining source file path with chunk index
412
+ # This is good if you want IDs to be deterministic based on file/chunk.
413
+ # You might need to make the file path more robust (e.g., hash it or normalize it).
414
+ file_source = chunk.metadata.get('source', 'unknown_source')
415
+ chunk_id = f"{file_source.replace('.','_')}_chunk_{i}"
416
+
417
+ # Add the unique ID to the chunk's metadata
418
+ # It's good practice to keep original metadata and just add your custom ID.
419
+ chunk.metadata['doc_id'] = chunk_id
420
+
421
+
422
+ processed_chunks_with_ids.append(chunk)
423
+ # embeddings = [embedding_model.encode(doc_chunks.page_content, convert_to_numpy=True) for doc_chunks in processed_chunks_with_ids]
424
+
425
+ print(f"Split {len(processed_chunks_with_ids)} chunks.")
426
+ print(f"Assigned unique 'doc_id' to each chunk in metadata.")
427
+ # dimension = 768
428
+ # # hnsw_m = 32
429
+ # # index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT)
430
+ # index = faiss.IndexFlatL2(dimension)
431
+ # vector_store = FAISS(
432
+ # embedding_function=embedding_model.embed_query,
433
+ # index=index,
434
+ # docstore= InMemoryDocstore(),
435
+ # index_to_docstore_id={}
436
+ # )
437
+ vectorstore = FAISS.from_documents(documents=processed_chunks_with_ids, embedding=embedding_model)
438
+ vectorstore.add_documents(processed_chunks_with_ids, ids = [cid.metadata['doc_id'] for cid in processed_chunks_with_ids])
439
+ # vectorstore_retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
440
+ vectorstore_retriever = vectorstore
441
+ msg = f"Successfully processed {len(files)} documents and created knowledge base."
442
+ return msg
443
+
444
+ # @app.post("/chat", response_model=ChatResponse)
445
+ def chat_with_rag(chatdata):
446
+ global compiled_app
447
+ global vectorstore_retriever
448
+ if vectorstore_retriever is None:
449
+ raise Exception(status_code=400, detail="Knowledge base not loaded. Please upload documents first.")
450
+ print(f"Received request: {chatdata}")
451
+ # moderation_result = moderator.moderate_content(request.user_question)
452
+ # if moderation_result["is_restricted"]:
453
+ # # Get appropriate response based on restriction type
454
+ # response_type = moderation_result.get("response_type", "general")
455
+ # response_text = Config.RESTRICTED_RESPONSES.get(
456
+ # response_type,
457
+ # Config.RESTRICTED_RESPONSES["general"]
458
+ # )
459
+
460
+ # logger.warning(
461
+ # f"Restricted query: {request.prompt[:100]}... "
462
+ # f"Reason: {moderation_result['reason']}"
463
+ # )
464
+
465
+ # return ChatResponse(
466
+ # ai_response=response_text,
467
+ # updated_chat_history=[],
468
+ # telemetry_entry_id=request.session_id,
469
+ # is_restricted=True,
470
+ # moderation_reason=moderation_result["reason"],
471
+ # )
472
+ print("✅ Question passed the RAI check.........")
473
+ initial_state = {
474
+ # "chat_history": [msg.model_dump() for msg in chatdata.get('chat_history')],
475
+ "chat_history": [msg for msg in chatdata.get('chat_history')],
476
+ "retrieved_documents": [],
477
+ "user_question": chatdata.get('user_question'),
478
+ "session_id": chatdata.get('session_id')
479
+ }
480
+
481
+ try:
482
+ config = {"configurable": {"thread_id": chatdata.get('session_id')}}
483
+ final_state = compiled_app.invoke(initial_state, config=config)
484
+
485
+ ai_response_message = final_state["chat_history"][-1]["content"]
486
+ updated_chat_history_dicts = final_state["chat_history"]
487
+
488
+ response_chat = ChatResponse(
489
+ ai_response=ai_response_message,
490
+ updated_chat_history=updated_chat_history_dicts,
491
+ telemetry_entry_id=final_state.get("telemetry_id"),
492
+ is_restricted=False,
493
+ )
494
+ return response_chat.dict()
495
+ except Exception as e:
496
+ print(f"Internal Server Error: {e}")
497
+ raise Exception(status_code=500, detail=f"An error occurred during chat processing: {e}")
498
+
499
+ # @app.post("/feedback")
500
+ # def submit_feedback(feedbackdata):
501
+ # db = SessionLocal()
502
+ # try:
503
+ # telemetry_record = db.query(Telemetry).filter(
504
+ # Telemetry.transaction_id == feedbackdata.telemetry_entry_id,
505
+ # Telemetry.session_id == feedbackdata.session_id
506
+ # ).first()
507
+
508
+ # if not telemetry_record:
509
+ # raise Exception(status_code=404, detail="Telemetry entry not found or session ID mismatch.")
510
+
511
+ # existing_feedback = db.query(Feedback).filter(
512
+ # Feedback.telemetry_entry_id == feedbackdata.telemetry_entry_id
513
+ # ).first()
514
+
515
+ # if existing_feedback:
516
+ # existing_feedback.feedback_score = feedbackdata.feedback_score
517
+ # existing_feedback.feedback_text = feedbackdata.feedback_text
518
+ # existing_feedback.timestamp = datetime.datetime.now()
519
+ # else:
520
+ # feedback_record = Feedback(
521
+ # telemetry_entry_id=feedbackdata.telemetry_entry_id,
522
+ # feedback_score=feedbackdata.feedback_score,
523
+ # feedback_text=feedbackdata.feedback_text,
524
+ # user_query=telemetry_record.user_question,
525
+ # llm_response=telemetry_record.response,
526
+ # timestamp=datetime.datetime.now()
527
+ # )
528
+ # db.add(feedback_record)
529
+
530
+ # db.commit()
531
+
532
+ # return {"message": "Feedback submitted successfully."}
533
+
534
+ # except Exception as e:
535
+ # raise e
536
+ # except Exception as e:
537
+ # db.rollback()
538
+ # raise Exception(status_code=500, detail=f"An error occurred: {str(e)}")
539
+ # finally:
540
+ # db.close()
541
+
542
+ def submit_feedback(feedbackdata):
543
+ db = SessionLocal()
544
+ try:
545
+ telemetry_record = db.query(Telemetry).filter(
546
+ Telemetry.transaction_id == feedbackdata['telemetry_entry_id'],
547
+ Telemetry.session_id == feedbackdata['session_id']
548
+ ).first()
549
+
550
+ if not telemetry_record:
551
+ raise Exception(status_code=404, detail="Telemetry entry not found or session ID mismatch.")
552
+
553
+ existing_feedback = db.query(Feedback).filter(
554
+ Feedback.telemetry_entry_id == feedbackdata['telemetry_entry_id']
555
+ ).first()
556
+
557
+ if existing_feedback:
558
+ existing_feedback.feedback_score = feedbackdata['feedback_score']
559
+ existing_feedback.feedback_text = feedbackdata['feedback_text']
560
+ existing_feedback.timestamp = datetime.datetime.now()
561
+ else:
562
+ feedback_record = Feedback(
563
+ telemetry_entry_id=feedbackdata['telemetry_entry_id'],
564
+ feedback_score=feedbackdata['feedback_score'],
565
+ feedback_text=feedbackdata['feedback_text'],
566
+ user_query=telemetry_record.user_question,
567
+ llm_response=telemetry_record.response,
568
+ timestamp=datetime.datetime.now()
569
+ )
570
+ db.add(feedback_record)
571
+
572
+ db.commit()
573
+
574
+ return {"message": "Feedback submitted successfully."}
575
+
576
+ except Exception as e:
577
+ raise e
578
+ except Exception as e:
579
+ db.rollback()
580
+ raise Exception(status_code=500, detail=f"An error occurred: {str(e)}")
581
+ finally:
582
+ db.close()
583
+
584
+ # @app.get("/conversations", response_model=List[ConversationSummary])
585
+ def get_conversations():
586
+ db = SessionLocal()
587
+ try:
588
+ conversations = db.query(ConversationHistory).order_by(ConversationHistory.last_updated.desc()).all()
589
+ summaries = []
590
+ for conv in conversations:
591
+ for msg in conv.messages:
592
+ print(msg)
593
+ first_user_message = next((msg for msg in conv.messages if msg["role"] == "user"), None)
594
+ title = first_user_message.get("content") if first_user_message else "New Conversation"
595
+ summaries.append({"session_id":conv.session_id, "title":title[:30] + "..." if len(title) > 30 else title})
596
+ return summaries
597
+ finally:
598
+ db.close()
599
+
600
+ # @app.get("/conversations/{session_id}", response_model=List[ChatHistoryEntry])
601
+ def get_conversation_history(session_id: str):
602
+ db = SessionLocal()
603
+ try:
604
+ conversation = db.query(ConversationHistory).filter(ConversationHistory.session_id == session_id).first()
605
+ if not conversation:
606
+ raise Exception(status_code=404, detail="Conversation not found.")
607
+ return conversation.messages
608
+ finally:
609
+ db.close()
610
+
611
+ if __name__ == "__main__":
612
+ pass
613
+ # uvicorn.run(app, host="0.0.0.0", port=8000)