QSBench commited on
Commit
9c18e8c
·
1 Parent(s): 00e8206
Files changed (1) hide show
  1. app.py +37 -23
app.py CHANGED
@@ -5,39 +5,51 @@ from sklearn.ensemble import RandomForestRegressor
5
  from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
7
 
8
- # Загружаем датасет один раз
9
  ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
10
 
11
- # Функция для отображения таблицы
 
 
 
 
12
  def show_data(split):
13
  try:
14
  df = pd.DataFrame(ds[split])
15
  return df.head(10)
16
  except Exception as e:
17
- return f"Error loading data: {e}"
 
18
 
19
  # Функция для обучения модели и создания графика
20
  def train_model():
21
- feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
22
- target_col = "ideal_expval_Z_global"
 
 
 
 
23
 
24
- X_train = pd.DataFrame(ds["train"])[feature_cols]
25
- y_train = pd.DataFrame(ds["train"])[target_col]
26
- X_test = pd.DataFrame(ds["test"])[feature_cols]
27
- y_test = pd.DataFrame(ds["test"])[target_col]
28
 
29
- model = RandomForestRegressor(n_estimators=100, random_state=42)
30
- model.fit(X_train, y_train)
31
- y_pred = model.predict(X_test)
32
- r2 = r2_score(y_test, y_pred)
33
 
34
- fig, ax = plt.subplots()
35
- ax.scatter(y_test, y_pred, alpha=0.5)
36
- ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
37
- ax.set_xlabel("True value")
38
- ax.set_ylabel("Predicted")
39
- ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
40
- return fig
 
 
 
41
 
42
  # Интерфейс
43
  with gr.Blocks(title="QSBench Demo Explorer") as demo:
@@ -52,15 +64,17 @@ with gr.Blocks(title="QSBench Demo Explorer") as demo:
52
 
53
  with gr.Tabs():
54
  with gr.TabItem("Data Explorer"):
 
 
55
  split_selector = gr.Dropdown(
56
- choices=["train", "validation", "test"],
57
  label="Choose a split",
58
- value="train"
59
  )
60
  data_table = gr.Dataframe(label="First 10 rows", interactive=False)
61
  split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
62
  # Загружаем данные по умолчанию
63
- demo.load(fn=lambda: show_data("train"), outputs=data_table)
64
 
65
  with gr.TabItem("Model Demo"):
66
  train_button = gr.Button("Train Random Forest")
 
5
  from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
7
 
8
+ # Загружаем датасет
9
  ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
10
 
11
+ # Определяем доступные сплиты (предполагаем train, val, test)
12
+ available_splits = list(ds.keys())
13
+ print("Available splits:", available_splits) # для отладки в логах
14
+
15
+ # Функция для отображения данных выбранного сплита
16
  def show_data(split):
17
  try:
18
  df = pd.DataFrame(ds[split])
19
  return df.head(10)
20
  except Exception as e:
21
+ # Возвращаем пустой DataFrame с сообщением об ошибке
22
+ return pd.DataFrame({"Error": [str(e)]})
23
 
24
  # Функция для обучения модели и создания графика
25
  def train_model():
26
+ try:
27
+ # Проверяем, есть ли нужные сплиты
28
+ if "train" not in ds or "test" not in ds:
29
+ return None # или вернуть сообщение
30
+ feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
31
+ target_col = "ideal_expval_Z_global"
32
 
33
+ X_train = pd.DataFrame(ds["train"])[feature_cols]
34
+ y_train = pd.DataFrame(ds["train"])[target_col]
35
+ X_test = pd.DataFrame(ds["test"])[feature_cols]
36
+ y_test = pd.DataFrame(ds["test"])[target_col]
37
 
38
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
39
+ model.fit(X_train, y_train)
40
+ y_pred = model.predict(X_test)
41
+ r2 = r2_score(y_test, y_pred)
42
 
43
+ fig, ax = plt.subplots()
44
+ ax.scatter(y_test, y_pred, alpha=0.5)
45
+ ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
46
+ ax.set_xlabel("True value")
47
+ ax.set_ylabel("Predicted")
48
+ ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
49
+ return fig
50
+ except Exception as e:
51
+ # В случае ошибки возвращаем сообщение как график? проще вернуть None и показать текст
52
+ raise e
53
 
54
  # Интерфейс
55
  with gr.Blocks(title="QSBench Demo Explorer") as demo:
 
64
 
65
  with gr.Tabs():
66
  with gr.TabItem("Data Explorer"):
67
+ # Определяем список сплитов из загруженного датасета
68
+ split_choices = available_splits if available_splits else ["train", "val", "test"]
69
  split_selector = gr.Dropdown(
70
+ choices=split_choices,
71
  label="Choose a split",
72
+ value=split_choices[0] if split_choices else "train"
73
  )
74
  data_table = gr.Dataframe(label="First 10 rows", interactive=False)
75
  split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
76
  # Загружаем данные по умолчанию
77
+ demo.load(fn=lambda: show_data(split_choices[0] if split_choices else "train"), outputs=data_table)
78
 
79
  with gr.TabItem("Model Demo"):
80
  train_button = gr.Button("Train Random Forest")