Spaces:
Running
Running
| import torch | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageClassification, AutoConfig | |
| import requests | |
| from io import BytesIO | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
| # Set environment variables for better network handling | |
| os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '300' # Increase timeout to 5 minutes | |
| os.environ['TRANSFORMERS_OFFLINE'] = '0' # Ensure online mode | |
| class SkinDiseaseClassifier: | |
| CLASS_NAMES = [ | |
| "Acne", "Basal Cell Carcinoma", "Benign Keratosis-like Lesions", "Chickenpox", "Eczema", "Healthy Skin", | |
| "Measles", "Melanocytic Nevi", "Melanoma", "Monkeypox", "Psoriasis Lichen Planus and related diseases", | |
| "Seborrheic Keratoses and other Benign Tumors", "Tinea Ringworm Candidiasis and other Fungal Infections", | |
| "Vitiligo", "Warts Molluscum and other Viral Infections" | |
| ] | |
| def __init__(self, repo_id="muhammadnoman76/skin-disease-classifier", cache_dir=None): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.repo_id = repo_id | |
| self.model = self.load_trained_model() | |
| self.transform = self.get_inference_transform() | |
| def load_trained_model(self): | |
| model_path= hf_hub_download(repo_id=self.repo_id, filename="healthy.pth", token=HUGGINGFACE_TOKEN) | |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) | |
| classifier_weight = checkpoint['model_state_dict']['classifier.3.weight'] | |
| num_classes = classifier_weight.size(0) | |
| config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=num_classes) | |
| model = AutoModelForImageClassification.from_pretrained( | |
| "google/vit-base-patch16-224-in21k", | |
| config=config, | |
| ignore_mismatched_sizes=True | |
| ) | |
| in_features = model.classifier.in_features | |
| model.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(in_features, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(0.3), | |
| torch.nn.Linear(512, num_classes) | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(self.device) | |
| if self.device.type == 'cuda': | |
| model = model.half() | |
| model.eval() | |
| return model | |
| def get_inference_transform(self): | |
| return transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def load_image(self, image_input): | |
| try: | |
| if isinstance(image_input, Image.Image): | |
| image = image_input | |
| elif isinstance(image_input, str): | |
| if image_input.startswith(('http://', 'https://')): | |
| response = requests.get(image_input) | |
| image = Image.open(BytesIO(response.content)) | |
| else: | |
| if not os.path.exists(image_input): | |
| raise FileNotFoundError(f"Image file not found: {image_input}") | |
| image = Image.open(image_input) | |
| elif hasattr(image_input, 'read'): | |
| image = Image.open(image_input) | |
| else: | |
| raise ValueError("Unsupported image input type") | |
| return image.convert('RGB') | |
| except Exception as e: | |
| raise Exception(f"Error loading image: {str(e)}") | |
| def predict(self, image_input, confidence_threshold=0.3): | |
| try: | |
| image = self.load_image(image_input) | |
| image_tensor = self.transform(image).unsqueeze(0) | |
| if self.device.type == 'cuda': | |
| image_tensor = image_tensor.half() | |
| image_tensor = image_tensor.to(self.device) | |
| with torch.inference_mode(): | |
| outputs = self.model(pixel_values=image_tensor).logits | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| confidence = confidence.item() | |
| predicted_class_idx = predicted.item() | |
| confidence_percentage = round(confidence * 100, 2) | |
| predicted_class_name = self.CLASS_NAMES[predicted_class_idx] | |
| return predicted_class_name, confidence_percentage | |
| except Exception as e: | |
| raise Exception(f"Error during prediction: {str(e)}") |