Nyan-Proxy / src /services /user_store.py
WasabiDrop's picture
Clean up excessive debug output and fix admin authentication
3149a5d
import uuid
import hashlib
import json
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Literal
from dataclasses import dataclass, asdict
from enum import Enum
from contextlib import contextmanager
try:
import firebase_admin
from firebase_admin import credentials, db
FIREBASE_AVAILABLE = True
except ImportError:
FIREBASE_AVAILABLE = False
print("Warning: Firebase not available. User data will be stored in memory only.")
from ..config.auth import auth_config
class ReadWriteLock:
"""A reader-writer lock implementation for better concurrency"""
def __init__(self):
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
self._writers = 0
self._write_ready = threading.Condition(threading.RLock())
@contextmanager
def read_lock(self):
"""Acquire read lock"""
with self._read_ready:
while self._writers > 0:
self._read_ready.wait()
self._readers += 1
try:
yield
finally:
with self._read_ready:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
@contextmanager
def write_lock(self):
"""Acquire write lock"""
with self._write_ready:
while self._writers > 0:
self._write_ready.wait()
self._writers += 1
with self._read_ready:
while self._readers > 0:
self._read_ready.wait()
try:
yield
finally:
with self._write_ready:
self._writers -= 1
self._write_ready.notify_all()
with self._read_ready:
self._read_ready.notify_all()
class UserType(Enum):
NORMAL = "normal"
SPECIAL = "special"
TEMPORARY = "temporary"
class AuthResult(Enum):
SUCCESS = "success"
DISABLED = "disabled"
NOT_FOUND = "not_found"
LIMITED = "limited"
@dataclass
class TokenCount:
input: int = 0
output: int = 0
total: int = 0
requests: int = 0
cost_usd: float = 0.0
first_request: Optional[datetime] = None
last_request: Optional[datetime] = None
def add_usage(self, input_tokens: int, output_tokens: int, cost: float = 0.0):
"""Add usage data and update counts"""
self.input += input_tokens
self.output += output_tokens
self.total += input_tokens + output_tokens
self.requests += 1
self.cost_usd += cost
now = datetime.now()
if self.first_request is None:
self.first_request = now
self.last_request = now
def to_dict(self):
return {
'input': self.input,
'output': self.output,
'total': self.total,
'requests': self.requests,
'cost_usd': self.cost_usd,
'first_request': self.first_request.isoformat() if self.first_request else None,
'last_request': self.last_request.isoformat() if self.last_request else None
}
@classmethod
def from_dict(cls, data: dict):
return cls(
input=data.get('input', 0),
output=data.get('output', 0),
total=data.get('total', 0),
requests=data.get('requests', 0),
cost_usd=data.get('cost_usd', 0.0),
first_request=datetime.fromisoformat(data['first_request']) if data.get('first_request') else None,
last_request=datetime.fromisoformat(data['last_request']) if data.get('last_request') else None
)
@dataclass
class IPUsage:
ip: str
prompt_count: int = 0
first_seen: Optional[datetime] = None
last_used: Optional[datetime] = None
total_requests: int = 0
total_tokens: int = 0
models_used: Dict[str, int] = None
user_agent: Optional[str] = None
is_suspicious: bool = False
def __post_init__(self):
if self.models_used is None:
self.models_used = {}
if self.first_seen is None:
self.first_seen = datetime.now()
def add_request(self, model_family: str, tokens: int, user_agent: str = None):
"""Add request data and update counts"""
self.prompt_count += 1
self.total_requests += 1
self.total_tokens += tokens
self.last_used = datetime.now()
if user_agent:
self.user_agent = user_agent
# Track model usage
if model_family not in self.models_used:
self.models_used[model_family] = 0
self.models_used[model_family] += 1
def to_dict(self):
return {
'ip': self.ip,
'prompt_count': self.prompt_count,
'first_seen': self.first_seen.isoformat() if self.first_seen else None,
'last_used': self.last_used.isoformat() if self.last_used else None,
'total_requests': self.total_requests,
'total_tokens': self.total_tokens,
'models_used': self.models_used,
'user_agent': self.user_agent,
'is_suspicious': self.is_suspicious
}
@classmethod
def from_dict(cls, data: dict):
return cls(
ip=data['ip'],
prompt_count=data.get('prompt_count', 0),
first_seen=datetime.fromisoformat(data['first_seen']) if data.get('first_seen') else None,
last_used=datetime.fromisoformat(data['last_used']) if data.get('last_used') else None,
total_requests=data.get('total_requests', 0),
total_tokens=data.get('total_tokens', 0),
models_used=data.get('models_used', {}),
user_agent=data.get('user_agent'),
is_suspicious=data.get('is_suspicious', False)
)
@dataclass
class User:
token: str
type: UserType = UserType.NORMAL
created_at: datetime = None
last_used: Optional[datetime] = None
disabled_at: Optional[datetime] = None
disabled_reason: Optional[str] = None
ip: List[str] = None
ip_usage: List[IPUsage] = None
token_counts: Dict[str, TokenCount] = None
token_limits: Dict[str, int] = None
nickname: Optional[str] = None
# Temporary user limits
prompt_limits: Optional[int] = None # Maximum number of prompts for temporary users
max_ips: Optional[int] = None # Maximum number of IPs for this user
# Enhanced tracking features
total_requests: int = 0
total_cost: float = 0.0
rate_limit_hits: int = 0
last_rate_limit: Optional[datetime] = None
suspicious_activity_count: int = 0
preferred_models: Dict[str, int] = None
average_tokens_per_request: float = 0.0
peak_requests_per_hour: int = 0
favorite_user_agent: Optional[str] = None
# Cat-themed status tracking
mood: str = "happy" # happy, sleepy, grumpy, playful
favorite_treat: Optional[str] = None
paw_prints: int = 0 # number of successful requests
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if self.ip is None:
self.ip = []
if self.ip_usage is None:
self.ip_usage = []
if self.token_counts is None:
self.token_counts = {}
if self.token_limits is None:
# Default limits per model family
self.token_limits = {
'openai': 100000,
'anthropic': 100000,
'google': 100000,
'mistral': 100000
}
if self.preferred_models is None:
self.preferred_models = {}
# Initialize cat mood based on user type
if self.type == UserType.SPECIAL:
self.mood = "playful"
elif self.type == UserType.TEMPORARY:
self.mood = "sleepy"
else:
self.mood = "happy"
def is_disabled(self) -> bool:
return self.disabled_at is not None
def is_ip_limit_exceeded(self) -> bool:
# Use user-specific max_ips if set, otherwise use global setting
max_ips = self.max_ips if self.max_ips is not None else auth_config.max_ips_per_user
return len(self.ip) >= max_ips
def is_prompt_limit_exceeded(self) -> bool:
"""Check if user has exceeded their prompt limits (for temporary users)"""
if self.prompt_limits is None:
return False
total_prompts = sum(ip_usage.prompt_count for ip_usage in self.ip_usage)
return total_prompts >= self.prompt_limits
def get_remaining_prompts(self) -> Optional[int]:
"""Get remaining prompts for temporary users"""
if self.prompt_limits is None:
return None
total_prompts = sum(ip_usage.prompt_count for ip_usage in self.ip_usage)
return max(0, self.prompt_limits - total_prompts)
def get_token_usage(self, model_family: str) -> TokenCount:
return self.token_counts.get(model_family, TokenCount())
def get_token_limit(self, model_family: str) -> int:
return self.token_limits.get(model_family, 100000)
def add_request_tracking(self, model_family: str, input_tokens: int, output_tokens: int,
cost: float, ip_hash: str, user_agent: str = None):
"""Add comprehensive request tracking with cat-themed elements"""
# Update token counts
if model_family not in self.token_counts:
self.token_counts[model_family] = TokenCount()
self.token_counts[model_family].add_usage(input_tokens, output_tokens, cost)
# Update general stats
self.total_requests += 1
self.total_cost += cost
self.last_used = datetime.now()
self.paw_prints += 1 # Successful request = paw print!
# Update IP usage tracking
ip_usage = self.get_or_create_ip_usage(ip_hash)
ip_usage.add_request(model_family, input_tokens + output_tokens, user_agent)
# Update preferred models
if model_family not in self.preferred_models:
self.preferred_models[model_family] = 0
self.preferred_models[model_family] += 1
# Update average tokens per request
total_tokens = sum(count.total for count in self.token_counts.values())
self.average_tokens_per_request = total_tokens / self.total_requests if self.total_requests > 0 else 0
# Update mood based on usage patterns
self.update_mood()
# Update favorite treat based on most used model
most_used_model = max(self.preferred_models.items(), key=lambda x: x[1], default=('', 0))[0]
treat_map = {
'openai': 'Fish treats',
'anthropic': 'Catnip',
'google': 'Tuna',
'mistral': 'Chicken bits'
}
self.favorite_treat = treat_map.get(most_used_model, 'Generic cat treats')
def get_or_create_ip_usage(self, ip_hash: str) -> IPUsage:
"""Get or create IP usage tracking for a given IP hash"""
for ip_usage in self.ip_usage:
if ip_usage.ip == ip_hash:
return ip_usage
# Create new IP usage if not found
new_ip_usage = IPUsage(ip=ip_hash)
self.ip_usage.append(new_ip_usage)
# Also add to IP list if not already there
if ip_hash not in self.ip:
self.ip.append(ip_hash)
return new_ip_usage
def update_mood(self):
"""Update cat mood based on usage patterns"""
if self.is_disabled():
self.mood = "grumpy"
elif self.rate_limit_hits > 10:
self.mood = "grumpy"
elif self.total_requests > 1000:
self.mood = "playful"
elif self.type == UserType.TEMPORARY:
self.mood = "sleepy"
else:
self.mood = "happy"
def get_mood_emoji(self) -> str:
"""Get emoji representation of cat mood"""
mood_emojis = {
"happy": "😸",
"sleepy": "😴",
"grumpy": "😾",
"playful": "😼"
}
return mood_emojis.get(self.mood, "😺")
def add_rate_limit_hit(self):
"""Track rate limit violations"""
self.rate_limit_hits += 1
self.last_rate_limit = datetime.now()
self.update_mood()
def mark_suspicious_activity(self):
"""Mark suspicious activity for this user"""
self.suspicious_activity_count += 1
if self.suspicious_activity_count > 5:
self.mood = "grumpy"
def to_dict(self) -> dict:
return {
'token': self.token,
'type': self.type.value,
'created_at': self.created_at.isoformat(),
'last_used': self.last_used.isoformat() if self.last_used else None,
'disabled_at': self.disabled_at.isoformat() if self.disabled_at else None,
'disabled_reason': self.disabled_reason,
'ip': self.ip,
'ip_usage': [usage.to_dict() for usage in self.ip_usage],
'token_counts': {k: v.to_dict() for k, v in self.token_counts.items()},
'token_limits': self.token_limits,
'nickname': self.nickname,
'prompt_limits': self.prompt_limits,
'max_ips': self.max_ips,
# Enhanced tracking fields
'total_requests': self.total_requests,
'total_cost': self.total_cost,
'rate_limit_hits': self.rate_limit_hits,
'last_rate_limit': self.last_rate_limit.isoformat() if self.last_rate_limit else None,
'suspicious_activity_count': self.suspicious_activity_count,
'preferred_models': self.preferred_models,
'average_tokens_per_request': self.average_tokens_per_request,
'peak_requests_per_hour': self.peak_requests_per_hour,
'favorite_user_agent': self.favorite_user_agent,
# Cat-themed fields
'mood': self.mood,
'mood_emoji': self.get_mood_emoji(),
'favorite_treat': self.favorite_treat,
'paw_prints': self.paw_prints
}
@classmethod
def from_dict(cls, data: dict):
# Handle backward compatibility for created_at
created_at = data.get('created_at')
if created_at:
created_at = datetime.fromisoformat(created_at)
else:
# Use current time for legacy users without created_at
created_at = datetime.now()
# Handle backward compatibility for token_counts
token_counts = {}
for k, v in data.get('token_counts', {}).items():
if isinstance(v, dict) and 'add_usage' not in v:
# New format with to_dict/from_dict
token_counts[k] = TokenCount.from_dict(v)
else:
# Old format - convert to new format
token_counts[k] = TokenCount(
input=v.get('input', 0),
output=v.get('output', 0),
total=v.get('total', 0),
requests=v.get('requests', 0),
cost_usd=v.get('cost_usd', 0.0)
)
return cls(
token=data['token'],
type=UserType(data.get('type', 'normal')),
created_at=created_at,
last_used=datetime.fromisoformat(data['last_used']) if data.get('last_used') else None,
disabled_at=datetime.fromisoformat(data['disabled_at']) if data.get('disabled_at') else None,
disabled_reason=data.get('disabled_reason'),
ip=data.get('ip', []),
ip_usage=[IPUsage.from_dict(usage) for usage in data.get('ip_usage', [])],
token_counts=token_counts,
token_limits=data.get('token_limits', {}),
nickname=data.get('nickname'),
prompt_limits=data.get('prompt_limits'),
max_ips=data.get('max_ips'),
# Enhanced tracking fields with backward compatibility
total_requests=data.get('total_requests', 0),
total_cost=data.get('total_cost', 0.0),
rate_limit_hits=data.get('rate_limit_hits', 0),
last_rate_limit=datetime.fromisoformat(data['last_rate_limit']) if data.get('last_rate_limit') else None,
suspicious_activity_count=data.get('suspicious_activity_count', 0),
preferred_models=data.get('preferred_models', {}),
average_tokens_per_request=data.get('average_tokens_per_request', 0.0),
peak_requests_per_hour=data.get('peak_requests_per_hour', 0),
favorite_user_agent=data.get('favorite_user_agent'),
# Cat-themed fields
mood=data.get('mood', 'happy'),
favorite_treat=data.get('favorite_treat'),
paw_prints=data.get('paw_prints', 0)
)
class UserStore:
def __init__(self):
self.users: Dict[str, User] = {}
self.lock = ReadWriteLock() # Use reader-writer lock for better concurrency
self.firebase_db = None
self.flush_queue = set()
self.flush_queue_lock = threading.Lock() # Separate lock for flush queue
self.cleanup_thread = None
# Initialize Firebase if available
if FIREBASE_AVAILABLE and auth_config.firebase_url:
self._initialize_firebase()
# Start cleanup thread
self._start_cleanup_thread()
def _initialize_firebase(self):
"""Initialize Firebase connection"""
try:
# Initializing Firebase
if auth_config.firebase_service_account_key:
# Use service account key
# Using service account key
key_str = auth_config.firebase_service_account_key
decoded_key = json.loads(key_str)
cred = credentials.Certificate(decoded_key)
else:
# Use default credentials
# Using default credentials
cred = credentials.ApplicationDefault()
if not firebase_admin._apps:
# Initializing Firebase app
firebase_admin.initialize_app(cred, {
'databaseURL': auth_config.firebase_url
})
self.firebase_db = db.reference()
# Firebase initialized successfully
# Load existing users from Firebase
self._load_users_from_firebase()
except Exception as e:
print(f"Failed to initialize Firebase: {e}")
self.firebase_db = None
def _load_users_from_firebase(self):
"""Load existing users from Firebase"""
try:
if not self.firebase_db:
return
users_ref = self.firebase_db.child('users')
users_data = users_ref.get()
if users_data:
with self.lock.write_lock():
for token, user_data in users_data.items():
try:
user = User.from_dict(user_data)
self.users[token] = user
except Exception as e:
print(f"Failed to load user {token}: {e}")
# Loaded users from Firebase
except Exception as e:
print(f"Failed to load users from Firebase: {e}")
def _sanitize_firebase_key(self, key: str) -> str:
"""Sanitize key for Firebase (remove invalid characters)"""
return key.replace('.', '_').replace('#', '_').replace('$', '_').replace('[', '_').replace(']', '_')
def _flush_to_firebase(self, token: str):
"""Flush user to Firebase"""
if not self.firebase_db:
return
try:
sanitized_token = self._sanitize_firebase_key(token)
if token in self.users:
# Update user
user_ref = self.firebase_db.child('users').child(sanitized_token)
user_ref.set(self.users[token].to_dict())
else:
# Delete user
user_ref = self.firebase_db.child('users').child(sanitized_token)
user_ref.delete()
except Exception as e:
print(f"Failed to flush user {token} to Firebase: {e}")
def _batch_flush_to_firebase(self, tokens: List[str]):
"""Batch flush multiple users to Firebase for better performance"""
if not self.firebase_db or not tokens:
return
try:
# Prepare batch update data
batch_data = {}
to_delete = []
for token in tokens:
sanitized_token = self._sanitize_firebase_key(token)
if token in self.users:
user = self.users[token]
batch_data[f'users/{sanitized_token}'] = user.to_dict()
else:
# User was deleted
to_delete.append(sanitized_token)
# Perform batch update for existing users
if batch_data:
self.firebase_db.update(batch_data)
# Delete users that were removed
for sanitized_token in to_delete:
user_ref = self.firebase_db.child('users').child(sanitized_token)
user_ref.delete()
except Exception as e:
print(f"Failed to batch flush users to Firebase: {e}")
# Fallback to individual flushes
for token in tokens:
self._flush_to_firebase(token)
def _start_cleanup_thread(self):
"""Start background cleanup thread"""
def cleanup_loop():
while True:
try:
# Flush pending changes to Firebase in batches
with self.flush_queue_lock:
flush_tokens = list(self.flush_queue.copy())
self.flush_queue.clear()
if flush_tokens:
# Process in batches for better performance
batch_size = 10
for i in range(0, len(flush_tokens), batch_size):
batch = flush_tokens[i:i + batch_size]
self._batch_flush_to_firebase(batch)
# Clean up expired temporary users
self._cleanup_expired_users()
time.sleep(60) # Run every minute
except Exception as e:
print(f"Cleanup thread error: {e}")
time.sleep(60)
self.cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
self.cleanup_thread.start()
def _cleanup_expired_users(self):
"""Clean up expired temporary users"""
with self.lock.write_lock():
to_disable = []
for token, user in self.users.items():
if user.type == UserType.TEMPORARY and not user.is_disabled():
# Check if temporary user has expired (24 hours)
if user.created_at < datetime.now() - timedelta(hours=24):
to_disable.append(token)
for token in to_disable:
self._disable_user_internal(token, "Temporary token expired")
def _hash_ip(self, ip: str) -> str:
"""Hash IP address for privacy"""
return hashlib.sha256(ip.encode()).hexdigest()
def _disable_user_internal(self, token: str, reason: str):
"""Internal method to disable user (assumes lock is held)"""
if token in self.users:
user = self.users[token]
user.disabled_at = datetime.now()
user.disabled_reason = reason
with self.flush_queue_lock:
self.flush_queue.add(token)
def create_user(self, user_type: UserType = UserType.NORMAL,
token_limits: Optional[Dict[str, int]] = None,
nickname: Optional[str] = None,
prompt_limits: Optional[int] = None,
max_ips: Optional[int] = None) -> str:
"""Create a new user and return their token"""
with self.lock.write_lock():
token = str(uuid.uuid4())
user = User(
token=token,
type=user_type,
nickname=nickname,
prompt_limits=prompt_limits,
max_ips=max_ips
)
if token_limits:
user.token_limits.update(token_limits)
self.users[token] = user
with self.flush_queue_lock:
self.flush_queue.add(token)
return token
def authenticate(self, token: str, ip: str) -> Tuple[AuthResult, Optional[User]]:
"""Authenticate user token and track IP"""
with self.lock.write_lock():
if token not in self.users:
return AuthResult.NOT_FOUND, None
user = self.users[token]
# Check if user is disabled
if user.is_disabled():
return AuthResult.DISABLED, user
# Check if temporary user has exceeded prompt limits
if user.type == UserType.TEMPORARY and user.is_prompt_limit_exceeded():
self._disable_user_internal(token, "Prompt limit exceeded")
return AuthResult.DISABLED, user
# Hash IP for privacy
hashed_ip = self._hash_ip(ip)
# Check if IP limit exceeded
if hashed_ip not in user.ip and user.is_ip_limit_exceeded():
if auth_config.max_ips_auto_ban:
self._disable_user_internal(token, "IP address limit exceeded")
return AuthResult.DISABLED, user
else:
return AuthResult.LIMITED, user
# Add IP if not already tracked
if hashed_ip not in user.ip:
user.ip.append(hashed_ip)
user.ip_usage.append(IPUsage(ip=hashed_ip, last_used=datetime.now()))
else:
# Update existing IP usage
for usage in user.ip_usage:
if usage.ip == hashed_ip:
usage.last_used = datetime.now()
break
# Update last used
user.last_used = datetime.now()
with self.flush_queue_lock:
self.flush_queue.add(token)
return AuthResult.SUCCESS, user
def increment_token_count(self, token: str, model_family: str,
input_tokens: int, output_tokens: int, cost: float = 0.0) -> bool:
"""Increment token count for user and model family (legacy method)"""
with self.lock.write_lock():
if token not in self.users:
return False
user = self.users[token]
if user.is_disabled():
return False
# Get or create token count for model family
if model_family not in user.token_counts:
user.token_counts[model_family] = TokenCount()
# Use the new add_usage method
user.token_counts[model_family].add_usage(input_tokens, output_tokens, cost)
# Update general user stats
user.total_requests += 1
user.total_cost += cost
user.last_used = datetime.now()
user.paw_prints += 1
# Update IP usage if we can determine current IP
# Note: This is a fallback method, prefer using add_request_tracking
if user.ip_usage and len(user.ip_usage) > 0:
# Update the most recently used IP
latest_ip_usage = max(user.ip_usage, key=lambda x: x.last_used or datetime.min)
latest_ip_usage.add_request(model_family, input_tokens + output_tokens)
with self.flush_queue_lock:
self.flush_queue.add(token)
return True
def disable_user(self, token: str, reason: str = None) -> bool:
"""Disable a user account"""
with self.lock.write_lock():
if token not in self.users:
return False
self._disable_user_internal(token, reason or "Disabled by admin")
return True
def reactivate_user(self, token: str) -> bool:
"""Reactivate a disabled user"""
with self.lock.write_lock():
if token not in self.users:
return False
user = self.users[token]
user.disabled_at = None
user.disabled_reason = None
with self.flush_queue_lock:
self.flush_queue.add(token)
return True
def delete_user(self, token: str) -> bool:
"""Permanently delete a user (must be disabled first)"""
with self.lock.write_lock():
if token not in self.users:
return False
user = self.users[token]
if not user.is_disabled():
raise ValueError("User must be disabled before deletion")
del self.users[token]
with self.flush_queue_lock:
self.flush_queue.add(token) # This will trigger Firebase deletion
return True
def rotate_user_token(self, old_token: str) -> Optional[str]:
"""Rotate user token (generate new token, preserve data)"""
with self.lock.write_lock():
if old_token not in self.users:
return None
user = self.users[old_token]
new_token = str(uuid.uuid4())
# Update token in user object
user.token = new_token
# Move user to new token key
self.users[new_token] = user
del self.users[old_token]
# Queue both for Firebase update
self.flush_queue.add(old_token) # Delete old
self.flush_queue.add(new_token) # Add new
return new_token
def get_user(self, token: str) -> Optional[User]:
"""Get user by token"""
with self.lock.read_lock():
return self.users.get(token)
def get_all_users(self) -> List[User]:
"""Get all users"""
with self.lock.read_lock():
return list(self.users.values())
def get_user_count(self) -> int:
"""Get total user count"""
with self.lock.read_lock():
return len(self.users)
def check_quota(self, token: str, model_family: str) -> Tuple[bool, int, int]:
"""Check if user has quota remaining. Returns (has_quota, used, limit)"""
with self.lock.read_lock():
if token not in self.users:
return False, 0, 0
user = self.users[token]
if user.is_disabled():
return False, 0, 0
used = user.get_token_usage(model_family).total
limit = user.get_token_limit(model_family)
return used < limit, used, limit
# Global user store instance
user_store = UserStore()