Litellm-proxy / app /auth.py
NitinBot001's picture
Update app/auth.py
1a76e37 verified
import os
import bcrypt
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from cryptography.fernet import Fernet
from sqlalchemy.orm import Session
from .database import get_db
from . import models
# ── JWT ────────────────────────────────────────────────────────────────────────
SECRET_KEY = os.getenv("SECRET_KEY", "CHANGE_ME_super_secret_key_32bytes!")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 60 * 24 * 7))
# ── Fernet encryption for API keys ─────────────────────────────────────────────
_raw_fernet_key = os.getenv("FERNET_KEY", "")
if not _raw_fernet_key:
_raw_fernet_key = Fernet.generate_key().decode()
print(f"[WARN] FERNET_KEY not set. Generated key (add to .env): {_raw_fernet_key}")
fernet = Fernet(_raw_fernet_key.encode() if isinstance(_raw_fernet_key, str) else _raw_fernet_key)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
# ── Password hashing (pure bcrypt β€” passlib nahi) ──────────────────────────────
def hash_password(plain: str) -> str:
# bcrypt max 72 bytes β€” safely truncate
secret = plain.encode("utf-8")[:72]
return bcrypt.hashpw(secret, bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
secret = plain.encode("utf-8")[:72]
return bcrypt.checkpw(secret, hashed.encode("utf-8"))
# ── Fernet helpers ─────────────────────────────────────────────────────────────
def encrypt_api_key(api_key: str) -> str:
return fernet.encrypt(api_key.encode()).decode()
def decrypt_api_key(encrypted: str) -> str:
return fernet.decrypt(encrypted.encode()).decode()
# ── JWT helpers ────────────────────────────────────────────────────────────────
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
to_encode["exp"] = expire
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
) -> models.User:
exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if not username:
raise exc
except JWTError:
raise exc
user = db.query(models.User).filter(models.User.username == username).first()
if not user:
raise exc
return user