Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import io | |
| import base64 | |
| from datetime import datetime | |
| from threading import Lock | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| import joblib | |
| from PIL import Image | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from supabase import create_client, Client | |
| # ========================= | |
| # Flask App | |
| # ========================= | |
| app = Flask(__name__) | |
| CORS(app) | |
| # ========================= | |
| # Device | |
| # ========================= | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # ========================= | |
| # Paths | |
| # ========================= | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "models") | |
| model_path = os.path.join(MODEL_DIR, "svm_densenet201_rbf.joblib") | |
| meta_path = os.path.join(MODEL_DIR, "metadata.json") | |
| # ========================= | |
| # Globals (Models & Config) | |
| # ========================= | |
| svm_model = None | |
| class_names = None | |
| IMG_SIZE = 224 | |
| # DenseNet globals | |
| densenet = None | |
| feature_extractor = None | |
| gap = None | |
| # Transform global (will be built after metadata loaded) | |
| eval_tfms = None | |
| # Load flags + lock (safe for concurrent requests) | |
| model_loaded = False | |
| densenet_loaded = False | |
| load_lock = Lock() | |
| # ========================= | |
| # Supabase | |
| # ========================= | |
| supabase_url = os.environ.get("SUPABASE_URL") | |
| supabase_key = os.environ.get("SUPABASE_ANON_KEY") | |
| supabase: Client = None | |
| if supabase_url and supabase_key: | |
| try: | |
| supabase = create_client(supabase_url, supabase_key) | |
| print("✓ Supabase client initialized") | |
| except Exception as e: | |
| print(f"⚠ Failed to initialize Supabase: {e}") | |
| supabase = None | |
| else: | |
| print("⚠ Supabase credentials not found, predictions won't be saved to database") | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def format_class_name(raw_name: str) -> str: | |
| """Convert usia_3_bulan to 3 Bulan for display""" | |
| mapping = { | |
| "usia_3_bulan": "3 Bulan", | |
| "usia_6_bulan": "6 Bulan", | |
| "usia_9_bulan": "9 Bulan" | |
| } | |
| return mapping.get(raw_name, raw_name) | |
| def build_eval_transforms(img_size: int): | |
| """Build transforms using current IMG_SIZE""" | |
| return transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| def decode_base64_image(base64_string: str) -> Image.Image: | |
| if "," in base64_string: | |
| base64_string = base64_string.split(",")[1] | |
| image_data = base64.b64decode(base64_string) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| return image | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| global eval_tfms | |
| if eval_tfms is None: | |
| # fallback if metadata not yet loaded | |
| eval_tfms = build_eval_transforms(IMG_SIZE) | |
| x = eval_tfms(image).unsqueeze(0) | |
| return x | |
| # ========================= | |
| # Loading: SVM + Metadata | |
| # ========================= | |
| def load_model(): | |
| """ | |
| Load SVM + metadata safely (works under gunicorn too). | |
| Lazy loaded on first request /classify. | |
| """ | |
| global svm_model, class_names, IMG_SIZE, model_loaded, eval_tfms | |
| if model_loaded: | |
| return | |
| with load_lock: | |
| if model_loaded: | |
| return | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| try: | |
| print(f"🔍 Checking model directory: {MODEL_DIR}") | |
| print(f" Model path: {model_path}") | |
| print(f" Metadata path: {meta_path}") | |
| print(f" Model exists: {os.path.exists(model_path)}") | |
| print(f" Metadata exists: {os.path.exists(meta_path)}") | |
| if os.path.exists(MODEL_DIR): | |
| files = os.listdir(MODEL_DIR) | |
| print(f" Files in models/: {files}") | |
| # ---- Load SVM ---- | |
| if os.path.exists(model_path): | |
| print("⏳ Loading SVM model...") | |
| svm_model = joblib.load(model_path) | |
| print("✓ SVM model loaded successfully") | |
| else: | |
| print(f"⚠ Model file not found at {model_path}") | |
| print(" Using simulation mode until model is uploaded") | |
| svm_model = None | |
| # ---- Load Metadata ---- | |
| if os.path.exists(meta_path): | |
| with open(meta_path, "r") as f: | |
| meta = json.load(f) | |
| class_names = meta.get("class_names", ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"]) | |
| IMG_SIZE = int(meta.get("img_size", 224)) | |
| print(f"✓ Metadata loaded: class_names={class_names}, IMG_SIZE={IMG_SIZE}") | |
| else: | |
| class_names = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"] | |
| IMG_SIZE = 224 | |
| print(f"⚠ Metadata not found, using default classes: {class_names}, IMG_SIZE={IMG_SIZE}") | |
| # IMPORTANT: rebuild transforms after IMG_SIZE updated | |
| eval_tfms = build_eval_transforms(IMG_SIZE) | |
| model_loaded = True | |
| except Exception as e: | |
| print(f"❌ Error loading model: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| svm_model = None | |
| class_names = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"] | |
| IMG_SIZE = 224 | |
| eval_tfms = build_eval_transforms(IMG_SIZE) | |
| model_loaded = True | |
| # ========================= | |
| # Loading: DenseNet201 | |
| # ========================= | |
| def load_densenet(): | |
| global densenet, feature_extractor, gap, densenet_loaded | |
| if densenet_loaded: | |
| return | |
| with load_lock: | |
| if densenet_loaded: | |
| return | |
| print("⏳ Loading DenseNet201 (first time may take a while)...") | |
| densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT) | |
| densenet.eval() | |
| feature_extractor = densenet.features.to(device) | |
| gap = nn.AdaptiveAvgPool2d((1, 1)).to(device) | |
| densenet_loaded = True | |
| print("✓ DenseNet201 loaded successfully") | |
| def extract_features(img_tensor: torch.Tensor) -> np.ndarray: | |
| load_densenet() | |
| img_tensor = img_tensor.to(device) | |
| feats = feature_extractor(img_tensor) | |
| feats = torch.relu(feats) | |
| feats = gap(feats) | |
| feats = feats.view(feats.size(0), -1) | |
| return feats.cpu().numpy() | |
| # ========================= | |
| # Prediction | |
| # ========================= | |
| def simulate_prediction(): | |
| if not class_names: | |
| _classes = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"] | |
| else: | |
| _classes = class_names | |
| probabilities = np.random.dirichlet(np.ones(len(_classes)), size=1)[0] | |
| pred_idx = int(np.argmax(probabilities)) | |
| pred_label = _classes[pred_idx] | |
| confidence = float(probabilities[pred_idx]) | |
| return pred_label, confidence, probabilities | |
| def predict_with_model(features: np.ndarray): | |
| proba = svm_model.predict_proba(features)[0] | |
| pred_idx = int(np.argmax(proba)) | |
| pred_label = class_names[pred_idx] | |
| confidence = float(proba[pred_idx]) | |
| return pred_label, confidence, proba | |
| # ========================= | |
| # Database Save | |
| # ========================= | |
| def save_to_database(pred_label, confidence, prob_dict, mode, image_data_url=None): | |
| if not supabase: | |
| return None | |
| try: | |
| prediction_data = { | |
| "predicted_class": pred_label, | |
| "confidence": float(confidence), | |
| "probabilities": prob_dict, | |
| "mode": mode, | |
| "created_at": datetime.utcnow().isoformat(), | |
| } | |
| if image_data_url: | |
| # truncate for safety | |
| prediction_data["image_data"] = image_data_url[:1000] | |
| # Save full image for display | |
| prediction_data["image_url"] = image_data_url | |
| result = supabase.table("predictions").insert(prediction_data).execute() | |
| return result.data[0] if result.data else None | |
| except Exception as e: | |
| print(f"⚠ Failed to save to database: {e}") | |
| return None | |
| # ========================= | |
| # Routes | |
| # ========================= | |
| def home(): | |
| return jsonify({ | |
| "service": "Seedling Classifier API", | |
| "status": "running", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "classify": "/classify (POST)", | |
| "reload_model": "/reload-model (POST)", | |
| "warmup": "/warmup (POST)", | |
| }, | |
| "note": "Open /health to verify. Use POST /classify with JSON {image: base64DataURL}." | |
| }) | |
| def health_check(): | |
| default_classes = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"] | |
| current_classes = class_names if class_names else default_classes | |
| display_classes = [format_class_name(c) for c in current_classes] | |
| return jsonify({ | |
| "status": "healthy", | |
| "model_loaded": svm_model is not None, | |
| "densenet_loaded": feature_extractor is not None, | |
| "device": str(device), | |
| "classes": display_classes, | |
| "ready": True | |
| }) | |
| def classify_image(): | |
| try: | |
| # Lazy-load model + metadata on first request | |
| if not model_loaded: | |
| load_model() | |
| data = request.get_json(silent=True) | |
| if not data or "image" not in data: | |
| return jsonify({"error": "No image data provided"}), 400 | |
| image_base64 = data["image"] | |
| image = decode_base64_image(image_base64) | |
| img_tensor = preprocess_image(image) | |
| # Use real model if available, else simulation mode | |
| if svm_model is not None: | |
| features = extract_features(img_tensor) | |
| pred_label, confidence, probabilities = predict_with_model(features) | |
| mode = "real" | |
| else: | |
| pred_label, confidence, probabilities = simulate_prediction() | |
| mode = "simulation" | |
| # Ensure class_names exists | |
| _classes = class_names if class_names else ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"] | |
| prob_dict = {format_class_name(_classes[i]): float(probabilities[i]) for i in range(len(_classes))} | |
| formatted_pred_label = format_class_name(pred_label) | |
| db_record = save_to_database(formatted_pred_label, confidence, prob_dict, mode, data.get("image")) | |
| response = { | |
| "predicted_class": formatted_pred_label, | |
| "confidence": float(confidence), | |
| "probabilities": prob_dict, | |
| "mode": mode, | |
| "saved_to_db": bool(db_record), | |
| } | |
| if db_record: | |
| response["id"] = db_record.get("id") | |
| return jsonify(response) | |
| except Exception as e: | |
| return jsonify({ | |
| "error": "Classification failed", | |
| "message": str(e) | |
| }), 500 | |
| def reload_model_route(): | |
| global model_loaded, svm_model, class_names, eval_tfms | |
| try: | |
| with load_lock: | |
| model_loaded = False | |
| svm_model = None | |
| class_names = None | |
| eval_tfms = None | |
| load_model() | |
| display_classes = [format_class_name(c) for c in class_names] if class_names else [] | |
| return jsonify({ | |
| "status": "success", | |
| "model_loaded": svm_model is not None, | |
| "classes": display_classes | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "status": "error", | |
| "message": str(e) | |
| }), 500 | |
| def warmup(): | |
| try: | |
| load_densenet() | |
| return jsonify({ | |
| "status": "success", | |
| "densenet_loaded": feature_extractor is not None, | |
| "device": str(device) | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "status": "error", | |
| "message": str(e) | |
| }), 500 | |
| # ========================= | |
| # Local run (optional) | |
| # ========================= | |
| if __name__ == "__main__": | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| print("🚀 Starting locally...") | |
| # Optional: uncomment to preload on local run | |
| # load_model() | |
| # load_densenet() | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) |