Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| import boto3 | |
| from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, UnstructuredPowerPointLoader, UnstructuredExcelLoader, TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.chat_models import BedrockChat | |
| from langchain_openai import ChatOpenAI | |
| from langchain.schema import Document | |
| from pathlib import Path | |
| from typing import List, Union | |
| import logging | |
| # Optional OCR support | |
| try: | |
| from pdf2image import convert_from_path | |
| import pytesseract | |
| OCR_AVAILABLE = True | |
| except ImportError: | |
| OCR_AVAILABLE = False | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| def get_api_keys(): | |
| """Get API keys from Hugging Face Spaces secrets.""" | |
| aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID") | |
| aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
| aws_region = os.environ.get("AWS_REGION", "us-east-1") # Default to us-east-1 if not specified | |
| openai_key = os.environ.get("OPENAI_API_KEY") | |
| if not aws_access_key or not aws_secret_key or not openai_key: | |
| return { | |
| "status": "error", | |
| "message": "Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and OPENAI_API_KEY in your Hugging Face Space secrets." | |
| } | |
| return { | |
| "status": "success", | |
| "aws_access_key": aws_access_key, | |
| "aws_secret_key": aws_secret_key, | |
| "aws_region": aws_region, | |
| "openai_key": openai_key | |
| } | |
| class AuditAgent: | |
| def __init__(self, model_name, provider): | |
| self.model_name = model_name | |
| self.provider = provider | |
| self.document_store = None | |
| # Initialize text splitter | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200 | |
| ) | |
| # Get API keys | |
| api_keys = get_api_keys() | |
| if api_keys["status"] == "error": | |
| raise ValueError(api_keys["message"]) | |
| # Initialize embeddings | |
| self.embeddings = OpenAIEmbeddings(openai_api_key=api_keys["openai_key"]) | |
| if provider == "bedrock": | |
| # Initialize AWS Bedrock client | |
| try: | |
| self.bedrock_client = boto3.client( | |
| service_name="bedrock-runtime", | |
| aws_access_key_id=api_keys["aws_access_key"], | |
| aws_secret_access_key=api_keys["aws_secret_key"], | |
| region_name=api_keys["aws_region"] | |
| ) | |
| # Use BedrockChat with the same interface | |
| self.llm = BedrockChat( | |
| client=self.bedrock_client, | |
| model_id="anthropic.claude-3-sonnet-20240229-v1:0", | |
| model_kwargs={"temperature": 0.2} | |
| ) | |
| except Exception as e: | |
| logging.error(f"Bedrock initialization error: {str(e)}") | |
| raise ValueError(f"Bedrock initialization error: {str(e)}") | |
| elif provider == "openai": | |
| self.llm = ChatOpenAI( | |
| model_name=model_name, | |
| openai_api_key=api_keys["openai_key"], | |
| temperature=0.2 | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def process_query(self, query): | |
| """Process a general query or numerical problem.""" | |
| if not query.strip(): | |
| return "Please provide a non-empty query." | |
| system_prompt = """You are an expert auditor assistant. Provide clear, detailed responses to audit-related queries. | |
| For numerical problems, show your calculations step by step. Always consider relevant accounting standards and auditing principles.""" | |
| try: | |
| if self.provider == "bedrock": | |
| # Handle the response format for BedrockChat | |
| response = self.llm.invoke( | |
| f"{system_prompt}\n\nUser: {query}\nAssistant:" | |
| ) | |
| # Extract the content based on response structure | |
| return response.content if hasattr(response, 'content') else str(response) | |
| elif self.provider == "openai": | |
| response = self.llm.invoke( | |
| [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": query} | |
| ] | |
| ) | |
| return response.content | |
| else: | |
| raise ValueError(f"Unsupported provider: {self.provider}") | |
| except Exception as e: | |
| return f"Error processing query: {str(e)}" | |
| def process_documents(self, file_paths): | |
| """Process multiple documents and return results.""" | |
| results = {} | |
| for file_path in file_paths: | |
| try: | |
| # Get file extension | |
| file_ext = os.path.splitext(file_path.lower())[1] | |
| # Validate file extension | |
| supported_exts = ['.pdf', '.docx', '.pptx', '.xlsx', '.xls', '.txt'] | |
| if file_ext not in supported_exts: | |
| results[file_path] = f"Unsupported file type: {file_ext}" | |
| continue | |
| # Read file content | |
| with open(file_path, 'rb') as f: | |
| content = f.read() | |
| # Process document based on type | |
| documents = self.process_document(content, file_ext) | |
| # Create vector store with the documents | |
| if documents: | |
| if not self.document_store: | |
| self.document_store = FAISS.from_documents(documents, self.embeddings) | |
| else: | |
| # Add to existing store | |
| self.document_store.add_documents(documents) | |
| num_chunks = len(documents) | |
| results[file_path] = f"Success ({num_chunks} chunks extracted)" | |
| else: | |
| results[file_path] = "No content could be extracted" | |
| except Exception as e: | |
| logging.error(f"Error processing document {file_path}: {str(e)}") | |
| results[file_path] = str(e) | |
| return results | |
| def process_document(self, content, doc_type): | |
| """Process document content based on type.""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=doc_type) as temp_file: | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| documents = self.load_document(temp_file_path) | |
| return self.split_documents(documents) | |
| finally: | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| def load_document(self, file_path): | |
| """Load document using appropriate loader with OCR fallback for PDFs.""" | |
| file_path = Path(file_path) | |
| suffix = file_path.suffix.lower() | |
| if suffix == '.pdf': | |
| # Try normal PDF loading first | |
| try: | |
| loader = PyPDFLoader(str(file_path)) | |
| documents = loader.load() | |
| if not any(doc.page_content.strip() for doc in documents): | |
| raise ValueError("No text content found") | |
| return documents | |
| except Exception as e: | |
| logging.warning(f"Standard PDF extraction failed: {str(e)}") | |
| # If normal loading fails, try OCR | |
| if OCR_AVAILABLE: | |
| logging.info("Attempting PDF extraction with OCR") | |
| return self._process_pdf_with_ocr(file_path) | |
| else: | |
| raise ValueError("PDF extraction failed and OCR is not available") | |
| elif suffix == '.docx': | |
| try: | |
| # Enhanced error handling for Word documents | |
| loader = Docx2txtLoader(str(file_path)) | |
| documents = loader.load() | |
| # Verify content was extracted | |
| if not documents or not any(doc.page_content.strip() for doc in documents): | |
| raise ValueError("No content extracted from Word document") | |
| return documents | |
| except Exception as e: | |
| logging.error(f"Word document processing error: {str(e)}") | |
| raise ValueError(f"Failed to process Word document: {str(e)}") | |
| elif suffix == '.pptx': | |
| loader = UnstructuredPowerPointLoader(str(file_path)) | |
| return loader.load() | |
| elif suffix in ['.xlsx', '.xls']: | |
| loader = UnstructuredExcelLoader(str(file_path)) | |
| return loader.load() | |
| elif suffix == '.txt': | |
| loader = TextLoader(str(file_path)) | |
| return loader.load() | |
| else: | |
| raise ValueError(f"Unsupported file type: {suffix}") | |
| def _process_pdf_with_ocr(self, file_path): | |
| """Process PDF with OCR using Tesseract.""" | |
| if not OCR_AVAILABLE: | |
| raise ImportError("pdf2image and pytesseract required for OCR processing") | |
| documents = [] | |
| images = convert_from_path(str(file_path)) | |
| for i, image in enumerate(images): | |
| text = pytesseract.image_to_string(image) | |
| if text.strip(): | |
| documents.append(Document( | |
| page_content=text, | |
| metadata={"source": str(file_path), "page": i + 1} | |
| )) | |
| return documents | |
| def split_documents(self, documents): | |
| """Split documents into chunks.""" | |
| return self.text_splitter.split_documents(documents) | |
| def query_documents(self, query): | |
| """Query the processed documents.""" | |
| if not self.document_store: | |
| return "Please upload and process documents first" | |
| if not query.strip(): | |
| return "Please provide a non-empty query." | |
| try: | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.document_store.as_retriever(), | |
| return_source_documents=True | |
| ) | |
| response = qa_chain({"query": query}) | |
| result = response['result'] | |
| source_docs = response.get('source_documents', []) | |
| if source_docs: | |
| result += "\n\n**Sources:**\n" | |
| for i, doc in enumerate(source_docs, 1): | |
| result += f"{i}. {doc.metadata.get('source', 'Unknown source')}, page {doc.metadata.get('page', 'N/A')}\n" | |
| return result | |
| except Exception as e: | |
| return f"Error querying documents: {str(e)}" | |
| # Updated LLM configurations - replaced openorca-mini with o3-mini | |
| llm_configs = { | |
| "claude-3-sonnet": { | |
| "name": "anthropic.claude-3-sonnet-20240229-v1:0", | |
| "provider": "bedrock", | |
| "description": "Balanced performance (AWS Bedrock)" | |
| }, | |
| "gpt-4": { | |
| "name": "gpt-4", | |
| "provider": "openai", | |
| "description": "Advanced reasoning" | |
| }, | |
| "gpt-3.5-turbo": { | |
| "name": "gpt-3.5-turbo", | |
| "provider": "openai", | |
| "description": "Fast responses" | |
| }, | |
| "o3-mini": { | |
| "name": "o3-mini", | |
| "provider": "openai", | |
| "description": "Compact OpenAI model" | |
| } | |
| } | |
| def create_interface(): | |
| # Check API keys first | |
| api_keys = get_api_keys() | |
| if api_keys["status"] == "error": | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# ⚠️ Configuration Error") | |
| gr.Markdown(api_keys["message"]) | |
| gr.Markdown(""" | |
| To set up your Hugging Face Space: | |
| 1. Go to your Space's Settings | |
| 2. Add your API keys as secrets: | |
| - AWS_ACCESS_KEY_ID | |
| - AWS_SECRET_ACCESS_KEY | |
| - AWS_REGION | |
| - OPENAI_API_KEY | |
| 3. Restart your Space | |
| """) | |
| return demo | |
| # Initialize agents dictionary - will be initialized on demand | |
| audit_agents = {} | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# 🔍 Amy - Your Audit Copilot") | |
| # Status indicator for initialization and operations | |
| status_message = gr.Textbox(label="Status", value="Ready") | |
| # Document processing section - moved above model selection | |
| gr.Markdown("## 📑 Document Processing") | |
| with gr.Row(): | |
| file_upload = gr.File( | |
| file_count="multiple", | |
| label="Upload Audit Documents (PDF, DOCX, PPTX, TXT, XLSX)", | |
| type="filepath" | |
| ) | |
| upload_button = gr.Button("Process Documents") | |
| upload_output = gr.Textbox(label="Processing Status", lines=10) | |
| # Use tabs for model selection | |
| with gr.Tabs() as model_tabs: | |
| model_tab_dict = {} | |
| for model_id, config in llm_configs.items(): | |
| with gr.Tab(f"{model_id} - {config['description']}") as tab: | |
| model_tab_dict[model_id] = tab | |
| with gr.Tabs() as feature_tabs: | |
| # Chat interface with history | |
| with gr.Tab("💬 Conversation"): | |
| chat_history = gr.Chatbot(height=400) | |
| chat_input = gr.Textbox( | |
| lines=3, | |
| label="Ask your audit question", | |
| placeholder="Enter your question here..." | |
| ) | |
| chat_clear = gr.Button("Clear Chat") | |
| chat_button = gr.Button("Send") | |
| with gr.Tab("🔢 Numerical Problem"): | |
| problem_input = gr.Textbox( | |
| lines=5, | |
| label="Describe the Problem", | |
| placeholder="Enter your numerical audit problem..." | |
| ) | |
| solve_button = gr.Button("Solve") | |
| solution_output = gr.Markdown(label="Solution") | |
| # Document query tab | |
| with gr.Tab("🔍 Document Query"): | |
| query_input = gr.Textbox( | |
| lines=3, | |
| label="Query Documents", | |
| placeholder="Ask about your uploaded documents..." | |
| ) | |
| query_button = gr.Button("Query") | |
| query_output = gr.Markdown(label="Response") | |
| # Track the selected model | |
| selected_model = gr.State("claude-3-sonnet") | |
| # Update selected model when tabs change | |
| def update_selected_model(evt: gr.SelectData): | |
| model_ids = list(llm_configs.keys()) | |
| if evt.index < len(model_ids): | |
| return model_ids[evt.index] | |
| return "claude-3-sonnet" # Default | |
| model_tabs.select(update_selected_model, outputs=[selected_model]) | |
| # Get or initialize agent and return both agent and status message | |
| def get_or_initialize_agent(model_name): | |
| """Initialize an agent if not already initialized and return a status message""" | |
| init_message = f"Initializing {model_name}..." | |
| # If agent already exists, return it with a status message | |
| if model_name in audit_agents: | |
| return audit_agents[model_name], f"{model_name} ready" | |
| # Try to initialize the agent | |
| try: | |
| config = llm_configs[model_name] | |
| logging.info(init_message) | |
| agent = AuditAgent(config["name"], config["provider"]) | |
| audit_agents[model_name] = agent | |
| success_message = f"{model_name} initialized successfully" | |
| logging.info(success_message) | |
| return agent, success_message | |
| except Exception as e: | |
| error_message = f"Error initializing {model_name}: {str(e)}" | |
| logging.error(error_message) | |
| return None, error_message | |
| # Handle chat with history | |
| def respond_to_chat(message, history, model_name): | |
| if not message.strip(): | |
| return "", history | |
| # Get or initialize agent | |
| agent, init_status = get_or_initialize_agent(model_name) | |
| # If initialization failed | |
| if agent is None: | |
| history.append((message, f"Could not initialize {model_name}. {init_status}")) | |
| return "", history, f"Error: {init_status}" | |
| # Process the query | |
| try: | |
| result = agent.process_query(message) | |
| history.append((message, result)) | |
| return "", history, f"Response from {model_name}" | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| history.append((message, error_msg)) | |
| return "", history, error_msg | |
| # Clear chat history | |
| def clear_chat_history(): | |
| return [], "Chat history cleared" | |
| # Handle numerical problem | |
| def handle_problem(problem, model_name): | |
| if not problem.strip(): | |
| return "Please provide a problem description", "No problem entered" | |
| status = f"Solving problem with {model_name}..." | |
| # Get or initialize agent | |
| agent, init_status = get_or_initialize_agent(model_name) | |
| # If initialization failed | |
| if agent is None: | |
| return f"Could not initialize {model_name}. {init_status}", init_status | |
| # Process the problem | |
| try: | |
| result = agent.process_query(problem) | |
| return result, f"Problem solved with {model_name}" | |
| except Exception as e: | |
| error_msg = f"Error solving problem: {str(e)}" | |
| return error_msg, error_msg | |
| # Improved file upload handler for multiple files | |
| def handle_file_upload(file_paths, model_name): | |
| if not file_paths: | |
| return "No files uploaded. Please upload files." | |
| # Get or initialize agent | |
| agent, init_status = get_or_initialize_agent(model_name) | |
| # If initialization failed | |
| if agent is None: | |
| return init_status | |
| logging.info(f"Processing {len(file_paths)} files") | |
| # Process all documents | |
| try: | |
| results = agent.process_documents(file_paths) | |
| # Format results | |
| output_lines = ["## Document Processing Results"] | |
| for file_path, status in results.items(): | |
| file_name = os.path.basename(file_path) | |
| if "Success" in status: | |
| output_lines.append(f"✓ {file_name}: {status}") | |
| else: | |
| output_lines.append(f"❌ {file_name}: {status}") | |
| if any("Success" in status for status in results.values()): | |
| output_lines.append("\n✅ Documents are ready for querying!") | |
| return "\n".join(output_lines) | |
| except Exception as e: | |
| logging.error(f"File upload error: {str(e)}") | |
| return f"Error processing files: {str(e)}" | |
| # Handle document query | |
| def handle_query(query, model_name): | |
| if not query.strip(): | |
| return "Please provide a query", "No query entered" | |
| status = f"Querying documents with {model_name}..." | |
| # Get or initialize agent | |
| agent, init_status = get_or_initialize_agent(model_name) | |
| # If initialization failed | |
| if agent is None: | |
| return f"Could not initialize {model_name}. {init_status}", init_status | |
| # Query the documents | |
| try: | |
| result = agent.query_documents(query) | |
| return result, f"Documents queried with {model_name}" | |
| except Exception as e: | |
| error_msg = f"Error querying documents: {str(e)}" | |
| return error_msg, error_msg | |
| # Set up event handlers | |
| chat_button.click( | |
| respond_to_chat, | |
| inputs=[chat_input, chat_history, selected_model], | |
| outputs=[chat_input, chat_history, status_message] | |
| ) | |
| chat_clear.click( | |
| clear_chat_history, | |
| outputs=[chat_history, status_message] | |
| ) | |
| solve_button.click( | |
| handle_problem, | |
| inputs=[problem_input, selected_model], | |
| outputs=[solution_output, status_message] | |
| ) | |
| upload_button.click( | |
| handle_file_upload, | |
| inputs=[file_upload, selected_model], | |
| outputs=[upload_output] | |
| ) | |
| query_button.click( | |
| handle_query, | |
| inputs=[query_input, selected_model], | |
| outputs=[query_output, status_message] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) |