ayush2917's picture
Update agent/memory.py
7b99272 verified
import sqlite3
import logging
import os
from typing import List, Dict
logger = logging.getLogger(__name__)
class AgentMemory:
def __init__(self):
"""
Initialize memory with the database path.
The actual connection will be established per operation.
"""
# Get database path from environment variable, fallback to the one defined in Dockerfile
self.db_path = os.getenv("DATABASE_PATH", "/app/data/news.db")
logger.info(f"AgentMemory initialized with database path: {self.db_path}")
def _get_db_connection(self) -> sqlite3.Connection:
"""
Establishes and returns a new database connection.
Ensures the database directory exists before connecting.
This method is intended for internal use within AgentMemory operations.
"""
# Ensure the directory for the database file exists
db_dir = os.path.dirname(self.db_path)
os.makedirs(db_dir, exist_ok=True)
db = sqlite3.connect(self.db_path, check_same_thread=False)
db.row_factory = sqlite3.Row
self._initialize_tables(db) # Ensure tables are initialized when a new connection is made
return db
def _initialize_tables(self, db: sqlite3.Connection):
"""
Initialize database tables if they don't exist.
This is called every time a new connection is established via _get_db_connection
to ensure tables are always present for operations.
"""
try:
# Table creation is idempotent due to IF NOT EXISTS
db.execute("""
CREATE TABLE IF NOT EXISTS news (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT,
summary TEXT,
sentiment_score REAL,
category TEXT,
published_at TEXT,
source TEXT,
url TEXT UNIQUE, -- Added URL for news articles
content TEXT, -- Added full content for news articles
description TEXT -- Added description for news articles
)
""")
db.execute("""
CREATE TABLE IF NOT EXISTS user_prefs (
user_id TEXT PRIMARY KEY,
stocks TEXT, -- Stored as comma-separated string
categories TEXT, -- Stored as comma-separated string
alert_frequency TEXT
)
""")
db.commit()
logger.debug("AgentMemory database tables initialized (if not existing).") # Changed to debug
except sqlite3.Error as e: # Catch specific SQLite errors
logger.error(f"Error initializing AgentMemory tables: {e}", exc_info=True)
raise # Re-raise to indicate a critical setup failure
def store_news(self, title: str, summary: str, sentiment_score: float, category: str, published_at: str, source: str, url: str, content: str, description: str):
"""
Store a news article in the database.
Includes new fields: url, content, description.
"""
db = None
try:
db = self._get_db_connection()
db.execute(
"""
INSERT OR REPLACE INTO news (title, summary, sentiment_score, category, published_at, source, url, content, description)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(title, summary, sentiment_score, category, published_at, source, url, content, description)
)
db.commit()
logger.info(f"Stored news article: {title} (URL: {url})")
except Exception as e:
logger.error(f"Error storing news article '{title}' (URL: {url}): {e}", exc_info=True)
finally:
if db:
db.close()
def search_news(self, query: str, category: str = None) -> List[Dict]:
"""
Search news articles by query (in title/content/description) and optional category.
Now includes 'content' and 'description' in search.
"""
db = None
results = []
try:
db = self._get_db_connection()
cursor = db.cursor()
sql_query = "SELECT * FROM news WHERE (title LIKE ? OR content LIKE ? OR description LIKE ?)"
params = [f"%{query}%", f"%{query}%", f"%{query}%"]
if category:
sql_query += " AND category = ?"
params.append(category)
sql_query += " ORDER BY published_at DESC" # Order by latest news
cursor.execute(sql_query, tuple(params))
results = [dict(row) for row in cursor.fetchall()]
logger.info(f"Found {len(results)} news articles for query: '{query}' (category: {category or 'any'})")
return results
except Exception as e:
logger.error(f"Error searching news for query '{query}': {e}", exc_info=True)
return []
finally:
if db:
db.close()
def update_user_prefs(self, user_id: str, stocks: list, categories: list, alert_frequency: str):
"""
Update user preferences in the database.
Converts lists of stocks and categories to comma-separated strings for storage.
"""
db = None
try:
db = self._get_db_connection()
db.execute(
"""
INSERT OR REPLACE INTO user_prefs (user_id, stocks, categories, alert_frequency)
VALUES (?, ?, ?, ?)
""",
(user_id, ",".join(stocks), ",".join(categories), alert_frequency)
)
db.commit()
logger.info(f"Updated preferences for user {user_id}: Stocks={stocks}, Categories={categories}, Alert Frequency={alert_frequency}")
except Exception as e:
logger.error(f"Error updating user preferences for '{user_id}': {e}", exc_info=True)
finally:
if db:
db.close()
def get_users_for_article(self, category: str, stocks: List[str]) -> List[Dict]:
"""
Get users interested in a specific category or stock.
Searches for users whose preferences contain the given category or any of the given stocks.
"""
db = None
results = []
try:
db = self._get_db_connection()
# Start the WHERE clause with category matching
sql_parts = ["categories LIKE ?"]
params = [f"%{category}%"]
# Add conditions for each stock, dynamically
for stock_sym in stocks:
sql_parts.append("stocks LIKE ?")
params.append(f"%{stock_sym}%")
# Combine all conditions with OR
where_clause = " OR ".join(sql_parts)
cursor = db.execute(f"SELECT * FROM user_prefs WHERE {where_clause}", tuple(params))
results = [dict(row) for row in cursor.fetchall()]
logger.info(f"Found {len(results)} users interested in category '{category}' or stocks {stocks}")
return results
except Exception as e:
logger.error(f"Error fetching users for article (category: {category}, stocks: {stocks}): {e}", exc_info=True)
return []
finally:
if db:
db.close()