Spaces:
Paused
Paused
| import logging | |
| from datetime import datetime, timezone | |
| from typing import Optional | |
| import requests | |
| from flask import current_app, redirect, request | |
| from flask_restful import Resource | |
| from werkzeug.exceptions import Unauthorized | |
| from configs import dify_config | |
| from constants.languages import languages | |
| from events.tenant_event import tenant_was_created | |
| from extensions.ext_database import db | |
| from libs.helper import extract_remote_ip | |
| from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | |
| from models import Account | |
| from models.account import AccountStatus | |
| from services.account_service import AccountService, RegisterService, TenantService | |
| from services.errors.account import AccountNotFoundError | |
| from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError | |
| from services.feature_service import FeatureService | |
| from .. import api | |
| def get_oauth_providers(): | |
| with current_app.app_context(): | |
| if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET: | |
| github_oauth = None | |
| else: | |
| github_oauth = GitHubOAuth( | |
| client_id=dify_config.GITHUB_CLIENT_ID, | |
| client_secret=dify_config.GITHUB_CLIENT_SECRET, | |
| redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", | |
| ) | |
| if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: | |
| google_oauth = None | |
| else: | |
| google_oauth = GoogleOAuth( | |
| client_id=dify_config.GOOGLE_CLIENT_ID, | |
| client_secret=dify_config.GOOGLE_CLIENT_SECRET, | |
| redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", | |
| ) | |
| OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} | |
| return OAUTH_PROVIDERS | |
| class OAuthLogin(Resource): | |
| def get(self, provider: str): | |
| invite_token = request.args.get("invite_token") or None | |
| OAUTH_PROVIDERS = get_oauth_providers() | |
| with current_app.app_context(): | |
| oauth_provider = OAUTH_PROVIDERS.get(provider) | |
| print(vars(oauth_provider)) | |
| if not oauth_provider: | |
| return {"error": "Invalid provider"}, 400 | |
| auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) | |
| return redirect(auth_url) | |
| class OAuthCallback(Resource): | |
| def get(self, provider: str): | |
| OAUTH_PROVIDERS = get_oauth_providers() | |
| with current_app.app_context(): | |
| oauth_provider = OAUTH_PROVIDERS.get(provider) | |
| if not oauth_provider: | |
| return {"error": "Invalid provider"}, 400 | |
| code = request.args.get("code") | |
| state = request.args.get("state") | |
| invite_token = None | |
| if state: | |
| invite_token = state | |
| try: | |
| token = oauth_provider.get_access_token(code) | |
| user_info = oauth_provider.get_user_info(token) | |
| except requests.exceptions.HTTPError as e: | |
| logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") | |
| return {"error": "OAuth process failed"}, 400 | |
| if invite_token and RegisterService.is_valid_invite_token(invite_token): | |
| invitation = RegisterService._get_invitation_by_token(token=invite_token) | |
| if invitation: | |
| invitation_email = invitation.get("email", None) | |
| if invitation_email != user_info.email: | |
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") | |
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") | |
| try: | |
| account = _generate_account(provider, user_info) | |
| except AccountNotFoundError: | |
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") | |
| except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): | |
| return redirect( | |
| f"{dify_config.CONSOLE_WEB_URL}/signin" | |
| "?message=Workspace not found, please contact system admin to invite you to join in a workspace." | |
| ) | |
| # Check account status | |
| if account.status == AccountStatus.BANNED.value: | |
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") | |
| if account.status == AccountStatus.PENDING.value: | |
| account.status = AccountStatus.ACTIVE.value | |
| account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) | |
| db.session.commit() | |
| try: | |
| TenantService.create_owner_tenant_if_not_exist(account) | |
| except Unauthorized: | |
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") | |
| except WorkSpaceNotAllowedCreateError: | |
| return redirect( | |
| f"{dify_config.CONSOLE_WEB_URL}/signin" | |
| "?message=Workspace not found, please contact system admin to invite you to join in a workspace." | |
| ) | |
| token_pair = AccountService.login( | |
| account=account, | |
| ip_address=extract_remote_ip(request), | |
| ) | |
| return redirect( | |
| f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" | |
| ) | |
| def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | |
| account = Account.get_by_openid(provider, user_info.id) | |
| if not account: | |
| account = Account.query.filter_by(email=user_info.email).first() | |
| return account | |
| def _generate_account(provider: str, user_info: OAuthUserInfo): | |
| # Get account by openid or email. | |
| account = _get_account_by_openid_or_email(provider, user_info) | |
| if account: | |
| tenant = TenantService.get_join_tenants(account) | |
| if not tenant: | |
| if not FeatureService.get_system_features().is_allow_create_workspace: | |
| raise WorkSpaceNotAllowedCreateError() | |
| else: | |
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |
| TenantService.create_tenant_member(tenant, account, role="owner") | |
| account.current_tenant = tenant | |
| tenant_was_created.send(tenant) | |
| if not account: | |
| if not FeatureService.get_system_features().is_allow_register: | |
| raise AccountNotFoundError() | |
| account_name = user_info.name or "Dify" | |
| account = RegisterService.register( | |
| email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider | |
| ) | |
| # Set interface language | |
| preferred_lang = request.accept_languages.best_match(languages) | |
| if preferred_lang and preferred_lang in languages: | |
| interface_language = preferred_lang | |
| else: | |
| interface_language = languages[0] | |
| account.interface_language = interface_language | |
| db.session.commit() | |
| # Link account | |
| AccountService.link_account_integrate(provider, user_info.id, account) | |
| return account | |
| api.add_resource(OAuthLogin, "/oauth/login/<provider>") | |
| api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") | |