QSBench commited on
Commit
c7ba23d
·
1 Parent(s): 9c18e8c
Files changed (1) hide show
  1. app.py +38 -36
app.py CHANGED
@@ -4,52 +4,54 @@ from datasets import load_dataset
4
  from sklearn.ensemble import RandomForestRegressor
5
  from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
 
7
 
8
  # Загружаем датасет
 
9
  ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
 
10
 
11
- # Определяем доступные сплиты (предполагаем train, val, test)
12
- available_splits = list(ds.keys())
13
- print("Available splits:", available_splits) # для отладки в логах
14
-
15
- # Функция для отображения данных выбранного сплита
16
  def show_data(split):
17
  try:
18
  df = pd.DataFrame(ds[split])
19
  return df.head(10)
20
  except Exception as e:
21
- # Возвращаем пустой DataFrame с сообщением об ошибке
22
- return pd.DataFrame({"Error": [str(e)]})
23
 
24
  # Функция для обучения модели и создания графика
25
  def train_model():
26
- try:
27
- # Проверяем, есть ли нужные сплиты
28
- if "train" not in ds or "test" not in ds:
29
- return None # или вернуть сообщение
30
- feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
31
- target_col = "ideal_expval_Z_global"
32
 
33
- X_train = pd.DataFrame(ds["train"])[feature_cols]
34
- y_train = pd.DataFrame(ds["train"])[target_col]
35
- X_test = pd.DataFrame(ds["test"])[feature_cols]
36
- y_test = pd.DataFrame(ds["test"])[target_col]
37
 
38
- model = RandomForestRegressor(n_estimators=100, random_state=42)
39
- model.fit(X_train, y_train)
40
- y_pred = model.predict(X_test)
41
- r2 = r2_score(y_test, y_pred)
 
42
 
43
- fig, ax = plt.subplots()
44
- ax.scatter(y_test, y_pred, alpha=0.5)
45
- ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
46
- ax.set_xlabel("True value")
47
- ax.set_ylabel("Predicted")
48
- ax.set_title(f"Predictions vs. Truth (= {r2:.4f})")
49
- return fig
50
- except Exception as e:
51
- # В случае ошибки возвращаем сообщение как график? проще вернуть None и показать текст
52
- raise e
 
 
 
 
 
 
 
53
 
54
  # Интерфейс
55
  with gr.Blocks(title="QSBench Demo Explorer") as demo:
@@ -64,17 +66,17 @@ with gr.Blocks(title="QSBench Demo Explorer") as demo:
64
 
65
  with gr.Tabs():
66
  with gr.TabItem("Data Explorer"):
67
- # Определяем список сплитов из загруженного датасета
68
- split_choices = available_splits if available_splits else ["train", "val", "test"]
69
  split_selector = gr.Dropdown(
70
- choices=split_choices,
71
  label="Choose a split",
72
- value=split_choices[0] if split_choices else "train"
73
  )
74
  data_table = gr.Dataframe(label="First 10 rows", interactive=False)
75
  split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
76
  # Загружаем данные по умолчанию
77
- demo.load(fn=lambda: show_data(split_choices[0] if split_choices else "train"), outputs=data_table)
78
 
79
  with gr.TabItem("Model Demo"):
80
  train_button = gr.Button("Train Random Forest")
 
4
  from sklearn.ensemble import RandomForestRegressor
5
  from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
7
+ import sys
8
 
9
  # Загружаем датасет
10
+ print("Loading dataset...")
11
  ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
12
+ print("Available splits:", list(ds.keys())) # выводим в логи
13
 
14
+ # Функция для отображения таблицы
 
 
 
 
15
  def show_data(split):
16
  try:
17
  df = pd.DataFrame(ds[split])
18
  return df.head(10)
19
  except Exception as e:
20
+ return f"Error loading data: {e}"
 
21
 
22
  # Функция для обучения модели и создания графика
23
  def train_model():
24
+ # Определяем доступные сплиты
25
+ splits = list(ds.keys())
26
+ if "test" not in splits:
27
+ return "Error: 'test' split not found in dataset", None
 
 
28
 
29
+ feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
30
+ target_col = "ideal_expval_Z_global"
 
 
31
 
32
+ # Проверяем наличие колонок
33
+ sample_df = pd.DataFrame(ds[splits[0]])
34
+ if not all(col in sample_df.columns for col in feature_cols + [target_col]):
35
+ missing = [col for col in feature_cols + [target_col] if col not in sample_df.columns]
36
+ return f"Error: missing columns: {missing}", None
37
 
38
+ X_train = pd.DataFrame(ds["train"])[feature_cols]
39
+ y_train = pd.DataFrame(ds["train"])[target_col]
40
+ X_test = pd.DataFrame(ds["test"])[feature_cols]
41
+ y_test = pd.DataFrame(ds["test"])[target_col]
42
+
43
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
44
+ model.fit(X_train, y_train)
45
+ y_pred = model.predict(X_test)
46
+ r2 = r2_score(y_test, y_pred)
47
+
48
+ fig, ax = plt.subplots()
49
+ ax.scatter(y_test, y_pred, alpha=0.5)
50
+ ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
51
+ ax.set_xlabel("True value")
52
+ ax.set_ylabel("Predicted")
53
+ ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
54
+ return fig
55
 
56
  # Интерфейс
57
  with gr.Blocks(title="QSBench Demo Explorer") as demo:
 
66
 
67
  with gr.Tabs():
68
  with gr.TabItem("Data Explorer"):
69
+ # Определяем доступные сплиты для выпадающего списка
70
+ available_splits = list(ds.keys())
71
  split_selector = gr.Dropdown(
72
+ choices=available_splits,
73
  label="Choose a split",
74
+ value=available_splits[0] if available_splits else None
75
  )
76
  data_table = gr.Dataframe(label="First 10 rows", interactive=False)
77
  split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
78
  # Загружаем данные по умолчанию
79
+ demo.load(fn=lambda: show_data(available_splits[0]), outputs=data_table)
80
 
81
  with gr.TabItem("Model Demo"):
82
  train_button = gr.Button("Train Random Forest")