File size: 6,316 Bytes
0c591a7
 
 
53fe655
0c591a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
Analysis Cache - Supabase PostgreSQL caching for final SWOT analysis results.

Caches final SWOT analysis output with 24h TTL to avoid re-running the full pipeline.
Uses schema: asa.analysis_cache
"""

import os
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional

import psycopg2
from psycopg2.extras import RealDictCursor
from dotenv import load_dotenv

# Load environment variables (project .env first, then ~/.env for local overrides)
load_dotenv()  # Project .env or HF Space secrets
load_dotenv(os.path.expanduser("~/.env"))  # Local development overrides

logger = logging.getLogger("analysis-cache")

# Default TTL: 24 hours
DEFAULT_TTL_HOURS = 24

# Supabase PostgreSQL connection string
SUPABASE_DB_URL = os.getenv("PIPELINE_SUPABASE_URL")


def get_connection():
    """Get PostgreSQL connection to Supabase."""
    if not SUPABASE_DB_URL:
        raise RuntimeError("PIPELINE_SUPABASE_URL not set in environment")
    return psycopg2.connect(SUPABASE_DB_URL)


def get_cached_analysis(ticker: str) -> Optional[dict]:
    """
    Get cached analysis for a ticker if it exists and hasn't expired.

    Args:
        ticker: Stock ticker symbol

    Returns:
        Cached analysis dict or None if not found/expired
    """
    try:
        conn = get_connection()
        cursor = conn.cursor(cursor_factory=RealDictCursor)

        # Auto-cleanup: delete expired entries before checking cache
        cursor.execute("DELETE FROM asa.analysis_cache WHERE expires_at <= NOW()")
        deleted = cursor.rowcount
        if deleted > 0:
            conn.commit()
            logger.info(f"Auto-cleanup: removed {deleted} expired cache entries")

        cursor.execute("""
            SELECT data, expires_at
            FROM asa.analysis_cache
            WHERE ticker = %s AND expires_at > NOW()
        """, (ticker.upper(),))

        row = cursor.fetchone()
        cursor.close()
        conn.close()

        if row:
            data = row['data']
            if isinstance(data, str):
                data = json.loads(data)

            # Add cache metadata
            data["_cache_info"] = {
                "cached": True,
                "expires_at": row['expires_at'].isoformat() if row['expires_at'] else None
            }
            logger.info(f"Cache HIT for {ticker}")
            return data

        logger.info(f"Cache MISS for {ticker}")
        return None

    except Exception as e:
        logger.error(f"Cache read error for {ticker}: {e}")
        return None


def set_cached_analysis(ticker: str, company_name: str, data: dict, ttl_hours: int = DEFAULT_TTL_HOURS):
    """
    Store analysis result in cache.

    Args:
        ticker: Stock ticker symbol
        company_name: Company name
        data: Full analysis result dict (swot_data, score, critique, etc.)
        ttl_hours: Time-to-live in hours (default 24)
    """
    try:
        conn = get_connection()
        cursor = conn.cursor()

        # Auto-cleanup: delete expired entries before inserting new one
        cursor.execute("DELETE FROM asa.analysis_cache WHERE expires_at <= NOW()")
        deleted = cursor.rowcount
        if deleted > 0:
            logger.info(f"Auto-cleanup: removed {deleted} expired cache entries")

        expires_at = datetime.now(timezone.utc) + timedelta(hours=ttl_hours)

        # Remove cache info before storing
        data_to_store = {k: v for k, v in data.items() if k != "_cache_info"}

        cursor.execute("""
            INSERT INTO asa.analysis_cache (ticker, company_name, data, created_at, expires_at)
            VALUES (%s, %s, %s, NOW(), %s)
            ON CONFLICT (ticker)
            DO UPDATE SET
                company_name = EXCLUDED.company_name,
                data = EXCLUDED.data,
                created_at = NOW(),
                expires_at = EXCLUDED.expires_at
        """, (ticker.upper(), company_name, json.dumps(data_to_store, default=str), expires_at))

        conn.commit()
        cursor.close()
        conn.close()

        logger.info(f"Cached analysis for {ticker} (expires: {expires_at})")

    except Exception as e:
        logger.error(f"Cache write error for {ticker}: {e}")


def clear_cache(ticker: Optional[str] = None):
    """
    Clear cache entries.

    Args:
        ticker: If provided, clear only this ticker. Otherwise clear all.
    """
    try:
        conn = get_connection()
        cursor = conn.cursor()

        if ticker:
            cursor.execute("DELETE FROM asa.analysis_cache WHERE ticker = %s", (ticker.upper(),))
            logger.info(f"Cleared cache for {ticker}")
        else:
            cursor.execute("DELETE FROM asa.analysis_cache")
            logger.info("Cleared all cache entries")

        conn.commit()
        cursor.close()
        conn.close()

    except Exception as e:
        logger.error(f"Cache clear error: {e}")


def clear_expired_cache() -> int:
    """Remove all expired cache entries. Returns count of deleted entries."""
    try:
        conn = get_connection()
        cursor = conn.cursor()

        cursor.execute("DELETE FROM asa.analysis_cache WHERE expires_at <= NOW()")
        deleted = cursor.rowcount

        conn.commit()
        cursor.close()
        conn.close()

        logger.info(f"Cleared {deleted} expired cache entries")
        return deleted

    except Exception as e:
        logger.error(f"Cache cleanup error: {e}")
        return 0


def get_cache_stats() -> dict:
    """Get cache statistics."""
    try:
        conn = get_connection()
        cursor = conn.cursor()

        cursor.execute("SELECT COUNT(*) FROM asa.analysis_cache")
        total = cursor.fetchone()[0]

        cursor.execute("SELECT COUNT(*) FROM asa.analysis_cache WHERE expires_at > NOW()")
        valid = cursor.fetchone()[0]

        cursor.execute("SELECT COUNT(*) FROM asa.analysis_cache WHERE expires_at <= NOW()")
        expired = cursor.fetchone()[0]

        cursor.close()
        conn.close()

        return {
            "total_entries": total,
            "valid_entries": valid,
            "expired_entries": expired
        }

    except Exception as e:
        logger.error(f"Cache stats error: {e}")
        return {"total_entries": 0, "valid_entries": 0, "expired_entries": 0}