| """ |
| MONET Tool - Skin lesion feature extraction using MONET model |
| Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import scipy.special |
| from PIL import Image |
| from typing import Optional, Dict, List |
| import torchvision.transforms as T |
|
|
|
|
| |
| MONET_FEATURES = [ |
| "MONET_ulceration_crust", |
| "MONET_hair", |
| "MONET_vasculature_vessels", |
| "MONET_erythema", |
| "MONET_pigmented", |
| "MONET_gel_water_drop_fluid_dermoscopy_liquid", |
| "MONET_skin_markings_pen_ink_purple_pen", |
| ] |
|
|
| |
| MONET_CONCEPT_TERMS = { |
| "MONET_ulceration_crust": ["ulceration", "crust", "crusting", "ulcer"], |
| "MONET_hair": ["hair", "hairy"], |
| "MONET_vasculature_vessels": ["blood vessels", "vasculature", "vessels", "telangiectasia"], |
| "MONET_erythema": ["erythema", "redness", "red"], |
| "MONET_pigmented": ["pigmented", "pigmentation", "melanin", "brown"], |
| "MONET_gel_water_drop_fluid_dermoscopy_liquid": ["dermoscopy gel", "fluid", "water drop", "immersion fluid"], |
| "MONET_skin_markings_pen_ink_purple_pen": ["pen marking", "ink", "surgical marking", "purple pen"], |
| } |
|
|
| |
| PROMPT_TEMPLATES = [ |
| "This is skin image of {}", |
| "This is dermatology image of {}", |
| "This is image of {}", |
| ] |
|
|
| |
| PROMPT_REFS = [ |
| ["This is skin image"], |
| ["This is dermatology image"], |
| ["This is image"], |
| ] |
|
|
|
|
| def get_transform(n_px=224): |
| """Get MONET preprocessing transform""" |
| def convert_image_to_rgb(image): |
| return image.convert("RGB") |
|
|
| return T.Compose([ |
| T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC), |
| T.CenterCrop(n_px), |
| convert_image_to_rgb, |
| T.ToTensor(), |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ]) |
|
|
|
|
| class MonetTool: |
| """ |
| MONET tool for extracting concept presence scores from skin lesion images. |
| Uses the proper contrastive scoring method from the MONET paper. |
| """ |
|
|
| def __init__(self, device: Optional[str] = None, use_hf: bool = True): |
| """ |
| Args: |
| device: Device to run on (cuda, mps, cpu) |
| use_hf: Use HuggingFace implementation (True) or original CLIP (False) |
| """ |
| self.model = None |
| self.processor = None |
| self.device = device |
| self.use_hf = use_hf |
| self.loaded = False |
| self.transform = get_transform(224) |
|
|
| |
| self._concept_embeddings = {} |
|
|
| def load(self): |
| """Load MONET model""" |
| if self.loaded: |
| return |
|
|
| |
| if self.device is None: |
| if torch.cuda.is_available(): |
| self.device = "cuda:0" |
| elif torch.backends.mps.is_available(): |
| self.device = "mps" |
| else: |
| self.device = "cpu" |
|
|
| if self.use_hf: |
| |
| from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
|
|
| self.processor = AutoProcessor.from_pretrained("chanwkim/monet") |
| self.model = AutoModelForZeroShotImageClassification.from_pretrained("chanwkim/monet") |
| self.model.to(self.device) |
| self.model.eval() |
| else: |
| |
| import clip |
|
|
| self.model, _ = clip.load("ViT-L/14", device=self.device, jit=False) |
| self.model.load_state_dict( |
| torch.hub.load_state_dict_from_url( |
| "https://aimslab.cs.washington.edu/MONET/weight_clip.pt" |
| ) |
| ) |
| self.model.eval() |
|
|
| self.loaded = True |
|
|
| |
| self._precompute_concept_embeddings() |
|
|
| def _encode_text(self, text_list: List[str]) -> torch.Tensor: |
| """Encode text to embeddings""" |
| if self.use_hf: |
| inputs = self.processor(text=text_list, return_tensors="pt", padding=True) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| embeddings = self.model.get_text_features(**inputs) |
| else: |
| import clip |
| tokens = clip.tokenize(text_list, truncate=True).to(self.device) |
| with torch.no_grad(): |
| embeddings = self.model.encode_text(tokens) |
|
|
| return embeddings.cpu() |
|
|
| def _encode_image(self, image: Image.Image) -> torch.Tensor: |
| """Encode image to embedding""" |
| image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
| if self.use_hf: |
| with torch.no_grad(): |
| embedding = self.model.get_image_features(image_tensor) |
| else: |
| with torch.no_grad(): |
| embedding = self.model.encode_image(image_tensor) |
|
|
| return embedding.cpu() |
|
|
| def _precompute_concept_embeddings(self): |
| """Pre-compute embeddings for all MONET concepts""" |
| for feature_name, concept_terms in MONET_CONCEPT_TERMS.items(): |
| self._concept_embeddings[feature_name] = self._get_concept_embedding(concept_terms) |
|
|
| def _get_concept_embedding(self, concept_terms: List[str]) -> Dict: |
| """ |
| Generate prompt embeddings for a concept using multiple templates. |
| |
| Args: |
| concept_terms: List of synonymous terms for the concept |
| |
| Returns: |
| dict with target and reference embeddings |
| """ |
| |
| prompt_target = [ |
| [template.format(term) for term in concept_terms] |
| for template in PROMPT_TEMPLATES |
| ] |
|
|
| |
| prompt_target_flat = [p for template_prompts in prompt_target for p in template_prompts] |
| target_embeddings = self._encode_text(prompt_target_flat) |
|
|
| |
| num_templates = len(PROMPT_TEMPLATES) |
| num_terms = len(concept_terms) |
| embed_dim = target_embeddings.shape[-1] |
| target_embeddings = target_embeddings.view(num_templates, num_terms, embed_dim) |
|
|
| |
| target_embeddings_norm = F.normalize(target_embeddings, dim=2) |
|
|
| |
| prompt_ref_flat = [p for ref_list in PROMPT_REFS for p in ref_list] |
| ref_embeddings = self._encode_text(prompt_ref_flat) |
| ref_embeddings = ref_embeddings.view(num_templates, -1, embed_dim) |
| ref_embeddings_norm = F.normalize(ref_embeddings, dim=2) |
|
|
| return { |
| "target_embedding_norm": target_embeddings_norm, |
| "ref_embedding_norm": ref_embeddings_norm, |
| } |
|
|
| def _calculate_concept_score( |
| self, |
| image_features_norm: torch.Tensor, |
| concept_embedding: Dict, |
| temp: float = 1 / np.exp(4.5944) |
| ) -> float: |
| """ |
| Calculate concept presence score using contrastive comparison. |
| |
| Args: |
| image_features_norm: Normalized image embedding [1, embed_dim] |
| concept_embedding: Dict with target and reference embeddings |
| temp: Temperature for softmax |
| |
| Returns: |
| Concept presence score (0-1) |
| """ |
| target_emb = concept_embedding["target_embedding_norm"].float() |
| ref_emb = concept_embedding["ref_embedding_norm"].float() |
|
|
| |
| target_similarity = target_emb @ image_features_norm.T.float() |
| ref_similarity = ref_emb @ image_features_norm.T.float() |
|
|
| |
| target_mean = target_similarity.mean(dim=1).squeeze() |
| ref_mean = ref_similarity.mean(dim=1).squeeze() |
|
|
| |
| scores = scipy.special.softmax( |
| np.array([target_mean.numpy() / temp, ref_mean.numpy() / temp]), |
| axis=0 |
| ) |
|
|
| |
| return float(scores[0].mean()) |
|
|
| def extract_features(self, image: Image.Image) -> Dict[str, float]: |
| """ |
| Extract MONET feature scores from a skin lesion image. |
| |
| Args: |
| image: PIL Image to analyze |
| |
| Returns: |
| dict with 7 MONET feature scores (0-1 range) |
| """ |
| if not self.loaded: |
| self.load() |
|
|
| |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| |
| image_features = self._encode_image(image) |
| image_features_norm = F.normalize(image_features, dim=1) |
|
|
| |
| features = {} |
| for feature_name in MONET_FEATURES: |
| concept_emb = self._concept_embeddings[feature_name] |
| score = self._calculate_concept_score(image_features_norm, concept_emb) |
| features[feature_name] = score |
|
|
| return features |
|
|
| def get_feature_vector(self, image: Image.Image) -> List[float]: |
| """Get MONET features as a list in the expected order.""" |
| features = self.extract_features(image) |
| return [features[f] for f in MONET_FEATURES] |
|
|
| def get_feature_tensor(self, image: Image.Image) -> torch.Tensor: |
| """Get MONET features as a PyTorch tensor.""" |
| return torch.tensor(self.get_feature_vector(image), dtype=torch.float32) |
|
|
| def describe_features(self, features: Dict[str, float], threshold: float = 0.6) -> str: |
| """Generate a natural language description of the MONET features.""" |
| descriptions = { |
| "MONET_ulceration_crust": "ulceration or crusting", |
| "MONET_hair": "visible hair", |
| "MONET_vasculature_vessels": "visible blood vessels", |
| "MONET_erythema": "erythema (redness)", |
| "MONET_pigmented": "pigmentation", |
| "MONET_gel_water_drop_fluid_dermoscopy_liquid": "dermoscopy gel/fluid", |
| "MONET_skin_markings_pen_ink_purple_pen": "pen markings", |
| } |
|
|
| present = [] |
| for feature, score in features.items(): |
| if score >= threshold: |
| desc = descriptions.get(feature, feature) |
| present.append(f"{desc} ({score:.0%})") |
|
|
| if not present: |
| |
| sorted_features = sorted(features.items(), key=lambda x: x[1], reverse=True)[:3] |
| present = [f"{descriptions.get(f, f)} ({s:.0%})" for f, s in sorted_features] |
|
|
| return "Detected features: " + ", ".join(present) |
|
|
| def analyze(self, image: Image.Image) -> Dict: |
| """Full analysis returning features, vector, and description.""" |
| features = self.extract_features(image) |
| vector = [features[f] for f in MONET_FEATURES] |
| description = self.describe_features(features) |
|
|
| return { |
| "features": features, |
| "vector": vector, |
| "description": description, |
| "feature_names": MONET_FEATURES, |
| } |
|
|
| def __call__(self, image: Image.Image) -> Dict: |
| """Shorthand for analyze()""" |
| return self.analyze(image) |
|
|
|
|
| |
| _monet_instance = None |
|
|
|
|
| def get_monet_tool() -> MonetTool: |
| """Get or create MONET tool instance""" |
| global _monet_instance |
| if _monet_instance is None: |
| _monet_instance = MonetTool() |
| return _monet_instance |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| print("MONET Tool Test (Correct Implementation)") |
| print("=" * 50) |
|
|
| tool = MonetTool(use_hf=True) |
| print("Loading model...") |
| tool.load() |
| print("Model loaded!") |
|
|
| if len(sys.argv) > 1: |
| image_path = sys.argv[1] |
| print(f"\nAnalyzing: {image_path}") |
| image = Image.open(image_path).convert("RGB") |
| result = tool.analyze(image) |
|
|
| print("\nMONET Features (Contrastive Scores):") |
| for name, score in result["features"].items(): |
| bar = "█" * int(score * 20) |
| print(f" {name}: {score:.3f} {bar}") |
|
|
| print(f"\nDescription: {result['description']}") |
| print(f"\nVector: {[f'{v:.3f}' for v in result['vector']]}") |
|
|