|
|
import gradio as gr |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import os |
|
|
from typing import Tuple, List, Dict |
|
|
|
|
|
def load_model() -> Tuple[torch.nn.Module, List[str]]: |
|
|
""" |
|
|
Loads the model and class labels. |
|
|
Returns: |
|
|
model: The loaded PyTorch model. |
|
|
labels: List of class labels. |
|
|
""" |
|
|
model_path = "skinconvnext_scripted.pt" |
|
|
labels_path = "labels.txt" |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model not found: {model_path}") |
|
|
if not os.path.exists(labels_path): |
|
|
raise FileNotFoundError("File labels.txt not found.") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = torch.jit.load(model_path, map_location=device) |
|
|
model.eval() |
|
|
|
|
|
with open(labels_path, "r") as f: |
|
|
labels = [line.strip() for line in f.readlines()] |
|
|
|
|
|
return model, labels |
|
|
|
|
|
model, labels = load_model() |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def predict(image: Image.Image) -> Dict[str, float]: |
|
|
""" |
|
|
Makes a prediction for the given image. |
|
|
Args: |
|
|
image (PIL.Image): The input image. |
|
|
|
|
|
Returns: |
|
|
Dict[str, float]: A dictionary where keys are class names, and values are probabilities. |
|
|
""" |
|
|
try: |
|
|
image = image.convert("RGB") |
|
|
image_tensor = preprocess(image).unsqueeze(0) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
image_tensor = image_tensor.to(device) |
|
|
model.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(image_tensor) |
|
|
scores = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
predictions = {label: float(score) for label, score in zip(labels, scores)} |
|
|
sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
return sorted_predictions |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
title = "🔥 Skin-AI" |
|
|
description = ( |
|
|
"🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n" |
|
|
"This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n" |
|
|
"### 🚀 How to Use:\n\n" |
|
|
"1️⃣ Upload or take a photo of the affected skin area.\n\n" |
|
|
"2️⃣ Click the 'Submit' button.\n\n" |
|
|
"3️⃣ The app will return the probabilities for possible skin conditions.\n\n" |
|
|
"⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n" |
|
|
"### 🛠 Technologies Used:\n" |
|
|
"- PyTorch (Lightning)\n" |
|
|
"- Gradio\n" |
|
|
"- Hugging Face Spaces\n\n" |
|
|
"🔗 Source Code: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI )" |
|
|
) |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["example1.jpg"], |
|
|
["example2.jpg"] |
|
|
] |
|
|
|
|
|
def update_submit_state(image): |
|
|
return gr.update(interactive=image is not None) |
|
|
|
|
|
with gr.Blocks() as interface: |
|
|
gr.Markdown(title) |
|
|
gr.Markdown(description) |
|
|
|
|
|
with gr.Row(): |
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
|
|
|
|
submit_button = gr.Button("Submit", interactive=False) |
|
|
output_label = gr.Label(num_top_classes=3, label="Prediction") |
|
|
|
|
|
image_input.change(fn=update_submit_state, inputs=image_input, outputs=submit_button) |
|
|
submit_button.click(fn=predict, inputs=image_input, outputs=output_label) |
|
|
|
|
|
gr.Examples(examples, inputs=image_input) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch(server_name="0.0.0.0", server_port=7860) |