prernajeet01's picture
Update app.py
22f579b verified
import gradio as gr
import os
import tempfile
import pandas as pd
import boto3
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, UnstructuredPowerPointLoader, UnstructuredExcelLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.chat_models import BedrockChat
from langchain_openai import ChatOpenAI
from langchain.schema import Document
from pathlib import Path
from typing import List, Union
import logging
# Optional OCR support
try:
from pdf2image import convert_from_path
import pytesseract
OCR_AVAILABLE = True
except ImportError:
OCR_AVAILABLE = False
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def get_api_keys():
"""Get API keys from Hugging Face Spaces secrets."""
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
aws_region = os.environ.get("AWS_REGION", "us-east-1") # Default to us-east-1 if not specified
openai_key = os.environ.get("OPENAI_API_KEY")
if not aws_access_key or not aws_secret_key or not openai_key:
return {
"status": "error",
"message": "Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and OPENAI_API_KEY in your Hugging Face Space secrets."
}
return {
"status": "success",
"aws_access_key": aws_access_key,
"aws_secret_key": aws_secret_key,
"aws_region": aws_region,
"openai_key": openai_key
}
class AuditAgent:
def __init__(self, model_name, provider):
self.model_name = model_name
self.provider = provider
self.document_store = None
# Initialize text splitter
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
# Get API keys
api_keys = get_api_keys()
if api_keys["status"] == "error":
raise ValueError(api_keys["message"])
# Initialize embeddings
self.embeddings = OpenAIEmbeddings(openai_api_key=api_keys["openai_key"])
if provider == "bedrock":
# Initialize AWS Bedrock client
try:
self.bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=api_keys["aws_access_key"],
aws_secret_access_key=api_keys["aws_secret_key"],
region_name=api_keys["aws_region"]
)
# Use BedrockChat with the same interface
self.llm = BedrockChat(
client=self.bedrock_client,
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
model_kwargs={"temperature": 0.2}
)
except Exception as e:
logging.error(f"Bedrock initialization error: {str(e)}")
raise ValueError(f"Bedrock initialization error: {str(e)}")
elif provider == "openai":
self.llm = ChatOpenAI(
model_name=model_name,
openai_api_key=api_keys["openai_key"],
temperature=0.2
)
else:
raise ValueError(f"Unsupported provider: {provider}")
def process_query(self, query):
"""Process a general query or numerical problem."""
if not query.strip():
return "Please provide a non-empty query."
system_prompt = """You are an expert auditor assistant. Provide clear, detailed responses to audit-related queries.
For numerical problems, show your calculations step by step. Always consider relevant accounting standards and auditing principles."""
try:
if self.provider == "bedrock":
# Handle the response format for BedrockChat
response = self.llm.invoke(
f"{system_prompt}\n\nUser: {query}\nAssistant:"
)
# Extract the content based on response structure
return response.content if hasattr(response, 'content') else str(response)
elif self.provider == "openai":
response = self.llm.invoke(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
]
)
return response.content
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
return f"Error processing query: {str(e)}"
def process_documents(self, file_paths):
"""Process multiple documents and return results."""
results = {}
for file_path in file_paths:
try:
# Get file extension
file_ext = os.path.splitext(file_path.lower())[1]
# Validate file extension
supported_exts = ['.pdf', '.docx', '.pptx', '.xlsx', '.xls', '.txt']
if file_ext not in supported_exts:
results[file_path] = f"Unsupported file type: {file_ext}"
continue
# Read file content
with open(file_path, 'rb') as f:
content = f.read()
# Process document based on type
documents = self.process_document(content, file_ext)
# Create vector store with the documents
if documents:
if not self.document_store:
self.document_store = FAISS.from_documents(documents, self.embeddings)
else:
# Add to existing store
self.document_store.add_documents(documents)
num_chunks = len(documents)
results[file_path] = f"Success ({num_chunks} chunks extracted)"
else:
results[file_path] = "No content could be extracted"
except Exception as e:
logging.error(f"Error processing document {file_path}: {str(e)}")
results[file_path] = str(e)
return results
def process_document(self, content, doc_type):
"""Process document content based on type."""
with tempfile.NamedTemporaryFile(delete=False, suffix=doc_type) as temp_file:
temp_file.write(content)
temp_file_path = temp_file.name
try:
documents = self.load_document(temp_file_path)
return self.split_documents(documents)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
def load_document(self, file_path):
"""Load document using appropriate loader with OCR fallback for PDFs."""
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.pdf':
# Try normal PDF loading first
try:
loader = PyPDFLoader(str(file_path))
documents = loader.load()
if not any(doc.page_content.strip() for doc in documents):
raise ValueError("No text content found")
return documents
except Exception as e:
logging.warning(f"Standard PDF extraction failed: {str(e)}")
# If normal loading fails, try OCR
if OCR_AVAILABLE:
logging.info("Attempting PDF extraction with OCR")
return self._process_pdf_with_ocr(file_path)
else:
raise ValueError("PDF extraction failed and OCR is not available")
elif suffix == '.docx':
try:
# Enhanced error handling for Word documents
loader = Docx2txtLoader(str(file_path))
documents = loader.load()
# Verify content was extracted
if not documents or not any(doc.page_content.strip() for doc in documents):
raise ValueError("No content extracted from Word document")
return documents
except Exception as e:
logging.error(f"Word document processing error: {str(e)}")
raise ValueError(f"Failed to process Word document: {str(e)}")
elif suffix == '.pptx':
loader = UnstructuredPowerPointLoader(str(file_path))
return loader.load()
elif suffix in ['.xlsx', '.xls']:
loader = UnstructuredExcelLoader(str(file_path))
return loader.load()
elif suffix == '.txt':
loader = TextLoader(str(file_path))
return loader.load()
else:
raise ValueError(f"Unsupported file type: {suffix}")
def _process_pdf_with_ocr(self, file_path):
"""Process PDF with OCR using Tesseract."""
if not OCR_AVAILABLE:
raise ImportError("pdf2image and pytesseract required for OCR processing")
documents = []
images = convert_from_path(str(file_path))
for i, image in enumerate(images):
text = pytesseract.image_to_string(image)
if text.strip():
documents.append(Document(
page_content=text,
metadata={"source": str(file_path), "page": i + 1}
))
return documents
def split_documents(self, documents):
"""Split documents into chunks."""
return self.text_splitter.split_documents(documents)
def query_documents(self, query):
"""Query the processed documents."""
if not self.document_store:
return "Please upload and process documents first"
if not query.strip():
return "Please provide a non-empty query."
try:
qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.document_store.as_retriever(),
return_source_documents=True
)
response = qa_chain({"query": query})
result = response['result']
source_docs = response.get('source_documents', [])
if source_docs:
result += "\n\n**Sources:**\n"
for i, doc in enumerate(source_docs, 1):
result += f"{i}. {doc.metadata.get('source', 'Unknown source')}, page {doc.metadata.get('page', 'N/A')}\n"
return result
except Exception as e:
return f"Error querying documents: {str(e)}"
# Updated LLM configurations - replaced openorca-mini with o3-mini
llm_configs = {
"claude-3-sonnet": {
"name": "anthropic.claude-3-sonnet-20240229-v1:0",
"provider": "bedrock",
"description": "Balanced performance (AWS Bedrock)"
},
"gpt-4": {
"name": "gpt-4",
"provider": "openai",
"description": "Advanced reasoning"
},
"gpt-3.5-turbo": {
"name": "gpt-3.5-turbo",
"provider": "openai",
"description": "Fast responses"
},
"o3-mini": {
"name": "o3-mini",
"provider": "openai",
"description": "Compact OpenAI model"
}
}
def create_interface():
# Check API keys first
api_keys = get_api_keys()
if api_keys["status"] == "error":
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# ⚠️ Configuration Error")
gr.Markdown(api_keys["message"])
gr.Markdown("""
To set up your Hugging Face Space:
1. Go to your Space's Settings
2. Add your API keys as secrets:
- AWS_ACCESS_KEY_ID
- AWS_SECRET_ACCESS_KEY
- AWS_REGION
- OPENAI_API_KEY
3. Restart your Space
""")
return demo
# Initialize agents dictionary - will be initialized on demand
audit_agents = {}
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# 🔍 Amy - Your Audit Copilot")
# Status indicator for initialization and operations
status_message = gr.Textbox(label="Status", value="Ready")
# Document processing section - moved above model selection
gr.Markdown("## 📑 Document Processing")
with gr.Row():
file_upload = gr.File(
file_count="multiple",
label="Upload Audit Documents (PDF, DOCX, PPTX, TXT, XLSX)",
type="filepath"
)
upload_button = gr.Button("Process Documents")
upload_output = gr.Textbox(label="Processing Status", lines=10)
# Use tabs for model selection
with gr.Tabs() as model_tabs:
model_tab_dict = {}
for model_id, config in llm_configs.items():
with gr.Tab(f"{model_id} - {config['description']}") as tab:
model_tab_dict[model_id] = tab
with gr.Tabs() as feature_tabs:
# Chat interface with history
with gr.Tab("💬 Conversation"):
chat_history = gr.Chatbot(height=400)
chat_input = gr.Textbox(
lines=3,
label="Ask your audit question",
placeholder="Enter your question here..."
)
chat_clear = gr.Button("Clear Chat")
chat_button = gr.Button("Send")
with gr.Tab("🔢 Numerical Problem"):
problem_input = gr.Textbox(
lines=5,
label="Describe the Problem",
placeholder="Enter your numerical audit problem..."
)
solve_button = gr.Button("Solve")
solution_output = gr.Markdown(label="Solution")
# Document query tab
with gr.Tab("🔍 Document Query"):
query_input = gr.Textbox(
lines=3,
label="Query Documents",
placeholder="Ask about your uploaded documents..."
)
query_button = gr.Button("Query")
query_output = gr.Markdown(label="Response")
# Track the selected model
selected_model = gr.State("claude-3-sonnet")
# Update selected model when tabs change
def update_selected_model(evt: gr.SelectData):
model_ids = list(llm_configs.keys())
if evt.index < len(model_ids):
return model_ids[evt.index]
return "claude-3-sonnet" # Default
model_tabs.select(update_selected_model, outputs=[selected_model])
# Get or initialize agent and return both agent and status message
def get_or_initialize_agent(model_name):
"""Initialize an agent if not already initialized and return a status message"""
init_message = f"Initializing {model_name}..."
# If agent already exists, return it with a status message
if model_name in audit_agents:
return audit_agents[model_name], f"{model_name} ready"
# Try to initialize the agent
try:
config = llm_configs[model_name]
logging.info(init_message)
agent = AuditAgent(config["name"], config["provider"])
audit_agents[model_name] = agent
success_message = f"{model_name} initialized successfully"
logging.info(success_message)
return agent, success_message
except Exception as e:
error_message = f"Error initializing {model_name}: {str(e)}"
logging.error(error_message)
return None, error_message
# Handle chat with history
def respond_to_chat(message, history, model_name):
if not message.strip():
return "", history
# Get or initialize agent
agent, init_status = get_or_initialize_agent(model_name)
# If initialization failed
if agent is None:
history.append((message, f"Could not initialize {model_name}. {init_status}"))
return "", history, f"Error: {init_status}"
# Process the query
try:
result = agent.process_query(message)
history.append((message, result))
return "", history, f"Response from {model_name}"
except Exception as e:
error_msg = f"Error: {str(e)}"
history.append((message, error_msg))
return "", history, error_msg
# Clear chat history
def clear_chat_history():
return [], "Chat history cleared"
# Handle numerical problem
def handle_problem(problem, model_name):
if not problem.strip():
return "Please provide a problem description", "No problem entered"
status = f"Solving problem with {model_name}..."
# Get or initialize agent
agent, init_status = get_or_initialize_agent(model_name)
# If initialization failed
if agent is None:
return f"Could not initialize {model_name}. {init_status}", init_status
# Process the problem
try:
result = agent.process_query(problem)
return result, f"Problem solved with {model_name}"
except Exception as e:
error_msg = f"Error solving problem: {str(e)}"
return error_msg, error_msg
# Improved file upload handler for multiple files
def handle_file_upload(file_paths, model_name):
if not file_paths:
return "No files uploaded. Please upload files."
# Get or initialize agent
agent, init_status = get_or_initialize_agent(model_name)
# If initialization failed
if agent is None:
return init_status
logging.info(f"Processing {len(file_paths)} files")
# Process all documents
try:
results = agent.process_documents(file_paths)
# Format results
output_lines = ["## Document Processing Results"]
for file_path, status in results.items():
file_name = os.path.basename(file_path)
if "Success" in status:
output_lines.append(f"✓ {file_name}: {status}")
else:
output_lines.append(f"❌ {file_name}: {status}")
if any("Success" in status for status in results.values()):
output_lines.append("\n✅ Documents are ready for querying!")
return "\n".join(output_lines)
except Exception as e:
logging.error(f"File upload error: {str(e)}")
return f"Error processing files: {str(e)}"
# Handle document query
def handle_query(query, model_name):
if not query.strip():
return "Please provide a query", "No query entered"
status = f"Querying documents with {model_name}..."
# Get or initialize agent
agent, init_status = get_or_initialize_agent(model_name)
# If initialization failed
if agent is None:
return f"Could not initialize {model_name}. {init_status}", init_status
# Query the documents
try:
result = agent.query_documents(query)
return result, f"Documents queried with {model_name}"
except Exception as e:
error_msg = f"Error querying documents: {str(e)}"
return error_msg, error_msg
# Set up event handlers
chat_button.click(
respond_to_chat,
inputs=[chat_input, chat_history, selected_model],
outputs=[chat_input, chat_history, status_message]
)
chat_clear.click(
clear_chat_history,
outputs=[chat_history, status_message]
)
solve_button.click(
handle_problem,
inputs=[problem_input, selected_model],
outputs=[solution_output, status_message]
)
upload_button.click(
handle_file_upload,
inputs=[file_upload, selected_model],
outputs=[upload_output]
)
query_button.click(
handle_query,
inputs=[query_input, selected_model],
outputs=[query_output, status_message]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)