pgurazada1 commited on
Commit
20af462
·
verified ·
1 Parent(s): 8e5a70e

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +68 -0
inference.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Obtain Predictions for Machine Failure Predictor Model using Gradio Client
3
+ ======================================================================
4
+
5
+ This script connects to a deployed machine failure predictor model using Gradio Client,
6
+ fetches the dataset, preprocesses the data, and generates predictions for a
7
+ sample of test data using the deployed model. The resulting predictions are
8
+ stored in a list. A time delay of one second is added after each prediction
9
+ submission to avoid overloading the model server.
10
+ """
11
+
12
+ import time
13
+
14
+ from gradio_client import Client
15
+
16
+ from sklearn.datasets import fetch_openml
17
+ from sklearn.model_selection import train_test_split
18
+
19
+
20
+ client = Client("pgurazada1/machine-failure-predictor")
21
+
22
+ dataset = fetch_openml(data_id=42890, as_frame=True, parser="auto")
23
+
24
+ data_df = dataset.data
25
+
26
+ target = 'Machine failure'
27
+ numeric_features = [
28
+ 'Air temperature [K]',
29
+ 'Process temperature [K]',
30
+ 'Rotational speed [rpm]',
31
+ 'Torque [Nm]',
32
+ 'Tool wear [min]'
33
+ ]
34
+ categorical_features = ['Type']
35
+
36
+ X = data_df[numeric_features + categorical_features]
37
+ y = data_df[target]
38
+
39
+ Xtrain, Xtest, ytrain, ytest = train_test_split(
40
+ X, y,
41
+ test_size=0.2,
42
+ random_state=42
43
+ )
44
+
45
+ Xtest_sample = Xtest.sample(100)
46
+
47
+ Xtest_sample_rows = list(Xtest_sample.itertuples(index=False, name=None))
48
+
49
+ batch_predictions = []
50
+
51
+ for row in Xtest_sample_rows:
52
+ try:
53
+ job = client.submit(
54
+ air_temperature=row[0],
55
+ process_temperature=row[1],
56
+ rotational_speed=row[2],
57
+ torque=row[3],
58
+ tool_wear=row[4],
59
+ type=row[5],
60
+ api_name="/predict"
61
+ )
62
+
63
+ batch_predictions.append(job.result())
64
+
65
+ time.sleep(1)
66
+
67
+ except Exception as e:
68
+ print(e)