Spaces:
Running
Running
fix
Browse files
app.py
CHANGED
|
@@ -5,39 +5,51 @@ from sklearn.ensemble import RandomForestRegressor
|
|
| 5 |
from sklearn.metrics import r2_score
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
|
| 8 |
-
# Загружаем датасет
|
| 9 |
ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def show_data(split):
|
| 13 |
try:
|
| 14 |
df = pd.DataFrame(ds[split])
|
| 15 |
return df.head(10)
|
| 16 |
except Exception as e:
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
# Функция для обучения модели и создания графика
|
| 20 |
def train_model():
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# Интерфейс
|
| 43 |
with gr.Blocks(title="QSBench Demo Explorer") as demo:
|
|
@@ -52,15 +64,17 @@ with gr.Blocks(title="QSBench Demo Explorer") as demo:
|
|
| 52 |
|
| 53 |
with gr.Tabs():
|
| 54 |
with gr.TabItem("Data Explorer"):
|
|
|
|
|
|
|
| 55 |
split_selector = gr.Dropdown(
|
| 56 |
-
choices=
|
| 57 |
label="Choose a split",
|
| 58 |
-
value="train"
|
| 59 |
)
|
| 60 |
data_table = gr.Dataframe(label="First 10 rows", interactive=False)
|
| 61 |
split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
|
| 62 |
# Загружаем данные по умолчанию
|
| 63 |
-
demo.load(fn=lambda: show_data("train"), outputs=data_table)
|
| 64 |
|
| 65 |
with gr.TabItem("Model Demo"):
|
| 66 |
train_button = gr.Button("Train Random Forest")
|
|
|
|
| 5 |
from sklearn.metrics import r2_score
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
|
| 8 |
+
# Загружаем датасет
|
| 9 |
ds = load_dataset("QSBench/QSBench-Core-v1.0.0-demo")
|
| 10 |
|
| 11 |
+
# Определяем доступные сплиты (предполагаем train, val, test)
|
| 12 |
+
available_splits = list(ds.keys())
|
| 13 |
+
print("Available splits:", available_splits) # для отладки в логах
|
| 14 |
+
|
| 15 |
+
# Функция для отображения данных выбранного сплита
|
| 16 |
def show_data(split):
|
| 17 |
try:
|
| 18 |
df = pd.DataFrame(ds[split])
|
| 19 |
return df.head(10)
|
| 20 |
except Exception as e:
|
| 21 |
+
# Возвращаем пустой DataFrame с сообщением об ошибке
|
| 22 |
+
return pd.DataFrame({"Error": [str(e)]})
|
| 23 |
|
| 24 |
# Функция для обучения модели и создания графика
|
| 25 |
def train_model():
|
| 26 |
+
try:
|
| 27 |
+
# Проверяем, есть ли нужные сплиты
|
| 28 |
+
if "train" not in ds or "test" not in ds:
|
| 29 |
+
return None # или вернуть сообщение
|
| 30 |
+
feature_cols = ["total_gates", "gate_entropy", "meyer_wallach"]
|
| 31 |
+
target_col = "ideal_expval_Z_global"
|
| 32 |
|
| 33 |
+
X_train = pd.DataFrame(ds["train"])[feature_cols]
|
| 34 |
+
y_train = pd.DataFrame(ds["train"])[target_col]
|
| 35 |
+
X_test = pd.DataFrame(ds["test"])[feature_cols]
|
| 36 |
+
y_test = pd.DataFrame(ds["test"])[target_col]
|
| 37 |
|
| 38 |
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
| 39 |
+
model.fit(X_train, y_train)
|
| 40 |
+
y_pred = model.predict(X_test)
|
| 41 |
+
r2 = r2_score(y_test, y_pred)
|
| 42 |
|
| 43 |
+
fig, ax = plt.subplots()
|
| 44 |
+
ax.scatter(y_test, y_pred, alpha=0.5)
|
| 45 |
+
ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
|
| 46 |
+
ax.set_xlabel("True value")
|
| 47 |
+
ax.set_ylabel("Predicted")
|
| 48 |
+
ax.set_title(f"Predictions vs. Truth (R² = {r2:.4f})")
|
| 49 |
+
return fig
|
| 50 |
+
except Exception as e:
|
| 51 |
+
# В случае ошибки возвращаем сообщение как график? проще вернуть None и показать текст
|
| 52 |
+
raise e
|
| 53 |
|
| 54 |
# Интерфейс
|
| 55 |
with gr.Blocks(title="QSBench Demo Explorer") as demo:
|
|
|
|
| 64 |
|
| 65 |
with gr.Tabs():
|
| 66 |
with gr.TabItem("Data Explorer"):
|
| 67 |
+
# Определяем список сплитов из загруженного датасета
|
| 68 |
+
split_choices = available_splits if available_splits else ["train", "val", "test"]
|
| 69 |
split_selector = gr.Dropdown(
|
| 70 |
+
choices=split_choices,
|
| 71 |
label="Choose a split",
|
| 72 |
+
value=split_choices[0] if split_choices else "train"
|
| 73 |
)
|
| 74 |
data_table = gr.Dataframe(label="First 10 rows", interactive=False)
|
| 75 |
split_selector.change(fn=show_data, inputs=split_selector, outputs=data_table)
|
| 76 |
# Загружаем данные по умолчанию
|
| 77 |
+
demo.load(fn=lambda: show_data(split_choices[0] if split_choices else "train"), outputs=data_table)
|
| 78 |
|
| 79 |
with gr.TabItem("Model Demo"):
|
| 80 |
train_button = gr.Button("Train Random Forest")
|