| | import os |
| | import json |
| | import io |
| | import base64 |
| | from datetime import datetime |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torchvision import transforms, models |
| | import joblib |
| | from PIL import Image |
| | from flask import Flask, request, jsonify |
| | from flask_cors import CORS |
| | from supabase import create_client, Client |
| |
|
| | app = Flask(__name__) |
| | CORS(app) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | MODEL_DIR = os.path.join(os.path.dirname(__file__), "models") |
| | model_path = os.path.join(MODEL_DIR, "svm_densenet201_rbf.joblib") |
| | meta_path = os.path.join(MODEL_DIR, "metadata.json") |
| |
|
| | svm_model = None |
| | class_names = None |
| | IMG_SIZE = 224 |
| |
|
| | supabase_url = os.environ.get('SUPABASE_URL') |
| | supabase_key = os.environ.get('SUPABASE_ANON_KEY') |
| | supabase: Client = None |
| |
|
| | if supabase_url and supabase_key: |
| | try: |
| | supabase = create_client(supabase_url, supabase_key) |
| | print("✓ Supabase client initialized") |
| | except Exception as e: |
| | print(f"⚠ Failed to initialize Supabase: {e}") |
| | supabase = None |
| | else: |
| | print("⚠ Supabase credentials not found, predictions won't be saved to database") |
| |
|
| | def load_model(): |
| | global svm_model, class_names, IMG_SIZE |
| |
|
| | try: |
| | if os.path.exists(model_path): |
| | svm_model = joblib.load(model_path) |
| | print("✓ SVM model loaded successfully") |
| | else: |
| | print(f"⚠ Model file not found at {model_path}") |
| | print(" Using simulation mode until model is uploaded") |
| | svm_model = None |
| |
|
| | if os.path.exists(meta_path): |
| | with open(meta_path, "r") as f: |
| | meta = json.load(f) |
| | class_names = meta.get("class_names", ["3 Bulan", "6 Bulan", "9 Bulan"]) |
| | IMG_SIZE = meta.get("img_size", 224) |
| | print(f"✓ Metadata loaded: {class_names}") |
| | else: |
| | class_names = ["3 Bulan", "6 Bulan", "9 Bulan"] |
| | print(f"⚠ Metadata not found, using default classes: {class_names}") |
| |
|
| | except Exception as e: |
| | print(f"Error loading model: {str(e)}") |
| | svm_model = None |
| | class_names = ["3 Bulan", "6 Bulan", "9 Bulan"] |
| |
|
| | densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT) |
| | densenet.eval() |
| | feature_extractor = densenet.features.to(device) |
| | gap = nn.AdaptiveAvgPool2d((1, 1)).to(device) |
| |
|
| | eval_tfms = transforms.Compose([ |
| | transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], |
| | [0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | def decode_base64_image(base64_string): |
| | if ',' in base64_string: |
| | base64_string = base64_string.split(',')[1] |
| |
|
| | image_data = base64.b64decode(base64_string) |
| | image = Image.open(io.BytesIO(image_data)).convert("RGB") |
| | return image |
| |
|
| | def preprocess_image(image): |
| | x = eval_tfms(image).unsqueeze(0) |
| | return x |
| |
|
| | @torch.no_grad() |
| | def extract_features(img_tensor): |
| | img_tensor = img_tensor.to(device) |
| | feats = feature_extractor(img_tensor) |
| | feats = torch.relu(feats) |
| | feats = gap(feats) |
| | feats = feats.view(feats.size(0), -1) |
| | return feats.cpu().numpy() |
| |
|
| | def simulate_prediction(): |
| | probabilities = np.random.dirichlet(np.ones(len(class_names)), size=1)[0] |
| | pred_idx = int(np.argmax(probabilities)) |
| | pred_label = class_names[pred_idx] |
| | confidence = float(probabilities[pred_idx]) |
| |
|
| | return pred_label, confidence, probabilities |
| |
|
| | def predict_with_model(features): |
| | proba = svm_model.predict_proba(features)[0] |
| | pred_idx = int(np.argmax(proba)) |
| | pred_label = class_names[pred_idx] |
| | confidence = float(proba[pred_idx]) |
| |
|
| | return pred_label, confidence, proba |
| |
|
| | @app.route('/health', methods=['GET']) |
| | def health_check(): |
| | return jsonify({ |
| | 'status': 'healthy', |
| | 'model_loaded': svm_model is not None, |
| | 'device': str(device), |
| | 'classes': class_names |
| | }) |
| |
|
| | def save_to_database(pred_label, confidence, prob_dict, mode, image_data_url=None): |
| | if not supabase: |
| | return None |
| |
|
| | try: |
| | prediction_data = { |
| | 'predicted_class': pred_label, |
| | 'confidence': confidence, |
| | 'probabilities': prob_dict, |
| | 'mode': mode, |
| | 'created_at': datetime.utcnow().isoformat() |
| | } |
| |
|
| | if image_data_url: |
| | prediction_data['image_data'] = image_data_url[:1000] |
| |
|
| | result = supabase.table('predictions').insert(prediction_data).execute() |
| | return result.data[0] if result.data else None |
| | except Exception as e: |
| | print(f"⚠ Failed to save to database: {e}") |
| | return None |
| |
|
| | @app.route('/classify', methods=['POST']) |
| | def classify_image(): |
| | try: |
| | data = request.json |
| |
|
| | if not data or 'image' not in data: |
| | return jsonify({'error': 'No image data provided'}), 400 |
| |
|
| | image_base64 = data['image'] |
| | image = decode_base64_image(image_base64) |
| |
|
| | img_tensor = preprocess_image(image) |
| |
|
| | if svm_model is not None: |
| | features = extract_features(img_tensor) |
| | pred_label, confidence, probabilities = predict_with_model(features) |
| | else: |
| | pred_label, confidence, probabilities = simulate_prediction() |
| |
|
| | prob_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} |
| | mode = 'real' if svm_model is not None else 'simulation' |
| |
|
| | db_record = save_to_database(pred_label, confidence, prob_dict, mode, data['image']) |
| |
|
| | response = { |
| | 'predicted_class': pred_label, |
| | 'confidence': confidence, |
| | 'probabilities': prob_dict, |
| | 'mode': mode |
| | } |
| |
|
| | if db_record: |
| | response['id'] = db_record.get('id') |
| | response['saved_to_db'] = True |
| | else: |
| | response['saved_to_db'] = False |
| |
|
| | return jsonify(response) |
| |
|
| | except Exception as e: |
| | return jsonify({ |
| | 'error': 'Classification failed', |
| | 'message': str(e) |
| | }), 500 |
| |
|
| | @app.route('/reload-model', methods=['POST']) |
| | def reload_model(): |
| | try: |
| | load_model() |
| | return jsonify({ |
| | 'status': 'success', |
| | 'model_loaded': svm_model is not None, |
| | 'classes': class_names |
| | }) |
| | except Exception as e: |
| | return jsonify({ |
| | 'status': 'error', |
| | 'message': str(e) |
| | }), 500 |
| |
|
| | if __name__ == '__main__': |
| | os.makedirs(MODEL_DIR, exist_ok=True) |
| | load_model() |
| |
|
| | port = int(os.environ.get('PORT', 5000)) |
| | app.run(host='0.0.0.0', port=port, debug=False) |
| |
|