QSBench commited on
Commit
dcea369
·
1 Parent(s): c7ba23d
Files changed (1) hide show
  1. app.py +39 -36
app.py CHANGED
@@ -4,54 +4,58 @@ from datasets import load_dataset
4
  from sklearn.ensemble import RandomForestRegressor
5
  from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
7
- import sys
8
 
9
- # Загружаем датасет
10
  print("Loading dataset...")
11
- ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
12
- print("Available splits:", list(ds.keys())) # выводим в логи
 
 
 
 
 
 
 
 
13
 
14
  # Функция для отображения таблицы
15
  def show_data(split):
16
- try:
17
- df = pd.DataFrame(ds[split])
18
- return df.head(10)
19
- except Exception as e:
20
- return f"Error loading data: {e}"
21
 
22
  # Функция для обучения модели и создания графика
23
  def train_model():
24
- # Определяем доступные сплиты
25
- splits = list(ds.keys())
26
- if "test" not in splits:
27
- return "Error: 'test' split not found in dataset", None
28
-
29
  feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
30
  target_col = "ideal_expval_Z_global"
31
-
32
  # Проверяем наличие колонок
33
- sample_df = pd.DataFrame(ds[splits[0]])
34
- if not all(col in sample_df.columns for col in feature_cols + [target_col]):
35
- missing = [col for col in feature_cols + [target_col] if col not in sample_df.columns]
36
- return f"Error: missing columns: {missing}", None
37
-
38
- X_train = pd.DataFrame(ds["train"])[feature_cols]
39
- y_train = pd.DataFrame(ds["train"])[target_col]
40
- X_test = pd.DataFrame(ds["test"])[feature_cols]
41
- y_test = pd.DataFrame(ds["test"])[target_col]
42
-
43
  model = RandomForestRegressor(n_estimators=100, random_state=42)
44
  model.fit(X_train, y_train)
45
  y_pred = model.predict(X_test)
46
  r2 = r2_score(y_test, y_pred)
47
-
48
  fig, ax = plt.subplots()
49
  ax.scatter(y_test, y_pred, alpha=0.5)
50
  ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
51
  ax.set_xlabel("True value")
52
  ax.set_ylabel("Predicted")
53
  ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
54
- return fig
55
 
56
  # Интерфейс
57
  with gr.Blocks(title="QSBench Demo Explorer") as demo:
@@ -63,26 +67,25 @@ with gr.Blocks(title="QSBench Demo Explorer") as demo:
63
  👉 **Full datasets (up to 200k samples, noisy versions, 10‑qubit transpilation packs) are available for purchase.**
64
  [Visit the QSBench website](https://qsbench.github.io/)
65
  """)
66
-
67
  with gr.Tabs():
68
  with gr.TabItem("Data Explorer"):
69
- # Определяем доступные сплиты для выпадающего списка
70
- available_splits = list(ds.keys())
71
  split_selector = gr.Dropdown(
72
- choices=available_splits,
73
  label="Choose a split",
74
- value=available_splits[0] if available_splits else None
75
  )
76
  data_table = gr.Dataframe(label="First 10 rows", interactive=False)
77
  split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
78
  # Загружаем данные по умолчанию
79
- demo.load(fn=lambda: show_data(available_splits[0]), outputs=data_table)
80
-
81
  with gr.TabItem("Model Demo"):
82
  train_button = gr.Button("Train Random Forest")
83
  plot_output = gr.Plot()
84
- train_button.click(fn=train_model, outputs=plot_output)
85
-
 
86
  gr.Markdown("---")
87
  gr.Markdown("""
88
  ### Get the full datasets
 
4
  from sklearn.ensemble import RandomForestRegressor
5
  from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
 
7
 
8
+ # Загружаем датасет (все данные в одном сплите 'train')
9
  print("Loading dataset...")
10
+ ds_all = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
11
+ # Берём только сплит 'train' (там все строки)
12
+ df_all = pd.DataFrame(ds_all['train'])
13
+
14
+ # Разделяем по колонке 'split'
15
+ splits = {}
16
+ for split_name in df_all['split'].unique():
17
+ splits[split_name] = df_all[df_all['split'] == split_name].reset_index(drop=True)
18
+
19
+ print("Available splits:", list(splits.keys()))
20
 
21
  # Функция для отображения таблицы
22
  def show_data(split):
23
+ if split in splits:
24
+ return splits[split].head(10)
25
+ else:
26
+ return f"Split '{split}' not found"
 
27
 
28
  # Функция для обучения модели и создания графика
29
  def train_model():
30
+ # Проверяем, что есть нужные сплиты
31
+ if 'train' not in splits or 'test' not in splits:
32
+ return None, "Error: train or test split not found in dataset"
33
+
 
34
  feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
35
  target_col = "ideal_expval_Z_global"
36
+
37
  # Проверяем наличие колонок
38
+ if not all(col in splits['train'].columns for col in feature_cols + [target_col]):
39
+ missing = [col for col in feature_cols + [target_col] if col not in splits['train'].columns]
40
+ return None, f"Error: missing columns: {missing}"
41
+
42
+ X_train = splits['train'][feature_cols]
43
+ y_train = splits['train'][target_col]
44
+ X_test = splits['test'][feature_cols]
45
+ y_test = splits['test'][target_col]
46
+
 
47
  model = RandomForestRegressor(n_estimators=100, random_state=42)
48
  model.fit(X_train, y_train)
49
  y_pred = model.predict(X_test)
50
  r2 = r2_score(y_test, y_pred)
51
+
52
  fig, ax = plt.subplots()
53
  ax.scatter(y_test, y_pred, alpha=0.5)
54
  ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
55
  ax.set_xlabel("True value")
56
  ax.set_ylabel("Predicted")
57
  ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
58
+ return fig, f"R² score: {r2:.4f}"
59
 
60
  # Интерфейс
61
  with gr.Blocks(title="QSBench Demo Explorer") as demo:
 
67
  👉 **Full datasets (up to 200k samples, noisy versions, 10‑qubit transpilation packs) are available for purchase.**
68
  [Visit the QSBench website](https://qsbench.github.io/)
69
  """)
70
+
71
  with gr.Tabs():
72
  with gr.TabItem("Data Explorer"):
 
 
73
  split_selector = gr.Dropdown(
74
+ choices=list(splits.keys()),
75
  label="Choose a split",
76
+ value=list(splits.keys())[0] if splits else None
77
  )
78
  data_table = gr.Dataframe(label="First 10 rows", interactive=False)
79
  split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
80
  # Загружаем данные по умолчанию
81
+ demo.load(fn=lambda: show_data(list(splits.keys())[0]), outputs=data_table)
82
+
83
  with gr.TabItem("Model Demo"):
84
  train_button = gr.Button("Train Random Forest")
85
  plot_output = gr.Plot()
86
+ text_output = gr.Textbox(label="Result", interactive=False)
87
+ train_button.click(fn=train_model, outputs=[plot_output, text_output])
88
+
89
  gr.Markdown("---")
90
  gr.Markdown("""
91
  ### Get the full datasets