from __future__ import annotations import os from typing import List import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torchvision.transforms.v2 as T from clearml import InputModel, Task from huggingface_hub import hf_hub_download from src.models.cnn_model import PlantCNN from src.models.resnet18_finetune import make_resnet18 from src.utils.class_names import CLASS_NAMES _MODEL_CACHE = {} _CLASS_NAMES_CACHE = None def _device_key(device: torch.device) -> str: return str(device) def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def _build_val_transform(image_size: int = 256) -> T.Compose: return T.Compose([ T.Resize((image_size, image_size)), T.ToImage(), T.ToDtype(torch.float32, scale=True), ]) def _load_model_from_checkpoint( model_path: str, num_classes: int, model_type: str, device: torch.device, ) -> nn.Module: if not os.path.isfile(model_path): raise FileNotFoundError(f"Model file not found - {model_path}") ckpt = torch.load(model_path, map_location=device) if model_type.lower() == "resnet18": model = make_resnet18(num_classes=num_classes) elif model_type.lower() == "cnn": model = PlantCNN(num_classes=num_classes) else: raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") if isinstance(ckpt, dict) and "state_dict" in ckpt: model.load_state_dict(ckpt["state_dict"]) elif isinstance(ckpt, nn.Module): model = ckpt else: try: model.load_state_dict(ckpt) except Exception: raise ValueError(f"Unexpected checkpoint format in - {model_path}. ") model.to(device) model.eval() return model def _load_model_from_clearml_model_id( model_id: str, num_classes: int, model_type: str, device: torch.device, ) -> nn.Module: model_obj = InputModel(model_id=model_id) downloaded_path = model_obj.get_local_copy() if downloaded_path is None: raise FileNotFoundError(f"Failed to download model from ClearML Model ID - {model_id}") if os.path.isdir(downloaded_path): model_files = [f for f in os.listdir(downloaded_path) if f.endswith((".pt", ".pth"))] if model_files: model_path = os.path.join(downloaded_path, model_files[0]) else: for name in ["best_baseline.pt", "best_model.pt", "best_baseline.pth", "best_model.pth"]: candidate = os.path.join(downloaded_path, name) if os.path.isfile(candidate): model_path = candidate break if model_path is None: raise FileNotFoundError(f"No model file found in directory - {downloaded_path}") else: model_path = downloaded_path if model_type.lower() == "resnet18": model = make_resnet18(num_classes=num_classes) elif model_type.lower() == "cnn": model = PlantCNN(num_classes=num_classes) else: raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") state_dict = torch.load(model_path, map_location=device) if isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict) model.to(device) model.eval() return model def _load_model_from_clearml_task_id( task_id: str, num_classes: int, model_type: str, device: torch.device, ) -> nn.Module: source_task = Task.get_task(task_id=task_id) artifact_names = ["best_model", "best_baseline", "model"] model_path = None for artifact_name in artifact_names: if artifact_name in source_task.artifacts: model_path = source_task.artifacts[artifact_name].get_local_copy() if model_path: break if model_path is None: raise FileNotFoundError(f"No model artifact found in Task ID - {task_id}") if model_type.lower() == "resnet18": model = make_resnet18(num_classes=num_classes) elif model_type.lower() == "cnn": model = PlantCNN(num_classes=num_classes) else: raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") state_dict = torch.load(model_path, map_location=device) if isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict) model.to(device) model.eval() return model def _load_model_from_huggingface( repo_id: str, filename: str, num_classes: int, model_type: str, device: torch.device, ) -> nn.Module: model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") if model_type.lower() == "resnet18": model = make_resnet18(num_classes=num_classes) elif model_type.lower() == "cnn": model = PlantCNN(num_classes=num_classes) else: raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") state_dict = torch.load(model_path, map_location=device) if isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict) model.to(device) model.eval() return model def _get_class_names() -> List[str]: return CLASS_NAMES def predict_image(img: Image.Image, k: int = 5) -> pd.DataFrame: """ Predict top-k for a single PIL image. Returns a DataFrame with columns: Img, Rank, Disease, Probability, Model """ if img is None: return pd.DataFrame({"Disease": [], "Probability": []}) try: class_names = _get_class_names() if not class_names: raise ValueError("class_names list is empty.") model_type = os.getenv("MODEL_TYPE", "resnet18") model_path = os.getenv("MODEL_PATH", "") clearml_model_id = os.getenv("CLEARML_MODEL_ID", "") clearml_task_id = os.getenv("CLEARML_TASK_ID", "") hf_repo_id = os.getenv("HF_REPO_ID", "") hf_filename = os.getenv("HF_FILENAME", "") device = _get_device() device_k = _device_key(device) num_classes = len(class_names) transform = _build_val_transform(image_size=256) x = transform(img.convert("RGB")).unsqueeze(0).to(device) model = None # ClearML Model ID if clearml_model_id and clearml_model_id.strip(): cache_key = ("clearml_model", model_type, clearml_model_id, num_classes, device_k) if cache_key not in _MODEL_CACHE: try: _MODEL_CACHE[cache_key] = _load_model_from_clearml_model_id(clearml_model_id, num_classes, model_type, device) except Exception: _MODEL_CACHE[cache_key] = None model = _MODEL_CACHE.get(cache_key) # ClearML Task ID if model is None and clearml_task_id and clearml_task_id.strip(): cache_key = ("clearml_task", model_type, clearml_task_id, num_classes, device_k) if cache_key not in _MODEL_CACHE: try: _MODEL_CACHE[cache_key] = _load_model_from_clearml_task_id(clearml_task_id, num_classes, model_type, device) except Exception: _MODEL_CACHE[cache_key] = None model = _MODEL_CACHE.get(cache_key) # Hugging Face if model is None and hf_repo_id and hf_repo_id.strip() and hf_filename and hf_filename.strip(): cache_key = ("huggingface", model_type, hf_repo_id, hf_filename, num_classes, device_k) if cache_key not in _MODEL_CACHE: try: _MODEL_CACHE[cache_key] = _load_model_from_huggingface(hf_repo_id, hf_filename, num_classes, model_type, device) except Exception: _MODEL_CACHE[cache_key] = None model = _MODEL_CACHE.get(cache_key) # Local checkpoint if model is None: if model_path and os.path.isfile(model_path): cache_key = ("local", model_type, model_path, num_classes, device_k) if cache_key not in _MODEL_CACHE: _MODEL_CACHE[cache_key] = _load_model_from_checkpoint(model_path, num_classes, model_type, device) model = _MODEL_CACHE[cache_key] else: raise FileNotFoundError( f"All loading methods failed. Model ID - {clearml_model_id}, Task ID - {clearml_task_id}, HF - {hf_repo_id}/{hf_filename}, Local path - {model_path}" ) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=1)[0] topk = min(int(k), len(class_names)) top_probs, top_indices = torch.topk(probs, k=topk) results = [ (class_names[idx.item()], float(prob.item())) for prob, idx in zip(top_probs, top_indices) ] return pd.DataFrame({ "Disease": [r[0] for r in results], "Probability": [r[1] for r in results], }) except Exception as e: return pd.DataFrame({"Disease": [f"Error: {str(e)}"], "Probability": [0.0]})