import matplotlib.pyplot as plt import numpy as np def create_prediction_plot(top_preds): """Create a horizontal bar chart of predictions""" labels = [label for label, _ in top_preds] probs = [prob for _, prob in top_preds] # Create figure fig, ax = plt.subplots(figsize=(10, 6)) # Create horizontal bar chart y_pos = np.arange(len(labels)) colors = plt.cm.RdYlGn(np.array(probs)) # Color based on confidence ax.barh(y_pos, probs, color=colors, alpha=0.8) ax.set_yticks(y_pos) ax.set_yticklabels(labels) ax.invert_yaxis() # Top prediction at the top ax.set_xlabel('Confidence Score', fontsize=12) ax.set_title('Top Disease Predictions', fontsize=14, fontweight='bold') ax.set_xlim([0, 1]) # Add value labels on bars for i, (label, prob) in enumerate(zip(labels, probs)): ax.text(prob + 0.01, i, f'{prob:.3f}', va='center', fontsize=10) plt.tight_layout() return fig