File size: 7,941 Bytes
f07e102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
database.py β€” SQLAlchemy ORM models and session management.

Tables:
  contracts β€” one row per ingested contract
  clauses   β€” one row per extracted clause (FK β†’ contracts)
  analysis_results β€” one row per analysis run linking all scores

All timestamps are stored as ISO-8601 UTC strings for SQLite compatibility.
"""

import sys
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator, Optional

from sqlalchemy import (
    Boolean,
    Column,
    DateTime,
    Float,
    ForeignKey,
    Integer,
    String,
    Text,
    create_engine,
    event,
)
from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker

sys.path.insert(0, str(Path(__file__).parent.parent))
import config

# Engine & Session

engine = create_engine(
    config.DB_URL,
    connect_args={"check_same_thread": False},  # SQLite requires this for FastAPI
    echo=False,
)

# Enable WAL mode for better concurrent read performance with SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
    """Enable WAL mode and foreign key enforcement on every new SQLite connection."""
    cursor = dbapi_connection.cursor()
    cursor.execute("PRAGMA journal_mode=WAL")
    cursor.execute("PRAGMA foreign_keys=ON")
    cursor.close()


SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


# ORM Base

class Base(DeclarativeBase):
    """Declarative base for all ORM models."""
    pass


# ORM Models

class Contract(Base):
    """Represents a single ingested contract document.

    Attributes:
        contract_id: MD5-based unique identifier derived from filename.
        filename: Original filename or identifier string.
        source: Origin of contract ('CUAD', 'upload', 'text_input').
        page_count: Number of pages (for PDFs).
        created_at: UTC timestamp of ingestion.
        clauses: Relationship to associated Clause rows.
    """

    __tablename__ = "contracts"

    contract_id = Column(String(64), primary_key=True, index=True)
    filename    = Column(String(512), nullable=False)
    source      = Column(String(64), nullable=False, default="upload")
    page_count  = Column(Integer, nullable=True)
    created_at  = Column(
        DateTime,
        default=lambda: datetime.now(timezone.utc),
        nullable=False,
    )

    clauses     = relationship("Clause", back_populates="contract", cascade="all, delete-orphan")
    results     = relationship("AnalysisResult", back_populates="contract", cascade="all, delete-orphan")

    def __repr__(self) -> str:
        return f"<Contract id={self.contract_id} file={self.filename}>"


class Clause(Base):
    """Represents a single clause extracted from a contract.

    Attributes:
        clause_id: MD5-based unique identifier.
        contract_id: FK to contracts table.
        clause_text: Raw text of the clause.
        clause_type: Pipe-separated list of CUAD clause type labels.
        party_a: Name of Party A as detected in clause.
        party_b: Name of Party B as detected in clause.
        source: Origin ('CUAD', 'upload', 'text_input').
        anomaly_score: Combined anomaly risk score (0–100).
        is_anomalous: True if anomaly_score > ANOMALY_FLAG_THRESHOLD.
        power_imbalance_score: Bilateral imbalance score (-100 to +100).
        party_a_leverage: Party A leverage score (0–100).
        party_b_leverage: Party B leverage score (0–100).
        sentiment_score: Sentiment feature value.
        modal_score: Modal verb feature value.
        obligation_score: Obligation assignment feature value.
        assertiveness_score: Assertiveness feature value.
        shap_plot_path: Filesystem path to the SHAP PNG for this clause.
        created_at: UTC timestamp of processing.
    """

    __tablename__ = "clauses"

    clause_id             = Column(String(64), primary_key=True, index=True)
    contract_id           = Column(String(64), ForeignKey("contracts.contract_id"), nullable=False, index=True)
    clause_text           = Column(Text, nullable=False)
    clause_type           = Column(String(512), nullable=False, default="")
    party_a               = Column(String(256), nullable=True, default="")
    party_b               = Column(String(256), nullable=True, default="")
    source                = Column(String(64), nullable=True, default="")

    # Anomaly detection fields
    anomaly_score         = Column(Float, nullable=True)
    is_anomalous          = Column(Boolean, nullable=True, default=False)

    # Power imbalance fields
    power_imbalance_score = Column(Float, nullable=True)
    party_a_leverage      = Column(Float, nullable=True)
    party_b_leverage      = Column(Float, nullable=True)

    # Feature-level scores
    sentiment_score       = Column(Float, nullable=True)
    modal_score           = Column(Float, nullable=True)
    obligation_score      = Column(Float, nullable=True)
    assertiveness_score   = Column(Float, nullable=True)

    # Explainability
    shap_plot_path        = Column(String(512), nullable=True)

    created_at            = Column(
        DateTime,
        default=lambda: datetime.now(timezone.utc),
        nullable=False,
    )

    contract = relationship("Contract", back_populates="clauses")

    def __repr__(self) -> str:
        return f"<Clause id={self.clause_id} type={self.clause_type[:30]}>"


class AnalysisResult(Base):
    """Stores aggregate analysis results at the contract level.

    Attributes:
        result_id: Auto-incremented primary key.
        contract_id: FK to contracts table.
        overall_imbalance_index: Contract-level power imbalance (-100 to +100).
        total_clauses: Total clause count.
        anomalous_clauses: Count of flagged anomalous clauses.
        dominant_clause_type: Most frequent clause type in the contract.
        analysis_metadata: JSON string with additional metadata.
        created_at: UTC timestamp.
    """

    __tablename__ = "analysis_results"

    result_id               = Column(Integer, primary_key=True, autoincrement=True)
    contract_id             = Column(String(64), ForeignKey("contracts.contract_id"), nullable=False, index=True)
    overall_imbalance_index = Column(Float, nullable=True)
    total_clauses           = Column(Integer, nullable=True)
    anomalous_clauses       = Column(Integer, nullable=True)
    dominant_clause_type    = Column(String(256), nullable=True)
    analysis_metadata       = Column(Text, nullable=True)  # JSON string
    created_at              = Column(
        DateTime,
        default=lambda: datetime.now(timezone.utc),
        nullable=False,
    )

    contract = relationship("Contract", back_populates="results")

    def __repr__(self) -> str:
        return f"<AnalysisResult contract={self.contract_id} imbalance={self.overall_imbalance_index}>"


# Utility functions

def create_tables() -> None:
    """Create all database tables if they do not already exist.

    Safe to call multiple times β€” uses CREATE TABLE IF NOT EXISTS semantics.
    """
    Base.metadata.create_all(bind=engine)


def get_db() -> Generator[Session, None, None]:
    """FastAPI dependency that provides a database session per request.

    Yields:
        SQLAlchemy Session instance. Automatically closed after use.
    """
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


@contextmanager
def managed_session() -> Generator[Session, None, None]:
    """Context manager for database sessions outside of FastAPI request scope.

    Usage:
        with managed_session() as session:
            session.add(some_object)
            session.commit()
    """
    session = SessionLocal()
    try:
        yield session
        session.commit()
    except Exception:
        session.rollback()
        raise
    finally:
        session.close()