File size: 19,293 Bytes
58fbd44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943787c
 
58fbd44
 
 
 
 
 
943787c
58fbd44
 
 
 
 
 
 
 
 
 
 
 
943787c
 
 
 
 
 
 
58fbd44
943787c
 
 
 
 
58fbd44
 
 
943787c
58fbd44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67de256
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
"""
In-Memory Database for Main DB Server.
Stores:
  - users_registry: {username: {username, telegram_id, server_num, tokens: [...]}}
  - server_counts: {server_num_str: count}
  - tokens_index: {fingerprint_token: username}  (reverse lookup)

IMPORTANT: Uses file-based locking + shared memory to sync between
Gunicorn workers. Each write operation saves to a shared temp file,
and reads always check for updates first.
"""

import time
import json
import base64
import os
import tempfile
import hashlib
from datetime import datetime
from threading import RLock
from filelock import FileLock

import requests

# Shared directory for inter-worker communication
SHARED_DIR = os.environ.get("SHARED_DATA_DIR", "/dev/shm/maindb")
SHARED_LOCK_PATH = os.path.join(SHARED_DIR, "maindb.lock")


def _ensure_shared_dir():
    """Create shared directory if it doesn't exist."""
    os.makedirs(SHARED_DIR, exist_ok=True)


class MainMemoryDB:
    """
    Central registry database. RAM + GitHub backup.

    Uses /dev/shm (shared memory filesystem) to sync data between
    Gunicorn workers. Every write saves to shared files, every read
    checks if shared files have been updated by another worker.
    """

    _instance = None
    _init_lock = RLock()

    STORES = ['users_registry', 'server_counts', 'tokens_index']

    STORE_FILES = {
        'users_registry': 'users_registry.json',
        'server_counts': 'server_counts.json',
        'tokens_index': 'tokens_index.json',
    }

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            with cls._init_lock:
                if cls._instance is None:
                    cls._instance = cls()
        return cls._instance

    def __init__(self):
        _ensure_shared_dir()

        self._data = {}
        self._locks = {}
        self._data_hashes = {}  # Track if shared file changed

        for store in self.STORES:
            self._locks[store] = RLock()
            self._data[store] = {}
            self._data_hashes[store] = ""

        # File lock for cross-worker synchronization
        self._file_lock = FileLock(SHARED_LOCK_PATH, timeout=10)

        # Try loading from shared memory first (another worker may have started first)
        loaded_from_shared = self._load_from_shared()

        if not loaded_from_shared:
            # First worker to start - load from GitHub
            self._load_from_github()
            self._rebuild_server_counts()
            # Save to shared memory so other workers can pick it up
            self._save_to_shared()

        print(f"βœ… MainMemoryDB initialized (PID: {os.getpid()})")
        for store in self.STORES:
            print(f"   {store}: {len(self._data[store])} records")

    # ─── Shared Memory (Inter-Worker Sync) ───

    def _shared_file_path(self, store_name):
        """Path to shared memory file for a store."""
        return os.path.join(SHARED_DIR, f"{store_name}.json")

    def _shared_hash_path(self, store_name):
        """Path to hash file that tracks last update."""
        return os.path.join(SHARED_DIR, f"{store_name}.hash")

    def _load_from_shared(self):
        """
        Try to load data from shared memory files.
        Returns True if data was found and loaded.
        """
        try:
            with self._file_lock:
                found_any = False
                for store_name in self.STORES:
                    fpath = self._shared_file_path(store_name)
                    if os.path.exists(fpath):
                        with open(fpath, 'r', encoding='utf-8') as f:
                            data = json.load(f)
                        if isinstance(data, dict):
                            self._data[store_name] = data
                            # Read the hash
                            hpath = self._shared_hash_path(store_name)
                            if os.path.exists(hpath):
                                with open(hpath, 'r') as hf:
                                    self._data_hashes[store_name] = hf.read().strip()
                            found_any = True

                if found_any:
                    print(f"  πŸ“‚ Loaded data from shared memory (PID: {os.getpid()})")
                return found_any
        except Exception as e:
            print(f"  ⚠️ Could not load from shared memory: {e}")
            return False

    def _save_to_shared(self):
        """
        Save all stores to shared memory files.
        Called after every write operation.
        """
        try:
            with self._file_lock:
                for store_name in self.STORES:
                    fpath = self._shared_file_path(store_name)
                    hpath = self._shared_hash_path(store_name)

                    with self._locks[store_name]:
                        data = self._data[store_name]
                        content = json.dumps(data, ensure_ascii=False)

                    # Write data
                    with open(fpath, 'w', encoding='utf-8') as f:
                        f.write(content)

                    # Write hash
                    new_hash = hashlib.md5(content.encode()).hexdigest()
                    with open(hpath, 'w') as hf:
                        hf.write(new_hash)

                    self._data_hashes[store_name] = new_hash

        except Exception as e:
            print(f"  ❌ Error saving to shared memory: {e}")

    def _sync_from_shared(self):
        """
        Check if another worker has updated the shared files.
        If so, reload the data. Called before every read operation.
        """
        try:
            needs_reload = False

            for store_name in self.STORES:
                hpath = self._shared_hash_path(store_name)
                if os.path.exists(hpath):
                    with open(hpath, 'r') as hf:
                        current_hash = hf.read().strip()
                    if current_hash != self._data_hashes.get(store_name, ''):
                        needs_reload = True
                        break

            if needs_reload:
                with self._file_lock:
                    for store_name in self.STORES:
                        fpath = self._shared_file_path(store_name)
                        hpath = self._shared_hash_path(store_name)

                        if os.path.exists(fpath):
                            with open(fpath, 'r', encoding='utf-8') as f:
                                data = json.load(f)
                            if isinstance(data, dict):
                                with self._locks[store_name]:
                                    self._data[store_name] = data

                            if os.path.exists(hpath):
                                with open(hpath, 'r') as hf:
                                    self._data_hashes[store_name] = hf.read().strip()

        except Exception as e:
            # Don't crash on sync errors - just use local data
            pass

    # ─── GitHub Operations ───

    def _github_headers(self):
        from config import GITHUB_TOKEN
        return {
            "Authorization": f"token {GITHUB_TOKEN}",
            "Accept": "application/vnd.github.v3+json",
            "Content-Type": "application/json",
        }

    def _github_file_url(self, filename):
        from config import GITHUB_REPO, GITHUB_BRANCH
        return (
            f"https://api.github.com/repos/{GITHUB_REPO}"
            f"/contents/{filename}?ref={GITHUB_BRANCH}"
        )

    def _load_from_github(self):
        from config import GITHUB_TOKEN
        if not GITHUB_TOKEN or len(GITHUB_TOKEN) < 10:
            print("  ⚠️ GitHub not configured - starting empty")
            return

        print("  πŸ“₯ Loading registry from GitHub...")
        self._file_shas = {}

        for store_name, filename in self.STORE_FILES.items():
            try:
                resp = requests.get(
                    self._github_file_url(filename),
                    headers=self._github_headers(),
                    timeout=30
                )
                if resp.status_code == 200:
                    data = resp.json()
                    self._file_shas[filename] = data.get('sha', '')
                    content = base64.b64decode(data.get('content', '')).decode('utf-8')
                    parsed = json.loads(content)
                    if isinstance(parsed, dict):
                        self._data[store_name] = parsed
                        print(f"  βœ… {filename}: {len(parsed)} records")
                    else:
                        print(f"  ⚠️ {filename}: not a dict, skipping")
                elif resp.status_code == 404:
                    print(f"  ℹ️ {filename}: not found (first run)")
                    self._file_shas[filename] = None
                else:
                    print(f"  ❌ {filename}: HTTP {resp.status_code}")
                    self._file_shas[filename] = None
            except Exception as e:
                print(f"  ❌ Error loading {filename}: {e}")
                self._file_shas[filename] = None

    def _rebuild_server_counts(self):
        """Rebuild server_counts from users_registry."""
        from config import TOTAL_SERVERS
        counts = {}
        for i in range(1, TOTAL_SERVERS + 1):
            counts[str(i)] = 0

        registry = self._data.get('users_registry', {})
        for username, user_data in registry.items():
            server_num = str(user_data.get('server_num', 0))
            if server_num in counts:
                counts[server_num] += 1

        with self._locks['server_counts']:
            self._data['server_counts'] = counts

        print(f"  πŸ“Š Server counts rebuilt: {sum(counts.values())} total users across {TOTAL_SERVERS} servers")

    def push_to_github(self):
        """Push all stores to GitHub."""
        from config import GITHUB_TOKEN, GITHUB_BRANCH
        if not GITHUB_TOKEN or len(GITHUB_TOKEN) < 10:
            return False, ["GitHub not configured"]

        # Sync first to make sure we have latest data
        self._sync_from_shared()

        errors = []
        for store_name, filename in self.STORE_FILES.items():
            try:
                with self._locks[store_name]:
                    data = dict(self._data.get(store_name, {}))

                content_str = json.dumps(data, indent=2, ensure_ascii=False)
                content_b64 = base64.b64encode(content_str.encode('utf-8')).decode('utf-8')

                payload = {
                    "message": f"Backup {filename} - {len(data)} records - {datetime.now().isoformat()}",
                    "content": content_b64,
                    "branch": GITHUB_BRANCH,
                }

                sha = getattr(self, '_file_shas', {}).get(filename)
                if not sha:
                    try:
                        check = requests.get(
                            self._github_file_url(filename),
                            headers=self._github_headers(),
                            timeout=15
                        )
                        if check.status_code == 200:
                            sha = check.json().get('sha', '')
                    except Exception:
                        pass

                if sha:
                    payload["sha"] = sha

                resp = requests.put(
                    self._github_file_url(filename),
                    headers=self._github_headers(),
                    json=payload,
                    timeout=30
                )

                if resp.status_code in [200, 201]:
                    new_sha = resp.json().get('content', {}).get('sha', '')
                    if new_sha:
                        if not hasattr(self, '_file_shas'):
                            self._file_shas = {}
                        self._file_shas[filename] = new_sha
                    print(f"  βœ… Pushed {filename}")
                else:
                    err_msg = resp.text[:200]
                    errors.append(f"{filename}: HTTP {resp.status_code} - {err_msg}")
                    print(f"  ❌ Failed {filename}: {resp.status_code}")

            except Exception as e:
                errors.append(f"{filename}: {e}")
                print(f"  ❌ Error pushing {filename}: {e}")

        return len(errors) == 0, errors

    # ─── Read Operations ───

    def get_user(self, username):
        self._sync_from_shared()  # Check for updates from other workers
        with self._locks['users_registry']:
            return self._data['users_registry'].get(username)

    def get_user_by_token(self, token):
        self._sync_from_shared()  # Check for updates from other workers
        with self._locks['tokens_index']:
            username = self._data['tokens_index'].get(token)
        if not username:
            return None
        return self.get_user(username)

    def get_server_counts(self):
        self._sync_from_shared()
        with self._locks['server_counts']:
            return dict(self._data['server_counts'])

    def get_best_server(self):
        """Find server with lowest user count."""
        from config import MAX_USERS_PER_SERVER, TOTAL_SERVERS
        counts = self.get_server_counts()

        best_server = None
        best_count = float('inf')

        for i in range(1, TOTAL_SERVERS + 1):
            num_str = str(i)
            count = counts.get(num_str, 0)
            if count < MAX_USERS_PER_SERVER and count < best_count:
                best_count = count
                best_server = i

        return best_server

    def get_stats(self):
        self._sync_from_shared()
        stats = {}
        for store in self.STORES:
            with self._locks[store]:
                stats[store] = len(self._data[store])
        stats['server_counts'] = self.get_server_counts()
        return stats

    def get_total_users(self):
        self._sync_from_shared()
        with self._locks['users_registry']:
            return len(self._data['users_registry'])

    # ─── Write Operations ───

    def register_user(self, username, telegram_id, server_num, token=None):
        """Register a new user in the central registry."""
        self._sync_from_shared()  # Get latest before writing

        user_data = {
            "username": username,
            "telegram_id": telegram_id,
            "server_num": server_num,
            "tokens": [],
            "created_at": datetime.now().isoformat(),
            "last_login": datetime.now().isoformat(),
        }

        if token:
            user_data["tokens"].append({
                "token": token,
                "created_at": datetime.now().isoformat(),
            })

        with self._locks['users_registry']:
            if username in self._data['users_registry']:
                return False, "Username already exists"
            self._data['users_registry'][username] = user_data

        # Update server count
        with self._locks['server_counts']:
            num_str = str(server_num)
            self._data['server_counts'][num_str] = \
                self._data['server_counts'].get(num_str, 0) + 1

        # Update token index
        if token:
            with self._locks['tokens_index']:
                self._data['tokens_index'][token] = username

        # Save to shared memory so other workers see it immediately
        self._save_to_shared()

        print(f"  βœ… Registered user '{username}' on server {server_num} (PID: {os.getpid()})")
        return True, None

    def link_token(self, username, token):
        """Link a device fingerprint token to a user. REPLACES all previous tokens."""
        self._sync_from_shared()

        with self._locks['users_registry']:
            user = self._data['users_registry'].get(username)
            if not user:
                return False, "User not found"

            # Remove this token from any OTHER user (one device = one account)
            old_username = None
            with self._locks['tokens_index']:
                old_username = self._data['tokens_index'].get(token)

            if old_username and old_username != username:
                old_user = self._data['users_registry'].get(old_username)
                if old_user:
                    old_user['tokens'] = [
                        t for t in old_user.get('tokens', [])
                        if t.get('token') != token
                    ]

            # Remove ALL old tokens for this user (1 device policy)
            old_tokens = user.get('tokens', [])
            for old_t in old_tokens:
                old_token_val = old_t.get('token', '')
                if old_token_val and old_token_val != token:
                    with self._locks['tokens_index']:
                        self._data['tokens_index'].pop(old_token_val, None)

            # Set ONLY this token for the user (replace, not append)
            user['tokens'] = [{
                "token": token,
                "created_at": datetime.now().isoformat(),
            }]

            user['last_login'] = datetime.now().isoformat()

        # Update token index - this token points to this user ONLY
        with self._locks['tokens_index']:
            self._data['tokens_index'][token] = username

        self._save_to_shared()

        return True, None

    def unlink_token(self, token):
        """Remove a token from the index and from the user."""
        self._sync_from_shared()  # Get latest before writing

        with self._locks['tokens_index']:
            username = self._data['tokens_index'].pop(token, None)

        if username:
            with self._locks['users_registry']:
                user = self._data['users_registry'].get(username)
                if user:
                    user['tokens'] = [
                        t for t in user.get('tokens', [])
                        if t.get('token') != token
                    ]

        # Save to shared memory
        self._save_to_shared()

        return True

    def update_user_login(self, username):
        """Update last_login timestamp."""
        self._sync_from_shared()

        with self._locks['users_registry']:
            user = self._data['users_registry'].get(username)
            if user:
                user['last_login'] = datetime.now().isoformat()

                # Save to shared memory
                self._save_to_shared()
                return True
        return False

    def search_users(self, query, limit=10):
        """Search users by username prefix."""
        if not query or len(query) < 1:
            return []

        self._sync_from_shared()

        query_lower = query.lower().strip()
        results = []

        with self._locks['users_registry']:
            for uname, udata in self._data['users_registry'].items():
                if uname.lower().startswith(query_lower):
                    results.append({
                        "username": uname,
                        "server_num": udata.get('server_num'),
                    })
                    if len(results) >= limit:
                        break

        return results


def get_db():
    return MainMemoryDB.get_instance()