Spaces:
Sleeping
Sleeping
| import joblib | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| # Load pre-trained encoder and KMeans model | |
| encoder = joblib.load('fund_type_encoder.pkl') | |
| model = joblib.load('investor_segmentation.pkl') | |
| # Function to process the uploaded file and create plot | |
| def analyze_investors(file): | |
| df = pd.read_csv(file.name) # Read uploaded file | |
| # Transform the 'Fund_Type_Viewed_Most' column using encoder | |
| X_encoded = encoder.transform(df[['Fund_Type_Viewed_Most']]) | |
| # Predict cluster using the pre-trained model | |
| df['Cluster'] = model.predict(X_encoded) | |
| # Assign meaningful labels to the clusters | |
| cluster_labels = ( | |
| df.groupby('Cluster')['Fund_Type_Viewed_Most'] | |
| .agg(lambda x: x.value_counts().idxmax()) | |
| .to_dict() | |
| ) | |
| # Map cluster numbers to readable labels | |
| df['Cluster_Label'] = df['Cluster'].map(lambda c: f"{cluster_labels[c]}") | |
| # Prepare pie chart data | |
| cluster_counts = df['Cluster_Label'].value_counts() | |
| # Prepare text box data | |
| cluster_summary = df.groupby('Cluster_Label')['News_Reads_Per_Week'].sum() | |
| text_str = "\n".join([f"{label}: {reads} reads/week" for label, reads in cluster_summary.items()]) | |
| # Create Plot | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) | |
| # Pie Chart | |
| #ax1.pie(cluster_counts, labels=cluster_counts.index, autopct='%1.1f%%', startangle=90) | |
| # Donut Chart | |
| ax1.pie(cluster_counts, labels=cluster_counts.index, autopct='%1.1f%%', startangle=90, wedgeprops={'width': 0.3}) | |
| ax1.set_title('Fund Type Distribution') | |
| ax1.axis('equal') | |
| # Text Box | |
| ax2.axis('off') | |
| props = dict(boxstyle='round', facecolor='lightgrey', alpha=0.4) | |
| ax2.text(0.95, 0.9, text_str, fontsize=12, va='top', ha='right', bbox=props) | |
| ax2.set_title('Total News Reads per Week', loc='right') | |
| return fig | |
| # Set up Gradio interface | |
| gr.Interface( | |
| fn=analyze_investors, | |
| inputs=gr.File(label="Upload Investors CSV"), | |
| outputs=gr.Plot(label="Investors Clusters Visualization") | |
| ).launch() | |