| import os |
| import json |
| import io |
| import base64 |
| from datetime import datetime |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms, models |
| import joblib |
| from PIL import Image |
| from flask import Flask, request, jsonify |
| from flask_cors import CORS |
| from supabase import create_client, Client |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "models") |
| model_path = os.path.join(MODEL_DIR, "svm_densenet201_rbf.joblib") |
| meta_path = os.path.join(MODEL_DIR, "metadata.json") |
|
|
| svm_model = None |
| class_names = None |
| IMG_SIZE = 224 |
|
|
| supabase_url = os.environ.get('SUPABASE_URL') |
| supabase_key = os.environ.get('SUPABASE_ANON_KEY') |
| supabase: Client = None |
|
|
| if supabase_url and supabase_key: |
| try: |
| supabase = create_client(supabase_url, supabase_key) |
| print("✓ Supabase client initialized") |
| except Exception as e: |
| print(f"⚠ Failed to initialize Supabase: {e}") |
| supabase = None |
| else: |
| print("⚠ Supabase credentials not found, predictions won't be saved to database") |
|
|
| def load_model(): |
| global svm_model, class_names, IMG_SIZE |
|
|
| try: |
| if os.path.exists(model_path): |
| svm_model = joblib.load(model_path) |
| print("✓ SVM model loaded successfully") |
| else: |
| print(f"⚠ Model file not found at {model_path}") |
| print(" Using simulation mode until model is uploaded") |
| svm_model = None |
|
|
| if os.path.exists(meta_path): |
| with open(meta_path, "r") as f: |
| meta = json.load(f) |
| class_names = meta.get("class_names", ["3 Bulan", "6 Bulan", "9 Bulan"]) |
| IMG_SIZE = meta.get("img_size", 224) |
| print(f"✓ Metadata loaded: {class_names}") |
| else: |
| class_names = ["3 Bulan", "6 Bulan", "9 Bulan"] |
| print(f"⚠ Metadata not found, using default classes: {class_names}") |
|
|
| except Exception as e: |
| print(f"Error loading model: {str(e)}") |
| svm_model = None |
| class_names = ["3 Bulan", "6 Bulan", "9 Bulan"] |
|
|
| densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT) |
| densenet.eval() |
| feature_extractor = densenet.features.to(device) |
| gap = nn.AdaptiveAvgPool2d((1, 1)).to(device) |
|
|
| eval_tfms = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], |
| [0.229, 0.224, 0.225]), |
| ]) |
|
|
| def decode_base64_image(base64_string): |
| if ',' in base64_string: |
| base64_string = base64_string.split(',')[1] |
|
|
| image_data = base64.b64decode(base64_string) |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") |
| return image |
|
|
| def preprocess_image(image): |
| x = eval_tfms(image).unsqueeze(0) |
| return x |
|
|
| @torch.no_grad() |
| def extract_features(img_tensor): |
| img_tensor = img_tensor.to(device) |
| feats = feature_extractor(img_tensor) |
| feats = torch.relu(feats) |
| feats = gap(feats) |
| feats = feats.view(feats.size(0), -1) |
| return feats.cpu().numpy() |
|
|
| def simulate_prediction(): |
| probabilities = np.random.dirichlet(np.ones(len(class_names)), size=1)[0] |
| pred_idx = int(np.argmax(probabilities)) |
| pred_label = class_names[pred_idx] |
| confidence = float(probabilities[pred_idx]) |
|
|
| return pred_label, confidence, probabilities |
|
|
| def predict_with_model(features): |
| proba = svm_model.predict_proba(features)[0] |
| pred_idx = int(np.argmax(proba)) |
| pred_label = class_names[pred_idx] |
| confidence = float(proba[pred_idx]) |
|
|
| return pred_label, confidence, proba |
|
|
| @app.route('/health', methods=['GET']) |
| def health_check(): |
| return jsonify({ |
| 'status': 'healthy', |
| 'model_loaded': svm_model is not None, |
| 'device': str(device), |
| 'classes': class_names |
| }) |
|
|
| def save_to_database(pred_label, confidence, prob_dict, mode, image_data_url=None): |
| if not supabase: |
| return None |
|
|
| try: |
| prediction_data = { |
| 'predicted_class': pred_label, |
| 'confidence': confidence, |
| 'probabilities': prob_dict, |
| 'mode': mode, |
| 'created_at': datetime.utcnow().isoformat() |
| } |
|
|
| if image_data_url: |
| prediction_data['image_data'] = image_data_url[:1000] |
|
|
| result = supabase.table('predictions').insert(prediction_data).execute() |
| return result.data[0] if result.data else None |
| except Exception as e: |
| print(f"⚠ Failed to save to database: {e}") |
| return None |
|
|
| @app.route('/classify', methods=['POST']) |
| def classify_image(): |
| try: |
| data = request.json |
|
|
| if not data or 'image' not in data: |
| return jsonify({'error': 'No image data provided'}), 400 |
|
|
| image_base64 = data['image'] |
| image = decode_base64_image(image_base64) |
|
|
| img_tensor = preprocess_image(image) |
|
|
| if svm_model is not None: |
| features = extract_features(img_tensor) |
| pred_label, confidence, probabilities = predict_with_model(features) |
| else: |
| pred_label, confidence, probabilities = simulate_prediction() |
|
|
| prob_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} |
| mode = 'real' if svm_model is not None else 'simulation' |
|
|
| db_record = save_to_database(pred_label, confidence, prob_dict, mode, data['image']) |
|
|
| response = { |
| 'predicted_class': pred_label, |
| 'confidence': confidence, |
| 'probabilities': prob_dict, |
| 'mode': mode |
| } |
|
|
| if db_record: |
| response['id'] = db_record.get('id') |
| response['saved_to_db'] = True |
| else: |
| response['saved_to_db'] = False |
|
|
| return jsonify(response) |
|
|
| except Exception as e: |
| return jsonify({ |
| 'error': 'Classification failed', |
| 'message': str(e) |
| }), 500 |
|
|
| @app.route('/reload-model', methods=['POST']) |
| def reload_model(): |
| try: |
| load_model() |
| return jsonify({ |
| 'status': 'success', |
| 'model_loaded': svm_model is not None, |
| 'classes': class_names |
| }) |
| except Exception as e: |
| return jsonify({ |
| 'status': 'error', |
| 'message': str(e) |
| }), 500 |
|
|
| if __name__ == '__main__': |
| os.makedirs(MODEL_DIR, exist_ok=True) |
| load_model() |
|
|
| port = int(os.environ.get('PORT', 5000)) |
| app.run(host='0.0.0.0', port=port, debug=False) |
|
|