AdrianHR commited on
Commit
96bd3d6
·
1 Parent(s): cdba802

feat: Add app.py to run the gradio application

Browse files

App.py is reused from the iris part i.e. task1 of lab1.
A model is trained on data using the code skeletons from
the assignment. The model is hosted on hopsworks.

The app runs a gradio application where you can put
in values for the four features and then get the
predicted label from the model of that imaginary flower.

Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import requests
4
+ import hopsworks
5
+ import joblib
6
+ import pandas as pd
7
+
8
+ project = hopsworks.login()
9
+ fs = project.get_feature_store()
10
+
11
+
12
+ mr = project.get_model_registry()
13
+ model = mr.get_model("iris_model", version=1)
14
+ model_dir = model.download()
15
+ model = joblib.load(model_dir + "/iris_model.pkl")
16
+ print("Model downloaded")
17
+
18
+ def iris(sepal_length, sepal_width, petal_length, petal_width):
19
+ print("Calling function")
20
+ # df = pd.DataFrame([[sepal_length],[sepal_width],[petal_length],[petal_width]],
21
+ df = pd.DataFrame([[sepal_length,sepal_width,petal_length,petal_width]],
22
+ columns=['sepal_length','sepal_width','petal_length','petal_width'])
23
+ print("Predicting")
24
+ print(df)
25
+ # 'res' is a list of predictions returned as the label.
26
+ res = model.predict(df)
27
+ # We add '[0]' to the result of the transformed 'res', because 'res' is a list, and we only want
28
+ # the first element.
29
+ # print("Res: {0}").format(res)
30
+ print(res)
31
+ flower_url = "https://raw.githubusercontent.com/featurestoreorg/serverless-ml-course/main/src/01-module/assets/" + res[0] + ".png"
32
+ img = Image.open(requests.get(flower_url, stream=True).raw)
33
+ return img
34
+
35
+ demo = gr.Interface(
36
+ fn=iris,
37
+ title="Iris Flower Predictive Analytics",
38
+ description="Experiment with sepal/petal lengths/widths to predict which flower it is.",
39
+ allow_flagging="never",
40
+ inputs=[
41
+ gr.inputs.Number(default=2.0, label="sepal length (cm)"),
42
+ gr.inputs.Number(default=1.0, label="sepal width (cm)"),
43
+ gr.inputs.Number(default=2.0, label="petal length (cm)"),
44
+ gr.inputs.Number(default=1.0, label="petal width (cm)"),
45
+ ],
46
+ outputs=gr.Image(type="pil"))
47
+
48
+ demo.launch(debug=True)
49
+