Spaces:
Build error
Build error
| import os | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| import warnings | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Set, Union | |
| from datetime import datetime | |
| import pytesseract | |
| from pdf2image import convert_from_path | |
| import numpy as np | |
| from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.documents import Document | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| class RiskLevel: | |
| LOW = "Low" | |
| MEDIUM = "Medium" | |
| HIGH = "High" | |
| CRITICAL = "Critical" | |
| class DocumentProcessor: | |
| """Enhanced document processing with OCR support.""" | |
| def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap | |
| ) | |
| def process_document(self, content: bytes, doc_type: str) -> List[Document]: | |
| """Process document content based on type.""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=doc_type) as temp_file: | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| documents = self.load_document(temp_file_path) | |
| return self.split_documents(documents) | |
| finally: | |
| os.unlink(temp_file_path) | |
| def load_document(self, file_path: Union[str, Path]) -> List[Document]: | |
| """Load document using appropriate loader with OCR support.""" | |
| file_path = Path(file_path) | |
| suffix = file_path.suffix.lower() | |
| if suffix == '.pdf': | |
| # Try normal PDF loading first | |
| try: | |
| loader = PyPDFLoader(str(file_path)) | |
| documents = loader.load() | |
| if not any(doc.page_content.strip() for doc in documents): | |
| raise ValueError("No text content found") | |
| return documents | |
| except: | |
| # If normal loading fails, try OCR | |
| return self._process_pdf_with_ocr(file_path) | |
| elif suffix == '.docx': | |
| loader = Docx2txtLoader(str(file_path)) | |
| return loader.load() | |
| elif suffix == '.txt': | |
| loader = TextLoader(str(file_path)) | |
| return loader.load() | |
| else: | |
| raise ValueError(f"Unsupported file type: {suffix}") | |
| def _process_pdf_with_ocr(self, file_path: Path) -> List[Document]: | |
| """Process PDF with OCR using Tesseract.""" | |
| documents = [] | |
| images = convert_from_path(str(file_path)) | |
| for i, image in enumerate(images): | |
| text = pytesseract.image_to_string(image) | |
| if text.strip(): | |
| documents.append(Document( | |
| page_content=text, | |
| metadata={"source": str(file_path), "page": i + 1} | |
| )) | |
| return documents | |
| def split_documents(self, documents: List[Document]) -> List[Document]: | |
| """Split documents into chunks.""" | |
| return self.text_splitter.split_documents(documents) | |
| class ComplianceAssistant: | |
| """Compliance and Audit Assistant with risk assessment capabilities.""" | |
| def __init__(self, openai_api_key: str): | |
| self.openai_api_key = openai_api_key | |
| self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
| self.vector_store = None | |
| self.doc_processor = DocumentProcessor() | |
| self.llm = ChatOpenAI( | |
| temperature=0, | |
| model_name="gpt-4", | |
| openai_api_key=openai_api_key | |
| ) | |
| def process_documents(self, file_paths: List[str]) -> Dict[str, str]: | |
| """Process documents and add to knowledge base.""" | |
| results = {} | |
| for file_path in file_paths: | |
| try: | |
| with open(file_path, 'rb') as f: | |
| content = f.read() | |
| doc_type = Path(file_path).suffix | |
| texts = self.doc_processor.process_document(content, doc_type) | |
| if self.vector_store is None: | |
| self.vector_store = FAISS.from_documents(texts, self.embeddings) | |
| else: | |
| self.vector_store.add_documents(texts) | |
| results[file_path] = "Success" | |
| except Exception as e: | |
| results[file_path] = f"Error: {str(e)}" | |
| return results | |
| def get_compliance_response(self, query: str) -> Dict[str, Any]: | |
| """Generate compliance-focused response to query.""" | |
| if not query.strip(): | |
| raise ValueError("Query cannot be empty") | |
| if self.vector_store is None: | |
| raise RuntimeError("No compliance documents have been processed yet") | |
| # Create the retrieval chain | |
| retriever = self.vector_store.as_retriever(search_kwargs={"k": 4}) | |
| # Create the compliance-focused prompt template | |
| template = """You are a compliance and audit expert. Answer the following question based on the provided context: | |
| Context: {context} | |
| Question: {question} | |
| Provide a detailed answer that: | |
| 1. Addresses compliance requirements and regulations | |
| 2. Identifies potential risks and their severity | |
| 3. Suggests mitigation strategies where applicable | |
| 4. Cites specific sources and regulations | |
| Response:""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| # Create the chain | |
| chain = ( | |
| { | |
| "context": retriever, | |
| "question": RunnablePassthrough() | |
| } | |
| | prompt | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| # Get response | |
| answer = chain.invoke(query) | |
| # Get source documents using the new invoke method | |
| source_docs = retriever.invoke(query) | |
| return { | |
| "answer": answer, | |
| "sources": self._format_sources(source_docs) | |
| } | |
| def generate_risk_assessment(self, document_path: str) -> Dict[str, Any]: | |
| """Generate risk assessment for a specific document.""" | |
| try: | |
| with open(document_path, 'rb') as f: | |
| content = f.read() | |
| texts = self.doc_processor.process_document(content, Path(document_path).suffix) | |
| # Create risk assessment prompt | |
| template = """Analyze the following audit document content and provide a structured risk assessment: | |
| Content: {content} | |
| Provide: | |
| 1. Executive Summary | |
| 2. Key Risk Factors (with severity ratings) | |
| 3. Compliance Issues | |
| 4. Recommended Actions | |
| 5. Timeline for Remediation | |
| Assessment:""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| # Combine all text content | |
| full_content = "\n".join([doc.page_content for doc in texts]) | |
| # Generate assessment | |
| chain = prompt | self.llm | StrOutputParser() | |
| assessment = chain.invoke({"content": full_content}) | |
| return { | |
| "assessment": assessment, | |
| "document": Path(document_path).name, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| raise RuntimeError(f"Risk assessment failed: {str(e)}") | |
| def _format_sources(self, source_documents: List[Document]) -> Set[str]: | |
| """Format source references.""" | |
| return {Path(doc.metadata['source']).name for doc in source_documents} | |
| def create_gradio_interface(assistant: ComplianceAssistant) -> gr.Blocks: | |
| """Create Gradio interface for compliance assistant.""" | |
| def handle_file_upload(files: List[tempfile._TemporaryFileWrapper]) -> str: | |
| try: | |
| if not files: | |
| return "No files uploaded." | |
| results = assistant.process_documents([f.name for f in files]) | |
| output_lines = [] | |
| for file_path, status in results.items(): | |
| file_name = Path(file_path).name | |
| if status == "Success": | |
| output_lines.append(f"✓ Successfully processed {file_name}") | |
| else: | |
| output_lines.append(f"❌ {file_name}: {status}") | |
| return "\n".join(output_lines) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def handle_compliance_query(query: str) -> str: | |
| try: | |
| result = assistant.get_compliance_response(query) | |
| response = result["answer"] | |
| if result["sources"]: | |
| response += f"\n\nSources: {', '.join(result['sources'])}" | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def handle_risk_assessment(file: tempfile._TemporaryFileWrapper) -> str: | |
| try: | |
| if not file: | |
| return "No file selected for risk assessment." | |
| result = assistant.generate_risk_assessment(file.name) | |
| return f"Risk Assessment for {result['document']}\n\n{result['assessment']}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create interface | |
| with gr.Blocks(title="Compliance and Audit Assistant") as interface: | |
| gr.Markdown("# Compliance and Audit Assistant") | |
| with gr.Tab("Document Processing"): | |
| with gr.Row(): | |
| file_input = gr.File( | |
| file_count="multiple", | |
| label="Upload Compliance Documents (PDF, DOCX, TXT)" | |
| ) | |
| upload_button = gr.Button("Process Documents") | |
| upload_output = gr.Textbox(label="Processing Status") | |
| with gr.Tab("Compliance Query"): | |
| with gr.Row(): | |
| query_input = gr.Textbox( | |
| lines=3, | |
| label="Enter your compliance or regulatory query" | |
| ) | |
| query_button = gr.Button("Submit Query") | |
| query_output = gr.Textbox( | |
| lines=10, | |
| label="Assistant Response" | |
| ) | |
| with gr.Tab("Risk Assessment"): | |
| with gr.Row(): | |
| assessment_file = gr.File( | |
| label="Select Document for Risk Assessment" | |
| ) | |
| assess_button = gr.Button("Generate Risk Assessment") | |
| assessment_output = gr.Textbox( | |
| lines=15, | |
| label="Risk Assessment Report" | |
| ) | |
| # Set up event handlers | |
| upload_button.click( | |
| fn=handle_file_upload, | |
| inputs=[file_input], | |
| outputs=[upload_output] | |
| ) | |
| query_button.click( | |
| fn=handle_compliance_query, | |
| inputs=[query_input], | |
| outputs=[query_output] | |
| ) | |
| assess_button.click( | |
| fn=handle_risk_assessment, | |
| inputs=[assessment_file], | |
| outputs=[assessment_output] | |
| ) | |
| return interface | |
| def main(): | |
| """Main function to run the compliance assistant.""" | |
| # Get OpenAI API key | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| api_key = input("Please enter your OpenAI API key: ") | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| # Initialize assistant | |
| assistant = ComplianceAssistant(api_key) | |
| # Launch interface | |
| interface = create_gradio_interface(assistant) | |
| interface.launch(share=True, debug=True) | |
| if __name__ == "__main__": | |
| main() |