Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Integration tests for Telegram Analytics: indexer, search, and dashboard endpoints. | |
| Run with: python -m pytest tests.py -v | |
| Or: python tests.py | |
| """ | |
| import json | |
| import os | |
| import sqlite3 | |
| import tempfile | |
| import time | |
| import unittest | |
| from pathlib import Path | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _sample_messages(n: int = 5) -> list[dict]: | |
| """Generate N realistic Telegram-format messages.""" | |
| base_ts = 1700000000 | |
| users = [ | |
| ("user1", "Alice"), | |
| ("user2", "Bob"), | |
| ("user3", "Carol"), | |
| ] | |
| msgs = [] | |
| for i in range(1, n + 1): | |
| uid, name = users[i % len(users)] | |
| msgs.append({ | |
| "id": 1000 + i, | |
| "type": "message", | |
| "date": f"2024-01-{(i % 28) + 1:02d}T10:00:00", | |
| "date_unixtime": str(base_ts + i * 3600), | |
| "from": name, | |
| "from_id": uid, | |
| "text": f"Test message number {i} from {name}", | |
| "text_entities": [ | |
| {"type": "plain", "text": f"Test message number {i} from {name}"} | |
| ], | |
| "reply_to_message_id": (1000 + i - 1) if i > 1 else None, | |
| }) | |
| return msgs | |
| def _write_json(path: str, messages: list[dict]): | |
| """Write messages in Telegram export JSON format.""" | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump({"messages": messages}, f, ensure_ascii=False) | |
| # --------------------------------------------------------------------------- | |
| # 1. Indexer Tests | |
| # --------------------------------------------------------------------------- | |
| class TestIndexer(unittest.TestCase): | |
| """Tests for OptimizedIndexer and IncrementalIndexer.""" | |
| def setUp(self): | |
| self.tmpdir = tempfile.mkdtemp() | |
| self.db_path = os.path.join(self.tmpdir, "test.db") | |
| self.json_path = os.path.join(self.tmpdir, "messages.json") | |
| self.messages = _sample_messages(10) | |
| _write_json(self.json_path, self.messages) | |
| def tearDown(self): | |
| import shutil | |
| shutil.rmtree(self.tmpdir, ignore_errors=True) | |
| def test_optimized_indexer_indexes_messages(self): | |
| from indexer import OptimizedIndexer | |
| indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| stats = indexer.index_file(self.json_path, show_progress=False) | |
| self.assertGreater(stats["messages"], 0) | |
| conn = sqlite3.connect(self.db_path) | |
| count = conn.execute("SELECT COUNT(*) FROM messages").fetchone()[0] | |
| conn.close() | |
| self.assertEqual(count, stats["messages"]) | |
| def test_incremental_indexer_deduplication(self): | |
| from indexer import OptimizedIndexer, IncrementalIndexer | |
| # First: create DB with OptimizedIndexer | |
| opt = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| opt.index_file(self.json_path, show_progress=False) | |
| # Now use IncrementalIndexer – same data, should all be duplicates | |
| idx = IncrementalIndexer(self.db_path) | |
| stats = idx.update_from_json(self.json_path, show_progress=False) | |
| idx.close() | |
| self.assertEqual(stats["new_messages"], 0) | |
| self.assertGreater(stats["duplicates"], 0) | |
| def test_incremental_indexer_adds_new(self): | |
| from indexer import OptimizedIndexer, IncrementalIndexer | |
| # Create DB with 5 messages | |
| msgs5 = _sample_messages(5) | |
| _write_json(self.json_path, msgs5) | |
| opt = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| opt.index_file(self.json_path, show_progress=False) | |
| # Now add 10 messages (5 old + 5 new) | |
| msgs10 = _sample_messages(10) | |
| json2 = os.path.join(self.tmpdir, "messages2.json") | |
| _write_json(json2, msgs10) | |
| idx = IncrementalIndexer(self.db_path) | |
| stats = idx.update_from_json(json2, show_progress=False) | |
| idx.close() | |
| self.assertEqual(stats["new_messages"], 5) | |
| self.assertEqual(stats["duplicates"], 5) | |
| def test_incremental_indexer_from_json_data(self): | |
| from indexer import OptimizedIndexer, IncrementalIndexer | |
| # Init DB first | |
| opt = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| opt.index_file(self.json_path, show_progress=False) | |
| # Add new messages via json_data | |
| new_msgs = _sample_messages(15) # 10 old + 5 new | |
| idx = IncrementalIndexer(self.db_path) | |
| stats = idx.update_from_json_data(new_msgs, show_progress=False) | |
| idx.close() | |
| self.assertEqual(stats["new_messages"], 5) | |
| def test_fts5_search_works(self): | |
| from indexer import OptimizedIndexer | |
| indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| indexer.index_file(self.json_path, show_progress=False) | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.execute( | |
| "SELECT COUNT(*) FROM messages_fts WHERE messages_fts MATCH 'message'" | |
| ) | |
| count = cursor.fetchone()[0] | |
| conn.close() | |
| self.assertGreater(count, 0, "FTS5 search should find messages with 'message'") | |
| def test_streaming_load_json_messages(self): | |
| from indexer import load_json_messages | |
| msgs = list(load_json_messages(self.json_path)) | |
| self.assertEqual(len(msgs), 10) | |
| self.assertIn("text_plain", msgs[0]) | |
| def test_entities_extracted(self): | |
| """Messages with links/mentions in text_entities should have entities stored.""" | |
| msgs = [ | |
| { | |
| "id": 9001, | |
| "type": "message", | |
| "date": "2024-01-01T10:00:00", | |
| "date_unixtime": "1700000000", | |
| "from": "Alice", | |
| "from_id": "user1", | |
| "text": "Check https://example.com and @bob", | |
| "text_entities": [ | |
| {"type": "plain", "text": "Check "}, | |
| {"type": "link", "text": "https://example.com"}, | |
| {"type": "plain", "text": " and "}, | |
| {"type": "mention", "text": "@bob"}, | |
| ], | |
| } | |
| ] | |
| _write_json(self.json_path, msgs) | |
| from indexer import OptimizedIndexer | |
| indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| indexer.index_file(self.json_path, show_progress=False) | |
| conn = sqlite3.connect(self.db_path) | |
| entities = conn.execute("SELECT type, value FROM entities WHERE message_id = 9001").fetchall() | |
| conn.close() | |
| types = [e[0] for e in entities] | |
| self.assertIn("link", types) | |
| self.assertIn("mention", types) | |
| # --------------------------------------------------------------------------- | |
| # 2. Search Tests | |
| # --------------------------------------------------------------------------- | |
| class TestSearch(unittest.TestCase): | |
| """Tests for FTS search.""" | |
| def setUp(self): | |
| self.tmpdir = tempfile.mkdtemp() | |
| self.db_path = os.path.join(self.tmpdir, "test.db") | |
| self.json_path = os.path.join(self.tmpdir, "messages.json") | |
| _write_json(self.json_path, _sample_messages(20)) | |
| from indexer import OptimizedIndexer | |
| indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False) | |
| indexer.index_file(self.json_path, show_progress=False) | |
| def tearDown(self): | |
| import shutil | |
| shutil.rmtree(self.tmpdir, ignore_errors=True) | |
| def test_fts_match_query(self): | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row | |
| rows = conn.execute( | |
| "SELECT id, text_plain FROM messages WHERE id IN " | |
| "(SELECT rowid FROM messages_fts WHERE messages_fts MATCH 'Alice')" | |
| ).fetchall() | |
| conn.close() | |
| self.assertGreater(len(rows), 0) | |
| for r in rows: | |
| self.assertIn("Alice", r["text_plain"]) | |
| def test_fts_returns_no_results_for_nonsense(self): | |
| conn = sqlite3.connect(self.db_path) | |
| rows = conn.execute( | |
| "SELECT COUNT(*) FROM messages_fts WHERE messages_fts MATCH 'xyzzyplugh'" | |
| ).fetchone()[0] | |
| conn.close() | |
| self.assertEqual(rows, 0) | |
| # --------------------------------------------------------------------------- | |
| # 3. SemanticSearch Empty Embeddings | |
| # --------------------------------------------------------------------------- | |
| try: | |
| import numpy as np | |
| HAS_NUMPY = True | |
| except ImportError: | |
| HAS_NUMPY = False | |
| class TestSemanticSearchEmpty(unittest.TestCase): | |
| """Test that SemanticSearch handles missing/empty embeddings gracefully.""" | |
| def test_is_available_missing_db(self): | |
| from semantic_search import SemanticSearch | |
| ss = SemanticSearch(embeddings_db="/tmp/nonexistent_embeddings_12345.db") | |
| self.assertFalse(ss.is_available()) | |
| def test_is_available_empty_db(self): | |
| from semantic_search import SemanticSearch | |
| tmpdir = tempfile.mkdtemp() | |
| db_path = os.path.join(tmpdir, "empty_emb.db") | |
| conn = sqlite3.connect(db_path) | |
| conn.execute( | |
| "CREATE TABLE embeddings (message_id INTEGER PRIMARY KEY, " | |
| "from_name TEXT, text_preview TEXT, embedding BLOB)" | |
| ) | |
| conn.commit() | |
| conn.close() | |
| ss = SemanticSearch(embeddings_db=db_path) | |
| self.assertFalse(ss.is_available()) | |
| import shutil | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| def test_load_empty_embeddings_no_crash(self): | |
| from semantic_search import SemanticSearch | |
| tmpdir = tempfile.mkdtemp() | |
| db_path = os.path.join(tmpdir, "empty_emb.db") | |
| conn = sqlite3.connect(db_path) | |
| conn.execute( | |
| "CREATE TABLE embeddings (message_id INTEGER PRIMARY KEY, " | |
| "from_name TEXT, text_preview TEXT, embedding BLOB)" | |
| ) | |
| conn.commit() | |
| conn.close() | |
| ss = SemanticSearch(embeddings_db=db_path) | |
| ss._load_embeddings() # Should not crash | |
| self.assertTrue(ss.embeddings_loaded) | |
| self.assertEqual(len(ss.message_ids), 0) | |
| import shutil | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| def test_stats_empty_db(self): | |
| from semantic_search import SemanticSearch | |
| tmpdir = tempfile.mkdtemp() | |
| db_path = os.path.join(tmpdir, "empty_emb.db") | |
| conn = sqlite3.connect(db_path) | |
| conn.execute( | |
| "CREATE TABLE embeddings (message_id INTEGER PRIMARY KEY, " | |
| "from_name TEXT, text_preview TEXT, embedding BLOB)" | |
| ) | |
| conn.commit() | |
| conn.close() | |
| ss = SemanticSearch(embeddings_db=db_path) | |
| s = ss.stats() | |
| self.assertTrue(s["available"]) # File exists and table exists | |
| self.assertEqual(s["count"], 0) | |
| import shutil | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| # --------------------------------------------------------------------------- | |
| # 4. Dashboard Endpoint Tests | |
| # --------------------------------------------------------------------------- | |
| try: | |
| import flask | |
| HAS_FLASK = True | |
| except ImportError: | |
| HAS_FLASK = False | |
| class TestDashboardEndpoints(unittest.TestCase): | |
| """Test Flask dashboard API endpoints.""" | |
| def setUpClass(cls): | |
| """Create a test DB and configure Flask test client.""" | |
| cls.tmpdir = tempfile.mkdtemp() | |
| cls.db_path = os.path.join(cls.tmpdir, "test.db") | |
| cls.json_path = os.path.join(cls.tmpdir, "messages.json") | |
| _write_json(cls.json_path, _sample_messages(50)) | |
| from indexer import OptimizedIndexer | |
| indexer = OptimizedIndexer(cls.db_path, build_trigrams=False, build_graph=False) | |
| indexer.index_file(cls.json_path, show_progress=False) | |
| import dashboard | |
| dashboard.DB_PATH = cls.db_path | |
| dashboard.app.config["TESTING"] = True | |
| cls.client = dashboard.app.test_client() | |
| def tearDownClass(cls): | |
| import shutil | |
| shutil.rmtree(cls.tmpdir, ignore_errors=True) | |
| def test_overview_endpoint(self): | |
| resp = self.client.get("/api/overview?timeframe=all") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| self.assertIn("total_messages", data) | |
| self.assertGreater(data["total_messages"], 0) | |
| def test_users_endpoint(self): | |
| resp = self.client.get("/api/users?timeframe=all&limit=10") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| self.assertIn("users", data) | |
| self.assertGreater(len(data["users"]), 0) | |
| user = data["users"][0] | |
| for field in ("user_id", "name", "messages", "percentage"): | |
| self.assertIn(field, user) | |
| def test_users_include_inactive(self): | |
| resp = self.client.get("/api/users?timeframe=all&include_inactive=0") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| for user in data["users"]: | |
| self.assertGreater(user["messages"], 0) | |
| def test_search_fts_endpoint(self): | |
| resp = self.client.get("/api/search?q=message&mode=fts&limit=5") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| self.assertIn("results", data) | |
| def test_chart_hourly_endpoint(self): | |
| resp = self.client.get("/api/chart/hourly?timeframe=all") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| self.assertIsInstance(data, list) | |
| self.assertEqual(len(data), 24) | |
| def test_chart_daily_endpoint(self): | |
| resp = self.client.get("/api/chart/daily?timeframe=all") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| self.assertIsInstance(data, list) | |
| def test_cache_invalidate_endpoint(self): | |
| resp = self.client.get("/api/cache/invalidate") | |
| self.assertEqual(resp.status_code, 200) | |
| data = resp.get_json() | |
| self.assertEqual(data["status"], "invalidated") | |
| def test_page_routes_return_200(self): | |
| """All page routes should return 200.""" | |
| for route in ("/", "/users", "/search", "/chat", "/moderation", "/settings"): | |
| resp = self.client.get(route) | |
| self.assertEqual(resp.status_code, 200, f"Route {route} failed") | |
| def test_user_profile_endpoint(self): | |
| resp = self.client.get("/api/users?timeframe=all&limit=1") | |
| data = resp.get_json() | |
| if data["users"]: | |
| uid = data["users"][0]["user_id"] | |
| resp2 = self.client.get(f"/api/user/{uid}/profile") | |
| self.assertEqual(resp2.status_code, 200) | |
| profile = resp2.get_json() | |
| self.assertIn("total_messages", profile) | |
| self.assertIn("hourly_activity", profile) | |
| def test_overview_has_expected_keys(self): | |
| resp = self.client.get("/api/overview?timeframe=all") | |
| data = resp.get_json() | |
| for key in ("total_messages", "total_users", "links_count", "media_count"): | |
| self.assertIn(key, data, f"Missing key: {key}") | |
| # --------------------------------------------------------------------------- | |
| # 5. AI Search Schema Test | |
| # --------------------------------------------------------------------------- | |
| class TestAISearchSchema(unittest.TestCase): | |
| """Test that AI search schema generation matches actual DB.""" | |
| def test_dynamic_schema_includes_real_columns(self): | |
| tmpdir = tempfile.mkdtemp() | |
| db_path = os.path.join(tmpdir, "test.db") | |
| # Initialize DB with real schema | |
| from indexer import init_database | |
| conn = init_database(db_path) | |
| conn.close() | |
| from ai_search import AISearchEngine | |
| # Create instance without connecting to a provider | |
| engine = AISearchEngine.__new__(AISearchEngine) | |
| engine.db_path = db_path | |
| schema = engine._get_db_schema() | |
| # Verify real column names are present | |
| self.assertIn("text_plain", schema) | |
| self.assertIn("date_unixtime", schema) | |
| self.assertIn("has_links", schema) | |
| self.assertIn("has_media", schema) | |
| self.assertIn("from_id", schema) | |
| self.assertIn("participants", schema) | |
| # Verify old wrong column names are NOT in the dynamic output | |
| self.assertNotIn("char_count", schema) | |
| # media_type would not appear unless there's a column named that | |
| lines_lower = schema.lower() | |
| # "media_type" should not be a column name (has_media is the real one) | |
| self.assertNotIn("media_type (", lines_lower) | |
| import shutil | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| unittest.main(verbosity=2) | |