|
|
import base64 |
|
|
import random |
|
|
import warnings |
|
|
from collections.abc import Coroutine |
|
|
from datetime import datetime, timedelta, timezone |
|
|
from typing import TYPE_CHECKING, Annotated |
|
|
from uuid import UUID |
|
|
|
|
|
from cryptography.fernet import Fernet |
|
|
from fastapi import Depends, HTTPException, Security, status |
|
|
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer |
|
|
from jose import JWTError, jwt |
|
|
from loguru import logger |
|
|
from sqlmodel.ext.asyncio.session import AsyncSession |
|
|
from starlette.websockets import WebSocket |
|
|
|
|
|
from langflow.services.database.models.api_key.crud import check_key |
|
|
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at |
|
|
from langflow.services.database.models.user.model import User, UserRead |
|
|
from langflow.services.deps import get_db_service, get_session, get_settings_service |
|
|
from langflow.services.settings.service import SettingsService |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langflow.services.database.models.api_key.model import ApiKey |
|
|
|
|
|
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) |
|
|
|
|
|
API_KEY_NAME = "x-api-key" |
|
|
|
|
|
api_key_query = APIKeyQuery(name=API_KEY_NAME, scheme_name="API key query", auto_error=False) |
|
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, scheme_name="API key header", auto_error=False) |
|
|
|
|
|
MINIMUM_KEY_LENGTH = 32 |
|
|
|
|
|
|
|
|
|
|
|
async def api_key_security( |
|
|
query_param: Annotated[str, Security(api_key_query)], |
|
|
header_param: Annotated[str, Security(api_key_header)], |
|
|
) -> UserRead | None: |
|
|
settings_service = get_settings_service() |
|
|
result: ApiKey | User | None |
|
|
|
|
|
async with get_db_service().with_async_session() as db: |
|
|
if settings_service.auth_settings.AUTO_LOGIN: |
|
|
|
|
|
if not settings_service.auth_settings.SUPERUSER: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
|
detail="Missing first superuser credentials", |
|
|
) |
|
|
|
|
|
result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER) |
|
|
|
|
|
elif not query_param and not header_param: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_403_FORBIDDEN, |
|
|
detail="An API key must be passed as query or header", |
|
|
) |
|
|
|
|
|
elif query_param: |
|
|
result = await check_key(db, query_param) |
|
|
|
|
|
else: |
|
|
result = await check_key(db, header_param) |
|
|
|
|
|
if not result: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_403_FORBIDDEN, |
|
|
detail="Invalid or missing API key", |
|
|
) |
|
|
if isinstance(result, User): |
|
|
return UserRead.model_validate(result, from_attributes=True) |
|
|
msg = "Invalid result type" |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
async def get_current_user( |
|
|
token: Annotated[str, Security(oauth2_login)], |
|
|
query_param: Annotated[str, Security(api_key_query)], |
|
|
header_param: Annotated[str, Security(api_key_header)], |
|
|
db: Annotated[AsyncSession, Depends(get_session)], |
|
|
) -> User: |
|
|
if token: |
|
|
return await get_current_user_by_jwt(token, db) |
|
|
user = await api_key_security(query_param, header_param) |
|
|
if user: |
|
|
return user |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_403_FORBIDDEN, |
|
|
detail="Invalid or missing API key", |
|
|
) |
|
|
|
|
|
|
|
|
async def get_current_user_by_jwt( |
|
|
token: str, |
|
|
db: AsyncSession, |
|
|
) -> User: |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
if isinstance(token, Coroutine): |
|
|
token = await token |
|
|
|
|
|
secret_key = settings_service.auth_settings.SECRET_KEY.get_secret_value() |
|
|
if secret_key is None: |
|
|
logger.error("Secret key is not set in settings.") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
|
|
detail="Authentication failure: Verify authentication settings.", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
try: |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM]) |
|
|
user_id: UUID = payload.get("sub") |
|
|
token_type: str = payload.get("type") |
|
|
if expires := payload.get("exp", None): |
|
|
expires_datetime = datetime.fromtimestamp(expires, timezone.utc) |
|
|
if datetime.now(timezone.utc) > expires_datetime: |
|
|
logger.info("Token expired for user") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Token has expired.", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
if user_id is None or token_type is None: |
|
|
logger.info(f"Invalid token payload. Token type: {token_type}") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Invalid token details.", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
except JWTError as e: |
|
|
logger.exception("JWT decoding error") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Could not validate credentials", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) from e |
|
|
|
|
|
user = await get_user_by_id(db, user_id) |
|
|
if user is None or not user.is_active: |
|
|
logger.info("User not found or inactive.") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="User not found or is inactive.", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
return user |
|
|
|
|
|
|
|
|
async def get_current_user_for_websocket( |
|
|
websocket: WebSocket, |
|
|
db: Annotated[AsyncSession, Depends(get_session)], |
|
|
query_param: Annotated[str, Security(api_key_query)], |
|
|
) -> User | None: |
|
|
token = websocket.query_params.get("token") |
|
|
api_key = websocket.query_params.get("x-api-key") |
|
|
if token: |
|
|
return await get_current_user_by_jwt(token, db) |
|
|
if api_key: |
|
|
return await api_key_security(api_key, query_param) |
|
|
return None |
|
|
|
|
|
|
|
|
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): |
|
|
if not current_user.is_active: |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") |
|
|
return current_user |
|
|
|
|
|
|
|
|
async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User: |
|
|
if not current_user.is_active: |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") |
|
|
if not current_user.is_superuser: |
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The user doesn't have enough privileges") |
|
|
return current_user |
|
|
|
|
|
|
|
|
def verify_password(plain_password, hashed_password): |
|
|
settings_service = get_settings_service() |
|
|
return settings_service.auth_settings.pwd_context.verify(plain_password, hashed_password) |
|
|
|
|
|
|
|
|
def get_password_hash(password): |
|
|
settings_service = get_settings_service() |
|
|
return settings_service.auth_settings.pwd_context.hash(password) |
|
|
|
|
|
|
|
|
def create_token(data: dict, expires_delta: timedelta): |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
to_encode = data.copy() |
|
|
expire = datetime.now(timezone.utc) + expires_delta |
|
|
to_encode["exp"] = expire |
|
|
|
|
|
return jwt.encode( |
|
|
to_encode, |
|
|
settings_service.auth_settings.SECRET_KEY.get_secret_value(), |
|
|
algorithm=settings_service.auth_settings.ALGORITHM, |
|
|
) |
|
|
|
|
|
|
|
|
async def create_super_user( |
|
|
username: str, |
|
|
password: str, |
|
|
db: AsyncSession, |
|
|
) -> User: |
|
|
super_user = await get_user_by_username(db, username) |
|
|
|
|
|
if not super_user: |
|
|
super_user = User( |
|
|
username=username, |
|
|
password=get_password_hash(password), |
|
|
is_superuser=True, |
|
|
is_active=True, |
|
|
last_login_at=None, |
|
|
) |
|
|
|
|
|
db.add(super_user) |
|
|
await db.commit() |
|
|
await db.refresh(super_user) |
|
|
|
|
|
return super_user |
|
|
|
|
|
|
|
|
async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]: |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
username = settings_service.auth_settings.SUPERUSER |
|
|
super_user = await get_user_by_username(db, username) |
|
|
if not super_user: |
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created") |
|
|
access_token_expires_longterm = timedelta(days=365) |
|
|
access_token = create_token( |
|
|
data={"sub": str(super_user.id), "type": "access"}, |
|
|
expires_delta=access_token_expires_longterm, |
|
|
) |
|
|
|
|
|
|
|
|
await update_user_last_login_at(super_user.id, db) |
|
|
|
|
|
return super_user.id, { |
|
|
"access_token": access_token, |
|
|
"refresh_token": None, |
|
|
"token_type": "bearer", |
|
|
} |
|
|
|
|
|
|
|
|
def create_user_api_key(user_id: UUID) -> dict: |
|
|
access_token = create_token( |
|
|
data={"sub": str(user_id), "type": "api_key"}, |
|
|
expires_delta=timedelta(days=365 * 2), |
|
|
) |
|
|
|
|
|
return {"api_key": access_token} |
|
|
|
|
|
|
|
|
def get_user_id_from_token(token: str) -> UUID: |
|
|
try: |
|
|
user_id = jwt.get_unverified_claims(token)["sub"] |
|
|
return UUID(user_id) |
|
|
except (KeyError, JWTError, ValueError): |
|
|
return UUID(int=0) |
|
|
|
|
|
|
|
|
async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict: |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS) |
|
|
access_token = create_token( |
|
|
data={"sub": str(user_id), "type": "access"}, |
|
|
expires_delta=access_token_expires, |
|
|
) |
|
|
|
|
|
refresh_token_expires = timedelta(seconds=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_SECONDS) |
|
|
refresh_token = create_token( |
|
|
data={"sub": str(user_id), "type": "refresh"}, |
|
|
expires_delta=refresh_token_expires, |
|
|
) |
|
|
|
|
|
|
|
|
if update_last_login: |
|
|
await update_user_last_login_at(user_id, db) |
|
|
|
|
|
return { |
|
|
"access_token": access_token, |
|
|
"refresh_token": refresh_token, |
|
|
"token_type": "bearer", |
|
|
} |
|
|
|
|
|
|
|
|
async def create_refresh_token(refresh_token: str, db: AsyncSession): |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
try: |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
payload = jwt.decode( |
|
|
refresh_token, |
|
|
settings_service.auth_settings.SECRET_KEY.get_secret_value(), |
|
|
algorithms=[settings_service.auth_settings.ALGORITHM], |
|
|
) |
|
|
user_id: UUID = payload.get("sub") |
|
|
token_type: str = payload.get("type") |
|
|
|
|
|
if user_id is None or token_type == "": |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") |
|
|
|
|
|
user_exists = await get_user_by_id(db, user_id) |
|
|
|
|
|
if user_exists is None: |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") |
|
|
|
|
|
return await create_user_tokens(user_id, db) |
|
|
|
|
|
except JWTError as e: |
|
|
logger.exception("JWT decoding error") |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Invalid refresh token", |
|
|
) from e |
|
|
|
|
|
|
|
|
async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None: |
|
|
user = await get_user_by_username(db, username) |
|
|
|
|
|
if not user: |
|
|
return None |
|
|
|
|
|
if not user.is_active: |
|
|
if not user.last_login_at: |
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Waiting for approval") |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") |
|
|
|
|
|
return user if verify_password(password, user.password) else None |
|
|
|
|
|
|
|
|
def add_padding(s): |
|
|
|
|
|
padding_needed = 4 - len(s) % 4 |
|
|
return s + "=" * padding_needed |
|
|
|
|
|
|
|
|
def ensure_valid_key(s: str) -> bytes: |
|
|
|
|
|
if len(s) < MINIMUM_KEY_LENGTH: |
|
|
|
|
|
random.seed(s) |
|
|
|
|
|
key = bytes(random.getrandbits(8) for _ in range(32)) |
|
|
key = base64.urlsafe_b64encode(key) |
|
|
else: |
|
|
key = add_padding(s).encode() |
|
|
return key |
|
|
|
|
|
|
|
|
def get_fernet(settings_service: SettingsService): |
|
|
secret_key: str = settings_service.auth_settings.SECRET_KEY.get_secret_value() |
|
|
valid_key = ensure_valid_key(secret_key) |
|
|
return Fernet(valid_key) |
|
|
|
|
|
|
|
|
def encrypt_api_key(api_key: str, settings_service: SettingsService): |
|
|
fernet = get_fernet(settings_service) |
|
|
|
|
|
encrypted_key = fernet.encrypt(api_key.encode()) |
|
|
return encrypted_key.decode() |
|
|
|
|
|
|
|
|
def decrypt_api_key(encrypted_api_key: str, settings_service: SettingsService): |
|
|
fernet = get_fernet(settings_service) |
|
|
decrypted_key = "" |
|
|
|
|
|
if isinstance(encrypted_api_key, str): |
|
|
try: |
|
|
decrypted_key = fernet.decrypt(encrypted_api_key.encode()).decode() |
|
|
except Exception: |
|
|
logger.debug("Failed to decrypt API key") |
|
|
decrypted_key = fernet.decrypt(encrypted_api_key).decode() |
|
|
return decrypted_key |
|
|
|