import os import joblib import numpy as np import pandas as pd from flask import Flask, jsonify, render_template_string, request from flask_cors import CORS from sklearn.datasets import load_iris from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier app = Flask(__name__) CORS(app) # Global variables for model and iris data model = None iris = None feature_names = None target_names = None def load_or_train_model(): """Load existing model or train new one if not exists""" global model, iris, feature_names, target_names # Load iris dataset iris = load_iris() feature_names = iris.feature_names target_names = iris.target_names model_path = 'iris_decision_tree_model.pkl' if os.path.exists(model_path): # Load existing model model = joblib.load(model_path) print("Model loaded from file") else: # Train new model print("Training new model...") X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) model = DecisionTreeClassifier(random_state=42) model.fit(X_train, y_train) # Save model joblib.dump(model, model_path) # Print accuracy y_pred = model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) print(f"Model trained with accuracy: {accuracy:.4f}") # Initialize model on startup load_or_train_model() @app.route('/') def home(): html = """ Iris Flower Classification API

🌸 Iris Flower Classification API

Masukkan nilai fitur bunga Iris untuk memprediksi spesiesnya:

Contoh Data:

Setosa: Sepal Length: 5.1, Sepal Width: 3.5, Petal Length: 1.4, Petal Width: 0.2

Versicolor: Sepal Length: 7.0, Sepal Width: 3.2, Petal Length: 4.7, Petal Width: 1.4

Virginica: Sepal Length: 6.3, Sepal Width: 3.3, Petal Length: 6.0, Petal Width: 2.5

""" return html @app.route('/predict', methods=['POST']) def predict_iris(): try: # Ambil data dari request data = request.json if not data: return jsonify({'error': 'No data provided'}), 400 # Validasi input required_fields = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] for field in required_fields: if field not in data: return jsonify({'error': f'Missing field: {field}'}), 400 if not isinstance(data[field], (int, float)): return jsonify({'error': f'Invalid value for {field}. Must be a number'}), 400 # Konversi ke array numpy features = np.array([[ data['sepal_length'], data['sepal_width'], data['petal_length'], data['petal_width'] ]]) # Validasi range nilai (opsional) if any(val < 0 for val in features[0]): return jsonify({'error': 'All feature values must be positive'}), 400 # Prediksi prediction = model.predict(features)[0] prediction_proba = model.predict_proba(features)[0] # Konversi ke nama spesies species = target_names[prediction] confidence = f"{prediction_proba[prediction]:.2%}" # Tambahan info untuk debugging probabilities = { target_names[i]: f"{prob:.2%}" for i, prob in enumerate(prediction_proba) } return jsonify({ 'species': species, 'species_code': int(prediction), 'confidence': confidence, 'all_probabilities': probabilities, 'input_features': { 'sepal_length': data['sepal_length'], 'sepal_width': data['sepal_width'], 'petal_length': data['petal_length'], 'petal_width': data['petal_width'] } }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/model-info', methods=['GET']) def model_info(): """Endpoint untuk mendapatkan informasi model""" try: # Dapatkan feature importance feature_importance = model.feature_importances_ feature_info = { feature_names[i]: float(importance) for i, importance in enumerate(feature_importance) } return jsonify({ 'model_type': 'Decision Tree Classifier', 'features': list(feature_names), 'target_classes': list(target_names), 'feature_importance': feature_info, 'tree_depth': model.get_depth(), 'number_of_leaves': model.get_n_leaves(), 'training_samples': len(iris.data) }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/health', methods=['GET']) def health(): return jsonify({ 'status': 'OK', 'message': 'Iris Classification API is running', 'model_loaded': model is not None }), 200 if __name__ == '__main__': app.run(debug=True, host='0.0.0.0', port=7860)