Policy / src /sf_tools.py
vishalkatheriya's picture
Upload 6 files
adf61d6 verified
from snowflake.connector import connect # type: ignore[import-untyped]
from datetime import datetime
import os
import json
from dotenv import load_dotenv
# 🔑 Load .env file
load_dotenv()
# -----------------------------
# Snowflake connection config
# -----------------------------
def get_snowflake_connection():
conn = connect(
user=os.getenv("SNOWFLAKE_USER"),
password=os.getenv("SNOWFLAKE_PASSWORD"),
account=os.getenv("SNOWFLAKE_ACCOUNT"),
role=os.getenv("SNOWFLAKE_ROLE"),
warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"),
database=os.getenv("SNOWFLAKE_DATABASE"),
schema=os.getenv("SNOWFLAKE_SCHEMA")
)
return conn
# -----------------------------
# Fetch Employee Details
# -----------------------------
def get_employee(employee_id: str):
"""Fetch employee details from Snowflake."""
print("get_employee called")
conn = get_snowflake_connection()
cur = conn.cursor()
try:
cur.execute("""
SELECT employee_id, country, employment_status
FROM employees
WHERE employee_id = %s
""", (employee_id,))
row = cur.fetchone()
if not row:
return {"error": f"Employee '{employee_id}' not found"}
employee = {
"employee_id": row[0],
"country": row[1],
"employment_status": row[2]
}
if employee["employment_status"] != "active":
return {"error": f"Employee '{employee_id}' is not active"}
# Fetch leave balances
cur.execute("""
SELECT leave_type, balance
FROM leave_balances
WHERE employee_id = %s
""", (employee_id,))
balances = {r[0]: r[1] for r in cur.fetchall()}
employee["leave_balances"] = balances
return employee
finally:
cur.close()
conn.close()
# -----------------------------
# Fetch Leave Policy
# -----------------------------
def get_leave_policy(country: str, leave_type: str):
"""Fetch leave policy from Snowflake."""
print("get_leave_policy called")
conn = get_snowflake_connection()
cur = conn.cursor()
try:
cur.execute("""
SELECT policy
FROM leave_policies
WHERE country = %s AND leave_type = %s
""", (country, leave_type))
row = cur.fetchone()
if not row:
# fetch available leave types
cur.execute("""
SELECT leave_type
FROM leave_policies
WHERE country = %s
""", (country,))
leave_types = [r[0] for r in cur.fetchall()]
return {
"error": f"Leave type '{leave_type}' not found",
"available_leave_types": leave_types
}
# policy is VARIANT JSON in Snowflake (may be dict or JSON string)
raw = row[0]
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
return json.loads(raw)
return dict(raw)
finally:
cur.close()
conn.close()
# -----------------------------
# Check Leave Eligibility
# -----------------------------
def check_leave_eligibility(employee_id: str, leave_type: str, requested_days: int):
"""Check if employee is eligible to take leave."""
print("check_leave_eligibility called")
if requested_days <= 0:
return {"eligible": False, "reason": "Requested leave days must be > 0"}
employee = get_employee(employee_id)
if "error" in employee:
return employee
country = employee["country"]
policy = get_leave_policy(country, leave_type)
if "error" in policy:
return policy
balance = employee["leave_balances"].get(leave_type, 0)
# Max consecutive days
if "max_consecutive_days" in policy:
if requested_days > policy["max_consecutive_days"]:
return {
"eligible": False,
"reason": f"Maximum consecutive days allowed for {leave_type} is {policy['max_consecutive_days']}"
}
# Annual allowance
if "annual_allowance" in policy:
if requested_days > policy["annual_allowance"]:
return {
"eligible": False,
"reason": "Requested days exceed annual allowance"
}
# Available balance
if requested_days > balance:
return {
"eligible": False,
"reason": "Insufficient leave balance",
"available_balance": balance
}
return {
"eligible": True,
"employee_id": employee_id,
"leave_type": leave_type,
"approved_days": requested_days,
"remaining_balance": balance - requested_days
}