Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from typing import List | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms.v2 as T | |
| from clearml import InputModel, Task | |
| from huggingface_hub import hf_hub_download | |
| from src.models.cnn_model import PlantCNN | |
| from src.models.resnet18_finetune import make_resnet18 | |
| from src.utils.class_names import CLASS_NAMES | |
| _MODEL_CACHE = {} | |
| _CLASS_NAMES_CACHE = None | |
| def _device_key(device: torch.device) -> str: | |
| return str(device) | |
| def _get_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def _build_val_transform(image_size: int = 256) -> T.Compose: | |
| return T.Compose([ | |
| T.Resize((image_size, image_size)), | |
| T.ToImage(), | |
| T.ToDtype(torch.float32, scale=True), | |
| ]) | |
| def _load_model_from_checkpoint( | |
| model_path: str, | |
| num_classes: int, | |
| model_type: str, | |
| device: torch.device, | |
| ) -> nn.Module: | |
| if not os.path.isfile(model_path): | |
| raise FileNotFoundError(f"Model file not found - {model_path}") | |
| ckpt = torch.load(model_path, map_location=device) | |
| if model_type.lower() == "resnet18": | |
| model = make_resnet18(num_classes=num_classes) | |
| elif model_type.lower() == "cnn": | |
| model = PlantCNN(num_classes=num_classes) | |
| else: | |
| raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") | |
| if isinstance(ckpt, dict) and "state_dict" in ckpt: | |
| model.load_state_dict(ckpt["state_dict"]) | |
| elif isinstance(ckpt, nn.Module): | |
| model = ckpt | |
| else: | |
| try: | |
| model.load_state_dict(ckpt) | |
| except Exception: | |
| raise ValueError(f"Unexpected checkpoint format in - {model_path}. ") | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def _load_model_from_clearml_model_id( | |
| model_id: str, | |
| num_classes: int, | |
| model_type: str, | |
| device: torch.device, | |
| ) -> nn.Module: | |
| model_obj = InputModel(model_id=model_id) | |
| downloaded_path = model_obj.get_local_copy() | |
| if downloaded_path is None: | |
| raise FileNotFoundError(f"Failed to download model from ClearML Model ID - {model_id}") | |
| if os.path.isdir(downloaded_path): | |
| model_files = [f for f in os.listdir(downloaded_path) if f.endswith((".pt", ".pth"))] | |
| if model_files: | |
| model_path = os.path.join(downloaded_path, model_files[0]) | |
| else: | |
| for name in ["best_baseline.pt", "best_model.pt", "best_baseline.pth", "best_model.pth"]: | |
| candidate = os.path.join(downloaded_path, name) | |
| if os.path.isfile(candidate): | |
| model_path = candidate | |
| break | |
| if model_path is None: | |
| raise FileNotFoundError(f"No model file found in directory - {downloaded_path}") | |
| else: | |
| model_path = downloaded_path | |
| if model_type.lower() == "resnet18": | |
| model = make_resnet18(num_classes=num_classes) | |
| elif model_type.lower() == "cnn": | |
| model = PlantCNN(num_classes=num_classes) | |
| else: | |
| raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") | |
| state_dict = torch.load(model_path, map_location=device) | |
| if isinstance(state_dict, dict) and "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def _load_model_from_clearml_task_id( | |
| task_id: str, | |
| num_classes: int, | |
| model_type: str, | |
| device: torch.device, | |
| ) -> nn.Module: | |
| source_task = Task.get_task(task_id=task_id) | |
| artifact_names = ["best_model", "best_baseline", "model"] | |
| model_path = None | |
| for artifact_name in artifact_names: | |
| if artifact_name in source_task.artifacts: | |
| model_path = source_task.artifacts[artifact_name].get_local_copy() | |
| if model_path: | |
| break | |
| if model_path is None: | |
| raise FileNotFoundError(f"No model artifact found in Task ID - {task_id}") | |
| if model_type.lower() == "resnet18": | |
| model = make_resnet18(num_classes=num_classes) | |
| elif model_type.lower() == "cnn": | |
| model = PlantCNN(num_classes=num_classes) | |
| else: | |
| raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") | |
| state_dict = torch.load(model_path, map_location=device) | |
| if isinstance(state_dict, dict) and "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def _load_model_from_huggingface( | |
| repo_id: str, | |
| filename: str, | |
| num_classes: int, | |
| model_type: str, | |
| device: torch.device, | |
| ) -> nn.Module: | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
| if model_type.lower() == "resnet18": | |
| model = make_resnet18(num_classes=num_classes) | |
| elif model_type.lower() == "cnn": | |
| model = PlantCNN(num_classes=num_classes) | |
| else: | |
| raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.") | |
| state_dict = torch.load(model_path, map_location=device) | |
| if isinstance(state_dict, dict) and "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def _get_class_names() -> List[str]: | |
| return CLASS_NAMES | |
| def predict_image(img: Image.Image, k: int = 5) -> pd.DataFrame: | |
| """ | |
| Predict top-k for a single PIL image. | |
| Returns a DataFrame with columns: Img, Rank, Disease, Probability, Model | |
| """ | |
| if img is None: | |
| return pd.DataFrame({"Disease": [], "Probability": []}) | |
| try: | |
| class_names = _get_class_names() | |
| if not class_names: | |
| raise ValueError("class_names list is empty.") | |
| model_type = os.getenv("MODEL_TYPE", "resnet18") | |
| model_path = os.getenv("MODEL_PATH", "") | |
| clearml_model_id = os.getenv("CLEARML_MODEL_ID", "") | |
| clearml_task_id = os.getenv("CLEARML_TASK_ID", "") | |
| hf_repo_id = os.getenv("HF_REPO_ID", "") | |
| hf_filename = os.getenv("HF_FILENAME", "") | |
| device = _get_device() | |
| device_k = _device_key(device) | |
| num_classes = len(class_names) | |
| transform = _build_val_transform(image_size=256) | |
| x = transform(img.convert("RGB")).unsqueeze(0).to(device) | |
| model = None | |
| # ClearML Model ID | |
| if clearml_model_id and clearml_model_id.strip(): | |
| cache_key = ("clearml_model", model_type, clearml_model_id, num_classes, device_k) | |
| if cache_key not in _MODEL_CACHE: | |
| try: | |
| _MODEL_CACHE[cache_key] = _load_model_from_clearml_model_id(clearml_model_id, num_classes, model_type, device) | |
| except Exception: | |
| _MODEL_CACHE[cache_key] = None | |
| model = _MODEL_CACHE.get(cache_key) | |
| # ClearML Task ID | |
| if model is None and clearml_task_id and clearml_task_id.strip(): | |
| cache_key = ("clearml_task", model_type, clearml_task_id, num_classes, device_k) | |
| if cache_key not in _MODEL_CACHE: | |
| try: | |
| _MODEL_CACHE[cache_key] = _load_model_from_clearml_task_id(clearml_task_id, num_classes, model_type, device) | |
| except Exception: | |
| _MODEL_CACHE[cache_key] = None | |
| model = _MODEL_CACHE.get(cache_key) | |
| # Hugging Face | |
| if model is None and hf_repo_id and hf_repo_id.strip() and hf_filename and hf_filename.strip(): | |
| cache_key = ("huggingface", model_type, hf_repo_id, hf_filename, num_classes, device_k) | |
| if cache_key not in _MODEL_CACHE: | |
| try: | |
| _MODEL_CACHE[cache_key] = _load_model_from_huggingface(hf_repo_id, hf_filename, num_classes, model_type, device) | |
| except Exception: | |
| _MODEL_CACHE[cache_key] = None | |
| model = _MODEL_CACHE.get(cache_key) | |
| # Local checkpoint | |
| if model is None: | |
| if model_path and os.path.isfile(model_path): | |
| cache_key = ("local", model_type, model_path, num_classes, device_k) | |
| if cache_key not in _MODEL_CACHE: | |
| _MODEL_CACHE[cache_key] = _load_model_from_checkpoint(model_path, num_classes, model_type, device) | |
| model = _MODEL_CACHE[cache_key] | |
| else: | |
| raise FileNotFoundError( | |
| f"All loading methods failed. Model ID - {clearml_model_id}, Task ID - {clearml_task_id}, HF - {hf_repo_id}/{hf_filename}, Local path - {model_path}" | |
| ) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| topk = min(int(k), len(class_names)) | |
| top_probs, top_indices = torch.topk(probs, k=topk) | |
| results = [ | |
| (class_names[idx.item()], float(prob.item())) | |
| for prob, idx in zip(top_probs, top_indices) | |
| ] | |
| return pd.DataFrame({ | |
| "Disease": [r[0] for r in results], | |
| "Probability": [r[1] for r in results], | |
| }) | |
| except Exception as e: | |
| return pd.DataFrame({"Disease": [f"Error: {str(e)}"], "Probability": [0.0]}) | |