| """ |
| In-Memory Database for Main DB Server. |
| Stores: |
| - users_registry: {username: {username, telegram_id, server_num, tokens: [...]}} |
| - server_counts: {server_num_str: count} |
| - tokens_index: {fingerprint_token: username} (reverse lookup) |
| |
| IMPORTANT: Uses file-based locking + shared memory to sync between |
| Gunicorn workers. Each write operation saves to a shared temp file, |
| and reads always check for updates first. |
| """ |
|
|
| import time |
| import json |
| import base64 |
| import os |
| import tempfile |
| import hashlib |
| from datetime import datetime |
| from threading import RLock |
| from filelock import FileLock |
|
|
| import requests |
|
|
| |
| SHARED_DIR = os.environ.get("SHARED_DATA_DIR", "/dev/shm/maindb") |
| SHARED_LOCK_PATH = os.path.join(SHARED_DIR, "maindb.lock") |
|
|
|
|
| def _ensure_shared_dir(): |
| """Create shared directory if it doesn't exist.""" |
| os.makedirs(SHARED_DIR, exist_ok=True) |
|
|
|
|
| class MainMemoryDB: |
| """ |
| Central registry database. RAM + GitHub backup. |
| |
| Uses /dev/shm (shared memory filesystem) to sync data between |
| Gunicorn workers. Every write saves to shared files, every read |
| checks if shared files have been updated by another worker. |
| """ |
|
|
| _instance = None |
| _init_lock = RLock() |
|
|
| STORES = ['users_registry', 'server_counts', 'tokens_index'] |
|
|
| STORE_FILES = { |
| 'users_registry': 'users_registry.json', |
| 'server_counts': 'server_counts.json', |
| 'tokens_index': 'tokens_index.json', |
| } |
|
|
| @classmethod |
| def get_instance(cls): |
| if cls._instance is None: |
| with cls._init_lock: |
| if cls._instance is None: |
| cls._instance = cls() |
| return cls._instance |
|
|
| def __init__(self): |
| _ensure_shared_dir() |
|
|
| self._data = {} |
| self._locks = {} |
| self._data_hashes = {} |
|
|
| for store in self.STORES: |
| self._locks[store] = RLock() |
| self._data[store] = {} |
| self._data_hashes[store] = "" |
|
|
| |
| self._file_lock = FileLock(SHARED_LOCK_PATH, timeout=10) |
|
|
| |
| loaded_from_shared = self._load_from_shared() |
|
|
| if not loaded_from_shared: |
| |
| self._load_from_github() |
| self._rebuild_server_counts() |
| |
| self._save_to_shared() |
|
|
| print(f"✅ MainMemoryDB initialized (PID: {os.getpid()})") |
| for store in self.STORES: |
| print(f" {store}: {len(self._data[store])} records") |
|
|
| |
|
|
| def _shared_file_path(self, store_name): |
| """Path to shared memory file for a store.""" |
| return os.path.join(SHARED_DIR, f"{store_name}.json") |
|
|
| def _shared_hash_path(self, store_name): |
| """Path to hash file that tracks last update.""" |
| return os.path.join(SHARED_DIR, f"{store_name}.hash") |
|
|
| def _load_from_shared(self): |
| """ |
| Try to load data from shared memory files. |
| Returns True if data was found and loaded. |
| """ |
| try: |
| with self._file_lock: |
| found_any = False |
| for store_name in self.STORES: |
| fpath = self._shared_file_path(store_name) |
| if os.path.exists(fpath): |
| with open(fpath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| if isinstance(data, dict): |
| self._data[store_name] = data |
| |
| hpath = self._shared_hash_path(store_name) |
| if os.path.exists(hpath): |
| with open(hpath, 'r') as hf: |
| self._data_hashes[store_name] = hf.read().strip() |
| found_any = True |
|
|
| if found_any: |
| print(f" 📂 Loaded data from shared memory (PID: {os.getpid()})") |
| return found_any |
| except Exception as e: |
| print(f" ⚠️ Could not load from shared memory: {e}") |
| return False |
|
|
| def _save_to_shared(self): |
| """ |
| Save all stores to shared memory files. |
| Called after every write operation. |
| """ |
| try: |
| with self._file_lock: |
| for store_name in self.STORES: |
| fpath = self._shared_file_path(store_name) |
| hpath = self._shared_hash_path(store_name) |
|
|
| with self._locks[store_name]: |
| data = self._data[store_name] |
| content = json.dumps(data, ensure_ascii=False) |
|
|
| |
| with open(fpath, 'w', encoding='utf-8') as f: |
| f.write(content) |
|
|
| |
| new_hash = hashlib.md5(content.encode()).hexdigest() |
| with open(hpath, 'w') as hf: |
| hf.write(new_hash) |
|
|
| self._data_hashes[store_name] = new_hash |
|
|
| except Exception as e: |
| print(f" ❌ Error saving to shared memory: {e}") |
|
|
| def _sync_from_shared(self): |
| """ |
| Check if another worker has updated the shared files. |
| If so, reload the data. Called before every read operation. |
| """ |
| try: |
| needs_reload = False |
|
|
| for store_name in self.STORES: |
| hpath = self._shared_hash_path(store_name) |
| if os.path.exists(hpath): |
| with open(hpath, 'r') as hf: |
| current_hash = hf.read().strip() |
| if current_hash != self._data_hashes.get(store_name, ''): |
| needs_reload = True |
| break |
|
|
| if needs_reload: |
| with self._file_lock: |
| for store_name in self.STORES: |
| fpath = self._shared_file_path(store_name) |
| hpath = self._shared_hash_path(store_name) |
|
|
| if os.path.exists(fpath): |
| with open(fpath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| if isinstance(data, dict): |
| with self._locks[store_name]: |
| self._data[store_name] = data |
|
|
| if os.path.exists(hpath): |
| with open(hpath, 'r') as hf: |
| self._data_hashes[store_name] = hf.read().strip() |
|
|
| except Exception as e: |
| |
| pass |
|
|
| |
|
|
| def _github_headers(self): |
| from config import GITHUB_TOKEN |
| return { |
| "Authorization": f"token {GITHUB_TOKEN}", |
| "Accept": "application/vnd.github.v3+json", |
| "Content-Type": "application/json", |
| } |
|
|
| def _github_file_url(self, filename): |
| from config import GITHUB_REPO, GITHUB_BRANCH |
| return ( |
| f"https://api.github.com/repos/{GITHUB_REPO}" |
| f"/contents/{filename}?ref={GITHUB_BRANCH}" |
| ) |
|
|
| def _load_from_github(self): |
| from config import GITHUB_TOKEN |
| if not GITHUB_TOKEN or len(GITHUB_TOKEN) < 10: |
| print(" ⚠️ GitHub not configured - starting empty") |
| return |
|
|
| print(" 📥 Loading registry from GitHub...") |
| self._file_shas = {} |
|
|
| for store_name, filename in self.STORE_FILES.items(): |
| try: |
| resp = requests.get( |
| self._github_file_url(filename), |
| headers=self._github_headers(), |
| timeout=30 |
| ) |
| if resp.status_code == 200: |
| data = resp.json() |
| self._file_shas[filename] = data.get('sha', '') |
| content = base64.b64decode(data.get('content', '')).decode('utf-8') |
| parsed = json.loads(content) |
| if isinstance(parsed, dict): |
| self._data[store_name] = parsed |
| print(f" ✅ {filename}: {len(parsed)} records") |
| else: |
| print(f" ⚠️ {filename}: not a dict, skipping") |
| elif resp.status_code == 404: |
| print(f" ℹ️ {filename}: not found (first run)") |
| self._file_shas[filename] = None |
| else: |
| print(f" ❌ {filename}: HTTP {resp.status_code}") |
| self._file_shas[filename] = None |
| except Exception as e: |
| print(f" ❌ Error loading {filename}: {e}") |
| self._file_shas[filename] = None |
|
|
| def _rebuild_server_counts(self): |
| """Rebuild server_counts from users_registry.""" |
| from config import TOTAL_SERVERS |
| counts = {} |
| for i in range(1, TOTAL_SERVERS + 1): |
| counts[str(i)] = 0 |
|
|
| registry = self._data.get('users_registry', {}) |
| for username, user_data in registry.items(): |
| server_num = str(user_data.get('server_num', 0)) |
| if server_num in counts: |
| counts[server_num] += 1 |
|
|
| with self._locks['server_counts']: |
| self._data['server_counts'] = counts |
|
|
| print(f" 📊 Server counts rebuilt: {sum(counts.values())} total users across {TOTAL_SERVERS} servers") |
|
|
| def push_to_github(self): |
| """Push all stores to GitHub.""" |
| from config import GITHUB_TOKEN, GITHUB_BRANCH |
| if not GITHUB_TOKEN or len(GITHUB_TOKEN) < 10: |
| return False, ["GitHub not configured"] |
|
|
| |
| self._sync_from_shared() |
|
|
| errors = [] |
| for store_name, filename in self.STORE_FILES.items(): |
| try: |
| with self._locks[store_name]: |
| data = dict(self._data.get(store_name, {})) |
|
|
| content_str = json.dumps(data, indent=2, ensure_ascii=False) |
| content_b64 = base64.b64encode(content_str.encode('utf-8')).decode('utf-8') |
|
|
| payload = { |
| "message": f"Backup {filename} - {len(data)} records - {datetime.now().isoformat()}", |
| "content": content_b64, |
| "branch": GITHUB_BRANCH, |
| } |
|
|
| sha = getattr(self, '_file_shas', {}).get(filename) |
| if not sha: |
| try: |
| check = requests.get( |
| self._github_file_url(filename), |
| headers=self._github_headers(), |
| timeout=15 |
| ) |
| if check.status_code == 200: |
| sha = check.json().get('sha', '') |
| except Exception: |
| pass |
|
|
| if sha: |
| payload["sha"] = sha |
|
|
| resp = requests.put( |
| self._github_file_url(filename), |
| headers=self._github_headers(), |
| json=payload, |
| timeout=30 |
| ) |
|
|
| if resp.status_code in [200, 201]: |
| new_sha = resp.json().get('content', {}).get('sha', '') |
| if new_sha: |
| if not hasattr(self, '_file_shas'): |
| self._file_shas = {} |
| self._file_shas[filename] = new_sha |
| print(f" ✅ Pushed {filename}") |
| else: |
| err_msg = resp.text[:200] |
| errors.append(f"{filename}: HTTP {resp.status_code} - {err_msg}") |
| print(f" ❌ Failed {filename}: {resp.status_code}") |
|
|
| except Exception as e: |
| errors.append(f"{filename}: {e}") |
| print(f" ❌ Error pushing {filename}: {e}") |
|
|
| return len(errors) == 0, errors |
|
|
| |
|
|
| def get_user(self, username): |
| self._sync_from_shared() |
| with self._locks['users_registry']: |
| return self._data['users_registry'].get(username) |
|
|
| def get_user_by_token(self, token): |
| self._sync_from_shared() |
| with self._locks['tokens_index']: |
| username = self._data['tokens_index'].get(token) |
| if not username: |
| return None |
| return self.get_user(username) |
|
|
| def get_server_counts(self): |
| self._sync_from_shared() |
| with self._locks['server_counts']: |
| return dict(self._data['server_counts']) |
|
|
| def get_best_server(self): |
| """Find server with lowest user count.""" |
| from config import MAX_USERS_PER_SERVER, TOTAL_SERVERS |
| counts = self.get_server_counts() |
|
|
| best_server = None |
| best_count = float('inf') |
|
|
| for i in range(1, TOTAL_SERVERS + 1): |
| num_str = str(i) |
| count = counts.get(num_str, 0) |
| if count < MAX_USERS_PER_SERVER and count < best_count: |
| best_count = count |
| best_server = i |
|
|
| return best_server |
|
|
| def get_stats(self): |
| self._sync_from_shared() |
| stats = {} |
| for store in self.STORES: |
| with self._locks[store]: |
| stats[store] = len(self._data[store]) |
| stats['server_counts'] = self.get_server_counts() |
| return stats |
|
|
| def get_total_users(self): |
| self._sync_from_shared() |
| with self._locks['users_registry']: |
| return len(self._data['users_registry']) |
|
|
| |
|
|
| def register_user(self, username, telegram_id, server_num, token=None): |
| """Register a new user in the central registry.""" |
| self._sync_from_shared() |
|
|
| user_data = { |
| "username": username, |
| "telegram_id": telegram_id, |
| "server_num": server_num, |
| "tokens": [], |
| "created_at": datetime.now().isoformat(), |
| "last_login": datetime.now().isoformat(), |
| } |
|
|
| if token: |
| user_data["tokens"].append({ |
| "token": token, |
| "created_at": datetime.now().isoformat(), |
| }) |
|
|
| with self._locks['users_registry']: |
| if username in self._data['users_registry']: |
| return False, "Username already exists" |
| self._data['users_registry'][username] = user_data |
|
|
| |
| with self._locks['server_counts']: |
| num_str = str(server_num) |
| self._data['server_counts'][num_str] = \ |
| self._data['server_counts'].get(num_str, 0) + 1 |
|
|
| |
| if token: |
| with self._locks['tokens_index']: |
| self._data['tokens_index'][token] = username |
|
|
| |
| self._save_to_shared() |
|
|
| print(f" ✅ Registered user '{username}' on server {server_num} (PID: {os.getpid()})") |
| return True, None |
|
|
| def link_token(self, username, token): |
| """Link a device fingerprint token to a user. REPLACES all previous tokens.""" |
| self._sync_from_shared() |
|
|
| with self._locks['users_registry']: |
| user = self._data['users_registry'].get(username) |
| if not user: |
| return False, "User not found" |
|
|
| |
| old_username = None |
| with self._locks['tokens_index']: |
| old_username = self._data['tokens_index'].get(token) |
|
|
| if old_username and old_username != username: |
| old_user = self._data['users_registry'].get(old_username) |
| if old_user: |
| old_user['tokens'] = [ |
| t for t in old_user.get('tokens', []) |
| if t.get('token') != token |
| ] |
|
|
| |
| old_tokens = user.get('tokens', []) |
| for old_t in old_tokens: |
| old_token_val = old_t.get('token', '') |
| if old_token_val and old_token_val != token: |
| with self._locks['tokens_index']: |
| self._data['tokens_index'].pop(old_token_val, None) |
|
|
| |
| user['tokens'] = [{ |
| "token": token, |
| "created_at": datetime.now().isoformat(), |
| }] |
|
|
| user['last_login'] = datetime.now().isoformat() |
|
|
| |
| with self._locks['tokens_index']: |
| self._data['tokens_index'][token] = username |
|
|
| self._save_to_shared() |
|
|
| return True, None |
|
|
| def unlink_token(self, token): |
| """Remove a token from the index and from the user.""" |
| self._sync_from_shared() |
|
|
| with self._locks['tokens_index']: |
| username = self._data['tokens_index'].pop(token, None) |
|
|
| if username: |
| with self._locks['users_registry']: |
| user = self._data['users_registry'].get(username) |
| if user: |
| user['tokens'] = [ |
| t for t in user.get('tokens', []) |
| if t.get('token') != token |
| ] |
|
|
| |
| self._save_to_shared() |
|
|
| return True |
|
|
| def update_user_login(self, username): |
| """Update last_login timestamp.""" |
| self._sync_from_shared() |
|
|
| with self._locks['users_registry']: |
| user = self._data['users_registry'].get(username) |
| if user: |
| user['last_login'] = datetime.now().isoformat() |
|
|
| |
| self._save_to_shared() |
| return True |
| return False |
|
|
| def search_users(self, query, limit=10): |
| """Search users by username prefix.""" |
| if not query or len(query) < 1: |
| return [] |
|
|
| self._sync_from_shared() |
|
|
| query_lower = query.lower().strip() |
| results = [] |
|
|
| with self._locks['users_registry']: |
| for uname, udata in self._data['users_registry'].items(): |
| if uname.lower().startswith(query_lower): |
| results.append({ |
| "username": uname, |
| "server_num": udata.get('server_num'), |
| }) |
| if len(results) >= limit: |
| break |
|
|
| return results |
|
|
|
|
| def get_db(): |
| return MainMemoryDB.get_instance() |