import os import io import time import logging import threading from flask import Flask, request, jsonify, send_file from flask_cors import CORS from PIL import Image import torch import torchvision.models as models import torchvision.transforms as transforms from ultralytics import YOLO # Fix PyTorch 2.6+ weights_only loading issue if not hasattr(torch, "_original_load"): torch._original_load = torch.load def patched_torch_load(*args, **kwargs): kwargs["weights_only"] = False return torch._original_load(*args, **kwargs) torch.load = patched_torch_load # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" ) logger = logging.getLogger("PlantClassifierAPI") # --------------------------------------------------------------------------- # Hugging Face repository configuration # --------------------------------------------------------------------------- # PRIMARY: Set HF_MODEL_REPO to a dedicated HF *model* repo containing the # real binary .pth / .pt files (e.g. "yourname/plantcare-models"). # # FALLBACK: If HF_MODEL_REPO is not set, the app will try to download the # weights from the Space's own repository using SPACE_ID (which HF # injects automatically into every running Space). This works because # HF stores the real LFS bytes even though the Docker build context # only receives the pointer files. # # Either way, set HF_TOKEN for private repos. # --------------------------------------------------------------------------- HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "") # SPACE_ID is auto-injected by HF Spaces; fallback to the known Space ID SPACE_ID = os.environ.get("SPACE_ID", "Karim31003/classifier") HF_CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/tmp/hf") HF_TOKEN = os.environ.get("HF_TOKEN", None) # --------------------------------------------------------------------------- # Initialize Flask App app = Flask(__name__) CORS(app) # Limit request payload to 10MB app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 # Configuration options DEBUG_SAVE_IMAGES = os.environ.get("DEBUG_SAVE_IMAGES", "false").lower() == "true" UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "uploads") ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png'} ALLOWED_MIME_TYPES = {'image/jpeg', 'image/png', 'application/octet-stream'} # Global cache for loaded models models_cache = { "yolo": None, "disease_models": {} } # Detailed load reports for debugging model_load_reports = [] # Force CPU device device = torch.device("cpu") logger.info(f"Using hardware device: {device}") models_lock = threading.RLock() models_ready = threading.Event() models_load_error = None LFS_POINTER_PREFIX = "version https://git-lfs.github.com/spec/v1" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_DIR = os.path.join(BASE_DIR, "model") if not os.path.isdir(MODEL_DIR): logger.warning(f"Model directory not found at {MODEL_DIR}; falling back to {BASE_DIR}.") MODEL_DIR = BASE_DIR # --------------------------------------------------------------------------- # Git LFS / file diagnostics # --------------------------------------------------------------------------- def _log_file_diagnostics(path: str, label: str) -> None: """Log path, size, and the first 200 characters of a file for debugging.""" if not os.path.exists(path): logger.warning(f"[DIAG] {label}: file does not exist at {path}") return size = os.path.getsize(path) try: with open(path, "r", encoding="utf-8", errors="ignore") as fh: head = fh.read(200) except Exception as exc: head = f"" logger.info( f"[DIAG] {label}\n" f" path : {path}\n" f" size : {size:,} bytes\n" f" head : {repr(head)}" ) def _is_lfs_pointer(path: str) -> bool: """Return True if the file is a Git LFS pointer text file.""" if not os.path.exists(path): return False try: with open(path, "r", encoding="utf-8", errors="ignore") as fh: header = fh.read(256) return header.startswith(LFS_POINTER_PREFIX) except Exception: return False def validate_model_artifact(model_path: str, artifact_name: str) -> None: """ Ensure the model artifact is a real binary checkpoint, not a Git LFS pointer. Logs full diagnostics before raising so the Space logs are informative. """ _log_file_diagnostics(model_path, artifact_name) if not os.path.exists(model_path): raise FileNotFoundError(f"{artifact_name} not found at: {model_path}") if _is_lfs_pointer(model_path): raise RuntimeError( f"{artifact_name} at {model_path} is a Git LFS pointer, not the actual model file.\n" "Fix options:\n" " 1. Set HF_MODEL_REPO env var to your HF model repo ID so weights are downloaded automatically.\n" " 2. Run `git lfs pull` locally, verify files are binary, then push again.\n" " 3. Upload real weight files directly via the Hugging Face web UI." ) # --------------------------------------------------------------------------- # Hugging Face model download helpers # --------------------------------------------------------------------------- def _hf_hub_available() -> bool: """Return True if huggingface_hub is importable.""" try: import huggingface_hub # noqa: F401 return True except ImportError: logger.warning("huggingface_hub not installed — add it to requirements.txt.") return False def _download_from_hf(local_filename: str) -> "str | None": """ Download *local_filename* from the best available HF source, in priority order: 1. HF_MODEL_REPO (dedicated model repo, file at repo root) repo_type = "model" 2. SPACE_ID (the Space's own repo — LFS bytes are stored in HF's backend even though the Docker build only sees pointer files) repo_type = "space", path inside repo = "model/" Returns the local cached path on success, None on failure. """ if not _hf_hub_available(): return None from huggingface_hub import hf_hub_download os.makedirs(HF_CACHE_DIR, exist_ok=True) # Build candidate list: (repo_id, repo_path, repo_type) candidates = [] if HF_MODEL_REPO: candidates.append((HF_MODEL_REPO, local_filename, "model")) if SPACE_ID: candidates.append((SPACE_ID, f"model/{local_filename}", "space")) if not candidates: logger.error( f"Cannot auto-download '{local_filename}': neither HF_MODEL_REPO nor SPACE_ID is set.\n" " • Set HF_MODEL_REPO to a dedicated model repo (e.g. 'yourname/plantcare-models'), OR\n" " • The SPACE_ID env var should be injected automatically by HF Spaces — " "if missing, add it manually in Space Settings → Variables." ) return None for repo_id, repo_path, repo_type in candidates: try: logger.info( f"Downloading '{repo_path}' from {repo_type} repo '{repo_id}' → {HF_CACHE_DIR} ..." ) local_path = hf_hub_download( repo_id=repo_id, filename=repo_path, repo_type=repo_type, cache_dir=HF_CACHE_DIR, token=HF_TOKEN, force_download=False, # reuse cached copy on restarts ) size = os.path.getsize(local_path) logger.info(f"Downloaded '{local_filename}' → {local_path} ({size:,} bytes)") return local_path except Exception as exc: logger.warning( f"Download attempt failed for '{repo_path}' from '{repo_id}' ({repo_type}): {exc}" ) logger.error(f"All HF download attempts failed for '{local_filename}'.") return None def resolve_model_path(local_filename: str) -> "str | None": """ Return the best usable path for *local_filename*: 1. Local file in MODEL_DIR — if it exists and is a real binary, use it directly. 2. Otherwise attempt to download from HF (model repo or Space repo). 3. Return None if nothing works. """ local_path = os.path.join(MODEL_DIR, local_filename) # Fast path — local binary exists if os.path.exists(local_path) and not _is_lfs_pointer(local_path): logger.info(f"Using local model file: {local_path}") return local_path # Diagnostics before attempting download _log_file_diagnostics(local_path, local_filename) if _is_lfs_pointer(local_path): logger.warning(f"'{local_filename}' is a Git LFS pointer — downloading real weights from HF.") else: logger.warning(f"'{local_filename}' not found locally — attempting HF download.") downloaded = _download_from_hf(local_filename) if downloaded and os.path.exists(downloaded) and not _is_lfs_pointer(downloaded): return downloaded if downloaded: logger.error(f"HF download for '{local_filename}' produced another LFS pointer or empty file.") return None def find_model_path(base_dir, *candidates): """Return the first existing, non-LFS model file path from candidate names.""" for candidate in candidates: candidate_path = os.path.join(base_dir, candidate) # Prefer a resolved (possibly downloaded) path over a raw local path resolved = resolve_model_path(candidate) if resolved: if candidate != candidates[0]: logger.warning(f"Using fallback model file '{candidate}' instead of '{candidates[0]}'") return resolved # Fall back to bare existence check (covers edge-cases where MODEL_DIR != base_dir) if os.path.exists(candidate_path) and not _is_lfs_pointer(candidate_path): return candidate_path raise FileNotFoundError( f"No valid (non-LFS) model file found in {base_dir}. Tried: {', '.join(candidates)}" ) # --------------------------------------------------------------------------- # Model architecture definitions # --------------------------------------------------------------------------- class MangoCNN(torch.nn.Module): def __init__(self, num_classes=8): super().__init__() self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = torch.nn.BatchNorm2d(32) self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = torch.nn.BatchNorm2d(64) self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = torch.nn.BatchNorm2d(128) self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=3, padding=1) self.bn4 = torch.nn.BatchNorm2d(256) self.pool = torch.nn.MaxPool2d(2, 2) self.fc1 = torch.nn.Linear(256 * 14 * 14, 512) self.fc2 = torch.nn.Linear(512, num_classes) def forward(self, x): x = self.pool(torch.nn.functional.relu(self.bn1(self.conv1(x)))) x = self.pool(torch.nn.functional.relu(self.bn2(self.conv2(x)))) x = self.pool(torch.nn.functional.relu(self.bn3(self.conv3(x)))) x = self.pool(torch.nn.functional.relu(self.bn4(self.conv4(x)))) x = x.view(x.size(0), -1) x = torch.nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x def get_preprocess(plant_type): """Returns model-specific image preprocessing pipeline.""" size = 380 if plant_type == "potato" else 224 return transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_disease_model(plant_type, base_dir): """Load and return a (model, classes) tuple for the given plant type.""" plant_type = plant_type.lower() if plant_type == "apple": model_path = find_model_path(base_dir, "apple_efficientnet_best.pth", "apple_model.pth") model_type = "efficientnet_b0" elif plant_type == "tomato": model_path = find_model_path(base_dir, "tomato_efficientnet_best.pth", "tomato_model.pth") model_type = "efficientnet_b0" elif plant_type == "potato": model_path = find_model_path(base_dir, "potato_efficientnet_best.pth", "potato_model.pth") model_type = "efficientnet_b4" elif plant_type == "cucumber": model_path = find_model_path(base_dir, "best_cucumber_resnet18.pth", "best_model.pth") model_type = "resnet18" elif plant_type == "lemon": model_path = find_model_path(base_dir, "best_lemon_model_ema.pth", "lemon_model.pth") model_type = "efficientnet_b0_timm" elif plant_type == "mango": model_path = find_model_path(base_dir, "best_mango_model.pth", "mango_model.pth") model_type = "mango_cnn" elif plant_type in ["soybean", "soyabean"]: model_path = find_model_path(base_dir, "soybean_disease_model_efficientnetb2.pth", "soybean_model.pth") model_type = "efficientnet_b2" else: logger.warning(f"No disease classification mapping defined for plant type: {plant_type}") return None, None # validate_model_artifact is now only a final guard – resolve_model_path already handles # the LFS-pointer case before we get here, but keep it as a safety net. validate_model_artifact(model_path, f"{plant_type} disease model") logger.info(f"Loading {plant_type} disease model from {os.path.basename(model_path)} ...") checkpoint = torch.load(model_path, map_location="cpu") # Resolve class list if isinstance(checkpoint, dict) and "classes" in checkpoint: classes = checkpoint["classes"] elif isinstance(checkpoint, dict) and "class_names" in checkpoint: classes = checkpoint["class_names"] else: if plant_type == "cucumber": classes = [ 'Anthracnose', 'Bacterial Wilt', 'Belly Rot', 'Downy Mildew', 'Fresh Cucumber', 'Fresh Leaf', 'Gummy Stem Blight', 'Pythium Fruit Rot' ] elif plant_type == "lemon": classes = [ 'Anthracnose', 'Bacterial Blight', 'Citrus Canker', 'Curl Virus', 'Deficiency Leaf', 'Dry Leaf', 'Healthy Leaf', 'Sooty Mould', 'Spider Mites' ] elif plant_type == "mango": classes = [ 'Anthracnose', 'Bacterial Canker', 'Cutting Weevil', 'Die Back', 'Gall Midge', 'Healthy', 'Powdery Mildew', 'Sooty Mould' ] elif plant_type in ["soybean", "soyabean"]: classes = [ 'Bacterial Blight', 'Cercospora Leaf Blight', 'Healthy', 'Soybean Rust', 'Sudden Death Syndrome' ] else: classes = [] # Instantiate architecture and load weights if model_type == "efficientnet_b0": model = models.efficientnet_b0(weights=None) in_features = model.classifier[1].in_features model.classifier[1] = torch.nn.Linear(in_features, len(classes)) model.load_state_dict(checkpoint["model_state_dict"]) elif model_type == "efficientnet_b2": model = models.efficientnet_b2(weights=None) in_features = model.classifier[1].in_features model.classifier[1] = torch.nn.Linear(in_features, len(classes)) model.load_state_dict(checkpoint) elif model_type == "efficientnet_b4": import timm model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=len(classes)) model.load_state_dict(checkpoint["model_state_dict"]) elif model_type == "efficientnet_b0_timm": import timm model = timm.create_model('efficientnet_b0', pretrained=False) model.classifier = torch.nn.Sequential( torch.nn.Dropout(p=0.2), torch.nn.Linear(1280, len(classes)) ) new_state_dict = { k[7:] if k.startswith("module.") else k: v for k, v in checkpoint.items() if k != "n_averaged" } model.load_state_dict(new_state_dict) elif model_type == "resnet18": model = models.resnet18(weights=None) model.fc = torch.nn.Sequential( torch.nn.Linear(512, 256), torch.nn.ReLU(), torch.nn.Dropout(0.5), torch.nn.Linear(256, len(classes)) ) model.load_state_dict(checkpoint) elif model_type == "mango_cnn": model = MangoCNN(num_classes=len(classes)) model.load_state_dict(checkpoint) else: logger.error(f"Unknown model architecture type: {model_type}") return None, None model = model.to(device) model.eval() logger.info(f"Successfully loaded {plant_type} model ({len(classes)} classes).") return model, classes # --------------------------------------------------------------------------- # Startup model pre-loading # --------------------------------------------------------------------------- def preload_all_models(): """Load and cache all detection and classification models at server startup.""" global models_load_error, model_load_reports with models_lock: start_time = time.time() logger.info("=== Starting model pre-load ===") logger.info(f"HF_MODEL_REPO : '{HF_MODEL_REPO or '(not set)'}'") logger.info(f"SPACE_ID : '{SPACE_ID or '(not set)'}'") logger.info(f"HF_CACHE_DIR : '{HF_CACHE_DIR}'") model_load_reports = [] try: model_files = os.listdir(MODEL_DIR) except Exception: model_files = [] logger.info(f"MODEL_DIR: {MODEL_DIR} | files: {model_files}") # ------------------------------------------------------------------ # # 1. YOLO detector # # ------------------------------------------------------------------ # try: yolo_path = find_model_path(MODEL_DIR, "best.pt", "detector.pt") validate_model_artifact(yolo_path, "YOLO model") logger.info(f"Loading YOLO model from {yolo_path} ...") models_cache["yolo"] = YOLO(yolo_path) logger.info("YOLO model loaded successfully.") except Exception as exc: models_load_error = str(exc) logger.critical(f"YOLO model loading failed: {exc}", exc_info=True) models_ready.set() return # ------------------------------------------------------------------ # # 2. Plant disease CNN models # # ------------------------------------------------------------------ # supported_plants = ["apple", "tomato", "potato", "cucumber", "lemon", "mango", "soybean"] for plant in supported_plants: try: model, classes = load_disease_model(plant, MODEL_DIR) if model is not None: models_cache["disease_models"][plant] = (model, classes) if plant == "soybean": models_cache["disease_models"]["soyabean"] = (model, classes) model_load_reports.append({ "plant": plant, "status": "loaded", "classes_count": len(classes) }) except Exception as exc: logger.error(f"Failed to load model for {plant}: {exc}", exc_info=True) model_load_reports.append({ "plant": plant, "status": "error", "error": str(exc) }) elapsed = time.time() - start_time loaded = [r["plant"] for r in model_load_reports if r["status"] == "loaded"] failed = [r["plant"] for r in model_load_reports if r["status"] == "error"] logger.info( f"=== Model pre-load done in {elapsed:.1f}s | " f"loaded: {loaded} | failed: {failed} ===" ) models_load_error = None models_ready.set() def ensure_models_loaded(): """Trigger model loading on demand if not already done.""" global models_load_error if models_cache["yolo"] is not None: models_ready.set() return True with models_lock: if models_cache["yolo"] is not None: models_ready.set() return True try: preload_all_models() except Exception as exc: models_load_error = str(exc) logger.error(f"Model loading failed: {exc}", exc_info=True) return False return models_cache["yolo"] is not None and models_load_error is None # --------------------------------------------------------------------------- # Flask helpers # --------------------------------------------------------------------------- def allowed_file(filename, mime_type): ext = os.path.splitext(filename)[1].lower() valid = ext in ALLOWED_EXTENSIONS and mime_type in ALLOWED_MIME_TYPES if not valid: logger.warning(f"File rejected — name: '{filename}', ext: '{ext}', mime: '{mime_type}'") return valid # --------------------------------------------------------------------------- # Flask routes # --------------------------------------------------------------------------- @app.errorhandler(413) def request_entity_too_large(error): return jsonify({"success": False, "message": "File size exceeds the 10 MB limit"}), 413 @app.route("/", methods=["GET"]) def index(): return send_file("main.html") @app.route("/health", methods=["GET"]) def health_check(): loaded_disease_models = list(models_cache["disease_models"].keys()) if "soyabean" in loaded_disease_models and "soybean" in loaded_disease_models: loaded_disease_models.remove("soyabean") return jsonify({ "status": "healthy" if models_cache["yolo"] is not None and models_load_error is None else "starting", "device": str(device), "yolo_loaded": models_cache["yolo"] is not None, "loaded_disease_models": loaded_disease_models, "model_load_reports": model_load_reports, "debug_save_images": DEBUG_SAVE_IMAGES, "models_loading": models_cache["yolo"] is None, "load_error": models_load_error, "hf_model_repo": HF_MODEL_REPO or "(not configured)", "space_id": SPACE_ID or "(not detected)", }), 200 @app.route("/predict", methods=["POST"]) def predict(): """Two-stage inference: YOLO plant detection → per-plant disease classification.""" if not ensure_models_loaded(): return jsonify({ "success": False, "message": models_load_error or "Models are still loading. Please retry shortly." }), 503 if 'image' not in request.files: return jsonify({"success": False, "message": "No file in form-data key 'image'"}), 400 file = request.files['image'] if file.filename == '': return jsonify({"success": False, "message": "No file selected for upload"}), 400 mime_type = file.content_type logger.info(f"Received file: '{file.filename}', MIME: '{mime_type}'") if not allowed_file(file.filename, mime_type): return jsonify({ "success": False, "message": "Invalid file format. Only JPG, JPEG, and PNG are allowed." }), 400 try: img_bytes = file.read() if DEBUG_SAVE_IMAGES: os.makedirs(UPLOAD_FOLDER, exist_ok=True) debug_filename = f"img_{int(time.time() * 1000)}{os.path.splitext(file.filename)[1]}" debug_path = os.path.join(UPLOAD_FOLDER, debug_filename) with open(debug_path, "wb") as fh: fh.write(img_bytes) logger.info(f"Debug: saved image to {debug_path}") img = Image.open(io.BytesIO(img_bytes)).convert("RGB") yolo_model = models_cache["yolo"] if yolo_model is None: return jsonify({"success": False, "message": "YOLO model is not loaded"}), 500 device_str = "cuda" if device.type == "cuda" else "cpu" results = yolo_model.predict( source=img, conf=0.25, iou=0.45, save=False, save_txt=False, verbose=False, device=device_str ) result = results[0] if len(result.boxes) == 0: return jsonify({"success": False, "message": "No plant detected in the image"}), 200 detections = [] for box in result.boxes: class_id = int(box.cls[0]) detection_confidence = float(box.conf[0]) plant_type = yolo_model.names[class_id].lower() x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) if plant_type in models_cache["disease_models"]: disease_model, classes = models_cache["disease_models"][plant_type] cropped_img = img.crop((x1, y1, x2, y2)) preprocess = get_preprocess(plant_type) img_tensor = preprocess(cropped_img).unsqueeze(0).to(device) with torch.no_grad(): outputs = disease_model(img_tensor) probs = torch.nn.functional.softmax(outputs[0], dim=0) pred_idx = torch.argmax(probs).item() disease_name = classes[pred_idx] disease_confidence = float(probs[pred_idx].item()) else: disease_name = "Unknown Disease (Unsupported Plant)" disease_confidence = 0.0 logger.warning(f"No disease model cached for '{plant_type}'.") detections.append({ "plant_type": plant_type, "detection_confidence": round(detection_confidence, 4), "box": [x1, y1, x2, y2], "disease": disease_name, "disease_confidence": round(disease_confidence, 4) }) return jsonify({"success": True, "detections": detections}), 200 except Exception as exc: logger.error(f"Inference pipeline failed: {exc}", exc_info=True) return jsonify({ "success": False, "message": f"Server error during inference: {exc}" }), 500 # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": # Kick off model loading in a background thread so the HTTP server starts immediately. threading.Thread(target=preload_all_models, daemon=True).start() port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)