|
|
""" |
|
|
Enhanced User Model with Authentication and VPN Client Management |
|
|
|
|
|
This module provides comprehensive user management with security features, |
|
|
VPN client management, and session tracking capabilities. |
|
|
""" |
|
|
|
|
|
|
|
|
from werkzeug.security import generate_password_hash, check_password_hash |
|
|
from datetime import datetime, timedelta |
|
|
import secrets |
|
|
import jwt |
|
|
import re |
|
|
from flask import current_app |
|
|
|
|
|
from .user import db |
|
|
|
|
|
class User(db.Model): |
|
|
"""Enhanced User model with authentication and VPN management""" |
|
|
__tablename__ = 'users' |
|
|
|
|
|
id = db.Column(db.Integer, primary_key=True) |
|
|
username = db.Column(db.String(80), unique=True, nullable=False, index=True) |
|
|
email = db.Column(db.String(120), unique=True, nullable=False, index=True) |
|
|
password_hash = db.Column(db.String(255), nullable=False) |
|
|
salt = db.Column(db.String(32), nullable=False) |
|
|
created_at = db.Column(db.DateTime, default=datetime.utcnow) |
|
|
last_login = db.Column(db.DateTime) |
|
|
is_active = db.Column(db.Boolean, default=True) |
|
|
is_admin = db.Column(db.Boolean, default=False) |
|
|
subscription_type = db.Column(db.String(20), default='free') |
|
|
subscription_expires = db.Column(db.DateTime) |
|
|
max_concurrent_connections = db.Column(db.Integer, default=1) |
|
|
bandwidth_limit_mbps = db.Column(db.Integer, default=10) |
|
|
email_verified = db.Column(db.Boolean, default=False) |
|
|
email_verification_token = db.Column(db.String(64)) |
|
|
two_factor_enabled = db.Column(db.Boolean, default=False) |
|
|
two_factor_secret = db.Column(db.String(32)) |
|
|
password_reset_token = db.Column(db.String(64)) |
|
|
password_reset_expires = db.Column(db.DateTime) |
|
|
failed_login_attempts = db.Column(db.Integer, default=0) |
|
|
account_locked_until = db.Column(db.DateTime) |
|
|
|
|
|
|
|
|
vpn_clients = db.relationship('VPNClient', backref='user', lazy=True, cascade='all, delete-orphan') |
|
|
vpn_sessions = db.relationship('VPNSession', backref='user', lazy=True) |
|
|
|
|
|
def __init__(self, username, email, password=None): |
|
|
self.username = username |
|
|
self.email = email |
|
|
if password: |
|
|
self.set_password(password) |
|
|
self.email_verification_token = secrets.token_urlsafe(32) |
|
|
|
|
|
def set_password(self, password): |
|
|
"""Set password with secure hashing and salt""" |
|
|
if not self.validate_password_strength(password): |
|
|
raise ValueError("Password does not meet security requirements") |
|
|
|
|
|
self.salt = secrets.token_hex(16) |
|
|
self.password_hash = generate_password_hash(password + self.salt) |
|
|
self.failed_login_attempts = 0 |
|
|
self.account_locked_until = None |
|
|
|
|
|
def check_password(self, password): |
|
|
"""Verify password against hash""" |
|
|
if self.is_account_locked(): |
|
|
return False |
|
|
|
|
|
is_valid = check_password_hash(self.password_hash, password + self.salt) |
|
|
|
|
|
if is_valid: |
|
|
self.failed_login_attempts = 0 |
|
|
self.last_login = datetime.utcnow() |
|
|
else: |
|
|
self.failed_login_attempts += 1 |
|
|
if self.failed_login_attempts >= 5: |
|
|
self.account_locked_until = datetime.utcnow() + timedelta(minutes=30) |
|
|
|
|
|
return is_valid |
|
|
|
|
|
def is_account_locked(self): |
|
|
"""Check if account is locked due to failed login attempts""" |
|
|
if self.account_locked_until and datetime.utcnow() < self.account_locked_until: |
|
|
return True |
|
|
elif self.account_locked_until and datetime.utcnow() >= self.account_locked_until: |
|
|
|
|
|
self.account_locked_until = None |
|
|
self.failed_login_attempts = 0 |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def validate_password_strength(password): |
|
|
"""Validate password meets security requirements""" |
|
|
if len(password) < 8: |
|
|
return False |
|
|
if not re.search(r'[A-Z]', password): |
|
|
return False |
|
|
if not re.search(r'[a-z]', password): |
|
|
return False |
|
|
if not re.search(r'\d', password): |
|
|
return False |
|
|
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password): |
|
|
return False |
|
|
return True |
|
|
|
|
|
@staticmethod |
|
|
def validate_email(email): |
|
|
"""Validate email format""" |
|
|
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' |
|
|
return re.match(pattern, email) is not None |
|
|
|
|
|
@staticmethod |
|
|
def validate_username(username): |
|
|
"""Validate username format""" |
|
|
if len(username) < 3 or len(username) > 80: |
|
|
return False |
|
|
if not re.match(r'^[a-zA-Z0-9_-]+$', username): |
|
|
return False |
|
|
return True |
|
|
|
|
|
def generate_auth_token(self, expires_in=3600): |
|
|
"""Generate JWT authentication token""" |
|
|
payload = { |
|
|
'user_id': self.id, |
|
|
'username': self.username, |
|
|
'email': self.email, |
|
|
'subscription_type': self.subscription_type, |
|
|
'is_admin': self.is_admin, |
|
|
'exp': datetime.utcnow() + timedelta(seconds=expires_in), |
|
|
'iat': datetime.utcnow() |
|
|
} |
|
|
return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') |
|
|
|
|
|
def generate_refresh_token(self, expires_in=2592000): |
|
|
"""Generate refresh token for extended sessions""" |
|
|
payload = { |
|
|
'user_id': self.id, |
|
|
'type': 'refresh', |
|
|
'exp': datetime.utcnow() + timedelta(seconds=expires_in), |
|
|
'iat': datetime.utcnow() |
|
|
} |
|
|
return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') |
|
|
|
|
|
@staticmethod |
|
|
def verify_auth_token(token): |
|
|
"""Verify JWT authentication token""" |
|
|
try: |
|
|
payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) |
|
|
if payload.get('type') == 'refresh': |
|
|
return None |
|
|
return User.query.get(payload['user_id']) |
|
|
except jwt.ExpiredSignatureError: |
|
|
return None |
|
|
except jwt.InvalidTokenError: |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def verify_refresh_token(token): |
|
|
"""Verify refresh token and return user""" |
|
|
try: |
|
|
payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) |
|
|
if payload.get('type') != 'refresh': |
|
|
return None |
|
|
return User.query.get(payload['user_id']) |
|
|
except jwt.ExpiredSignatureError: |
|
|
return None |
|
|
except jwt.InvalidTokenError: |
|
|
return None |
|
|
|
|
|
def generate_password_reset_token(self): |
|
|
"""Generate password reset token""" |
|
|
self.password_reset_token = secrets.token_urlsafe(32) |
|
|
self.password_reset_expires = datetime.utcnow() + timedelta(hours=1) |
|
|
return self.password_reset_token |
|
|
|
|
|
def verify_password_reset_token(self, token): |
|
|
"""Verify password reset token""" |
|
|
if (self.password_reset_token == token and |
|
|
self.password_reset_expires and |
|
|
datetime.utcnow() < self.password_reset_expires): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def reset_password(self, new_password, token): |
|
|
"""Reset password using reset token""" |
|
|
if not self.verify_password_reset_token(token): |
|
|
return False |
|
|
|
|
|
self.set_password(new_password) |
|
|
self.password_reset_token = None |
|
|
self.password_reset_expires = None |
|
|
return True |
|
|
|
|
|
def verify_email(self, token): |
|
|
"""Verify email using verification token""" |
|
|
if self.email_verification_token == token: |
|
|
self.email_verified = True |
|
|
self.email_verification_token = None |
|
|
return True |
|
|
return False |
|
|
|
|
|
def can_create_vpn_client(self): |
|
|
"""Check if user can create additional VPN clients""" |
|
|
active_clients = len([c for c in self.vpn_clients if c.is_active]) |
|
|
|
|
|
if self.subscription_type == 'free': |
|
|
return active_clients < 1 |
|
|
elif self.subscription_type == 'premium': |
|
|
return active_clients < 5 |
|
|
elif self.subscription_type == 'enterprise': |
|
|
return active_clients < 50 |
|
|
|
|
|
return False |
|
|
|
|
|
def get_active_sessions_count(self): |
|
|
"""Get count of active VPN sessions""" |
|
|
return len([s for s in self.vpn_sessions if s.disconnected_at is None]) |
|
|
|
|
|
def can_connect_vpn(self): |
|
|
"""Check if user can establish new VPN connections""" |
|
|
active_sessions = self.get_active_sessions_count() |
|
|
return active_sessions < self.max_concurrent_connections |
|
|
|
|
|
def get_bandwidth_usage_today(self): |
|
|
"""Get bandwidth usage for today""" |
|
|
today = datetime.utcnow().date() |
|
|
today_sessions = [s for s in self.vpn_sessions |
|
|
if s.connected_at and s.connected_at.date() == today] |
|
|
|
|
|
total_bytes = sum(s.bytes_received + s.bytes_sent for s in today_sessions) |
|
|
return total_bytes |
|
|
|
|
|
def is_subscription_active(self): |
|
|
"""Check if subscription is active""" |
|
|
if self.subscription_type == 'free': |
|
|
return True |
|
|
|
|
|
return (self.subscription_expires and |
|
|
datetime.utcnow() < self.subscription_expires) |
|
|
|
|
|
def to_dict(self, include_sensitive=False): |
|
|
"""Convert user to dictionary""" |
|
|
data = { |
|
|
'id': self.id, |
|
|
'username': self.username, |
|
|
'email': self.email, |
|
|
'created_at': self.created_at.isoformat() if self.created_at else None, |
|
|
'last_login': self.last_login.isoformat() if self.last_login else None, |
|
|
'is_active': self.is_active, |
|
|
'subscription_type': self.subscription_type, |
|
|
'subscription_expires': self.subscription_expires.isoformat() if self.subscription_expires else None, |
|
|
'max_concurrent_connections': self.max_concurrent_connections, |
|
|
'bandwidth_limit_mbps': self.bandwidth_limit_mbps, |
|
|
'email_verified': self.email_verified, |
|
|
'two_factor_enabled': self.two_factor_enabled, |
|
|
'is_subscription_active': self.is_subscription_active(), |
|
|
'active_vpn_clients': len([c for c in self.vpn_clients if c.is_active]), |
|
|
'active_sessions': self.get_active_sessions_count(), |
|
|
'can_create_vpn_client': self.can_create_vpn_client(), |
|
|
'can_connect_vpn': self.can_connect_vpn() |
|
|
} |
|
|
|
|
|
if include_sensitive and (self.is_admin or include_sensitive == 'self'): |
|
|
data.update({ |
|
|
'is_admin': self.is_admin, |
|
|
'failed_login_attempts': self.failed_login_attempts, |
|
|
'account_locked': self.is_account_locked(), |
|
|
'bandwidth_usage_today': self.get_bandwidth_usage_today() |
|
|
}) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
class VPNClient(db.Model): |
|
|
"""VPN Client configuration and management""" |
|
|
__tablename__ = 'vpn_clients' |
|
|
|
|
|
id = db.Column(db.Integer, primary_key=True) |
|
|
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) |
|
|
client_name = db.Column(db.String(100), nullable=False) |
|
|
protocol = db.Column(db.String(20), nullable=False) |
|
|
certificate_serial = db.Column(db.String(50), unique=True) |
|
|
private_key_path = db.Column(db.String(255)) |
|
|
certificate_path = db.Column(db.String(255)) |
|
|
config_file_path = db.Column(db.String(255)) |
|
|
created_at = db.Column(db.DateTime, default=datetime.utcnow) |
|
|
last_connected = db.Column(db.DateTime) |
|
|
is_active = db.Column(db.Boolean, default=True) |
|
|
device_type = db.Column(db.String(50)) |
|
|
public_key = db.Column(db.Text) |
|
|
|
|
|
|
|
|
sessions = db.relationship('VPNSession', backref='vpn_client', lazy=True) |
|
|
|
|
|
def __init__(self, user_id, client_name, protocol, device_type=None): |
|
|
self.user_id = user_id |
|
|
self.client_name = client_name |
|
|
self.protocol = protocol |
|
|
self.device_type = device_type |
|
|
|
|
|
def get_active_sessions(self): |
|
|
"""Get active sessions for this client""" |
|
|
return [s for s in self.sessions if s.disconnected_at is None] |
|
|
|
|
|
def get_total_bandwidth_usage(self): |
|
|
"""Get total bandwidth usage for this client""" |
|
|
return sum(s.bytes_received + s.bytes_sent for s in self.sessions) |
|
|
|
|
|
def to_dict(self): |
|
|
"""Convert VPN client to dictionary""" |
|
|
return { |
|
|
'id': self.id, |
|
|
'client_name': self.client_name, |
|
|
'protocol': self.protocol, |
|
|
'device_type': self.device_type, |
|
|
'created_at': self.created_at.isoformat() if self.created_at else None, |
|
|
'last_connected': self.last_connected.isoformat() if self.last_connected else None, |
|
|
'is_active': self.is_active, |
|
|
'certificate_serial': self.certificate_serial, |
|
|
'active_sessions': len(self.get_active_sessions()), |
|
|
'total_bandwidth_usage': self.get_total_bandwidth_usage() |
|
|
} |
|
|
|
|
|
|
|
|
class VPNSession(db.Model): |
|
|
"""VPN Session tracking""" |
|
|
__tablename__ = 'vpn_sessions' |
|
|
|
|
|
id = db.Column(db.Integer, primary_key=True) |
|
|
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) |
|
|
client_id = db.Column(db.Integer, db.ForeignKey('vpn_clients.id'), nullable=False) |
|
|
server_protocol = db.Column(db.String(20), nullable=False) |
|
|
client_ip = db.Column(db.String(15)) |
|
|
server_ip = db.Column(db.String(15)) |
|
|
client_real_ip = db.Column(db.String(45)) |
|
|
connected_at = db.Column(db.DateTime, default=datetime.utcnow) |
|
|
disconnected_at = db.Column(db.DateTime) |
|
|
bytes_received = db.Column(db.BigInteger, default=0) |
|
|
bytes_sent = db.Column(db.BigInteger, default=0) |
|
|
session_duration = db.Column(db.Integer) |
|
|
disconnect_reason = db.Column(db.String(100)) |
|
|
|
|
|
def __init__(self, user_id, client_id, server_protocol, client_ip=None, server_ip=None, client_real_ip=None): |
|
|
self.user_id = user_id |
|
|
self.client_id = client_id |
|
|
self.server_protocol = server_protocol |
|
|
self.client_ip = client_ip |
|
|
self.server_ip = server_ip |
|
|
self.client_real_ip = client_real_ip |
|
|
|
|
|
def disconnect(self, reason=None): |
|
|
"""Mark session as disconnected""" |
|
|
self.disconnected_at = datetime.utcnow() |
|
|
self.disconnect_reason = reason |
|
|
if self.connected_at: |
|
|
self.session_duration = int((self.disconnected_at - self.connected_at).total_seconds()) |
|
|
|
|
|
def is_active(self): |
|
|
"""Check if session is active""" |
|
|
return self.disconnected_at is None |
|
|
|
|
|
def get_duration(self): |
|
|
"""Get session duration in seconds""" |
|
|
if self.disconnected_at: |
|
|
return self.session_duration |
|
|
elif self.connected_at: |
|
|
return int((datetime.utcnow() - self.connected_at).total_seconds()) |
|
|
return 0 |
|
|
|
|
|
def to_dict(self): |
|
|
"""Convert VPN session to dictionary""" |
|
|
return { |
|
|
'id': self.id, |
|
|
'client_id': self.client_id, |
|
|
'server_protocol': self.server_protocol, |
|
|
'client_ip': self.client_ip, |
|
|
'server_ip': self.server_ip, |
|
|
'client_real_ip': self.client_real_ip, |
|
|
'connected_at': self.connected_at.isoformat() if self.connected_at else None, |
|
|
'disconnected_at': self.disconnected_at.isoformat() if self.disconnected_at else None, |
|
|
'bytes_received': self.bytes_received, |
|
|
'bytes_sent': self.bytes_sent, |
|
|
'session_duration': self.get_duration(), |
|
|
'disconnect_reason': self.disconnect_reason, |
|
|
'is_active': self.is_active() |
|
|
} |
|
|
|
|
|
|
|
|
class ServerConfiguration(db.Model): |
|
|
"""VPN Server configuration management""" |
|
|
__tablename__ = 'server_configurations' |
|
|
|
|
|
id = db.Column(db.Integer, primary_key=True) |
|
|
protocol = db.Column(db.String(20), nullable=False) |
|
|
server_name = db.Column(db.String(100), nullable=False) |
|
|
listen_port = db.Column(db.Integer, nullable=False) |
|
|
network_cidr = db.Column(db.String(18), nullable=False) |
|
|
dns_servers = db.Column(db.Text) |
|
|
routes = db.Column(db.Text) |
|
|
is_active = db.Column(db.Boolean, default=True) |
|
|
created_at = db.Column(db.DateTime, default=datetime.utcnow) |
|
|
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) |
|
|
max_clients = db.Column(db.Integer, default=100) |
|
|
|
|
|
def __init__(self, protocol, server_name, listen_port, network_cidr): |
|
|
self.protocol = protocol |
|
|
self.server_name = server_name |
|
|
self.listen_port = listen_port |
|
|
self.network_cidr = network_cidr |
|
|
|
|
|
def to_dict(self): |
|
|
"""Convert server configuration to dictionary""" |
|
|
return { |
|
|
'id': self.id, |
|
|
'protocol': self.protocol, |
|
|
'server_name': self.server_name, |
|
|
'listen_port': self.listen_port, |
|
|
'network_cidr': self.network_cidr, |
|
|
'dns_servers': self.dns_servers, |
|
|
'routes': self.routes, |
|
|
'is_active': self.is_active, |
|
|
'created_at': self.created_at.isoformat() if self.created_at else None, |
|
|
'updated_at': self.updated_at.isoformat() if self.updated_at else None, |
|
|
'max_clients': self.max_clients |
|
|
} |
|
|
|
|
|
|