Mantles2 / app.py
Sirivennela's picture
Update app.py
a556612 verified
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from risk_model import predict_risk
from datetime import datetime
import pytz
history_df = pd.DataFrame(columns=["Max_Temperature", "Duration", "Risk_Level", "Confidence", "Timestamp"])
def get_interpretation(risk_level):
if risk_level == "High":
return "⚠️ High risk of overheating! Immediate action recommended."
elif risk_level == "Moderate":
return "⚠️ Moderate risk detected. Monitor closely."
else:
return "Low risk. Operation is within safe limits."
def get_alert(risk_level):
if risk_level == "High":
return "ALERT: High risk detected!"
elif risk_level == "Moderate":
return "Warning: Moderate risk detected."
else:
return "No alert. Risk level is low."
def get_bullet_points(risk_level, confidence):
points = [
f"• Risk Level: {risk_level}",
f"• Confidence: {confidence}%",
f"• {get_alert(risk_level)}",
f"• {get_interpretation(risk_level)}"
]
return "\n".join(points)
def gradio_predict(temp, duration):
global history_df
risk_label, confidence = predict_risk(temp, duration)
tz = pytz.timezone('Asia/Kolkata')
timestamp = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
new_entry = {
"Max_Temperature": temp,
"Duration": duration,
"Risk_Level": risk_label,
"Confidence": round(confidence, 1),
"Timestamp": timestamp
}
history_df = pd.concat([history_df, pd.DataFrame([new_entry])], ignore_index=True)
chart = create_chart(history_df)
bullets = get_bullet_points(risk_label, round(confidence, 1))
# Debug: print counts of risk levels
counts = history_df['Risk_Level'].value_counts()
print("Risk Level Counts:\n", counts)
return risk_label, round(confidence, 1), bullets, history_df, chart
def reset_all():
global history_df
history_df = pd.DataFrame(columns=["Max_Temperature", "Duration", "Risk_Level", "Confidence", "Timestamp"])
return "", "", "", pd.DataFrame(), None
def create_chart(df):
if df.empty:
return None
counts = df['Risk_Level'].value_counts()
# Ensure all risk levels present and in correct order
counts = counts.reindex(['Low', 'Moderate', 'High'], fill_value=0)
print("Reindexed Risk Level Counts for chart:\n", counts) # Debug print
fig, ax = plt.subplots(figsize=(6, 4))
bars = ax.bar(counts.index, counts.values, color=['#2ca02c', '#ff7f0e', '#d62728'])
ax.set_xlabel('Risk Level')
ax.set_ylabel('Number of Predictions')
ax.set_title('Prediction Counts by Risk Level')
for bar in bars:
height = bar.get_height()
ax.annotate(f'{int(height)}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom')
plt.tight_layout()
return fig
with gr.Blocks() as demo:
reset_btn = gr.Button("Reset")
with gr.Row():
temp_input = gr.Number(label="Max Temperature (°C)", value=25, precision=1, minimum=0)
duration_input = gr.Number(label="Duration (minutes)", value=10, precision=1, minimum=0)
predict_btn = gr.Button("Predict")
risk_status = gr.Textbox(label="Risk Status", interactive=False)
risk_score = gr.Textbox(label="Risk Score (%)", interactive=False)
bullet_points = gr.Textbox(label="Summary", interactive=False, lines=5)
history_table = gr.Dataframe(label="Prediction History")
risk_chart = gr.Plot(label="Risk Level Chart")
predict_btn.click(
fn=gradio_predict,
inputs=[temp_input, duration_input],
outputs=[risk_status, risk_score, bullet_points, history_table, risk_chart]
)
reset_btn.click(
fn=reset_all,
inputs=[],
outputs=[risk_status, risk_score, bullet_points, history_table, risk_chart]
)
if __name__ == "__main__":
demo.launch()