File size: 3,566 Bytes
41dfc48
5e7c541
41dfc48
5e7c541
41dfc48
 
 
 
 
 
5e7c541
41dfc48
 
 
 
d7da01b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41dfc48
 
 
 
 
 
 
 
 
 
5e7c541
 
 
41dfc48
 
 
 
 
 
 
 
 
310c607
 
 
41dfc48
 
 
 
 
 
5e7c541
 
41dfc48
 
 
 
 
 
 
5e7c541
 
 
41dfc48
 
 
5e7c541
41dfc48
 
 
 
5e7c541
41dfc48
 
 
 
5e7c541
41dfc48
 
 
 
 
5e7c541
41dfc48
 
5e7c541
41dfc48
5e7c541
41dfc48
 
 
 
 
 
 
 
 
5e7c541
 
41dfc48
 
 
 
 
 
 
 
 
 
 
 
 
5e7c541
41dfc48
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# src/payslip/utils.py
from datetime import date, datetime
from typing import Optional

from dateutil.relativedelta import relativedelta
from fastapi import Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select

from src.core.database import get_async_session
from src.core.models import Users
from src.core.config import settings


from cryptography.fernet import Fernet
from src.core.config import settings


fernet = Fernet(settings.FERNET_KEY.encode())


def encrypt_token(token: str) -> str:
    """Encrypts a refresh token before saving to DB."""
    return fernet.encrypt(token.encode()).decode()


def decrypt_token(token: str) -> str:
    """Decrypts a stored refresh token when needed."""
    return fernet.decrypt(token.encode()).decode()

bearer_scheme = HTTPBearer()

SECRET_KEY = settings.SECRET_KEY
ALGORITHM = settings.JWT_ALGORITHM


def _parse_month(month_str: str) -> date:
    """
    "2024-05" -> date(2024, 5, 1)
    """
    try:
        d = datetime.strptime(month_str, "%Y-%m")
        return date(d.year, d.month, 1)
    except ValueError:
        raise HTTPException(
            status_code=400,
            detail="Invalid month format. Use YYYY-MM, e.g. 2024-05",
        )


def validate_join_date(join_date: Optional[str], period_start: date):
    if not join_date:
        join = date(2020, 4, 1)
    else:
        join = datetime.strptime(join_date, "%Y-%m-%d").date()

    if period_start < join:
        raise HTTPException(
            400,
            f"You joined on {join}. You cannot request payslips before joining date.",
        )


def calculate_period(mode: str, start_month: str = None, end_month: str = None):
    """
    mode:
      - "3_months"
      - "6_months"
      - "manual" + start_month, end_month in "YYYY-MM"
    """
    today = date.today()

    if mode == "3_months":
        end = today.replace(day=1)
        start = end - relativedelta(months=3)
        return start, end

    if mode == "6_months":
        end = today.replace(day=1)
        start = end - relativedelta(months=6)
        return start, end

    if mode == "manual":
        # Validate fields
        if not start_month or not end_month:
            raise HTTPException(400, "Manual mode requires start_month and end_month")

        try:
            start = datetime.strptime(start_month, "%Y-%m").date()
            end = datetime.strptime(end_month, "%Y-%m").date()
        except ValueError:
            raise HTTPException(400, "Invalid month format. Use YYYY-MM")

        if start > end:
            raise HTTPException(400, "Start month cannot be after end month")

        return start, end

    # Invalid mode
    raise HTTPException(400, "Invalid payslip request mode")


async def get_current_user_model(
    credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
    session: AsyncSession = Depends(get_async_session),
):
    token = credentials.credentials

    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        user_id = payload.get("sub")

        if not user_id:
            raise HTTPException(401, "Invalid token")

        result = await session.execute(select(Users).where(Users.id == user_id))
        user = result.scalar_one_or_none()

        if not user:
            raise HTTPException(401, "User not found")

        return user

    except JWTError:
        raise HTTPException(401, "Invalid or expired token")