Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import gradio as gr | |
| import io | |
| import base64 | |
| import torch.nn.functional as F | |
| import warnings | |
| import os | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| print("🔍 Starting Skin Lesion Analysis System...") | |
| # --- VERIFIED MODEL CONFIGURATIONS --- | |
| MODEL_CONFIGS = { | |
| "specialized": [ | |
| { | |
| 'name': 'Syaha Skin Cancer', | |
| 'id': 'syaha/skin_cancer_detection_model', | |
| 'type': 'custom', | |
| 'accuracy': 0.82, | |
| 'description': 'CNN trained on HAM10000 dataset', | |
| 'emoji': '🩺' | |
| }, | |
| { | |
| 'name': 'VRJBro Skin Detection', | |
| 'id': 'VRJBro/skin-cancer-detection', | |
| 'type': 'custom', | |
| 'accuracy': 0.85, | |
| 'description': 'Specialized detector (2024)', | |
| 'emoji': '🎯' | |
| }, | |
| { | |
| 'name': 'Anwarkh1 Skin Cancer', | |
| 'id': 'Anwarkh1/Skin_Cancer-Image_Classification', | |
| 'type': 'vit', | |
| 'accuracy': 0.89, | |
| 'description': 'Multi-class skin lesion classifier', | |
| 'emoji': '🧠' | |
| }, | |
| { | |
| 'name': 'Jhoppanne SMOTE', | |
| 'id': 'jhoppanne/SkinCancerClassifier_smote-V0', | |
| 'type': 'custom', | |
| 'accuracy': 0.86, | |
| 'description': 'ISIC 2024 model using SMOTE for class imbalance', | |
| 'emoji': '⚖️' | |
| }, | |
| { | |
| 'name': 'ViT ISIC Binary', | |
| 'id': 'ahishamm/vit-base-binary-isic-sharpened-patch-32', | |
| 'type': 'vit', | |
| 'accuracy': 0.89, | |
| 'description': 'ViT model for binary ISIC lesion classification (benign/malignant)', | |
| 'emoji': '🔬' | |
| }, | |
| { | |
| 'name': 'ViT ISIC Multi-class', | |
| 'id': 'ahishamm/vit-base-isic-patch-16', | |
| 'type': 'vit', | |
| 'accuracy': 0.79, | |
| 'description': 'ViT model for multi-class ISIC lesion classification', | |
| 'emoji': '🔍' | |
| } | |
| ], | |
| "general": [ | |
| { | |
| 'name': 'ViT Base General', | |
| 'id': 'google/vit-base-patch16-224', | |
| 'type': 'vit', | |
| 'accuracy': 0.78, | |
| 'description': 'ViT base pre-trained on ImageNet-1k.', | |
| 'emoji': '📈' | |
| }, | |
| { | |
| 'name': 'ResNet-50 (Microsoft)', | |
| 'id': 'microsoft/resnet-50', | |
| 'type': 'custom', | |
| 'accuracy': 0.77, | |
| 'description': 'Classic ResNet-50, robust and high-performing.', | |
| 'emoji': '⚙️' | |
| }, | |
| { | |
| 'name': 'DeiT Base (Facebook)', | |
| 'id': 'facebook/deit-base-patch16-224', | |
| 'type': 'vit', | |
| 'accuracy': 0.79, | |
| 'description': 'Data-efficient Image Transformer, efficient and accurate.', | |
| 'emoji': '💡' | |
| }, | |
| { | |
| 'name': 'MobileNetV2 (Google)', | |
| 'id': 'google/mobilenet_v2_1.0_224', | |
| 'type': 'custom', | |
| 'accuracy': 0.72, | |
| 'description': 'Lightweight model for mobile or low-resource environments.', | |
| 'emoji': '📱' | |
| }, | |
| { | |
| 'name': 'Swin Tiny (Microsoft)', | |
| 'id': 'microsoft/swin-tiny-patch4-window7-224', | |
| 'type': 'custom', | |
| 'accuracy': 0.81, | |
| 'description': 'Swin Transformer (Tiny), efficient and powerful.', | |
| 'emoji': '🌀' | |
| }, | |
| { | |
| 'name': 'ViT Base General (Fallback)', | |
| 'id': 'google/vit-base-patch16-224-in21k', | |
| 'type': 'vit', | |
| 'accuracy': 0.75, | |
| 'description': 'Generic ViT fallback model', | |
| 'emoji': '🔄' | |
| } | |
| ] | |
| } | |
| # --- SAFE MODEL LOADING --- | |
| loaded_models = {} | |
| model_performance = {} | |
| def load_model_safe(config): | |
| """Safely loads a model with multiple revision fallbacks.""" | |
| try: | |
| model_id = config['id'] | |
| model_type = config['type'] | |
| print(f"🔄 Loading {config['emoji']} {config['name']}...") | |
| revisions_to_try = ["main", "no_float16_weights", None] | |
| processor = None | |
| model = None | |
| load_successful = False | |
| for revision in revisions_to_try: | |
| try: | |
| if revision: | |
| print(f" Trying revision: {revision}") | |
| processor = AutoImageProcessor.from_pretrained(model_id, revision=revision) | |
| model = AutoModelForImageClassification.from_pretrained(model_id, revision=revision) | |
| else: | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id) | |
| load_successful = True | |
| break | |
| except Exception as e_rev: | |
| print(f" Failed with revision '{revision}': {e_rev}") | |
| if model_type == 'vit' and revision is None: | |
| try: | |
| processor = ViTImageProcessor.from_pretrained(model_id) | |
| model = ViTForImageClassification.from_pretrained(model_id) | |
| load_successful = True | |
| break | |
| except Exception as e_vit: | |
| print(f" Failed with ViTImageProcessor/ViTForImageClassification: {e_vit}") | |
| continue | |
| if not load_successful: | |
| raise Exception("Failed to load model with all revisions.") | |
| model.eval() | |
| test_input = processor(Image.new('RGB', (224, 224), color='white'), return_tensors="pt") | |
| with torch.no_grad(): | |
| model(**test_input) | |
| print(f"✅ {config['emoji']} {config['name']} loaded successfully") | |
| return { | |
| 'processor': processor, | |
| 'model': model, | |
| 'config': config, | |
| 'category': config.get('category', 'general') | |
| } | |
| except Exception as e: | |
| print(f"❌ {config['emoji']} {config['name']} failed: {e}") | |
| return None | |
| print("\n📦 Loading models...") | |
| for category, configs in MODEL_CONFIGS.items(): | |
| for config in configs: | |
| config['category'] = category | |
| model_data = load_model_safe(config) | |
| if model_data: | |
| loaded_models[config['name']] = model_data | |
| model_performance[config['name']] = config.get('accuracy', 0.8) | |
| if not loaded_models: | |
| print("❌ No model could be loaded. Using fallback models...") | |
| fallback_models = [ | |
| 'google/vit-base-patch16-224-in21k', | |
| 'microsoft/resnet-50' | |
| ] | |
| for fallback_id in fallback_models: | |
| try: | |
| print(f"🔄 Trying fallback: {fallback_id}") | |
| processor = AutoImageProcessor.from_pretrained(fallback_id) | |
| model = AutoModelForImageClassification.from_pretrained(fallback_id) | |
| model.eval() | |
| loaded_models[f'Fallback-{fallback_id.split("/")[-1]}'] = { | |
| 'processor': processor, | |
| 'model': model, | |
| 'config': {'name': f'Fallback {fallback_id}', 'emoji': '🏥'}, | |
| 'category': 'general' | |
| } | |
| print(f"✅ Fallback model {fallback_id} loaded") | |
| break | |
| except Exception as e: | |
| print(f"❌ Fallback {fallback_id} failed: {e}") | |
| continue | |
| # --- SKIN LESION CLASSES --- | |
| CLASSES = [ | |
| "Actinic Keratosis / Bowen (AKIEC)", | |
| "Basal Cell Carcinoma (BCC)", | |
| "Benign Keratosis (BKL)", | |
| "Dermatofibroma (DF)", | |
| "Malignant Melanoma (MEL)", | |
| "Melanocytic Nevus (NV)", | |
| "Vascular Lesion (VASC)" | |
| ] | |
| RISK_LEVELS = { | |
| 0: {'level': 'High', 'color': '#ff6b35', 'urgency': 'Referral in 48h'}, | |
| 1: {'level': 'Critical', 'color': '#cc0000', 'urgency': 'Immediate referral'}, | |
| 2: {'level': 'Low', 'color': '#44ff44', 'urgency': 'Routine check'}, | |
| 3: {'level': 'Low', 'color': '#44ff44', 'urgency': 'Routine check'}, | |
| 4: {'level': 'Critical', 'color': '#990000', 'urgency': 'URGENT - Oncology'}, | |
| 5: {'level': 'Low', 'color': '#66ff66', 'urgency': 'Follow-up in 6 months'}, | |
| 6: {'level': 'Moderate', 'color': '#ffaa00', 'urgency': 'Check-up in 3 months'} | |
| } | |
| MALIGNANT_INDICES = [0, 1, 4] | |
| # --- PREDICTION FUNCTION --- | |
| def predict_with_model(image, model_data): | |
| try: | |
| config = model_data['config'] | |
| image_resized = image.resize((224, 224), Image.LANCZOS) | |
| processor = model_data['processor'] | |
| model = model_data['model'] | |
| inputs = processor(image_resized, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0] | |
| probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0] | |
| # Handling models with unexpected output dimensions | |
| if len(probabilities) == 7: | |
| mapped_probs = probabilities | |
| elif len(probabilities) == 2: | |
| mapped_probs = np.zeros(7) | |
| mapped_probs[4] = probabilities[1] * 0.5 | |
| mapped_probs[1] = probabilities[1] * 0.3 | |
| mapped_probs[0] = probabilities[1] * 0.2 | |
| mapped_probs[5] = probabilities[0] * 0.6 | |
| mapped_probs[2] = probabilities[0] * 0.2 | |
| mapped_probs[3] = probabilities[0] * 0.1 | |
| mapped_probs[6] = probabilities[0] * 0.1 | |
| mapped_probs /= np.sum(mapped_probs) | |
| else: | |
| mapped_probs = np.ones(7) / 7 | |
| predicted_idx = int(np.argmax(mapped_probs)) | |
| confidence = float(mapped_probs[predicted_idx]) | |
| return { | |
| 'model': f"{config['emoji']} {config['name']}", | |
| 'class': CLASSES[predicted_idx], | |
| 'confidence': confidence, | |
| 'probabilities': mapped_probs, | |
| 'is_malignant': predicted_idx in MALIGNANT_INDICES, | |
| 'predicted_idx': predicted_idx, | |
| 'success': True, | |
| 'category': model_data['category'] | |
| } | |
| except Exception as e: | |
| print(f"❌ Error in {config['name']}: {e}") | |
| return {'model': config['name'], 'success': False, 'error': str(e)} | |
| # --- CONSENSUS ANALYSIS FUNCTION --- | |
| def analyze_lesion(img): | |
| if img is None: | |
| return "<h3>⚠️ Please upload an image</h3>" | |
| predictions = [] | |
| for model_name, model_data in loaded_models.items(): | |
| if model_data.get('category') != 'dummy': | |
| pred = predict_with_model(img, model_data) | |
| if pred.get('success'): | |
| predictions.append(pred) | |
| if not predictions: | |
| return "<h3>❌ No valid predictions</h3>" | |
| class_votes, confidence_sum = {}, {} | |
| for pred in predictions: | |
| c = pred['class'] | |
| conf = pred['confidence'] | |
| class_votes[c] = class_votes.get(c, 0) + 1 | |
| confidence_sum[c] = confidence_sum.get(c, 0) + conf | |
| consensus_class = max(class_votes, key=class_votes.get) | |
| avg_conf = confidence_sum[consensus_class] / class_votes[consensus_class] | |
| consensus_idx = CLASSES.index(consensus_class) | |
| risk_info = RISK_LEVELS[consensus_idx] | |
| return f""" | |
| <h2>🏥 Skin Lesion Analysis Report</h2> | |
| <h3>Consensus Diagnosis: {consensus_class}</h3> | |
| <p>Average Confidence: <b>{avg_conf:.1%}</b></p> | |
| <p>Risk Level: <b style='color:{risk_info['color']}'>{risk_info['level']}</b></p> | |
| <p>Recommendation: {risk_info['urgency']}</p> | |
| <hr> | |
| <h4>Model Details:</h4> | |
| {''.join([f"<p>{p['model']}: {p['class']} ({p['confidence']:.1%})</p>" for p in predictions])} | |
| <hr> | |
| <p style='color:gray;'>⚠️ This AI tool is for educational and research purposes only. Always consult a dermatologist for accurate medical diagnosis.</p> | |
| """ | |
| # --- GRADIO INTERFACE --- | |
| gr.Interface( | |
| fn=analyze_lesion, | |
| inputs=gr.Image(type="pil", label="Upload a Skin Lesion Image"), | |
| outputs=gr.HTML(label="AI Analysis Report"), | |
| title="Skin Lesion Analysis AI", | |
| description=""" | |
| <h2 style="text-align:center;">🩺 AI-Powered Skin Lesion Analyzer 🩺</h2> | |
| <p style="text-align:center;">Upload a clear skin lesion image. The system runs several deep learning models (both skin-specialized and general vision models) and provides a consensus diagnosis with confidence and risk level.</p> | |
| <p style="text-align:center; color:gray;">⚠️ Research prototype only. Not a substitute for professional medical advice.</p> | |
| """, | |
| theme="soft" | |
| ).launch(debug=True) | |