sene-disease-api / model_core.py
mohamedsamake8322's picture
Upload 5 files
950be73 verified
"""
model_core.py
--------------
Module centralisant toute la logique IA pour le modèle de production
(Phase 2 - Swin Base, metric learning).
Utilisé par:
- disease_api.py (FastAPI backend)
- 04_app_streamlit.py (interface de debug Streamlit)
Objectifs:
- Charger le modèle et les artefacts UNE SEULE FOIS (chemins cohérents)
- Garantir exactement les mêmes prédictions partout
- Eviter la duplication de code (prétraitement, prototypes, FAISS, etc.)
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from PIL import Image
import torch
import sys
sys.path.append(".")
from huggingface_hub import hf_hub_download
try:
import faiss # type: ignore
except Exception: # pragma: no cover - FAISS optionnel
faiss = None
# Import from local metric_training_core (same HF Spaces directory)
from metric_training_core import DiagnosticModel # type: ignore
# Dossiers de base (communs API + Streamlit)
MODELS_PATH_PHASE2 = Path("./outputs/phase2_swin_base_production/models")
DATASET_ROOT_LOCAL = Path("./dataset_final")
# Sur Streamlit Cloud, il n'y a pas de GPU disponible.
# Pour éviter les allocations inutiles et garder une empreinte mémoire prévisible,
# on force l'utilisation du CPU pour l'inférence.
DEVICE = torch.device("cpu")
# Référentiel Hugging Face contenant les artefacts du modèle
MODEL_REPO = "mohamedsamake8322/plant-diseaseS-swin-faiss"
def load_phase2_model_and_metadata(
models_path: Path = MODELS_PATH_PHASE2,
) -> Tuple[torch.nn.Module, Optional[Any], Dict[str, Any], Optional[np.ndarray], Optional[np.ndarray], torch.device]:
"""
Charge le modèle phase 2 (metric_model.pt), la metadata, l'index FAISS (optionnel)
et les prototypes. Retourne:
- model
- index (FAISS ou None)
- metadata (dict)
- prototypes (np.ndarray ou None)
- prototype_labels (np.ndarray ou None)
- device
"""
metric_model_path = models_path / "metric_model.pt"
metadata_path = models_path / "metadata.pkl"
# 1) On tente d'utiliser les fichiers locaux (utile en dev hors-ligne)
if not metric_model_path.exists() or not metadata_path.exists():
# 2) Fallback pro : téléchargement depuis Hugging Face
try:
metric_model_path = Path(
hf_hub_download(repo_id=MODEL_REPO, filename="metric_model.pt")
)
metadata_path = Path(
hf_hub_download(repo_id=MODEL_REPO, filename="metadata.pkl")
)
except Exception as e:
raise RuntimeError(
f"Impossible de trouver les artefacts modèles localement dans {models_path} "
f"et le téléchargement depuis Hugging Face a échoué: {e}"
) from e
import pickle
with open(metadata_path, "rb") as f:
metadata = pickle.load(f)
# Pour limiter l'empreinte mémoire en production (Streamlit Cloud),
# on supprime explicitement les champs lourds dont on n'a pas besoin
# pour l'inférence (embeddings bruts, métriques détaillées, etc.).
for heavy_key in [
"embeddings",
"train_embeddings",
"val_embeddings",
"test_embeddings",
"logits",
"train_metrics",
"val_metrics",
]:
if heavy_key in metadata:
try:
tmp = metadata.pop(heavy_key, None)
del tmp
except Exception:
metadata[heavy_key] = None
checkpoint = torch.load(metric_model_path, map_location=DEVICE)
cfg = checkpoint.get("config", {})
model_name = cfg.get("model_name", "swin_base_patch4_window7_224")
embedding_dim = cfg.get("embedding_dim", metadata.get("embedding_dim", 768))
image_size = cfg.get("image_size", 224)
model = DiagnosticModel(
model_name=model_name, embedding_dim=embedding_dim, image_size=image_size
)
model.load_state_dict(checkpoint["model_state_dict"])
# On peut maintenant libérer le checkpoint pour réduire l'empreinte mémoire
del checkpoint
model = model.to(DEVICE)
model.eval()
# FAISS index
index = None
faiss_index_path = metadata.get("faiss_index_path") or str(
models_path / "faiss_index.bin"
)
faiss_index_file = Path(faiss_index_path)
# Si l'index FAISS n'existe pas localement, on essaie de le récupérer depuis Hugging Face
if faiss is not None and not faiss_index_file.exists():
try:
faiss_index_file = Path(
hf_hub_download(repo_id=MODEL_REPO, filename="faiss_index.bin")
)
except Exception:
faiss_index_file = Path("")
if faiss is not None and faiss_index_file.exists():
try:
# Index FAISS chargé en CPU uniquement (pas de GPU sur Streamlit Cloud)
index = faiss.read_index(str(faiss_index_file))
except Exception:
index = None
# Prototypes
prototypes = None
prototype_labels = None
if "prototypes" in metadata and "prototype_labels" in metadata:
prototypes = np.asarray(metadata["prototypes"], dtype="float32")
prototype_labels = np.asarray(metadata["prototype_labels"], dtype=int)
return model, index, metadata, prototypes, prototype_labels, DEVICE
def preprocess_image_pil(image: Image.Image, size: int = 224) -> torch.Tensor:
"""Prétraitement commun (PIL -> tensor normalisé)."""
if image.mode != "RGB":
image = image.convert("RGB")
image = image.resize((size, size))
img_array = np.array(image).astype("float32") / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype="float32")
std = np.array([0.229, 0.224, 0.225], dtype="float32")
img_array = (img_array - mean) / std
img_tensor = torch.from_numpy(img_array.astype("float32")).permute(2, 0, 1).unsqueeze(0)
return img_tensor.float()
def map_image_path_to_local(raw_path: str) -> str:
"""
Mappe un chemin absolu Colab (/content/drive/MyDrive/dataset_final/...)
vers le dataset local ./dataset_final/... si necessaire.
"""
p = Path(raw_path)
if p.exists():
return str(p)
parts = p.parts
if "dataset_final" in parts:
idx = parts.index("dataset_final")
rel = Path(*parts[idx + 1 :]) if idx + 1 < len(parts) else Path(".")
candidate = DATASET_ROOT_LOCAL / rel
if candidate.exists():
return str(candidate)
return raw_path
def _class_name_for(idx_to_class: Dict[Any, Any], label: int) -> str:
if label in idx_to_class:
return idx_to_class[label]
if str(label) in idx_to_class:
return idx_to_class[str(label)]
return f"class_{label}"
def infer_on_image(
model: torch.nn.Module,
index: Optional[Any],
metadata: Dict[str, Any],
prototypes: Optional[np.ndarray],
prototype_labels: Optional[np.ndarray],
image: Image.Image,
device: torch.device,
top_k: int = 5,
unknown_threshold: float = 0.55,
) -> Dict[str, Any]:
"""
Pipeline d'inférence commun.
Retourne un dict avec:
- predicted_label
- predicted_disease
- predicted_similarity
- is_unknown
- topk_prototypes: [{rank, label, disease, similarity}, ...]
- topk_neighbors: [{rank, label, disease, similarity, image_path}, ...]
"""
image_paths = metadata["image_paths"]
labels = metadata["labels"]
idx_to_class = metadata["idx_to_class"]
# Embedding
img_tensor = preprocess_image_pil(
image, size=metadata.get("image_size", 224)
).to(device)
with torch.no_grad():
emb = model(img_tensor).cpu().numpy().astype("float32") # (1, D)
emb_norm = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12)
# Prototypes
proto_ranking: List[Dict[str, Any]] = []
predicted_label: Optional[int] = None
predicted_score: Optional[float] = None
if prototypes is not None and prototype_labels is not None:
sims = prototypes @ emb_norm.T # (C, 1)
sims = sims.squeeze(axis=1)
order = np.argsort(sims)[::-1]
for rank, ci in enumerate(order[: max(top_k, 5)]):
class_id = int(prototype_labels[ci])
proto_ranking.append(
{
"rank": rank + 1,
"label": class_id,
"disease": _class_name_for(idx_to_class, class_id),
"similarity": float(sims[ci]),
}
)
if proto_ranking:
predicted_label = proto_ranking[0]["label"]
predicted_score = proto_ranking[0]["similarity"]
is_unknown = (
predicted_score is not None
and float(predicted_score) < float(unknown_threshold)
)
# Voisins FAISS
neighbors: List[Dict[str, Any]] = []
if index is not None:
distances, indices = index.search(emb_norm.astype("float32"), k=top_k)
for rank, (dist, idx) in enumerate(zip(distances[0], indices[0])):
label_i = labels[idx]
neighbors.append(
{
"rank": rank + 1,
"label": int(label_i),
"disease": _class_name_for(idx_to_class, label_i),
"similarity": float(dist),
"image_path": map_image_path_to_local(image_paths[idx]),
}
)
return {
"predicted_label": int(predicted_label) if predicted_label is not None else None,
"predicted_disease": "UNKNOWN DISEASE"
if is_unknown
else (_class_name_for(idx_to_class, predicted_label) if predicted_label is not None else None),
"predicted_similarity": predicted_score,
"is_unknown": is_unknown,
"topk_prototypes": proto_ranking[:top_k],
"topk_neighbors": neighbors,
}
def infer_batch(
model: torch.nn.Module,
images: List[Image.Image],
device: torch.device,
image_size: int,
) -> np.ndarray:
"""
Inférence en batch pour de futures extensions multi‑images.
Retourne un tableau numpy d'embeddings normalisés (N, D).
"""
tensors = []
for img in images:
tensors.append(preprocess_image_pil(img, size=image_size))
if not tensors:
return np.empty((0, 0), dtype="float32")
batch = torch.cat(tensors, dim=0).to(device)
with torch.no_grad():
emb = model(batch).cpu().numpy().astype("float32")
emb_norm = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12)
return emb_norm
__all__ = [
"MODELS_PATH_PHASE2",
"DATASET_ROOT_LOCAL",
"DEVICE",
"load_phase2_model_and_metadata",
"preprocess_image_pil",
"map_image_path_to_local",
"infer_on_image",
"infer_batch",
]