QSBench commited on
Commit
c752e48
·
verified ·
1 Parent(s): 0553f08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -26
app.py CHANGED
@@ -2,6 +2,7 @@ import ast
2
  import logging
3
  import re
4
  from typing import Dict, List, Optional, Tuple
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import numpy as np
@@ -9,6 +10,7 @@ import pandas as pd
9
  from datasets import load_dataset
10
  from sklearn.ensemble import HistGradientBoostingClassifier
11
  from sklearn.impute import SimpleImputer
 
12
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
13
  from sklearn.model_selection import train_test_split
14
  from sklearn.pipeline import Pipeline
@@ -271,8 +273,8 @@ def build_dataset_profile(df: pd.DataFrame) -> str:
271
  """Build a short dataset summary for the explorer tab."""
272
  return (
273
  f"### Dataset profile\n\n"
274
- f"**Rows:** {len(df):,} \n"
275
- f"**Columns:** {len(df.columns):,} \n"
276
  f"**Classes:** {', '.join(CLASS_ORDER)}"
277
  )
278
 
@@ -292,9 +294,9 @@ def refresh_explorer(dataset_key: str, split_name: str) -> Tuple[gr.update, pd.D
292
  profile_box = build_dataset_profile(df)
293
  summary_box = (
294
  f"### Split summary\n\n"
295
- f"**Dataset:** `{dataset_key}` \n"
296
- f"**Label:** `{REPO_CONFIG[dataset_key]['label']}` \n"
297
- f"**Available splits:** {', '.join(splits)} \n"
298
  f"**Preview rows:** {len(display_df)}"
299
  )
300
  return (
@@ -307,15 +309,16 @@ def refresh_explorer(dataset_key: str, split_name: str) -> Tuple[gr.update, pd.D
307
  )
308
 
309
 
310
- def sync_feature_picker(_dataset_key: str) -> gr.update:
311
- """Refresh the feature list from the combined dataset."""
312
- df = load_combined_dataset()
313
  features = get_available_feature_columns(df)
314
  defaults = default_feature_selection(features)
315
  return gr.update(choices=features, value=defaults)
316
 
317
 
318
  def train_classifier(
 
319
  feature_columns: List[str],
320
  test_size: float,
321
  n_estimators: int,
@@ -326,7 +329,7 @@ def train_classifier(
326
  if not feature_columns:
327
  return None, "### ❌ Please select at least one feature."
328
 
329
- df = load_combined_dataset()
330
  required_cols = feature_columns + ["noise_label"]
331
  train_df = df.dropna(subset=required_cols).copy()
332
  train_df = train_df[train_df["noise_label"].isin(CLASS_ORDER)]
@@ -341,7 +344,6 @@ def train_classifier(
341
  depth = int(max_depth) if max_depth and int(max_depth) > 0 else None
342
  max_iter = int(n_estimators)
343
 
344
- # --- Stratified split ---
345
  try:
346
  X_train, X_test, y_train, y_test = train_test_split(
347
  X, y, test_size=test_size, random_state=seed, stratify=y
@@ -351,7 +353,6 @@ def train_classifier(
351
  X, y, test_size=test_size, random_state=seed
352
  )
353
 
354
- # --- Pipeline with class_weight='balanced' ---
355
  model = Pipeline(
356
  steps=[
357
  ("imputer", SimpleImputer(strategy="median")),
@@ -363,9 +364,9 @@ def train_classifier(
363
  max_depth=depth,
364
  random_state=seed,
365
  min_samples_leaf=1,
366
- class_weight="balanced", # ← главное улучшение
367
- learning_rate=0.1, # можно поиграть (0.05-0.2)
368
- max_bins=255, # стандартное хорошее значение
369
  ),
370
  ),
371
  ]
@@ -378,8 +379,16 @@ def train_classifier(
378
  macro_f1 = float(f1_score(y_test, y_pred, average="macro", zero_division=0))
379
  weighted_f1 = float(f1_score(y_test, y_pred, average="weighted", zero_division=0))
380
 
381
- classifier = model.named_steps["classifier"]
382
- importances = getattr(classifier, "feature_importances_", None)
 
 
 
 
 
 
 
 
383
 
384
  fig = make_classification_figure(y_test.to_numpy(), y_pred, CLASS_ORDER, list(feature_columns), importances)
385
 
@@ -389,19 +398,18 @@ def train_classifier(
389
  labels=CLASS_ORDER,
390
  zero_division=0,
391
  )
392
-
393
  results = (
394
  "### Classification results\n\n"
395
- f"**Rows used:** {len(train_df):,} \n"
396
- f"**Test size:** {test_size:.0%} \n"
397
- f"**Accuracy:** {accuracy:.4f} \n"
398
- f"**Macro F1:** {macro_f1:.4f} \n"
 
399
  f"**Weighted F1:** {weighted_f1:.4f}\n\n"
400
  "```text\n"
401
  f"{report}"
402
  "```"
403
  )
404
-
405
  return fig, results
406
 
407
 
@@ -439,6 +447,11 @@ with gr.Blocks(title=APP_TITLE) as demo:
439
  transpiled_qasm = gr.Code(label="Transpiled QASM", language=None)
440
 
441
  with gr.TabItem("🧠 Classification"):
 
 
 
 
 
442
  feature_picker = gr.CheckboxGroup(label="Input features", choices=[])
443
  test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test split")
444
  n_estimators = gr.Slider(50, 400, value=200, step=10, label="Trees")
@@ -470,11 +483,11 @@ with gr.Blocks(title=APP_TITLE) as demo:
470
  [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
471
  )
472
 
473
- dataset_dropdown.change(sync_feature_picker, [dataset_dropdown], [feature_picker])
474
 
475
  run_btn.click(
476
  train_classifier,
477
- [feature_picker, test_size, n_estimators, max_depth, seed],
478
  [plot, metrics],
479
  )
480
 
@@ -483,8 +496,8 @@ with gr.Blocks(title=APP_TITLE) as demo:
483
  [dataset_dropdown, split_dropdown],
484
  [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
485
  )
486
- demo.load(sync_feature_picker, [dataset_dropdown], [feature_picker])
487
 
488
 
489
  if __name__ == "__main__":
490
- demo.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)
 
2
  import logging
3
  import re
4
  from typing import Dict, List, Optional, Tuple
5
+
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import numpy as np
 
10
  from datasets import load_dataset
11
  from sklearn.ensemble import HistGradientBoostingClassifier
12
  from sklearn.impute import SimpleImputer
13
+ from sklearn.inspection import permutation_importance
14
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
15
  from sklearn.model_selection import train_test_split
16
  from sklearn.pipeline import Pipeline
 
273
  """Build a short dataset summary for the explorer tab."""
274
  return (
275
  f"### Dataset profile\n\n"
276
+ f"**Rows:** {len(df):,} \n"
277
+ f"**Columns:** {len(df.columns):,} \n"
278
  f"**Classes:** {', '.join(CLASS_ORDER)}"
279
  )
280
 
 
294
  profile_box = build_dataset_profile(df)
295
  summary_box = (
296
  f"### Split summary\n\n"
297
+ f"**Dataset:** `{dataset_key}` \n"
298
+ f"**Label:** `{REPO_CONFIG[dataset_key]['label']}` \n"
299
+ f"**Available splits:** {', '.join(splits)} \n"
300
  f"**Preview rows:** {len(display_df)}"
301
  )
302
  return (
 
309
  )
310
 
311
 
312
+ def sync_feature_picker(dataset_key: str) -> gr.update:
313
+ """Refresh the feature list from the selected dataset."""
314
+ df = load_single_dataset(dataset_key)
315
  features = get_available_feature_columns(df)
316
  defaults = default_feature_selection(features)
317
  return gr.update(choices=features, value=defaults)
318
 
319
 
320
  def train_classifier(
321
+ dataset_key: str,
322
  feature_columns: List[str],
323
  test_size: float,
324
  n_estimators: int,
 
329
  if not feature_columns:
330
  return None, "### ❌ Please select at least one feature."
331
 
332
+ df = load_single_dataset(dataset_key)
333
  required_cols = feature_columns + ["noise_label"]
334
  train_df = df.dropna(subset=required_cols).copy()
335
  train_df = train_df[train_df["noise_label"].isin(CLASS_ORDER)]
 
344
  depth = int(max_depth) if max_depth and int(max_depth) > 0 else None
345
  max_iter = int(n_estimators)
346
 
 
347
  try:
348
  X_train, X_test, y_train, y_test = train_test_split(
349
  X, y, test_size=test_size, random_state=seed, stratify=y
 
353
  X, y, test_size=test_size, random_state=seed
354
  )
355
 
 
356
  model = Pipeline(
357
  steps=[
358
  ("imputer", SimpleImputer(strategy="median")),
 
364
  max_depth=depth,
365
  random_state=seed,
366
  min_samples_leaf=1,
367
+ class_weight="balanced",
368
+ learning_rate=0.1,
369
+ max_bins=255,
370
  ),
371
  ),
372
  ]
 
379
  macro_f1 = float(f1_score(y_test, y_pred, average="macro", zero_division=0))
380
  weighted_f1 = float(f1_score(y_test, y_pred, average="weighted", zero_division=0))
381
 
382
+ perm = permutation_importance(
383
+ model,
384
+ X_test,
385
+ y_test,
386
+ n_repeats=8,
387
+ random_state=seed,
388
+ scoring="f1_macro",
389
+ n_jobs=-1,
390
+ )
391
+ importances = perm.importances_mean
392
 
393
  fig = make_classification_figure(y_test.to_numpy(), y_pred, CLASS_ORDER, list(feature_columns), importances)
394
 
 
398
  labels=CLASS_ORDER,
399
  zero_division=0,
400
  )
 
401
  results = (
402
  "### Classification results\n\n"
403
+ f"**Rows used:** {len(train_df):,} \n"
404
+ f"**Dataset:** `{dataset_key}` \n"
405
+ f"**Test size:** {test_size:.0%} \n"
406
+ f"**Accuracy:** {accuracy:.4f} \n"
407
+ f"**Macro F1:** {macro_f1:.4f} \n"
408
  f"**Weighted F1:** {weighted_f1:.4f}\n\n"
409
  "```text\n"
410
  f"{report}"
411
  "```"
412
  )
 
413
  return fig, results
414
 
415
 
 
447
  transpiled_qasm = gr.Code(label="Transpiled QASM", language=None)
448
 
449
  with gr.TabItem("🧠 Classification"):
450
+ class_dataset_dropdown = gr.Dropdown(
451
+ list(REPO_CONFIG.keys()),
452
+ value="clean",
453
+ label="Dataset",
454
+ )
455
  feature_picker = gr.CheckboxGroup(label="Input features", choices=[])
456
  test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="Test split")
457
  n_estimators = gr.Slider(50, 400, value=200, step=10, label="Trees")
 
483
  [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
484
  )
485
 
486
+ class_dataset_dropdown.change(sync_feature_picker, [class_dataset_dropdown], [feature_picker])
487
 
488
  run_btn.click(
489
  train_classifier,
490
+ [class_dataset_dropdown, feature_picker, test_size, n_estimators, max_depth, seed],
491
  [plot, metrics],
492
  )
493
 
 
496
  [dataset_dropdown, split_dropdown],
497
  [split_dropdown, explorer_df, raw_qasm, transpiled_qasm, profile_box, summary_box],
498
  )
499
+ demo.load(sync_feature_picker, [class_dataset_dropdown], [feature_picker])
500
 
501
 
502
  if __name__ == "__main__":
503
+ demo.launch(theme=gr.themes.Soft(), css=CUSTOM_CSS)