LovnishVerma commited on
Commit
96747d6
·
verified ·
1 Parent(s): b969798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -1,27 +1,32 @@
1
- from flask import Flask, render_template, request
2
- from sklearn.linear_model import LogisticRegression
3
  import pickle
 
 
 
 
4
 
5
  app = Flask(__name__)
 
6
 
7
- #predicting using model saved using pickle
8
- model = pickle.load(open("iris_model.pkl", "rb"))
 
 
 
 
 
 
 
9
 
10
- @app.route('/')
11
- def home():
12
- return render_template("index.html")
13
-
14
- @app.route('/predict', methods=["POST"])
15
  def predict():
16
- try:
17
- swidth = float(request.form.get("swidth"))
18
- sheight = float(request.form.get("sheight"))
19
- pwidth = float(request.form.get("pwidth"))
20
- pheight = float(request.form.get("pheight"))
21
- prediction = model.predict([[swidth, sheight, pwidth, pheight]])
22
- return render_template("index.html", data=prediction[0])
23
- except Exception as e:
24
- return render_template("index.html", data=f"Error: {str(e)}")
25
 
26
  if __name__ == "__main__":
27
- app.run()
 
 
 
1
  import pickle
2
+ from flask import Flask, request, jsonify
3
+ from flask_cors import CORS
4
+ from sklearn.datasets import load_iris
5
+ from sklearn.ensemble import RandomForestClassifier
6
 
7
  app = Flask(__name__)
8
+ CORS(app) # Allow frontend calls from GitHub Pages
9
 
10
+ # --- Train or load model ---
11
+ try:
12
+ model = pickle.load(open("model.pkl", "rb"))
13
+ except:
14
+ iris = load_iris()
15
+ X, y = iris.data, iris.target
16
+ model = RandomForestClassifier()
17
+ model.fit(X, y)
18
+ pickle.dump(model, open("model.pkl", "wb"))
19
 
20
+ # --- API route ---
21
+ @app.route("/predict", methods=["POST"])
 
 
 
22
  def predict():
23
+ data = request.json
24
+ features = [data["sepal_length"], data["sepal_width"],
25
+ data["petal_length"], data["petal_width"]]
26
+ prediction = model.predict([features])[0]
27
+
28
+ target_names = load_iris().target_names
29
+ return jsonify({"prediction": target_names[prediction]})
 
 
30
 
31
  if __name__ == "__main__":
32
+ app.run(host="0.0.0.0", port=7860)