""" Enterprise-Grade Repository Layer for Database Operations Provides clean interface with tenant isolation, transactions, and error handling """ import logging from typing import List, Optional, Dict, Any from datetime import datetime from sqlalchemy import select, update, delete, and_, or_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from .models import ( Company, Prospect, Contact, Fact, Activity, Suppression, Handoff, AuditLog ) logger = logging.getLogger(__name__) class BaseRepository: """Base repository with common operations and tenant isolation""" def __init__(self, session: AsyncSession, tenant_id: Optional[str] = None): self.session = session self.tenant_id = tenant_id def _apply_tenant_filter(self, query, model): """Apply tenant filter to query if tenant_id is set""" if self.tenant_id and hasattr(model, 'tenant_id'): return query.where(model.tenant_id == self.tenant_id) return query async def _log_audit( self, action: str, resource_type: str, resource_id: str, old_value: Optional[Dict] = None, new_value: Optional[Dict] = None, user_id: Optional[str] = None ): """Log audit trail""" audit_log = AuditLog( tenant_id=self.tenant_id, user_id=user_id, action=action, resource_type=resource_type, resource_id=resource_id, old_value=old_value, new_value=new_value ) self.session.add(audit_log) class CompanyRepository(BaseRepository): """Repository for Company operations""" async def create(self, company_data: Dict[str, Any]) -> Company: """Create a new company""" if self.tenant_id: company_data['tenant_id'] = self.tenant_id company = Company(**company_data) self.session.add(company) await self.session.flush() await self._log_audit('create', 'company', company.id, new_value=company_data) logger.info(f"Created company: {company.id}") return company async def get_by_id(self, company_id: str) -> Optional[Company]: """Get company by ID""" query = select(Company).where(Company.id == company_id) query = self._apply_tenant_filter(query, Company) result = await self.session.execute(query) return result.scalar_one_or_none() async def get_by_domain(self, domain: str) -> Optional[Company]: """Get company by domain""" query = select(Company).where(Company.domain == domain.lower()) query = self._apply_tenant_filter(query, Company) result = await self.session.execute(query) return result.scalar_one_or_none() async def list( self, limit: int = 100, offset: int = 0, industry: Optional[str] = None, is_active: bool = True ) -> List[Company]: """List companies with filters""" query = select(Company) query = self._apply_tenant_filter(query, Company) if is_active is not None: query = query.where(Company.is_active == is_active) if industry: query = query.where(Company.industry == industry) query = query.limit(limit).offset(offset).order_by(Company.created_at.desc()) result = await self.session.execute(query) return list(result.scalars().all()) async def update(self, company_id: str, company_data: Dict[str, Any]) -> Optional[Company]: """Update a company""" company = await self.get_by_id(company_id) if not company: return None old_data = {key: getattr(company, key) for key in company_data.keys() if hasattr(company, key)} for key, value in company_data.items(): if hasattr(company, key): setattr(company, key, value) await self.session.flush() await self._log_audit('update', 'company', company_id, old_value=old_data, new_value=company_data) logger.info(f"Updated company: {company_id}") return company async def delete(self, company_id: str) -> bool: """Delete a company (soft delete by marking inactive)""" company = await self.get_by_id(company_id) if not company: return False company.is_active = False await self.session.flush() await self._log_audit('delete', 'company', company_id) logger.info(f"Soft deleted company: {company_id}") return True class ProspectRepository(BaseRepository): """Repository for Prospect operations""" async def create(self, prospect_data: Dict[str, Any]) -> Prospect: """Create a new prospect""" if self.tenant_id: prospect_data['tenant_id'] = self.tenant_id prospect = Prospect(**prospect_data) self.session.add(prospect) await self.session.flush() await self._log_audit('create', 'prospect', prospect.id, new_value=prospect_data) logger.info(f"Created prospect: {prospect.id}") return prospect async def get_by_id(self, prospect_id: str, load_relationships: bool = False) -> Optional[Prospect]: """Get prospect by ID with optional relationship loading""" query = select(Prospect).where(Prospect.id == prospect_id) query = self._apply_tenant_filter(query, Prospect) if load_relationships: query = query.options( selectinload(Prospect.company), selectinload(Prospect.activities), selectinload(Prospect.handoffs) ) result = await self.session.execute(query) return result.scalar_one_or_none() async def list( self, limit: int = 100, offset: int = 0, status: Optional[str] = None, stage: Optional[str] = None, min_score: Optional[float] = None ) -> List[Prospect]: """List prospects with filters""" query = select(Prospect) query = self._apply_tenant_filter(query, Prospect) if status: query = query.where(Prospect.status == status) if stage: query = query.where(Prospect.stage == stage) if min_score is not None: query = query.where(Prospect.overall_score >= min_score) query = query.limit(limit).offset(offset).order_by(Prospect.created_at.desc()) result = await self.session.execute(query) return list(result.scalars().all()) async def update(self, prospect_id: str, prospect_data: Dict[str, Any]) -> Optional[Prospect]: """Update a prospect""" prospect = await self.get_by_id(prospect_id) if not prospect: return None old_data = {key: getattr(prospect, key) for key in prospect_data.keys() if hasattr(prospect, key)} for key, value in prospect_data.items(): if hasattr(prospect, key): setattr(prospect, key, value) await self.session.flush() await self._log_audit('update', 'prospect', prospect_id, old_value=old_data, new_value=prospect_data) logger.info(f"Updated prospect: {prospect_id}") return prospect async def update_score( self, prospect_id: str, fit_score: Optional[float] = None, engagement_score: Optional[float] = None, intent_score: Optional[float] = None ) -> Optional[Prospect]: """Update prospect scores and calculate overall score""" prospect = await self.get_by_id(prospect_id) if not prospect: return None if fit_score is not None: prospect.fit_score = fit_score if engagement_score is not None: prospect.engagement_score = engagement_score if intent_score is not None: prospect.intent_score = intent_score # Calculate overall score (weighted average) scores = [] if prospect.fit_score is not None: scores.append(prospect.fit_score * 0.5) # 50% weight if prospect.engagement_score is not None: scores.append(prospect.engagement_score * 0.3) # 30% weight if prospect.intent_score is not None: scores.append(prospect.intent_score * 0.2) # 20% weight if scores: prospect.overall_score = sum(scores) / (len(scores) * 0.1) * 0.1 await self.session.flush() logger.info(f"Updated prospect scores: {prospect_id}") return prospect class ContactRepository(BaseRepository): """Repository for Contact operations""" async def create(self, contact_data: Dict[str, Any]) -> Contact: """Create a new contact""" if self.tenant_id: contact_data['tenant_id'] = self.tenant_id # Normalize email if 'email' in contact_data: contact_data['email'] = contact_data['email'].lower() contact = Contact(**contact_data) self.session.add(contact) await self.session.flush() await self._log_audit('create', 'contact', contact.id, new_value=contact_data) logger.info(f"Created contact: {contact.id}") return contact async def get_by_id(self, contact_id: str) -> Optional[Contact]: """Get contact by ID""" query = select(Contact).where(Contact.id == contact_id) query = self._apply_tenant_filter(query, Contact) result = await self.session.execute(query) return result.scalar_one_or_none() async def get_by_email(self, email: str) -> Optional[Contact]: """Get contact by email""" query = select(Contact).where(Contact.email == email.lower()) query = self._apply_tenant_filter(query, Contact) result = await self.session.execute(query) return result.scalar_one_or_none() async def list_by_company(self, company_id: str) -> List[Contact]: """List contacts for a company""" query = select(Contact).where(Contact.company_id == company_id) query = self._apply_tenant_filter(query, Contact) query = query.where(Contact.is_active == True).order_by(Contact.is_primary_contact.desc()) result = await self.session.execute(query) return list(result.scalars().all()) async def list_by_domain(self, domain: str) -> List[Contact]: """List contacts by domain (from email)""" query = select(Contact).where(Contact.email.endswith(f"@{domain}")) query = self._apply_tenant_filter(query, Contact) query = query.where(Contact.is_active == True) result = await self.session.execute(query) return list(result.scalars().all()) class FactRepository(BaseRepository): """Repository for Fact operations""" async def create(self, fact_data: Dict[str, Any]) -> Fact: """Create a new fact""" if self.tenant_id: fact_data['tenant_id'] = self.tenant_id fact = Fact(**fact_data) self.session.add(fact) await self.session.flush() logger.info(f"Created fact: {fact.id}") return fact async def list_by_company( self, company_id: str, fact_type: Optional[str] = None, limit: int = 50 ) -> List[Fact]: """List facts for a company""" query = select(Fact).where(Fact.company_id == company_id) query = self._apply_tenant_filter(query, Fact) if fact_type: query = query.where(Fact.fact_type == fact_type) query = query.order_by(Fact.published_at.desc()).limit(limit) result = await self.session.execute(query) return list(result.scalars().all()) class ActivityRepository(BaseRepository): """Repository for Activity operations""" async def create(self, activity_data: Dict[str, Any]) -> Activity: """Create a new activity""" if self.tenant_id: activity_data['tenant_id'] = self.tenant_id activity = Activity(**activity_data) self.session.add(activity) await self.session.flush() logger.info(f"Created activity: {activity.id}") return activity async def list_by_prospect( self, prospect_id: str, activity_type: Optional[str] = None, limit: int = 100 ) -> List[Activity]: """List activities for a prospect""" query = select(Activity).where(Activity.prospect_id == prospect_id) query = self._apply_tenant_filter(query, Activity) if activity_type: query = query.where(Activity.activity_type == activity_type) query = query.order_by(Activity.created_at.desc()).limit(limit) result = await self.session.execute(query) return list(result.scalars().all()) class SuppressionRepository(BaseRepository): """Repository for Suppression operations""" async def create(self, suppression_data: Dict[str, Any]) -> Suppression: """Create a new suppression""" if self.tenant_id: suppression_data['tenant_id'] = self.tenant_id # Normalize value if 'value' in suppression_data: suppression_data['value'] = suppression_data['value'].lower() suppression = Suppression(**suppression_data) self.session.add(suppression) await self.session.flush() logger.info(f"Created suppression: {suppression.id}") return suppression async def check( self, suppression_type: str, value: str ) -> bool: """Check if a value is suppressed""" value = value.lower() query = select(Suppression).where( and_( Suppression.suppression_type == suppression_type, Suppression.value == value ) ) query = self._apply_tenant_filter(query, Suppression) # Check expiry query = query.where( or_( Suppression.expires_at.is_(None), Suppression.expires_at > datetime.utcnow() ) ) result = await self.session.execute(query) suppression = result.scalar_one_or_none() return suppression is not None async def list( self, suppression_type: Optional[str] = None, limit: int = 100 ) -> List[Suppression]: """List suppressions""" query = select(Suppression) query = self._apply_tenant_filter(query, Suppression) if suppression_type: query = query.where(Suppression.suppression_type == suppression_type) # Only active suppressions query = query.where( or_( Suppression.expires_at.is_(None), Suppression.expires_at > datetime.utcnow() ) ) query = query.limit(limit).order_by(Suppression.created_at.desc()) result = await self.session.execute(query) return list(result.scalars().all()) class HandoffRepository(BaseRepository): """Repository for Handoff operations""" async def create(self, handoff_data: Dict[str, Any]) -> Handoff: """Create a new handoff""" if self.tenant_id: handoff_data['tenant_id'] = self.tenant_id handoff = Handoff(**handoff_data) self.session.add(handoff) await self.session.flush() await self._log_audit('create', 'handoff', handoff.id, new_value=handoff_data) logger.info(f"Created handoff: {handoff.id}") return handoff async def get_by_id(self, handoff_id: str) -> Optional[Handoff]: """Get handoff by ID""" query = select(Handoff).where(Handoff.id == handoff_id) query = self._apply_tenant_filter(query, Handoff) result = await self.session.execute(query) return result.scalar_one_or_none() async def list( self, status: Optional[str] = None, priority: Optional[str] = None, assigned_to: Optional[str] = None, limit: int = 100 ) -> List[Handoff]: """List handoffs with filters""" query = select(Handoff) query = self._apply_tenant_filter(query, Handoff) if status: query = query.where(Handoff.status == status) if priority: query = query.where(Handoff.priority == priority) if assigned_to: query = query.where(Handoff.assigned_to == assigned_to) query = query.limit(limit).order_by(Handoff.created_at.desc()) result = await self.session.execute(query) return list(result.scalars().all()) async def update(self, handoff_id: str, handoff_data: Dict[str, Any]) -> Optional[Handoff]: """Update a handoff""" handoff = await self.get_by_id(handoff_id) if not handoff: return None old_data = {key: getattr(handoff, key) for key in handoff_data.keys() if hasattr(handoff, key)} for key, value in handoff_data.items(): if hasattr(handoff, key): setattr(handoff, key, value) await self.session.flush() await self._log_audit('update', 'handoff', handoff_id, old_value=old_data, new_value=handoff_data) logger.info(f"Updated handoff: {handoff_id}") return handoff