Spaces:
Sleeping
Sleeping
implement NegaBot API with FastAPI for tweet sentiment classification and add SQLite logging system
92a3517
| """ | |
| Database and Logging System for NegaBot API | |
| Handles prediction logging using SQLite database | |
| """ | |
| import sqlite3 | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from typing import List, Dict | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Database configuration | |
| DB_PATH = "negabot_predictions.db" | |
| class PredictionLogger: | |
| def __init__(self, db_path: str = DB_PATH): | |
| """ | |
| Initialize the prediction logger with SQLite database | |
| Args: | |
| db_path (str): Path to SQLite database file | |
| """ | |
| self.db_path = db_path | |
| self.init_database() | |
| def init_database(self): | |
| """Initialize the database with required tables""" | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| # Create predictions table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS predictions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| text TEXT NOT NULL, | |
| sentiment TEXT NOT NULL, | |
| confidence REAL NOT NULL, | |
| predicted_class INTEGER NOT NULL, | |
| timestamp TEXT NOT NULL, | |
| metadata TEXT, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Create index for faster queries | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_sentiment ON predictions(sentiment) | |
| """) | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_timestamp ON predictions(timestamp) | |
| """) | |
| conn.commit() | |
| logger.info("Database initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error initializing database: {str(e)}") | |
| raise e | |
| def log_prediction(self, text: str, sentiment: str, confidence: float, | |
| predicted_class: int = None, metadata: Dict = None): | |
| """ | |
| Log a prediction to the database | |
| Args: | |
| text (str): Input text | |
| sentiment (str): Predicted sentiment | |
| confidence (float): Prediction confidence | |
| predicted_class (int): Predicted class (0 or 1) | |
| metadata (dict): Optional metadata | |
| """ | |
| try: | |
| # Infer predicted_class if not provided | |
| if predicted_class is None: | |
| predicted_class = 1 if sentiment == "Negative" else 0 | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO predictions (text, sentiment, confidence, predicted_class, timestamp, metadata) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, ( | |
| text, | |
| sentiment, | |
| confidence, | |
| predicted_class, | |
| datetime.now().isoformat(), | |
| json.dumps(metadata) if metadata else None | |
| )) | |
| conn.commit() | |
| except Exception as e: | |
| logger.error(f"Error logging prediction: {str(e)}") | |
| raise e | |
| def get_all_predictions(self, limit: int = None) -> List[Dict]: | |
| """ | |
| Get all predictions from the database | |
| Args: | |
| limit (int): Maximum number of records to return | |
| Returns: | |
| List of prediction dictionaries | |
| """ | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| query = """ | |
| SELECT id, text, sentiment, confidence, predicted_class, timestamp, metadata, created_at | |
| FROM predictions | |
| ORDER BY created_at DESC | |
| """ | |
| if limit: | |
| query += f" LIMIT {limit}" | |
| cursor.execute(query) | |
| rows = cursor.fetchall() | |
| predictions = [] | |
| for row in rows: | |
| prediction = { | |
| "id": row[0], | |
| "text": row[1], | |
| "sentiment": row[2], | |
| "confidence": row[3], | |
| "predicted_class": row[4], | |
| "timestamp": row[5], | |
| "metadata": json.loads(row[6]) if row[6] else None, | |
| "created_at": row[7] | |
| } | |
| predictions.append(prediction) | |
| return predictions | |
| except Exception as e: | |
| logger.error(f"Error getting predictions: {str(e)}") | |
| return [] | |
| def get_predictions_by_sentiment(self, sentiment: str) -> List[Dict]: | |
| """ | |
| Get predictions filtered by sentiment | |
| Args: | |
| sentiment (str): Sentiment to filter by ("Positive" or "Negative") | |
| Returns: | |
| List of prediction dictionaries | |
| """ | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT id, text, sentiment, confidence, predicted_class, timestamp, metadata, created_at | |
| FROM predictions | |
| WHERE sentiment = ? | |
| ORDER BY created_at DESC | |
| """, (sentiment,)) | |
| rows = cursor.fetchall() | |
| predictions = [] | |
| for row in rows: | |
| prediction = { | |
| "id": row[0], | |
| "text": row[1], | |
| "sentiment": row[2], | |
| "confidence": row[3], | |
| "predicted_class": row[4], | |
| "timestamp": row[5], | |
| "metadata": json.loads(row[6]) if row[6] else None, | |
| "created_at": row[7] | |
| } | |
| predictions.append(prediction) | |
| return predictions | |
| except Exception as e: | |
| logger.error(f"Error getting predictions by sentiment: {str(e)}") | |
| return [] | |
| def get_stats(self) -> Dict: | |
| """ | |
| Get prediction statistics | |
| Returns: | |
| Dictionary with statistics | |
| """ | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| # Total count | |
| cursor.execute("SELECT COUNT(*) FROM predictions") | |
| total_count = cursor.fetchone()[0] | |
| if total_count == 0: | |
| return { | |
| "total_predictions": 0, | |
| "positive_count": 0, | |
| "negative_count": 0, | |
| "average_confidence": 0 | |
| } | |
| # Sentiment counts | |
| cursor.execute("SELECT sentiment, COUNT(*) FROM predictions GROUP BY sentiment") | |
| sentiment_counts = dict(cursor.fetchall()) | |
| # Average confidence | |
| cursor.execute("SELECT AVG(confidence) FROM predictions") | |
| avg_confidence = cursor.fetchone()[0] | |
| return { | |
| "total_predictions": total_count, | |
| "positive_count": sentiment_counts.get("Positive", 0), | |
| "negative_count": sentiment_counts.get("Negative", 0), | |
| "average_confidence": round(avg_confidence, 4) if avg_confidence else 0 | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting stats: {str(e)}") | |
| return {} | |
| # Global logger instance | |
| _logger_instance = None | |
| def get_logger(): | |
| """Get the global logger instance""" | |
| global _logger_instance | |
| if _logger_instance is None: | |
| _logger_instance = PredictionLogger() | |
| return _logger_instance | |
| def log_prediction(text: str, sentiment: str, confidence: float, metadata: Dict = None): | |
| """Convenience function to log a prediction""" | |
| logger_instance = get_logger() | |
| logger_instance.log_prediction(text, sentiment, confidence, metadata=metadata) | |
| def get_all_predictions(limit: int = None) -> List[Dict]: | |
| """Convenience function to get all predictions""" | |
| logger_instance = get_logger() | |
| return logger_instance.get_all_predictions(limit=limit) | |
| def get_predictions_by_sentiment(sentiment: str) -> List[Dict]: | |
| """Convenience function to get predictions by sentiment""" | |
| logger_instance = get_logger() | |
| return logger_instance.get_predictions_by_sentiment(sentiment) | |
| def get_prediction_stats() -> Dict: | |
| """Convenience function to get prediction statistics""" | |
| logger_instance = get_logger() | |
| return logger_instance.get_stats() | |
| if __name__ == "__main__": | |
| # Test the logging system | |
| logger_instance = PredictionLogger() | |
| # Test logging | |
| test_predictions = [ | |
| ("This product is amazing!", "Positive", 0.95), | |
| ("Terrible quality, waste of money", "Negative", 0.89), | |
| ("It's okay, nothing special", "Positive", 0.67), | |
| ("Awful customer service", "Negative", 0.92) | |
| ] | |
| print("Testing prediction logging...") | |
| for text, sentiment, confidence in test_predictions: | |
| logger_instance.log_prediction(text, sentiment, confidence) | |
| print(f"Logged: {sentiment} - {text}") | |
| # Test retrieval | |
| print("\nRetrieving all predictions:") | |
| predictions = logger_instance.get_all_predictions() | |
| for pred in predictions: | |
| print(f"ID: {pred['id']}, Sentiment: {pred['sentiment']}, Text: {pred['text'][:50]}...") | |
| # Test stats | |
| print("\nPrediction statistics:") | |
| stats = logger_instance.get_stats() | |
| print(json.dumps(stats, indent=2)) | |