|
|
""" |
|
|
Enterprise-Grade Repository Layer for Database Operations |
|
|
Provides clean interface with tenant isolation, transactions, and error handling |
|
|
""" |
|
|
import logging |
|
|
from typing import List, Optional, Dict, Any |
|
|
from datetime import datetime |
|
|
from sqlalchemy import select, update, delete, and_, or_ |
|
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
|
from sqlalchemy.orm import selectinload |
|
|
|
|
|
from .models import ( |
|
|
Company, Prospect, Contact, Fact, Activity, |
|
|
Suppression, Handoff, AuditLog |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class BaseRepository: |
|
|
"""Base repository with common operations and tenant isolation""" |
|
|
|
|
|
def __init__(self, session: AsyncSession, tenant_id: Optional[str] = None): |
|
|
self.session = session |
|
|
self.tenant_id = tenant_id |
|
|
|
|
|
def _apply_tenant_filter(self, query, model): |
|
|
"""Apply tenant filter to query if tenant_id is set""" |
|
|
if self.tenant_id and hasattr(model, 'tenant_id'): |
|
|
return query.where(model.tenant_id == self.tenant_id) |
|
|
return query |
|
|
|
|
|
async def _log_audit( |
|
|
self, |
|
|
action: str, |
|
|
resource_type: str, |
|
|
resource_id: str, |
|
|
old_value: Optional[Dict] = None, |
|
|
new_value: Optional[Dict] = None, |
|
|
user_id: Optional[str] = None |
|
|
): |
|
|
"""Log audit trail""" |
|
|
audit_log = AuditLog( |
|
|
tenant_id=self.tenant_id, |
|
|
user_id=user_id, |
|
|
action=action, |
|
|
resource_type=resource_type, |
|
|
resource_id=resource_id, |
|
|
old_value=old_value, |
|
|
new_value=new_value |
|
|
) |
|
|
self.session.add(audit_log) |
|
|
|
|
|
|
|
|
class CompanyRepository(BaseRepository): |
|
|
"""Repository for Company operations""" |
|
|
|
|
|
async def create(self, company_data: Dict[str, Any]) -> Company: |
|
|
"""Create a new company""" |
|
|
if self.tenant_id: |
|
|
company_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
company = Company(**company_data) |
|
|
self.session.add(company) |
|
|
await self.session.flush() |
|
|
|
|
|
await self._log_audit('create', 'company', company.id, new_value=company_data) |
|
|
logger.info(f"Created company: {company.id}") |
|
|
return company |
|
|
|
|
|
async def get_by_id(self, company_id: str) -> Optional[Company]: |
|
|
"""Get company by ID""" |
|
|
query = select(Company).where(Company.id == company_id) |
|
|
query = self._apply_tenant_filter(query, Company) |
|
|
result = await self.session.execute(query) |
|
|
return result.scalar_one_or_none() |
|
|
|
|
|
async def get_by_domain(self, domain: str) -> Optional[Company]: |
|
|
"""Get company by domain""" |
|
|
query = select(Company).where(Company.domain == domain.lower()) |
|
|
query = self._apply_tenant_filter(query, Company) |
|
|
result = await self.session.execute(query) |
|
|
return result.scalar_one_or_none() |
|
|
|
|
|
async def list( |
|
|
self, |
|
|
limit: int = 100, |
|
|
offset: int = 0, |
|
|
industry: Optional[str] = None, |
|
|
is_active: bool = True |
|
|
) -> List[Company]: |
|
|
"""List companies with filters""" |
|
|
query = select(Company) |
|
|
query = self._apply_tenant_filter(query, Company) |
|
|
|
|
|
if is_active is not None: |
|
|
query = query.where(Company.is_active == is_active) |
|
|
if industry: |
|
|
query = query.where(Company.industry == industry) |
|
|
|
|
|
query = query.limit(limit).offset(offset).order_by(Company.created_at.desc()) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
async def update(self, company_id: str, company_data: Dict[str, Any]) -> Optional[Company]: |
|
|
"""Update a company""" |
|
|
company = await self.get_by_id(company_id) |
|
|
if not company: |
|
|
return None |
|
|
|
|
|
old_data = {key: getattr(company, key) for key in company_data.keys() if hasattr(company, key)} |
|
|
|
|
|
for key, value in company_data.items(): |
|
|
if hasattr(company, key): |
|
|
setattr(company, key, value) |
|
|
|
|
|
await self.session.flush() |
|
|
await self._log_audit('update', 'company', company_id, old_value=old_data, new_value=company_data) |
|
|
|
|
|
logger.info(f"Updated company: {company_id}") |
|
|
return company |
|
|
|
|
|
async def delete(self, company_id: str) -> bool: |
|
|
"""Delete a company (soft delete by marking inactive)""" |
|
|
company = await self.get_by_id(company_id) |
|
|
if not company: |
|
|
return False |
|
|
|
|
|
company.is_active = False |
|
|
await self.session.flush() |
|
|
await self._log_audit('delete', 'company', company_id) |
|
|
|
|
|
logger.info(f"Soft deleted company: {company_id}") |
|
|
return True |
|
|
|
|
|
|
|
|
class ProspectRepository(BaseRepository): |
|
|
"""Repository for Prospect operations""" |
|
|
|
|
|
async def create(self, prospect_data: Dict[str, Any]) -> Prospect: |
|
|
"""Create a new prospect""" |
|
|
if self.tenant_id: |
|
|
prospect_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
prospect = Prospect(**prospect_data) |
|
|
self.session.add(prospect) |
|
|
await self.session.flush() |
|
|
|
|
|
await self._log_audit('create', 'prospect', prospect.id, new_value=prospect_data) |
|
|
logger.info(f"Created prospect: {prospect.id}") |
|
|
return prospect |
|
|
|
|
|
async def get_by_id(self, prospect_id: str, load_relationships: bool = False) -> Optional[Prospect]: |
|
|
"""Get prospect by ID with optional relationship loading""" |
|
|
query = select(Prospect).where(Prospect.id == prospect_id) |
|
|
query = self._apply_tenant_filter(query, Prospect) |
|
|
|
|
|
if load_relationships: |
|
|
query = query.options( |
|
|
selectinload(Prospect.company), |
|
|
selectinload(Prospect.activities), |
|
|
selectinload(Prospect.handoffs) |
|
|
) |
|
|
|
|
|
result = await self.session.execute(query) |
|
|
return result.scalar_one_or_none() |
|
|
|
|
|
async def list( |
|
|
self, |
|
|
limit: int = 100, |
|
|
offset: int = 0, |
|
|
status: Optional[str] = None, |
|
|
stage: Optional[str] = None, |
|
|
min_score: Optional[float] = None |
|
|
) -> List[Prospect]: |
|
|
"""List prospects with filters""" |
|
|
query = select(Prospect) |
|
|
query = self._apply_tenant_filter(query, Prospect) |
|
|
|
|
|
if status: |
|
|
query = query.where(Prospect.status == status) |
|
|
if stage: |
|
|
query = query.where(Prospect.stage == stage) |
|
|
if min_score is not None: |
|
|
query = query.where(Prospect.overall_score >= min_score) |
|
|
|
|
|
query = query.limit(limit).offset(offset).order_by(Prospect.created_at.desc()) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
async def update(self, prospect_id: str, prospect_data: Dict[str, Any]) -> Optional[Prospect]: |
|
|
"""Update a prospect""" |
|
|
prospect = await self.get_by_id(prospect_id) |
|
|
if not prospect: |
|
|
return None |
|
|
|
|
|
old_data = {key: getattr(prospect, key) for key in prospect_data.keys() if hasattr(prospect, key)} |
|
|
|
|
|
for key, value in prospect_data.items(): |
|
|
if hasattr(prospect, key): |
|
|
setattr(prospect, key, value) |
|
|
|
|
|
await self.session.flush() |
|
|
await self._log_audit('update', 'prospect', prospect_id, old_value=old_data, new_value=prospect_data) |
|
|
|
|
|
logger.info(f"Updated prospect: {prospect_id}") |
|
|
return prospect |
|
|
|
|
|
async def update_score( |
|
|
self, |
|
|
prospect_id: str, |
|
|
fit_score: Optional[float] = None, |
|
|
engagement_score: Optional[float] = None, |
|
|
intent_score: Optional[float] = None |
|
|
) -> Optional[Prospect]: |
|
|
"""Update prospect scores and calculate overall score""" |
|
|
prospect = await self.get_by_id(prospect_id) |
|
|
if not prospect: |
|
|
return None |
|
|
|
|
|
if fit_score is not None: |
|
|
prospect.fit_score = fit_score |
|
|
if engagement_score is not None: |
|
|
prospect.engagement_score = engagement_score |
|
|
if intent_score is not None: |
|
|
prospect.intent_score = intent_score |
|
|
|
|
|
|
|
|
scores = [] |
|
|
if prospect.fit_score is not None: |
|
|
scores.append(prospect.fit_score * 0.5) |
|
|
if prospect.engagement_score is not None: |
|
|
scores.append(prospect.engagement_score * 0.3) |
|
|
if prospect.intent_score is not None: |
|
|
scores.append(prospect.intent_score * 0.2) |
|
|
|
|
|
if scores: |
|
|
prospect.overall_score = sum(scores) / (len(scores) * 0.1) * 0.1 |
|
|
|
|
|
await self.session.flush() |
|
|
logger.info(f"Updated prospect scores: {prospect_id}") |
|
|
return prospect |
|
|
|
|
|
|
|
|
class ContactRepository(BaseRepository): |
|
|
"""Repository for Contact operations""" |
|
|
|
|
|
async def create(self, contact_data: Dict[str, Any]) -> Contact: |
|
|
"""Create a new contact""" |
|
|
if self.tenant_id: |
|
|
contact_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
|
|
|
if 'email' in contact_data: |
|
|
contact_data['email'] = contact_data['email'].lower() |
|
|
|
|
|
contact = Contact(**contact_data) |
|
|
self.session.add(contact) |
|
|
await self.session.flush() |
|
|
|
|
|
await self._log_audit('create', 'contact', contact.id, new_value=contact_data) |
|
|
logger.info(f"Created contact: {contact.id}") |
|
|
return contact |
|
|
|
|
|
async def get_by_id(self, contact_id: str) -> Optional[Contact]: |
|
|
"""Get contact by ID""" |
|
|
query = select(Contact).where(Contact.id == contact_id) |
|
|
query = self._apply_tenant_filter(query, Contact) |
|
|
result = await self.session.execute(query) |
|
|
return result.scalar_one_or_none() |
|
|
|
|
|
async def get_by_email(self, email: str) -> Optional[Contact]: |
|
|
"""Get contact by email""" |
|
|
query = select(Contact).where(Contact.email == email.lower()) |
|
|
query = self._apply_tenant_filter(query, Contact) |
|
|
result = await self.session.execute(query) |
|
|
return result.scalar_one_or_none() |
|
|
|
|
|
async def list_by_company(self, company_id: str) -> List[Contact]: |
|
|
"""List contacts for a company""" |
|
|
query = select(Contact).where(Contact.company_id == company_id) |
|
|
query = self._apply_tenant_filter(query, Contact) |
|
|
query = query.where(Contact.is_active == True).order_by(Contact.is_primary_contact.desc()) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
async def list_by_domain(self, domain: str) -> List[Contact]: |
|
|
"""List contacts by domain (from email)""" |
|
|
query = select(Contact).where(Contact.email.endswith(f"@{domain}")) |
|
|
query = self._apply_tenant_filter(query, Contact) |
|
|
query = query.where(Contact.is_active == True) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
|
|
|
class FactRepository(BaseRepository): |
|
|
"""Repository for Fact operations""" |
|
|
|
|
|
async def create(self, fact_data: Dict[str, Any]) -> Fact: |
|
|
"""Create a new fact""" |
|
|
if self.tenant_id: |
|
|
fact_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
fact = Fact(**fact_data) |
|
|
self.session.add(fact) |
|
|
await self.session.flush() |
|
|
|
|
|
logger.info(f"Created fact: {fact.id}") |
|
|
return fact |
|
|
|
|
|
async def list_by_company( |
|
|
self, |
|
|
company_id: str, |
|
|
fact_type: Optional[str] = None, |
|
|
limit: int = 50 |
|
|
) -> List[Fact]: |
|
|
"""List facts for a company""" |
|
|
query = select(Fact).where(Fact.company_id == company_id) |
|
|
query = self._apply_tenant_filter(query, Fact) |
|
|
|
|
|
if fact_type: |
|
|
query = query.where(Fact.fact_type == fact_type) |
|
|
|
|
|
query = query.order_by(Fact.published_at.desc()).limit(limit) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
|
|
|
class ActivityRepository(BaseRepository): |
|
|
"""Repository for Activity operations""" |
|
|
|
|
|
async def create(self, activity_data: Dict[str, Any]) -> Activity: |
|
|
"""Create a new activity""" |
|
|
if self.tenant_id: |
|
|
activity_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
activity = Activity(**activity_data) |
|
|
self.session.add(activity) |
|
|
await self.session.flush() |
|
|
|
|
|
logger.info(f"Created activity: {activity.id}") |
|
|
return activity |
|
|
|
|
|
async def list_by_prospect( |
|
|
self, |
|
|
prospect_id: str, |
|
|
activity_type: Optional[str] = None, |
|
|
limit: int = 100 |
|
|
) -> List[Activity]: |
|
|
"""List activities for a prospect""" |
|
|
query = select(Activity).where(Activity.prospect_id == prospect_id) |
|
|
query = self._apply_tenant_filter(query, Activity) |
|
|
|
|
|
if activity_type: |
|
|
query = query.where(Activity.activity_type == activity_type) |
|
|
|
|
|
query = query.order_by(Activity.created_at.desc()).limit(limit) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
|
|
|
class SuppressionRepository(BaseRepository): |
|
|
"""Repository for Suppression operations""" |
|
|
|
|
|
async def create(self, suppression_data: Dict[str, Any]) -> Suppression: |
|
|
"""Create a new suppression""" |
|
|
if self.tenant_id: |
|
|
suppression_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
|
|
|
if 'value' in suppression_data: |
|
|
suppression_data['value'] = suppression_data['value'].lower() |
|
|
|
|
|
suppression = Suppression(**suppression_data) |
|
|
self.session.add(suppression) |
|
|
await self.session.flush() |
|
|
|
|
|
logger.info(f"Created suppression: {suppression.id}") |
|
|
return suppression |
|
|
|
|
|
async def check( |
|
|
self, |
|
|
suppression_type: str, |
|
|
value: str |
|
|
) -> bool: |
|
|
"""Check if a value is suppressed""" |
|
|
value = value.lower() |
|
|
|
|
|
query = select(Suppression).where( |
|
|
and_( |
|
|
Suppression.suppression_type == suppression_type, |
|
|
Suppression.value == value |
|
|
) |
|
|
) |
|
|
query = self._apply_tenant_filter(query, Suppression) |
|
|
|
|
|
|
|
|
query = query.where( |
|
|
or_( |
|
|
Suppression.expires_at.is_(None), |
|
|
Suppression.expires_at > datetime.utcnow() |
|
|
) |
|
|
) |
|
|
|
|
|
result = await self.session.execute(query) |
|
|
suppression = result.scalar_one_or_none() |
|
|
|
|
|
return suppression is not None |
|
|
|
|
|
async def list( |
|
|
self, |
|
|
suppression_type: Optional[str] = None, |
|
|
limit: int = 100 |
|
|
) -> List[Suppression]: |
|
|
"""List suppressions""" |
|
|
query = select(Suppression) |
|
|
query = self._apply_tenant_filter(query, Suppression) |
|
|
|
|
|
if suppression_type: |
|
|
query = query.where(Suppression.suppression_type == suppression_type) |
|
|
|
|
|
|
|
|
query = query.where( |
|
|
or_( |
|
|
Suppression.expires_at.is_(None), |
|
|
Suppression.expires_at > datetime.utcnow() |
|
|
) |
|
|
) |
|
|
|
|
|
query = query.limit(limit).order_by(Suppression.created_at.desc()) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
|
|
|
class HandoffRepository(BaseRepository): |
|
|
"""Repository for Handoff operations""" |
|
|
|
|
|
async def create(self, handoff_data: Dict[str, Any]) -> Handoff: |
|
|
"""Create a new handoff""" |
|
|
if self.tenant_id: |
|
|
handoff_data['tenant_id'] = self.tenant_id |
|
|
|
|
|
handoff = Handoff(**handoff_data) |
|
|
self.session.add(handoff) |
|
|
await self.session.flush() |
|
|
|
|
|
await self._log_audit('create', 'handoff', handoff.id, new_value=handoff_data) |
|
|
logger.info(f"Created handoff: {handoff.id}") |
|
|
return handoff |
|
|
|
|
|
async def get_by_id(self, handoff_id: str) -> Optional[Handoff]: |
|
|
"""Get handoff by ID""" |
|
|
query = select(Handoff).where(Handoff.id == handoff_id) |
|
|
query = self._apply_tenant_filter(query, Handoff) |
|
|
result = await self.session.execute(query) |
|
|
return result.scalar_one_or_none() |
|
|
|
|
|
async def list( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
priority: Optional[str] = None, |
|
|
assigned_to: Optional[str] = None, |
|
|
limit: int = 100 |
|
|
) -> List[Handoff]: |
|
|
"""List handoffs with filters""" |
|
|
query = select(Handoff) |
|
|
query = self._apply_tenant_filter(query, Handoff) |
|
|
|
|
|
if status: |
|
|
query = query.where(Handoff.status == status) |
|
|
if priority: |
|
|
query = query.where(Handoff.priority == priority) |
|
|
if assigned_to: |
|
|
query = query.where(Handoff.assigned_to == assigned_to) |
|
|
|
|
|
query = query.limit(limit).order_by(Handoff.created_at.desc()) |
|
|
result = await self.session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
async def update(self, handoff_id: str, handoff_data: Dict[str, Any]) -> Optional[Handoff]: |
|
|
"""Update a handoff""" |
|
|
handoff = await self.get_by_id(handoff_id) |
|
|
if not handoff: |
|
|
return None |
|
|
|
|
|
old_data = {key: getattr(handoff, key) for key in handoff_data.keys() if hasattr(handoff, key)} |
|
|
|
|
|
for key, value in handoff_data.items(): |
|
|
if hasattr(handoff, key): |
|
|
setattr(handoff, key, value) |
|
|
|
|
|
await self.session.flush() |
|
|
await self._log_audit('update', 'handoff', handoff_id, old_value=old_data, new_value=handoff_data) |
|
|
|
|
|
logger.info(f"Updated handoff: {handoff_id}") |
|
|
return handoff |
|
|
|