Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,8 @@ import gradio as gr
|
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import seaborn as sns
|
| 11 |
import io
|
|
|
|
|
|
|
| 12 |
from PIL import Image
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings('ignore')
|
|
@@ -19,7 +21,7 @@ def load_and_preprocess_data(file):
|
|
| 19 |
data = pd.read_csv(file.name)
|
| 20 |
|
| 21 |
# Convert suits and ranks to numerical values
|
| 22 |
-
suit_order = {'spades': 0, 'hearts': 1, '
|
| 23 |
rank_order = {'ace': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7, '9': 8, '10': 9,
|
| 24 |
'jack': 10, 'queen': 11, 'king': 12}
|
| 25 |
|
|
@@ -100,7 +102,19 @@ def plot_accuracy_chart(accuracies):
|
|
| 100 |
plt.close()
|
| 101 |
return img
|
| 102 |
|
| 103 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def train_model(file, n_estimators, learning_rate, max_depth, subsample, progress=gr.Progress()):
|
| 105 |
progress(0, desc="Starting...")
|
| 106 |
results = []
|
|
@@ -110,7 +124,7 @@ def train_model(file, n_estimators, learning_rate, max_depth, subsample, progres
|
|
| 110 |
progress(0.1, desc="Loading and preprocessing data...")
|
| 111 |
data, error = load_and_preprocess_data(file)
|
| 112 |
if error:
|
| 113 |
-
return error, None, None
|
| 114 |
|
| 115 |
# Create features
|
| 116 |
progress(0.2, desc="Engineering features...")
|
|
@@ -126,13 +140,13 @@ def train_model(file, n_estimators, learning_rate, max_depth, subsample, progres
|
|
| 126 |
}
|
| 127 |
|
| 128 |
# Scale features
|
| 129 |
-
progress(0.3, desc="Scaling features
|
| 130 |
-
scaler = StandardScaler()
|
| 131 |
features_scaled = scaler.fit_transform(features)
|
| 132 |
features_scaled = pd.DataFrame(features_scaled, columns=features.columns)
|
| 133 |
|
| 134 |
accuracies = {}
|
| 135 |
confusion_matrices = []
|
|
|
|
| 136 |
|
| 137 |
# Train models
|
| 138 |
for i, (target_name, target) in enumerate(targets.items()):
|
|
@@ -165,6 +179,9 @@ def train_model(file, n_estimators, learning_rate, max_depth, subsample, progres
|
|
| 165 |
verbose=False
|
| 166 |
)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
# Evaluate
|
| 169 |
y_pred = model.predict(X_test)
|
| 170 |
accuracy = accuracy_score(y_test, y_pred)
|
|
@@ -179,20 +196,23 @@ def train_model(file, n_estimators, learning_rate, max_depth, subsample, progres
|
|
| 179 |
cm_plot = plot_confusion_matrix(y_test, y_pred, f"Confusion Matrix - {target_name}")
|
| 180 |
confusion_matrices.append(cm_plot)
|
| 181 |
|
| 182 |
-
progress(0.9, desc="Generating visualizations...")
|
| 183 |
# Generate accuracy bar chart
|
| 184 |
accuracy_plot = plot_accuracy_chart(accuracies)
|
| 185 |
|
|
|
|
|
|
|
|
|
|
| 186 |
progress(1.0, desc="Completed!")
|
| 187 |
-
return "\n".join(results), accuracy_plot, confusion_matrices
|
| 188 |
|
| 189 |
except Exception as e:
|
| 190 |
-
return f"Error during training: {str(e)}", None, None
|
| 191 |
|
| 192 |
# Gradio interface
|
| 193 |
with gr.Blocks() as demo:
|
| 194 |
gr.Markdown("# Card Game Prediction Model Training")
|
| 195 |
-
gr.Markdown("Upload the training dataset and configure hyperparameters to train the model. Track progress
|
| 196 |
|
| 197 |
file_input = gr.File(label="Upload TRAINING_CARD_DATA.csv")
|
| 198 |
n_estimators = gr.Slider(50, 300, value=100, step=10, label="Number of Estimators")
|
|
@@ -205,11 +225,12 @@ with gr.Blocks() as demo:
|
|
| 205 |
output_text = gr.Textbox(label="Training Results")
|
| 206 |
accuracy_plot = gr.Image(label="Accuracy Comparison")
|
| 207 |
confusion_plots = gr.Gallery(label="Confusion Matrices")
|
|
|
|
| 208 |
|
| 209 |
train_button.click(
|
| 210 |
fn=train_model,
|
| 211 |
inputs=[file_input, n_estimators, learning_rate, max_depth, subsample],
|
| 212 |
-
outputs=[output_text, accuracy_plot, confusion_plots]
|
| 213 |
)
|
| 214 |
|
| 215 |
demo.launch()
|
|
|
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import seaborn as sns
|
| 11 |
import io
|
| 12 |
+
import zipfile
|
| 13 |
+
import joblib
|
| 14 |
from PIL import Image
|
| 15 |
import warnings
|
| 16 |
warnings.filterwarnings('ignore')
|
|
|
|
| 21 |
data = pd.read_csv(file.name)
|
| 22 |
|
| 23 |
# Convert suits and ranks to numerical values
|
| 24 |
+
suit_order = {'spades': 0, 'hearts': 1, 'logs': 2, 'diamonds': 3}
|
| 25 |
rank_order = {'ace': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7, '9': 8, '10': 9,
|
| 26 |
'jack': 10, 'queen': 11, 'king': 12}
|
| 27 |
|
|
|
|
| 102 |
plt.close()
|
| 103 |
return img
|
| 104 |
|
| 105 |
+
# Function to create a ZIP file of models
|
| 106 |
+
def create_model_zip(models):
|
| 107 |
+
zip_buffer = io.BytesIO()
|
| 108 |
+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
| 109 |
+
for model_name, model in models.items():
|
| 110 |
+
model_buffer = io.BytesIO()
|
| 111 |
+
joblib.dump(model, model_buffer)
|
| 112 |
+
model_buffer.seek(0)
|
| 113 |
+
zip_file.writestr(f"{model_name}_model.pkl", model_buffer.getvalue())
|
| 114 |
+
zip_buffer.seek(0)
|
| 115 |
+
return zip_buffer
|
| 116 |
+
|
| 117 |
+
# Training function with progress tracking and model saving
|
| 118 |
def train_model(file, n_estimators, learning_rate, max_depth, subsample, progress=gr.Progress()):
|
| 119 |
progress(0, desc="Starting...")
|
| 120 |
results = []
|
|
|
|
| 124 |
progress(0.1, desc="Loading and preprocessing data...")
|
| 125 |
data, error = load_and_preprocess_data(file)
|
| 126 |
if error:
|
| 127 |
+
return error, None, None, None
|
| 128 |
|
| 129 |
# Create features
|
| 130 |
progress(0.2, desc="Engineering features...")
|
|
|
|
| 140 |
}
|
| 141 |
|
| 142 |
# Scale features
|
| 143 |
+
progress(0.3, desc="Scaling features scaler = StandardScaler()
|
|
|
|
| 144 |
features_scaled = scaler.fit_transform(features)
|
| 145 |
features_scaled = pd.DataFrame(features_scaled, columns=features.columns)
|
| 146 |
|
| 147 |
accuracies = {}
|
| 148 |
confusion_matrices = []
|
| 149 |
+
trained_models = {}
|
| 150 |
|
| 151 |
# Train models
|
| 152 |
for i, (target_name, target) in enumerate(targets.items()):
|
|
|
|
| 179 |
verbose=False
|
| 180 |
)
|
| 181 |
|
| 182 |
+
# Save model
|
| 183 |
+
trained_models[target_name] = model
|
| 184 |
+
|
| 185 |
# Evaluate
|
| 186 |
y_pred = model.predict(X_test)
|
| 187 |
accuracy = accuracy_score(y_test, y_pred)
|
|
|
|
| 196 |
cm_plot = plot_confusion_matrix(y_test, y_pred, f"Confusion Matrix - {target_name}")
|
| 197 |
confusion_matrices.append(cm_plot)
|
| 198 |
|
| 199 |
+
progress(0.9, desc="Generating visualizations and model archive...")
|
| 200 |
# Generate accuracy bar chart
|
| 201 |
accuracy_plot = plot_accuracy_chart(accuracies)
|
| 202 |
|
| 203 |
+
# Create ZIP file of models
|
| 204 |
+
model_zip = create_model_zip(trained_models)
|
| 205 |
+
|
| 206 |
progress(1.0, desc="Completed!")
|
| 207 |
+
return "\n".join(results), accuracy_plot, confusion_matrices, model_zip
|
| 208 |
|
| 209 |
except Exception as e:
|
| 210 |
+
return f"Error during training: {str(e)}", None, None, None
|
| 211 |
|
| 212 |
# Gradio interface
|
| 213 |
with gr.Blocks() as demo:
|
| 214 |
gr.Markdown("# Card Game Prediction Model Training")
|
| 215 |
+
gr.Markdown("Upload the training dataset and configure hyperparameters to train the model. Track progress, view results, and download trained models.")
|
| 216 |
|
| 217 |
file_input = gr.File(label="Upload TRAINING_CARD_DATA.csv")
|
| 218 |
n_estimators = gr.Slider(50, 300, value=100, step=10, label="Number of Estimators")
|
|
|
|
| 225 |
output_text = gr.Textbox(label="Training Results")
|
| 226 |
accuracy_plot = gr.Image(label="Accuracy Comparison")
|
| 227 |
confusion_plots = gr.Gallery(label="Confusion Matrices")
|
| 228 |
+
model_download = gr.File(label="Download Trained Models (ZIP)")
|
| 229 |
|
| 230 |
train_button.click(
|
| 231 |
fn=train_model,
|
| 232 |
inputs=[file_input, n_estimators, learning_rate, max_depth, subsample],
|
| 233 |
+
outputs=[output_text, accuracy_plot, confusion_plots, model_download]
|
| 234 |
)
|
| 235 |
|
| 236 |
demo.launch()
|