Spaces:
Runtime error
Runtime error
| """ | |
| Main application entry point | |
| """ | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from routers import admin, users, vpn | |
| from core.error_handlers import setup_error_handlers | |
| from core.database import init_db | |
| from core.logger import setup_logging | |
| from core.middleware import RequestLoggerMiddleware, ErrorHandlerMiddleware | |
| import logging | |
| import os | |
| import requests | |
| import socket | |
| from starlette.responses import RedirectResponse as StarletteRedirect | |
| from starlette.status import HTTP_302_FOUND, HTTP_303_SEE_OTHER | |
| import logging | |
| import json | |
| import asyncio | |
| import threading | |
| import os | |
| import json | |
| import uuid | |
| import bcrypt | |
| from datetime import datetime, timedelta | |
| import logging | |
| from typing import Dict, Optional, List | |
| from sqlalchemy.orm import Session | |
| # Initialize logging | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| # Create FastAPI application | |
| app = FastAPI( | |
| title="VPN Server API", | |
| description="API for managing VPN server and users", | |
| version="1.0.0" | |
| ) | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure this properly in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Add custom middleware | |
| app.add_middleware(RequestLoggerMiddleware) | |
| app.add_middleware(ErrorHandlerMiddleware) | |
| # Configure static files and templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Include routers | |
| app.include_router(admin.router, prefix="/api") | |
| app.include_router(users.router, prefix="/api") | |
| app.include_router(vpn.router, prefix="/api") | |
| # Setup error handlers | |
| setup_error_handlers(app) | |
| async def startup_event(): | |
| """Initialize application on startup""" | |
| try: | |
| # Initialize database | |
| await init_db() | |
| logger.info("Database initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize application: {e}") | |
| raise | |
| async def root(request: Request): | |
| """Root endpoint - renders the main template""" | |
| return templates.TemplateResponse( | |
| "index.html", | |
| {"request": request} | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy"} | |
| # Database dependency | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| # OAuth2 password bearer for token auth | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) | |
| # Pydantic models for request/response validation | |
| class Token(BaseModel): | |
| access_token: str | |
| token_type: str | |
| app = FastAPI() | |
| # Configure static files and templates | |
| app.mount("/static", StaticFiles(directory="web/static"), name="static") | |
| templates = Jinja2Templates(directory="web/templates") | |
| # Add template context processor for static URLs and other global helpers | |
| def static_url(path: str) -> str: | |
| return f"/static/{path}" | |
| async def get_current_user(request: Request, db: Session = Depends(get_db)): | |
| try: | |
| user = await get_optional_user(request, db) | |
| if user: | |
| return { | |
| "username": user.username, | |
| "id": str(user.id), | |
| "config_id": user.config_id | |
| } | |
| return None | |
| except Exception: | |
| return None | |
| templates.env.globals.update({ | |
| "static_url": static_url, | |
| "url_for": lambda name, **params: f"/{name}" if name != "static" else static_url(params.get("filename", "")), | |
| }) | |
| # OAuth2 password bearer for token auth | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| # Pydantic models for request/response validation | |
| class Token(BaseModel): | |
| access_token: str | |
| token_type: str | |
| class TokenData(BaseModel): | |
| username: Optional[str] = None | |
| class UserBase(BaseModel): | |
| email: EmailStr | |
| class UserCreate(UserBase): | |
| password: str | |
| class UserInDB(UserBase): | |
| hashed_password: str | |
| config_id: str | |
| created_at: datetime | |
| class Config: | |
| orm_mode = True | |
| async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| token_data = TokenData(username=username) | |
| except JWTError: | |
| raise credentials_exception | |
| user = db.query(User).filter(User.username == token_data.username).first() | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=15) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| return encoded_jwt | |
| # Global VPN server state | |
| vpn_server: Optional[OutlineServer] = None | |
| session_tracker: Optional[SessionTracker] = None | |
| logger: Optional[LogManager] = None | |
| # Initialize database | |
| init_db() | |
| CONFIG_DIR = 'config' | |
| USERS_FILE = os.path.join(CONFIG_DIR, 'users.json') | |
| os.makedirs(CONFIG_DIR, exist_ok=True) | |
| def load_users(): | |
| if os.path.exists(USERS_FILE): | |
| with open(USERS_FILE, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def save_users(users): | |
| with open(USERS_FILE, 'w') as f: | |
| json.dump(users, f) | |
| def get_server_ip(): | |
| """Get the server's public IP address""" | |
| try: | |
| # First try to get public IP from external service | |
| response = requests.get('https://api.ipify.org', timeout=5) | |
| if response.status_code == 200: | |
| return response.text.strip() | |
| except: | |
| pass | |
| try: | |
| # Try another public IP service as backup | |
| response = requests.get('https://ifconfig.me', timeout=5) | |
| if response.status_code == 200: | |
| return response.text.strip() | |
| except: | |
| pass | |
| # Fallback: Get local IP | |
| try: | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| s.connect(('8.8.8.8', 80)) | |
| local_ip = s.getsockname()[0] | |
| s.close() | |
| return local_ip | |
| except: | |
| # Last resort fallback | |
| return '127.0.0.1' | |
| def initialize_ikev2_server(): | |
| """Initialize IKEv2 server""" | |
| global ikev2_server | |
| server_ip = get_server_ip() | |
| ikev2_server = IKEv2Server(server_ip, logger) | |
| logger.log(LogLevel.INFO, LogCategory.SYSTEM, "app", "IKEv2 server initialized") | |
| def initialize_vpn_server(): | |
| """Initialize the VPN server components""" | |
| global vpn_server, session_tracker, logger, ikev2_server | |
| # Initialize logger | |
| logger = LogManager() | |
| logger.log(LogLevel.INFO, LogCategory.SYSTEM, "app", "Initializing VPN server") | |
| # Initialize session tracker | |
| session_tracker = SessionTracker() | |
| # Initialize IKEv2 server | |
| initialize_ikev2_server() | |
| # Initialize VPN server | |
| server_ip = get_server_ip() | |
| vpn_server_config = { | |
| "server": { | |
| "host": server_ip, # Use automatically detected server IP | |
| "port": 8388, # Default Shadowsocks port | |
| "virtual_network": "10.7.0.0/24", # Virtual network for client IPs | |
| "protocols": { | |
| "shadowsocks": { | |
| "enabled": True, | |
| "port": 8388 | |
| }, | |
| "wireguard": { | |
| "enabled": True, | |
| "port": 51820 | |
| }, | |
| "openvpn": { | |
| "enabled": True, | |
| "port": 1194 | |
| }, | |
| "ikev2": { | |
| "enabled": True, | |
| "port": 500 | |
| } | |
| } | |
| }, | |
| "security": { | |
| "cipher": "aes-256-gcm", | |
| "auth": "sha256", | |
| "enable_perfect_forward_secrecy": True | |
| } | |
| } | |
| vpn_server = OutlineServer(vpn_server_config) | |
| # Start the VPN server in a separate thread | |
| def run_server(): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(vpn_server.start()) | |
| loop.run_forever() | |
| server_thread = threading.Thread(target=run_server, daemon=True) | |
| server_thread.start() | |
| logger.log(LogLevel.INFO, LogCategory.SYSTEM, "app", f"VPN server initialized and started on {server_ip}") | |
| def load_users(): | |
| if os.path.exists(USERS_FILE): | |
| with open(USERS_FILE, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def save_users(users): | |
| with open(USERS_FILE, 'w') as f: | |
| json.dump(users, f) | |
| def login_required(func): | |
| async def wrapper(*args, **kwargs): | |
| token = kwargs.get('token') | |
| if not token: | |
| return StarletteRedirect('/login', status_code=HTTP_303_SEE_OTHER) | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| return StarletteRedirect('/login', status_code=HTTP_303_SEE_OTHER) | |
| except JWTError: | |
| return StarletteRedirect('/login', status_code=HTTP_303_SEE_OTHER) | |
| return await func(*args, **kwargs) | |
| return wrapper | |
| async def get_optional_user( | |
| request: Request, | |
| db: Session = Depends(get_db) | |
| ) -> Optional[User]: | |
| try: | |
| auth = request.headers.get("Authorization") | |
| if not auth: | |
| return None | |
| scheme, _, token = auth.partition(" ") | |
| if scheme.lower() != "bearer": | |
| return None | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| return None | |
| user = db.query(User).filter(User.username == username).first() | |
| return user | |
| except JWTError: | |
| return None | |
| except Exception: | |
| return None | |
| async def index(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def login( | |
| request: Request, | |
| form_data: OAuth2PasswordRequestForm = Depends(), | |
| db: Session = Depends(get_db) | |
| ): | |
| user = db.query(User).filter(User.username == form_data.username).first() | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # Verify password | |
| if not verify_password(form_data.password, user.hashed_password): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # Check if account is locked | |
| if user.status == UserStatus.LOCKED: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Account is locked. Please contact support." | |
| ) | |
| # Create access token | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| access_token = create_access_token( | |
| data={"sub": user.username}, expires_delta=access_token_expires | |
| ) | |
| # Record successful login | |
| user_service = UserService(db) | |
| user_service.create_session( | |
| user=user, | |
| ip_address=request.client.host, | |
| device_info=request.headers.get("user-agent", "") | |
| ) | |
| user.record_login_attempt(success=True) | |
| db.commit() | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| async def signup(user: UserCreate, db: Session = Depends(get_db)): | |
| # Check if user exists | |
| db_user = db.query(User).filter(User.username == user.email).first() | |
| if db_user: | |
| raise HTTPException(status_code=400, detail="Email already registered") | |
| # Create new user | |
| config_id = str(uuid.uuid4()) | |
| hashed_password = get_password_hash(user.password) | |
| db_user = User( | |
| username=user.email, | |
| hashed_password=hashed_password, | |
| config_id=config_id, | |
| created_at=datetime.utcnow() | |
| ) | |
| db.add(db_user) | |
| db.commit() | |
| db.refresh(db_user) | |
| # Create VPN configuration | |
| create_user_config(config_id) | |
| # Create access token | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| access_token = create_access_token( | |
| data={"sub": user.email}, expires_delta=access_token_expires | |
| ) | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| async def signup_form(request: Request): | |
| return templates.TemplateResponse("signup.html", {"request": request}) | |
| async def dashboard(request: Request, current_user: User = Depends(get_current_user)): | |
| stats = get_user_stats(current_user.config_id) | |
| return templates.TemplateResponse("dashboard.html", { | |
| "request": request, | |
| "stats": stats | |
| }) | |
| async def download_config(current_user: User = Depends(get_current_user)): | |
| config_path = os.path.join(CONFIG_DIR, f"{current_user.config_id}.json") | |
| if not os.path.exists(config_path): | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Configuration not found" | |
| ) | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| return JSONResponse(content=config) | |
| async def get_stats(current_user: User = Depends(get_current_user)): | |
| return JSONResponse(content=get_user_stats(current_user.config_id)) | |
| def get_server_ip(): | |
| """Get the server's public IP address""" | |
| try: | |
| # First try to get public IP from external service | |
| response = requests.get('https://api.ipify.org') | |
| if response.status_code == 200: | |
| return response.text.strip() | |
| except: | |
| pass | |
| # Fallback: Get local IP | |
| try: | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| s.connect(('8.8.8.8', 80)) | |
| local_ip = s.getsockname()[0] | |
| s.close() | |
| return local_ip | |
| except: | |
| return '127.0.0.1' # Last resort fallback | |
| def initialize_ikev2_server(): | |
| """Initialize IKEv2 server""" | |
| global ikev2_server | |
| server_ip = get_server_ip() | |
| ikev2_server = IKEv2Server(server_ip, logger) | |
| logger.log(LogLevel.INFO, LogCategory.SYSTEM, "app", "IKEv2 server initialized") | |
| def generate_ikev2_certificate(config_id: str) -> Dict: | |
| """Generate IKEv2 certificates for a user""" | |
| username = f"user_{config_id[:8]}" | |
| password = str(uuid.uuid4()) | |
| psk = str(uuid.uuid4()) | |
| try: | |
| cert_data = ikev2_server.add_user(config_id, username, password, psk) | |
| logger.info(LogCategory.SYSTEM, "app", f"Generated IKEv2 certificates for user {config_id}") | |
| return cert_data | |
| except Exception as e: | |
| logger.error(LogCategory.SYSTEM, "app", f"Failed to generate IKEv2 certificates: {e}") | |
| return None | |
| def create_user_config(config_id): | |
| """Create Outline VPN configuration for a new user""" | |
| if not os.path.exists(CONFIG_DIR): | |
| os.makedirs(CONFIG_DIR) | |
| server_ip = get_server_ip() | |
| access_key = str(uuid.uuid4()) | |
| # Outline/Shadowsocks config | |
| ss_config = { | |
| 'id': config_id, | |
| 'server': { | |
| 'host': server_ip, | |
| 'port': 8388 # Shadowsocks port | |
| }, | |
| 'access_key': access_key, | |
| 'protocol': 'shadowsocks', | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| # IKEv2 config (Windows 10/11, Android 10+) | |
| ikev2_config = { | |
| 'id': f"{config_id}_ikev2", | |
| 'server': { | |
| 'host': server_ip, | |
| 'port': 500 # IKEv2 port | |
| }, | |
| 'credentials': { | |
| 'username': f"user_{config_id[:8]}", | |
| 'password': str(uuid.uuid4()), | |
| }, | |
| 'psk': str(uuid.uuid4()), # Pre-shared key | |
| 'certificate': generate_ikev2_certificate(config_id), | |
| 'protocol': 'ikev2', | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| # L2TP/IPsec config (Windows, Android) | |
| l2tp_config = { | |
| 'id': f"{config_id}_l2tp", | |
| 'server': { | |
| 'host': server_ip, | |
| 'ports': { | |
| 'l2tp': 1701, | |
| 'ipsec': [500, 4500] # IPsec ports for NAT traversal | |
| } | |
| }, | |
| 'credentials': { | |
| 'username': f"user_{config_id[:8]}", | |
| 'password': str(uuid.uuid4()) | |
| }, | |
| 'ipsec': { | |
| 'psk': str(uuid.uuid4()), # Pre-shared key for IPsec | |
| 'encryption': 'aes-256-cbc', | |
| 'hash': 'sha256' | |
| }, | |
| 'protocol': 'l2tp_ipsec', | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| # PPTP config (Legacy support - Windows, Android) | |
| pptp_config = { | |
| 'id': f"{config_id}_pptp", | |
| 'server': { | |
| 'host': server_ip, | |
| 'port': 1723 # PPTP port | |
| }, | |
| 'credentials': { | |
| 'username': f"user_{config_id[:8]}", | |
| 'password': str(uuid.uuid4()) | |
| }, | |
| 'protocol': 'pptp', | |
| 'encryption': 'require-mppe', # Maximum PPTP security | |
| 'warning': 'PPTP is considered less secure, use IKEv2 or L2TP/IPsec when possible', | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| # OpenVPN config (Universal support) | |
| openvpn_config = { | |
| 'id': f"{config_id}_openvpn", | |
| 'server': { | |
| 'host': server_ip, | |
| 'port': 1194, # OpenVPN default port | |
| 'protocol': 'udp' # UDP for better performance | |
| }, | |
| 'credentials': { | |
| 'username': f"user_{config_id[:8]}", | |
| 'password': str(uuid.uuid4()) | |
| }, | |
| 'certificates': generate_openvpn_certificates(config_id), | |
| 'protocol': 'openvpn', | |
| 'created_at': datetime.now().isoformat(), | |
| 'config_file': generate_openvpn_config(config_id, server_ip) | |
| } | |
| # WireGuard config (Built-in Windows 11, Android, iOS) | |
| wireguard_config = { | |
| 'id': f"{config_id}_wireguard", | |
| 'server': { | |
| 'host': server_ip, | |
| 'port': 51820, # WireGuard default port | |
| 'public_key': generate_wireguard_keys(config_id)['server_public'], | |
| 'allowed_ips': ['0.0.0.0/0', '::/0'] # Route all traffic | |
| }, | |
| 'client': { | |
| 'private_key': generate_wireguard_keys(config_id)['client_private'], | |
| 'public_key': generate_wireguard_keys(config_id)['client_public'], | |
| 'address': f'10.7.0.{2 + len(load_users())}', # Unique IP for each client | |
| 'dns': ['1.1.1.1', '8.8.8.8'] | |
| }, | |
| 'protocol': 'wireguard', | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| # L2TP/IPsec config (Built-in Windows, Android, iOS) | |
| l2tp_config = { | |
| 'id': f"{config_id}_l2tp", | |
| 'server': { | |
| 'host': server_ip, | |
| 'port': 1701, # L2TP port | |
| }, | |
| 'credentials': { | |
| 'username': f"user_{config_id[:8]}", | |
| 'password': str(uuid.uuid4()) | |
| }, | |
| 'ipsec': { | |
| 'psk': str(uuid.uuid4()) # Pre-shared key for IPsec | |
| }, | |
| 'protocol': 'l2tp_ipsec', | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| # Combined config with all supported protocols | |
| config = { | |
| 'id': config_id, | |
| 'protocols': { | |
| 'shadowsocks': ss_config, | |
| 'ikev2': ikev2_config, | |
| 'l2tp': l2tp_config, | |
| 'pptp': pptp_config | |
| }, | |
| 'recommended_protocol': { | |
| 'windows': 'ikev2', | |
| 'android': 'ikev2', | |
| 'fallback': 'l2tp' | |
| }, | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| config_path = os.path.join(CONFIG_DIR, f"{config_id}.json") | |
| with open(config_path, 'w') as f: | |
| json.dump(config, f) | |
| def get_user_stats(config_id): | |
| """Get real VPN usage statistics for a user from all active sessions""" | |
| try: | |
| if not session_tracker: | |
| logger.error(LogCategory.SYSTEM, "app", "Session tracker not initialized") | |
| return None | |
| # Get all sessions for this user | |
| user_sessions = session_tracker.get_user_sessions(config_id) | |
| if not user_sessions: | |
| return { | |
| 'bytes_sent': 0, | |
| 'bytes_received': 0, | |
| 'connected_since': None, | |
| 'last_seen': None, | |
| 'status': 'disconnected', | |
| 'active_sessions': [], | |
| 'protocols': [] | |
| } | |
| # Aggregate stats from all active sessions | |
| total_bytes_sent = 0 | |
| total_bytes_received = 0 | |
| earliest_connection = None | |
| latest_seen = None | |
| active_sessions = [] | |
| used_protocols = set() | |
| for sess in user_sessions: | |
| # Update totals | |
| total_bytes_sent += sess.bytes_out | |
| total_bytes_received += sess.bytes_in | |
| # Track connection times | |
| session_start = datetime.fromtimestamp(sess.start_time) | |
| session_last_seen = datetime.fromtimestamp(sess.last_seen) | |
| if not earliest_connection or session_start < earliest_connection: | |
| earliest_connection = session_start | |
| if not latest_seen or session_last_seen > latest_seen: | |
| latest_seen = session_last_seen | |
| # Track protocols | |
| used_protocols.add(sess.protocol) | |
| # Get session details | |
| session_info = { | |
| 'id': sess.session_id, | |
| 'protocol': sess.protocol, | |
| 'assigned_ip': sess.assigned_ip, | |
| 'connected_since': session_start.isoformat(), | |
| 'last_seen': session_last_seen.isoformat(), | |
| 'bytes_sent': sess.bytes_out, | |
| 'bytes_received': sess.bytes_in, | |
| 'is_offline': sess.is_offline | |
| } | |
| active_sessions.append(session_info) | |
| # Determine overall status | |
| current_time = datetime.now() | |
| is_active = any( | |
| (current_time - datetime.fromtimestamp(s.last_seen)).total_seconds() < 300 # 5 minutes | |
| for s in user_sessions | |
| ) | |
| status = 'active' if is_active else 'offline' | |
| if not is_active and any(s.is_offline for s in user_sessions): | |
| status = 'offline_available' | |
| return { | |
| 'bytes_sent': total_bytes_sent, | |
| 'bytes_received': total_bytes_received, | |
| 'connected_since': earliest_connection.isoformat() if earliest_connection else None, | |
| 'last_seen': latest_seen.isoformat() if latest_seen else None, | |
| 'status': status, | |
| 'active_sessions': active_sessions, | |
| 'protocols': list(used_protocols) | |
| } | |
| except Exception as e: | |
| logger.error(LogCategory.SYSTEM, "app", f"Error getting user stats: {e}") | |
| return None | |
| async def logout(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): | |
| try: | |
| # Find and end the current session | |
| current_session = ( | |
| db.query(UserSession) | |
| .filter(UserSession.user_id == current_user.id) | |
| .order_by(UserSession.created_at.desc()) | |
| .first() | |
| ) | |
| if current_session: | |
| current_session.expires_at = datetime.utcnow() | |
| db.commit() | |
| return StarletteRedirect('/', status_code=HTTP_303_SEE_OTHER) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Error during logout" | |
| ) | |
| async def forgot_password_form(request: Request): | |
| return templates.TemplateResponse("forgot_password.html", {"request": request}) | |
| async def forgot_password(email: str, db: Session = Depends(get_db)): | |
| try: | |
| user = db.query(User).filter(User.username == email).first() | |
| if user: | |
| # Generate password reset token | |
| user_service = UserService(db) | |
| reset_token = user_service.generate_reset_token() | |
| user.reset_token = reset_token | |
| user.reset_token_expires = datetime.utcnow() + timedelta(hours=24) | |
| db.commit() | |
| # TODO: Send reset email with token | |
| # For now, just return success message | |
| return JSONResponse( | |
| content={"message": "Password reset link has been sent to your email address"} | |
| ) | |
| else: | |
| # To prevent user enumeration, show the same message | |
| return JSONResponse( | |
| content={"message": "Password reset link has been sent to your email address"} | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Error processing password reset request" | |
| ) | |
| async def startup_event(): | |
| """Initialize VPN server on startup""" | |
| initialize_vpn_server() | |
| async def shutdown_event(): | |
| """Shutdown VPN server on application shutdown""" | |
| global vpn_server | |
| if vpn_server and vpn_server.is_running: | |
| await vpn_server.stop() | |
| logger.log(LogLevel.INFO, LogCategory.SYSTEM, "app", "VPN server shut down.") | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |