""" database.py — SQLAlchemy ORM models and session management. Tables: contracts — one row per ingested contract clauses — one row per extracted clause (FK → contracts) analysis_results — one row per analysis run linking all scores All timestamps are stored as ISO-8601 UTC strings for SQLite compatibility. """ import sys from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path from typing import Generator, Optional from sqlalchemy import ( Boolean, Column, DateTime, Float, ForeignKey, Integer, String, Text, create_engine, event, ) from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker sys.path.insert(0, str(Path(__file__).parent.parent)) import config # Engine & Session engine = create_engine( config.DB_URL, connect_args={"check_same_thread": False}, # SQLite requires this for FastAPI echo=False, ) # Enable WAL mode for better concurrent read performance with SQLite @event.listens_for(engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): """Enable WAL mode and foreign key enforcement on every new SQLite connection.""" cursor = dbapi_connection.cursor() cursor.execute("PRAGMA journal_mode=WAL") cursor.execute("PRAGMA foreign_keys=ON") cursor.close() SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # ORM Base class Base(DeclarativeBase): """Declarative base for all ORM models.""" pass # ORM Models class Contract(Base): """Represents a single ingested contract document. Attributes: contract_id: MD5-based unique identifier derived from filename. filename: Original filename or identifier string. source: Origin of contract ('CUAD', 'upload', 'text_input'). page_count: Number of pages (for PDFs). created_at: UTC timestamp of ingestion. clauses: Relationship to associated Clause rows. """ __tablename__ = "contracts" contract_id = Column(String(64), primary_key=True, index=True) filename = Column(String(512), nullable=False) source = Column(String(64), nullable=False, default="upload") page_count = Column(Integer, nullable=True) created_at = Column( DateTime, default=lambda: datetime.now(timezone.utc), nullable=False, ) clauses = relationship("Clause", back_populates="contract", cascade="all, delete-orphan") results = relationship("AnalysisResult", back_populates="contract", cascade="all, delete-orphan") def __repr__(self) -> str: return f"" class Clause(Base): """Represents a single clause extracted from a contract. Attributes: clause_id: MD5-based unique identifier. contract_id: FK to contracts table. clause_text: Raw text of the clause. clause_type: Pipe-separated list of CUAD clause type labels. party_a: Name of Party A as detected in clause. party_b: Name of Party B as detected in clause. source: Origin ('CUAD', 'upload', 'text_input'). anomaly_score: Combined anomaly risk score (0–100). is_anomalous: True if anomaly_score > ANOMALY_FLAG_THRESHOLD. power_imbalance_score: Bilateral imbalance score (-100 to +100). party_a_leverage: Party A leverage score (0–100). party_b_leverage: Party B leverage score (0–100). sentiment_score: Sentiment feature value. modal_score: Modal verb feature value. obligation_score: Obligation assignment feature value. assertiveness_score: Assertiveness feature value. shap_plot_path: Filesystem path to the SHAP PNG for this clause. created_at: UTC timestamp of processing. """ __tablename__ = "clauses" clause_id = Column(String(64), primary_key=True, index=True) contract_id = Column(String(64), ForeignKey("contracts.contract_id"), nullable=False, index=True) clause_text = Column(Text, nullable=False) clause_type = Column(String(512), nullable=False, default="") party_a = Column(String(256), nullable=True, default="") party_b = Column(String(256), nullable=True, default="") source = Column(String(64), nullable=True, default="") # Anomaly detection fields anomaly_score = Column(Float, nullable=True) is_anomalous = Column(Boolean, nullable=True, default=False) # Power imbalance fields power_imbalance_score = Column(Float, nullable=True) party_a_leverage = Column(Float, nullable=True) party_b_leverage = Column(Float, nullable=True) # Feature-level scores sentiment_score = Column(Float, nullable=True) modal_score = Column(Float, nullable=True) obligation_score = Column(Float, nullable=True) assertiveness_score = Column(Float, nullable=True) # Explainability shap_plot_path = Column(String(512), nullable=True) created_at = Column( DateTime, default=lambda: datetime.now(timezone.utc), nullable=False, ) contract = relationship("Contract", back_populates="clauses") def __repr__(self) -> str: return f"" class AnalysisResult(Base): """Stores aggregate analysis results at the contract level. Attributes: result_id: Auto-incremented primary key. contract_id: FK to contracts table. overall_imbalance_index: Contract-level power imbalance (-100 to +100). total_clauses: Total clause count. anomalous_clauses: Count of flagged anomalous clauses. dominant_clause_type: Most frequent clause type in the contract. analysis_metadata: JSON string with additional metadata. created_at: UTC timestamp. """ __tablename__ = "analysis_results" result_id = Column(Integer, primary_key=True, autoincrement=True) contract_id = Column(String(64), ForeignKey("contracts.contract_id"), nullable=False, index=True) overall_imbalance_index = Column(Float, nullable=True) total_clauses = Column(Integer, nullable=True) anomalous_clauses = Column(Integer, nullable=True) dominant_clause_type = Column(String(256), nullable=True) analysis_metadata = Column(Text, nullable=True) # JSON string created_at = Column( DateTime, default=lambda: datetime.now(timezone.utc), nullable=False, ) contract = relationship("Contract", back_populates="results") def __repr__(self) -> str: return f"" # Utility functions def create_tables() -> None: """Create all database tables if they do not already exist. Safe to call multiple times — uses CREATE TABLE IF NOT EXISTS semantics. """ Base.metadata.create_all(bind=engine) def get_db() -> Generator[Session, None, None]: """FastAPI dependency that provides a database session per request. Yields: SQLAlchemy Session instance. Automatically closed after use. """ db = SessionLocal() try: yield db finally: db.close() @contextmanager def managed_session() -> Generator[Session, None, None]: """Context manager for database sessions outside of FastAPI request scope. Usage: with managed_session() as session: session.add(some_object) session.commit() """ session = SessionLocal() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close()