cladyles commited on
Commit
4ca5c57
·
verified ·
1 Parent(s): 9b1461d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -4,30 +4,32 @@ import joblib
4
  import numpy as np
5
  from sklearn.ensemble import RandomForestRegressor
6
 
7
- # 模拟加载或创建模型
8
- try:
9
- model = joblib.load('rf_features.joblib')
10
- except:
11
- # 模拟数据训练
12
- X = pd.DataFrame(np.random.rand(100, 2), columns=['feature1', 'feature2'])
13
- y = X['feature1'] * 3 + X['feature2'] * 2 + np.random.rand(100)
14
- model = joblib.load("rf_features.joblib")
15
- model.fit(X, y)
16
- joblib.dump(model, 'model.pkl')
17
 
18
- # 预测函数
19
- def predict_energy(feature1, feature2):
20
- input_data = pd.DataFrame([[feature1, feature2]], columns=['feature1', 'feature2'])
21
- prediction = model.predict(input_data)
22
- return f'预测能耗为:{prediction[0]:.2f}'
 
 
 
 
 
 
 
 
23
 
24
- # Gradio 接口定义
25
  demo = gr.Interface(
26
  fn=predict_energy,
27
- inputs=[gr.Number(label='Feature 1'), gr.Number(label='Feature 2')],
28
- outputs='text',
29
- title='Energy Predictor',
30
- description='根据输入特征预测能耗'
31
  )
32
 
33
- demo.launch()
 
 
 
4
  import numpy as np
5
  from sklearn.ensemble import RandomForestRegressor
6
 
7
+ # ——— 加载你训练好的模型和特征名 ———
8
+ model = joblib.load("rf_model.joblib")
9
+ feature_names = joblib.load("rf_features.joblib") # 应当是像 ["feature1", "feature2", ...]
 
 
 
 
 
 
 
10
 
11
+ # ——— 定义预测函数 ———
12
+ def predict_energy(*inputs):
13
+ # 将输入打包成 DataFrame
14
+ df = pd.DataFrame([inputs], columns=feature_names)
15
+ # 调用模型预测
16
+ pred = model.predict(df)
17
+ return f"预测能耗为:{pred[0]:.2f}"
18
+
19
+ # ——— 构建 Gradio 界面 ———
20
+ # 动态为每个特征创建一个 Number 输入框
21
+ input_components = [
22
+ gr.Number(label=name) for name in feature_names
23
+ ]
24
 
 
25
  demo = gr.Interface(
26
  fn=predict_energy,
27
+ inputs=input_components,
28
+ outputs="text",
29
+ title="🏡 Energy Usage Predictor",
30
+ description="基于 RandomForest 已训练模型 (rf_model.joblib),输入若干特征,输出能耗预测结果"
31
  )
32
 
33
+ # ——— 启动 ———
34
+ if __name__ == "__main__":
35
+ demo.launch()