JAMM032's picture
Upload github repo files
97fcc90 verified
from __future__ import annotations
import os
from typing import List
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms.v2 as T
from clearml import InputModel, Task
from huggingface_hub import hf_hub_download
from src.models.cnn_model import PlantCNN
from src.models.resnet18_finetune import make_resnet18
from src.utils.class_names import CLASS_NAMES
_MODEL_CACHE = {}
_CLASS_NAMES_CACHE = None
def _device_key(device: torch.device) -> str:
return str(device)
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def _build_val_transform(image_size: int = 256) -> T.Compose:
return T.Compose([
T.Resize((image_size, image_size)),
T.ToImage(),
T.ToDtype(torch.float32, scale=True),
])
def _load_model_from_checkpoint(
model_path: str,
num_classes: int,
model_type: str,
device: torch.device,
) -> nn.Module:
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Model file not found - {model_path}")
ckpt = torch.load(model_path, map_location=device)
if model_type.lower() == "resnet18":
model = make_resnet18(num_classes=num_classes)
elif model_type.lower() == "cnn":
model = PlantCNN(num_classes=num_classes)
else:
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
if isinstance(ckpt, dict) and "state_dict" in ckpt:
model.load_state_dict(ckpt["state_dict"])
elif isinstance(ckpt, nn.Module):
model = ckpt
else:
try:
model.load_state_dict(ckpt)
except Exception:
raise ValueError(f"Unexpected checkpoint format in - {model_path}. ")
model.to(device)
model.eval()
return model
def _load_model_from_clearml_model_id(
model_id: str,
num_classes: int,
model_type: str,
device: torch.device,
) -> nn.Module:
model_obj = InputModel(model_id=model_id)
downloaded_path = model_obj.get_local_copy()
if downloaded_path is None:
raise FileNotFoundError(f"Failed to download model from ClearML Model ID - {model_id}")
if os.path.isdir(downloaded_path):
model_files = [f for f in os.listdir(downloaded_path) if f.endswith((".pt", ".pth"))]
if model_files:
model_path = os.path.join(downloaded_path, model_files[0])
else:
for name in ["best_baseline.pt", "best_model.pt", "best_baseline.pth", "best_model.pth"]:
candidate = os.path.join(downloaded_path, name)
if os.path.isfile(candidate):
model_path = candidate
break
if model_path is None:
raise FileNotFoundError(f"No model file found in directory - {downloaded_path}")
else:
model_path = downloaded_path
if model_type.lower() == "resnet18":
model = make_resnet18(num_classes=num_classes)
elif model_type.lower() == "cnn":
model = PlantCNN(num_classes=num_classes)
else:
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
state_dict = torch.load(model_path, map_location=device)
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def _load_model_from_clearml_task_id(
task_id: str,
num_classes: int,
model_type: str,
device: torch.device,
) -> nn.Module:
source_task = Task.get_task(task_id=task_id)
artifact_names = ["best_model", "best_baseline", "model"]
model_path = None
for artifact_name in artifact_names:
if artifact_name in source_task.artifacts:
model_path = source_task.artifacts[artifact_name].get_local_copy()
if model_path:
break
if model_path is None:
raise FileNotFoundError(f"No model artifact found in Task ID - {task_id}")
if model_type.lower() == "resnet18":
model = make_resnet18(num_classes=num_classes)
elif model_type.lower() == "cnn":
model = PlantCNN(num_classes=num_classes)
else:
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
state_dict = torch.load(model_path, map_location=device)
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def _load_model_from_huggingface(
repo_id: str,
filename: str,
num_classes: int,
model_type: str,
device: torch.device,
) -> nn.Module:
model_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
if model_type.lower() == "resnet18":
model = make_resnet18(num_classes=num_classes)
elif model_type.lower() == "cnn":
model = PlantCNN(num_classes=num_classes)
else:
raise ValueError(f"Unknown model type - {model_type}. Must be 'resnet18' or 'cnn'.")
state_dict = torch.load(model_path, map_location=device)
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def _get_class_names() -> List[str]:
return CLASS_NAMES
def predict_image(img: Image.Image, k: int = 5) -> pd.DataFrame:
"""
Predict top-k for a single PIL image.
Returns a DataFrame with columns: Img, Rank, Disease, Probability, Model
"""
if img is None:
return pd.DataFrame({"Disease": [], "Probability": []})
try:
class_names = _get_class_names()
if not class_names:
raise ValueError("class_names list is empty.")
model_type = os.getenv("MODEL_TYPE", "resnet18")
model_path = os.getenv("MODEL_PATH", "")
clearml_model_id = os.getenv("CLEARML_MODEL_ID", "")
clearml_task_id = os.getenv("CLEARML_TASK_ID", "")
hf_repo_id = os.getenv("HF_REPO_ID", "")
hf_filename = os.getenv("HF_FILENAME", "")
device = _get_device()
device_k = _device_key(device)
num_classes = len(class_names)
transform = _build_val_transform(image_size=256)
x = transform(img.convert("RGB")).unsqueeze(0).to(device)
model = None
# ClearML Model ID
if clearml_model_id and clearml_model_id.strip():
cache_key = ("clearml_model", model_type, clearml_model_id, num_classes, device_k)
if cache_key not in _MODEL_CACHE:
try:
_MODEL_CACHE[cache_key] = _load_model_from_clearml_model_id(clearml_model_id, num_classes, model_type, device)
except Exception:
_MODEL_CACHE[cache_key] = None
model = _MODEL_CACHE.get(cache_key)
# ClearML Task ID
if model is None and clearml_task_id and clearml_task_id.strip():
cache_key = ("clearml_task", model_type, clearml_task_id, num_classes, device_k)
if cache_key not in _MODEL_CACHE:
try:
_MODEL_CACHE[cache_key] = _load_model_from_clearml_task_id(clearml_task_id, num_classes, model_type, device)
except Exception:
_MODEL_CACHE[cache_key] = None
model = _MODEL_CACHE.get(cache_key)
# Hugging Face
if model is None and hf_repo_id and hf_repo_id.strip() and hf_filename and hf_filename.strip():
cache_key = ("huggingface", model_type, hf_repo_id, hf_filename, num_classes, device_k)
if cache_key not in _MODEL_CACHE:
try:
_MODEL_CACHE[cache_key] = _load_model_from_huggingface(hf_repo_id, hf_filename, num_classes, model_type, device)
except Exception:
_MODEL_CACHE[cache_key] = None
model = _MODEL_CACHE.get(cache_key)
# Local checkpoint
if model is None:
if model_path and os.path.isfile(model_path):
cache_key = ("local", model_type, model_path, num_classes, device_k)
if cache_key not in _MODEL_CACHE:
_MODEL_CACHE[cache_key] = _load_model_from_checkpoint(model_path, num_classes, model_type, device)
model = _MODEL_CACHE[cache_key]
else:
raise FileNotFoundError(
f"All loading methods failed. Model ID - {clearml_model_id}, Task ID - {clearml_task_id}, HF - {hf_repo_id}/{hf_filename}, Local path - {model_path}"
)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
topk = min(int(k), len(class_names))
top_probs, top_indices = torch.topk(probs, k=topk)
results = [
(class_names[idx.item()], float(prob.item()))
for prob, idx in zip(top_probs, top_indices)
]
return pd.DataFrame({
"Disease": [r[0] for r in results],
"Probability": [r[1] for r in results],
})
except Exception as e:
return pd.DataFrame({"Disease": [f"Error: {str(e)}"], "Probability": [0.0]})