| """ |
| GitHub OAuth service with PKCE support for LandPPT |
| Uses Valkey for state storage with in-memory fallback |
| """ |
|
|
| import secrets |
| import hashlib |
| import base64 |
| import time |
| import httpx |
| import logging |
| from urllib.parse import urlencode |
| from typing import Optional, Tuple, Dict, Any |
| from sqlalchemy.orm import Session |
|
|
| from ..database.models import User |
| from ..core.config import app_config |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| GITHUB_AUTHORIZE_URL = "https://github.com/login/oauth/authorize" |
| GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token" |
| GITHUB_USER_URL = "https://api.github.com/user" |
| GITHUB_EMAILS_URL = "https://api.github.com/user/emails" |
|
|
| |
| _oauth_states_fallback: Dict[str, Dict[str, Any]] = {} |
|
|
| |
| OAUTH_STATE_TTL = 1800 |
|
|
| def _strip_default_port(host: str, scheme: str) -> str: |
| if not host: |
| return host |
| normalized_scheme = (scheme or "").strip().lower() |
| if normalized_scheme not in {"http", "https"}: |
| return host |
|
|
| default_port = "443" if normalized_scheme == "https" else "80" |
| raw_host = host.strip() |
|
|
| |
| if raw_host.startswith("["): |
| end = raw_host.find("]") |
| if end > 0 and len(raw_host) > end + 2 and raw_host[end + 1] == ":": |
| port = raw_host[end + 2 :] |
| if port.isdigit() and port == default_port: |
| return raw_host[: end + 1] |
| return raw_host |
|
|
| |
| if raw_host.count(":") == 1: |
| hostname, port = raw_host.rsplit(":", 1) |
| if port.isdigit() and port == default_port: |
| return hostname |
| return raw_host |
|
|
|
|
| def _build_callback_url_from_request(request) -> Optional[str]: |
| """ |
| Build callback URL from incoming request/proxy headers when available. |
| """ |
| if not request: |
| return None |
| try: |
| headers = request.headers |
| host = (headers.get("x-forwarded-host") or headers.get("host") or request.url.netloc or "").strip() |
| if "," in host: |
| host = host.split(",", 1)[0].strip() |
|
|
| scheme = (headers.get("x-forwarded-proto") or request.url.scheme or "http").strip() |
| if "," in scheme: |
| scheme = scheme.split(",", 1)[0].strip() |
| scheme = scheme.lower() |
|
|
| if not host: |
| return None |
| host = _strip_default_port(host, scheme) |
| return f"{scheme}://{host}/auth/github/callback" |
| except Exception: |
| return None |
|
|
|
|
| def _get_system_oauth_config_sync() -> Dict[str, Any]: |
| """Best-effort read of system OAuth config from DB with app_config fallback.""" |
| config: Dict[str, Any] = {} |
| try: |
| from ..services.db_config_service import get_db_config_service |
|
|
| config_service = get_db_config_service() |
| config = config_service.get_all_config_sync(user_id=None) |
| except Exception as exc: |
| logger.warning("Failed to load GitHub OAuth config from DB: %s", exc) |
| return config |
|
|
|
|
| def generate_pkce_pair() -> Tuple[str, str]: |
| """ |
| Generate PKCE code_verifier and code_challenge pair. |
| |
| Returns: |
| Tuple of (code_verifier, code_challenge) |
| """ |
| |
| code_verifier = secrets.token_urlsafe(32) |
| |
| |
| code_challenge_bytes = hashlib.sha256(code_verifier.encode('ascii')).digest() |
| code_challenge = base64.urlsafe_b64encode(code_challenge_bytes).rstrip(b'=').decode('ascii') |
| |
| return code_verifier, code_challenge |
|
|
|
|
| def generate_state() -> str: |
| """Generate a random state parameter for CSRF protection.""" |
| return secrets.token_urlsafe(32) |
|
|
|
|
| async def store_oauth_state( |
| state: str, |
| code_verifier: str, |
| redirect_url: str = "/dashboard", |
| callback_url: Optional[str] = None, |
| invite_code: Optional[str] = None, |
| ) -> None: |
| """ |
| Store OAuth state with code_verifier for later verification. |
| Uses Valkey with in-memory fallback. |
| |
| Args: |
| state: The state parameter |
| code_verifier: The PKCE code verifier |
| redirect_url: URL to redirect after successful login |
| """ |
| state_data = { |
| "code_verifier": code_verifier, |
| "redirect_url": redirect_url, |
| "callback_url": callback_url, |
| "invite_code": (str(invite_code or "").strip().upper() or None), |
| "created_at": time.time() |
| } |
| |
| |
| try: |
| from ..services.cache_service import get_cache_service |
| cache = await get_cache_service() |
| |
| logger.info(f"GitHub OAuth: Valkey connected={cache.is_connected}") |
| |
| if cache.is_connected: |
| success = await cache.set_oauth_state("github", state, state_data, OAUTH_STATE_TTL) |
| if success: |
| logger.info(f"GitHub OAuth state stored in Valkey: {state[:8]}...") |
| return |
| else: |
| logger.warning(f"GitHub OAuth: Valkey set_oauth_state returned False") |
| except Exception as e: |
| logger.warning(f"Failed to store OAuth state in Valkey: {e}") |
| |
| |
| _oauth_states_fallback[state] = state_data |
| logger.info(f"GitHub OAuth state stored in memory (fallback): {state[:8]}... (total in memory: {len(_oauth_states_fallback)})") |
| |
| |
| _cleanup_old_states_fallback() |
|
|
|
|
| async def get_and_consume_oauth_state(state: str) -> Optional[Dict[str, Any]]: |
| """ |
| Get and remove OAuth state data. |
| Checks Valkey first, then falls back to in-memory storage. |
| |
| Args: |
| state: The state parameter to look up |
| |
| Returns: |
| State data dict if found, None otherwise |
| """ |
| |
| try: |
| from ..services.cache_service import get_cache_service |
| cache = await get_cache_service() |
| |
| logger.info(f"GitHub OAuth callback: Valkey connected={cache.is_connected}") |
| |
| if cache.is_connected: |
| data = await cache.get_and_consume_oauth_state("github", state) |
| if data: |
| logger.info(f"GitHub OAuth state retrieved from Valkey: {state[:8]}...") |
| return data |
| else: |
| logger.warning(f"GitHub OAuth: state not found in Valkey: {state[:8]}...") |
| except Exception as e: |
| logger.warning(f"Failed to get OAuth state from Valkey: {e}") |
| |
| |
| logger.info(f"GitHub OAuth: checking memory fallback (total stored: {len(_oauth_states_fallback)})") |
| data = _oauth_states_fallback.pop(state, None) |
| if data: |
| logger.info(f"GitHub OAuth state retrieved from memory (fallback): {state[:8]}...") |
| |
| if time.time() - data.get("created_at", 0) > OAUTH_STATE_TTL: |
| logger.warning(f"OAuth state expired in memory fallback: {state[:8]}...") |
| return None |
| |
| return data |
|
|
|
|
| def _cleanup_old_states_fallback() -> None: |
| """Remove expired OAuth states from memory fallback.""" |
| current_time = time.time() |
| expired_states = [ |
| state for state, data in _oauth_states_fallback.items() |
| if current_time - data.get("created_at", 0) > OAUTH_STATE_TTL |
| ] |
| for state in expired_states: |
| _oauth_states_fallback.pop(state, None) |
|
|
|
|
| def build_authorization_url(state: str, code_challenge: str, callback_url: Optional[str] = None) -> str: |
| """ |
| Build GitHub OAuth authorization URL with PKCE. |
| |
| Args: |
| state: CSRF protection state parameter |
| code_challenge: PKCE code challenge |
| |
| Returns: |
| Full authorization URL |
| """ |
| system_config = _get_system_oauth_config_sync() |
| params = { |
| "client_id": str(system_config.get("github_client_id") or app_config.github_client_id or "").strip(), |
| "redirect_uri": callback_url or get_callback_url(), |
| "scope": "read:user user:email", |
| "state": state, |
| "code_challenge": code_challenge, |
| "code_challenge_method": "S256", |
| "allow_signup": "true" |
| } |
| |
| query_string = urlencode(params) |
| return f"{GITHUB_AUTHORIZE_URL}?{query_string}" |
|
|
|
|
| def get_callback_url(request=None) -> str: |
| """Get the OAuth callback URL based on current configuration.""" |
| system_config = _get_system_oauth_config_sync() |
| request_callback = _build_callback_url_from_request(request) |
|
|
| |
| configured_callback = str(system_config.get("github_callback_url") or app_config.github_callback_url or "").strip() |
| callback_use_request_host = bool( |
| system_config.get("github_callback_use_request_host", getattr(app_config, "github_callback_use_request_host", False)) |
| ) |
| if configured_callback: |
| if request_callback and callback_use_request_host: |
| logger.info( |
| "GitHub OAuth callback uses request host because GITHUB_CALLBACK_USE_REQUEST_HOST=true: %s", |
| request_callback, |
| ) |
| return request_callback |
| return configured_callback |
|
|
| if request_callback: |
| return request_callback |
| |
| |
| host = app_config.host |
| port = app_config.port |
| |
| if host == "0.0.0.0": |
| host = "localhost" |
| |
| return f"http://{host}:{port}/auth/github/callback" |
|
|
|
|
| async def exchange_code_for_token( |
| code: str, |
| code_verifier: str, |
| callback_url: Optional[str] = None |
| ) -> Optional[str]: |
| """ |
| Exchange authorization code for access token using PKCE. |
| |
| Args: |
| code: Authorization code from GitHub |
| code_verifier: PKCE code verifier |
| |
| Returns: |
| Access token if successful, None otherwise |
| """ |
| system_config = _get_system_oauth_config_sync() |
| client_id = str(system_config.get("github_client_id") or app_config.github_client_id or "").strip() |
| client_secret = str(system_config.get("github_client_secret") or app_config.github_client_secret or "").strip() |
|
|
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| GITHUB_TOKEN_URL, |
| data={ |
| "client_id": client_id, |
| "client_secret": client_secret, |
| "code": code, |
| "redirect_uri": callback_url or get_callback_url(), |
| "code_verifier": code_verifier |
| }, |
| headers={ |
| "Accept": "application/json" |
| }, |
| timeout=30.0 |
| ) |
| |
| if response.status_code != 200: |
| logger.error(f"GitHub token exchange failed: {response.status_code} - {response.text}") |
| return None |
| |
| data = response.json() |
| |
| if "error" in data: |
| logger.error(f"GitHub OAuth error: {data.get('error')} - {data.get('error_description')}") |
| return None |
| |
| return data.get("access_token") |
| |
| except Exception as e: |
| logger.error(f"Failed to exchange code for token: {e}") |
| return None |
|
|
|
|
| async def get_github_user_info(access_token: str) -> Optional[Dict[str, Any]]: |
| """ |
| Fetch GitHub user information using access token. |
| |
| Args: |
| access_token: GitHub access token |
| |
| Returns: |
| User info dict if successful, None otherwise |
| """ |
| try: |
| async with httpx.AsyncClient() as client: |
| |
| user_response = await client.get( |
| GITHUB_USER_URL, |
| headers={ |
| "Authorization": f"Bearer {access_token}", |
| "Accept": "application/json" |
| }, |
| timeout=30.0 |
| ) |
| |
| if user_response.status_code != 200: |
| logger.error(f"Failed to get GitHub user: {user_response.status_code}") |
| return None |
| |
| user_data = user_response.json() |
| |
| |
| email = user_data.get("email") |
| if not email: |
| emails_response = await client.get( |
| GITHUB_EMAILS_URL, |
| headers={ |
| "Authorization": f"Bearer {access_token}", |
| "Accept": "application/json" |
| }, |
| timeout=30.0 |
| ) |
| |
| if emails_response.status_code == 200: |
| emails = emails_response.json() |
| |
| for e in emails: |
| if e.get("primary") and e.get("verified"): |
| email = e.get("email") |
| break |
| |
| if not email: |
| for e in emails: |
| if e.get("verified"): |
| email = e.get("email") |
| break |
| |
| return { |
| "id": str(user_data.get("id")), |
| "login": user_data.get("login"), |
| "email": email, |
| "name": user_data.get("name"), |
| "avatar_url": user_data.get("avatar_url") |
| } |
| |
| except Exception as e: |
| logger.error(f"Failed to get GitHub user info: {e}") |
| return None |
|
|
|
|
| def get_or_create_user_by_github( |
| db: Session, |
| github_id: str, |
| github_login: str, |
| email: Optional[str], |
| name: Optional[str], |
| avatar_url: Optional[str], |
| invite_code: Optional[str] = None, |
| ) -> tuple[Optional[User], bool, Optional[str]]: |
| """ |
| Get existing user or create new user from GitHub OAuth. |
| |
| Logic: |
| 1. If user with github_id exists -> return that user |
| 2. If email exists and matches a local user -> link GitHub to that user |
| 3. Otherwise -> create new user |
| |
| Args: |
| db: Database session |
| github_id: GitHub user ID |
| github_login: GitHub username |
| email: User's email |
| name: User's display name |
| avatar_url: User's avatar URL |
| |
| Returns: |
| User object if successful, None otherwise |
| """ |
| try: |
| |
| existing_github_user = db.query(User).filter(User.github_id == github_id).first() |
| if existing_github_user: |
| |
| existing_github_user.last_login = time.time() |
| if avatar_url: |
| existing_github_user.avatar = avatar_url |
| db.commit() |
| return existing_github_user, False, None |
| |
| |
| if email: |
| existing_email_user = db.query(User).filter(User.email == email).first() |
| if existing_email_user: |
| |
| existing_email_user.github_id = github_id |
| existing_email_user.oauth_provider = "github" |
| existing_email_user.last_login = time.time() |
| if avatar_url and not existing_email_user.avatar: |
| existing_email_user.avatar = avatar_url |
| db.commit() |
| logger.info(f"Linked GitHub account {github_login} to existing user {existing_email_user.username}") |
| return existing_email_user, False, None |
|
|
| from ..services.community_service import community_service |
|
|
| try: |
| validated_invite = community_service.resolve_registration_invite(db, invite_code, "github") |
| except ValueError as exc: |
| return None, False, str(exc) |
| |
| |
| |
| username = github_login |
| existing_username = db.query(User).filter(User.username == username).first() |
| if existing_username: |
| username = f"{github_login}_{secrets.token_hex(4)}" |
| |
| |
| default_credits = 0 |
| if app_config.enable_credits_system: |
| default_credits = app_config.default_credits_for_new_users |
| |
| |
| new_user = User( |
| username=username, |
| email=email, |
| avatar=avatar_url, |
| github_id=github_id, |
| oauth_provider="github", |
| registration_channel="github", |
| is_active=True, |
| is_admin=False, |
| credits_balance=default_credits, |
| created_at=time.time(), |
| last_login=time.time() |
| ) |
| |
| new_user.password_hash = hashlib.sha256(secrets.token_bytes(32)).hexdigest() |
| |
| db.add(new_user) |
| db.flush() |
| if validated_invite is not None: |
| community_service.apply_invite_code_to_user(db, new_user, validated_invite, "github") |
| db.commit() |
| db.refresh(new_user) |
| |
| logger.info(f"Created new user from GitHub OAuth: {username} (GitHub: {github_login})") |
| return new_user, True, None |
| |
| except Exception as e: |
| logger.error(f"Failed to get or create user from GitHub: {e}") |
| db.rollback() |
| return None, False, "GitHub 注册失败,请稍后重试" |
|
|
|
|
| def is_github_oauth_enabled() -> bool: |
| """Check if GitHub OAuth is properly configured and enabled.""" |
| system_config = _get_system_oauth_config_sync() |
| enabled = bool(system_config.get("github_oauth_enabled", app_config.github_oauth_enabled)) |
| client_id = str(system_config.get("github_client_id") or app_config.github_client_id or "").strip() |
| client_secret = str(system_config.get("github_client_secret") or app_config.github_client_secret or "").strip() |
| return ( |
| enabled |
| and bool(client_id) |
| and bool(client_secret) |
| ) |
|
|