PD03 commited on
Commit
abd202e
·
verified ·
1 Parent(s): 94ab63d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -197
app.py CHANGED
@@ -57,120 +57,153 @@ class SAPARPredictor:
57
 
58
  def train_model(self, progress=gr.Progress()):
59
  """Train the ML model with progress tracking"""
60
- progress(0, desc="Generating synthetic data...")
61
-
62
- # Generate training data
63
- df = self.generate_synthetic_data(1000)
64
- time.sleep(1) # Simulate data generation time
65
-
66
- progress(0.1, desc="Preparing features and labels...")
67
-
68
- # Prepare features and labels
69
- feature_columns = ['invoice_amount', 'days_overdue', 'previous_delays',
70
- 'credit_score', 'industry_risk', 'seasonality']
71
- X = df[feature_columns].values
72
- y = df['paid_on_time'].values
73
-
74
- # Split data
75
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
76
-
77
- progress(0.2, desc="Building neural network...")
78
-
79
- # Create model
80
- self.model = tf.keras.Sequential([
81
- tf.keras.layers.Dense(32, activation='relu', input_shape=(6,)),
82
- tf.keras.layers.Dropout(0.2),
83
- tf.keras.layers.Dense(16, activation='relu'),
84
- tf.keras.layers.Dropout(0.2),
85
- tf.keras.layers.Dense(1, activation='sigmoid')
86
- ])
87
-
88
- self.model.compile(
89
- optimizer=tf.keras.optimizers.Adam(0.001),
90
- loss='binary_crossentropy',
91
- metrics=['accuracy']
92
- )
93
-
94
- progress(0.3, desc="Training model...")
95
-
96
- # Train model
97
- history = self.model.fit(
98
- X_train, y_train,
99
- epochs=50,
100
- batch_size=32,
101
- validation_split=0.2,
102
- verbose=0
103
- )
104
-
105
- progress(0.8, desc="Evaluating model...")
106
-
107
- # Make predictions on test set
108
- y_pred_proba = self.model.predict(X_test)
109
- y_pred = (y_pred_proba > 0.5).astype(int)
110
-
111
- # Calculate metrics
112
- accuracy = accuracy_score(y_test, y_pred)
113
- precision = precision_score(y_test, y_pred)
114
- recall = recall_score(y_test, y_pred)
115
- f1 = f1_score(y_test, y_pred)
116
-
117
- self.training_history = history.history
118
- self.is_trained = True
119
-
120
- progress(1.0, desc="Training completed!")
121
-
122
- # Create training visualization
123
- fig = go.Figure()
124
-
125
- epochs = list(range(1, len(history.history['accuracy']) + 1))
126
-
127
- fig.add_trace(go.Scatter(
128
- x=epochs,
129
- y=history.history['accuracy'],
130
- mode='lines+markers',
131
- name='Training Accuracy',
132
- line=dict(color='#007bff', width=3),
133
- marker=dict(size=6)
134
- ))
135
-
136
- fig.add_trace(go.Scatter(
137
- x=epochs,
138
- y=history.history['val_accuracy'],
139
- mode='lines+markers',
140
- name='Validation Accuracy',
141
- line=dict(color='#28a745', width=3),
142
- marker=dict(size=6)
143
- ))
144
-
145
- fig.update_layout(
146
- title='Model Training Progress',
147
- xaxis_title='Epoch',
148
- yaxis_title='Accuracy',
149
- template='plotly_white',
150
- height=400,
151
- hovermode='x unified'
152
- )
153
-
154
- # Create metrics summary
155
- metrics_text = f"""
156
- ## 🎯 Model Performance Metrics
157
-
158
- - **Accuracy**: {accuracy:.1%}
159
- - **Precision**: {precision:.1%}
160
- - **Recall**: {recall:.1%}
161
- - **F1 Score**: {f1:.1%}
162
-
163
- ✅ Model trained successfully on 1,000 synthetic SAP AR records!
164
- """
165
-
166
- return fig, metrics_text, gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def generate_unpaid_invoices(self):
169
  """Generate sample unpaid invoices for prediction"""
170
  customers = ['SAP-CUST001', 'SAP-CUST002', 'SAP-CUST003', 'SAP-CUST004', 'SAP-CUST005']
171
 
172
  invoices = []
173
- for i in range(15):
174
  invoice_id = f"INV-{datetime.now().strftime('%Y%m%d')}-{i:03d}"
