py-learn-backend / auth /json_database.py
Oviya
fix
707d3fc
"""
JSON file-based database module
Handles:
- JSON file storage for users, tokens, and blacklisted tokens
- User management and authentication
- Token blacklist management
- Refresh token storage
- Thread-safe JSON operations with file locking
- Hugging Face friendly seeding + writable fallback
"""
import os
import json
import threading
import shutil
from typing import Optional, Dict, Any, List, Tuple
from pathlib import Path
from datetime import datetime
# =========================================================
# PATHS
# =========================================================
def _project_root() -> Path:
# auth/json_database.py -> auth (parent) -> project root (parent)
return Path(__file__).resolve().parent.parent
DEFAULT_DB_DIR = str((_project_root() / "data").resolve()) # <project>/data
DB_DIR_ENV = os.getenv("AUTH_DB_DIR", DEFAULT_DB_DIR)
# Seed dir (where your repo JSON files exist)
SEED_DIR_ENV = os.getenv("AUTH_SEED_DIR", DEFAULT_DB_DIR)
# Seeding controls
SEED_ON_START = os.getenv("AUTH_SEED_ON_START", "1") == "1"
SEED_IF_EMPTY = os.getenv("AUTH_SEED_IF_EMPTY", "1") == "1"
# Initialize DB files on startup
RUN_INIT_DB = os.getenv("RUN_INIT_DB", "1") == "1"
# Resolve target db dir
DB_DIR = str(Path(DB_DIR_ENV).expanduser().resolve())
SEED_DIR = str(Path(SEED_DIR_ENV).expanduser().resolve())
USERS_FILE = os.path.join(DB_DIR, "users.json")
BLACKLISTED_TOKENS_FILE = os.path.join(DB_DIR, "blacklisted_tokens.json")
REFRESH_TOKENS_FILE = os.path.join(DB_DIR, "refresh_tokens.json")
# Thread locks
_file_locks = {
"users": threading.Lock(),
"blacklisted_tokens": threading.Lock(),
"refresh_tokens": threading.Lock(),
}
_db_init_done = False
_db_init_lock = threading.Lock()
# =========================================================
# HELPERS
# =========================================================
def _log(msg: str):
# Simple stdout logging (Hugging Face logs show this)
print(msg, flush=True)
def ensure_data_directory():
"""
Create DB directory. If not writable, fallback to /tmp/auth_db.
"""
global DB_DIR, USERS_FILE, BLACKLISTED_TOKENS_FILE, REFRESH_TOKENS_FILE
try:
Path(DB_DIR).mkdir(parents=True, exist_ok=True)
# write test
test_file = os.path.join(DB_DIR, ".write_test")
with open(test_file, "w", encoding="utf-8") as f:
f.write("ok")
os.remove(test_file)
return
except Exception as e:
_log(f"[AUTH][WARN] DB_DIR not writable: {DB_DIR}. Error: {e}")
fallback = "/tmp/auth_db"
_log(f"[AUTH][WARN] Falling back to: {fallback}")
DB_DIR = fallback
Path(DB_DIR).mkdir(parents=True, exist_ok=True)
USERS_FILE = os.path.join(DB_DIR, "users.json")
BLACKLISTED_TOKENS_FILE = os.path.join(DB_DIR, "blacklisted_tokens.json")
REFRESH_TOKENS_FILE = os.path.join(DB_DIR, "refresh_tokens.json")
def load_json_file(filepath: str) -> Dict[str, Any]:
"""
Load JSON file with error handling.
Returns {} if file does not exist.
"""
if not os.path.exists(filepath):
return {}
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
_log(f"[AUTH][ERROR] Error loading {filepath}: {e}")
return {}
def save_json_file(filepath: str, data: Dict[str, Any]):
"""
Atomic save JSON file (prevents partial writes).
"""
parent = os.path.dirname(filepath)
os.makedirs(parent, exist_ok=True)
tmp_path = filepath + ".tmp"
try:
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
os.replace(tmp_path, filepath)
except Exception as e:
_log(f"[AUTH][ERROR] Error saving {filepath}: {e}")
# cleanup temp if exists
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception:
pass
raise
def _file_is_empty_json(filepath: str) -> bool:
if not os.path.exists(filepath):
return True
try:
data = load_json_file(filepath)
return not bool(data)
except Exception:
return True
def _seed_file_if_needed(target_path: str, seed_path: str):
"""
Copy seed file -> target if:
- target missing, OR
- target exists but empty AND SEED_IF_EMPTY=1
"""
if not SEED_ON_START:
return
seed_exists = os.path.exists(seed_path)
target_exists = os.path.exists(target_path)
if (not target_exists and seed_exists) or (target_exists and SEED_IF_EMPTY and seed_exists and _file_is_empty_json(target_path)):
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.copy2(seed_path, target_path)
_log(f"[AUTH][SEED] Copied seed: {seed_path} -> {target_path}")
# =========================================================
# STORAGE CLASSES
# =========================================================
class JSONUsers:
@staticmethod
def load_all() -> Dict[str, Any]:
with _file_locks["users"]:
return load_json_file(USERS_FILE)
@staticmethod
def find_by_username(username: str) -> Optional[Dict[str, Any]]:
users = JSONUsers.load_all()
return users.get(username.lower())
@staticmethod
def create_user(username: str, password_hash: str, role: str = "user") -> bool:
username_lower = username.lower()
with _file_locks["users"]:
users = load_json_file(USERS_FILE)
if username_lower in users:
return False
users[username_lower] = {
"id": len(users) + 1,
"username": username_lower,
"password_hash": password_hash,
"role": role,
"created_at": datetime.now().isoformat(),
}
save_json_file(USERS_FILE, users)
return True
@staticmethod
def get_all_users() -> List[Dict[str, Any]]:
users = JSONUsers.load_all()
return [
{"id": u.get("id"), "username": u.get("username"), "role": u.get("role")}
for u in users.values()
]
@staticmethod
def promote_to_admin(username: str) -> bool:
username_lower = username.lower()
with _file_locks["users"]:
users = load_json_file(USERS_FILE)
if username_lower not in users:
return False
users[username_lower]["role"] = "admin"
save_json_file(USERS_FILE, users)
return True
@staticmethod
def user_count() -> int:
return len(JSONUsers.load_all())
class JSONBlacklistedTokens:
@staticmethod
def load_all() -> Dict[str, Any]:
with _file_locks["blacklisted_tokens"]:
return load_json_file(BLACKLISTED_TOKENS_FILE)
@staticmethod
def is_blacklisted(token: str) -> bool:
tokens = JSONBlacklistedTokens.load_all()
return token in tokens
@staticmethod
def add_to_blacklist(token: str) -> bool:
with _file_locks["blacklisted_tokens"]:
tokens = load_json_file(BLACKLISTED_TOKENS_FILE)
if token in tokens:
return True
tokens[token] = {"token": token, "created_at": datetime.now().isoformat()}
save_json_file(BLACKLISTED_TOKENS_FILE, tokens)
return True
class JSONRefreshTokens:
@staticmethod
def load_all() -> Dict[str, Any]:
with _file_locks["refresh_tokens"]:
return load_json_file(REFRESH_TOKENS_FILE)
@staticmethod
def find_by_token(token: str) -> Optional[str]:
tokens = JSONRefreshTokens.load_all()
for token_data in tokens.values():
if token_data.get("token") == token:
return token_data.get("username")
return None
@staticmethod
def create_token(username: str, token: str) -> bool:
with _file_locks["refresh_tokens"]:
tokens = load_json_file(REFRESH_TOKENS_FILE)
token_key = f"{username.lower()}_{len(tokens) + 1}"
tokens[token_key] = {
"username": username.lower(),
"token": token,
"created_at": datetime.now().isoformat(),
}
save_json_file(REFRESH_TOKENS_FILE, tokens)
return True
@staticmethod
def delete_user_tokens(username: str) -> bool:
username_lower = username.lower()
with _file_locks["refresh_tokens"]:
tokens = load_json_file(REFRESH_TOKENS_FILE)
to_remove = [k for k, v in tokens.items() if v.get("username") == username_lower]
for k in to_remove:
del tokens[k]
save_json_file(REFRESH_TOKENS_FILE, tokens)
return True
# =========================================================
# INIT / DIAG
# =========================================================
def init_db():
"""
Initialize JSON database files.
If AUTH_SEED_ON_START=1, seed from AUTH_SEED_DIR when missing or empty.
"""
ensure_data_directory()
# seed paths
seed_users = os.path.join(SEED_DIR, "users.json")
seed_blacklist = os.path.join(SEED_DIR, "blacklisted_tokens.json")
seed_refresh = os.path.join(SEED_DIR, "refresh_tokens.json")
# Seed (if needed)
_seed_file_if_needed(USERS_FILE, seed_users)
_seed_file_if_needed(BLACKLISTED_TOKENS_FILE, seed_blacklist)
_seed_file_if_needed(REFRESH_TOKENS_FILE, seed_refresh)
# Ensure files exist (create empty if still missing)
if not os.path.exists(USERS_FILE):
save_json_file(USERS_FILE, {})
_log(f"[AUTH][INIT] Created empty {USERS_FILE}")
if not os.path.exists(BLACKLISTED_TOKENS_FILE):
save_json_file(BLACKLISTED_TOKENS_FILE, {})
_log(f"[AUTH][INIT] Created empty {BLACKLISTED_TOKENS_FILE}")
if not os.path.exists(REFRESH_TOKENS_FILE):
save_json_file(REFRESH_TOKENS_FILE, {})
_log(f"[AUTH][INIT] Created empty {REFRESH_TOKENS_FILE}")
# Startup log (very useful in HF logs)
try:
users_count = len(load_json_file(USERS_FILE))
_log(f"[AUTH][READY] DB_DIR={DB_DIR} | users={users_count}")
except Exception as e:
_log(f"[AUTH][WARN] Could not read users count: {e}")
def ensure_database_initialized() -> bool:
"""
Ensure JSON DB is initialized once (thread-safe).
Controlled by RUN_INIT_DB.
"""
global _db_init_done
if not RUN_INIT_DB:
return False
if not _db_init_done:
with _db_init_lock:
if not _db_init_done:
init_db()
_db_init_done = True
return True
def get_database_info() -> Dict[str, Any]:
info = {
"database_type": "JSON",
"storage_location": DB_DIR,
"seed_location": SEED_DIR,
"files": {},
}
def _file_info(path: str, label: str):
if not os.path.exists(path):
info["files"][label] = {"exists": False}
return
data = load_json_file(path)
info["files"][label] = {
"exists": True,
"path": path,
"size_bytes": os.path.getsize(path),
"count": len(data) if isinstance(data, dict) else 0,
}
try:
_file_info(USERS_FILE, "users")
_file_info(BLACKLISTED_TOKENS_FILE, "blacklisted_tokens")
_file_info(REFRESH_TOKENS_FILE, "refresh_tokens")
info["connection_status"] = "ok"
except Exception as e:
info["connection_status"] = "error"
info["error"] = str(e)
return info
def test_database_connection() -> Tuple[bool, str]:
try:
ensure_data_directory()
_ = JSONUsers.load_all()
test_file = os.path.join(DB_DIR, ".test")
save_json_file(test_file, {"test": "ok"})
os.remove(test_file)
return True, "JSON database connection successful"
except Exception as e:
return False, f"JSON database connection failed: {str(e)}"