nephro2 / app.py
anujkum0x's picture
Create app.py
b3d2671 verified
import pandas as pd
import numpy as np
import pulp as pl
import matplotlib.pyplot as plt
import gradio as gr
from itertools import product
import io
import base64
import tempfile
import os
from datetime import datetime
def am_pm(hour):
"""Converts 24-hour time to AM/PM format."""
period = "AM"
if hour >= 12:
period = "PM"
if hour > 12:
hour -= 12
elif hour == 0:
hour = 12 # Midnight
return f"{int(hour):02d}:00 {period}"
def show_dataframe(csv_path):
"""Reads a CSV file and returns a Pandas DataFrame."""
try:
df = pd.read_csv(csv_path)
return df
except Exception as e:
return f"Error loading CSV: {e}"
def optimize_staffing(
csv_file,
beds_per_staff,
max_hours_per_staff,
hours_per_cycle,
rest_days_per_week,
clinic_start,
clinic_end,
overlap_time,
max_start_time_change
):
# Load data
try:
if isinstance(csv_file, str):
# Handle the case when a filepath is passed directly
data = pd.read_csv(csv_file)
else:
# Handle the case when file object is uploaded through Gradio
data = pd.read_csv(io.StringIO(csv_file.decode('utf-8')))
except Exception as e:
print(f"Error loading CSV file: {e}")
return f"Error loading CSV file: {e}", None, None, None, None
# Rename the index column if necessary
if data.columns[0] not in ['day', 'Day', 'DAY']:
data = data.rename(columns={data.columns[0]: 'day'})
# Fill missing values
for col in data.columns:
if col.startswith('cycle'):
data[col].fillna(0, inplace=True)
# Calculate clinic hours
if clinic_end < clinic_start:
clinic_hours = 24 - clinic_start + clinic_end
else:
clinic_hours = clinic_end - clinic_start
# Parameters
BEDS_PER_STAFF = float(beds_per_staff)
MAX_HOURS_PER_STAFF = float(max_hours_per_staff)
HOURS_PER_CYCLE = float(hours_per_cycle)
REST_DAYS_PER_WEEK = int(rest_days_per_week)
SHIFT_TYPES = [6, 8, 10, 12] # Standard shift types
OVERLAP_TIME = float(overlap_time)
CLINIC_START = int(clinic_start)
CLINIC_END = int(clinic_end)
CLINIC_HOURS = clinic_hours
MAX_START_TIME_CHANGE = int(max_start_time_change)
# Calculate staff needed per cycle (beds/BEDS_PER_STAFF, rounded up)
for col in data.columns:
if col.startswith('cycle') and not col.endswith('_staff'):
data[f'{col}_staff'] = np.ceil(data[col] / BEDS_PER_STAFF)
# Get cycle names and number of cycles
cycle_cols = [col for col in data.columns if col.startswith('cycle') and not col.endswith('_staff')]
num_cycles = len(cycle_cols)
# Define cycle times
cycle_times = {}
for i, cycle in enumerate(cycle_cols):
cycle_start = (CLINIC_START + i * HOURS_PER_CYCLE) % 24
cycle_end = (CLINIC_START + (i + 1) * HOURS_PER_CYCLE) % 24
cycle_times[cycle] = (cycle_start, cycle_end)
# Get staff requirements
max_staff_needed = max([data[f'{cycle}_staff'].max() for cycle in cycle_cols])
# Define possible shift start times
shift_start_times = list(range(CLINIC_START, CLINIC_START + int(CLINIC_HOURS) - min(SHIFT_TYPES) + 1))
# Generate all possible shifts
possible_shifts = []
for duration in SHIFT_TYPES:
for start_time in shift_start_times:
end_time = (start_time + duration) % 24
# Create a shift with its coverage of cycles
shift = {
'id': f"{duration}hr_{start_time:02d}",
'start': start_time,
'end': end_time,
'duration': duration,
'cycles_covered': set()
}
# Determine which cycles this shift covers
for cycle, (cycle_start, cycle_end) in cycle_times.items():
# Handle overnight cycles
if cycle_end < cycle_start: # overnight cycle
if start_time >= cycle_start or end_time <= cycle_end or (start_time < end_time and end_time > cycle_start):
shift['cycles_covered'].add(cycle)
else: # normal cycle
shift_end = end_time if end_time > start_time else end_time + 24
cycle_end_adj = cycle_end if cycle_end > cycle_start else cycle_end + 24
# Check for overlap
if not (shift_end <= cycle_start or start_time >= cycle_end_adj):
shift['cycles_covered'].add(cycle)
if shift['cycles_covered']: # Only add shifts that cover at least one cycle
possible_shifts.append(shift)
# Estimate minimum number of staff needed
total_staff_hours = 0
for _, row in data.iterrows():
for cycle in cycle_cols:
total_staff_hours += row[f'{cycle}_staff'] * HOURS_PER_CYCLE
min_staff_estimate = np.ceil(total_staff_hours / MAX_HOURS_PER_STAFF)
# Get number of days in the dataset
num_days = len(data)
# Add some buffer for constraints like rest days and shift changes
estimated_staff = max(min_staff_estimate, max_staff_needed + 1)
def optimize_schedule(num_staff):
# Create a binary linear programming model
model = pl.LpProblem("Staff_Scheduling", pl.LpMinimize)
# Decision variables
# x[s,d,shift] = 1 if staff s works shift on day d
x = pl.LpVariable.dicts("shift",
[(s, d, shift['id']) for s in range(1, num_staff+1)
for d in range(1, num_days+1)
for shift in possible_shifts],
cat='Binary')
# Objective: Minimize total staff hours while ensuring coverage
model += pl.lpSum(x[(s, d, shift['id'])] * shift['duration']
for s in range(1, num_staff+1)
for d in range(1, num_days+1)
for shift in possible_shifts)
# Constraint: Each staff works at most one shift per day
for s in range(1, num_staff+1):
for d in range(1, num_days+1):
model += pl.lpSum(x[(s, d, shift['id'])] for shift in possible_shifts) <= 1
# Constraint: Each staff has at least one rest day per week
for s in range(1, num_staff+1):
for w in range((num_days + 6) // 7): # Number of weeks
week_start = w*7 + 1
week_end = min(week_start + 6, num_days)
model += pl.lpSum(x[(s, d, shift['id'])]
for d in range(week_start, week_end+1)
for shift in possible_shifts) <= (week_end - week_start + 1) - REST_DAYS_PER_WEEK
# Constraint: Each staff works at most MAX_HOURS_PER_STAFF in the period
for s in range(1, num_staff+1):
model += pl.lpSum(x[(s, d, shift['id'])] * shift['duration']
for d in range(1, num_days+1)
for shift in possible_shifts) <= MAX_HOURS_PER_STAFF
# Constraint: Each cycle has enough staff each day
for d in range(1, num_days+1):
day_index = d - 1 # 0-indexed for DataFrame
for cycle in cycle_cols:
staff_needed = data.iloc[day_index][f'{cycle}_staff']
# Get all shifts that cover this cycle
covering_shifts = [shift for shift in possible_shifts if cycle in shift['cycles_covered']]
model += pl.lpSum(x[(s, d, shift['id'])]
for s in range(1, num_staff+1)
for shift in covering_shifts) >= staff_needed
# Solve model with a time limit
model.solve(pl.PULP_CBC_CMD(timeLimit=300, msg=0))
# Check if a feasible solution was found
if model.status == pl.LpStatusOptimal or model.status == pl.LpStatusNotSolved:
# Extract the solution
schedule = []
for s in range(1, num_staff+1):
for d in range(1, num_days+1):
for shift in possible_shifts:
if pl.value(x[(s, d, shift['id'])]) == 1:
# Find the shift details
shift_details = next((sh for sh in possible_shifts if sh['id'] == shift['id']), None)
schedule.append({
'staff_id': s,
'day': d,
'shift_id': shift['id'],
'start': shift_details['start'],
'end': shift_details['end'],
'duration': shift_details['duration'],
'cycles_covered': list(shift_details['cycles_covered'])
})
return schedule, model.objective.value()
else:
return None, None
# Try to solve with estimated number of staff
staff_count = int(estimated_staff)
results = f"Trying with {staff_count} staff...\n"
schedule, objective = optimize_schedule(staff_count)
# If no solution found, increment staff count until a solution is found
while schedule is None and staff_count < 15: # Cap at 15 to avoid infinite loop
staff_count += 1
results += f"Trying with {staff_count} staff...\n"
schedule, objective = optimize_schedule(staff_count)
if schedule is None:
results += "Failed to find a feasible solution. Try relaxing some constraints."
return results, None, None, None, None
results += f"Optimal solution found with {staff_count} staff\n"
results += f"Total staff hours: {objective}\n"
# Convert to DataFrame for analysis
schedule_df = pd.DataFrame(schedule)
# Analyze staff workload
staff_hours = {}
for s in range(1, staff_count+1):
staff_shifts = schedule_df[schedule_df['staff_id'] == s]
total_hours = staff_shifts['duration'].sum()
staff_hours[s] = total_hours
results += "\nStaff Hours:\n"
for staff_id, hours in staff_hours.items():
utilization = (hours / MAX_HOURS_PER_STAFF) * 100
results += f"Staff {staff_id}: {hours} hours ({utilization:.1f}% utilization)\n"
avg_utilization = sum(staff_hours.values()) / (staff_count * MAX_HOURS_PER_STAFF) * 100
results += f"\nAverage staff utilization: {avg_utilization:.1f}%\n"
# Check coverage for each day and cycle
coverage_check = []
for d in range(1, num_days+1):
day_index = d - 1 # 0-indexed for DataFrame
day_schedule = schedule_df[schedule_df['day'] == d]
for cycle in cycle_cols:
required = data.iloc[day_index][f'{cycle}_staff']
# Count staff covering this cycle
assigned = sum(1 for _, shift in day_schedule.iterrows()
if cycle in shift['cycles_covered'])
coverage_check.append({
'day': d,
'cycle': cycle,
'required': required,
'assigned': assigned,
'satisfied': assigned >= required
})
coverage_df = pd.DataFrame(coverage_check)
satisfaction = coverage_df['satisfied'].mean() * 100
results += f"Coverage satisfaction: {satisfaction:.1f}%\n"
if satisfaction < 100:
results += "Warning: Not all staffing requirements are met!\n"
unsatisfied = coverage_df[~coverage_df['satisfied']]
results += unsatisfied.to_string() + "\n"
# Generate detailed schedule report
detailed_schedule = "Detailed Schedule:\n"
for d in range(1, num_days+1):
day_schedule = schedule_df[schedule_df['day'] == d]
day_schedule = day_schedule.sort_values(['start'])
detailed_schedule += f"\nDay {d}:\n"
for _, shift in day_schedule.iterrows():
start_hour = shift['start']
end_hour = shift['end']
start_str = am_pm(start_hour)
end_str = am_pm(end_hour)
cycles = ", ".join(shift['cycles_covered'])
detailed_schedule += f" Staff {shift['staff_id']}: {start_str}-{end_str} ({shift['duration']} hrs), Cycles: {cycles}\n"
# Generate schedule visualization
fig, ax = plt.subplots(figsize=(15, 8))
# Prepare schedule for plotting
staff_days = {}
for s in range(1, staff_count+1):
staff_days[s] = [0] * num_days # 0 means off duty
for _, shift in schedule_df.iterrows():
staff_id = shift['staff_id']
day = shift['day'] - 1 # 0-indexed
staff_days[staff_id][day] = shift['duration']
# Plot the schedule
for s, hours in staff_days.items():
ax.bar(range(1, num_days+1), hours, label=f'Staff {s}')
ax.set_xlabel('Day')
ax.set_ylabel('Shift Hours')
ax.set_title('Staff Schedule')
ax.set_xticks(range(1, num_days+1))
ax.legend()
# Save the figure to a temporary file
plot_path = None
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
plt.savefig(f.name)
plt.close(fig)
plot_path = f.name
# Create a Gantt chart
gantt_fig, gantt_ax = plt.subplots(figsize=(30, 12)) # Increased figure width
# Set up colors for each staff
colors = plt.cm.tab20.colors # Use a visually distinct color palette
# Sort by staff then day
schedule_df['start_ampm'] = schedule_df['start'].apply(am_pm)
schedule_df['end_ampm'] = schedule_df['end'].apply(am_pm)
schedule_df = schedule_df.sort_values(['staff_id', 'day'])
# Plot Gantt chart
for staff_id in range(1, staff_count+1):
staff_shifts = schedule_df[schedule_df['staff_id'] == staff_id]
y_pos = staff_id
for i, shift in staff_shifts.iterrows():
day = shift['day']
start_hour = shift['start']
end_hour = shift['end']
duration = shift['duration']
start_ampm = shift['start_ampm']
end_ampm = shift['end_ampm']
# Handle overnight shifts
if end_hour < start_hour: # Overnight shift
gantt_ax.broken_barh([(day-1 + start_hour/24, (24-start_hour)/24),
(day, end_hour/24)],
(y_pos-0.3, 0.6), # Increased bar height
facecolors=colors[staff_id % len(colors)])
else:
gantt_ax.broken_barh([(day-1 + start_hour/24, duration/24)],
(y_pos-0.3, 0.6), # Increased bar height
facecolors=colors[staff_id % len(colors)])
# Staggered text labels
text_y_offset = 0.1 if (i % 2) == 0 else -0.1 # Alternate label position
# Add text label - prioritize staff ID, add time range if space allows
text_label = f"Staff {staff_id}"
if duration > 6: # Adjust this threshold as needed
text_label += f"\n{start_ampm}-{end_ampm}"
gantt_ax.text(day-1 + start_hour/24 + duration/48, y_pos + text_y_offset,
text_label,
horizontalalignment='center', verticalalignment='center', fontsize=7) # Slightly smaller font
gantt_ax.set_xlabel('Day')
gantt_ax.set_yticks(range(1, staff_count+1))
gantt_ax.set_yticklabels([f'Staff {s}' for s in range(1, staff_count+1)])
gantt_ax.set_xlim(0, num_days)
gantt_ax.set_title('Staff Schedule (Full Period)')
gantt_ax.grid(False) # Remove grid lines
# Save the Gantt chart
gantt_path = None
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
gantt_fig.savefig(f.name)
plt.close(gantt_fig)
gantt_path = f.name
# Convert schedule to CSV data
schedule_df['start_ampm'] = schedule_df['start'].apply(am_pm)
schedule_df['end_ampm'] = schedule_df['end'].apply(am_pm)
schedule_csv = schedule_df[['staff_id', 'day', 'start_ampm', 'end_ampm', 'duration', 'cycles_covered']].to_csv(index=False)
# Create a temporary file and write the CSV data into it
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as temp_file:
temp_file.write(schedule_csv)
schedule_csv_path = temp_file.name
# Create staff assignment table
staff_assignment_data = []
for d in range(1, num_days + 1):
cycle_staff = {}
for cycle in cycle_cols:
# Get staff IDs assigned to this cycle on this day
staff_ids = schedule_df[(schedule_df['day'] == d) & (schedule_df['cycles_covered'].apply(lambda x: cycle in x))]['staff_id'].tolist()
cycle_staff[cycle] = len(staff_ids)
staff_assignment_data.append([d] + [cycle_staff[cycle] for cycle in cycle_cols])
staff_assignment_df = pd.DataFrame(staff_assignment_data, columns=['Day'] + cycle_cols)
staff_assignment_csv = staff_assignment_df.to_csv(index=False)
# Create a temporary file and write the CSV data into it
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as temp_file:
temp_file.write(staff_assignment_csv)
staff_assignment_csv_path = temp_file.name
return results, plot_path, schedule_csv_path, gantt_path, staff_assignment_csv_path
def convert_to_24h(time_str):
"""Converts AM/PM time string to 24-hour format."""
try:
time_obj = datetime.strptime(time_str, "%I:00 %p")
return time_obj.hour
except ValueError:
return None
def gradio_wrapper(
csv_file, beds_per_staff, max_hours_per_staff, hours_per_cycle,
rest_days_per_week, clinic_start_ampm, clinic_end_ampm, overlap_time, max_start_time_change
):
clinic_start = convert_to_24h(clinic_start_ampm)
clinic_end = convert_to_24h(clinic_end_ampm)
results, plot_img, schedule_csv_path, gantt_path, staff_assignment_csv_path = optimize_staffing(
csv_file, beds_per_staff, max_hours_per_staff, hours_per_cycle,
rest_days_per_week, clinic_start, clinic_end, overlap_time, max_start_time_change
)
# Load plot images if they exist
plot_img = plot_img if plot_img and os.path.exists(plot_img) else None
gantt_img = gantt_path if gantt_path and os.path.exists(gantt_path) else None
# Load DataFrames from CSV files
staff_assignment_df = show_dataframe(staff_assignment_csv_path)
schedule_df = show_dataframe(schedule_csv_path)
return results, staff_assignment_df, gantt_img, schedule_df, plot_img
# Define Gradio UI
am_pm_times = [f"{i:02d}:00 AM" for i in range(1, 13)] + [f"{i:02d}:00 PM" for i in range(1, 13)]
iface = gr.Interface(
fn=gradio_wrapper,
inputs=[
gr.File(label="Upload CSV"),
gr.Number(label="Beds per Staff", value=3),
gr.Number(label="Max Hours per Staff", value=40),
gr.Number(label="Hours per Cycle", value=4),
gr.Number(label="Rest Days per Week", value=2),
gr.Dropdown(label="Clinic Start Hour (AM/PM)", choices=am_pm_times, value="08:00 AM"),
gr.Dropdown(label="Clinic End Hour (AM/PM)", choices=am_pm_times, value="08:00 PM"),
gr.Number(label="Overlap Time", value=0),
gr.Number(label="Max Start Time Change", value=2)
],
outputs=[
gr.Textbox(label="Optimization Results", elem_id="optimization_results"),
gr.Dataframe(label="Staff Assignment Table", elem_id="staff_assignment_table"),
gr.Image(label="Gantt Chart", elem_id="gantt_chart"),
gr.Dataframe(label="CSV Schedule", elem_id="csv_schedule"),
gr.Image(label="Schedule Visualization", elem_id="schedule_visualization"),
],
title="Staff Scheduling Optimizer",
description="Upload a CSV file with cycle data and configure parameters to generate an optimal staff schedule.",
css="""
#staff_assignment_table {
width: 100%; /* Make the table occupy the full width of its container */
}
#csv_schedule {
width: 100%; /* Make the table occupy the full width of its container */
}
"""
)
# Launch the Gradio app
iface.launch(share=True)