cx_ai_agent_v1 / mcp /database /repositories.py
muzakkirhussain011's picture
Add application files (text files only)
8bab08d
"""
Enterprise-Grade Repository Layer for Database Operations
Provides clean interface with tenant isolation, transactions, and error handling
"""
import logging
from typing import List, Optional, Dict, Any
from datetime import datetime
from sqlalchemy import select, update, delete, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from .models import (
Company, Prospect, Contact, Fact, Activity,
Suppression, Handoff, AuditLog
)
logger = logging.getLogger(__name__)
class BaseRepository:
"""Base repository with common operations and tenant isolation"""
def __init__(self, session: AsyncSession, tenant_id: Optional[str] = None):
self.session = session
self.tenant_id = tenant_id
def _apply_tenant_filter(self, query, model):
"""Apply tenant filter to query if tenant_id is set"""
if self.tenant_id and hasattr(model, 'tenant_id'):
return query.where(model.tenant_id == self.tenant_id)
return query
async def _log_audit(
self,
action: str,
resource_type: str,
resource_id: str,
old_value: Optional[Dict] = None,
new_value: Optional[Dict] = None,
user_id: Optional[str] = None
):
"""Log audit trail"""
audit_log = AuditLog(
tenant_id=self.tenant_id,
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
old_value=old_value,
new_value=new_value
)
self.session.add(audit_log)
class CompanyRepository(BaseRepository):
"""Repository for Company operations"""
async def create(self, company_data: Dict[str, Any]) -> Company:
"""Create a new company"""
if self.tenant_id:
company_data['tenant_id'] = self.tenant_id
company = Company(**company_data)
self.session.add(company)
await self.session.flush()
await self._log_audit('create', 'company', company.id, new_value=company_data)
logger.info(f"Created company: {company.id}")
return company
async def get_by_id(self, company_id: str) -> Optional[Company]:
"""Get company by ID"""
query = select(Company).where(Company.id == company_id)
query = self._apply_tenant_filter(query, Company)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def get_by_domain(self, domain: str) -> Optional[Company]:
"""Get company by domain"""
query = select(Company).where(Company.domain == domain.lower())
query = self._apply_tenant_filter(query, Company)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def list(
self,
limit: int = 100,
offset: int = 0,
industry: Optional[str] = None,
is_active: bool = True
) -> List[Company]:
"""List companies with filters"""
query = select(Company)
query = self._apply_tenant_filter(query, Company)
if is_active is not None:
query = query.where(Company.is_active == is_active)
if industry:
query = query.where(Company.industry == industry)
query = query.limit(limit).offset(offset).order_by(Company.created_at.desc())
result = await self.session.execute(query)
return list(result.scalars().all())
async def update(self, company_id: str, company_data: Dict[str, Any]) -> Optional[Company]:
"""Update a company"""
company = await self.get_by_id(company_id)
if not company:
return None
old_data = {key: getattr(company, key) for key in company_data.keys() if hasattr(company, key)}
for key, value in company_data.items():
if hasattr(company, key):
setattr(company, key, value)
await self.session.flush()
await self._log_audit('update', 'company', company_id, old_value=old_data, new_value=company_data)
logger.info(f"Updated company: {company_id}")
return company
async def delete(self, company_id: str) -> bool:
"""Delete a company (soft delete by marking inactive)"""
company = await self.get_by_id(company_id)
if not company:
return False
company.is_active = False
await self.session.flush()
await self._log_audit('delete', 'company', company_id)
logger.info(f"Soft deleted company: {company_id}")
return True
class ProspectRepository(BaseRepository):
"""Repository for Prospect operations"""
async def create(self, prospect_data: Dict[str, Any]) -> Prospect:
"""Create a new prospect"""
if self.tenant_id:
prospect_data['tenant_id'] = self.tenant_id
prospect = Prospect(**prospect_data)
self.session.add(prospect)
await self.session.flush()
await self._log_audit('create', 'prospect', prospect.id, new_value=prospect_data)
logger.info(f"Created prospect: {prospect.id}")
return prospect
async def get_by_id(self, prospect_id: str, load_relationships: bool = False) -> Optional[Prospect]:
"""Get prospect by ID with optional relationship loading"""
query = select(Prospect).where(Prospect.id == prospect_id)
query = self._apply_tenant_filter(query, Prospect)
if load_relationships:
query = query.options(
selectinload(Prospect.company),
selectinload(Prospect.activities),
selectinload(Prospect.handoffs)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def list(
self,
limit: int = 100,
offset: int = 0,
status: Optional[str] = None,
stage: Optional[str] = None,
min_score: Optional[float] = None
) -> List[Prospect]:
"""List prospects with filters"""
query = select(Prospect)
query = self._apply_tenant_filter(query, Prospect)
if status:
query = query.where(Prospect.status == status)
if stage:
query = query.where(Prospect.stage == stage)
if min_score is not None:
query = query.where(Prospect.overall_score >= min_score)
query = query.limit(limit).offset(offset).order_by(Prospect.created_at.desc())
result = await self.session.execute(query)
return list(result.scalars().all())
async def update(self, prospect_id: str, prospect_data: Dict[str, Any]) -> Optional[Prospect]:
"""Update a prospect"""
prospect = await self.get_by_id(prospect_id)
if not prospect:
return None
old_data = {key: getattr(prospect, key) for key in prospect_data.keys() if hasattr(prospect, key)}
for key, value in prospect_data.items():
if hasattr(prospect, key):
setattr(prospect, key, value)
await self.session.flush()
await self._log_audit('update', 'prospect', prospect_id, old_value=old_data, new_value=prospect_data)
logger.info(f"Updated prospect: {prospect_id}")
return prospect
async def update_score(
self,
prospect_id: str,
fit_score: Optional[float] = None,
engagement_score: Optional[float] = None,
intent_score: Optional[float] = None
) -> Optional[Prospect]:
"""Update prospect scores and calculate overall score"""
prospect = await self.get_by_id(prospect_id)
if not prospect:
return None
if fit_score is not None:
prospect.fit_score = fit_score
if engagement_score is not None:
prospect.engagement_score = engagement_score
if intent_score is not None:
prospect.intent_score = intent_score
# Calculate overall score (weighted average)
scores = []
if prospect.fit_score is not None:
scores.append(prospect.fit_score * 0.5) # 50% weight
if prospect.engagement_score is not None:
scores.append(prospect.engagement_score * 0.3) # 30% weight
if prospect.intent_score is not None:
scores.append(prospect.intent_score * 0.2) # 20% weight
if scores:
prospect.overall_score = sum(scores) / (len(scores) * 0.1) * 0.1
await self.session.flush()
logger.info(f"Updated prospect scores: {prospect_id}")
return prospect
class ContactRepository(BaseRepository):
"""Repository for Contact operations"""
async def create(self, contact_data: Dict[str, Any]) -> Contact:
"""Create a new contact"""
if self.tenant_id:
contact_data['tenant_id'] = self.tenant_id
# Normalize email
if 'email' in contact_data:
contact_data['email'] = contact_data['email'].lower()
contact = Contact(**contact_data)
self.session.add(contact)
await self.session.flush()
await self._log_audit('create', 'contact', contact.id, new_value=contact_data)
logger.info(f"Created contact: {contact.id}")
return contact
async def get_by_id(self, contact_id: str) -> Optional[Contact]:
"""Get contact by ID"""
query = select(Contact).where(Contact.id == contact_id)
query = self._apply_tenant_filter(query, Contact)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> Optional[Contact]:
"""Get contact by email"""
query = select(Contact).where(Contact.email == email.lower())
query = self._apply_tenant_filter(query, Contact)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def list_by_company(self, company_id: str) -> List[Contact]:
"""List contacts for a company"""
query = select(Contact).where(Contact.company_id == company_id)
query = self._apply_tenant_filter(query, Contact)
query = query.where(Contact.is_active == True).order_by(Contact.is_primary_contact.desc())
result = await self.session.execute(query)
return list(result.scalars().all())
async def list_by_domain(self, domain: str) -> List[Contact]:
"""List contacts by domain (from email)"""
query = select(Contact).where(Contact.email.endswith(f"@{domain}"))
query = self._apply_tenant_filter(query, Contact)
query = query.where(Contact.is_active == True)
result = await self.session.execute(query)
return list(result.scalars().all())
class FactRepository(BaseRepository):
"""Repository for Fact operations"""
async def create(self, fact_data: Dict[str, Any]) -> Fact:
"""Create a new fact"""
if self.tenant_id:
fact_data['tenant_id'] = self.tenant_id
fact = Fact(**fact_data)
self.session.add(fact)
await self.session.flush()
logger.info(f"Created fact: {fact.id}")
return fact
async def list_by_company(
self,
company_id: str,
fact_type: Optional[str] = None,
limit: int = 50
) -> List[Fact]:
"""List facts for a company"""
query = select(Fact).where(Fact.company_id == company_id)
query = self._apply_tenant_filter(query, Fact)
if fact_type:
query = query.where(Fact.fact_type == fact_type)
query = query.order_by(Fact.published_at.desc()).limit(limit)
result = await self.session.execute(query)
return list(result.scalars().all())
class ActivityRepository(BaseRepository):
"""Repository for Activity operations"""
async def create(self, activity_data: Dict[str, Any]) -> Activity:
"""Create a new activity"""
if self.tenant_id:
activity_data['tenant_id'] = self.tenant_id
activity = Activity(**activity_data)
self.session.add(activity)
await self.session.flush()
logger.info(f"Created activity: {activity.id}")
return activity
async def list_by_prospect(
self,
prospect_id: str,
activity_type: Optional[str] = None,
limit: int = 100
) -> List[Activity]:
"""List activities for a prospect"""
query = select(Activity).where(Activity.prospect_id == prospect_id)
query = self._apply_tenant_filter(query, Activity)
if activity_type:
query = query.where(Activity.activity_type == activity_type)
query = query.order_by(Activity.created_at.desc()).limit(limit)
result = await self.session.execute(query)
return list(result.scalars().all())
class SuppressionRepository(BaseRepository):
"""Repository for Suppression operations"""
async def create(self, suppression_data: Dict[str, Any]) -> Suppression:
"""Create a new suppression"""
if self.tenant_id:
suppression_data['tenant_id'] = self.tenant_id
# Normalize value
if 'value' in suppression_data:
suppression_data['value'] = suppression_data['value'].lower()
suppression = Suppression(**suppression_data)
self.session.add(suppression)
await self.session.flush()
logger.info(f"Created suppression: {suppression.id}")
return suppression
async def check(
self,
suppression_type: str,
value: str
) -> bool:
"""Check if a value is suppressed"""
value = value.lower()
query = select(Suppression).where(
and_(
Suppression.suppression_type == suppression_type,
Suppression.value == value
)
)
query = self._apply_tenant_filter(query, Suppression)
# Check expiry
query = query.where(
or_(
Suppression.expires_at.is_(None),
Suppression.expires_at > datetime.utcnow()
)
)
result = await self.session.execute(query)
suppression = result.scalar_one_or_none()
return suppression is not None
async def list(
self,
suppression_type: Optional[str] = None,
limit: int = 100
) -> List[Suppression]:
"""List suppressions"""
query = select(Suppression)
query = self._apply_tenant_filter(query, Suppression)
if suppression_type:
query = query.where(Suppression.suppression_type == suppression_type)
# Only active suppressions
query = query.where(
or_(
Suppression.expires_at.is_(None),
Suppression.expires_at > datetime.utcnow()
)
)
query = query.limit(limit).order_by(Suppression.created_at.desc())
result = await self.session.execute(query)
return list(result.scalars().all())
class HandoffRepository(BaseRepository):
"""Repository for Handoff operations"""
async def create(self, handoff_data: Dict[str, Any]) -> Handoff:
"""Create a new handoff"""
if self.tenant_id:
handoff_data['tenant_id'] = self.tenant_id
handoff = Handoff(**handoff_data)
self.session.add(handoff)
await self.session.flush()
await self._log_audit('create', 'handoff', handoff.id, new_value=handoff_data)
logger.info(f"Created handoff: {handoff.id}")
return handoff
async def get_by_id(self, handoff_id: str) -> Optional[Handoff]:
"""Get handoff by ID"""
query = select(Handoff).where(Handoff.id == handoff_id)
query = self._apply_tenant_filter(query, Handoff)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def list(
self,
status: Optional[str] = None,
priority: Optional[str] = None,
assigned_to: Optional[str] = None,
limit: int = 100
) -> List[Handoff]:
"""List handoffs with filters"""
query = select(Handoff)
query = self._apply_tenant_filter(query, Handoff)
if status:
query = query.where(Handoff.status == status)
if priority:
query = query.where(Handoff.priority == priority)
if assigned_to:
query = query.where(Handoff.assigned_to == assigned_to)
query = query.limit(limit).order_by(Handoff.created_at.desc())
result = await self.session.execute(query)
return list(result.scalars().all())
async def update(self, handoff_id: str, handoff_data: Dict[str, Any]) -> Optional[Handoff]:
"""Update a handoff"""
handoff = await self.get_by_id(handoff_id)
if not handoff:
return None
old_data = {key: getattr(handoff, key) for key in handoff_data.keys() if hasattr(handoff, key)}
for key, value in handoff_data.items():
if hasattr(handoff, key):
setattr(handoff, key, value)
await self.session.flush()
await self._log_audit('update', 'handoff', handoff_id, old_value=old_data, new_value=handoff_data)
logger.info(f"Updated handoff: {handoff_id}")
return handoff