Spaces:
Running
Running
fix
Browse files
app.py
CHANGED
|
@@ -4,52 +4,54 @@ from datasets import load_dataset
|
|
| 4 |
from sklearn.ensemble import RandomForestRegressor
|
| 5 |
from sklearn.metrics import r2_score
|
| 6 |
import matplotlib.pyplot as plt
|
|
|
|
| 7 |
|
| 8 |
# Загружаем датасет
|
|
|
|
| 9 |
ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
|
|
|
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
available_splits = list(ds.keys())
|
| 13 |
-
print("Available splits:", available_splits) # для отладки в логах
|
| 14 |
-
|
| 15 |
-
# Функция для отображения данных выбранного сплита
|
| 16 |
def show_data(split):
|
| 17 |
try:
|
| 18 |
df = pd.DataFrame(ds[split])
|
| 19 |
return df.head(10)
|
| 20 |
except Exception as e:
|
| 21 |
-
|
| 22 |
-
return pd.DataFrame({"Error": [str(e)]})
|
| 23 |
|
| 24 |
# Функция для обучения модели и создания графика
|
| 25 |
def train_model():
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
|
| 31 |
-
target_col = "ideal_expval_Z_global"
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
X_test = pd.DataFrame(ds["test"])[feature_cols]
|
| 36 |
-
y_test = pd.DataFrame(ds["test"])[target_col]
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Интерфейс
|
| 55 |
with gr.Blocks(title="QSBench Demo Explorer") as demo:
|
|
@@ -64,17 +66,17 @@ with gr.Blocks(title="QSBench Demo Explorer") as demo:
|
|
| 64 |
|
| 65 |
with gr.Tabs():
|
| 66 |
with gr.TabItem("Data Explorer"):
|
| 67 |
-
# Определяем сп
|
| 68 |
-
|
| 69 |
split_selector = gr.Dropdown(
|
| 70 |
-
choices=
|
| 71 |
label="Choose a split",
|
| 72 |
-
value=
|
| 73 |
)
|
| 74 |
data_table = gr.Dataframe(label="First 10 rows", interactive=False)
|
| 75 |
split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
|
| 76 |
# Загружаем данные по умолчанию
|
| 77 |
-
demo.load(fn=lambda: show_data(
|
| 78 |
|
| 79 |
with gr.TabItem("Model Demo"):
|
| 80 |
train_button = gr.Button("Train Random Forest")
|
|
|
|
| 4 |
from sklearn.ensemble import RandomForestRegressor
|
| 5 |
from sklearn.metrics import r2_score
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
+
import sys
|
| 8 |
|
| 9 |
# Загружаем датасет
|
| 10 |
+
print("Loading dataset...")
|
| 11 |
ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
|
| 12 |
+
print("Available splits:", list(ds.keys())) # выводим в логи
|
| 13 |
|
| 14 |
+
# Функция для отображения таблицы
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def show_data(split):
|
| 16 |
try:
|
| 17 |
df = pd.DataFrame(ds[split])
|
| 18 |
return df.head(10)
|
| 19 |
except Exception as e:
|
| 20 |
+
return f"Error loading data: {e}"
|
|
|
|
| 21 |
|
| 22 |
# Функция для обучения модели и создания графика
|
| 23 |
def train_model():
|
| 24 |
+
# Определяем доступные сплиты
|
| 25 |
+
splits = list(ds.keys())
|
| 26 |
+
if "test" not in splits:
|
| 27 |
+
return "Error: 'test' split not found in dataset", None
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
|
| 30 |
+
target_col = "ideal_expval_Z_global"
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
# Проверяем наличие колонок
|
| 33 |
+
sample_df = pd.DataFrame(ds[splits[0]])
|
| 34 |
+
if not all(col in sample_df.columns for col in feature_cols + [target_col]):
|
| 35 |
+
missing = [col for col in feature_cols + [target_col] if col not in sample_df.columns]
|
| 36 |
+
return f"Error: missing columns: {missing}", None
|
| 37 |
|
| 38 |
+
X_train = pd.DataFrame(ds["train"])[feature_cols]
|
| 39 |
+
y_train = pd.DataFrame(ds["train"])[target_col]
|
| 40 |
+
X_test = pd.DataFrame(ds["test"])[feature_cols]
|
| 41 |
+
y_test = pd.DataFrame(ds["test"])[target_col]
|
| 42 |
+
|
| 43 |
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
| 44 |
+
model.fit(X_train, y_train)
|
| 45 |
+
y_pred = model.predict(X_test)
|
| 46 |
+
r2 = r2_score(y_test, y_pred)
|
| 47 |
+
|
| 48 |
+
fig, ax = plt.subplots()
|
| 49 |
+
ax.scatter(y_test, y_pred, alpha=0.5)
|
| 50 |
+
ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
|
| 51 |
+
ax.set_xlabel("True value")
|
| 52 |
+
ax.set_ylabel("Predicted")
|
| 53 |
+
ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
|
| 54 |
+
return fig
|
| 55 |
|
| 56 |
# Интерфейс
|
| 57 |
with gr.Blocks(title="QSBench Demo Explorer") as demo:
|
|
|
|
| 66 |
|
| 67 |
with gr.Tabs():
|
| 68 |
with gr.TabItem("Data Explorer"):
|
| 69 |
+
# Определяем доступные сплиты для выпадающего списка
|
| 70 |
+
available_splits = list(ds.keys())
|
| 71 |
split_selector = gr.Dropdown(
|
| 72 |
+
choices=available_splits,
|
| 73 |
label="Choose a split",
|
| 74 |
+
value=available_splits[0] if available_splits else None
|
| 75 |
)
|
| 76 |
data_table = gr.Dataframe(label="First 10 rows", interactive=False)
|
| 77 |
split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
|
| 78 |
# Загружаем данные по умолчанию
|
| 79 |
+
demo.load(fn=lambda: show_data(available_splits[0]), outputs=data_table)
|
| 80 |
|
| 81 |
with gr.TabItem("Model Demo"):
|
| 82 |
train_button = gr.Button("Train Random Forest")
|