| | import gradio as gr |
| | import os |
| | import json |
| | import pickle |
| | from datetime import datetime |
| | import requests |
| | from bs4 import BeautifulSoup |
| | import fitz |
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | import sqlite3 |
| | import hashlib |
| | from typing import List, Dict, Any, Tuple |
| | import logging |
| | import tempfile |
| | import shutil |
| | from urllib.parse import urlparse, urljoin |
| | import re |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class MedicalRAGSystem: |
| | def __init__(self): |
| | self.embedding_model = None |
| | self.db_path = "medical_rag.db" |
| | self.embeddings_cache = {} |
| | self.init_database() |
| | self.load_embedding_model() |
| | |
| | def load_embedding_model(self): |
| | """Load a free sentence transformer model""" |
| | try: |
| | |
| | self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
| | logger.info("Embedding model loaded successfully") |
| | except Exception as e: |
| | logger.error(f"Error loading embedding model: {e}") |
| | return None |
| | |
| | def init_database(self): |
| | """Initialize SQLite database for persistent storage""" |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | |
| | cursor.execute(''' |
| | CREATE TABLE IF NOT EXISTS documents ( |
| | id INTEGER PRIMARY KEY AUTOINCREMENT, |
| | filename TEXT NOT NULL, |
| | content TEXT NOT NULL, |
| | content_hash TEXT UNIQUE, |
| | category TEXT NOT NULL, |
| | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| | metadata TEXT |
| | ) |
| | ''') |
| | |
| | cursor.execute(''' |
| | CREATE TABLE IF NOT EXISTS websites ( |
| | id INTEGER PRIMARY KEY AUTOINCREMENT, |
| | url TEXT NOT NULL, |
| | content TEXT NOT NULL, |
| | content_hash TEXT UNIQUE, |
| | title TEXT, |
| | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| | metadata TEXT |
| | ) |
| | ''') |
| | |
| | cursor.execute(''' |
| | CREATE TABLE IF NOT EXISTS standards ( |
| | id INTEGER PRIMARY KEY AUTOINCREMENT, |
| | standard_name TEXT NOT NULL, |
| | content TEXT NOT NULL, |
| | content_hash TEXT UNIQUE, |
| | version TEXT, |
| | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| | metadata TEXT |
| | ) |
| | ''') |
| | |
| | cursor.execute(''' |
| | CREATE TABLE IF NOT EXISTS embeddings ( |
| | id INTEGER PRIMARY KEY AUTOINCREMENT, |
| | source_type TEXT NOT NULL, |
| | source_id INTEGER NOT NULL, |
| | chunk_index INTEGER NOT NULL, |
| | embedding BLOB NOT NULL, |
| | text_chunk TEXT NOT NULL |
| | ) |
| | ''') |
| | |
| | conn.commit() |
| | conn.close() |
| | logger.info("Database initialized successfully") |
| | |
| | def get_content_hash(self, content: str) -> str: |
| | """Generate hash for content to avoid duplicates""" |
| | return hashlib.md5(content.encode()).hexdigest() |
| | |
| | def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: |
| | """Split text into overlapping chunks for better retrieval""" |
| | words = text.split() |
| | chunks = [] |
| | |
| | for i in range(0, len(words), chunk_size - overlap): |
| | chunk = ' '.join(words[i:i + chunk_size]) |
| | if chunk.strip(): |
| | chunks.append(chunk) |
| | |
| | return chunks |
| | |
| | def process_pdf_document(self, file_path: str) -> Tuple[str, Dict]: |
| | """Extract text content from PDF documents""" |
| | try: |
| | doc = fitz.open(file_path) |
| | text_content = "" |
| | metadata = {"pages": doc.page_count, "format": "PDF"} |
| | |
| | for page_num in range(doc.page_count): |
| | page = doc[page_num] |
| | text_content += page.get_text() |
| | |
| | doc.close() |
| | return text_content, metadata |
| | except Exception as e: |
| | logger.error(f"Error processing PDF: {e}") |
| | return "", {} |
| | |
| | def process_text_document(self, file_path: str) -> Tuple[str, Dict]: |
| | """Process text documents""" |
| | try: |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| | return content, {"format": "TEXT"} |
| | except Exception as e: |
| | logger.error(f"Error processing text document: {e}") |
| | return "", {} |
| | |
| | def scrape_website(self, url: str) -> Tuple[str, str, Dict]: |
| | """Scrape content from regulatory websites""" |
| | try: |
| | headers = { |
| | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' |
| | } |
| | response = requests.get(url, headers=headers, timeout=30) |
| | response.raise_for_status() |
| | |
| | soup = BeautifulSoup(response.content, 'html.parser') |
| | |
| | |
| | for script in soup(["script", "style"]): |
| | script.decompose() |
| | |
| | |
| | title = soup.title.string if soup.title else url |
| | |
| | |
| | content = soup.get_text() |
| | content = re.sub(r'\s+', ' ', content).strip() |
| | |
| | metadata = { |
| | "title": title, |
| | "url": url, |
| | "scraped_at": datetime.now().isoformat() |
| | } |
| | |
| | return content, title, metadata |
| | |
| | except Exception as e: |
| | logger.error(f"Error scraping website {url}: {e}") |
| | return "", "", {} |
| | |
| | def add_document(self, file_path: str, filename: str, category: str) -> str: |
| | """Add document to the knowledge base""" |
| | try: |
| | |
| | if filename.lower().endswith('.pdf'): |
| | content, metadata = self.process_pdf_document(file_path) |
| | else: |
| | content, metadata = self.process_text_document(file_path) |
| | |
| | if not content: |
| | return "Error: Could not extract content from document" |
| | |
| | content_hash = self.get_content_hash(content) |
| | |
| | |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | try: |
| | cursor.execute(''' |
| | INSERT INTO documents (filename, content, content_hash, category, metadata) |
| | VALUES (?, ?, ?, ?, ?) |
| | ''', (filename, content, content_hash, category, json.dumps(metadata))) |
| | |
| | doc_id = cursor.lastrowid |
| | conn.commit() |
| | |
| | |
| | self.generate_embeddings_for_content(content, 'document', doc_id) |
| | |
| | conn.close() |
| | return f"Document '{filename}' added successfully to category '{category}'" |
| | |
| | except sqlite3.IntegrityError: |
| | conn.close() |
| | return "Document already exists in the knowledge base" |
| | |
| | except Exception as e: |
| | logger.error(f"Error adding document: {e}") |
| | return f"Error adding document: {str(e)}" |
| | |
| | def add_website(self, url: str) -> str: |
| | """Add website content to the knowledge base""" |
| | try: |
| | content, title, metadata = self.scrape_website(url) |
| | |
| | if not content: |
| | return "Error: Could not scrape website content" |
| | |
| | content_hash = self.get_content_hash(content) |
| | |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | try: |
| | cursor.execute(''' |
| | INSERT INTO websites (url, content, content_hash, title, metadata) |
| | VALUES (?, ?, ?, ?, ?) |
| | ''', (url, content, content_hash, title, json.dumps(metadata))) |
| | |
| | website_id = cursor.lastrowid |
| | conn.commit() |
| | |
| | |
| | self.generate_embeddings_for_content(content, 'website', website_id) |
| | |
| | conn.close() |
| | return f"Website '{title}' added successfully" |
| | |
| | except sqlite3.IntegrityError: |
| | conn.close() |
| | return "Website already exists in the knowledge base" |
| | |
| | except Exception as e: |
| | logger.error(f"Error adding website: {e}") |
| | return f"Error adding website: {str(e)}" |
| | |
| | def add_standard(self, standard_name: str, content: str, version: str = "") -> str: |
| | """Add standard content to the knowledge base""" |
| | try: |
| | if not content.strip(): |
| | return "Error: Standard content cannot be empty" |
| | |
| | content_hash = self.get_content_hash(content) |
| | |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | metadata = {"version": version, "added_at": datetime.now().isoformat()} |
| | |
| | try: |
| | cursor.execute(''' |
| | INSERT INTO standards (standard_name, content, content_hash, version, metadata) |
| | VALUES (?, ?, ?, ?, ?) |
| | ''', (standard_name, content, content_hash, version, json.dumps(metadata))) |
| | |
| | standard_id = cursor.lastrowid |
| | conn.commit() |
| | |
| | |
| | self.generate_embeddings_for_content(content, 'standard', standard_id) |
| | |
| | conn.close() |
| | return f"Standard '{standard_name}' added successfully" |
| | |
| | except sqlite3.IntegrityError: |
| | conn.close() |
| | return "Standard already exists in the knowledge base" |
| | |
| | except Exception as e: |
| | logger.error(f"Error adding standard: {e}") |
| | return f"Error adding standard: {str(e)}" |
| | |
| | def generate_embeddings_for_content(self, content: str, source_type: str, source_id: int): |
| | """Generate embeddings for content chunks""" |
| | if not self.embedding_model: |
| | logger.error("Embedding model not available") |
| | return |
| | |
| | chunks = self.chunk_text(content) |
| | |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | for i, chunk in enumerate(chunks): |
| | try: |
| | embedding = self.embedding_model.encode(chunk) |
| | embedding_blob = pickle.dumps(embedding) |
| | |
| | cursor.execute(''' |
| | INSERT INTO embeddings (source_type, source_id, chunk_index, embedding, text_chunk) |
| | VALUES (?, ?, ?, ?, ?) |
| | ''', (source_type, source_id, i, embedding_blob, chunk)) |
| | |
| | except Exception as e: |
| | logger.error(f"Error generating embedding for chunk {i}: {e}") |
| | |
| | conn.commit() |
| | conn.close() |
| | |
| | def search_knowledge_base(self, query: str, top_k: int = 5) -> List[Dict]: |
| | """Search the knowledge base using semantic similarity""" |
| | if not self.embedding_model: |
| | return [] |
| | |
| | try: |
| | query_embedding = self.embedding_model.encode(query) |
| | |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | |
| | cursor.execute(''' |
| | SELECT e.source_type, e.source_id, e.text_chunk, e.embedding, |
| | CASE |
| | WHEN e.source_type = 'document' THEN d.filename |
| | WHEN e.source_type = 'website' THEN w.title |
| | WHEN e.source_type = 'standard' THEN s.standard_name |
| | END as source_name |
| | FROM embeddings e |
| | LEFT JOIN documents d ON e.source_type = 'document' AND e.source_id = d.id |
| | LEFT JOIN websites w ON e.source_type = 'website' AND e.source_id = w.id |
| | LEFT JOIN standards s ON e.source_type = 'standard' AND e.source_id = s.id |
| | ''') |
| | |
| | results = [] |
| | for row in cursor.fetchall(): |
| | try: |
| | stored_embedding = pickle.loads(row[3]) |
| | similarity = cosine_similarity([query_embedding], [stored_embedding])[0][0] |
| | |
| | results.append({ |
| | 'source_type': row[0], |
| | 'source_id': row[1], |
| | 'text_chunk': row[2], |
| | 'source_name': row[4], |
| | 'similarity': similarity |
| | }) |
| | except Exception as e: |
| | logger.error(f"Error processing embedding: {e}") |
| | |
| | conn.close() |
| | |
| | |
| | results.sort(key=lambda x: x['similarity'], reverse=True) |
| | return results[:top_k] |
| | |
| | except Exception as e: |
| | logger.error(f"Error searching knowledge base: {e}") |
| | return [] |
| | |
| | def get_knowledge_base_stats(self) -> Dict: |
| | """Get statistics about the knowledge base""" |
| | conn = sqlite3.connect(self.db_path) |
| | cursor = conn.cursor() |
| | |
| | stats = {} |
| | |
| | |
| | cursor.execute("SELECT COUNT(*) FROM documents") |
| | stats['documents'] = cursor.fetchone()[0] |
| | |
| | |
| | cursor.execute("SELECT COUNT(*) FROM websites") |
| | stats['websites'] = cursor.fetchone()[0] |
| | |
| | |
| | cursor.execute("SELECT COUNT(*) FROM standards") |
| | stats['standards'] = cursor.fetchone()[0] |
| | |
| | |
| | cursor.execute("SELECT COUNT(*) FROM embeddings") |
| | stats['embeddings'] = cursor.fetchone()[0] |
| | |
| | conn.close() |
| | return stats |
| |
|
| | |
| | rag_system = MedicalRAGSystem() |
| |
|
| | def handle_document_upload(files, category): |
| | """Handle document upload""" |
| | if not files: |
| | return "No files selected" |
| | |
| | results = [] |
| | for file in files: |
| | filename = os.path.basename(file.name) |
| | result = rag_system.add_document(file.name, filename, category) |
| | results.append(result) |
| | |
| | return "\n".join(results) |
| |
|
| | def handle_website_addition(url): |
| | """Handle website addition""" |
| | if not url.strip(): |
| | return "Please enter a valid URL" |
| | |
| | return rag_system.add_website(url.strip()) |
| |
|
| | def handle_standard_addition(standard_name, content, version): |
| | """Handle standard addition""" |
| | if not standard_name.strip() or not content.strip(): |
| | return "Please provide both standard name and content" |
| | |
| | return rag_system.add_standard(standard_name.strip(), content.strip(), version.strip()) |
| |
|
| | def handle_search(query): |
| | """Handle search queries""" |
| | if not query.strip(): |
| | return "Please enter a search query", "" |
| | |
| | results = rag_system.search_knowledge_base(query.strip()) |
| | |
| | if not results: |
| | return "No relevant results found", "" |
| | |
| | |
| | formatted_results = [] |
| | context = [] |
| | |
| | for i, result in enumerate(results, 1): |
| | similarity_pct = result['similarity'] * 100 |
| | formatted_results.append(f""" |
| | **Result {i}** (Similarity: {similarity_pct:.1f}%) |
| | **Source:** {result['source_name']} ({result['source_type']}) |
| | **Content:** {result['text_chunk'][:300]}{'...' if len(result['text_chunk']) > 300 else ''} |
| | --- |
| | """) |
| | context.append(result['text_chunk']) |
| | |
| | |
| | answer = generate_answer(query, context) |
| | |
| | return "\n".join(formatted_results), answer |
| |
|
| | def generate_answer(query: str, context: List[str]) -> str: |
| | """Generate an answer based on the retrieved context""" |
| | |
| | relevant_info = [] |
| | |
| | query_lower = query.lower() |
| | for chunk in context: |
| | |
| | sentences = chunk.split('.') |
| | for sentence in sentences: |
| | if any(term in sentence.lower() for term in query_lower.split()): |
| | relevant_info.append(sentence.strip()) |
| | |
| | if relevant_info: |
| | |
| | unique_info = list(dict.fromkeys(relevant_info)) |
| | return "Based on the regulatory documents:\n\n" + "\n\n".join(unique_info[:3]) |
| | else: |
| | return "The retrieved content may contain relevant information, but I couldn't extract a specific answer. Please review the search results above." |
| |
|
| | def get_stats(): |
| | """Get knowledge base statistics""" |
| | stats = rag_system.get_knowledge_base_stats() |
| | return f""" |
| | Knowledge Base Statistics: |
| | - Documents: {stats['documents']} |
| | - Websites: {stats['websites']} |
| | - Standards: {stats['standards']} |
| | - Total Text Chunks: {stats['embeddings']} |
| | """ |
| |
|
| | |
| | with gr.Blocks(title="Medical Devices RAG System", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown(""" |
| | # π₯ Medical Devices Regulatory RAG System |
| | |
| | A comprehensive knowledge base system for medical device regulatory analysts. |
| | Add documents, websites, and standards to build your regulatory knowledge base. |
| | """) |
| | |
| | with gr.Tabs(): |
| | |
| | with gr.Tab("π Search Knowledge Base"): |
| | gr.Markdown("### Search your regulatory knowledge base") |
| | |
| | search_input = gr.Textbox( |
| | placeholder="Enter your regulatory question (e.g., 'What are the requirements for Class II medical devices?')", |
| | label="Search Query", |
| | lines=2 |
| | ) |
| | search_button = gr.Button("Search", variant="primary") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | search_results = gr.Markdown(label="Search Results") |
| | with gr.Column(): |
| | answer_output = gr.Markdown(label="Generated Answer") |
| | |
| | search_button.click( |
| | handle_search, |
| | inputs=[search_input], |
| | outputs=[search_results, answer_output] |
| | ) |
| | |
| | |
| | with gr.Tab("π Add Documents"): |
| | gr.Markdown("### Add regulatory documents (PDF, TXT)") |
| | |
| | document_files = gr.File( |
| | label="Upload Documents", |
| | file_count="multiple", |
| | file_types=[".pdf", ".txt", ".docx"] |
| | ) |
| | document_category = gr.Dropdown( |
| | choices=["EU MDR 2017/745", "CMDR SOR/98-282", "MDCG", "MDSAP Audit Approach", "UK MDR", "Other"], |
| | label="Document Category", |
| | value="Other" |
| | ) |
| | add_doc_button = gr.Button("Add Documents", variant="primary") |
| | doc_output = gr.Textbox(label="Result", lines=3) |
| | |
| | add_doc_button.click( |
| | handle_document_upload, |
| | inputs=[document_files, document_category], |
| | outputs=[doc_output] |
| | ) |
| | |
| | |
| | with gr.Tab("π Add Websites"): |
| | gr.Markdown("### Add regulatory websites") |
| | |
| | website_url = gr.Textbox( |
| | placeholder="https://www.fda.gov/medical-devices/...", |
| | label="Website URL", |
| | lines=1 |
| | ) |
| | add_website_button = gr.Button("Add Website", variant="primary") |
| | website_output = gr.Textbox(label="Result", lines=3) |
| | |
| | gr.Markdown("**Suggested regulatory websites:**") |
| | gr.Markdown(""" |
| | - US FDA 21CFR: https://www.accessdata.fda.gov/scripts/cdrh/cfdocs/cfcfr/cfrsearch.cfm |
| | - EU Medical Devices: https://ec.europa.eu/health/medical-devices-sector_en |
| | - Health Canada Medical Devices: https://www.canada.ca/en/health-canada/services/drugs-health-products/medical-devices.html |
| | """) |
| | |
| | add_website_button.click( |
| | handle_website_addition, |
| | inputs=[website_url], |
| | outputs=[website_output] |
| | ) |
| | |
| | |
| | with gr.Tab("π Add Standards"): |
| | gr.Markdown("### Add regulatory standards") |
| | |
| | standard_name = gr.Textbox( |
| | placeholder="ISO 13485:2016", |
| | label="Standard Name", |
| | lines=1 |
| | ) |
| | standard_version = gr.Textbox( |
| | placeholder="2016 (optional)", |
| | label="Version", |
| | lines=1 |
| | ) |
| | standard_content = gr.Textbox( |
| | placeholder="Enter or paste the standard content here...", |
| | label="Standard Content", |
| | lines=10 |
| | ) |
| | add_standard_button = gr.Button("Add Standard", variant="primary") |
| | standard_output = gr.Textbox(label="Result", lines=3) |
| | |
| | add_standard_button.click( |
| | handle_standard_addition, |
| | inputs=[standard_name, standard_content, standard_version], |
| | outputs=[standard_output] |
| | ) |
| | |
| | |
| | with gr.Tab("π Knowledge Base Stats"): |
| | gr.Markdown("### Knowledge Base Statistics") |
| | |
| | stats_button = gr.Button("Refresh Statistics", variant="secondary") |
| | stats_output = gr.Textbox(label="Statistics", lines=8) |
| | |
| | stats_button.click( |
| | get_stats, |
| | outputs=[stats_output] |
| | ) |
| | |
| | |
| | demo.load(get_stats, outputs=[stats_output]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |