Spaces:
Runtime error
Runtime error
| """ | |
| JSON file-based database module | |
| Handles: | |
| - JSON file storage for users, tokens, and blacklisted tokens | |
| - User management and authentication | |
| - Token blacklist management | |
| - Refresh token storage | |
| - Thread-safe JSON operations with file locking | |
| - Hugging Face friendly seeding + writable fallback | |
| """ | |
| import os | |
| import json | |
| import threading | |
| import shutil | |
| from typing import Optional, Dict, Any, List, Tuple | |
| from pathlib import Path | |
| from datetime import datetime | |
| # ========================================================= | |
| # PATHS | |
| # ========================================================= | |
| def _project_root() -> Path: | |
| # auth/json_database.py -> auth (parent) -> project root (parent) | |
| return Path(__file__).resolve().parent.parent | |
| DEFAULT_DB_DIR = str((_project_root() / "data").resolve()) # <project>/data | |
| DB_DIR_ENV = os.getenv("AUTH_DB_DIR", DEFAULT_DB_DIR) | |
| # Seed dir (where your repo JSON files exist) | |
| SEED_DIR_ENV = os.getenv("AUTH_SEED_DIR", DEFAULT_DB_DIR) | |
| # Seeding controls | |
| SEED_ON_START = os.getenv("AUTH_SEED_ON_START", "1") == "1" | |
| SEED_IF_EMPTY = os.getenv("AUTH_SEED_IF_EMPTY", "1") == "1" | |
| # Initialize DB files on startup | |
| RUN_INIT_DB = os.getenv("RUN_INIT_DB", "1") == "1" | |
| # Resolve target db dir | |
| DB_DIR = str(Path(DB_DIR_ENV).expanduser().resolve()) | |
| SEED_DIR = str(Path(SEED_DIR_ENV).expanduser().resolve()) | |
| USERS_FILE = os.path.join(DB_DIR, "users.json") | |
| BLACKLISTED_TOKENS_FILE = os.path.join(DB_DIR, "blacklisted_tokens.json") | |
| REFRESH_TOKENS_FILE = os.path.join(DB_DIR, "refresh_tokens.json") | |
| # Thread locks | |
| _file_locks = { | |
| "users": threading.Lock(), | |
| "blacklisted_tokens": threading.Lock(), | |
| "refresh_tokens": threading.Lock(), | |
| } | |
| _db_init_done = False | |
| _db_init_lock = threading.Lock() | |
| # ========================================================= | |
| # HELPERS | |
| # ========================================================= | |
| def _log(msg: str): | |
| # Simple stdout logging (Hugging Face logs show this) | |
| print(msg, flush=True) | |
| def ensure_data_directory(): | |
| """ | |
| Create DB directory. If not writable, fallback to /tmp/auth_db. | |
| """ | |
| global DB_DIR, USERS_FILE, BLACKLISTED_TOKENS_FILE, REFRESH_TOKENS_FILE | |
| try: | |
| Path(DB_DIR).mkdir(parents=True, exist_ok=True) | |
| # write test | |
| test_file = os.path.join(DB_DIR, ".write_test") | |
| with open(test_file, "w", encoding="utf-8") as f: | |
| f.write("ok") | |
| os.remove(test_file) | |
| return | |
| except Exception as e: | |
| _log(f"[AUTH][WARN] DB_DIR not writable: {DB_DIR}. Error: {e}") | |
| fallback = "/tmp/auth_db" | |
| _log(f"[AUTH][WARN] Falling back to: {fallback}") | |
| DB_DIR = fallback | |
| Path(DB_DIR).mkdir(parents=True, exist_ok=True) | |
| USERS_FILE = os.path.join(DB_DIR, "users.json") | |
| BLACKLISTED_TOKENS_FILE = os.path.join(DB_DIR, "blacklisted_tokens.json") | |
| REFRESH_TOKENS_FILE = os.path.join(DB_DIR, "refresh_tokens.json") | |
| def load_json_file(filepath: str) -> Dict[str, Any]: | |
| """ | |
| Load JSON file with error handling. | |
| Returns {} if file does not exist. | |
| """ | |
| if not os.path.exists(filepath): | |
| return {} | |
| try: | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, IOError) as e: | |
| _log(f"[AUTH][ERROR] Error loading {filepath}: {e}") | |
| return {} | |
| def save_json_file(filepath: str, data: Dict[str, Any]): | |
| """ | |
| Atomic save JSON file (prevents partial writes). | |
| """ | |
| parent = os.path.dirname(filepath) | |
| os.makedirs(parent, exist_ok=True) | |
| tmp_path = filepath + ".tmp" | |
| try: | |
| with open(tmp_path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| os.replace(tmp_path, filepath) | |
| except Exception as e: | |
| _log(f"[AUTH][ERROR] Error saving {filepath}: {e}") | |
| # cleanup temp if exists | |
| try: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| raise | |
| def _file_is_empty_json(filepath: str) -> bool: | |
| if not os.path.exists(filepath): | |
| return True | |
| try: | |
| data = load_json_file(filepath) | |
| return not bool(data) | |
| except Exception: | |
| return True | |
| def _seed_file_if_needed(target_path: str, seed_path: str): | |
| """ | |
| Copy seed file -> target if: | |
| - target missing, OR | |
| - target exists but empty AND SEED_IF_EMPTY=1 | |
| """ | |
| if not SEED_ON_START: | |
| return | |
| seed_exists = os.path.exists(seed_path) | |
| target_exists = os.path.exists(target_path) | |
| if (not target_exists and seed_exists) or (target_exists and SEED_IF_EMPTY and seed_exists and _file_is_empty_json(target_path)): | |
| os.makedirs(os.path.dirname(target_path), exist_ok=True) | |
| shutil.copy2(seed_path, target_path) | |
| _log(f"[AUTH][SEED] Copied seed: {seed_path} -> {target_path}") | |
| # ========================================================= | |
| # STORAGE CLASSES | |
| # ========================================================= | |
| class JSONUsers: | |
| def load_all() -> Dict[str, Any]: | |
| with _file_locks["users"]: | |
| return load_json_file(USERS_FILE) | |
| def find_by_username(username: str) -> Optional[Dict[str, Any]]: | |
| users = JSONUsers.load_all() | |
| return users.get(username.lower()) | |
| def create_user(username: str, password_hash: str, role: str = "user") -> bool: | |
| username_lower = username.lower() | |
| with _file_locks["users"]: | |
| users = load_json_file(USERS_FILE) | |
| if username_lower in users: | |
| return False | |
| users[username_lower] = { | |
| "id": len(users) + 1, | |
| "username": username_lower, | |
| "password_hash": password_hash, | |
| "role": role, | |
| "created_at": datetime.now().isoformat(), | |
| } | |
| save_json_file(USERS_FILE, users) | |
| return True | |
| def get_all_users() -> List[Dict[str, Any]]: | |
| users = JSONUsers.load_all() | |
| return [ | |
| {"id": u.get("id"), "username": u.get("username"), "role": u.get("role")} | |
| for u in users.values() | |
| ] | |
| def promote_to_admin(username: str) -> bool: | |
| username_lower = username.lower() | |
| with _file_locks["users"]: | |
| users = load_json_file(USERS_FILE) | |
| if username_lower not in users: | |
| return False | |
| users[username_lower]["role"] = "admin" | |
| save_json_file(USERS_FILE, users) | |
| return True | |
| def user_count() -> int: | |
| return len(JSONUsers.load_all()) | |
| class JSONBlacklistedTokens: | |
| def load_all() -> Dict[str, Any]: | |
| with _file_locks["blacklisted_tokens"]: | |
| return load_json_file(BLACKLISTED_TOKENS_FILE) | |
| def is_blacklisted(token: str) -> bool: | |
| tokens = JSONBlacklistedTokens.load_all() | |
| return token in tokens | |
| def add_to_blacklist(token: str) -> bool: | |
| with _file_locks["blacklisted_tokens"]: | |
| tokens = load_json_file(BLACKLISTED_TOKENS_FILE) | |
| if token in tokens: | |
| return True | |
| tokens[token] = {"token": token, "created_at": datetime.now().isoformat()} | |
| save_json_file(BLACKLISTED_TOKENS_FILE, tokens) | |
| return True | |
| class JSONRefreshTokens: | |
| def load_all() -> Dict[str, Any]: | |
| with _file_locks["refresh_tokens"]: | |
| return load_json_file(REFRESH_TOKENS_FILE) | |
| def find_by_token(token: str) -> Optional[str]: | |
| tokens = JSONRefreshTokens.load_all() | |
| for token_data in tokens.values(): | |
| if token_data.get("token") == token: | |
| return token_data.get("username") | |
| return None | |
| def create_token(username: str, token: str) -> bool: | |
| with _file_locks["refresh_tokens"]: | |
| tokens = load_json_file(REFRESH_TOKENS_FILE) | |
| token_key = f"{username.lower()}_{len(tokens) + 1}" | |
| tokens[token_key] = { | |
| "username": username.lower(), | |
| "token": token, | |
| "created_at": datetime.now().isoformat(), | |
| } | |
| save_json_file(REFRESH_TOKENS_FILE, tokens) | |
| return True | |
| def delete_user_tokens(username: str) -> bool: | |
| username_lower = username.lower() | |
| with _file_locks["refresh_tokens"]: | |
| tokens = load_json_file(REFRESH_TOKENS_FILE) | |
| to_remove = [k for k, v in tokens.items() if v.get("username") == username_lower] | |
| for k in to_remove: | |
| del tokens[k] | |
| save_json_file(REFRESH_TOKENS_FILE, tokens) | |
| return True | |
| # ========================================================= | |
| # INIT / DIAG | |
| # ========================================================= | |
| def init_db(): | |
| """ | |
| Initialize JSON database files. | |
| If AUTH_SEED_ON_START=1, seed from AUTH_SEED_DIR when missing or empty. | |
| """ | |
| ensure_data_directory() | |
| # seed paths | |
| seed_users = os.path.join(SEED_DIR, "users.json") | |
| seed_blacklist = os.path.join(SEED_DIR, "blacklisted_tokens.json") | |
| seed_refresh = os.path.join(SEED_DIR, "refresh_tokens.json") | |
| # Seed (if needed) | |
| _seed_file_if_needed(USERS_FILE, seed_users) | |
| _seed_file_if_needed(BLACKLISTED_TOKENS_FILE, seed_blacklist) | |
| _seed_file_if_needed(REFRESH_TOKENS_FILE, seed_refresh) | |
| # Ensure files exist (create empty if still missing) | |
| if not os.path.exists(USERS_FILE): | |
| save_json_file(USERS_FILE, {}) | |
| _log(f"[AUTH][INIT] Created empty {USERS_FILE}") | |
| if not os.path.exists(BLACKLISTED_TOKENS_FILE): | |
| save_json_file(BLACKLISTED_TOKENS_FILE, {}) | |
| _log(f"[AUTH][INIT] Created empty {BLACKLISTED_TOKENS_FILE}") | |
| if not os.path.exists(REFRESH_TOKENS_FILE): | |
| save_json_file(REFRESH_TOKENS_FILE, {}) | |
| _log(f"[AUTH][INIT] Created empty {REFRESH_TOKENS_FILE}") | |
| # Startup log (very useful in HF logs) | |
| try: | |
| users_count = len(load_json_file(USERS_FILE)) | |
| _log(f"[AUTH][READY] DB_DIR={DB_DIR} | users={users_count}") | |
| except Exception as e: | |
| _log(f"[AUTH][WARN] Could not read users count: {e}") | |
| def ensure_database_initialized() -> bool: | |
| """ | |
| Ensure JSON DB is initialized once (thread-safe). | |
| Controlled by RUN_INIT_DB. | |
| """ | |
| global _db_init_done | |
| if not RUN_INIT_DB: | |
| return False | |
| if not _db_init_done: | |
| with _db_init_lock: | |
| if not _db_init_done: | |
| init_db() | |
| _db_init_done = True | |
| return True | |
| def get_database_info() -> Dict[str, Any]: | |
| info = { | |
| "database_type": "JSON", | |
| "storage_location": DB_DIR, | |
| "seed_location": SEED_DIR, | |
| "files": {}, | |
| } | |
| def _file_info(path: str, label: str): | |
| if not os.path.exists(path): | |
| info["files"][label] = {"exists": False} | |
| return | |
| data = load_json_file(path) | |
| info["files"][label] = { | |
| "exists": True, | |
| "path": path, | |
| "size_bytes": os.path.getsize(path), | |
| "count": len(data) if isinstance(data, dict) else 0, | |
| } | |
| try: | |
| _file_info(USERS_FILE, "users") | |
| _file_info(BLACKLISTED_TOKENS_FILE, "blacklisted_tokens") | |
| _file_info(REFRESH_TOKENS_FILE, "refresh_tokens") | |
| info["connection_status"] = "ok" | |
| except Exception as e: | |
| info["connection_status"] = "error" | |
| info["error"] = str(e) | |
| return info | |
| def test_database_connection() -> Tuple[bool, str]: | |
| try: | |
| ensure_data_directory() | |
| _ = JSONUsers.load_all() | |
| test_file = os.path.join(DB_DIR, ".test") | |
| save_json_file(test_file, {"test": "ok"}) | |
| os.remove(test_file) | |
| return True, "JSON database connection successful" | |
| except Exception as e: | |
| return False, f"JSON database connection failed: {str(e)}" | |