Spaces:
Sleeping
Sleeping
| import os | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from flask import Flask, jsonify, render_template_string, request | |
| from flask_cors import CORS | |
| from sklearn.datasets import load_iris | |
| from sklearn.metrics import accuracy_score | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.tree import DecisionTreeClassifier | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Global variables for model and iris data | |
| model = None | |
| iris = None | |
| feature_names = None | |
| target_names = None | |
| def load_or_train_model(): | |
| """Load existing model or train new one if not exists""" | |
| global model, iris, feature_names, target_names | |
| # Load iris dataset | |
| iris = load_iris() | |
| feature_names = iris.feature_names | |
| target_names = iris.target_names | |
| model_path = 'iris_decision_tree_model.pkl' | |
| if os.path.exists(model_path): | |
| # Load existing model | |
| model = joblib.load(model_path) | |
| print("Model loaded from file") | |
| else: | |
| # Train new model | |
| print("Training new model...") | |
| X = iris.data | |
| y = iris.target | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42 | |
| ) | |
| model = DecisionTreeClassifier(random_state=42) | |
| model.fit(X_train, y_train) | |
| # Save model | |
| joblib.dump(model, model_path) | |
| # Print accuracy | |
| y_pred = model.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| print(f"Model trained with accuracy: {accuracy:.4f}") | |
| # Initialize model on startup | |
| load_or_train_model() | |
| def home(): | |
| html = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Iris Flower Classification API</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 40px; background-color: #f5f5f5; } | |
| .container { max-width: 800px; margin: 0 auto; background: white; padding: 30px; border-radius: 10px; box-shadow: 0 0 20px rgba(0,0,0,0.1); } | |
| h1 { color: #2c3e50; text-align: center; } | |
| .feature { margin: 10px 0; } | |
| .feature label { display: inline-block; width: 200px; font-weight: bold; } | |
| .feature input { padding: 8px; width: 200px; border: 1px solid #ddd; border-radius: 4px; } | |
| button { background: #3498db; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; margin: 10px 5px; } | |
| button:hover { background: #2980b9; } | |
| .result { margin: 20px 0; padding: 15px; background: #ecf0f1; border-radius: 5px; } | |
| .examples { background: #f8f9fa; padding: 15px; border-radius: 5px; margin: 20px 0; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>🌸 Iris Flower Classification API</h1> | |
| <p>Masukkan nilai fitur bunga Iris untuk memprediksi spesiesnya:</p> | |
| <form id="irisForm"> | |
| <div class="feature"> | |
| <label>Sepal Length (cm):</label> | |
| <input type="number" step="0.1" id="sepal_length" placeholder="e.g., 5.1" required> | |
| </div> | |
| <div class="feature"> | |
| <label>Sepal Width (cm):</label> | |
| <input type="number" step="0.1" id="sepal_width" placeholder="e.g., 3.5" required> | |
| </div> | |
| <div class="feature"> | |
| <label>Petal Length (cm):</label> | |
| <input type="number" step="0.1" id="petal_length" placeholder="e.g., 1.4" required> | |
| </div> | |
| <div class="feature"> | |
| <label>Petal Width (cm):</label> | |
| <input type="number" step="0.1" id="petal_width" placeholder="e.g., 0.2" required> | |
| </div> | |
| <button type="submit">Prediksi Spesies</button> | |
| <button type="button" onclick="loadExample(1)">Contoh Setosa</button> | |
| <button type="button" onclick="loadExample(2)">Contoh Versicolor</button> | |
| <button type="button" onclick="loadExample(3)">Contoh Virginica</button> | |
| </form> | |
| <div id="result" class="result" style="display:none;"> | |
| <h3>Hasil Prediksi:</h3> | |
| <p id="prediction"></p> | |
| <p id="confidence"></p> | |
| </div> | |
| <div class="examples"> | |
| <h3>Contoh Data:</h3> | |
| <p><strong>Setosa:</strong> Sepal Length: 5.1, Sepal Width: 3.5, Petal Length: 1.4, Petal Width: 0.2</p> | |
| <p><strong>Versicolor:</strong> Sepal Length: 7.0, Sepal Width: 3.2, Petal Length: 4.7, Petal Width: 1.4</p> | |
| <p><strong>Virginica:</strong> Sepal Length: 6.3, Sepal Width: 3.3, Petal Length: 6.0, Petal Width: 2.5</p> | |
| </div> | |
| </div> | |
| <script> | |
| function loadExample(type) { | |
| if (type === 1) { | |
| document.getElementById('sepal_length').value = 5.1; | |
| document.getElementById('sepal_width').value = 3.5; | |
| document.getElementById('petal_length').value = 1.4; | |
| document.getElementById('petal_width').value = 0.2; | |
| } else if (type === 2) { | |
| document.getElementById('sepal_length').value = 7.0; | |
| document.getElementById('sepal_width').value = 3.2; | |
| document.getElementById('petal_length').value = 4.7; | |
| document.getElementById('petal_width').value = 1.4; | |
| } else if (type === 3) { | |
| document.getElementById('sepal_length').value = 6.3; | |
| document.getElementById('sepal_width').value = 3.3; | |
| document.getElementById('petal_length').value = 6.0; | |
| document.getElementById('petal_width').value = 2.5; | |
| } | |
| } | |
| document.getElementById('irisForm').addEventListener('submit', function(e) { | |
| e.preventDefault(); | |
| const data = { | |
| sepal_length: parseFloat(document.getElementById('sepal_length').value), | |
| sepal_width: parseFloat(document.getElementById('sepal_width').value), | |
| petal_length: parseFloat(document.getElementById('petal_length').value), | |
| petal_width: parseFloat(document.getElementById('petal_width').value) | |
| }; | |
| fetch('/predict', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify(data) | |
| }) | |
| .then(response => response.json()) | |
| .then(data => { | |
| if (data.error) { | |
| alert('Error: ' + data.error); | |
| } else { | |
| document.getElementById('prediction').innerHTML = | |
| `<strong>Spesies: ${data.species}</strong>`; | |
| document.getElementById('confidence').innerHTML = | |
| `Confidence: ${data.confidence}`; | |
| document.getElementById('result').style.display = 'block'; | |
| } | |
| }) | |
| .catch(error => { | |
| alert('Error: ' + error); | |
| }); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html | |
| def predict_iris(): | |
| try: | |
| # Ambil data dari request | |
| data = request.json | |
| if not data: | |
| return jsonify({'error': 'No data provided'}), 400 | |
| # Validasi input | |
| required_fields = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] | |
| for field in required_fields: | |
| if field not in data: | |
| return jsonify({'error': f'Missing field: {field}'}), 400 | |
| if not isinstance(data[field], (int, float)): | |
| return jsonify({'error': f'Invalid value for {field}. Must be a number'}), 400 | |
| # Konversi ke array numpy | |
| features = np.array([[ | |
| data['sepal_length'], | |
| data['sepal_width'], | |
| data['petal_length'], | |
| data['petal_width'] | |
| ]]) | |
| # Validasi range nilai (opsional) | |
| if any(val < 0 for val in features[0]): | |
| return jsonify({'error': 'All feature values must be positive'}), 400 | |
| # Prediksi | |
| prediction = model.predict(features)[0] | |
| prediction_proba = model.predict_proba(features)[0] | |
| # Konversi ke nama spesies | |
| species = target_names[prediction] | |
| confidence = f"{prediction_proba[prediction]:.2%}" | |
| # Tambahan info untuk debugging | |
| probabilities = { | |
| target_names[i]: f"{prob:.2%}" | |
| for i, prob in enumerate(prediction_proba) | |
| } | |
| return jsonify({ | |
| 'species': species, | |
| 'species_code': int(prediction), | |
| 'confidence': confidence, | |
| 'all_probabilities': probabilities, | |
| 'input_features': { | |
| 'sepal_length': data['sepal_length'], | |
| 'sepal_width': data['sepal_width'], | |
| 'petal_length': data['petal_length'], | |
| 'petal_width': data['petal_width'] | |
| } | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def model_info(): | |
| """Endpoint untuk mendapatkan informasi model""" | |
| try: | |
| # Dapatkan feature importance | |
| feature_importance = model.feature_importances_ | |
| feature_info = { | |
| feature_names[i]: float(importance) | |
| for i, importance in enumerate(feature_importance) | |
| } | |
| return jsonify({ | |
| 'model_type': 'Decision Tree Classifier', | |
| 'features': list(feature_names), | |
| 'target_classes': list(target_names), | |
| 'feature_importance': feature_info, | |
| 'tree_depth': model.get_depth(), | |
| 'number_of_leaves': model.get_n_leaves(), | |
| 'training_samples': len(iris.data) | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def health(): | |
| return jsonify({ | |
| 'status': 'OK', | |
| 'message': 'Iris Classification API is running', | |
| 'model_loaded': model is not None | |
| }), 200 | |
| if __name__ == '__main__': | |
| app.run(debug=True, host='0.0.0.0', port=7860) |