BankChurn3 / app.py
XRachel's picture
Upload 7 files
5d3fcf2 verified
import os
import subprocess
import gradio as gr
import pandas as pd
import joblib
MODEL_PATH="models/pipeline.joblib"
def load_model():
if os.path.exists(MODEL_PATH):
return joblib.load(MODEL_PATH)
return None
model=load_model()
def predict(age,balance):
global model
if model is None:
return "Model not trained yet. Run pipeline first."
df=pd.DataFrame([[age,balance]],columns=["Age","Balance"])
p=model.predict(df)[0]
return f"Prediction: {p}"
def run_pipeline():
proc=subprocess.Popen(
["python","scripts/pipeline.py"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True
)
log=""
for line in proc.stdout:
log+=line
yield log
def build_ui():
css=open("style.css").read()
with gr.Blocks() as demo:
gr.HTML(f"<style>{css}</style>")
gr.Markdown("# Bank Churn Demo")
with gr.Tab("Pipeline"):
btn=gr.Button("Run Pipeline")
log=gr.Textbox(lines=20,label="Execution Log")
btn.click(run_pipeline,outputs=log)
with gr.Tab("Prediction"):
age=gr.Number(label="Age")
balance=gr.Number(label="Balance")
btn2=gr.Button("Predict")
out=gr.Textbox()
btn2.click(predict,[age,balance],out)
return demo
if __name__=="__main__":
demo=build_ui()
demo.queue()
port=int(os.environ.get("PORT",7860))
demo.launch(server_name="0.0.0.0",server_port=port,show_api=False)