telegram-analytics / tests.py
rottg's picture
Upload folder using huggingface_hub
4a21e7e
#!/usr/bin/env python3
"""
Integration tests for Telegram Analytics: indexer, search, and dashboard endpoints.
Run with: python -m pytest tests.py -v
Or: python tests.py
"""
import json
import os
import sqlite3
import tempfile
import time
import unittest
from pathlib import Path
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _sample_messages(n: int = 5) -> list[dict]:
"""Generate N realistic Telegram-format messages."""
base_ts = 1700000000
users = [
("user1", "Alice"),
("user2", "Bob"),
("user3", "Carol"),
]
msgs = []
for i in range(1, n + 1):
uid, name = users[i % len(users)]
msgs.append({
"id": 1000 + i,
"type": "message",
"date": f"2024-01-{(i % 28) + 1:02d}T10:00:00",
"date_unixtime": str(base_ts + i * 3600),
"from": name,
"from_id": uid,
"text": f"Test message number {i} from {name}",
"text_entities": [
{"type": "plain", "text": f"Test message number {i} from {name}"}
],
"reply_to_message_id": (1000 + i - 1) if i > 1 else None,
})
return msgs
def _write_json(path: str, messages: list[dict]):
"""Write messages in Telegram export JSON format."""
with open(path, "w", encoding="utf-8") as f:
json.dump({"messages": messages}, f, ensure_ascii=False)
# ---------------------------------------------------------------------------
# 1. Indexer Tests
# ---------------------------------------------------------------------------
class TestIndexer(unittest.TestCase):
"""Tests for OptimizedIndexer and IncrementalIndexer."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.db_path = os.path.join(self.tmpdir, "test.db")
self.json_path = os.path.join(self.tmpdir, "messages.json")
self.messages = _sample_messages(10)
_write_json(self.json_path, self.messages)
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_optimized_indexer_indexes_messages(self):
from indexer import OptimizedIndexer
indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
stats = indexer.index_file(self.json_path, show_progress=False)
self.assertGreater(stats["messages"], 0)
conn = sqlite3.connect(self.db_path)
count = conn.execute("SELECT COUNT(*) FROM messages").fetchone()[0]
conn.close()
self.assertEqual(count, stats["messages"])
def test_incremental_indexer_deduplication(self):
from indexer import OptimizedIndexer, IncrementalIndexer
# First: create DB with OptimizedIndexer
opt = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
opt.index_file(self.json_path, show_progress=False)
# Now use IncrementalIndexer – same data, should all be duplicates
idx = IncrementalIndexer(self.db_path)
stats = idx.update_from_json(self.json_path, show_progress=False)
idx.close()
self.assertEqual(stats["new_messages"], 0)
self.assertGreater(stats["duplicates"], 0)
def test_incremental_indexer_adds_new(self):
from indexer import OptimizedIndexer, IncrementalIndexer
# Create DB with 5 messages
msgs5 = _sample_messages(5)
_write_json(self.json_path, msgs5)
opt = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
opt.index_file(self.json_path, show_progress=False)
# Now add 10 messages (5 old + 5 new)
msgs10 = _sample_messages(10)
json2 = os.path.join(self.tmpdir, "messages2.json")
_write_json(json2, msgs10)
idx = IncrementalIndexer(self.db_path)
stats = idx.update_from_json(json2, show_progress=False)
idx.close()
self.assertEqual(stats["new_messages"], 5)
self.assertEqual(stats["duplicates"], 5)
def test_incremental_indexer_from_json_data(self):
from indexer import OptimizedIndexer, IncrementalIndexer
# Init DB first
opt = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
opt.index_file(self.json_path, show_progress=False)
# Add new messages via json_data
new_msgs = _sample_messages(15) # 10 old + 5 new
idx = IncrementalIndexer(self.db_path)
stats = idx.update_from_json_data(new_msgs, show_progress=False)
idx.close()
self.assertEqual(stats["new_messages"], 5)
def test_fts5_search_works(self):
from indexer import OptimizedIndexer
indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
indexer.index_file(self.json_path, show_progress=False)
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"SELECT COUNT(*) FROM messages_fts WHERE messages_fts MATCH 'message'"
)
count = cursor.fetchone()[0]
conn.close()
self.assertGreater(count, 0, "FTS5 search should find messages with 'message'")
def test_streaming_load_json_messages(self):
from indexer import load_json_messages
msgs = list(load_json_messages(self.json_path))
self.assertEqual(len(msgs), 10)
self.assertIn("text_plain", msgs[0])
def test_entities_extracted(self):
"""Messages with links/mentions in text_entities should have entities stored."""
msgs = [
{
"id": 9001,
"type": "message",
"date": "2024-01-01T10:00:00",
"date_unixtime": "1700000000",
"from": "Alice",
"from_id": "user1",
"text": "Check https://example.com and @bob",
"text_entities": [
{"type": "plain", "text": "Check "},
{"type": "link", "text": "https://example.com"},
{"type": "plain", "text": " and "},
{"type": "mention", "text": "@bob"},
],
}
]
_write_json(self.json_path, msgs)
from indexer import OptimizedIndexer
indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
indexer.index_file(self.json_path, show_progress=False)
conn = sqlite3.connect(self.db_path)
entities = conn.execute("SELECT type, value FROM entities WHERE message_id = 9001").fetchall()
conn.close()
types = [e[0] for e in entities]
self.assertIn("link", types)
self.assertIn("mention", types)
# ---------------------------------------------------------------------------
# 2. Search Tests
# ---------------------------------------------------------------------------
class TestSearch(unittest.TestCase):
"""Tests for FTS search."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.db_path = os.path.join(self.tmpdir, "test.db")
self.json_path = os.path.join(self.tmpdir, "messages.json")
_write_json(self.json_path, _sample_messages(20))
from indexer import OptimizedIndexer
indexer = OptimizedIndexer(self.db_path, build_trigrams=False, build_graph=False)
indexer.index_file(self.json_path, show_progress=False)
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_fts_match_query(self):
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, text_plain FROM messages WHERE id IN "
"(SELECT rowid FROM messages_fts WHERE messages_fts MATCH 'Alice')"
).fetchall()
conn.close()
self.assertGreater(len(rows), 0)
for r in rows:
self.assertIn("Alice", r["text_plain"])
def test_fts_returns_no_results_for_nonsense(self):
conn = sqlite3.connect(self.db_path)
rows = conn.execute(
"SELECT COUNT(*) FROM messages_fts WHERE messages_fts MATCH 'xyzzyplugh'"
).fetchone()[0]
conn.close()
self.assertEqual(rows, 0)
# ---------------------------------------------------------------------------
# 3. SemanticSearch Empty Embeddings
# ---------------------------------------------------------------------------
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
@unittest.skipUnless(HAS_NUMPY, "numpy not installed")
class TestSemanticSearchEmpty(unittest.TestCase):
"""Test that SemanticSearch handles missing/empty embeddings gracefully."""
def test_is_available_missing_db(self):
from semantic_search import SemanticSearch
ss = SemanticSearch(embeddings_db="/tmp/nonexistent_embeddings_12345.db")
self.assertFalse(ss.is_available())
def test_is_available_empty_db(self):
from semantic_search import SemanticSearch
tmpdir = tempfile.mkdtemp()
db_path = os.path.join(tmpdir, "empty_emb.db")
conn = sqlite3.connect(db_path)
conn.execute(
"CREATE TABLE embeddings (message_id INTEGER PRIMARY KEY, "
"from_name TEXT, text_preview TEXT, embedding BLOB)"
)
conn.commit()
conn.close()
ss = SemanticSearch(embeddings_db=db_path)
self.assertFalse(ss.is_available())
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
def test_load_empty_embeddings_no_crash(self):
from semantic_search import SemanticSearch
tmpdir = tempfile.mkdtemp()
db_path = os.path.join(tmpdir, "empty_emb.db")
conn = sqlite3.connect(db_path)
conn.execute(
"CREATE TABLE embeddings (message_id INTEGER PRIMARY KEY, "
"from_name TEXT, text_preview TEXT, embedding BLOB)"
)
conn.commit()
conn.close()
ss = SemanticSearch(embeddings_db=db_path)
ss._load_embeddings() # Should not crash
self.assertTrue(ss.embeddings_loaded)
self.assertEqual(len(ss.message_ids), 0)
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
def test_stats_empty_db(self):
from semantic_search import SemanticSearch
tmpdir = tempfile.mkdtemp()
db_path = os.path.join(tmpdir, "empty_emb.db")
conn = sqlite3.connect(db_path)
conn.execute(
"CREATE TABLE embeddings (message_id INTEGER PRIMARY KEY, "
"from_name TEXT, text_preview TEXT, embedding BLOB)"
)
conn.commit()
conn.close()
ss = SemanticSearch(embeddings_db=db_path)
s = ss.stats()
self.assertTrue(s["available"]) # File exists and table exists
self.assertEqual(s["count"], 0)
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 4. Dashboard Endpoint Tests
# ---------------------------------------------------------------------------
try:
import flask
HAS_FLASK = True
except ImportError:
HAS_FLASK = False
@unittest.skipUnless(HAS_FLASK, "flask not installed")
class TestDashboardEndpoints(unittest.TestCase):
"""Test Flask dashboard API endpoints."""
@classmethod
def setUpClass(cls):
"""Create a test DB and configure Flask test client."""
cls.tmpdir = tempfile.mkdtemp()
cls.db_path = os.path.join(cls.tmpdir, "test.db")
cls.json_path = os.path.join(cls.tmpdir, "messages.json")
_write_json(cls.json_path, _sample_messages(50))
from indexer import OptimizedIndexer
indexer = OptimizedIndexer(cls.db_path, build_trigrams=False, build_graph=False)
indexer.index_file(cls.json_path, show_progress=False)
import dashboard
dashboard.DB_PATH = cls.db_path
dashboard.app.config["TESTING"] = True
cls.client = dashboard.app.test_client()
@classmethod
def tearDownClass(cls):
import shutil
shutil.rmtree(cls.tmpdir, ignore_errors=True)
def test_overview_endpoint(self):
resp = self.client.get("/api/overview?timeframe=all")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
self.assertIn("total_messages", data)
self.assertGreater(data["total_messages"], 0)
def test_users_endpoint(self):
resp = self.client.get("/api/users?timeframe=all&limit=10")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
self.assertIn("users", data)
self.assertGreater(len(data["users"]), 0)
user = data["users"][0]
for field in ("user_id", "name", "messages", "percentage"):
self.assertIn(field, user)
def test_users_include_inactive(self):
resp = self.client.get("/api/users?timeframe=all&include_inactive=0")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
for user in data["users"]:
self.assertGreater(user["messages"], 0)
def test_search_fts_endpoint(self):
resp = self.client.get("/api/search?q=message&mode=fts&limit=5")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
self.assertIn("results", data)
def test_chart_hourly_endpoint(self):
resp = self.client.get("/api/chart/hourly?timeframe=all")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
self.assertIsInstance(data, list)
self.assertEqual(len(data), 24)
def test_chart_daily_endpoint(self):
resp = self.client.get("/api/chart/daily?timeframe=all")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
self.assertIsInstance(data, list)
def test_cache_invalidate_endpoint(self):
resp = self.client.get("/api/cache/invalidate")
self.assertEqual(resp.status_code, 200)
data = resp.get_json()
self.assertEqual(data["status"], "invalidated")
def test_page_routes_return_200(self):
"""All page routes should return 200."""
for route in ("/", "/users", "/search", "/chat", "/moderation", "/settings"):
resp = self.client.get(route)
self.assertEqual(resp.status_code, 200, f"Route {route} failed")
def test_user_profile_endpoint(self):
resp = self.client.get("/api/users?timeframe=all&limit=1")
data = resp.get_json()
if data["users"]:
uid = data["users"][0]["user_id"]
resp2 = self.client.get(f"/api/user/{uid}/profile")
self.assertEqual(resp2.status_code, 200)
profile = resp2.get_json()
self.assertIn("total_messages", profile)
self.assertIn("hourly_activity", profile)
def test_overview_has_expected_keys(self):
resp = self.client.get("/api/overview?timeframe=all")
data = resp.get_json()
for key in ("total_messages", "total_users", "links_count", "media_count"):
self.assertIn(key, data, f"Missing key: {key}")
# ---------------------------------------------------------------------------
# 5. AI Search Schema Test
# ---------------------------------------------------------------------------
class TestAISearchSchema(unittest.TestCase):
"""Test that AI search schema generation matches actual DB."""
def test_dynamic_schema_includes_real_columns(self):
tmpdir = tempfile.mkdtemp()
db_path = os.path.join(tmpdir, "test.db")
# Initialize DB with real schema
from indexer import init_database
conn = init_database(db_path)
conn.close()
from ai_search import AISearchEngine
# Create instance without connecting to a provider
engine = AISearchEngine.__new__(AISearchEngine)
engine.db_path = db_path
schema = engine._get_db_schema()
# Verify real column names are present
self.assertIn("text_plain", schema)
self.assertIn("date_unixtime", schema)
self.assertIn("has_links", schema)
self.assertIn("has_media", schema)
self.assertIn("from_id", schema)
self.assertIn("participants", schema)
# Verify old wrong column names are NOT in the dynamic output
self.assertNotIn("char_count", schema)
# media_type would not appear unless there's a column named that
lines_lower = schema.lower()
# "media_type" should not be a column name (has_media is the real one)
self.assertNotIn("media_type (", lines_lower)
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
if __name__ == "__main__":
unittest.main(verbosity=2)