175
  customer = random.choice(customers)
176
  amount = random.randint(5000, 50000)
@@ -194,63 +227,123 @@ class SAPARPredictor:
194
  def make_predictions(self):
195
  """Make predictions on unpaid invoices"""
196
  if not self.is_trained:
197
- return None, "❌ Please train the model first!"
198
-
199
- # Generate unpaid invoices
200
- df = self.generate_unpaid_invoices()
201
-
202
- # Prepare features for prediction
203
- features = []
204
- for _, row in df.iterrows():
205
- features.append([
206
- row['Amount ($)'] / 50000, # Normalize
207
- row['Days Overdue'] / 120, # Normalize
208
- row['Previous Delays'] / 5, # Normalize
209
- row['Credit Score'] / 100, # Normalize
210
- row['Industry Risk'],
211
- row['Seasonality']
212
- ])
213
-
214
- # Make predictions
215
- predictions = self.model.predict(np.array(features))
216
-
217
- # Add predictions to dataframe
218
- df['Payment Probability'] = [f"{p[0]:.1%}" for p in predictions]
219
- df['Prediction'] = ['✅ Will Pay' if p[0] > 0.5 else '❌ Risk of Default' for p in predictions]
220
- df['Risk Level'] = ['🟢 Low' if p[0] > 0.7 else '🟡 Medium' if p[0] > 0.4 else '🔴 High' for p in predictions]
221
-
222
- # Format amount column
223
- df['Amount ($)'] = df['Amount ($)'].apply(lambda x: f"${x:,}")
224
-
225
- # Create probability distribution chart
226
- prob_values = [p[0] for p in predictions]
227
-
228
- fig = go.Figure(data=[
229
- go.Histogram(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  x=prob_values,
231
- nbinsx=20,
232
  marker_color='rgba(0, 123, 255, 0.7)',
233
  marker_line_color='rgba(0, 123, 255, 1)',
234
- marker_line_width=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  )
236
- ])
237
-
238
- fig.update_layout(
239
- title='Distribution of Payment Probabilities',
240
- xaxis_title='Payment Probability',
241
- yaxis_title='Number of Invoices',
242
- template='plotly_white',
243
- height=300
244
- )
245
-
246
- success_msg = f"🔮 Generated predictions for {len(df)} unpaid invoices!"
247
-
248
- return df, success_msg, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  # Initialize the predictor
251
  predictor = SAPARPredictor()
252
 
