classifier / app.py
Karim
Updating files to the space
25fd2d7
Raw
History Blame Contribute Delete
27.1 kB
import os
import io
import time
import logging
import threading
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from PIL import Image
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from ultralytics import YOLO
# Fix PyTorch 2.6+ weights_only loading issue
if not hasattr(torch, "_original_load"):
torch._original_load = torch.load
def patched_torch_load(*args, **kwargs):
kwargs["weights_only"] = False
return torch._original_load(*args, **kwargs)
torch.load = patched_torch_load
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger("PlantClassifierAPI")
# ---------------------------------------------------------------------------
# Hugging Face repository configuration
# ---------------------------------------------------------------------------
# PRIMARY: Set HF_MODEL_REPO to a dedicated HF *model* repo containing the
# real binary .pth / .pt files (e.g. "yourname/plantcare-models").
#
# FALLBACK: If HF_MODEL_REPO is not set, the app will try to download the
# weights from the Space's own repository using SPACE_ID (which HF
# injects automatically into every running Space). This works because
# HF stores the real LFS bytes even though the Docker build context
# only receives the pointer files.
#
# Either way, set HF_TOKEN for private repos.
# ---------------------------------------------------------------------------
HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "")
# SPACE_ID is auto-injected by HF Spaces; fallback to the known Space ID
SPACE_ID = os.environ.get("SPACE_ID", "Karim31003/classifier")
HF_CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/tmp/hf")
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# ---------------------------------------------------------------------------
# Initialize Flask App
app = Flask(__name__)
CORS(app)
# Limit request payload to 10MB
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024
# Configuration options
DEBUG_SAVE_IMAGES = os.environ.get("DEBUG_SAVE_IMAGES", "false").lower() == "true"
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "uploads")
ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png'}
ALLOWED_MIME_TYPES = {'image/jpeg', 'image/png', 'application/octet-stream'}
# Global cache for loaded models
models_cache = {
"yolo": None,
"disease_models": {}
}
# Detailed load reports for debugging
model_load_reports = []
# Force CPU device
device = torch.device("cpu")
logger.info(f"Using hardware device: {device}")
models_lock = threading.RLock()
models_ready = threading.Event()
models_load_error = None
LFS_POINTER_PREFIX = "version https://git-lfs.github.com/spec/v1"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "model")
if not os.path.isdir(MODEL_DIR):
logger.warning(f"Model directory not found at {MODEL_DIR}; falling back to {BASE_DIR}.")
MODEL_DIR = BASE_DIR
# ---------------------------------------------------------------------------
# Git LFS / file diagnostics
# ---------------------------------------------------------------------------
def _log_file_diagnostics(path: str, label: str) -> None:
"""Log path, size, and the first 200 characters of a file for debugging."""
if not os.path.exists(path):
logger.warning(f"[DIAG] {label}: file does not exist at {path}")
return
size = os.path.getsize(path)
try:
with open(path, "r", encoding="utf-8", errors="ignore") as fh:
head = fh.read(200)
except Exception as exc:
head = f"<could not read: {exc}>"
logger.info(
f"[DIAG] {label}\n"
f" path : {path}\n"
f" size : {size:,} bytes\n"
f" head : {repr(head)}"
)
def _is_lfs_pointer(path: str) -> bool:
"""Return True if the file is a Git LFS pointer text file."""
if not os.path.exists(path):
return False
try:
with open(path, "r", encoding="utf-8", errors="ignore") as fh:
header = fh.read(256)
return header.startswith(LFS_POINTER_PREFIX)
except Exception:
return False
def validate_model_artifact(model_path: str, artifact_name: str) -> None:
"""
Ensure the model artifact is a real binary checkpoint, not a Git LFS pointer.
Logs full diagnostics before raising so the Space logs are informative.
"""
_log_file_diagnostics(model_path, artifact_name)
if not os.path.exists(model_path):
raise FileNotFoundError(f"{artifact_name} not found at: {model_path}")
if _is_lfs_pointer(model_path):
raise RuntimeError(
f"{artifact_name} at {model_path} is a Git LFS pointer, not the actual model file.\n"
"Fix options:\n"
" 1. Set HF_MODEL_REPO env var to your HF model repo ID so weights are downloaded automatically.\n"
" 2. Run `git lfs pull` locally, verify files are binary, then push again.\n"
" 3. Upload real weight files directly via the Hugging Face web UI."
)
# ---------------------------------------------------------------------------
# Hugging Face model download helpers
# ---------------------------------------------------------------------------
def _hf_hub_available() -> bool:
"""Return True if huggingface_hub is importable."""
try:
import huggingface_hub # noqa: F401
return True
except ImportError:
logger.warning("huggingface_hub not installed — add it to requirements.txt.")
return False
def _download_from_hf(local_filename: str) -> "str | None":
"""
Download *local_filename* from the best available HF source, in priority order:
1. HF_MODEL_REPO (dedicated model repo, file at repo root)
repo_type = "model"
2. SPACE_ID (the Space's own repo — LFS bytes are stored in HF's
backend even though the Docker build only sees pointer files)
repo_type = "space", path inside repo = "model/<filename>"
Returns the local cached path on success, None on failure.
"""
if not _hf_hub_available():
return None
from huggingface_hub import hf_hub_download
os.makedirs(HF_CACHE_DIR, exist_ok=True)
# Build candidate list: (repo_id, repo_path, repo_type)
candidates = []
if HF_MODEL_REPO:
candidates.append((HF_MODEL_REPO, local_filename, "model"))
if SPACE_ID:
candidates.append((SPACE_ID, f"model/{local_filename}", "space"))
if not candidates:
logger.error(
f"Cannot auto-download '{local_filename}': neither HF_MODEL_REPO nor SPACE_ID is set.\n"
" • Set HF_MODEL_REPO to a dedicated model repo (e.g. 'yourname/plantcare-models'), OR\n"
" • The SPACE_ID env var should be injected automatically by HF Spaces — "
"if missing, add it manually in Space Settings → Variables."
)
return None
for repo_id, repo_path, repo_type in candidates:
try:
logger.info(
f"Downloading '{repo_path}' from {repo_type} repo '{repo_id}' → {HF_CACHE_DIR} ..."
)
local_path = hf_hub_download(
repo_id=repo_id,
filename=repo_path,
repo_type=repo_type,
cache_dir=HF_CACHE_DIR,
token=HF_TOKEN,
force_download=False, # reuse cached copy on restarts
)
size = os.path.getsize(local_path)
logger.info(f"Downloaded '{local_filename}' → {local_path} ({size:,} bytes)")
return local_path
except Exception as exc:
logger.warning(
f"Download attempt failed for '{repo_path}' from '{repo_id}' ({repo_type}): {exc}"
)
logger.error(f"All HF download attempts failed for '{local_filename}'.")
return None
def resolve_model_path(local_filename: str) -> "str | None":
"""
Return the best usable path for *local_filename*:
1. Local file in MODEL_DIR — if it exists and is a real binary, use it directly.
2. Otherwise attempt to download from HF (model repo or Space repo).
3. Return None if nothing works.
"""
local_path = os.path.join(MODEL_DIR, local_filename)
# Fast path — local binary exists
if os.path.exists(local_path) and not _is_lfs_pointer(local_path):
logger.info(f"Using local model file: {local_path}")
return local_path
# Diagnostics before attempting download
_log_file_diagnostics(local_path, local_filename)
if _is_lfs_pointer(local_path):
logger.warning(f"'{local_filename}' is a Git LFS pointer — downloading real weights from HF.")
else:
logger.warning(f"'{local_filename}' not found locally — attempting HF download.")
downloaded = _download_from_hf(local_filename)
if downloaded and os.path.exists(downloaded) and not _is_lfs_pointer(downloaded):
return downloaded
if downloaded:
logger.error(f"HF download for '{local_filename}' produced another LFS pointer or empty file.")
return None
def find_model_path(base_dir, *candidates):
"""Return the first existing, non-LFS model file path from candidate names."""
for candidate in candidates:
candidate_path = os.path.join(base_dir, candidate)
# Prefer a resolved (possibly downloaded) path over a raw local path
resolved = resolve_model_path(candidate)
if resolved:
if candidate != candidates[0]:
logger.warning(f"Using fallback model file '{candidate}' instead of '{candidates[0]}'")
return resolved
# Fall back to bare existence check (covers edge-cases where MODEL_DIR != base_dir)
if os.path.exists(candidate_path) and not _is_lfs_pointer(candidate_path):
return candidate_path
raise FileNotFoundError(
f"No valid (non-LFS) model file found in {base_dir}. Tried: {', '.join(candidates)}"
)
# ---------------------------------------------------------------------------
# Model architecture definitions
# ---------------------------------------------------------------------------
class MangoCNN(torch.nn.Module):
def __init__(self, num_classes=8):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = torch.nn.BatchNorm2d(32)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = torch.nn.BatchNorm2d(64)
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = torch.nn.BatchNorm2d(128)
self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn4 = torch.nn.BatchNorm2d(256)
self.pool = torch.nn.MaxPool2d(2, 2)
self.fc1 = torch.nn.Linear(256 * 14 * 14, 512)
self.fc2 = torch.nn.Linear(512, num_classes)
def forward(self, x):
x = self.pool(torch.nn.functional.relu(self.bn1(self.conv1(x))))
x = self.pool(torch.nn.functional.relu(self.bn2(self.conv2(x))))
x = self.pool(torch.nn.functional.relu(self.bn3(self.conv3(x))))
x = self.pool(torch.nn.functional.relu(self.bn4(self.conv4(x))))
x = x.view(x.size(0), -1)
x = torch.nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
def get_preprocess(plant_type):
"""Returns model-specific image preprocessing pipeline."""
size = 380 if plant_type == "potato" else 224
return transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_disease_model(plant_type, base_dir):
"""Load and return a (model, classes) tuple for the given plant type."""
plant_type = plant_type.lower()
if plant_type == "apple":
model_path = find_model_path(base_dir, "apple_efficientnet_best.pth", "apple_model.pth")
model_type = "efficientnet_b0"
elif plant_type == "tomato":
model_path = find_model_path(base_dir, "tomato_efficientnet_best.pth", "tomato_model.pth")
model_type = "efficientnet_b0"
elif plant_type == "potato":
model_path = find_model_path(base_dir, "potato_efficientnet_best.pth", "potato_model.pth")
model_type = "efficientnet_b4"
elif plant_type == "cucumber":
model_path = find_model_path(base_dir, "best_cucumber_resnet18.pth", "best_model.pth")
model_type = "resnet18"
elif plant_type == "lemon":
model_path = find_model_path(base_dir, "best_lemon_model_ema.pth", "lemon_model.pth")
model_type = "efficientnet_b0_timm"
elif plant_type == "mango":
model_path = find_model_path(base_dir, "best_mango_model.pth", "mango_model.pth")
model_type = "mango_cnn"
elif plant_type in ["soybean", "soyabean"]:
model_path = find_model_path(base_dir, "soybean_disease_model_efficientnetb2.pth", "soybean_model.pth")
model_type = "efficientnet_b2"
else:
logger.warning(f"No disease classification mapping defined for plant type: {plant_type}")
return None, None
# validate_model_artifact is now only a final guard – resolve_model_path already handles
# the LFS-pointer case before we get here, but keep it as a safety net.
validate_model_artifact(model_path, f"{plant_type} disease model")
logger.info(f"Loading {plant_type} disease model from {os.path.basename(model_path)} ...")
checkpoint = torch.load(model_path, map_location="cpu")
# Resolve class list
if isinstance(checkpoint, dict) and "classes" in checkpoint:
classes = checkpoint["classes"]
elif isinstance(checkpoint, dict) and "class_names" in checkpoint:
classes = checkpoint["class_names"]
else:
if plant_type == "cucumber":
classes = [
'Anthracnose', 'Bacterial Wilt', 'Belly Rot', 'Downy Mildew',
'Fresh Cucumber', 'Fresh Leaf', 'Gummy Stem Blight', 'Pythium Fruit Rot'
]
elif plant_type == "lemon":
classes = [
'Anthracnose', 'Bacterial Blight', 'Citrus Canker', 'Curl Virus',
'Deficiency Leaf', 'Dry Leaf', 'Healthy Leaf', 'Sooty Mould', 'Spider Mites'
]
elif plant_type == "mango":
classes = [
'Anthracnose', 'Bacterial Canker', 'Cutting Weevil', 'Die Back',
'Gall Midge', 'Healthy', 'Powdery Mildew', 'Sooty Mould'
]
elif plant_type in ["soybean", "soyabean"]:
classes = [
'Bacterial Blight', 'Cercospora Leaf Blight', 'Healthy',
'Soybean Rust', 'Sudden Death Syndrome'
]
else:
classes = []
# Instantiate architecture and load weights
if model_type == "efficientnet_b0":
model = models.efficientnet_b0(weights=None)
in_features = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(in_features, len(classes))
model.load_state_dict(checkpoint["model_state_dict"])
elif model_type == "efficientnet_b2":
model = models.efficientnet_b2(weights=None)
in_features = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(in_features, len(classes))
model.load_state_dict(checkpoint)
elif model_type == "efficientnet_b4":
import timm
model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=len(classes))
model.load_state_dict(checkpoint["model_state_dict"])
elif model_type == "efficientnet_b0_timm":
import timm
model = timm.create_model('efficientnet_b0', pretrained=False)
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2),
torch.nn.Linear(1280, len(classes))
)
new_state_dict = {
k[7:] if k.startswith("module.") else k: v
for k, v in checkpoint.items()
if k != "n_averaged"
}
model.load_state_dict(new_state_dict)
elif model_type == "resnet18":
model = models.resnet18(weights=None)
model.fc = torch.nn.Sequential(
torch.nn.Linear(512, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.5),
torch.nn.Linear(256, len(classes))
)
model.load_state_dict(checkpoint)
elif model_type == "mango_cnn":
model = MangoCNN(num_classes=len(classes))
model.load_state_dict(checkpoint)
else:
logger.error(f"Unknown model architecture type: {model_type}")
return None, None
model = model.to(device)
model.eval()
logger.info(f"Successfully loaded {plant_type} model ({len(classes)} classes).")
return model, classes
# ---------------------------------------------------------------------------
# Startup model pre-loading
# ---------------------------------------------------------------------------
def preload_all_models():
"""Load and cache all detection and classification models at server startup."""
global models_load_error, model_load_reports
with models_lock:
start_time = time.time()
logger.info("=== Starting model pre-load ===")
logger.info(f"HF_MODEL_REPO : '{HF_MODEL_REPO or '(not set)'}'")
logger.info(f"SPACE_ID : '{SPACE_ID or '(not set)'}'")
logger.info(f"HF_CACHE_DIR : '{HF_CACHE_DIR}'")
model_load_reports = []
try:
model_files = os.listdir(MODEL_DIR)
except Exception:
model_files = []
logger.info(f"MODEL_DIR: {MODEL_DIR} | files: {model_files}")
# ------------------------------------------------------------------ #
# 1. YOLO detector #
# ------------------------------------------------------------------ #
try:
yolo_path = find_model_path(MODEL_DIR, "best.pt", "detector.pt")
validate_model_artifact(yolo_path, "YOLO model")
logger.info(f"Loading YOLO model from {yolo_path} ...")
models_cache["yolo"] = YOLO(yolo_path)
logger.info("YOLO model loaded successfully.")
except Exception as exc:
models_load_error = str(exc)
logger.critical(f"YOLO model loading failed: {exc}", exc_info=True)
models_ready.set()
return
# ------------------------------------------------------------------ #
# 2. Plant disease CNN models #
# ------------------------------------------------------------------ #
supported_plants = ["apple", "tomato", "potato", "cucumber", "lemon", "mango", "soybean"]
for plant in supported_plants:
try:
model, classes = load_disease_model(plant, MODEL_DIR)
if model is not None:
models_cache["disease_models"][plant] = (model, classes)
if plant == "soybean":
models_cache["disease_models"]["soyabean"] = (model, classes)
model_load_reports.append({
"plant": plant,
"status": "loaded",
"classes_count": len(classes)
})
except Exception as exc:
logger.error(f"Failed to load model for {plant}: {exc}", exc_info=True)
model_load_reports.append({
"plant": plant,
"status": "error",
"error": str(exc)
})
elapsed = time.time() - start_time
loaded = [r["plant"] for r in model_load_reports if r["status"] == "loaded"]
failed = [r["plant"] for r in model_load_reports if r["status"] == "error"]
logger.info(
f"=== Model pre-load done in {elapsed:.1f}s | "
f"loaded: {loaded} | failed: {failed} ==="
)
models_load_error = None
models_ready.set()
def ensure_models_loaded():
"""Trigger model loading on demand if not already done."""
global models_load_error
if models_cache["yolo"] is not None:
models_ready.set()
return True
with models_lock:
if models_cache["yolo"] is not None:
models_ready.set()
return True
try:
preload_all_models()
except Exception as exc:
models_load_error = str(exc)
logger.error(f"Model loading failed: {exc}", exc_info=True)
return False
return models_cache["yolo"] is not None and models_load_error is None
# ---------------------------------------------------------------------------
# Flask helpers
# ---------------------------------------------------------------------------
def allowed_file(filename, mime_type):
ext = os.path.splitext(filename)[1].lower()
valid = ext in ALLOWED_EXTENSIONS and mime_type in ALLOWED_MIME_TYPES
if not valid:
logger.warning(f"File rejected — name: '{filename}', ext: '{ext}', mime: '{mime_type}'")
return valid
# ---------------------------------------------------------------------------
# Flask routes
# ---------------------------------------------------------------------------
@app.errorhandler(413)
def request_entity_too_large(error):
return jsonify({"success": False, "message": "File size exceeds the 10 MB limit"}), 413
@app.route("/", methods=["GET"])
def index():
return send_file("main.html")
@app.route("/health", methods=["GET"])
def health_check():
loaded_disease_models = list(models_cache["disease_models"].keys())
if "soyabean" in loaded_disease_models and "soybean" in loaded_disease_models:
loaded_disease_models.remove("soyabean")
return jsonify({
"status": "healthy" if models_cache["yolo"] is not None and models_load_error is None else "starting",
"device": str(device),
"yolo_loaded": models_cache["yolo"] is not None,
"loaded_disease_models": loaded_disease_models,
"model_load_reports": model_load_reports,
"debug_save_images": DEBUG_SAVE_IMAGES,
"models_loading": models_cache["yolo"] is None,
"load_error": models_load_error,
"hf_model_repo": HF_MODEL_REPO or "(not configured)",
"space_id": SPACE_ID or "(not detected)",
}), 200
@app.route("/predict", methods=["POST"])
def predict():
"""Two-stage inference: YOLO plant detection → per-plant disease classification."""
if not ensure_models_loaded():
return jsonify({
"success": False,
"message": models_load_error or "Models are still loading. Please retry shortly."
}), 503
if 'image' not in request.files:
return jsonify({"success": False, "message": "No file in form-data key 'image'"}), 400
file = request.files['image']
if file.filename == '':
return jsonify({"success": False, "message": "No file selected for upload"}), 400
mime_type = file.content_type
logger.info(f"Received file: '{file.filename}', MIME: '{mime_type}'")
if not allowed_file(file.filename, mime_type):
return jsonify({
"success": False,
"message": "Invalid file format. Only JPG, JPEG, and PNG are allowed."
}), 400
try:
img_bytes = file.read()
if DEBUG_SAVE_IMAGES:
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
debug_filename = f"img_{int(time.time() * 1000)}{os.path.splitext(file.filename)[1]}"
debug_path = os.path.join(UPLOAD_FOLDER, debug_filename)
with open(debug_path, "wb") as fh:
fh.write(img_bytes)
logger.info(f"Debug: saved image to {debug_path}")
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
yolo_model = models_cache["yolo"]
if yolo_model is None:
return jsonify({"success": False, "message": "YOLO model is not loaded"}), 500
device_str = "cuda" if device.type == "cuda" else "cpu"
results = yolo_model.predict(
source=img, conf=0.25, iou=0.45,
save=False, save_txt=False, verbose=False, device=device_str
)
result = results[0]
if len(result.boxes) == 0:
return jsonify({"success": False, "message": "No plant detected in the image"}), 200
detections = []
for box in result.boxes:
class_id = int(box.cls[0])
detection_confidence = float(box.conf[0])
plant_type = yolo_model.names[class_id].lower()
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
if plant_type in models_cache["disease_models"]:
disease_model, classes = models_cache["disease_models"][plant_type]
cropped_img = img.crop((x1, y1, x2, y2))
preprocess = get_preprocess(plant_type)
img_tensor = preprocess(cropped_img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = disease_model(img_tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
pred_idx = torch.argmax(probs).item()
disease_name = classes[pred_idx]
disease_confidence = float(probs[pred_idx].item())
else:
disease_name = "Unknown Disease (Unsupported Plant)"
disease_confidence = 0.0
logger.warning(f"No disease model cached for '{plant_type}'.")
detections.append({
"plant_type": plant_type,
"detection_confidence": round(detection_confidence, 4),
"box": [x1, y1, x2, y2],
"disease": disease_name,
"disease_confidence": round(disease_confidence, 4)
})
return jsonify({"success": True, "detections": detections}), 200
except Exception as exc:
logger.error(f"Inference pipeline failed: {exc}", exc_info=True)
return jsonify({
"success": False,
"message": f"Server error during inference: {exc}"
}), 500
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# Kick off model loading in a background thread so the HTTP server starts immediately.
threading.Thread(target=preload_all_models, daemon=True).start()
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)