Spaces:
Build error
Build error
| import sqlite3 | |
| import logging | |
| import os | |
| from typing import List, Dict | |
| logger = logging.getLogger(__name__) | |
| class AgentMemory: | |
| def __init__(self): | |
| """ | |
| Initialize memory with the database path. | |
| The actual connection will be established per operation. | |
| """ | |
| # Get database path from environment variable, fallback to the one defined in Dockerfile | |
| self.db_path = os.getenv("DATABASE_PATH", "/app/data/news.db") | |
| logger.info(f"AgentMemory initialized with database path: {self.db_path}") | |
| def _get_db_connection(self) -> sqlite3.Connection: | |
| """ | |
| Establishes and returns a new database connection. | |
| Ensures the database directory exists before connecting. | |
| This method is intended for internal use within AgentMemory operations. | |
| """ | |
| # Ensure the directory for the database file exists | |
| db_dir = os.path.dirname(self.db_path) | |
| os.makedirs(db_dir, exist_ok=True) | |
| db = sqlite3.connect(self.db_path, check_same_thread=False) | |
| db.row_factory = sqlite3.Row | |
| self._initialize_tables(db) # Ensure tables are initialized when a new connection is made | |
| return db | |
| def _initialize_tables(self, db: sqlite3.Connection): | |
| """ | |
| Initialize database tables if they don't exist. | |
| This is called every time a new connection is established via _get_db_connection | |
| to ensure tables are always present for operations. | |
| """ | |
| try: | |
| # Table creation is idempotent due to IF NOT EXISTS | |
| db.execute(""" | |
| CREATE TABLE IF NOT EXISTS news ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| title TEXT, | |
| summary TEXT, | |
| sentiment_score REAL, | |
| category TEXT, | |
| published_at TEXT, | |
| source TEXT, | |
| url TEXT UNIQUE, -- Added URL for news articles | |
| content TEXT, -- Added full content for news articles | |
| description TEXT -- Added description for news articles | |
| ) | |
| """) | |
| db.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_prefs ( | |
| user_id TEXT PRIMARY KEY, | |
| stocks TEXT, -- Stored as comma-separated string | |
| categories TEXT, -- Stored as comma-separated string | |
| alert_frequency TEXT | |
| ) | |
| """) | |
| db.commit() | |
| logger.debug("AgentMemory database tables initialized (if not existing).") # Changed to debug | |
| except sqlite3.Error as e: # Catch specific SQLite errors | |
| logger.error(f"Error initializing AgentMemory tables: {e}", exc_info=True) | |
| raise # Re-raise to indicate a critical setup failure | |
| def store_news(self, title: str, summary: str, sentiment_score: float, category: str, published_at: str, source: str, url: str, content: str, description: str): | |
| """ | |
| Store a news article in the database. | |
| Includes new fields: url, content, description. | |
| """ | |
| db = None | |
| try: | |
| db = self._get_db_connection() | |
| db.execute( | |
| """ | |
| INSERT OR REPLACE INTO news (title, summary, sentiment_score, category, published_at, source, url, content, description) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, | |
| (title, summary, sentiment_score, category, published_at, source, url, content, description) | |
| ) | |
| db.commit() | |
| logger.info(f"Stored news article: {title} (URL: {url})") | |
| except Exception as e: | |
| logger.error(f"Error storing news article '{title}' (URL: {url}): {e}", exc_info=True) | |
| finally: | |
| if db: | |
| db.close() | |
| def search_news(self, query: str, category: str = None) -> List[Dict]: | |
| """ | |
| Search news articles by query (in title/content/description) and optional category. | |
| Now includes 'content' and 'description' in search. | |
| """ | |
| db = None | |
| results = [] | |
| try: | |
| db = self._get_db_connection() | |
| cursor = db.cursor() | |
| sql_query = "SELECT * FROM news WHERE (title LIKE ? OR content LIKE ? OR description LIKE ?)" | |
| params = [f"%{query}%", f"%{query}%", f"%{query}%"] | |
| if category: | |
| sql_query += " AND category = ?" | |
| params.append(category) | |
| sql_query += " ORDER BY published_at DESC" # Order by latest news | |
| cursor.execute(sql_query, tuple(params)) | |
| results = [dict(row) for row in cursor.fetchall()] | |
| logger.info(f"Found {len(results)} news articles for query: '{query}' (category: {category or 'any'})") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error searching news for query '{query}': {e}", exc_info=True) | |
| return [] | |
| finally: | |
| if db: | |
| db.close() | |
| def update_user_prefs(self, user_id: str, stocks: list, categories: list, alert_frequency: str): | |
| """ | |
| Update user preferences in the database. | |
| Converts lists of stocks and categories to comma-separated strings for storage. | |
| """ | |
| db = None | |
| try: | |
| db = self._get_db_connection() | |
| db.execute( | |
| """ | |
| INSERT OR REPLACE INTO user_prefs (user_id, stocks, categories, alert_frequency) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (user_id, ",".join(stocks), ",".join(categories), alert_frequency) | |
| ) | |
| db.commit() | |
| logger.info(f"Updated preferences for user {user_id}: Stocks={stocks}, Categories={categories}, Alert Frequency={alert_frequency}") | |
| except Exception as e: | |
| logger.error(f"Error updating user preferences for '{user_id}': {e}", exc_info=True) | |
| finally: | |
| if db: | |
| db.close() | |
| def get_users_for_article(self, category: str, stocks: List[str]) -> List[Dict]: | |
| """ | |
| Get users interested in a specific category or stock. | |
| Searches for users whose preferences contain the given category or any of the given stocks. | |
| """ | |
| db = None | |
| results = [] | |
| try: | |
| db = self._get_db_connection() | |
| # Start the WHERE clause with category matching | |
| sql_parts = ["categories LIKE ?"] | |
| params = [f"%{category}%"] | |
| # Add conditions for each stock, dynamically | |
| for stock_sym in stocks: | |
| sql_parts.append("stocks LIKE ?") | |
| params.append(f"%{stock_sym}%") | |
| # Combine all conditions with OR | |
| where_clause = " OR ".join(sql_parts) | |
| cursor = db.execute(f"SELECT * FROM user_prefs WHERE {where_clause}", tuple(params)) | |
| results = [dict(row) for row in cursor.fetchall()] | |
| logger.info(f"Found {len(results)} users interested in category '{category}' or stocks {stocks}") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error fetching users for article (category: {category}, stocks: {stocks}): {e}", exc_info=True) | |
| return [] | |
| finally: | |
| if db: | |
| db.close() | |