Detections / app.py
prernajeet01's picture
Update app.py
405dcde verified
import gradio as gr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import OneClassSVM
from sklearn.cluster import KMeans
from sklearn.preprocessing import scale
import matplotlib
matplotlib.use('Agg')
import io
import base64
from PIL import Image
import os
from openai import OpenAI
# Path to the CSV file in the environment
CSV_PATH = 'FI_Transactions.csv'
# Get OpenAI API key from environment variables
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)
def detect_anomalies(nu_value, n_clusters):
# Read the CSV file from the environment
df = pd.read_csv(CSV_PATH)
# Data preprocessing and scaling
features = df.select_dtypes(include=[np.number])
feature_names = features.columns.tolist()
scaled_features = scale(features)
# Train One-Class SVM for anomaly detection
svm_model = OneClassSVM(kernel='rbf', nu=nu_value, gamma='scale')
svm_model.fit(scaled_features)
# Predict anomalies
svm_preds = svm_model.predict(scaled_features)
df['SVM_Anomaly'] = ['Anomaly' if x == -1 else 'Normal' for x in svm_preds]
# Count anomalies
anomaly_count = (df['SVM_Anomaly'] == 'Anomaly').sum()
normal_count = (df['SVM_Anomaly'] == 'Normal').sum()
# Train KMeans for clustering
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(scaled_features)
df['KMeans_Cluster'] = kmeans.labels_
# Create visualizations
# 1. Anomalies count pie chart
plt.figure(figsize=(8, 6))
plt.pie([anomaly_count, normal_count], labels=['Anomalies', 'Normal'], autopct='%1.1f%%', colors=['#FF9999', '#66B2FF'])
plt.title('SVM Anomaly Detection Results')
pie_chart_img = plt_to_img()
# 2. KMeans clustering scatter plot (using first two features)
plt.figure(figsize=(10, 6))
x_feature = 0 if len(feature_names) > 0 else 0
y_feature = 1 if len(feature_names) > 1 else 0
scatter = plt.scatter(scaled_features[:, x_feature],
scaled_features[:, y_feature],
c=kmeans.labels_,
cmap='viridis',
alpha=0.7)
plt.colorbar(scatter, label='Cluster')
plt.title('KMeans Clustering Results')
plt.xlabel(feature_names[x_feature] if len(feature_names) > x_feature else "Feature 1")
plt.ylabel(feature_names[y_feature] if len(feature_names) > y_feature else "Feature 2")
kmeans_img = plt_to_img()
# 3. SVM anomalies scatter plot
plt.figure(figsize=(10, 6))
colors = ['red' if x == 'Anomaly' else 'blue' for x in df['SVM_Anomaly']]
plt.scatter(scaled_features[:, x_feature],
scaled_features[:, y_feature],
c=colors,
alpha=0.7)
plt.title('SVM Anomaly Detection')
plt.xlabel(feature_names[x_feature] if len(feature_names) > x_feature else "Feature 1")
plt.ylabel(feature_names[y_feature] if len(feature_names) > y_feature else "Feature 2")
plt.legend(['Anomaly', 'Normal'])
svm_img = plt_to_img()
# Create summary dataframe of anomalies
anomalies_df = df[df['SVM_Anomaly'] == 'Anomaly'].reset_index()
# Get AI insights about anomalies using OpenAI API
ai_insights = get_ai_insights(df, anomalies_df)
# Convert the dataframe to HTML for display
anomalies_html = anomalies_df.to_html(classes='table table-striped')
# Create HTML summary
summary_html = f"""
<h3>Analysis Summary</h3>
<p>Total transactions: {len(df)}</p>
<p>Anomalies detected: {anomaly_count} ({anomaly_count/len(df)*100:.2f}%)</p>
<p>Normal transactions: {normal_count} ({normal_count/len(df)*100:.2f}%)</p>
<h3>AI Insights</h3>
<p>{ai_insights}</p>
"""
return pie_chart_img, kmeans_img, svm_img, summary_html, anomalies_html
def get_ai_insights(df, anomalies_df):
"""Get insights about the anomalies using OpenAI API"""
try:
if not OPENAI_API_KEY:
return "OpenAI API key not found in environment variables. AI insights are unavailable."
# Prepare information about the dataset and anomalies
df_info = df.describe().to_string()
anomaly_info = anomalies_df.head(5).to_string() if not anomalies_df.empty else "No anomalies detected"
# Create a prompt for the OpenAI API
prompt = f"""
Analyze the following financial transaction data and detected anomalies:
Dataset Statistics:
{df_info}
Sample Anomalies (top 5):
{anomaly_info}
Please provide:
1. Possible patterns or reasons for these anomalies
2. Recommendations for further investigation
3. Potential risk factors these anomalies might indicate
Keep your analysis concise and focused on financial fraud detection.
"""
# Call the OpenAI API using the new client format
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a financial fraud detection expert."},
{"role": "user", "content": prompt}
],
max_tokens=500
)
# Extract and return the insights
return response.choices[0].message.content
except Exception as e:
return f"Could not generate AI insights. Error: {str(e)}"
def plt_to_img():
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
return img
# Create the Gradio interface
with gr.Blocks(title="Financial Transaction Anomaly Detection") as demo:
gr.Markdown("# Financial Transaction Anomaly Detection")
gr.Markdown(f"Analyzing data from {CSV_PATH}")
with gr.Row():
with gr.Column():
nu_slider = gr.Slider(0.01, 0.2, value=0.05, step=0.01, label="SVM nu parameter (controls anomaly threshold)")
cluster_slider = gr.Slider(2, 10, value=2, step=1, label="Number of KMeans clusters")
detect_button = gr.Button("Detect Anomalies")
with gr.Column():
summary_output = gr.HTML(label="Summary")
with gr.Row():
pie_output = gr.Image(label="Anomaly Distribution")
svm_output = gr.Image(label="SVM Anomaly Detection")
with gr.Row():
kmeans_output = gr.Image(label="KMeans Clustering")
with gr.Row():
anomalies_output = gr.HTML(label="Detected Anomalies")
detect_button.click(
detect_anomalies,
inputs=[nu_slider, cluster_slider],
outputs=[pie_output, kmeans_output, svm_output, summary_output, anomalies_output]
)
gr.Markdown("""
## How to Use
1. Adjust the SVM nu parameter (controls anomaly detection sensitivity)
2. Choose the number of clusters for KMeans
3. Click 'Detect Anomalies' to analyze the data
## Interpretation
- The pie chart shows the proportion of normal vs anomalous transactions
- The scatter plots visualize the clusters and anomalies
- The AI insights provide expert analysis of detected anomalies
- The table displays detailed information about detected anomalies
""")
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)