SkinProAI / models /monet_tool.py
cgoodmaker's picture
Force MCP tool models to CPU to avoid GPU VRAM contention with MedGemma
1a97904
raw
history blame
12.7 kB
"""
MONET Tool - Skin lesion feature extraction using MONET model
Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb
"""
import os
import torch
import torch.nn.functional as F
import numpy as np
import scipy.special
from PIL import Image
from typing import Optional, Dict, List
import torchvision.transforms as T
# The 7 MONET feature columns expected by ConvNeXt
MONET_FEATURES = [
"MONET_ulceration_crust",
"MONET_hair",
"MONET_vasculature_vessels",
"MONET_erythema",
"MONET_pigmented",
"MONET_gel_water_drop_fluid_dermoscopy_liquid",
"MONET_skin_markings_pen_ink_purple_pen",
]
# Concept terms for each MONET feature (multiple synonyms improve detection)
MONET_CONCEPT_TERMS = {
"MONET_ulceration_crust": ["ulceration", "crust", "crusting", "ulcer"],
"MONET_hair": ["hair", "hairy"],
"MONET_vasculature_vessels": ["blood vessels", "vasculature", "vessels", "telangiectasia"],
"MONET_erythema": ["erythema", "redness", "red"],
"MONET_pigmented": ["pigmented", "pigmentation", "melanin", "brown"],
"MONET_gel_water_drop_fluid_dermoscopy_liquid": ["dermoscopy gel", "fluid", "water drop", "immersion fluid"],
"MONET_skin_markings_pen_ink_purple_pen": ["pen marking", "ink", "surgical marking", "purple pen"],
}
# Prompt templates (from MONET paper)
PROMPT_TEMPLATES = [
"This is skin image of {}",
"This is dermatology image of {}",
"This is image of {}",
]
# Reference prompts (baseline for contrastive scoring)
PROMPT_REFS = [
["This is skin image"],
["This is dermatology image"],
["This is image"],
]
def get_transform(n_px=224):
"""Get MONET preprocessing transform"""
def convert_image_to_rgb(image):
return image.convert("RGB")
return T.Compose([
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(n_px),
convert_image_to_rgb,
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
class MonetTool:
"""
MONET tool for extracting concept presence scores from skin lesion images.
Uses the proper contrastive scoring method from the MONET paper.
"""
def __init__(self, device: Optional[str] = None, use_hf: bool = True):
"""
Args:
device: Device to run on (cuda, mps, cpu)
use_hf: Use HuggingFace implementation (True) or original CLIP (False)
"""
self.model = None
self.processor = None
self.device = device
self.use_hf = use_hf
self.loaded = False
self.transform = get_transform(224)
# Cache for concept embeddings
self._concept_embeddings = {}
def load(self):
"""Load MONET model"""
if self.loaded:
return
# Determine device (respect SKINPRO_TOOL_DEVICE override for GPU sharing)
forced = os.environ.get("SKINPRO_TOOL_DEVICE")
if forced:
self.device = forced
elif self.device is None:
if torch.cuda.is_available():
self.device = "cuda:0"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
if self.use_hf:
# HuggingFace implementation
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
self.processor = AutoProcessor.from_pretrained("chanwkim/monet")
self.model = AutoModelForZeroShotImageClassification.from_pretrained("chanwkim/monet")
self.model.to(self.device)
self.model.eval()
else:
# Original CLIP implementation
import clip
self.model, _ = clip.load("ViT-L/14", device=self.device, jit=False)
self.model.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://aimslab.cs.washington.edu/MONET/weight_clip.pt"
)
)
self.model.eval()
self.loaded = True
# Pre-compute concept embeddings for all MONET features
self._precompute_concept_embeddings()
def _encode_text(self, text_list: List[str]) -> torch.Tensor:
"""Encode text to embeddings"""
if self.use_hf:
inputs = self.processor(text=text_list, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
embeddings = self.model.get_text_features(**inputs)
else:
import clip
tokens = clip.tokenize(text_list, truncate=True).to(self.device)
with torch.no_grad():
embeddings = self.model.encode_text(tokens)
return embeddings.cpu()
def _encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image to embedding"""
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
if self.use_hf:
with torch.no_grad():
embedding = self.model.get_image_features(image_tensor)
else:
with torch.no_grad():
embedding = self.model.encode_image(image_tensor)
return embedding.cpu()
def _precompute_concept_embeddings(self):
"""Pre-compute embeddings for all MONET concepts"""
for feature_name, concept_terms in MONET_CONCEPT_TERMS.items():
self._concept_embeddings[feature_name] = self._get_concept_embedding(concept_terms)
def _get_concept_embedding(self, concept_terms: List[str]) -> Dict:
"""
Generate prompt embeddings for a concept using multiple templates.
Args:
concept_terms: List of synonymous terms for the concept
Returns:
dict with target and reference embeddings
"""
# Target prompts: "This is skin image of {term}"
prompt_target = [
[template.format(term) for term in concept_terms]
for template in PROMPT_TEMPLATES
]
# Flatten and encode
prompt_target_flat = [p for template_prompts in prompt_target for p in template_prompts]
target_embeddings = self._encode_text(prompt_target_flat)
# Reshape to [num_templates, num_terms, embed_dim]
num_templates = len(PROMPT_TEMPLATES)
num_terms = len(concept_terms)
embed_dim = target_embeddings.shape[-1]
target_embeddings = target_embeddings.view(num_templates, num_terms, embed_dim)
# Normalize
target_embeddings_norm = F.normalize(target_embeddings, dim=2)
# Reference prompts: "This is skin image"
prompt_ref_flat = [p for ref_list in PROMPT_REFS for p in ref_list]
ref_embeddings = self._encode_text(prompt_ref_flat)
ref_embeddings = ref_embeddings.view(num_templates, -1, embed_dim)
ref_embeddings_norm = F.normalize(ref_embeddings, dim=2)
return {
"target_embedding_norm": target_embeddings_norm,
"ref_embedding_norm": ref_embeddings_norm,
}
def _calculate_concept_score(
self,
image_features_norm: torch.Tensor,
concept_embedding: Dict,
temp: float = 1 / np.exp(4.5944)
) -> float:
"""
Calculate concept presence score using contrastive comparison.
Args:
image_features_norm: Normalized image embedding [1, embed_dim]
concept_embedding: Dict with target and reference embeddings
temp: Temperature for softmax
Returns:
Concept presence score (0-1)
"""
target_emb = concept_embedding["target_embedding_norm"].float()
ref_emb = concept_embedding["ref_embedding_norm"].float()
# Similarity: [num_templates, num_terms] @ [embed_dim, 1] -> [num_templates, num_terms, 1]
target_similarity = target_emb @ image_features_norm.T.float()
ref_similarity = ref_emb @ image_features_norm.T.float()
# Mean over terms for each template
target_mean = target_similarity.mean(dim=1).squeeze() # [num_templates]
ref_mean = ref_similarity.mean(dim=1).squeeze() # [num_templates]
# Softmax between target and reference (contrastive scoring)
scores = scipy.special.softmax(
np.array([target_mean.numpy() / temp, ref_mean.numpy() / temp]),
axis=0
)
# Return mean of target scores across templates
return float(scores[0].mean())
def extract_features(self, image: Image.Image) -> Dict[str, float]:
"""
Extract MONET feature scores from a skin lesion image.
Args:
image: PIL Image to analyze
Returns:
dict with 7 MONET feature scores (0-1 range)
"""
if not self.loaded:
self.load()
# Ensure RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Get image embedding
image_features = self._encode_image(image)
image_features_norm = F.normalize(image_features, dim=1)
# Calculate score for each MONET feature
features = {}
for feature_name in MONET_FEATURES:
concept_emb = self._concept_embeddings[feature_name]
score = self._calculate_concept_score(image_features_norm, concept_emb)
features[feature_name] = score
return features
def get_feature_vector(self, image: Image.Image) -> List[float]:
"""Get MONET features as a list in the expected order."""
features = self.extract_features(image)
return [features[f] for f in MONET_FEATURES]
def get_feature_tensor(self, image: Image.Image) -> torch.Tensor:
"""Get MONET features as a PyTorch tensor."""
return torch.tensor(self.get_feature_vector(image), dtype=torch.float32)
def describe_features(self, features: Dict[str, float], threshold: float = 0.6) -> str:
"""Generate a natural language description of the MONET features."""
descriptions = {
"MONET_ulceration_crust": "ulceration or crusting",
"MONET_hair": "visible hair",
"MONET_vasculature_vessels": "visible blood vessels",
"MONET_erythema": "erythema (redness)",
"MONET_pigmented": "pigmentation",
"MONET_gel_water_drop_fluid_dermoscopy_liquid": "dermoscopy gel/fluid",
"MONET_skin_markings_pen_ink_purple_pen": "pen markings",
}
present = []
for feature, score in features.items():
if score >= threshold:
desc = descriptions.get(feature, feature)
present.append(f"{desc} ({score:.0%})")
if not present:
# Show top features even if below threshold
sorted_features = sorted(features.items(), key=lambda x: x[1], reverse=True)[:3]
present = [f"{descriptions.get(f, f)} ({s:.0%})" for f, s in sorted_features]
return "Detected features: " + ", ".join(present)
def analyze(self, image: Image.Image) -> Dict:
"""Full analysis returning features, vector, and description."""
features = self.extract_features(image)
vector = [features[f] for f in MONET_FEATURES]
description = self.describe_features(features)
return {
"features": features,
"vector": vector,
"description": description,
"feature_names": MONET_FEATURES,
}
def __call__(self, image: Image.Image) -> Dict:
"""Shorthand for analyze()"""
return self.analyze(image)
# Singleton instance
_monet_instance = None
def get_monet_tool() -> MonetTool:
"""Get or create MONET tool instance"""
global _monet_instance
if _monet_instance is None:
_monet_instance = MonetTool()
return _monet_instance
if __name__ == "__main__":
import sys
print("MONET Tool Test (Correct Implementation)")
print("=" * 50)
tool = MonetTool(use_hf=True)
print("Loading model...")
tool.load()
print("Model loaded!")
if len(sys.argv) > 1:
image_path = sys.argv[1]
print(f"\nAnalyzing: {image_path}")
image = Image.open(image_path).convert("RGB")
result = tool.analyze(image)
print("\nMONET Features (Contrastive Scores):")
for name, score in result["features"].items():
bar = "█" * int(score * 20)
print(f" {name}: {score:.3f} {bar}")
print(f"\nDescription: {result['description']}")
print(f"\nVector: {[f'{v:.3f}' for v in result['vector']]}")