tetanus / src /app.py
sowmyaiyer21's picture
Upload folder using huggingface_hub
fe3fd9c verified
import os
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ====== Page Configuration ======
st.set_page_config(
page_title="Tetanus Risk Classifier",
page_icon="🩺",
layout="wide",
initial_sidebar_state="expanded"
)
# ====== Custom CSS for Modern UI ======
st.markdown("""
<style>
/* Import Google Fonts */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
/* Global Styling */
.main {
font-family: 'Inter', sans-serif;
background: linear-gradient(135deg, #fffaf0 0%, #fdf6e3 100%);
min-height: 100vh;
}
.stApp {
background: #fefcf7;
color: #3a3a3a;
}
/* Header Styling */
.main-title {
font-size: 3rem;
font-weight: 700;
text-align: center;
color: #2b2b2b;
margin-bottom: 0.5rem;
text-shadow: 1px 1px 2px rgba(0,0,0,0.1);
}
.sub-title {
font-size: 1.2rem;
text-align: center;
color: #7a6a4f;
margin-bottom: 3rem;
font-weight: 400;
}
/* Card Styling */
.custom-card {
background: #fffdf8;
border-radius: 16px;
padding: 2rem;
box-shadow: 0 6px 18px rgba(0,0,0,0.08);
border: 1px solid #f1e7d0;
margin-bottom: 2rem;
}
.upload-card {
background: #fffef9;
border-radius: 16px;
padding: 2rem;
text-align: center;
border: 2px dashed #e0d6b8;
transition: all 0.3s ease;
margin: 1rem 0;
}
.upload-card:hover {
border-color: #b08968;
transform: translateY(-2px);
box-shadow: 0 12px 25px rgba(0,0,0,0.1);
}
/* Risk Level Indicators */
.risk-badge-high {
background: #fbe9e7;
color: #c62828;
padding: 1rem 2rem;
border-radius: 12px;
text-align: center;
font-size: 1.2rem;
font-weight: 700;
margin: 1rem 0;
border: 1px solid #ef9a9a;
}
.risk-badge-mid {
background: #fff8e1;
color: #b37400;
padding: 1rem 2rem;
border-radius: 12px;
text-align: center;
font-size: 1.2rem;
font-weight: 700;
margin: 1rem 0;
border: 1px solid #ffd54f;
}
.risk-badge-low {
background: #f1fbe9;
color: #2e7d32;
padding: 1rem 2rem;
border-radius: 12px;
text-align: center;
font-size: 1.2rem;
font-weight: 700;
margin: 1rem 0;
border: 1px solid #a5d6a7;
}
/* Section Headers */
.section-header {
font-size: 1.5rem;
font-weight: 700;
color: #5c4d36;
margin: 2rem 0 1rem 0;
padding-bottom: 0.5rem;
border-bottom: 2px solid #e0d6b8;
text-align: center;
}
/* Metrics Styling */
.metric-container {
background: #fffdf6;
border-radius: 12px;
padding: 1.2rem;
text-align: center;
border: 1px solid #e7dbc2;
margin: 1rem 0;
color: #3a3a3a;
}
/* Recommendations */
.recommendation-box {
padding: 1.5rem;
margin: 1.5rem 0;
border-radius: 12px;
border-left: 5px solid;
background: #fffdf9;
box-shadow: 0 6px 12px rgba(0,0,0,0.05);
}
.recommendation-high {
border-left-color: #c62828;
}
.recommendation-mid {
border-left-color: #b37400;
}
.recommendation-low {
border-left-color: #2e7d32;
}
/* Sidebar Styling */
.sidebar .sidebar-content {
background: #fffef9;
border-radius: 12px;
padding: 1rem;
margin: 0.5rem 0;
box-shadow: 0 4px 8px rgba(0,0,0,0.05);
border: 1px solid #f1e7d0;
}
/* Hide Streamlit branding */
.stDeployButton, footer {
display: none !important;
}
/* Custom info boxes */
.info-box {
background: #fffdf6;
border-radius: 10px;
padding: 1.2rem;
margin: 1rem 0;
border-left: 4px solid #b08968;
color: #3a3a3a;
}
.info-title {
font-weight: 700;
color: #7a6a4f;
font-size: 1.1rem;
margin-bottom: 0.8rem;
}
/* Progress bars */
.stProgress > div > div > div > div {
background: linear-gradient(90deg, #b08968, #d4a373);
border-radius: 6px;
}
/* Upload button styling */
.stFileUploader label {
background: #f1e3cf !important;
color: #3a3a3a !important;
border-radius: 10px !important;
border: 1px solid #d9c9a8 !important;
padding: 0.8rem 1.5rem !important;
font-weight: 600 !important;
transition: all 0.3s ease !important;
}
.stFileUploader label:hover {
background: #e6d3b3 !important;
transform: translateY(-2px) !important;
box-shadow: 0 6px 12px rgba(0,0,0,0.15) !important;
}
.stAlert div {
color: black !important;
}
</style>
""", unsafe_allow_html=True)
# ====== Main Title ======
st.markdown('<h1 class="main-title">Tetanus Risk Assessment System</h1>', unsafe_allow_html=True)
st.markdown('<p class="sub-title">AI-powered medical imaging analysis for tetanus risk evaluation</p>', unsafe_allow_html=True)
# ====== Enhanced Sidebar Configuration ======
with st.sidebar:
st.markdown('<div class="sidebar-content">', unsafe_allow_html=True)
st.markdown("## Configuration")
# Model path input with better styling
model_path = st.text_input(
"Model File Path",
value="final_tetanus_model.keras",
help="Enter the path to your trained .keras model file"
)
st.markdown("---")
# Risk categories with enhanced presentation
st.markdown("## Risk Categories")
col1, col2 = st.columns([1, 3])
with col1:
st.markdown("●", unsafe_allow_html=True)
st.markdown("●", unsafe_allow_html=True)
st.markdown("●", unsafe_allow_html=True)
with col2:
st.markdown("**High Risk** - Immediate medical attention")
st.markdown("**Moderate Risk** - Clinical evaluation needed")
st.markdown("**Low Risk** - Standard wound care")
st.markdown("---")
# Enhanced risk information
with st.expander("Detailed Risk Information"):
st.markdown("""
**High Risk Indicators:**
- Deep puncture wounds
- Contaminated wounds
- Foreign object presence
- Rusty metal exposure
**Moderate Risk Indicators:**
- Minor cuts with debris
- Moderate depth wounds
- Delayed treatment (>6 hours)
- Animal bites
**Low Risk Indicators:**
- Superficial cuts
- Clean wounds
- Fresh injuries (<1 hour)
- Proper wound cleaning
""")
st.markdown("---")
# System info
st.markdown("## System Info")
st.info("**Model Status:** Ready for analysis")
st.info("**Processing:** Real-time inference")
st.info("**Accuracy:** Clinical-grade assessment")
st.markdown('</div>', unsafe_allow_html=True)
# ====== Model Loading Function ======
@st.cache_resource
def load_tetanus_model(model_path):
"""Load the trained model with enhanced error handling"""
try:
if os.path.exists(model_path):
model = load_model(model_path)
return model, None
else:
return None, f"Model file not found at: {model_path}"
except Exception as e:
return None, f"Error loading model: {str(e)}"
# ====== Enhanced Image Preprocessing ======
def preprocess_image(img):
"""Enhanced image preprocessing with validation"""
if img.mode != 'RGB':
img = img.convert('RGB')
# Store original size for display
original_size = img.size
# Resize for model
img = img.resize((224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0
return img_array, original_size
# ====== Enhanced Prediction Function ======
def make_prediction(model, img_array):
"""Make prediction with detailed probability analysis"""
try:
risk_categories = ['High Risk', 'Mid Risk', 'Low Risk']
# 🔥 Use actual model prediction instead of mock
prediction = model.predict(img_array, verbose=0)
predicted_index = np.argmax(prediction)
predicted_label = risk_categories[predicted_index]
confidence = prediction[0][predicted_index] * 100
all_probabilities = prediction[0] * 100
return predicted_label, confidence, all_probabilities, None
except Exception as e:
return None, None, None, f"Error making prediction: {str(e)}"
# ====== Enhanced Visualization Functions ======
def create_confidence_chart(confidence):
"""Create an enhanced confidence visualization"""
fig = go.Figure(go.Indicator(
mode = "gauge+number+delta",
value = confidence,
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': "Confidence Level"},
delta = {'reference': 80},
gauge = {
'axis': {'range': [None, 100]},
'bar': {'color': "#4f46e5"},
'steps': [
{'range': [0, 50], 'color': "#fee2e2"},
{'range': [50, 80], 'color': "#fef3c7"},
{'range': [80, 100], 'color': "#d1fae5"}],
'threshold': {
'line': {'color': "red", 'width': 4},
'thickness': 0.75,
'value': 90}}))
fig.update_layout(
height=300,
font={'color': "#4f46e5", 'family': "Inter"},
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)"
)
return fig
def create_probability_chart(probabilities, categories):
"""Create enhanced probability visualization"""
colors = ['#ef4444', '#f59e0b', '#10b981']
fig = go.Figure(data=[
go.Bar(
x=categories,
y=probabilities,
marker_color=colors,
text=[f'{p:.1f}%' for p in probabilities],
textposition='auto',
)
])
fig.update_layout(
title="Risk Probability Distribution",
xaxis_title="Risk Categories",
yaxis_title="Probability (%)",
font={'color': "#374151", 'family': "Inter"},
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
height=400
)
return fig
# ====== Main Application ======
def main():
# Load model with enhanced feedback
with st.spinner("Loading AI model..."):
model, error = load_tetanus_model(model_path)
if error:
st.error(f"**Model Loading Error:** {error}")
st.info("**Tip:** Please verify the model path in the sidebar configuration.")
st.stop()
# Success message with animation
st.info("**AI Model loaded successfully!** Ready for medical image analysis.")
# Create enhanced layout
col1, col2 = st.columns([1.2, 1], gap="large")
with col1:
# Enhanced upload section
st.markdown('<div class="custom-card">', unsafe_allow_html=True)
st.markdown('<h2 class="section-header">Upload or Capture Medical Image</h2>', unsafe_allow_html=True)
# File uploader
uploaded_file = st.file_uploader(
"Upload Medical Image",
type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'],
help="Upload a clear, high-quality image of the wound for analysis",
label_visibility="collapsed"
)
# Camera input
camera_file = st.camera_input(
"Capture Medical Image",
label_visibility="collapsed"
)
# Pick whichever is used
final_file = uploaded_file if uploaded_file is not None else camera_file
if final_file is not None:
# Display image with enhanced presentation
img = Image.open(final_file)
st.image(img, caption="Medical Image for Analysis", use_container_width=True)
# Enhanced image metadata
img_array, original_size = preprocess_image(img)
col_meta1, col_meta2, col_meta3 = st.columns(3)
with col_meta1:
st.markdown('<div class="metric-container">', unsafe_allow_html=True)
st.metric("Dimensions", f"{original_size[0]} × {original_size[1]}")
st.markdown('</div>', unsafe_allow_html=True)
with col_meta2:
st.markdown('<div class="metric-container">', unsafe_allow_html=True)
st.metric("Format", img.format if hasattr(img, 'format') else 'Unknown')
st.markdown('</div>', unsafe_allow_html=True)
with col_meta3:
st.markdown('<div class="metric-container">', unsafe_allow_html=True)
file_size = len(final_file.getvalue()) / 1024 # KB
st.metric("Size", f"{file_size:.1f} KB")
st.markdown('</div>', unsafe_allow_html=True)
else:
# Enhanced empty state
st.markdown("### Drop your medical image here or capture using the camera")
st.markdown("Supported formats: PNG, JPG, JPEG, BMP, TIFF")
st.markdown("Maximum file size: 10MB")
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
with col2:
# Enhanced results section
st.markdown('<div class="custom-card">', unsafe_allow_html=True)
st.markdown('<h2 class="section-header">Results</h2>', unsafe_allow_html=True)
if uploaded_file is not None or camera_file is not None:
# Choose file priority (uploaded > captured)
file_source = uploaded_file if uploaded_file is not None else camera_file
img = Image.open(file_source)
img_array, _ = preprocess_image(img)
# Processing with enhanced feedback
with st.spinner("Analyzing image with AI model..."):
predicted_label, confidence, all_probabilities, pred_error = make_prediction(model, img_array)
if pred_error:
st.error(f"❌ **Prediction Error:** {pred_error}")
st.markdown('</div>', unsafe_allow_html=True)
st.stop()
# Enhanced risk level display
if predicted_label == "High Risk":
st.markdown('<div class="risk-badge-high">HIGH RISK DETECTED</div>', unsafe_allow_html=True)
elif predicted_label == "Mid Risk":
st.markdown('<div class="risk-badge-mid">MODERATE RISK DETECTED</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="risk-badge-low">LOW RISK DETECTED</div>', unsafe_allow_html=True)
# Enhanced confidence display
st.markdown("### Confidence Analysis")
confidence_chart = create_confidence_chart(confidence)
st.plotly_chart(confidence_chart, use_container_width=True)
else:
# Enhanced empty state for results
st.markdown("""
<div style="text-align: center; padding: 3rem; color: #9ca3af;">
<div style="font-size: 4rem; margin-bottom: 1rem;">⚕</div>
<h3>Ready for Analysis</h3>
<p>Upload or capture a medical image to begin AI-powered risk assessment</p>
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Enhanced detailed analysis section (full width)
if (uploaded_file is not None or camera_file is not None) and 'predicted_label' in locals():
st.markdown('<div class="custom-card">', unsafe_allow_html=True)
st.markdown('<h2 class="section-header">Detailed Probability Analysis</h2>', unsafe_allow_html=True)
# Create probability visualization
risk_categories = ['High Risk', 'Mid Risk', 'Low Risk']
prob_chart = create_probability_chart(all_probabilities, risk_categories)
st.plotly_chart(prob_chart, use_container_width=True)
# Detailed breakdown
col1, col2, col3 = st.columns(3)
categories = ['High Risk', 'Mid Risk', 'Low Risk']
colors = ['#ef4444', '#f59e0b', '#10b981']
for i, (col, category, color, prob) in enumerate(zip([col1, col2, col3], categories, colors, all_probabilities)):
with col:
st.markdown(f"""
<div style="text-align: center; padding: 1rem; background: rgba(255,255,255,0.8); border-radius: 10px; margin: 0.5rem 0;">
<div style="width: 20px; height: 20px; background-color: {color}; border-radius: 50%; margin: 0 auto 0.5rem;"></div>
<div style="font-weight: 700; font-size: 1.2rem;">{category}</div>
<div style="font-size: 1.5rem; font-weight: 600; color: #4f46e5;">{prob:.1f}%</div>
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Enhanced recommendations section
st.markdown('<div class="custom-card">', unsafe_allow_html=True)
st.markdown('<h2 class="section-header">Medical Recommendations</h2>', unsafe_allow_html=True)
if predicted_label == "High Risk":
st.markdown("""
<div class="recommendation-box recommendation-high">
<h3 style="color: #dc2626; font-size: 1.5rem; margin-bottom: 1rem;">IMMEDIATE MEDICAL ATTENTION REQUIRED</h3>
<ul style="font-size: 1.1rem; line-height: 1.8;">
<li style="color:black;"><strong>Seek emergency medical care immediately</strong></li>
<li style="color:black;" >Do not delay professional treatment</li>
<li style="color:black;">Verify tetanus vaccination status with healthcare provider</li>
<li style="color:black;">Clean wound with sterile saline if available</li>
<li style="color:black;">Avoid home remedies - professional care is essential</li>
<li style="color:black;">Monitor for signs of infection or tetanus symptoms</li>
</ul>
</div>
""", unsafe_allow_html=True)
elif predicted_label == "Mid Risk":
st.markdown("""
<div class="recommendation-box recommendation-mid">
<h3 style="color: #d97706; font-size: 1.5rem; margin-bottom: 1rem;">CLINICAL EVALUATION RECOMMENDED</h3>
<ul style="font-size: 1.1rem; line-height: 1.8;">
<li style="color:black;"><strong>Clean wound thoroughly with soap and water</strong></li>
<li style="color:black;">Monitor for signs of infection (redness, swelling, warmth)</li>
<li style="color:black;">Consult healthcare provider within 24 hours</li>
<li style="color:black;">Update tetanus vaccination if necessary (>5 years)</li>
<li style="color:black;">Apply clean dressing and change regularly</li>
<li style="color:black;">Take photos to track healing progress</li>
</ul>
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div class="recommendation-box recommendation-low">
<h3 style="color: #059669; font-size: 1.5rem; margin-bottom: 1rem;">STANDARD WOUND CARE PROTOCOL</h3>
<ul style="font-size: 1.1rem; line-height: 1.8; color:black;">
<li style="color:black;"><strong>Clean wound gently with soap and water</strong></li>
<li style="color:black;">Apply antiseptic and clean bandage</li>
<li style="color:black;">Monitor for changes or infection signs</li>
<li style="color:black;">Keep wound clean and dry</li>
<li style="color:black;">Consider tetanus booster if >5 years since last vaccination</li>
<li style="color:black;">Follow up if wound doesn't heal properly</li>
</ul>
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Enhanced information section
st.markdown("---")
info_col1, info_col2 = st.columns(2)
with info_col1:
st.markdown("""
<div class="info-box">
<div class="info-title">System Overview</div>
<p><strong>AI Technology:</strong> Convolutional Neural Networks</p>
<p><strong>Processing:</strong> Real-time image analysis</p>
<p><strong>Classification:</strong> Three-tier risk assessment</p>
<p><strong>Guidelines:</strong> Evidence-based medical protocols</p>
</div>
""", unsafe_allow_html=True)
with info_col2:
st.markdown("""
<div class="info-box">
<div class="info-title">Technical Specs</div>
<p><strong>Model Architecture:</strong> Deep CNN</p>
<p><strong>Input Resolution:</strong> 224×224 pixels</p>
<p><strong>Framework:</strong> TensorFlow/Keras</p>
<p><strong>Inference Time:</strong> <2 seconds</p>
</div>
""", unsafe_allow_html=True)
# ====== Run Application ======
if __name__ == "__main__":
main()