import gradio as gr import os import json import pickle from datetime import datetime import requests from bs4 import BeautifulSoup import fitz # PyMuPDF for PDF processing import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import sqlite3 import hashlib from typing import List, Dict, Any, Tuple import logging import tempfile import shutil from urllib.parse import urlparse, urljoin import re # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MedicalRAGSystem: def __init__(self): self.embedding_model = None self.db_path = "medical_rag.db" self.embeddings_cache = {} self.init_database() self.load_embedding_model() def load_embedding_model(self): """Load a free sentence transformer model""" try: # Using a lightweight, free model suitable for regulatory text self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') logger.info("Embedding model loaded successfully") except Exception as e: logger.error(f"Error loading embedding model: {e}") return None def init_database(self): """Initialize SQLite database for persistent storage""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create tables for different source types cursor.execute(''' CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT NOT NULL, content TEXT NOT NULL, content_hash TEXT UNIQUE, category TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, metadata TEXT ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS websites ( id INTEGER PRIMARY KEY AUTOINCREMENT, url TEXT NOT NULL, content TEXT NOT NULL, content_hash TEXT UNIQUE, title TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, metadata TEXT ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS standards ( id INTEGER PRIMARY KEY AUTOINCREMENT, standard_name TEXT NOT NULL, content TEXT NOT NULL, content_hash TEXT UNIQUE, version TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, metadata TEXT ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS embeddings ( id INTEGER PRIMARY KEY AUTOINCREMENT, source_type TEXT NOT NULL, source_id INTEGER NOT NULL, chunk_index INTEGER NOT NULL, embedding BLOB NOT NULL, text_chunk TEXT NOT NULL ) ''') conn.commit() conn.close() logger.info("Database initialized successfully") def get_content_hash(self, content: str) -> str: """Generate hash for content to avoid duplicates""" return hashlib.md5(content.encode()).hexdigest() def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: """Split text into overlapping chunks for better retrieval""" words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) if chunk.strip(): chunks.append(chunk) return chunks def process_pdf_document(self, file_path: str) -> Tuple[str, Dict]: """Extract text content from PDF documents""" try: doc = fitz.open(file_path) text_content = "" metadata = {"pages": doc.page_count, "format": "PDF"} for page_num in range(doc.page_count): page = doc[page_num] text_content += page.get_text() doc.close() return text_content, metadata except Exception as e: logger.error(f"Error processing PDF: {e}") return "", {} def process_text_document(self, file_path: str) -> Tuple[str, Dict]: """Process text documents""" try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content, {"format": "TEXT"} except Exception as e: logger.error(f"Error processing text document: {e}") return "", {} def scrape_website(self, url: str) -> Tuple[str, str, Dict]: """Scrape content from regulatory websites""" try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, headers=headers, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.content, 'html.parser') # Remove script and style elements for script in soup(["script", "style"]): script.decompose() # Get title title = soup.title.string if soup.title else url # Extract main content content = soup.get_text() content = re.sub(r'\s+', ' ', content).strip() metadata = { "title": title, "url": url, "scraped_at": datetime.now().isoformat() } return content, title, metadata except Exception as e: logger.error(f"Error scraping website {url}: {e}") return "", "", {} def add_document(self, file_path: str, filename: str, category: str) -> str: """Add document to the knowledge base""" try: # Determine file type and process accordingly if filename.lower().endswith('.pdf'): content, metadata = self.process_pdf_document(file_path) else: content, metadata = self.process_text_document(file_path) if not content: return "Error: Could not extract content from document" content_hash = self.get_content_hash(content) # Store in database conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute(''' INSERT INTO documents (filename, content, content_hash, category, metadata) VALUES (?, ?, ?, ?, ?) ''', (filename, content, content_hash, category, json.dumps(metadata))) doc_id = cursor.lastrowid conn.commit() # Generate embeddings self.generate_embeddings_for_content(content, 'document', doc_id) conn.close() return f"Document '{filename}' added successfully to category '{category}'" except sqlite3.IntegrityError: conn.close() return "Document already exists in the knowledge base" except Exception as e: logger.error(f"Error adding document: {e}") return f"Error adding document: {str(e)}" def add_website(self, url: str) -> str: """Add website content to the knowledge base""" try: content, title, metadata = self.scrape_website(url) if not content: return "Error: Could not scrape website content" content_hash = self.get_content_hash(content) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute(''' INSERT INTO websites (url, content, content_hash, title, metadata) VALUES (?, ?, ?, ?, ?) ''', (url, content, content_hash, title, json.dumps(metadata))) website_id = cursor.lastrowid conn.commit() # Generate embeddings self.generate_embeddings_for_content(content, 'website', website_id) conn.close() return f"Website '{title}' added successfully" except sqlite3.IntegrityError: conn.close() return "Website already exists in the knowledge base" except Exception as e: logger.error(f"Error adding website: {e}") return f"Error adding website: {str(e)}" def add_standard(self, standard_name: str, content: str, version: str = "") -> str: """Add standard content to the knowledge base""" try: if not content.strip(): return "Error: Standard content cannot be empty" content_hash = self.get_content_hash(content) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() metadata = {"version": version, "added_at": datetime.now().isoformat()} try: cursor.execute(''' INSERT INTO standards (standard_name, content, content_hash, version, metadata) VALUES (?, ?, ?, ?, ?) ''', (standard_name, content, content_hash, version, json.dumps(metadata))) standard_id = cursor.lastrowid conn.commit() # Generate embeddings self.generate_embeddings_for_content(content, 'standard', standard_id) conn.close() return f"Standard '{standard_name}' added successfully" except sqlite3.IntegrityError: conn.close() return "Standard already exists in the knowledge base" except Exception as e: logger.error(f"Error adding standard: {e}") return f"Error adding standard: {str(e)}" def generate_embeddings_for_content(self, content: str, source_type: str, source_id: int): """Generate embeddings for content chunks""" if not self.embedding_model: logger.error("Embedding model not available") return chunks = self.chunk_text(content) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() for i, chunk in enumerate(chunks): try: embedding = self.embedding_model.encode(chunk) embedding_blob = pickle.dumps(embedding) cursor.execute(''' INSERT INTO embeddings (source_type, source_id, chunk_index, embedding, text_chunk) VALUES (?, ?, ?, ?, ?) ''', (source_type, source_id, i, embedding_blob, chunk)) except Exception as e: logger.error(f"Error generating embedding for chunk {i}: {e}") conn.commit() conn.close() def search_knowledge_base(self, query: str, top_k: int = 5) -> List[Dict]: """Search the knowledge base using semantic similarity""" if not self.embedding_model: return [] try: query_embedding = self.embedding_model.encode(query) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Get all embeddings cursor.execute(''' SELECT e.source_type, e.source_id, e.text_chunk, e.embedding, CASE WHEN e.source_type = 'document' THEN d.filename WHEN e.source_type = 'website' THEN w.title WHEN e.source_type = 'standard' THEN s.standard_name END as source_name FROM embeddings e LEFT JOIN documents d ON e.source_type = 'document' AND e.source_id = d.id LEFT JOIN websites w ON e.source_type = 'website' AND e.source_id = w.id LEFT JOIN standards s ON e.source_type = 'standard' AND e.source_id = s.id ''') results = [] for row in cursor.fetchall(): try: stored_embedding = pickle.loads(row[3]) similarity = cosine_similarity([query_embedding], [stored_embedding])[0][0] results.append({ 'source_type': row[0], 'source_id': row[1], 'text_chunk': row[2], 'source_name': row[4], 'similarity': similarity }) except Exception as e: logger.error(f"Error processing embedding: {e}") conn.close() # Sort by similarity and return top k results.sort(key=lambda x: x['similarity'], reverse=True) return results[:top_k] except Exception as e: logger.error(f"Error searching knowledge base: {e}") return [] def get_knowledge_base_stats(self) -> Dict: """Get statistics about the knowledge base""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() stats = {} # Count documents cursor.execute("SELECT COUNT(*) FROM documents") stats['documents'] = cursor.fetchone()[0] # Count websites cursor.execute("SELECT COUNT(*) FROM websites") stats['websites'] = cursor.fetchone()[0] # Count standards cursor.execute("SELECT COUNT(*) FROM standards") stats['standards'] = cursor.fetchone()[0] # Count total embeddings cursor.execute("SELECT COUNT(*) FROM embeddings") stats['embeddings'] = cursor.fetchone()[0] conn.close() return stats # Initialize the RAG system rag_system = MedicalRAGSystem() def handle_document_upload(files, category): """Handle document upload""" if not files: return "No files selected" results = [] for file in files: filename = os.path.basename(file.name) result = rag_system.add_document(file.name, filename, category) results.append(result) return "\n".join(results) def handle_website_addition(url): """Handle website addition""" if not url.strip(): return "Please enter a valid URL" return rag_system.add_website(url.strip()) def handle_standard_addition(standard_name, content, version): """Handle standard addition""" if not standard_name.strip() or not content.strip(): return "Please provide both standard name and content" return rag_system.add_standard(standard_name.strip(), content.strip(), version.strip()) def handle_search(query): """Handle search queries""" if not query.strip(): return "Please enter a search query", "" results = rag_system.search_knowledge_base(query.strip()) if not results: return "No relevant results found", "" # Format results for display formatted_results = [] context = [] for i, result in enumerate(results, 1): similarity_pct = result['similarity'] * 100 formatted_results.append(f""" **Result {i}** (Similarity: {similarity_pct:.1f}%) **Source:** {result['source_name']} ({result['source_type']}) **Content:** {result['text_chunk'][:300]}{'...' if len(result['text_chunk']) > 300 else ''} --- """) context.append(result['text_chunk']) # Generate a comprehensive answer based on the context answer = generate_answer(query, context) return "\n".join(formatted_results), answer def generate_answer(query: str, context: List[str]) -> str: """Generate an answer based on the retrieved context""" # Simple extractive approach - in a production system, you might use a generative model relevant_info = [] query_lower = query.lower() for chunk in context: # Find sentences that contain query terms sentences = chunk.split('.') for sentence in sentences: if any(term in sentence.lower() for term in query_lower.split()): relevant_info.append(sentence.strip()) if relevant_info: # Remove duplicates and combine unique_info = list(dict.fromkeys(relevant_info)) return "Based on the regulatory documents:\n\n" + "\n\n".join(unique_info[:3]) else: return "The retrieved content may contain relevant information, but I couldn't extract a specific answer. Please review the search results above." def get_stats(): """Get knowledge base statistics""" stats = rag_system.get_knowledge_base_stats() return f""" Knowledge Base Statistics: - Documents: {stats['documents']} - Websites: {stats['websites']} - Standards: {stats['standards']} - Total Text Chunks: {stats['embeddings']} """ # Create Gradio interface with gr.Blocks(title="Medical Devices RAG System", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 Medical Devices Regulatory RAG System A comprehensive knowledge base system for medical device regulatory analysts. Add documents, websites, and standards to build your regulatory knowledge base. """) with gr.Tabs(): # Search Tab with gr.Tab("🔍 Search Knowledge Base"): gr.Markdown("### Search your regulatory knowledge base") search_input = gr.Textbox( placeholder="Enter your regulatory question (e.g., 'What are the requirements for Class II medical devices?')", label="Search Query", lines=2 ) search_button = gr.Button("Search", variant="primary") with gr.Row(): with gr.Column(): search_results = gr.Markdown(label="Search Results") with gr.Column(): answer_output = gr.Markdown(label="Generated Answer") search_button.click( handle_search, inputs=[search_input], outputs=[search_results, answer_output] ) # Add Documents Tab with gr.Tab("📄 Add Documents"): gr.Markdown("### Add regulatory documents (PDF, TXT)") document_files = gr.File( label="Upload Documents", file_count="multiple", file_types=[".pdf", ".txt", ".docx"] ) document_category = gr.Dropdown( choices=["EU MDR 2017/745", "CMDR SOR/98-282", "MDCG", "MDSAP Audit Approach", "UK MDR", "Other"], label="Document Category", value="Other" ) add_doc_button = gr.Button("Add Documents", variant="primary") doc_output = gr.Textbox(label="Result", lines=3) add_doc_button.click( handle_document_upload, inputs=[document_files, document_category], outputs=[doc_output] ) # Add Websites Tab with gr.Tab("🌐 Add Websites"): gr.Markdown("### Add regulatory websites") website_url = gr.Textbox( placeholder="https://www.fda.gov/medical-devices/...", label="Website URL", lines=1 ) add_website_button = gr.Button("Add Website", variant="primary") website_output = gr.Textbox(label="Result", lines=3) gr.Markdown("**Suggested regulatory websites:**") gr.Markdown(""" - US FDA 21CFR: https://www.accessdata.fda.gov/scripts/cdrh/cfdocs/cfcfr/cfrsearch.cfm - EU Medical Devices: https://ec.europa.eu/health/medical-devices-sector_en - Health Canada Medical Devices: https://www.canada.ca/en/health-canada/services/drugs-health-products/medical-devices.html """) add_website_button.click( handle_website_addition, inputs=[website_url], outputs=[website_output] ) # Add Standards Tab with gr.Tab("📋 Add Standards"): gr.Markdown("### Add regulatory standards") standard_name = gr.Textbox( placeholder="ISO 13485:2016", label="Standard Name", lines=1 ) standard_version = gr.Textbox( placeholder="2016 (optional)", label="Version", lines=1 ) standard_content = gr.Textbox( placeholder="Enter or paste the standard content here...", label="Standard Content", lines=10 ) add_standard_button = gr.Button("Add Standard", variant="primary") standard_output = gr.Textbox(label="Result", lines=3) add_standard_button.click( handle_standard_addition, inputs=[standard_name, standard_content, standard_version], outputs=[standard_output] ) # Statistics Tab with gr.Tab("📊 Knowledge Base Stats"): gr.Markdown("### Knowledge Base Statistics") stats_button = gr.Button("Refresh Statistics", variant="secondary") stats_output = gr.Textbox(label="Statistics", lines=8) stats_button.click( get_stats, outputs=[stats_output] ) # Load initial stats demo.load(get_stats, outputs=[stats_output]) if __name__ == "__main__": demo.launch(share=True)