| | """ |
| | MONET Tool - Skin lesion feature extraction using MONET model |
| | Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb |
| | """ |
| |
|
| | import os |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import scipy.special |
| | from PIL import Image |
| | from typing import Optional, Dict, List |
| | import torchvision.transforms as T |
| |
|
| |
|
| | |
| | MONET_FEATURES = [ |
| | "MONET_ulceration_crust", |
| | "MONET_hair", |
| | "MONET_vasculature_vessels", |
| | "MONET_erythema", |
| | "MONET_pigmented", |
| | "MONET_gel_water_drop_fluid_dermoscopy_liquid", |
| | "MONET_skin_markings_pen_ink_purple_pen", |
| | ] |
| |
|
| | |
| | MONET_CONCEPT_TERMS = { |
| | "MONET_ulceration_crust": ["ulceration", "crust", "crusting", "ulcer"], |
| | "MONET_hair": ["hair", "hairy"], |
| | "MONET_vasculature_vessels": ["blood vessels", "vasculature", "vessels", "telangiectasia"], |
| | "MONET_erythema": ["erythema", "redness", "red"], |
| | "MONET_pigmented": ["pigmented", "pigmentation", "melanin", "brown"], |
| | "MONET_gel_water_drop_fluid_dermoscopy_liquid": ["dermoscopy gel", "fluid", "water drop", "immersion fluid"], |
| | "MONET_skin_markings_pen_ink_purple_pen": ["pen marking", "ink", "surgical marking", "purple pen"], |
| | } |
| |
|
| | |
| | PROMPT_TEMPLATES = [ |
| | "This is skin image of {}", |
| | "This is dermatology image of {}", |
| | "This is image of {}", |
| | ] |
| |
|
| | |
| | PROMPT_REFS = [ |
| | ["This is skin image"], |
| | ["This is dermatology image"], |
| | ["This is image"], |
| | ] |
| |
|
| |
|
| | def get_transform(n_px=224): |
| | """Get MONET preprocessing transform""" |
| | def convert_image_to_rgb(image): |
| | return image.convert("RGB") |
| |
|
| | return T.Compose([ |
| | T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC), |
| | T.CenterCrop(n_px), |
| | convert_image_to_rgb, |
| | T.ToTensor(), |
| | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| | ]) |
| |
|
| |
|
| | class MonetTool: |
| | """ |
| | MONET tool for extracting concept presence scores from skin lesion images. |
| | Uses the proper contrastive scoring method from the MONET paper. |
| | """ |
| |
|
| | def __init__(self, device: Optional[str] = None, use_hf: bool = True): |
| | """ |
| | Args: |
| | device: Device to run on (cuda, mps, cpu) |
| | use_hf: Use HuggingFace implementation (True) or original CLIP (False) |
| | """ |
| | self.model = None |
| | self.processor = None |
| | self.device = device |
| | self.use_hf = use_hf |
| | self.loaded = False |
| | self.transform = get_transform(224) |
| |
|
| | |
| | self._concept_embeddings = {} |
| |
|
| | def load(self): |
| | """Load MONET model""" |
| | if self.loaded: |
| | return |
| |
|
| | |
| | forced = os.environ.get("SKINPRO_TOOL_DEVICE") |
| | if forced: |
| | self.device = forced |
| | elif self.device is None: |
| | if torch.cuda.is_available(): |
| | self.device = "cuda:0" |
| | elif torch.backends.mps.is_available(): |
| | self.device = "mps" |
| | else: |
| | self.device = "cpu" |
| |
|
| | if self.use_hf: |
| | |
| | from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
| |
|
| | self.processor = AutoProcessor.from_pretrained("chanwkim/monet") |
| | self.model = AutoModelForZeroShotImageClassification.from_pretrained("chanwkim/monet") |
| | self.model.to(self.device) |
| | self.model.eval() |
| | else: |
| | |
| | import clip |
| |
|
| | self.model, _ = clip.load("ViT-L/14", device=self.device, jit=False) |
| | self.model.load_state_dict( |
| | torch.hub.load_state_dict_from_url( |
| | "https://aimslab.cs.washington.edu/MONET/weight_clip.pt" |
| | ) |
| | ) |
| | self.model.eval() |
| |
|
| | self.loaded = True |
| |
|
| | |
| | self._precompute_concept_embeddings() |
| |
|
| | def _encode_text(self, text_list: List[str]) -> torch.Tensor: |
| | """Encode text to embeddings""" |
| | if self.use_hf: |
| | inputs = self.processor(text=text_list, return_tensors="pt", padding=True) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | with torch.no_grad(): |
| | embeddings = self.model.get_text_features(**inputs) |
| | else: |
| | import clip |
| | tokens = clip.tokenize(text_list, truncate=True).to(self.device) |
| | with torch.no_grad(): |
| | embeddings = self.model.encode_text(tokens) |
| |
|
| | return embeddings.cpu() |
| |
|
| | def _encode_image(self, image: Image.Image) -> torch.Tensor: |
| | """Encode image to embedding""" |
| | image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
| |
|
| | if self.use_hf: |
| | with torch.no_grad(): |
| | embedding = self.model.get_image_features(image_tensor) |
| | else: |
| | with torch.no_grad(): |
| | embedding = self.model.encode_image(image_tensor) |
| |
|
| | return embedding.cpu() |
| |
|
| | def _precompute_concept_embeddings(self): |
| | """Pre-compute embeddings for all MONET concepts""" |
| | for feature_name, concept_terms in MONET_CONCEPT_TERMS.items(): |
| | self._concept_embeddings[feature_name] = self._get_concept_embedding(concept_terms) |
| |
|
| | def _get_concept_embedding(self, concept_terms: List[str]) -> Dict: |
| | """ |
| | Generate prompt embeddings for a concept using multiple templates. |
| | |
| | Args: |
| | concept_terms: List of synonymous terms for the concept |
| | |
| | Returns: |
| | dict with target and reference embeddings |
| | """ |
| | |
| | prompt_target = [ |
| | [template.format(term) for term in concept_terms] |
| | for template in PROMPT_TEMPLATES |
| | ] |
| |
|
| | |
| | prompt_target_flat = [p for template_prompts in prompt_target for p in template_prompts] |
| | target_embeddings = self._encode_text(prompt_target_flat) |
| |
|
| | |
| | num_templates = len(PROMPT_TEMPLATES) |
| | num_terms = len(concept_terms) |
| | embed_dim = target_embeddings.shape[-1] |
| | target_embeddings = target_embeddings.view(num_templates, num_terms, embed_dim) |
| |
|
| | |
| | target_embeddings_norm = F.normalize(target_embeddings, dim=2) |
| |
|
| | |
| | prompt_ref_flat = [p for ref_list in PROMPT_REFS for p in ref_list] |
| | ref_embeddings = self._encode_text(prompt_ref_flat) |
| | ref_embeddings = ref_embeddings.view(num_templates, -1, embed_dim) |
| | ref_embeddings_norm = F.normalize(ref_embeddings, dim=2) |
| |
|
| | return { |
| | "target_embedding_norm": target_embeddings_norm, |
| | "ref_embedding_norm": ref_embeddings_norm, |
| | } |
| |
|
| | def _calculate_concept_score( |
| | self, |
| | image_features_norm: torch.Tensor, |
| | concept_embedding: Dict, |
| | temp: float = 1 / np.exp(4.5944) |
| | ) -> float: |
| | """ |
| | Calculate concept presence score using contrastive comparison. |
| | |
| | Args: |
| | image_features_norm: Normalized image embedding [1, embed_dim] |
| | concept_embedding: Dict with target and reference embeddings |
| | temp: Temperature for softmax |
| | |
| | Returns: |
| | Concept presence score (0-1) |
| | """ |
| | target_emb = concept_embedding["target_embedding_norm"].float() |
| | ref_emb = concept_embedding["ref_embedding_norm"].float() |
| |
|
| | |
| | target_similarity = target_emb @ image_features_norm.T.float() |
| | ref_similarity = ref_emb @ image_features_norm.T.float() |
| |
|
| | |
| | target_mean = target_similarity.mean(dim=1).squeeze() |
| | ref_mean = ref_similarity.mean(dim=1).squeeze() |
| |
|
| | |
| | scores = scipy.special.softmax( |
| | np.array([target_mean.numpy() / temp, ref_mean.numpy() / temp]), |
| | axis=0 |
| | ) |
| |
|
| | |
| | return float(scores[0].mean()) |
| |
|
| | def extract_features(self, image: Image.Image) -> Dict[str, float]: |
| | """ |
| | Extract MONET feature scores from a skin lesion image. |
| | |
| | Args: |
| | image: PIL Image to analyze |
| | |
| | Returns: |
| | dict with 7 MONET feature scores (0-1 range) |
| | """ |
| | if not self.loaded: |
| | self.load() |
| |
|
| | |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| |
|
| | |
| | image_features = self._encode_image(image) |
| | image_features_norm = F.normalize(image_features, dim=1) |
| |
|
| | |
| | features = {} |
| | for feature_name in MONET_FEATURES: |
| | concept_emb = self._concept_embeddings[feature_name] |
| | score = self._calculate_concept_score(image_features_norm, concept_emb) |
| | features[feature_name] = score |
| |
|
| | return features |
| |
|
| | def get_feature_vector(self, image: Image.Image) -> List[float]: |
| | """Get MONET features as a list in the expected order.""" |
| | features = self.extract_features(image) |
| | return [features[f] for f in MONET_FEATURES] |
| |
|
| | def get_feature_tensor(self, image: Image.Image) -> torch.Tensor: |
| | """Get MONET features as a PyTorch tensor.""" |
| | return torch.tensor(self.get_feature_vector(image), dtype=torch.float32) |
| |
|
| | def describe_features(self, features: Dict[str, float], threshold: float = 0.6) -> str: |
| | """Generate a natural language description of the MONET features.""" |
| | descriptions = { |
| | "MONET_ulceration_crust": "ulceration or crusting", |
| | "MONET_hair": "visible hair", |
| | "MONET_vasculature_vessels": "visible blood vessels", |
| | "MONET_erythema": "erythema (redness)", |
| | "MONET_pigmented": "pigmentation", |
| | "MONET_gel_water_drop_fluid_dermoscopy_liquid": "dermoscopy gel/fluid", |
| | "MONET_skin_markings_pen_ink_purple_pen": "pen markings", |
| | } |
| |
|
| | present = [] |
| | for feature, score in features.items(): |
| | if score >= threshold: |
| | desc = descriptions.get(feature, feature) |
| | present.append(f"{desc} ({score:.0%})") |
| |
|
| | if not present: |
| | |
| | sorted_features = sorted(features.items(), key=lambda x: x[1], reverse=True)[:3] |
| | present = [f"{descriptions.get(f, f)} ({s:.0%})" for f, s in sorted_features] |
| |
|
| | return "Detected features: " + ", ".join(present) |
| |
|
| | def analyze(self, image: Image.Image) -> Dict: |
| | """Full analysis returning features, vector, and description.""" |
| | features = self.extract_features(image) |
| | vector = [features[f] for f in MONET_FEATURES] |
| | description = self.describe_features(features) |
| |
|
| | return { |
| | "features": features, |
| | "vector": vector, |
| | "description": description, |
| | "feature_names": MONET_FEATURES, |
| | } |
| |
|
| | def __call__(self, image: Image.Image) -> Dict: |
| | """Shorthand for analyze()""" |
| | return self.analyze(image) |
| |
|
| |
|
| | |
| | _monet_instance = None |
| |
|
| |
|
| | def get_monet_tool() -> MonetTool: |
| | """Get or create MONET tool instance""" |
| | global _monet_instance |
| | if _monet_instance is None: |
| | _monet_instance = MonetTool() |
| | return _monet_instance |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | print("MONET Tool Test (Correct Implementation)") |
| | print("=" * 50) |
| |
|
| | tool = MonetTool(use_hf=True) |
| | print("Loading model...") |
| | tool.load() |
| | print("Model loaded!") |
| |
|
| | if len(sys.argv) > 1: |
| | image_path = sys.argv[1] |
| | print(f"\nAnalyzing: {image_path}") |
| | image = Image.open(image_path).convert("RGB") |
| | result = tool.analyze(image) |
| |
|
| | print("\nMONET Features (Contrastive Scores):") |
| | for name, score in result["features"].items(): |
| | bar = "█" * int(score * 20) |
| | print(f" {name}: {score:.3f} {bar}") |
| |
|
| | print(f"\nDescription: {result['description']}") |
| | print(f"\nVector: {[f'{v:.3f}' for v in result['vector']]}") |
| |
|