iris / app.py
Kevinyogap's picture
Initial commit
824c383
import os
import joblib
import numpy as np
import pandas as pd
from flask import Flask, jsonify, render_template_string, request
from flask_cors import CORS
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
app = Flask(__name__)
CORS(app)
# Global variables for model and iris data
model = None
iris = None
feature_names = None
target_names = None
def load_or_train_model():
"""Load existing model or train new one if not exists"""
global model, iris, feature_names, target_names
# Load iris dataset
iris = load_iris()
feature_names = iris.feature_names
target_names = iris.target_names
model_path = 'iris_decision_tree_model.pkl'
if os.path.exists(model_path):
# Load existing model
model = joblib.load(model_path)
print("Model loaded from file")
else:
# Train new model
print("Training new model...")
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
# Save model
joblib.dump(model, model_path)
# Print accuracy
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Model trained with accuracy: {accuracy:.4f}")
# Initialize model on startup
load_or_train_model()
@app.route('/')
def home():
html = """
<!DOCTYPE html>
<html>
<head>
<title>Iris Flower Classification API</title>
<style>
body { font-family: Arial, sans-serif; margin: 40px; background-color: #f5f5f5; }
.container { max-width: 800px; margin: 0 auto; background: white; padding: 30px; border-radius: 10px; box-shadow: 0 0 20px rgba(0,0,0,0.1); }
h1 { color: #2c3e50; text-align: center; }
.feature { margin: 10px 0; }
.feature label { display: inline-block; width: 200px; font-weight: bold; }
.feature input { padding: 8px; width: 200px; border: 1px solid #ddd; border-radius: 4px; }
button { background: #3498db; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; margin: 10px 5px; }
button:hover { background: #2980b9; }
.result { margin: 20px 0; padding: 15px; background: #ecf0f1; border-radius: 5px; }
.examples { background: #f8f9fa; padding: 15px; border-radius: 5px; margin: 20px 0; }
</style>
</head>
<body>
<div class="container">
<h1>🌸 Iris Flower Classification API</h1>
<p>Masukkan nilai fitur bunga Iris untuk memprediksi spesiesnya:</p>
<form id="irisForm">
<div class="feature">
<label>Sepal Length (cm):</label>
<input type="number" step="0.1" id="sepal_length" placeholder="e.g., 5.1" required>
</div>
<div class="feature">
<label>Sepal Width (cm):</label>
<input type="number" step="0.1" id="sepal_width" placeholder="e.g., 3.5" required>
</div>
<div class="feature">
<label>Petal Length (cm):</label>
<input type="number" step="0.1" id="petal_length" placeholder="e.g., 1.4" required>
</div>
<div class="feature">
<label>Petal Width (cm):</label>
<input type="number" step="0.1" id="petal_width" placeholder="e.g., 0.2" required>
</div>
<button type="submit">Prediksi Spesies</button>
<button type="button" onclick="loadExample(1)">Contoh Setosa</button>
<button type="button" onclick="loadExample(2)">Contoh Versicolor</button>
<button type="button" onclick="loadExample(3)">Contoh Virginica</button>
</form>
<div id="result" class="result" style="display:none;">
<h3>Hasil Prediksi:</h3>
<p id="prediction"></p>
<p id="confidence"></p>
</div>
<div class="examples">
<h3>Contoh Data:</h3>
<p><strong>Setosa:</strong> Sepal Length: 5.1, Sepal Width: 3.5, Petal Length: 1.4, Petal Width: 0.2</p>
<p><strong>Versicolor:</strong> Sepal Length: 7.0, Sepal Width: 3.2, Petal Length: 4.7, Petal Width: 1.4</p>
<p><strong>Virginica:</strong> Sepal Length: 6.3, Sepal Width: 3.3, Petal Length: 6.0, Petal Width: 2.5</p>
</div>
</div>
<script>
function loadExample(type) {
if (type === 1) {
document.getElementById('sepal_length').value = 5.1;
document.getElementById('sepal_width').value = 3.5;
document.getElementById('petal_length').value = 1.4;
document.getElementById('petal_width').value = 0.2;
} else if (type === 2) {
document.getElementById('sepal_length').value = 7.0;
document.getElementById('sepal_width').value = 3.2;
document.getElementById('petal_length').value = 4.7;
document.getElementById('petal_width').value = 1.4;
} else if (type === 3) {
document.getElementById('sepal_length').value = 6.3;
document.getElementById('sepal_width').value = 3.3;
document.getElementById('petal_length').value = 6.0;
document.getElementById('petal_width').value = 2.5;
}
}
document.getElementById('irisForm').addEventListener('submit', function(e) {
e.preventDefault();
const data = {
sepal_length: parseFloat(document.getElementById('sepal_length').value),
sepal_width: parseFloat(document.getElementById('sepal_width').value),
petal_length: parseFloat(document.getElementById('petal_length').value),
petal_width: parseFloat(document.getElementById('petal_width').value)
};
fetch('/predict', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(data)
})
.then(response => response.json())
.then(data => {
if (data.error) {
alert('Error: ' + data.error);
} else {
document.getElementById('prediction').innerHTML =
`<strong>Spesies: ${data.species}</strong>`;
document.getElementById('confidence').innerHTML =
`Confidence: ${data.confidence}`;
document.getElementById('result').style.display = 'block';
}
})
.catch(error => {
alert('Error: ' + error);
});
});
</script>
</body>
</html>
"""
return html
@app.route('/predict', methods=['POST'])
def predict_iris():
try:
# Ambil data dari request
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
# Validasi input
required_fields = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
for field in required_fields:
if field not in data:
return jsonify({'error': f'Missing field: {field}'}), 400
if not isinstance(data[field], (int, float)):
return jsonify({'error': f'Invalid value for {field}. Must be a number'}), 400
# Konversi ke array numpy
features = np.array([[
data['sepal_length'],
data['sepal_width'],
data['petal_length'],
data['petal_width']
]])
# Validasi range nilai (opsional)
if any(val < 0 for val in features[0]):
return jsonify({'error': 'All feature values must be positive'}), 400
# Prediksi
prediction = model.predict(features)[0]
prediction_proba = model.predict_proba(features)[0]
# Konversi ke nama spesies
species = target_names[prediction]
confidence = f"{prediction_proba[prediction]:.2%}"
# Tambahan info untuk debugging
probabilities = {
target_names[i]: f"{prob:.2%}"
for i, prob in enumerate(prediction_proba)
}
return jsonify({
'species': species,
'species_code': int(prediction),
'confidence': confidence,
'all_probabilities': probabilities,
'input_features': {
'sepal_length': data['sepal_length'],
'sepal_width': data['sepal_width'],
'petal_length': data['petal_length'],
'petal_width': data['petal_width']
}
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/model-info', methods=['GET'])
def model_info():
"""Endpoint untuk mendapatkan informasi model"""
try:
# Dapatkan feature importance
feature_importance = model.feature_importances_
feature_info = {
feature_names[i]: float(importance)
for i, importance in enumerate(feature_importance)
}
return jsonify({
'model_type': 'Decision Tree Classifier',
'features': list(feature_names),
'target_classes': list(target_names),
'feature_importance': feature_info,
'tree_depth': model.get_depth(),
'number_of_leaves': model.get_n_leaves(),
'training_samples': len(iris.data)
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
return jsonify({
'status': 'OK',
'message': 'Iris Classification API is running',
'model_loaded': model is not None
}), 200
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=7860)