AgentGuard / app.py
shreemanigandan's picture
Update app.py
f82fe39 verified
#!/usr/bin/env python3
"""
Simple HTTP Server - Hello World with POST data
"""
import json
import logging
import sys
import sentence_transformers
from flask import Flask, request, jsonify
from flask_cors import CORS
from groq import Groq
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
compliance_collection = 'compliance_collection'
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
vectorstore = None
app = Flask(__name__)
CORS(app)
@app.route('/', methods=['GET'])
def index():
"""Health check"""
logger.info("test1")
if vectorstore is None:
logger.info("ingesting docs..")
pdf_folder_location = "Bank_Contact_Center_Compliance_Policies.pdf"
ingest_documents(
pdf_folder_location=pdf_folder_location,
tenant_id="tenant_123",
policy_set_id="policy_set_abc",
domain="banking"
)
return jsonify({
'status': 'running',
'message': 'Hello World API Server'
})
@app.route('/api/v1/transcript/process', methods=['POST'])
def process():
"""Process POST data and return Hello World"""
data = request.get_json()
logger.info("test2")
result = callLlm(data)
# Parse the JSON string returned by the LLM
parsed_result = json.loads(result)
return jsonify(parsed_result)
def ingest_documents(pdf_folder_location, tenant_id=None, policy_set_id=None, domain=None):
"""Ingest PDF documents into vector store with metadata"""
import os
global vectorstore
pdf_loader = PyPDFLoader(pdf_folder_location)
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
encoding_name='cl100k_base',
chunk_size=64,
chunk_overlap=16
)
compliance_chunks = pdf_loader.load_and_split(text_splitter)
# Add metadata ONCE (will propagate to all chunks)
if tenant_id or policy_set_id or domain:
for d in compliance_chunks:
metadata = {}
if tenant_id:
metadata["tenant_id"] = tenant_id
if policy_set_id:
metadata["policy_set_id"] = policy_set_id
if domain:
metadata["domain"] = domain
d.metadata.update(metadata)
len(compliance_chunks)
os.environ["CHROMA_TELEMETRY"] = "FALSE"
vectorstore = Chroma(
collection_name=compliance_collection,
persist_directory='./compliance_db',
embedding_function=embedding_model
)
vectorstore.add_documents(compliance_chunks)
vectorstore.persist()
logger.info(f"Ingested {len(compliance_chunks)} document chunks")
return vectorstore
def callLlm(data):
import os
global compliance_collection, embedding_model, vectorstore
# Loading the Chroma DB and using the retriever to retreive the chunks just for testing
transcript = data["transcript"]
combined_text = " ".join(turn["content"] for turn in transcript if "content" in turn)
client = Groq()
model_name = 'openai/gpt-oss-20b'
# # Original cell: 26E1QcvAR-OO
# # Retrieve the first two chunks from the vector store
# retrieved_data = vectorstore_persisted.get(
# include=['metadatas', 'embeddings', 'documents'],
# limit=2
# )
# # Display the content and embeddings of the first two chunks
# for i in range(len(retrieved_data['ids'])):
# logger.info(f"Chunk ID: {retrieved_data['ids'][i]}")
# logger.info(f"Chunk Content: {retrieved_data['documents'][i]}")
# logger.info(f"Chunk Embedding (first 10 values): {retrieved_data['embeddings'][i][:10]}")
qna_system_message = """
You are an assistant to a contact center human agent who checks if whatever the agent is speaking is compliant with the company policies based on the policy doc provided.
Agent utterances will need to be compared with the portions of relevent compliance document provided in the context and find the violations and their degree, if any.
This context will begin with the token: ###Context.
The context contains references to specific portions of a document relevant to the agent utterances.
A portion of the Transcript between the human agent and a customer will begin with the token: ###Transcript.
Please find policy violations only using the context provided in the input.
Do not mention anything about the context in your final answer. Your response should only contain the severity of the violation.
If no context is provided, respond with "Compliant".
Pick the highest severity if multiple violations are there. Supported categories are - WARNING, ERROR and CRITICAL, in the order of lowest to highest level of violation.
If a policy is defined as an enforcement action, classify it as CRITICAL.
If a policy is defined as a guideline, classify it as WARNING.
If a policy is defined as a recommendation, classify it as ERROR.
If there are no violations, respond with "Compliant".
Also give some reasoning for your classification.
Response should be in the following json format:
{
"violation_severity": "<One of Compliant, WARNING, ERROR, CRITICAL>",
"reasoning": "<detailed reasoning here>"
}
"""
qna_user_message_template = """
###Context
Here are some documents that are relevant to the question mentioned below.
{context}
###Transcript
{transcript}
"""
tenant_id = data["tenant_id"]
retriever = vectorstore.as_retriever(
search_type='similarity',
search_kwargs={'k': 5, 'filter' : {"tenant_id": tenant_id}}
)
relevant_document_chunks = retriever.get_relevant_documents(combined_text)
# relevant_document_chunks = vectorstore_persisted.similarity_search(combined_text, k=3,
# filter={"tenant_id": tenant_id})
len(relevant_document_chunks)
logger.info("relevent chunks: ")
for document in relevant_document_chunks:
logger.info(document.page_content.replace("\t", " "))
break
context_list = [d.page_content for d in relevant_document_chunks]
context_for_query = ". ".join(context_list)
prompt = [
{'role': 'system', 'content': qna_system_message},
{'role': 'user', 'content': qna_user_message_template.format(
context=context_for_query,
transcript=transcript
)
}
]
logger.info(prompt)
try:
response = client.chat.completions.create(
model=model_name,
messages=prompt,
temperature=0
)
prediction = response.choices[0].message.content.strip()
except Exception as e:
prediction = f'Sorry, I encountered the following error: \n {e}'
logger.info(prediction)
return prediction
if __name__ == '__main__':
import os
# todo: list all policy documents and ingest them once
pdf_folder_location = "ComplianceFile.pdf"
ingest_documents(
pdf_folder_location=pdf_folder_location,
tenant_id="tenant_123",
policy_set_id="policy_set_abc",
domain="banking"
)
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860
logger.info(f"Starting server on port {port}")
logger.info(f"POST endpoint: http://0.0.0.0:{port}/api/v1/transcript/process")
app.run(
host='0.0.0.0',
port=port,
debug=False # Set to False for production
)