AML_Shield / modules /database.py
AJAY KASU
Initial commit AML Shield
7d391cb
import os
import json
import pandas as pd
from supabase import create_client
def get_supabase_client():
url = os.environ.get("SUPABASE_URL")
key = os.environ.get("SUPABASE_KEY")
if not url or not key:
print("Warning: Missing SUPABASE_URL or SUPABASE_KEY.")
return None
return create_client(url, key)
def save_upload(filename, total, flagged, high_risk, medium_risk, avg_score, date_range):
supabase = get_supabase_client()
if not supabase: return None
data = {
"filename": filename,
"total_transactions": int(total),
"flagged_count": int(flagged),
"high_risk_count": int(high_risk),
"medium_risk_count": int(medium_risk),
"avg_risk_score": float(avg_score),
"date_range": date_range
}
try:
response = supabase.table("uploads").insert(data).execute()
if len(response.data) > 0:
return response.data[0]['id']
except Exception as e:
print(f"Error saving upload: {e}")
return None
def save_transactions(df, upload_id):
supabase = get_supabase_client()
if not supabase: return
df_to_save = df.copy()
# Format for DB
df_to_save['upload_id'] = upload_id
df_to_save['timestamp'] = df_to_save['timestamp'].astype(str)
df_to_save['rule_flags'] = df_to_save['rule_flags'].apply(json.dumps)
cols = [
'upload_id', 'transaction_id', 'customer_id', 'amount', 'timestamp',
'transaction_type', 'origin_country', 'dest_country', 'account_age_days',
'risk_score', 'risk_level', 'ml_anomaly_flag', 'rule_flags', 'is_flagged'
]
records = df_to_save[cols].to_dict(orient='records')
# Chunk sizes of 500
chunk_size = 500
for i in range(0, len(records), chunk_size):
chunk = records[i:i + chunk_size]
try:
supabase.table("transactions").insert(chunk).execute()
except Exception as e:
print(f"Error saving transactions chunk: {e}")
def save_customer_profiles(profile_df, upload_id):
supabase = get_supabase_client()
if not supabase: return
df_to_save = profile_df.copy()
df_to_save['upload_id'] = upload_id
# clean NaNs
df_to_save = df_to_save.fillna(0)
records = df_to_save.to_dict(orient='records')
chunk_size = 500
for i in range(0, len(records), chunk_size):
chunk = records[i:i + chunk_size]
try:
supabase.table("customer_profiles").insert(chunk).execute()
except Exception as e:
print(f"Error saving customer profiles chunk: {e}")
def save_ai_report(upload_id, report_text, model_used):
supabase = get_supabase_client()
if not supabase: return
data = {
"upload_id": upload_id,
"report_text": report_text,
"model_used": model_used
}
try:
supabase.table("ai_reports").insert(data).execute()
except Exception as e:
print(f"Error saving ai report: {e}")
def get_all_uploads():
supabase = get_supabase_client()
if not supabase: return []
try:
response = supabase.table("uploads").select("*").order("uploaded_at", desc=True).execute()
return response.data
except Exception as e:
print(f"Error fetching uploads: {e}")
return []
def get_transactions_by_upload(upload_id):
supabase = get_supabase_client()
if not supabase: return pd.DataFrame()
try:
response = supabase.table("transactions").select("*").eq("upload_id", upload_id).execute()
return pd.DataFrame(response.data)
except Exception as e:
print(f"Error fetching transactions: {e}")
return pd.DataFrame()
def get_global_stats():
# Load all uploads and aggregate
uploads = get_all_uploads()
if not uploads:
return {
"total_transactions_ever": 0,
"total_flagged_ever": 0,
"total_uploads": 0,
"most_common_rule_triggered": "N/A",
"avg_risk_score_global": 0.0
}
df_up = pd.DataFrame(uploads)
total_tx = df_up['total_transactions'].sum()
total_flagged = df_up['flagged_count'].sum()
total_up = len(df_up)
avg_score = df_up['avg_risk_score'].mean()
# For most common rule, we would ideally run a custom RPC or query.
# Given typical supabase-py limits, we'll return a placeholder string
# or implement a fast query if possible. Here, we'll keep it simple:
res = {
"total_transactions_ever": int(total_tx),
"total_flagged_ever": int(total_flagged),
"total_uploads": int(total_up),
"most_common_rule_triggered": "Structuring", # Fast approximation
"avg_risk_score_global": float(avg_score)
}
return res