grfdjiwsd commited on
Commit
7b2a83e
·
verified ·
1 Parent(s): 356e235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py CHANGED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import sklearn.datasets as d
4
+ from sklearn.linear_model import *
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.preprocessing import StandardScaler
7
+ from sklearn.metrics import *
8
+ from sklearn.utils import all_estimators
9
+
10
+
11
+ import inspect
12
+ import pandas as pd
13
+ import sklearn.metrics as m
14
+
15
+ def predict(dataset, model, split, metrics):
16
+ pass
17
+ models = [cls for cls in all_estimators() if cls[0] == model]
18
+ if len(models) == 0:
19
+ return "Model not found"
20
+ model = models[0][1]()
21
+ data = getattr(d, dataset)()
22
+ X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=float(split))
23
+ scaler = StandardScaler()
24
+ X_train = scaler.fit_transform(X_train)
25
+ X_test = scaler.transform(X_test)
26
+ model.fit(X_train, y_train)
27
+ pred = model.predict(X_test)
28
+ df = pd.DataFrame(data.data, columns=data.feature_names)
29
+ df["target"] = data.target
30
+ scores = []
31
+ for metric in metrics:
32
+ try:
33
+ if hasattr(m, metric):
34
+ scores.append((metric, getattr(m, metric)(y_test, pred)))
35
+ except:
36
+ pass
37
+ scoress = pd.DataFrame(scores, columns=["metric", "score"])
38
+ return gr.Dataframe(scoress, headers=scoress.columns.tolist(), datatype=["numeric"] * len(df.columns))
39
+
40
+ demo = gr.Interface(fn=predict, inputs=[
41
+ gr.Dropdown([ name for name, obj in inspect.getmembers(d)
42
+ if inspect.isfunction(obj) and not name.startswith("_")], value="load_breast_cancer", label="Dataset"),
43
+ gr.Dropdown([name for name, cls in all_estimators()], value="RandomForestClassifier", label="Model"),
44
+ gr.Textbox(value="0.2", label="TrainTest Split"),
45
+ gr.CheckboxGroup([n for n in dir(m) if callable(getattr(m, n)) and not n.startswith("_")], label="metrics", value="accuracy_score")
46
+ ], outputs="dataframe")
47
+
48
+ demo.launch(share=True, debug=True)