Abs6187's picture
Update app.py
8809fef verified
import gradio as gr
import joblib
import pandas as pd
import numpy as np
import os
import json
from PIL import Image
import google.generativeai as genai
# --- Gemini API Configuration ---
# It's recommended to set the GOOGLE_API_KEY as an environment variable for security.
try:
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
else:
print("Warning: GOOGLE_API_KEY environment variable not found. The application will not be able to predict churn.")
except Exception as e:
print(f"Error configuring Gemini API: {e}")
# --- Model and Asset Loading (Kept for reference, but not used in the new prediction function) ---
def load_models():
"""Loads the original ML model, encoders, and scaler from disk."""
try:
model = joblib.load('models/churn_model.pkl')
encoders = joblib.load('models/label_encoders.pkl')
scaler = joblib.load('models/scaler.pkl')
return model, encoders, scaler
except FileNotFoundError:
print("Warning: Original model files not found in 'models/' directory.")
return None, None, None
except Exception as e:
print(f"An unexpected error occurred while loading local models: {e}")
return None, None, None
model, encoders, scaler = load_models()
def load_image(image_name):
"""Loads an image from the 'images' folder."""
try:
img_path = os.path.join("images", image_name)
return Image.open(img_path) if os.path.exists(img_path) else None
except Exception as e:
print(f"Error loading image {image_name}: {e}")
return None
# --- Feature Constants ---
REGIONS = ['North', 'South', 'East', 'West', 'Central']
PLAN_TYPES = ['Prepaid', 'Postpaid']
CONTRACT_TYPES = ['Month-to-month', 'One year', 'Two year']
COMPLAINT_STATUS = ['Open', 'Closed', 'Not Applicable']
PAYMENT_METHODS = ['Electronic check', 'Mailed check', 'Bank transfer', 'Credit card']
# --- Prediction Logic using Gemini API ---
def predict_churn(customer_id, region, plan_type, monthly_charges, total_charges,
tenure_months, contract_type, paperless_billing, payment_method,
data_usage_gb, call_minutes, sms_count, complaint_status, complaint_count):
"""Predicts customer churn using the Gemini API and generates a detailed result."""
if not GOOGLE_API_KEY:
return "๐Ÿ”ด **Error:** `GOOGLE_API_KEY` not configured. Please set it as an environment variable to enable predictions.", 0.0
try:
# 1. Construct a detailed prompt for the Gemini API
prompt = f"""
You are an expert telecom churn prediction analyst AI. Your analysis must be sharp, concise, and business-focused.
Based on the following customer data, predict the churn probability.
**Customer Data:**
- **Customer ID:** {customer_id}
- **Region:** {region}
- **Plan Type:** {plan_type}
- **Contract Type:** {contract_type}
- **Tenure:** {tenure_months} months
- **Monthly Charges:** โ‚น{monthly_charges}
- **Total Charges:** โ‚น{total_charges}
- **Data Usage:** {data_usage_gb} GB/month
- **Call Minutes:** {call_minutes} mins/month
- **SMS Count:** {sms_count} texts/month
- **Payment Method:** {payment_method}
- **Paperless Billing:** {'Yes' if paperless_billing else 'No'}
- **Last Complaint Status:** {complaint_status}
- **Total Complaint Count:** {complaint_count}
**Analysis Guidelines:**
- **High Risk Factors:** Month-to-month contracts, short tenure (especially < 12 months), high monthly charges relative to tenure, and any open complaints are strong indicators of churn.
- **Low Risk Factors:** Long-term contracts (One year, Two year), long tenure, a high ratio of total charges to monthly charges, and no complaint history suggest a loyal customer.
Return your response **only** as a single, valid JSON object with two keys:
1. "churn_probability": a float number between 0.0 and 1.0.
2. "analysis": a brief, professional, single-paragraph analysis explaining the key factors that influenced your prediction for this specific customer.
**Example JSON output:**
{{
"churn_probability": 0.72,
"analysis": "This customer is at a high risk of churning primarily due to their month-to-month contract, which offers little long-term commitment. Combined with a relatively short tenure of 12 months, they are susceptible to competitive offers. Proactive engagement with a long-term plan incentive is recommended."
}}
"""
# 2. Call the Gemini API
gemini_model = genai.GenerativeModel('gemini-2.5-flash')
response = gemini_model.generate_content(prompt)
# 3. Parse the JSON response
# Clean up potential markdown formatting from the response
response_text = response.text.strip().replace("```json", "").replace("```", "")
parsed_response = json.loads(response_text)
churn_probability = float(parsed_response['churn_probability'])
analysis = parsed_response['analysis']
# 4. Determine risk level and format the final output
if churn_probability > 0.7:
risk_level = "๐Ÿ”ด HIGH RISK"
elif churn_probability > 0.4:
risk_level = "๐ŸŸก MEDIUM RISK"
else:
risk_level = "๐ŸŸข LOW RISK"
result = f"""
### Prediction for Customer `{customer_id}`
- **Churn Risk Level:** **{risk_level}**
- **Probability of Churn:** **{churn_probability:.1%}**
- **AI-Powered Analysis:** {analysis}
"""
return result, churn_probability
except json.JSONDecodeError as e:
print(f"JSON Parsing Error: {e}")
print(f"Gemini Response that caused the error: {response.text}")
return "๐Ÿ”ด **Error:** Could not parse the analysis from the AI model. The response was not in the expected format.", 0.0
except Exception as e:
print(f"An unexpected error occurred during Gemini prediction: {e}")
return "An unexpected error occurred while contacting the AI model. Please check the logs and try again.", 0.0
# --- Gradio UI ---
with gr.Blocks(title="Telecom Churn Prediction - BRBRAITT Group 5", theme=gr.themes.Soft()) as app:
# Header and other UI elements... (No changes needed here, keeping it the same as before)
gr.Markdown("""
# ๐Ÿ”ฎ Telecom Churn Prediction System
**TIRTC Course: Advance AI/ML Training (Nokia) | Institution: BRBRAITT, Jabalpur | Group 5**
---
This AI-powered system predicts customer churn using Google's Gemini model for real-time analysis.
""")
with gr.Tabs():
# Tab 1: Prediction Interface
with gr.TabItem("๐ŸŽฏ Churn Prediction"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Enter Customer Details")
with gr.Row():
customer_id = gr.Textbox(label="Customer ID", value="CUST-001")
region = gr.Dropdown(choices=REGIONS, label="Region", value="North")
plan_type = gr.Dropdown(choices=PLAN_TYPES, label="Plan Type", value="Postpaid")
with gr.Row():
contract_type = gr.Dropdown(choices=CONTRACT_TYPES, label="Contract Type", value="Month-to-month")
payment_method = gr.Dropdown(choices=PAYMENT_METHODS, label="Payment Method", value="Electronic check")
paperless_billing = gr.Checkbox(label="Paperless Billing", value=True)
gr.Markdown("#### Service Usage & Charges")
with gr.Row():
monthly_charges = gr.Number(label="Monthly Charges (โ‚น)", value=1000)
total_charges = gr.Number(label="Total Charges (โ‚น)", value=12000)
tenure_months = gr.Number(label="Tenure (Months)", value=12)
with gr.Row():
data_usage_gb = gr.Number(label="Data Usage (GB)", value=15)
call_minutes = gr.Number(label="Call Minutes", value=500)
sms_count = gr.Number(label="SMS Count", value=100)
gr.Markdown("#### Customer Complaints")
with gr.Row():
complaint_status = gr.Dropdown(choices=COMPLAINT_STATUS, label="Last Complaint Status", value="Not Applicable")
complaint_count = gr.Number(label="Total Complaint Count", value=0)
predict_btn = gr.Button("๐Ÿ”ฎ Predict Churn Risk", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“Š Prediction Result")
prediction_output = gr.Markdown(value="*Results will be displayed here...*")
probability_gauge = gr.Number(label="Churn Probability", value=0.0, interactive=False)
predict_btn.click(
fn=predict_churn,
inputs=[customer_id, region, plan_type, monthly_charges, total_charges,
tenure_months, contract_type, paperless_billing, payment_method,
data_usage_gb, call_minutes, sms_count, complaint_status, complaint_count],
outputs=[prediction_output, probability_gauge]
)
# Tab 2: Business Insights
with gr.TabItem("๐Ÿ’ก Business Insights"):
gr.Markdown("### Key Findings & Recommendations")
gr.Markdown("""
#### ๐ŸŽฏ Model Performance
- **Accuracy:** 90%
- **AUC Score:** 0.95
- **Best Algorithm:** Random Forest Classifier
#### ๐Ÿ’ผ Business Impact
- **Monthly Revenue at Risk:** Over โ‚น12,250.
- **Potential Annual Loss:** Over โ‚น147,000.
#### ๐Ÿ”ด Top Churn Drivers
1. **Contract Type:** `Month-to-month` customers are at highest risk.
2. **Tenure:** New customers (0-12 months) are most likely to churn.
3. **Complaints:** A single open complaint doubles the likelihood of churn.
""")
# Tab 3: Visualizations
with gr.TabItem("๐Ÿ“ˆ Visualizations"):
gr.Markdown("### Data Analysis Dashboard")
image_files = [
("churn_distribution.png", "Overall Churn Distribution"),
("churn_by_contract.png", "Churn by Contract Type"),
("revenue_vs_churn.png", "Revenue Impact Analysis"),
("complaints_analysis.png", "Complaints Impact on Churn"),
("correlation_matrix.png", "Feature Correlation Matrix"),
]
for img_file, title in image_files:
img = load_image(img_file)
if img:
gr.Image(img, label=title, show_label=True)
else:
gr.Markdown(f"*{title} - Image not available*")
# Tab 4: About Project
with gr.TabItem("โ„น๏ธ About"):
gr.Markdown("""
### ๐ŸŽ“ Academic Project Details
- **Course:** TIRTC - Advance AI/ML Training (Nokia)
- **Institution:** BRBRAITT, Jabalpur
- **Team (Group 5):** Abhay Gupta, Jay Kumar, Kripanshu Gupta, Ruhy Namdeo
- **Tech Stack:** Scikit-learn, Pandas, Gradio, Gemini, Hugging Face
---
**๐Ÿ† Project Status:** Complete | **๐Ÿ“… Last Updated:** October 2025 | **๐Ÿ”ข Version:** 1.3.0
""")
gr.Markdown("--- \n ยฉ 2025 BRBRAITT Group 5 | TIRTC Advance AI/ML Training")
# Launch the app
if __name__ == "__main__":
app.launch(share=True)