MedRegRAG / app.py
Harrisun's picture
Create app.py
0d1b42d verified
import gradio as gr
import os
import json
import pickle
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import fitz # PyMuPDF for PDF processing
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import sqlite3
import hashlib
from typing import List, Dict, Any, Tuple
import logging
import tempfile
import shutil
from urllib.parse import urlparse, urljoin
import re
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MedicalRAGSystem:
def __init__(self):
self.embedding_model = None
self.db_path = "medical_rag.db"
self.embeddings_cache = {}
self.init_database()
self.load_embedding_model()
def load_embedding_model(self):
"""Load a free sentence transformer model"""
try:
# Using a lightweight, free model suitable for regulatory text
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
logger.info("Embedding model loaded successfully")
except Exception as e:
logger.error(f"Error loading embedding model: {e}")
return None
def init_database(self):
"""Initialize SQLite database for persistent storage"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create tables for different source types
cursor.execute('''
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT NOT NULL,
content TEXT NOT NULL,
content_hash TEXT UNIQUE,
category TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS websites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT NOT NULL,
content TEXT NOT NULL,
content_hash TEXT UNIQUE,
title TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS standards (
id INTEGER PRIMARY KEY AUTOINCREMENT,
standard_name TEXT NOT NULL,
content TEXT NOT NULL,
content_hash TEXT UNIQUE,
version TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_type TEXT NOT NULL,
source_id INTEGER NOT NULL,
chunk_index INTEGER NOT NULL,
embedding BLOB NOT NULL,
text_chunk TEXT NOT NULL
)
''')
conn.commit()
conn.close()
logger.info("Database initialized successfully")
def get_content_hash(self, content: str) -> str:
"""Generate hash for content to avoid duplicates"""
return hashlib.md5(content.encode()).hexdigest()
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks for better retrieval"""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk)
return chunks
def process_pdf_document(self, file_path: str) -> Tuple[str, Dict]:
"""Extract text content from PDF documents"""
try:
doc = fitz.open(file_path)
text_content = ""
metadata = {"pages": doc.page_count, "format": "PDF"}
for page_num in range(doc.page_count):
page = doc[page_num]
text_content += page.get_text()
doc.close()
return text_content, metadata
except Exception as e:
logger.error(f"Error processing PDF: {e}")
return "", {}
def process_text_document(self, file_path: str) -> Tuple[str, Dict]:
"""Process text documents"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content, {"format": "TEXT"}
except Exception as e:
logger.error(f"Error processing text document: {e}")
return "", {}
def scrape_website(self, url: str) -> Tuple[str, str, Dict]:
"""Scrape content from regulatory websites"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
soup = BeautifulSoup(response.content, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style"]):
script.decompose()
# Get title
title = soup.title.string if soup.title else url
# Extract main content
content = soup.get_text()
content = re.sub(r'\s+', ' ', content).strip()
metadata = {
"title": title,
"url": url,
"scraped_at": datetime.now().isoformat()
}
return content, title, metadata
except Exception as e:
logger.error(f"Error scraping website {url}: {e}")
return "", "", {}
def add_document(self, file_path: str, filename: str, category: str) -> str:
"""Add document to the knowledge base"""
try:
# Determine file type and process accordingly
if filename.lower().endswith('.pdf'):
content, metadata = self.process_pdf_document(file_path)
else:
content, metadata = self.process_text_document(file_path)
if not content:
return "Error: Could not extract content from document"
content_hash = self.get_content_hash(content)
# Store in database
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute('''
INSERT INTO documents (filename, content, content_hash, category, metadata)
VALUES (?, ?, ?, ?, ?)
''', (filename, content, content_hash, category, json.dumps(metadata)))
doc_id = cursor.lastrowid
conn.commit()
# Generate embeddings
self.generate_embeddings_for_content(content, 'document', doc_id)
conn.close()
return f"Document '{filename}' added successfully to category '{category}'"
except sqlite3.IntegrityError:
conn.close()
return "Document already exists in the knowledge base"
except Exception as e:
logger.error(f"Error adding document: {e}")
return f"Error adding document: {str(e)}"
def add_website(self, url: str) -> str:
"""Add website content to the knowledge base"""
try:
content, title, metadata = self.scrape_website(url)
if not content:
return "Error: Could not scrape website content"
content_hash = self.get_content_hash(content)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute('''
INSERT INTO websites (url, content, content_hash, title, metadata)
VALUES (?, ?, ?, ?, ?)
''', (url, content, content_hash, title, json.dumps(metadata)))
website_id = cursor.lastrowid
conn.commit()
# Generate embeddings
self.generate_embeddings_for_content(content, 'website', website_id)
conn.close()
return f"Website '{title}' added successfully"
except sqlite3.IntegrityError:
conn.close()
return "Website already exists in the knowledge base"
except Exception as e:
logger.error(f"Error adding website: {e}")
return f"Error adding website: {str(e)}"
def add_standard(self, standard_name: str, content: str, version: str = "") -> str:
"""Add standard content to the knowledge base"""
try:
if not content.strip():
return "Error: Standard content cannot be empty"
content_hash = self.get_content_hash(content)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
metadata = {"version": version, "added_at": datetime.now().isoformat()}
try:
cursor.execute('''
INSERT INTO standards (standard_name, content, content_hash, version, metadata)
VALUES (?, ?, ?, ?, ?)
''', (standard_name, content, content_hash, version, json.dumps(metadata)))
standard_id = cursor.lastrowid
conn.commit()
# Generate embeddings
self.generate_embeddings_for_content(content, 'standard', standard_id)
conn.close()
return f"Standard '{standard_name}' added successfully"
except sqlite3.IntegrityError:
conn.close()
return "Standard already exists in the knowledge base"
except Exception as e:
logger.error(f"Error adding standard: {e}")
return f"Error adding standard: {str(e)}"
def generate_embeddings_for_content(self, content: str, source_type: str, source_id: int):
"""Generate embeddings for content chunks"""
if not self.embedding_model:
logger.error("Embedding model not available")
return
chunks = self.chunk_text(content)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
for i, chunk in enumerate(chunks):
try:
embedding = self.embedding_model.encode(chunk)
embedding_blob = pickle.dumps(embedding)
cursor.execute('''
INSERT INTO embeddings (source_type, source_id, chunk_index, embedding, text_chunk)
VALUES (?, ?, ?, ?, ?)
''', (source_type, source_id, i, embedding_blob, chunk))
except Exception as e:
logger.error(f"Error generating embedding for chunk {i}: {e}")
conn.commit()
conn.close()
def search_knowledge_base(self, query: str, top_k: int = 5) -> List[Dict]:
"""Search the knowledge base using semantic similarity"""
if not self.embedding_model:
return []
try:
query_embedding = self.embedding_model.encode(query)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Get all embeddings
cursor.execute('''
SELECT e.source_type, e.source_id, e.text_chunk, e.embedding,
CASE
WHEN e.source_type = 'document' THEN d.filename
WHEN e.source_type = 'website' THEN w.title
WHEN e.source_type = 'standard' THEN s.standard_name
END as source_name
FROM embeddings e
LEFT JOIN documents d ON e.source_type = 'document' AND e.source_id = d.id
LEFT JOIN websites w ON e.source_type = 'website' AND e.source_id = w.id
LEFT JOIN standards s ON e.source_type = 'standard' AND e.source_id = s.id
''')
results = []
for row in cursor.fetchall():
try:
stored_embedding = pickle.loads(row[3])
similarity = cosine_similarity([query_embedding], [stored_embedding])[0][0]
results.append({
'source_type': row[0],
'source_id': row[1],
'text_chunk': row[2],
'source_name': row[4],
'similarity': similarity
})
except Exception as e:
logger.error(f"Error processing embedding: {e}")
conn.close()
# Sort by similarity and return top k
results.sort(key=lambda x: x['similarity'], reverse=True)
return results[:top_k]
except Exception as e:
logger.error(f"Error searching knowledge base: {e}")
return []
def get_knowledge_base_stats(self) -> Dict:
"""Get statistics about the knowledge base"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
stats = {}
# Count documents
cursor.execute("SELECT COUNT(*) FROM documents")
stats['documents'] = cursor.fetchone()[0]
# Count websites
cursor.execute("SELECT COUNT(*) FROM websites")
stats['websites'] = cursor.fetchone()[0]
# Count standards
cursor.execute("SELECT COUNT(*) FROM standards")
stats['standards'] = cursor.fetchone()[0]
# Count total embeddings
cursor.execute("SELECT COUNT(*) FROM embeddings")
stats['embeddings'] = cursor.fetchone()[0]
conn.close()
return stats
# Initialize the RAG system
rag_system = MedicalRAGSystem()
def handle_document_upload(files, category):
"""Handle document upload"""
if not files:
return "No files selected"
results = []
for file in files:
filename = os.path.basename(file.name)
result = rag_system.add_document(file.name, filename, category)
results.append(result)
return "\n".join(results)
def handle_website_addition(url):
"""Handle website addition"""
if not url.strip():
return "Please enter a valid URL"
return rag_system.add_website(url.strip())
def handle_standard_addition(standard_name, content, version):
"""Handle standard addition"""
if not standard_name.strip() or not content.strip():
return "Please provide both standard name and content"
return rag_system.add_standard(standard_name.strip(), content.strip(), version.strip())
def handle_search(query):
"""Handle search queries"""
if not query.strip():
return "Please enter a search query", ""
results = rag_system.search_knowledge_base(query.strip())
if not results:
return "No relevant results found", ""
# Format results for display
formatted_results = []
context = []
for i, result in enumerate(results, 1):
similarity_pct = result['similarity'] * 100
formatted_results.append(f"""
**Result {i}** (Similarity: {similarity_pct:.1f}%)
**Source:** {result['source_name']} ({result['source_type']})
**Content:** {result['text_chunk'][:300]}{'...' if len(result['text_chunk']) > 300 else ''}
---
""")
context.append(result['text_chunk'])
# Generate a comprehensive answer based on the context
answer = generate_answer(query, context)
return "\n".join(formatted_results), answer
def generate_answer(query: str, context: List[str]) -> str:
"""Generate an answer based on the retrieved context"""
# Simple extractive approach - in a production system, you might use a generative model
relevant_info = []
query_lower = query.lower()
for chunk in context:
# Find sentences that contain query terms
sentences = chunk.split('.')
for sentence in sentences:
if any(term in sentence.lower() for term in query_lower.split()):
relevant_info.append(sentence.strip())
if relevant_info:
# Remove duplicates and combine
unique_info = list(dict.fromkeys(relevant_info))
return "Based on the regulatory documents:\n\n" + "\n\n".join(unique_info[:3])
else:
return "The retrieved content may contain relevant information, but I couldn't extract a specific answer. Please review the search results above."
def get_stats():
"""Get knowledge base statistics"""
stats = rag_system.get_knowledge_base_stats()
return f"""
Knowledge Base Statistics:
- Documents: {stats['documents']}
- Websites: {stats['websites']}
- Standards: {stats['standards']}
- Total Text Chunks: {stats['embeddings']}
"""
# Create Gradio interface
with gr.Blocks(title="Medical Devices RAG System", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ₯ Medical Devices Regulatory RAG System
A comprehensive knowledge base system for medical device regulatory analysts.
Add documents, websites, and standards to build your regulatory knowledge base.
""")
with gr.Tabs():
# Search Tab
with gr.Tab("πŸ” Search Knowledge Base"):
gr.Markdown("### Search your regulatory knowledge base")
search_input = gr.Textbox(
placeholder="Enter your regulatory question (e.g., 'What are the requirements for Class II medical devices?')",
label="Search Query",
lines=2
)
search_button = gr.Button("Search", variant="primary")
with gr.Row():
with gr.Column():
search_results = gr.Markdown(label="Search Results")
with gr.Column():
answer_output = gr.Markdown(label="Generated Answer")
search_button.click(
handle_search,
inputs=[search_input],
outputs=[search_results, answer_output]
)
# Add Documents Tab
with gr.Tab("πŸ“„ Add Documents"):
gr.Markdown("### Add regulatory documents (PDF, TXT)")
document_files = gr.File(
label="Upload Documents",
file_count="multiple",
file_types=[".pdf", ".txt", ".docx"]
)
document_category = gr.Dropdown(
choices=["EU MDR 2017/745", "CMDR SOR/98-282", "MDCG", "MDSAP Audit Approach", "UK MDR", "Other"],
label="Document Category",
value="Other"
)
add_doc_button = gr.Button("Add Documents", variant="primary")
doc_output = gr.Textbox(label="Result", lines=3)
add_doc_button.click(
handle_document_upload,
inputs=[document_files, document_category],
outputs=[doc_output]
)
# Add Websites Tab
with gr.Tab("🌐 Add Websites"):
gr.Markdown("### Add regulatory websites")
website_url = gr.Textbox(
placeholder="https://www.fda.gov/medical-devices/...",
label="Website URL",
lines=1
)
add_website_button = gr.Button("Add Website", variant="primary")
website_output = gr.Textbox(label="Result", lines=3)
gr.Markdown("**Suggested regulatory websites:**")
gr.Markdown("""
- US FDA 21CFR: https://www.accessdata.fda.gov/scripts/cdrh/cfdocs/cfcfr/cfrsearch.cfm
- EU Medical Devices: https://ec.europa.eu/health/medical-devices-sector_en
- Health Canada Medical Devices: https://www.canada.ca/en/health-canada/services/drugs-health-products/medical-devices.html
""")
add_website_button.click(
handle_website_addition,
inputs=[website_url],
outputs=[website_output]
)
# Add Standards Tab
with gr.Tab("πŸ“‹ Add Standards"):
gr.Markdown("### Add regulatory standards")
standard_name = gr.Textbox(
placeholder="ISO 13485:2016",
label="Standard Name",
lines=1
)
standard_version = gr.Textbox(
placeholder="2016 (optional)",
label="Version",
lines=1
)
standard_content = gr.Textbox(
placeholder="Enter or paste the standard content here...",
label="Standard Content",
lines=10
)
add_standard_button = gr.Button("Add Standard", variant="primary")
standard_output = gr.Textbox(label="Result", lines=3)
add_standard_button.click(
handle_standard_addition,
inputs=[standard_name, standard_content, standard_version],
outputs=[standard_output]
)
# Statistics Tab
with gr.Tab("πŸ“Š Knowledge Base Stats"):
gr.Markdown("### Knowledge Base Statistics")
stats_button = gr.Button("Refresh Statistics", variant="secondary")
stats_output = gr.Textbox(label="Statistics", lines=8)
stats_button.click(
get_stats,
outputs=[stats_output]
)
# Load initial stats
demo.load(get_stats, outputs=[stats_output])
if __name__ == "__main__":
demo.launch(share=True)