CAI / app.py
AliceRolan's picture
Update app.py
dd29e87 verified
import gradio as gr
from huggingface_hub import InferenceClient
import os
import torch
import transformers
from tensorflow import keras
from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM,AutoModelForCausalLM
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain.llms import HuggingFacePipeline
import gradio as gr
import re
from bs4 import BeautifulSoup
from guardrails.validators import Validator, register_validator, ValidationResult, FailResult, PassResult
from presidio_analyzer import AnalyzerEngine
from presidio_analyzer.nlp_engine import SpacyNlpEngine, NlpEngineProvider
from better_profanity import profanity
from presidio_analyzer import PatternRecognizer, Pattern
import inflection
from guardrails import Guard
import warnings
# Suppress all warnings
warnings.filterwarnings("ignore")
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
print("GPU Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU Found")
pdf_files = ["Apple-10K-2023.pdf", "Apple-10K-2024.pdf"]
"""### πŸ“Œ Step 1: Load Multiple 10-K Financial Report PDFs """
all_documents = []
def preprocess_text(text):
# Remove HTML tags
text = BeautifulSoup(text, "html.parser").get_text()
# Remove extra whitespace and newlines
text = re.sub(r'\s+', ' ', text).strip()
return text
for pdf_path in pdf_files:
loader = PyPDFLoader(pdf_path)
documents = loader.load()
for doc in documents:
doc.page_content = preprocess_text(doc.page_content)
all_documents.extend(documents)
"""### πŸ“Œ Step 2: Split Text into Chunks
<p> Here each split will also have a metadata defining the location of the chunk in the actual document for citation,also other details.As the pdf text is clean with no html tags etc , we use it as such with no cleaning
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
all_splits = text_splitter.split_documents(all_documents)
"""### πŸ“Œ Step 3: Create Embeddings using Sentence Transformers"""
# Check if CUDA (GPU) is available; otherwise, use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={"device": device})
"""### πŸ“Œ Step 4: Store & Retrieve using ChromaDB"""
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="content/drive/MyDrive/RAG_DB/chroma_db")
retriever = vectordb.as_retriever()
# Choose a smaller T5 model
model_name = "google/flan-t5-large"
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
# Create Hugging Face pipeline
hf_pipeline = pipeline( "text2text-generation",
model=model,
tokenizer=tokenizer,
truncation=True)
# Integrate with LangChain
llm = HuggingFacePipeline(pipeline=hf_pipeline)
"""### πŸ“Œ Step 6: Define RAG Prompt"""
# Define RAG Prompt
template = """You are an AI assistant answering financial questions using retrieved financial reports.
Use the following retrieved context to answer the question concisely.
Question: {question}
Context: {context}
Answer:
"""
prompt = ChatPromptTemplate.from_template(template)
"""### πŸ“Œ Step 7: Create RAG pipeline"""
# Create RAG Pipeline
conversation_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
def to_camel_case(text):
"""Convert normal text to camelCase using inflection package."""
camel_text = inflection.camelize(text, uppercase_first_letter=True) # Ensure lowerCamelCase
return camel_text
"""### πŸ“Œ Step 8: Create a function to get the confidence score"""
# Function to Get Confidence Score
def get_confidence_score(question):
retrieved_docs_with_scores = vectordb.similarity_search_with_score(question, k=5)
max_score = max([doc[1] for doc in retrieved_docs_with_scores]) if retrieved_docs_with_scores else 0
return min(1.0, round(max_score, 2)) # Normalize to 0-1 scale
## GuardRail validators
# Define NLP Configuration with lang_code
nlp_configuration = {
"nlp_engine_name": "spacy",
"models": [{"lang_code": "en", "model_name": "en_core_web_lg"}], # Specify language
}
# Define SSN Pattern
ssn_regex = r"\b\d{3}-\d{2}-\d{4}\b" # Matches US SSN format (123-45-6789)
ssn_pattern = Pattern(name="SSN Pattern", regex=ssn_regex, score=0.85) # Score between 0-1
# Create Custom SSN Recognizer
ssn_recognizer = PatternRecognizer(supported_entity="SSN", patterns=[ssn_pattern])
analyzer = AnalyzerEngine()
analyzer.registry.add_recognizer(ssn_recognizer)
@register_validator(name="custom_pii_detector", data_type="string")
class CustomPIIDetector(Validator):
def validate(self, value, metadata={}) -> ValidationResult:
# Analyze text for PII
results = analyzer.analyze(text=value, entities=["PHONE_NUMBER", "EMAIL_ADDRESS", "CREDIT_CARD", "SSN"], language="en")
if results:
detected_entities = ", ".join(set([res.entity_type for res in results]))
return FailResult(
error_message=f"Query contains PII: {detected_entities}."
)
return PassResult()
# Custom Profanity Detector using better-profanity
@register_validator(name="custom_profanity_detector", data_type="string")
class CustomProfanityDetector(Validator):
def validate(self, value, metadata={}) -> ValidationResult:
if profanity.contains_profanity(value):
return FailResult(
error_message="Query contains profanity."
)
return PassResult()
# Custom Relevance Validator for Finance and Apple-related Queries
@register_validator(name="custom_relevance_detector", data_type="string")
class CustomRelevanceDetector(Validator):
def validate(self, value, metadata={}) -> ValidationResult:
finance_keywords = {"revenue", "profit", "expenses", "balance sheet", "earnings", "financial", "investment", "dividends", "assets", "liabilities", "cash flow", "loss","turnover"}
apple_keywords = {"apple", "iphone", "macbook", "tim cook", "apple inc", "ios", "mac", "ipad"}
text_lower = value.lower()
# Check if any finance-related or Apple-related keyword appears in the query
if not any(keyword in text_lower for keyword in (finance_keywords | apple_keywords)): # Use set union
return FailResult(
error_message="Query is not related to finance or Apple."
)
return PassResult()
guard = Guard().use(CustomPIIDetector).use(CustomProfanityDetector).use(CustomRelevanceDetector)
"""### πŸ“Œ Step 10: Integrate with Gradio UI"""
# Define Chatbot Function
def chat_with_rag(message, history):
# try:
# res = guard.validate(message)
# except Exception as e:
# return f"❌ Guardrail {str(e)}"
try:
response = conversation_chain.invoke(message)
confidence_score = get_confidence_score(message)
formatted_response = f"**Answer:** {to_camel_case(response)}\n\n**Confidence Score:** {confidence_score:.2f}"
return formatted_response
except Exception as e:
return f"Error: {str(e)}"
# A relevant financial question (high-confidence).
user_input = "what are the biggest challenges for Apple?"
confidence_score = get_confidence_score(user_input)
output = conversation_chain.invoke(user_input)
print(f"πŸ“Œ **Answer:** {to_camel_case(output)}\n\n**Confidence Score:** {confidence_score:.2f}")
# A relevant financial question (low-confidence).
user_input = "what was apple's Total revenue in 2023?"
confidence_score = get_confidence_score(user_input)
output = conversation_chain.invoke(user_input)
print(f"πŸ“Œ **Answer:** {to_camel_case(output)}\n\n**Confidence Score:** {confidence_score:.2f}")
# An irrelevant question (e.g., "What is the capital of France?") to check system robustness.
user_input = "What is the capital of France?"
output = conversation_chain.invoke(user_input)
confidence_score = get_confidence_score(user_input)
print(f"πŸ“Œ **Answer:** {to_camel_case(output)}\n\n**Confidence Score:** {confidence_score:.2f}")
# Create Gradio Chatbot UI with Auto-Clearing Input
demo = gr.ChatInterface(
fn=chat_with_rag, # Function to generate responses
title="πŸ“Š Financial Basic RAG Chatbot",
description="Ask questions about Apple's financial reports and get AI-powered answers!",
theme="soft",
examples=[
["What are factors impacting Apple's financial growth?"],
["what was apple's Total revenue in 2023?"],
["What is the capital of France?"],
],
submit_btn="Ask",
stop_btn=None,
)
if __name__ == "__main__":
demo.launch()