253
- # Create Gradio interface
254
  with gr.Blocks(
255
  theme=gr.themes.Soft(
256
  primary_hue="blue",
@@ -260,26 +353,30 @@ with gr.Blocks(
260
  title="SAP AR ML Prediction Demo",
261
  css="""
262
  .gradio-container {
263
- max-width: 1200px !important;
 
264
  }
265
  .main-header {
266
  text-align: center;
267
  margin-bottom: 2rem;
268
- }
269
- .metric-card {
270
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
271
- padding: 1rem;
272
- border-radius: 10px;
273
  color: white;
274
- text-align: center;
 
 
 
275
  }
276
  """
277
  ) as demo:
278
 
279
  gr.HTML("""
280
  <div class="main-header">
281
- <h1>🏢 SAP Account Receivable ML Prediction Demo</h1>
282
- <p style="font-size: 1.1rem; color: #666;">
 
 
283
  Machine Learning-powered invoice payment prediction system using TensorFlow
284
  </p>
285
  </div>
@@ -288,49 +385,71 @@ with gr.Blocks(
288
  with gr.Tabs() as tabs:
289
 
290
  with gr.Tab("🎯 Model Training", id=0):
291
- gr.Markdown("""
292
- ### Train ML Model
293
- Train a neural network on synthetic SAP AR data to predict invoice payment likelihood.
294
- The model uses features like invoice amount, days overdue, customer credit score, and more.
295
- """)
296
-
297
  with gr.Row():
298
- with gr.Column(scale=1):
 
 
 
 
 
 
 
299
  train_btn = gr.Button(
300
  "🚀 Train ML Model",
301
  variant="primary",
302
- size="lg"
 
303
  )
304
-
305
- with gr.Column(scale=2):
306
- metrics_display = gr.Markdown("")
307
-
308
- training_plot = gr.Plot(label="Training Progress")
309
- predict_btn = gr.Button(
310
- "🔮 Make Predictions",
311
- variant="secondary",
312
- interactive=False,
313
- size="lg"
314
- )
315
-
316
- with gr.Tab("📊 Predictions", id=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  gr.Markdown("""
318
- ### Invoice Payment Predictions
319
- View real-time predictions for unpaid invoices using the trained ML model.
320
  """)
321
 
322
- prediction_status = gr.Markdown("")
323
- predictions_df = gr.Dataframe(
324
- label="Invoice Predictions",
325
- interactive=False,
326
- wrap=True
327
- )
328
- probability_plot = gr.Plot(label="Probability Distribution")
 
 
 
 
 
 
329
 
330
  # Event handlers
331
  train_btn.click(
332
  fn=predictor.train_model,
333
- outputs=[training_plot, metrics_display, predict_btn]
 
334
  )
335
 
336
  predict_btn.click(
@@ -340,4 +459,4 @@ with gr.Blocks(
340
 
341
  # Launch the app
342
  if __name__ == "__main__":
343
- demo.launch()
 
57
 
58
  def train_model(self, progress=gr.Progress()):
59
  """Train the ML model with progress tracking"""
60
+ try:
61
+ progress(0, desc="🔄 Generating synthetic data...")
62
+
63
+ # Generate training data
64
+ df = self.generate_synthetic_data(1000)
65
+ time.sleep(0.5) # Simulate data generation time
66
+
67
+ progress(0.2, desc="📊 Preparing features and labels...")
68
+
69
+ # Prepare features and labels
70
+ feature_columns = ['invoice_amount', 'days_overdue', 'previous_delays',
71
+ 'credit_score', 'industry_risk', 'seasonality']
72
+ X = df[feature_columns].values
73
+ y = df['paid_on_time'].values
74
+
75
+ # Split data
76
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
77
+
78
+ progress(0.3, desc="🧠 Building neural network...")
79
+
80
+ # Create model
81
+ self.model = tf.keras.Sequential([
82
+ tf.keras.layers.Dense(32, activation='relu', input_shape=(6,)),
83
+ tf.keras.layers.Dropout(0.2),
84
+ tf.keras.layers.Dense(16, activation='relu'),
85
+ tf.keras.layers.Dropout(0.2),
86
+ tf.keras.layers.Dense(1, activation='sigmoid')
87
+ ])
88
+
89
+ self.model.compile(
90
+ optimizer=tf.keras.optimizers.Adam(0.001),
91
+ loss='binary_crossentropy',
92
+ metrics=['accuracy']
93
+ )
94
+
95
+ progress(0.4, desc="🎯 Training model (50 epochs)...")
96
+
97
+ # Train model
98
+ history = self.model.fit(
99
+ X_train, y_train,
100
+ epochs=50,
101
+ batch_size=32,
102
+ validation_split=0.2,
103
+ verbose=0
104
+ )
105
+
106
+ progress(0.8, desc="📈 Evaluating model performance...")
107
+
108
+ # Make predictions on test set
109
+ y_pred_proba = self.model.predict(X_test, verbose=0)
110
+ y_pred = (y_pred_proba > 0.5).astype(int)
111
+
112
+ # Calculate metrics
113
+ accuracy = accuracy_score(y_test, y_pred)
114
+ precision = precision_score(y_test, y_pred)
115
+ recall = recall_score(y_test, y_pred)
116
+ f1 = f1_score(y_test, y_pred)
117
+
118
+ self.training_history = history.history
119
+ self.is_trained = True
120
+
121
+ progress(1.0, desc="✅ Training completed successfully!")
122
+
123
+ # Create training visualization
124
+ fig = go.Figure()
125
+
126
+ epochs = list(range(1, len(history.history['accuracy']) + 1))
127
+
128
+ fig.add_trace(go.Scatter(
129
+ x=epochs,
130
+ y=history.history['accuracy'],
131
+ mode='lines+markers',
132
+ name='Training Accuracy',
133
+ line=dict(color='#007bff', width=4),
134
+ marker=dict(size=8)
135
+ ))
136
+
137
+ fig.add_trace(go.Scatter(
138
+ x=epochs,
139
+ y=history.history['val_accuracy'],
140
+ mode='lines+markers',
141
+ name='Validation Accuracy',
142
+ line=dict(color='#28a745', width=4),
143
+ marker=dict(size=8)
144
+ ))
145
+
146
+ fig.update_layout(
147
+ title={
148
+ 'text': '📊 Model Training Progress',
149
+ 'x': 0.5,
150
+ 'font': {'size': 20}
151
+ },
152
+ xaxis_title='Epoch',
153
+ yaxis_title='Accuracy',
154
+ template='plotly_white',
155
+ height=450,
156
+ hovermode='x unified',
157
+ legend=dict(
158
+ yanchor="bottom",
159
+ y=0.02,
160
+ xanchor="right",
161
+ x=0.98
162
+ )
163
+ )
164
+
165
+ # Create metrics cards HTML
166
+ metrics_html = f"""
167
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px; margin: 20px 0;">
168
+ <div style="background: linear-gradient(135deg, #007bff, #0056b3); color: white; padding: 20px; border-radius: 15px; text-align: center; box-shadow: 0 4px 15px rgba(0,123,255,0.3);">
169
+ <div style="font-size: 2.5rem; font-weight: bold; margin-bottom: 5px;">{accuracy:.1%}</div>
170
+ <div style="font-size: 1.1rem;">🎯 Accuracy</div>
171
+ </div>
172
+ <div style="background: linear-gradient(135deg, #28a745, #20c997); color: white; padding: 20px; border-radius: 15px; text-align: center; box-shadow: 0 4px 15px rgba(40,167,69,0.3);">
173
+ <div style="font-size: 2.5rem; font-weight: bold; margin-bottom: 5px;">{precision:.1%}</div>
174
+ <div style="font-size: 1.1rem;">🎯 Precision</div>
175
+ </div>
176
+ <div style="background: linear-gradient(135deg, #ffc107, #fd7e14); color: white; padding: 20px; border-radius: 15px; text-align: center; box-shadow: 0 4px 15px rgba(255,193,7,0.3);">
177
+ <div style="font-size: 2.5rem; font-weight: bold; margin-bottom: 5px;">{recall:.1%}</div>
178
+ <div style="font-size: 1.1rem;">📊 Recall</div>
179
+ </div>
180
+ <div style="background: linear-gradient(135deg, #17a2b8, #138496); color: white; padding: 20px; border-radius: 15px; text-align: center; box-shadow: 0 4px 15px rgba(23,162,184,0.3);">
181
+ <div style="font-size: 2.5rem; font-weight: bold; margin-bottom: 5px;">{f1:.1%}</div>
182
+ <div style="font-size: 1.1rem">⚖️ F1 Score</div>
183
+ </div>
184
+ </div>
185
+ <div style="background: #d4edda; border: 1px solid #c3e6cb; color: #155724; padding: 15px; border-radius: 10px; margin-top: 15px; text-align: center;">
186
+ <strong>✅ Model trained successfully on 1,000 synthetic SAP AR records!</strong><br>
187
+ <em>The model is now ready to make predictions on unpaid invoices.</em>
188
+ </div>
189
+ """
190
+
191
+ return fig, metrics_html, gr.update(interactive=True, variant="primary")
192
+
193
+ except Exception as e:
194
+ error_html = f"""
195
+ <div style="background: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; padding: 15px; border-radius: 10px; text-align: center;">
196
+ <strong>❌ Training failed:</strong> {str(e)}
197
+ </div>
198
+ """
199
+ return None, error_html, gr.update(interactive=False)
200
 
201
  def generate_unpaid_invoices(self):
202
  """Generate sample unpaid invoices for prediction"""
203
  customers = ['SAP-CUST001', 'SAP-CUST002', 'SAP-CUST003', 'SAP-CUST004', 'SAP-CUST005']
204
 
205
  invoices = []
206
+ for i in range(12):
207
  invoice_id = f"INV-{datetime.now().strftime('%Y%m%d')}-{i:03d}"
208
  customer = random.choice(customers)
209
  amount = random.randint(5000, 50000)
 
227
  def make_predictions(self):
228
  """Make predictions on unpaid invoices"""
229
  if not self.is_trained:
230
+ error_msg = """
231
+ <div style="background: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; padding: 15px; border-radius: 10px; text-align: center;">
232
+ <strong>❌ Please train the model first!</strong><br>
233
+ <em>Go to the Model Training tab and click "Train ML Model"</em>
234
+ </div>
235
+ """
236
+ return None, error_msg, None
237
+
238
+ try:
239
+ # Generate unpaid invoices
240
+ df = self.generate_unpaid_invoices()
241
+
242
+ # Prepare features for prediction
243
+ features = []
244
+ for _, row in df.iterrows():
245
+ features.append([
246
+ row['Amount ($)'] / 50000, # Normalize
247
+ row['Days Overdue'] / 120, # Normalize
248
+ row['Previous Delays'] / 5, # Normalize
249
+ row['Credit Score'] / 100, # Normalize
250
+ row['Industry Risk'],
251
+ row['Seasonality']
252
+ ])
253
+
254
+ # Make predictions
255
+ predictions = self.model.predict(np.array(features), verbose=0)
256
+
257
+ # Create results dataframe with better formatting
258
+ results_df = df.copy()
259
+ prob_values = [p[0] for p in predictions]
260
+
261
+ # Add prediction columns
262
+ results_df['Payment Probability'] = [f"{p:.1%}" for p in prob_values]
263
+ results_df['Prediction'] = ['✅ Will Pay' if p > 0.5 else '❌ Risk of Default' for p in prob_values]
264
+ results_df['Risk Level'] = ['🟢 Low Risk' if p > 0.7 else '🟡 Medium Risk' if p > 0.4 else '🔴 High Risk' for p in prob_values]
265
+
266
+ # Format amount column
267
+ results_df['Amount ($)'] = results_df['Amount ($)'].apply(lambda x: f"${x:,}")
268
+
269
+ # Reorder columns for better display
270
+ column_order = ['Invoice ID', 'Customer', 'Amount ($)', 'Days Overdue', 'Credit Score',
271
+ 'Payment Probability', 'Prediction', 'Risk Level']
272
+ results_df = results_df[column_order]
273
+
274
+ # Create probability distribution chart
275
+ fig = go.Figure()
276
+
277
+ # Create histogram
278
+ fig.add_trace(go.Histogram(
279
  x=prob_values,
280
+ nbinsx=15,
281
  marker_color='rgba(0, 123, 255, 0.7)',
282
  marker_line_color='rgba(0, 123, 255, 1)',
283
+ marker_line_width=2,
284
+ name='Payment Probability'
285
+ ))
286
+
287
+ # Add vertical lines for risk thresholds
288
+ fig.add_vline(x=0.4, line_dash="dash", line_color="orange",
289
+ annotation_text="Medium Risk Threshold")
290
+ fig.add_vline(x=0.7, line_dash="dash", line_color="green",
291
+ annotation_text="Low Risk Threshold")
292
+
293
+ fig.update_layout(
294
+ title={
295
+ 'text': '📊 Distribution of Payment Probabilities',
296
+ 'x': 0.5,
297
+ 'font': {'size': 18}
298
+ },
299
+ xaxis_title='Payment Probability',
300
+ yaxis_title='Number of Invoices',
301
+ template='plotly_white',
302
+ height=400,
303
+ showlegend=False
304
  )
305
+
306
+ # Count predictions by category
307
+ will_pay = sum(1 for p in prob_values if p > 0.5)
308
+ risk_default = len(prob_values) - will_pay
309
+ high_risk = sum(1 for p in prob_values if p <= 0.4)
310
+
311
+ success_msg = f"""
312
+ <div style="background: #d4edda; border: 1px solid #c3e6cb; color: #155724; padding: 20px; border-radius: 10px; margin: 15px 0;">
313
+ <div style="text-align: center; margin-bottom: 15px;">
314
+ <strong style="font-size: 1.2rem;">🔮 Prediction Results Generated Successfully!</strong>
315
+ </div>
316
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px; text-align: center;">
317
+ <div style="background: rgba(40, 167, 69, 0.1); padding: 15px; border-radius: 8px; border: 2px solid #28a745;">
318
+ <div style="font-size: 2rem; font-weight: bold; color: #28a745;">{will_pay}</div>
319
+ <div style="font-weight: bold;">✅ Will Pay</div>
320
+ </div>
321
+ <div style="background: rgba(220, 53, 69, 0.1); padding: 15px; border-radius: 8px; border: 2px solid #dc3545;">
322
+ <div style="font-size: 2rem; font-weight: bold; color: #dc3545;">{risk_default}</div>
323
+ <div style="font-weight: bold;">❌ Risk of Default</div>
324
+ </div>
325
+ <div style="background: rgba(255, 193, 7, 0.1); padding: 15px; border-radius: 8px; border: 2px solid #ffc107;">
326
+ <div style="font-size: 2rem; font-weight: bold; color: #856404;">{high_risk}</div>
327
+ <div style="font-weight: bold;">🔴 High Risk</div>
328
+ </div>
329
+ </div>
330
+ </div>
331
+ """
332
+
333
+ return results_df, success_msg, fig
334
+
335
+ except Exception as e:
336
+ error_msg = f"""
337
+ <div style="background: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; padding: 15px; border-radius: 10px; text-align: center;">
338
+ <strong>❌ Prediction failed:</strong> {str(e)}
339
+ </div>
340
+ """
341
+ return None, error_msg, None
342
 
343
  # Initialize the predictor
344
  predictor = SAPARPredictor()
345
 
346
+ # Create Gradio interface with improved layout
347
  with gr.Blocks(
348
  theme=gr.themes.Soft(
349
  primary_hue="blue",
 
353
  title="SAP AR ML Prediction Demo",
354
  css="""
355
  .gradio-container {
356
+ max-width: 1400px !important;
357
+ margin: 0 auto !important;
358
  }
359
  .main-header {
360
  text-align: center;
361
  margin-bottom: 2rem;
362
+ padding: 2rem;
 
363
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
364
+ border-radius: 15px;
 
365
  color: white;
366
+ margin-bottom: 30px;
367
+ }
368
+ .tab-nav {
369
+ margin-bottom: 20px;
370
  }
371
  """
372
  ) as demo:
373
 
374
  gr.HTML("""
375
  <div class="main-header">
376
+ <h1 style="font-size: 2.5rem; margin-bottom: 15px; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);">
377
+ 🏢 SAP Account Receivable ML Prediction Demo
378
+ </h1>
379
+ <p style="font-size: 1.2rem; opacity: 0.95; margin: 0;">
380
  Machine Learning-powered invoice payment prediction system using TensorFlow
381
  </p>
382
  </div>
 
385
  with gr.Tabs() as tabs:
386
 
387
  with gr.Tab("🎯 Model Training", id=0):
 
 
 
 
 
 
388
  with gr.Row():
389
+ with gr.Column(scale=2):
390
+ gr.Markdown("""
391
+ ### 🚀 Train Your ML Model
392
+
393
+ This will create a neural network trained on **1,000 synthetic SAP AR records** to predict invoice payment likelihood.
394
+ The model analyzes multiple factors including invoice amount, days overdue, customer credit score, and payment history.
395
+ """)
396
+
397
  train_btn = gr.Button(
398
  "🚀 Train ML Model",
399
  variant="primary",
400
+ size="lg",
401
+ scale=1
402
  )
403
+
404
+ with gr.Column(scale=1):
405
+ gr.Markdown("""
406
+ ### 📋 Model Features
407
+ - Invoice Amount
408
+ - Days Overdue
409
+ - Previous Delays
410
+ - Credit Score
411
+ - Industry Risk
412
+ - Seasonality
413
+ """)
414
+
415
+ metrics_display = gr.HTML()
416
+
417
+ with gr.Row():
418
+ training_plot = gr.Plot(label="📈 Training Progress")
419
+
420
+ with gr.Row():
421
+ predict_btn = gr.Button(
422
+ "🔮 Generate Predictions",
423
+ variant="secondary",
424
+ interactive=False,
425
+ size="lg"
426
+ )
427
+
428
+ with gr.Tab("📊 Invoice Predictions", id=1):
429
  gr.Markdown("""
430
+ ### 🔮 Real-time Payment Predictions
431
+ View ML-powered predictions for unpaid invoices with probability scores and risk assessments.
432
  """)
433
 
434
+ prediction_status = gr.HTML()
435
+
436
+ with gr.Row():
437
+ with gr.Column(scale=2):
438
+ predictions_df = gr.Dataframe(
439
+ label="📋 Invoice Predictions",
440
+ interactive=False,
441
+ wrap=True,
442
+ height=400
443
+ )
444
+
445
+ with gr.Column(scale=1):
446
+ probability_plot = gr.Plot(label="📊 Probability Distribution")
447
 
448
  # Event handlers
449
  train_btn.click(
450
  fn=predictor.train_model,
451
+ outputs=[training_plot, metrics_display, predict_btn],
452
+ show_progress=True
453
  )
454
 
455
  predict_btn.click(
 
459
 
460
  # Launch the app
461
  if __name__ == "__main__":
462
+ demo.launch(share=True)