Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import time | |
| import logging | |
| import threading | |
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| from PIL import Image | |
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from ultralytics import YOLO | |
| # Fix PyTorch 2.6+ weights_only loading issue | |
| if not hasattr(torch, "_original_load"): | |
| torch._original_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| kwargs["weights_only"] = False | |
| return torch._original_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" | |
| ) | |
| logger = logging.getLogger("PlantClassifierAPI") | |
| # --------------------------------------------------------------------------- | |
| # Hugging Face repository configuration | |
| # --------------------------------------------------------------------------- | |
| # PRIMARY: Set HF_MODEL_REPO to a dedicated HF *model* repo containing the | |
| # real binary .pth / .pt files (e.g. "yourname/plantcare-models"). | |
| # | |
| # FALLBACK: If HF_MODEL_REPO is not set, the app will try to download the | |
| # weights from the Space's own repository using SPACE_ID (which HF | |
| # injects automatically into every running Space). This works because | |
| # HF stores the real LFS bytes even though the Docker build context | |
| # only receives the pointer files. | |
| # | |
| # Either way, set HF_TOKEN for private repos. | |
| # --------------------------------------------------------------------------- | |
| HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "") | |
| # SPACE_ID is auto-injected by HF Spaces; fallback to the known Space ID | |
| SPACE_ID = os.environ.get("SPACE_ID", "Karim31003/classifier") | |
| HF_CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/tmp/hf") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # --------------------------------------------------------------------------- | |
| # Initialize Flask App | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Limit request payload to 10MB | |
| app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 | |
| # Configuration options | |
| DEBUG_SAVE_IMAGES = os.environ.get("DEBUG_SAVE_IMAGES", "false").lower() == "true" | |
| UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "uploads") | |
| ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png'} | |
| ALLOWED_MIME_TYPES = {'image/jpeg', 'image/png', 'application/octet-stream'} | |
| # Global cache for loaded models | |
| models_cache = { | |
| "yolo": None, | |
| "disease_models": {} | |
| } | |
| # Detailed load reports for debugging | |
| model_load_reports = [] | |
| # Force CPU device | |
| device = torch.device("cpu") | |
| logger.info(f"Using hardware device: {device}") | |
| models_lock = threading.RLock() | |
| models_ready = threading.Event() | |
| models_load_error = None | |
| LFS_POINTER_PREFIX = "version https://git-lfs.github.com/spec/v1" | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "model") | |
| if not os.path.isdir(MODEL_DIR): | |
| logger.warning(f"Model directory not found at {MODEL_DIR}; falling back to {BASE_DIR}.") | |
| MODEL_DIR = BASE_DIR | |
| # --------------------------------------------------------------------------- | |
| # Git LFS / file diagnostics | |
| # --------------------------------------------------------------------------- | |
| def _log_file_diagnostics(path: str, label: str) -> None: | |
| """Log path, size, and the first 200 characters of a file for debugging.""" | |
| if not os.path.exists(path): | |
| logger.warning(f"[DIAG] {label}: file does not exist at {path}") | |
| return | |
| size = os.path.getsize(path) | |
| try: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as fh: | |
| head = fh.read(200) | |
| except Exception as exc: | |
| head = f"<could not read: {exc}>" | |
| logger.info( | |
| f"[DIAG] {label}\n" | |
| f" path : {path}\n" | |
| f" size : {size:,} bytes\n" | |
| f" head : {repr(head)}" | |
| ) | |
| def _is_lfs_pointer(path: str) -> bool: | |
| """Return True if the file is a Git LFS pointer text file.""" | |
| if not os.path.exists(path): | |
| return False | |
| try: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as fh: | |
| header = fh.read(256) | |
| return header.startswith(LFS_POINTER_PREFIX) | |
| except Exception: | |
| return False | |
| def validate_model_artifact(model_path: str, artifact_name: str) -> None: | |
| """ | |
| Ensure the model artifact is a real binary checkpoint, not a Git LFS pointer. | |
| Logs full diagnostics before raising so the Space logs are informative. | |
| """ | |
| _log_file_diagnostics(model_path, artifact_name) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"{artifact_name} not found at: {model_path}") | |
| if _is_lfs_pointer(model_path): | |
| raise RuntimeError( | |
| f"{artifact_name} at {model_path} is a Git LFS pointer, not the actual model file.\n" | |
| "Fix options:\n" | |
| " 1. Set HF_MODEL_REPO env var to your HF model repo ID so weights are downloaded automatically.\n" | |
| " 2. Run `git lfs pull` locally, verify files are binary, then push again.\n" | |
| " 3. Upload real weight files directly via the Hugging Face web UI." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Hugging Face model download helpers | |
| # --------------------------------------------------------------------------- | |
| def _hf_hub_available() -> bool: | |
| """Return True if huggingface_hub is importable.""" | |
| try: | |
| import huggingface_hub # noqa: F401 | |
| return True | |
| except ImportError: | |
| logger.warning("huggingface_hub not installed — add it to requirements.txt.") | |
| return False | |
| def _download_from_hf(local_filename: str) -> "str | None": | |
| """ | |
| Download *local_filename* from the best available HF source, in priority order: | |
| 1. HF_MODEL_REPO (dedicated model repo, file at repo root) | |
| repo_type = "model" | |
| 2. SPACE_ID (the Space's own repo — LFS bytes are stored in HF's | |
| backend even though the Docker build only sees pointer files) | |
| repo_type = "space", path inside repo = "model/<filename>" | |
| Returns the local cached path on success, None on failure. | |
| """ | |
| if not _hf_hub_available(): | |
| return None | |
| from huggingface_hub import hf_hub_download | |
| os.makedirs(HF_CACHE_DIR, exist_ok=True) | |
| # Build candidate list: (repo_id, repo_path, repo_type) | |
| candidates = [] | |
| if HF_MODEL_REPO: | |
| candidates.append((HF_MODEL_REPO, local_filename, "model")) | |
| if SPACE_ID: | |
| candidates.append((SPACE_ID, f"model/{local_filename}", "space")) | |
| if not candidates: | |
| logger.error( | |
| f"Cannot auto-download '{local_filename}': neither HF_MODEL_REPO nor SPACE_ID is set.\n" | |
| " • Set HF_MODEL_REPO to a dedicated model repo (e.g. 'yourname/plantcare-models'), OR\n" | |
| " • The SPACE_ID env var should be injected automatically by HF Spaces — " | |
| "if missing, add it manually in Space Settings → Variables." | |
| ) | |
| return None | |
| for repo_id, repo_path, repo_type in candidates: | |
| try: | |
| logger.info( | |
| f"Downloading '{repo_path}' from {repo_type} repo '{repo_id}' → {HF_CACHE_DIR} ..." | |
| ) | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=repo_path, | |
| repo_type=repo_type, | |
| cache_dir=HF_CACHE_DIR, | |
| token=HF_TOKEN, | |
| force_download=False, # reuse cached copy on restarts | |
| ) | |
| size = os.path.getsize(local_path) | |
| logger.info(f"Downloaded '{local_filename}' → {local_path} ({size:,} bytes)") | |
| return local_path | |
| except Exception as exc: | |
| logger.warning( | |
| f"Download attempt failed for '{repo_path}' from '{repo_id}' ({repo_type}): {exc}" | |
| ) | |
| logger.error(f"All HF download attempts failed for '{local_filename}'.") | |
| return None | |
| def resolve_model_path(local_filename: str) -> "str | None": | |
| """ | |
| Return the best usable path for *local_filename*: | |
| 1. Local file in MODEL_DIR — if it exists and is a real binary, use it directly. | |
| 2. Otherwise attempt to download from HF (model repo or Space repo). | |
| 3. Return None if nothing works. | |
| """ | |
| local_path = os.path.join(MODEL_DIR, local_filename) | |
| # Fast path — local binary exists | |
| if os.path.exists(local_path) and not _is_lfs_pointer(local_path): | |
| logger.info(f"Using local model file: {local_path}") | |
| return local_path | |
| # Diagnostics before attempting download | |
| _log_file_diagnostics(local_path, local_filename) | |
| if _is_lfs_pointer(local_path): | |
| logger.warning(f"'{local_filename}' is a Git LFS pointer — downloading real weights from HF.") | |
| else: | |
| logger.warning(f"'{local_filename}' not found locally — attempting HF download.") | |
| downloaded = _download_from_hf(local_filename) | |
| if downloaded and os.path.exists(downloaded) and not _is_lfs_pointer(downloaded): | |
| return downloaded | |
| if downloaded: | |
| logger.error(f"HF download for '{local_filename}' produced another LFS pointer or empty file.") | |
| return None | |
| def find_model_path(base_dir, *candidates): | |
| """Return the first existing, non-LFS model file path from candidate names.""" | |
| for candidate in candidates: | |
| candidate_path = os.path.join(base_dir, candidate) | |
| # Prefer a resolved (possibly downloaded) path over a raw local path | |
| resolved = resolve_model_path(candidate) | |
| if resolved: | |
| if candidate != candidates[0]: | |
| logger.warning(f"Using fallback model file '{candidate}' instead of '{candidates[0]}'") | |
| return resolved | |
| # Fall back to bare existence check (covers edge-cases where MODEL_DIR != base_dir) | |
| if os.path.exists(candidate_path) and not _is_lfs_pointer(candidate_path): | |
| return candidate_path | |
| raise FileNotFoundError( | |
| f"No valid (non-LFS) model file found in {base_dir}. Tried: {', '.join(candidates)}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model architecture definitions | |
| # --------------------------------------------------------------------------- | |
| class MangoCNN(torch.nn.Module): | |
| def __init__(self, num_classes=8): | |
| super().__init__() | |
| self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1) | |
| self.bn1 = torch.nn.BatchNorm2d(32) | |
| self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
| self.bn2 = torch.nn.BatchNorm2d(64) | |
| self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
| self.bn3 = torch.nn.BatchNorm2d(128) | |
| self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=3, padding=1) | |
| self.bn4 = torch.nn.BatchNorm2d(256) | |
| self.pool = torch.nn.MaxPool2d(2, 2) | |
| self.fc1 = torch.nn.Linear(256 * 14 * 14, 512) | |
| self.fc2 = torch.nn.Linear(512, num_classes) | |
| def forward(self, x): | |
| x = self.pool(torch.nn.functional.relu(self.bn1(self.conv1(x)))) | |
| x = self.pool(torch.nn.functional.relu(self.bn2(self.conv2(x)))) | |
| x = self.pool(torch.nn.functional.relu(self.bn3(self.conv3(x)))) | |
| x = self.pool(torch.nn.functional.relu(self.bn4(self.conv4(x)))) | |
| x = x.view(x.size(0), -1) | |
| x = torch.nn.functional.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| def get_preprocess(plant_type): | |
| """Returns model-specific image preprocessing pipeline.""" | |
| size = 380 if plant_type == "potato" else 224 | |
| return transforms.Compose([ | |
| transforms.Resize((size, size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def load_disease_model(plant_type, base_dir): | |
| """Load and return a (model, classes) tuple for the given plant type.""" | |
| plant_type = plant_type.lower() | |
| if plant_type == "apple": | |
| model_path = find_model_path(base_dir, "apple_efficientnet_best.pth", "apple_model.pth") | |
| model_type = "efficientnet_b0" | |
| elif plant_type == "tomato": | |
| model_path = find_model_path(base_dir, "tomato_efficientnet_best.pth", "tomato_model.pth") | |
| model_type = "efficientnet_b0" | |
| elif plant_type == "potato": | |
| model_path = find_model_path(base_dir, "potato_efficientnet_best.pth", "potato_model.pth") | |
| model_type = "efficientnet_b4" | |
| elif plant_type == "cucumber": | |
| model_path = find_model_path(base_dir, "best_cucumber_resnet18.pth", "best_model.pth") | |
| model_type = "resnet18" | |
| elif plant_type == "lemon": | |
| model_path = find_model_path(base_dir, "best_lemon_model_ema.pth", "lemon_model.pth") | |
| model_type = "efficientnet_b0_timm" | |
| elif plant_type == "mango": | |
| model_path = find_model_path(base_dir, "best_mango_model.pth", "mango_model.pth") | |
| model_type = "mango_cnn" | |
| elif plant_type in ["soybean", "soyabean"]: | |
| model_path = find_model_path(base_dir, "soybean_disease_model_efficientnetb2.pth", "soybean_model.pth") | |
| model_type = "efficientnet_b2" | |
| else: | |
| logger.warning(f"No disease classification mapping defined for plant type: {plant_type}") | |
| return None, None | |
| # validate_model_artifact is now only a final guard – resolve_model_path already handles | |
| # the LFS-pointer case before we get here, but keep it as a safety net. | |
| validate_model_artifact(model_path, f"{plant_type} disease model") | |
| logger.info(f"Loading {plant_type} disease model from {os.path.basename(model_path)} ...") | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| # Resolve class list | |
| if isinstance(checkpoint, dict) and "classes" in checkpoint: | |
| classes = checkpoint["classes"] | |
| elif isinstance(checkpoint, dict) and "class_names" in checkpoint: | |
| classes = checkpoint["class_names"] | |
| else: | |
| if plant_type == "cucumber": | |
| classes = [ | |
| 'Anthracnose', 'Bacterial Wilt', 'Belly Rot', 'Downy Mildew', | |
| 'Fresh Cucumber', 'Fresh Leaf', 'Gummy Stem Blight', 'Pythium Fruit Rot' | |
| ] | |
| elif plant_type == "lemon": | |
| classes = [ | |
| 'Anthracnose', 'Bacterial Blight', 'Citrus Canker', 'Curl Virus', | |
| 'Deficiency Leaf', 'Dry Leaf', 'Healthy Leaf', 'Sooty Mould', 'Spider Mites' | |
| ] | |
| elif plant_type == "mango": | |
| classes = [ | |
| 'Anthracnose', 'Bacterial Canker', 'Cutting Weevil', 'Die Back', | |
| 'Gall Midge', 'Healthy', 'Powdery Mildew', 'Sooty Mould' | |
| ] | |
| elif plant_type in ["soybean", "soyabean"]: | |
| classes = [ | |
| 'Bacterial Blight', 'Cercospora Leaf Blight', 'Healthy', | |
| 'Soybean Rust', 'Sudden Death Syndrome' | |
| ] | |
| else: | |
| classes = [] | |
| # Instantiate architecture and load weights | |
| if model_type == "efficientnet_b0": | |
| model = models.efficientnet_b0(weights=None) | |
| in_features = model.classifier[1].in_features | |
| model.classifier[1] = torch.nn.Linear(in_features, len(classes)) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| elif model_type == "efficientnet_b2": | |
| model = models.efficientnet_b2(weights=None) | |
| in_features = model.classifier[1].in_features | |
| model.classifier[1] = torch.nn.Linear(in_features, len(classes)) | |
| model.load_state_dict(checkpoint) | |
| elif model_type == "efficientnet_b4": | |
| import timm | |
| model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=len(classes)) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| elif model_type == "efficientnet_b0_timm": | |
| import timm | |
| model = timm.create_model('efficientnet_b0', pretrained=False) | |
| model.classifier = torch.nn.Sequential( | |
| torch.nn.Dropout(p=0.2), | |
| torch.nn.Linear(1280, len(classes)) | |
| ) | |
| new_state_dict = { | |
| k[7:] if k.startswith("module.") else k: v | |
| for k, v in checkpoint.items() | |
| if k != "n_averaged" | |
| } | |
| model.load_state_dict(new_state_dict) | |
| elif model_type == "resnet18": | |
| model = models.resnet18(weights=None) | |
| model.fc = torch.nn.Sequential( | |
| torch.nn.Linear(512, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(0.5), | |
| torch.nn.Linear(256, len(classes)) | |
| ) | |
| model.load_state_dict(checkpoint) | |
| elif model_type == "mango_cnn": | |
| model = MangoCNN(num_classes=len(classes)) | |
| model.load_state_dict(checkpoint) | |
| else: | |
| logger.error(f"Unknown model architecture type: {model_type}") | |
| return None, None | |
| model = model.to(device) | |
| model.eval() | |
| logger.info(f"Successfully loaded {plant_type} model ({len(classes)} classes).") | |
| return model, classes | |
| # --------------------------------------------------------------------------- | |
| # Startup model pre-loading | |
| # --------------------------------------------------------------------------- | |
| def preload_all_models(): | |
| """Load and cache all detection and classification models at server startup.""" | |
| global models_load_error, model_load_reports | |
| with models_lock: | |
| start_time = time.time() | |
| logger.info("=== Starting model pre-load ===") | |
| logger.info(f"HF_MODEL_REPO : '{HF_MODEL_REPO or '(not set)'}'") | |
| logger.info(f"SPACE_ID : '{SPACE_ID or '(not set)'}'") | |
| logger.info(f"HF_CACHE_DIR : '{HF_CACHE_DIR}'") | |
| model_load_reports = [] | |
| try: | |
| model_files = os.listdir(MODEL_DIR) | |
| except Exception: | |
| model_files = [] | |
| logger.info(f"MODEL_DIR: {MODEL_DIR} | files: {model_files}") | |
| # ------------------------------------------------------------------ # | |
| # 1. YOLO detector # | |
| # ------------------------------------------------------------------ # | |
| try: | |
| yolo_path = find_model_path(MODEL_DIR, "best.pt", "detector.pt") | |
| validate_model_artifact(yolo_path, "YOLO model") | |
| logger.info(f"Loading YOLO model from {yolo_path} ...") | |
| models_cache["yolo"] = YOLO(yolo_path) | |
| logger.info("YOLO model loaded successfully.") | |
| except Exception as exc: | |
| models_load_error = str(exc) | |
| logger.critical(f"YOLO model loading failed: {exc}", exc_info=True) | |
| models_ready.set() | |
| return | |
| # ------------------------------------------------------------------ # | |
| # 2. Plant disease CNN models # | |
| # ------------------------------------------------------------------ # | |
| supported_plants = ["apple", "tomato", "potato", "cucumber", "lemon", "mango", "soybean"] | |
| for plant in supported_plants: | |
| try: | |
| model, classes = load_disease_model(plant, MODEL_DIR) | |
| if model is not None: | |
| models_cache["disease_models"][plant] = (model, classes) | |
| if plant == "soybean": | |
| models_cache["disease_models"]["soyabean"] = (model, classes) | |
| model_load_reports.append({ | |
| "plant": plant, | |
| "status": "loaded", | |
| "classes_count": len(classes) | |
| }) | |
| except Exception as exc: | |
| logger.error(f"Failed to load model for {plant}: {exc}", exc_info=True) | |
| model_load_reports.append({ | |
| "plant": plant, | |
| "status": "error", | |
| "error": str(exc) | |
| }) | |
| elapsed = time.time() - start_time | |
| loaded = [r["plant"] for r in model_load_reports if r["status"] == "loaded"] | |
| failed = [r["plant"] for r in model_load_reports if r["status"] == "error"] | |
| logger.info( | |
| f"=== Model pre-load done in {elapsed:.1f}s | " | |
| f"loaded: {loaded} | failed: {failed} ===" | |
| ) | |
| models_load_error = None | |
| models_ready.set() | |
| def ensure_models_loaded(): | |
| """Trigger model loading on demand if not already done.""" | |
| global models_load_error | |
| if models_cache["yolo"] is not None: | |
| models_ready.set() | |
| return True | |
| with models_lock: | |
| if models_cache["yolo"] is not None: | |
| models_ready.set() | |
| return True | |
| try: | |
| preload_all_models() | |
| except Exception as exc: | |
| models_load_error = str(exc) | |
| logger.error(f"Model loading failed: {exc}", exc_info=True) | |
| return False | |
| return models_cache["yolo"] is not None and models_load_error is None | |
| # --------------------------------------------------------------------------- | |
| # Flask helpers | |
| # --------------------------------------------------------------------------- | |
| def allowed_file(filename, mime_type): | |
| ext = os.path.splitext(filename)[1].lower() | |
| valid = ext in ALLOWED_EXTENSIONS and mime_type in ALLOWED_MIME_TYPES | |
| if not valid: | |
| logger.warning(f"File rejected — name: '{filename}', ext: '{ext}', mime: '{mime_type}'") | |
| return valid | |
| # --------------------------------------------------------------------------- | |
| # Flask routes | |
| # --------------------------------------------------------------------------- | |
| def request_entity_too_large(error): | |
| return jsonify({"success": False, "message": "File size exceeds the 10 MB limit"}), 413 | |
| def index(): | |
| return send_file("main.html") | |
| def health_check(): | |
| loaded_disease_models = list(models_cache["disease_models"].keys()) | |
| if "soyabean" in loaded_disease_models and "soybean" in loaded_disease_models: | |
| loaded_disease_models.remove("soyabean") | |
| return jsonify({ | |
| "status": "healthy" if models_cache["yolo"] is not None and models_load_error is None else "starting", | |
| "device": str(device), | |
| "yolo_loaded": models_cache["yolo"] is not None, | |
| "loaded_disease_models": loaded_disease_models, | |
| "model_load_reports": model_load_reports, | |
| "debug_save_images": DEBUG_SAVE_IMAGES, | |
| "models_loading": models_cache["yolo"] is None, | |
| "load_error": models_load_error, | |
| "hf_model_repo": HF_MODEL_REPO or "(not configured)", | |
| "space_id": SPACE_ID or "(not detected)", | |
| }), 200 | |
| def predict(): | |
| """Two-stage inference: YOLO plant detection → per-plant disease classification.""" | |
| if not ensure_models_loaded(): | |
| return jsonify({ | |
| "success": False, | |
| "message": models_load_error or "Models are still loading. Please retry shortly." | |
| }), 503 | |
| if 'image' not in request.files: | |
| return jsonify({"success": False, "message": "No file in form-data key 'image'"}), 400 | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({"success": False, "message": "No file selected for upload"}), 400 | |
| mime_type = file.content_type | |
| logger.info(f"Received file: '{file.filename}', MIME: '{mime_type}'") | |
| if not allowed_file(file.filename, mime_type): | |
| return jsonify({ | |
| "success": False, | |
| "message": "Invalid file format. Only JPG, JPEG, and PNG are allowed." | |
| }), 400 | |
| try: | |
| img_bytes = file.read() | |
| if DEBUG_SAVE_IMAGES: | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| debug_filename = f"img_{int(time.time() * 1000)}{os.path.splitext(file.filename)[1]}" | |
| debug_path = os.path.join(UPLOAD_FOLDER, debug_filename) | |
| with open(debug_path, "wb") as fh: | |
| fh.write(img_bytes) | |
| logger.info(f"Debug: saved image to {debug_path}") | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| yolo_model = models_cache["yolo"] | |
| if yolo_model is None: | |
| return jsonify({"success": False, "message": "YOLO model is not loaded"}), 500 | |
| device_str = "cuda" if device.type == "cuda" else "cpu" | |
| results = yolo_model.predict( | |
| source=img, conf=0.25, iou=0.45, | |
| save=False, save_txt=False, verbose=False, device=device_str | |
| ) | |
| result = results[0] | |
| if len(result.boxes) == 0: | |
| return jsonify({"success": False, "message": "No plant detected in the image"}), 200 | |
| detections = [] | |
| for box in result.boxes: | |
| class_id = int(box.cls[0]) | |
| detection_confidence = float(box.conf[0]) | |
| plant_type = yolo_model.names[class_id].lower() | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| if plant_type in models_cache["disease_models"]: | |
| disease_model, classes = models_cache["disease_models"][plant_type] | |
| cropped_img = img.crop((x1, y1, x2, y2)) | |
| preprocess = get_preprocess(plant_type) | |
| img_tensor = preprocess(cropped_img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = disease_model(img_tensor) | |
| probs = torch.nn.functional.softmax(outputs[0], dim=0) | |
| pred_idx = torch.argmax(probs).item() | |
| disease_name = classes[pred_idx] | |
| disease_confidence = float(probs[pred_idx].item()) | |
| else: | |
| disease_name = "Unknown Disease (Unsupported Plant)" | |
| disease_confidence = 0.0 | |
| logger.warning(f"No disease model cached for '{plant_type}'.") | |
| detections.append({ | |
| "plant_type": plant_type, | |
| "detection_confidence": round(detection_confidence, 4), | |
| "box": [x1, y1, x2, y2], | |
| "disease": disease_name, | |
| "disease_confidence": round(disease_confidence, 4) | |
| }) | |
| return jsonify({"success": True, "detections": detections}), 200 | |
| except Exception as exc: | |
| logger.error(f"Inference pipeline failed: {exc}", exc_info=True) | |
| return jsonify({ | |
| "success": False, | |
| "message": f"Server error during inference: {exc}" | |
| }), 500 | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # Kick off model loading in a background thread so the HTTP server starts immediately. | |
| threading.Thread(target=preload_all_models, daemon=True).start() | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |