| """
|
| Neo4j Database Manager
|
| Handle graph database connections and operations
|
| """
|
|
|
| from neo4j import GraphDatabase
|
| from typing import Dict, List, Optional, Any
|
| import yaml
|
| import logging
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class DatabaseManager:
|
| """Manage Neo4j database connections and schema"""
|
|
|
| def __init__(self, config_path: str = "config.yml"):
|
| with open(config_path, 'r') as f:
|
| self.config = yaml.safe_load(f)['neo4j']
|
|
|
| self.driver = GraphDatabase.driver(
|
| self.config['uri'],
|
| auth=(self.config['username'], self.config['password'])
|
| )
|
|
|
| logger.info(f"Connected to Neo4j at {self.config['uri']}")
|
|
|
| def close(self):
|
| """Close database connection"""
|
| self.driver.close()
|
|
|
| def __enter__(self):
|
| return self
|
|
|
| def __exit__(self, exc_type, exc_val, exc_tb):
|
| self.close()
|
|
|
| def execute_query(self, query: str, parameters: Optional[Dict] = None) -> List[Dict]:
|
| """Execute a Cypher query and return results"""
|
| with self.driver.session() as session:
|
| result = session.run(query, parameters or {})
|
| return [record.data() for record in result]
|
|
|
| def initialize_schema(self):
|
| """Initialize database schema with constraints and indexes"""
|
| queries = [
|
|
|
| "CREATE CONSTRAINT gene_id IF NOT EXISTS FOR (g:Gene) REQUIRE g.gene_id IS UNIQUE",
|
| "CREATE CONSTRAINT mutation_id IF NOT EXISTS FOR (m:Mutation) REQUIRE m.mutation_id IS UNIQUE",
|
| "CREATE CONSTRAINT patient_id IF NOT EXISTS FOR (p:Patient) REQUIRE p.patient_id IS UNIQUE",
|
| "CREATE CONSTRAINT cancer_type_id IF NOT EXISTS FOR (c:CancerType) REQUIRE c.cancer_type_id IS UNIQUE",
|
|
|
|
|
| "CREATE INDEX gene_symbol IF NOT EXISTS FOR (g:Gene) ON (g.symbol)",
|
| "CREATE INDEX mutation_position IF NOT EXISTS FOR (m:Mutation) ON (m.chromosome, m.position)",
|
| "CREATE INDEX patient_project IF NOT EXISTS FOR (p:Patient) ON (p.project_id)",
|
| ]
|
|
|
| with self.driver.session() as session:
|
| for query in queries:
|
| try:
|
| session.run(query)
|
| logger.info(f"Executed: {query[:50]}...")
|
| except Exception as e:
|
| logger.warning(f"Schema query failed (may already exist): {e}")
|
|
|
| logger.info("Database schema initialized")
|
|
|
| def clear_database(self):
|
| """Clear all nodes and relationships (use with caution!)"""
|
| query = "MATCH (n) DETACH DELETE n"
|
| with self.driver.session() as session:
|
| session.run(query)
|
| logger.info("Database cleared")
|
|
|
|
|
| class GeneRepository:
|
| """Repository for Gene nodes"""
|
|
|
| def __init__(self, db_manager: DatabaseManager):
|
| self.db = db_manager
|
|
|
| def create_gene(self, gene_data: Dict) -> Dict:
|
| """Create a Gene node"""
|
| query = """
|
| MERGE (g:Gene {gene_id: $gene_id})
|
| SET g.symbol = $symbol,
|
| g.name = $name,
|
| g.chromosome = $chromosome,
|
| g.start_position = $start_position,
|
| g.end_position = $end_position,
|
| g.strand = $strand,
|
| g.gene_type = $gene_type
|
| RETURN g
|
| """
|
| result = self.db.execute_query(query, gene_data)
|
| return result[0]['g'] if result else {}
|
|
|
| def get_gene_by_symbol(self, symbol: str) -> Optional[Dict]:
|
| """Find gene by symbol"""
|
| query = """
|
| MATCH (g:Gene {symbol: $symbol})
|
| RETURN g
|
| """
|
| result = self.db.execute_query(query, {'symbol': symbol})
|
| return result[0]['g'] if result else None
|
|
|
| def get_gene_mutations(self, gene_id: str) -> List[Dict]:
|
| """Get all mutations for a gene"""
|
| query = """
|
| MATCH (g:Gene {gene_id: $gene_id})<-[:AFFECTS]-(m:Mutation)
|
| RETURN m
|
| ORDER BY m.position
|
| """
|
| result = self.db.execute_query(query, {'gene_id': gene_id})
|
| return [r['m'] for r in result]
|
|
|
|
|
| class MutationRepository:
|
| """Repository for Mutation nodes"""
|
|
|
| def __init__(self, db_manager: DatabaseManager):
|
| self.db = db_manager
|
|
|
| def create_mutation(self, mutation_data: Dict, gene_id: str) -> Dict:
|
| """Create a Mutation node and link to Gene"""
|
| query = """
|
| MATCH (g:Gene {gene_id: $gene_id})
|
| MERGE (m:Mutation {mutation_id: $mutation_id})
|
| SET m.chromosome = $chromosome,
|
| m.position = $position,
|
| m.reference = $reference,
|
| m.alternate = $alternate,
|
| m.consequence = $consequence,
|
| m.variant_type = $variant_type,
|
| m.quality = $quality
|
| MERGE (m)-[:AFFECTS]->(g)
|
| RETURN m
|
| """
|
| params = {**mutation_data, 'gene_id': gene_id}
|
| result = self.db.execute_query(query, params)
|
| return result[0]['m'] if result else {}
|
|
|
| def link_mutation_to_patient(self, mutation_id: str, patient_id: str, properties: Optional[Dict] = None):
|
| """Create HAS_MUTATION relationship"""
|
| query = """
|
| MATCH (p:Patient {patient_id: $patient_id})
|
| MATCH (m:Mutation {mutation_id: $mutation_id})
|
| MERGE (p)-[r:HAS_MUTATION]->(m)
|
| SET r.allele_frequency = $allele_frequency,
|
| r.depth = $depth
|
| RETURN r
|
| """
|
| params = {
|
| 'patient_id': patient_id,
|
| 'mutation_id': mutation_id,
|
| 'allele_frequency': properties.get('allele_frequency', 0) if properties else 0,
|
| 'depth': properties.get('depth', 0) if properties else 0
|
| }
|
| self.db.execute_query(query, params)
|
|
|
| def get_mutation_frequency(self, mutation_id: str) -> Dict:
|
| """Calculate mutation frequency across patients"""
|
| query = """
|
| MATCH (m:Mutation {mutation_id: $mutation_id})
|
| MATCH (p:Patient)-[:HAS_MUTATION]->(m)
|
| OPTIONAL MATCH (all:Patient)
|
| WITH m, count(DISTINCT p) as patients_with_mutation, count(DISTINCT all) as total_patients
|
| RETURN m.mutation_id as mutation_id,
|
| patients_with_mutation,
|
| total_patients,
|
| toFloat(patients_with_mutation) / total_patients as frequency
|
| """
|
| result = self.db.execute_query(query, {'mutation_id': mutation_id})
|
| return result[0] if result else {}
|
|
|
|
|
| class PatientRepository:
|
| """Repository for Patient nodes"""
|
|
|
| def __init__(self, db_manager: DatabaseManager):
|
| self.db = db_manager
|
|
|
| def create_patient(self, patient_data: Dict) -> Dict:
|
| """Create a Patient node"""
|
| query = """
|
| MERGE (p:Patient {patient_id: $patient_id})
|
| SET p.project_id = $project_id,
|
| p.age = $age,
|
| p.gender = $gender,
|
| p.race = $race,
|
| p.ethnicity = $ethnicity,
|
| p.vital_status = $vital_status
|
| RETURN p
|
| """
|
| result = self.db.execute_query(query, patient_data)
|
| return result[0]['p'] if result else {}
|
|
|
| def link_patient_to_cancer_type(self, patient_id: str, cancer_type_id: str, properties: Optional[Dict] = None):
|
| """Create DIAGNOSED_WITH relationship"""
|
| query = """
|
| MATCH (p:Patient {patient_id: $patient_id})
|
| MATCH (c:CancerType {cancer_type_id: $cancer_type_id})
|
| MERGE (p)-[r:DIAGNOSED_WITH]->(c)
|
| SET r.stage = $stage,
|
| r.grade = $grade,
|
| r.diagnosis_date = $diagnosis_date
|
| RETURN r
|
| """
|
| params = {
|
| 'patient_id': patient_id,
|
| 'cancer_type_id': cancer_type_id,
|
| 'stage': properties.get('stage') if properties else None,
|
| 'grade': properties.get('grade') if properties else None,
|
| 'diagnosis_date': properties.get('diagnosis_date') if properties else None
|
| }
|
| self.db.execute_query(query, params)
|
|
|
| def get_patient_mutations(self, patient_id: str) -> List[Dict]:
|
| """Get all mutations for a patient"""
|
| query = """
|
| MATCH (p:Patient {patient_id: $patient_id})-[r:HAS_MUTATION]->(m:Mutation)-[:AFFECTS]->(g:Gene)
|
| RETURN m, g, r.allele_frequency as allele_frequency, r.depth as depth
|
| ORDER BY g.symbol
|
| """
|
| result = self.db.execute_query(query, {'patient_id': patient_id})
|
| return result
|
|
|
|
|
| class CancerTypeRepository:
|
| """Repository for CancerType nodes"""
|
|
|
| def __init__(self, db_manager: DatabaseManager):
|
| self.db = db_manager
|
|
|
| def create_cancer_type(self, cancer_data: Dict) -> Dict:
|
| """Create a CancerType node"""
|
| query = """
|
| MERGE (c:CancerType {cancer_type_id: $cancer_type_id})
|
| SET c.name = $name,
|
| c.tissue = $tissue,
|
| c.disease_type = $disease_type
|
| RETURN c
|
| """
|
| result = self.db.execute_query(query, cancer_data)
|
| return result[0]['c'] if result else {}
|
|
|
| def get_common_mutations(self, cancer_type_id: str, limit: int = 10) -> List[Dict]:
|
| """Get most common mutations for a cancer type"""
|
| query = """
|
| MATCH (c:CancerType {cancer_type_id: $cancer_type_id})<-[:DIAGNOSED_WITH]-(p:Patient)
|
| MATCH (p)-[:HAS_MUTATION]->(m:Mutation)-[:AFFECTS]->(g:Gene)
|
| WITH m, g, count(DISTINCT p) as patient_count
|
| RETURN m, g, patient_count
|
| ORDER BY patient_count DESC
|
| LIMIT $limit
|
| """
|
| result = self.db.execute_query(query, {'cancer_type_id': cancer_type_id, 'limit': limit})
|
| return result
|
|
|
| def get_statistics(self, cancer_type_id: str) -> Dict:
|
| """Get statistics for a cancer type"""
|
| query = """
|
| MATCH (c:CancerType {cancer_type_id: $cancer_type_id})<-[:DIAGNOSED_WITH]-(p:Patient)
|
| OPTIONAL MATCH (p)-[:HAS_MUTATION]->(m:Mutation)
|
| WITH c, count(DISTINCT p) as total_patients, count(DISTINCT m) as total_mutations
|
| RETURN c.name as cancer_type,
|
| total_patients,
|
| total_mutations,
|
| CASE WHEN total_patients > 0
|
| THEN toFloat(total_mutations) / total_patients
|
| ELSE 0
|
| END as avg_mutations_per_patient
|
| """
|
| result = self.db.execute_query(query, {'cancer_type_id': cancer_type_id})
|
| return result[0] if result else {}
|
|
|